package store import ( "context" "database/sql" "errors" "strings" "atlas9.dev/c/core" "atlas9.dev/c/core/dbi" "atlas9.dev/c/core/iam" "atlas9.dev/c/demo/lib" ) type SqliteTenantMemberStore struct { db dbi.DBI guard lib.Guard } var _ iam.TenantMemberStore = (*SqliteTenantMemberStore)(nil) func NewSqliteTenantMemberStore(db dbi.DBI, guard lib.Guard) *SqliteTenantMemberStore { return &SqliteTenantMemberStore{db: db, guard: guard} } func (s *SqliteTenantMemberStore) Create(ctx context.Context, m iam.TenantMember) error { if err := s.guard.Check(ctx, iam.CapTenantMembersCreate, m.Tenant, ""); err != nil { return err } _, err := s.db.Exec(ctx, ` INSERT INTO tenant_members (tenant, user_id, owner) VALUES ($1, $2, $3) `, m.Tenant, m.UserID, m.Owner) if err == nil { return nil } msg := err.Error() if strings.Contains(msg, "UNIQUE constraint failed") { return iam.ErrAlreadyExists } if strings.Contains(msg, "FOREIGN KEY constraint failed") { // Determine whether the tenant or user is missing. var n int checkErr := s.db.QueryRow(ctx, `SELECT 1 FROM tenants WHERE id = $1`, m.Tenant).Scan(&n) if errors.Is(checkErr, sql.ErrNoRows) { return iam.ErrTenantNotFound } if checkErr != nil { return checkErr } return iam.ErrUserNotFound } return err } func (s *SqliteTenantMemberStore) Remove(ctx context.Context, tenant core.ID, userID core.ID) error { if err := s.guard.Check(ctx, iam.CapTenantMembersRemove, tenant, ""); err != nil { return err } if err := s.checkNotLastOwner(ctx, tenant, userID); err != nil { return err } res, err := s.db.Exec(ctx, ` DELETE FROM tenant_members WHERE tenant = $1 AND user_id = $2 `, tenant, userID) if err != nil { return err } n, err := res.RowsAffected() if err != nil { return err } if n == 0 { return core.ErrNotFound } return nil } // checkNotLastOwner returns an error if the user is the last owner of the tenant. func (s *SqliteTenantMemberStore) checkNotLastOwner(ctx context.Context, tenant core.ID, userID core.ID) error { var isOwner bool err := s.db.QueryRow(ctx, ` SELECT owner FROM tenant_members WHERE tenant = $1 AND user_id = $2 `, tenant, userID).Scan(&isOwner) if errors.Is(err, sql.ErrNoRows) { return nil // not a member, nothing to check } if err != nil { return err } if !isOwner { return nil } var otherOwners int err = s.db.QueryRow(ctx, ` SELECT COUNT(*) FROM tenant_members WHERE tenant = $1 AND owner = true AND user_id != $2 `, tenant, userID).Scan(&otherOwners) if err != nil { return err } if otherOwners == 0 { return iam.ErrLastOwner } return nil } func (s *SqliteTenantMemberStore) Get(ctx context.Context, tenant core.ID, userID core.ID, out *iam.TenantMember) error { if err := s.guard.Check(ctx, iam.CapTenantMembersRead, tenant, ""); err != nil { return err } return dbi.Get(ctx, s.db, out, ` SELECT tenant, user_id, owner FROM tenant_members WHERE tenant = $1 AND user_id = $2 `, tenant, userID) } func (s *SqliteTenantMemberStore) List(ctx context.Context, tenant core.ID, page core.PageReq, out *core.Page[iam.TenantMember]) error { if err := s.guard.Check(ctx, iam.CapTenantMembersRead, tenant, ""); err != nil { return err } return dbi.Paginate(ctx, s.db, page, out, func(m iam.TenantMember) string { return m.UserID.String() }, `SELECT tenant, user_id, owner FROM tenant_members WHERE tenant = $1 AND user_id > $cursor ORDER BY user_id LIMIT $limit`, tenant) } func (s *SqliteTenantMemberStore) ListByUser(ctx context.Context, userID core.ID, page core.PageReq, out *core.Page[iam.TenantMember]) error { cursorFn := func(m iam.TenantMember) string { return m.Tenant.String() } // TODO this seems to ignore non-nil errors that are not access denied and let's that fall through // which feels slightly sloppy. The interface also makes distinguishing expected vs non-expected // errors slightly clunky. if err := s.guard.System(ctx, iam.CapTenantMembersRead); err == nil { // System-wide access: return all memberships for this user. return dbi.Paginate(ctx, s.db, page, out, cursorFn, `SELECT tenant, user_id, owner FROM tenant_members WHERE user_id = $1 AND tenant > $cursor ORDER BY tenant LIMIT $limit`, userID) } // TODO could this just be handled by the caller (api layer)? principal := iam.GetPrincipal(ctx) if principal.Subject == userID.String() { // The caller is querying for their own tenants. return dbi.Paginate(ctx, s.db, page, out, cursorFn, `SELECT tm.tenant, tm.user_id, tm.owner FROM tenant_members tm WHERE tm.user_id = $1 AND tm.tenant > $cursor ORDER BY tm.tenant LIMIT $limit`, userID) } // Non-system, non-self: fetch all memberships for this user and filter // to tenants where the caller has read access. Users are typically in // few tenants so fetching all and filtering in memory is acceptable. rows, err := s.db.Query(ctx, ` SELECT tenant, user_id, owner FROM tenant_members WHERE user_id = $1 ORDER BY tenant `, userID) if err != nil { return err } defer rows.Close() var all []iam.TenantMember for rows.Next() { var m iam.TenantMember if err := rows.Scan(&m.Tenant, &m.UserID, &m.Owner); err != nil { return err } if s.guard.Check(ctx, iam.CapTenantMembersRead, m.Tenant, "") == nil { all = append(all, m) } } if err := rows.Err(); err != nil { return err } // Apply cursor and limit manually. limit := page.Limit if limit <= 0 { limit = 100 } start := 0 for start < len(all) && all[start].Tenant.String() <= page.Cursor { start++ } all = all[start:] if len(all) > limit { all = all[:limit] out.Cursor = cursorFn(all[limit-1]) } out.Items = all return nil }