// Package oidc_provider - OAuth/OIDC provider integration with automatic user creation and linking. package oidc_provider import ( "context" "database/sql" "errors" "fmt" "log/slog" "net/http" "atlas9.dev/c/core" "atlas9.dev/c/core/dbi" "atlas9.dev/c/core/iam" "atlas9.dev/c/core/routes" "atlas9.dev/c/core/tokens" gooidc "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" ) // UserInfo holds the identity information extracted from a provider. type UserInfo struct { Subject string Email string EmailVerified bool } // Config holds the configuration for an OAuth provider. type Config struct { // Name identifies this provider (e.g. "google", "apple", "github"). Name string // OAuth2 credentials ClientID string ClientSecret string RedirectURL string // Scopes to request (e.g. []string{"openid", "email", "profile"}) Scopes []string // AuthCodeOptions are additional options passed to AuthCodeURL. AuthCodeOptions []oauth2.AuthCodeOption // LoginPath is where to redirect on auth errors. LoginPath string // SuccessPath is where to redirect after successful login. SuccessPath string // Issuer is the OIDC issuer URL (e.g. "https://accounts.google.com"). // Used for OIDC discovery to obtain endpoints and token verification. // Ignored when Endpoint is set. Issuer string // Endpoint sets OAuth2 endpoints manually, skipping OIDC discovery. // Required for non-OIDC providers like GitHub. Endpoint *oauth2.Endpoint // FetchUserInfo extracts user identity from the token response. // For OIDC providers, leave nil to use automatic ID token verification. // For plain OAuth2 providers, set this to call the provider's userinfo API. FetchUserInfo func(ctx context.Context, token *oauth2.Token, nonce string) (*UserInfo, error) } // Deps holds the dependencies injected by the application. type Deps struct { DB *sql.DB Tokens tokens.Store[StateData] Users dbi.Factory[iam.UserStore] OAuth dbi.Factory[iam.OAuthStore] Sessions iam.SessionStore } // Provider handles OAuth authentication for a single identity provider. type Provider struct { name string oauth2Config *oauth2.Config authCodeOptions []oauth2.AuthCodeOption fetchUserInfo func(ctx context.Context, token *oauth2.Token, nonce string) (*UserInfo, error) tokens tokens.Store[StateData] users dbi.Factory[iam.UserStore] oauth dbi.Factory[iam.OAuthStore] sessions iam.SessionStore db *sql.DB loginPath string successPath string } // New creates a new OAuth Provider. // For OIDC providers (Issuer set, Endpoint nil), it performs OIDC discovery // and sets up automatic ID token verification. // For plain OAuth2 providers (Endpoint set), FetchUserInfo must be provided. func New(ctx context.Context, cfg Config, deps Deps) (*Provider, error) { var endpoint oauth2.Endpoint fetchUserInfo := cfg.FetchUserInfo if cfg.Endpoint != nil { endpoint = *cfg.Endpoint } else { // OIDC discovery oidcProvider, err := gooidc.NewProvider(ctx, cfg.Issuer) if err != nil { return nil, fmt.Errorf("oidc discovery for %s: %w", cfg.Name, err) } endpoint = oidcProvider.Endpoint() if fetchUserInfo == nil { verifier := oidcProvider.Verifier(&gooidc.Config{ ClientID: cfg.ClientID, }) fetchUserInfo = oidcFetchUserInfo(verifier) } } oauth2Config := &oauth2.Config{ ClientID: cfg.ClientID, ClientSecret: cfg.ClientSecret, RedirectURL: cfg.RedirectURL, Endpoint: endpoint, Scopes: cfg.Scopes, } return &Provider{ name: cfg.Name, oauth2Config: oauth2Config, authCodeOptions: cfg.AuthCodeOptions, fetchUserInfo: fetchUserInfo, tokens: deps.Tokens, users: deps.Users, oauth: deps.OAuth, sessions: deps.Sessions, db: deps.DB, loginPath: cfg.LoginPath, successPath: cfg.SuccessPath, }, nil } // oidcFetchUserInfo returns a FetchUserInfo function that verifies an OIDC ID token. func oidcFetchUserInfo(verifier *gooidc.IDTokenVerifier) func(context.Context, *oauth2.Token, string) (*UserInfo, error) { return func(ctx context.Context, token *oauth2.Token, nonce string) (*UserInfo, error) { rawIDToken, ok := token.Extra("id_token").(string) if !ok || rawIDToken == "" { return nil, fmt.Errorf("no id_token in token response") } idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { return nil, fmt.Errorf("verifying ID token: %w", err) } if idToken.Nonce != nonce { return nil, fmt.Errorf("nonce mismatch") } var claims struct { Email string `json:"email"` EmailVerified bool `json:"email_verified"` } if err := idToken.Claims(&claims); err != nil { return nil, fmt.Errorf("extracting claims: %w", err) } if !claims.EmailVerified { return nil, fmt.Errorf("email not verified: %s", claims.Email) } return &UserInfo{ Subject: idToken.Subject, Email: claims.Email, EmailVerified: true, }, nil } } // Routes returns the HTTP routes for this provider. func (p *Provider) Routes() []routes.Route { return []routes.Route{ routes.HTTP("GET /auth/"+p.name, p.HandleLogin), routes.HTTP("GET /auth/"+p.name+"/callback", p.HandleCallback), routes.HTTP("POST /auth/"+p.name+"/callback", p.HandleCallback), } } func (p *Provider) HandleLogin(w http.ResponseWriter, r *http.Request) { genKey := tokens.RandomString(32) state, err := genKey() if err != nil { slog.Error("generating state", "error", err, "provider", p.name) http.Error(w, "internal error", http.StatusInternalServerError) return } codeVerifier, err := genKey() if err != nil { slog.Error("generating code verifier", "error", err, "provider", p.name) http.Error(w, "internal error", http.StatusInternalServerError) return } nonce, err := genKey() if err != nil { slog.Error("generating nonce", "error", err, "provider", p.name) http.Error(w, "internal error", http.StatusInternalServerError) return } _, err = p.tokens.Put(r.Context(), state, StateData{ CodeVerifier: codeVerifier, Nonce: nonce, }) if err != nil { slog.Error("storing OAuth state", "error", err, "provider", p.name) http.Error(w, "internal error", http.StatusInternalServerError) return } codeChallenge := generateCodeChallenge(codeVerifier) opts := []oauth2.AuthCodeOption{ oauth2.SetAuthURLParam("code_challenge", codeChallenge), oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("nonce", nonce), } opts = append(opts, p.authCodeOptions...) url := p.oauth2Config.AuthCodeURL(state, opts...) http.Redirect(w, r, url, http.StatusTemporaryRedirect) } func (p *Provider) HandleCallback(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { slog.Error("parsing callback form", "error", err, "provider", p.name) http.Redirect(w, r, p.loginPath, http.StatusSeeOther) return } receivedState := r.FormValue("state") if receivedState == "" { slog.Error("no state parameter in callback", "provider", p.name) http.Redirect(w, r, p.loginPath, http.StatusSeeOther) return } tok, err := p.tokens.Get(r.Context(), receivedState) if err != nil { slog.Error("retrieving OAuth state", "error", err, "provider", p.name) http.Redirect(w, r, p.loginPath, http.StatusSeeOther) return } _ = p.tokens.Delete(r.Context(), receivedState) // Exchange authorization code for tokens with PKCE verifier code := r.FormValue("code") token, err := p.oauth2Config.Exchange( r.Context(), code, oauth2.SetAuthURLParam("code_verifier", tok.Data.CodeVerifier), ) if err != nil { slog.Error("code exchange failed", "error", err, "provider", p.name) http.Redirect(w, r, p.loginPath, http.StatusSeeOther) return } // Fetch user identity from the provider info, err := p.fetchUserInfo(r.Context(), token, tok.Data.Nonce) if err != nil { slog.Error("fetching user info", "error", err, "provider", p.name) http.Redirect(w, r, p.loginPath, http.StatusSeeOther) return } // Find or create user ctx := r.Context() user, err := p.findOrCreateUser(ctx, info.Subject, info.Email) if err != nil { slog.Error("finding or creating user", "error", err, "provider", p.name) http.Redirect(w, r, p.loginPath, http.StatusSeeOther) return } if err := p.sessions.Put(ctx, user.ID); err != nil { slog.Error("creating session", "error", err, "provider", p.name) http.Redirect(w, r, p.loginPath, http.StatusSeeOther) return } // 303 See Other ensures the browser uses GET for the redirect, // which matters when the callback was a POST (e.g. Apple's form_post). http.Redirect(w, r, p.successPath, http.StatusSeeOther) } // findOrCreateUser implements the three-step cascade: // 1. Lookup by provider+subject (already linked) // 2. Lookup by email (existing user, link provider) // 3. Create new user and link provider func (p *Provider) findOrCreateUser(ctx context.Context, subject, email string) (*iam.User, error) { return dbi.ReadWrite(ctx, p.db, func(tx dbi.DBI) (*iam.User, error) { users := p.users(tx) oauthStore := p.oauth(tx) // Already linked? user, err := oauthStore.GetUserByProvider(ctx, p.name, subject) if err == nil { return user, nil } if !errors.Is(err, core.ErrNotFound) { return nil, fmt.Errorf("looking up provider: %w", err) } // Existing user with this email? user, err = users.GetByEmail(ctx, email) if err != nil && !errors.Is(err, core.ErrNotFound) { return nil, fmt.Errorf("looking up email: %w", err) } // Create new user if not found if errors.Is(err, core.ErrNotFound) { user = &iam.User{Email: email, Verified: true} if _, err = users.Save(ctx, user); err != nil { return nil, fmt.Errorf("creating user: %w", err) } } // Link provider if err := oauthStore.AddProvider(ctx, user.ID, p.name, subject); err != nil { return nil, fmt.Errorf("adding provider: %w", err) } return user, nil }) }