package store import ( "context" "crypto/subtle" "database/sql" "errors" "strings" "time" "atlas9.dev/c/core" "atlas9.dev/c/core/dbi" "atlas9.dev/c/core/iam" "atlas9.dev/c/core/tokens" "atlas9.dev/c/demo/lib" ) const ( tenantInvitationKeyLen = 16 tenantInvitationCodeLen = 32 tenantInvitationLifetime = 14 * 24 * time.Hour tenantInvitationTimeFmt = "2006-01-02 15:04:05" ) // TenantInvitationTaskPayload is the payload pushed to the // tenant_invitation_tasks queue. Shape matches Tasks_SendTenantInvitationReq. type TenantInvitationTaskPayload struct { Email string Token string } type SqliteTenantInvitationStore struct { db dbi.DBI guard lib.Guard } var _ iam.TenantInvitationStore = (*SqliteTenantInvitationStore)(nil) func NewSqliteTenantInvitationStore(db dbi.DBI, guard lib.Guard) *SqliteTenantInvitationStore { return &SqliteTenantInvitationStore{db: db, guard: guard} } func (s *SqliteTenantInvitationStore) Create(ctx context.Context, tenantID core.ID, email string) (string, error) { if err := s.guard.Check(ctx, iam.CapTenantInvitationsCreate, tenantID, ""); err != nil { return "", err } email = strings.ToLower(email) key := tokens.RandomString(tenantInvitationKeyLen)() code := tokens.RandomString(tenantInvitationCodeLen)() expiresAt := time.Now().UTC().Add(tenantInvitationLifetime) _, err := s.db.Exec(ctx, ` INSERT INTO tenant_invitations (tenant, email, token_key, token_code, expires_at) VALUES ($1, $2, $3, $4, $5) `, tenantID, email, key, code, expiresAt.Format(tenantInvitationTimeFmt)) if err != nil { msg := err.Error() if strings.Contains(msg, "UNIQUE constraint failed") { return "", iam.ErrAlreadyExists } if strings.Contains(msg, "FOREIGN KEY constraint failed") { return "", iam.ErrTenantNotFound } return "", err } return key + "." + code, nil } func (s *SqliteTenantInvitationStore) GetByToken(ctx context.Context, token string, out *iam.TenantInvitation) error { key, code, ok := strings.Cut(token, ".") if !ok { return core.ErrNotFound } var ( storedCode string expiresAtStr string ) err := s.db.QueryRow(ctx, ` SELECT tenant, email, token_code, expires_at FROM tenant_invitations WHERE token_key = $1 AND expires_at > datetime('now') `, key).Scan(&out.Tenant, &out.Email, &storedCode, &expiresAtStr) if errors.Is(err, sql.ErrNoRows) { return core.ErrNotFound } if err != nil { return err } // Timing-safe compare. Any failure here collapses into ErrNotFound with // the row-missing and expired cases — do not split into a distinct error. if subtle.ConstantTimeCompare([]byte(storedCode), []byte(code)) != 1 { return core.ErrNotFound } out.ExpiresAt, _ = time.Parse(tenantInvitationTimeFmt, expiresAtStr) return nil } func (s *SqliteTenantInvitationStore) DeleteByToken(ctx context.Context, token string) error { // Look up first so we can verify the code before deleting. A straight // DELETE by token_key would let an attacker rescind invitations by // guessing keys without needing the code. var inv iam.TenantInvitation if err := s.GetByToken(ctx, token, &inv); err != nil { return err } key, _, _ := strings.Cut(token, ".") _, err := s.db.Exec(ctx, `DELETE FROM tenant_invitations WHERE token_key = $1`, key) return err } func (s *SqliteTenantInvitationStore) List(ctx context.Context, tenantID core.ID, page core.PageReq, out *core.Page[iam.TenantInvitation]) error { if err := s.guard.Check(ctx, iam.CapTenantInvitationsRead, tenantID, ""); err != nil { return err } return dbi.Paginate(ctx, s.db, page, out, func(inv iam.TenantInvitation) string { return inv.Email }, `SELECT tenant, email, expires_at FROM tenant_invitations WHERE tenant = $1 AND email > $cursor ORDER BY email LIMIT $limit`, tenantID) } func (s *SqliteTenantInvitationStore) ListByEmail(ctx context.Context, email string, page core.PageReq, out *core.Page[iam.TenantInvitation]) error { email = strings.ToLower(email) if err := s.guard.System(ctx, iam.CapTenantInvitationsRead); err != nil { // Fall back to self-access: the principal's email must match. principal := iam.GetPrincipal(ctx) var userEmail string err := s.db.QueryRow(ctx, `SELECT email FROM users WHERE id = $1`, principal.Subject).Scan(&userEmail) if errors.Is(err, sql.ErrNoRows) { return iam.ErrForbidden } if err != nil { return err } if strings.ToLower(userEmail) != email { return iam.ErrForbidden } } return dbi.Paginate(ctx, s.db, page, out, func(inv iam.TenantInvitation) string { return inv.Tenant.String() }, `SELECT tenant, email, expires_at FROM tenant_invitations WHERE email = $1 AND tenant > $cursor ORDER BY tenant LIMIT $limit`, email) } func (s *SqliteTenantInvitationStore) Delete(ctx context.Context, tenantID core.ID, email string) error { if err := s.guard.Check(ctx, iam.CapTenantInvitationsDelete, tenantID, ""); err != nil { return err } email = strings.ToLower(email) _, err := s.db.Exec(ctx, ` DELETE FROM tenant_invitations WHERE tenant = $1 AND email = $2 `, tenantID, email) return err }