package api import ( "context" "database/sql" "errors" "fmt" "net/http" "strings" "atlas9.dev/c/core" "atlas9.dev/c/core/dbi" "atlas9.dev/c/demo/lib" "atlas9.dev/c/demo/lib/domains" "atlas9.dev/c/demo/lib/sso" "atlas9.dev/c/iam/oidc_provider" "golang.org/x/oauth2" ) type SsoImpl struct { DB *sql.DB Domains dbi.Factory[domains.Store] SsoConfigs dbi.Factory[sso.Store] OidcDeps oidc_provider.Deps } func (s *SsoImpl) Check(w http.ResponseWriter, r *http.Request) { ctx := r.Context() var req Sso_CheckReq if read(w, r, &req) { return } var res Sso_CheckRes sp := strings.Split(req.Email, "@") if len(sp) == 2 { rawName := sp[1] cfg, err := dbi.Read(ctx, s.DB, func(tx dbi.DBI) (*sso.Config, error) { d, err := s.Domains(tx).GetByName(ctx, rawName) if err != nil { return nil, err } return s.SsoConfigs(tx).GetByDomain(ctx, d.Tenant, d.ID) }) if errors.Is(err, core.ErrNotFound) { write(r.Context(), w, nil, res) return } if err != nil { write(ctx, w, err, nil) return } // TODO hard-coded base url res.Redirect = fmt.Sprintf("http://localhost:8010/auth/%s/%s/login", cfg.Tenant, cfg.ID) } write(r.Context(), w, nil, res) } func (s *SsoImpl) Login(w http.ResponseWriter, r *http.Request) { ctx := r.Context() p, err := s.getProvider(ctx, r) if write(ctx, w, err, nil) { return } // TODO should pass ctx into these functions p.HandleLogin(w, r) } func (s *SsoImpl) Callback(w http.ResponseWriter, r *http.Request) { ctx := r.Context() p, err := s.getProvider(ctx, r) if write(ctx, w, err, nil) { return } p.HandleCallback(w, r) } func (s *SsoImpl) getProvider(ctx context.Context, r *http.Request) (*oidc_provider.Provider, error) { tenant, err := core.ParseID(r.PathValue("tenant")) if err != nil { return nil, err } id, err := core.ParseID(r.PathValue("id")) if err != nil { return nil, err } if tenant.IsEmpty() { return nil, errors.New("tenant ID is empty") } if id.IsEmpty() { return nil, errors.New("sso ID is empty") } access := lib.GetAccess(ctx) // All requests need to be able to read the requested sso config. // TODO this modifies the Access, which allows this grant to escape this scope. access.Grant(sso.Cap_Sso_Read, tenant, "") cfg, err := dbi.Read(ctx, s.OidcDeps.DB, func(tx dbi.DBI) (*sso.Config, error) { return s.SsoConfigs(tx).Get(ctx, tenant, id) }) if err != nil { return nil, err } return oidc_provider.New(ctx, oidc_provider.Config{ Name: cfg.Name, // TODO look into whether name is really needed and how it's used. ClientID: cfg.ClientID, ClientSecret: cfg.ClientSecret, // TODO hard-coded base url RedirectURL: fmt.Sprintf("http://localhost:8010/auth/%s/%s/callback", tenant, id), Scopes: []string{"profile", "openid", "email"}, AuthCodeOptions: []oauth2.AuthCodeOption{}, LoginPath: "/login", SuccessPath: "/", IssuerUrl: cfg.IssuerUrl, JwksUrl: cfg.JwksUrl, Endpoint: oauth2.Endpoint{ AuthURL: cfg.AuthUrl, DeviceAuthURL: cfg.DeviceAuthUrl, TokenURL: cfg.TokenUrl, AuthStyle: oauth2.AuthStyleAutoDetect, }, }, s.OidcDeps), nil }