package main import ( "context" "fmt" "strings" "atlas9.dev/c/core" "atlas9.dev/c/core/dbi" "atlas9.dev/c/core/iam" ) type SqliteGroupStore struct { db dbi.DBI } var _ iam.GroupStore = (*SqliteGroupStore)(nil) func NewSqliteGroupStore(db dbi.DBI) *SqliteGroupStore { return &SqliteGroupStore{db: db} } func (s *SqliteGroupStore) Save(ctx context.Context, g *iam.Group) (bool, error) { if g.ID.IsZero() { g.ID = core.NewID() } // Only require a parent for paths deeper than namespace.name // (e.g. app.engineering.backend requires app.engineering to exist, // but app.engineering does not require app to exist). parent := g.Path.Parent() if parent != "" && strings.Contains(string(parent), ".") { if _, err := s.GetByPath(ctx, parent); err != nil { return false, fmt.Errorf("parent group %s not found: %w", parent, err) } } res, err := s.db.Exec(ctx, ` INSERT INTO groups (id, path, name) VALUES ($1, $2, $3) ON CONFLICT (path) DO NOTHING `, g.ID, g.Path, g.Name) if err != nil { return false, err } n, err := res.RowsAffected() if err != nil { return false, err } created := n > 0 if !created { err = dbi.Get(ctx, s.db, g, ` SELECT id, path, name FROM groups WHERE path = $1 `, g.Path) } return created, err } func (s *SqliteGroupStore) Get(ctx context.Context, id core.ID) (*iam.Group, error) { var g iam.Group err := dbi.Get(ctx, s.db, &g, ` SELECT id, path, name FROM groups WHERE id = $1 `, id) return &g, err } func (s *SqliteGroupStore) GetByPath(ctx context.Context, path core.Path) (*iam.Group, error) { var g iam.Group err := dbi.Get(ctx, s.db, &g, ` SELECT id, path, name FROM groups WHERE path = $1 `, path) return &g, err } func (s *SqliteGroupStore) Delete(ctx context.Context, id core.ID) error { g, err := s.Get(ctx, id) if err != nil { return err } // Check for children using LIKE with escaped underscores. prefix := escapeLike(string(g.Path)) + "." var count int row := s.db.QueryRow(ctx, ` SELECT COUNT(*) FROM groups WHERE path LIKE $1 ESCAPE '\' `, prefix+"%") if err := row.Scan(&count); err != nil { return err } if count > 0 { return iam.ErrGroupHasChildren } _, err = s.db.Exec(ctx, `DELETE FROM group_members WHERE group_id = $1`, id) if err != nil { return err } _, err = s.db.Exec(ctx, `DELETE FROM groups WHERE id = $1`, id) return err } func (s *SqliteGroupStore) List(ctx context.Context, page core.PageReq) (core.Page[iam.Group], error) { limit := page.Limit if limit <= 0 { limit = 100 } rows, err := s.db.Query(ctx, ` SELECT id, path, name FROM groups WHERE id > $1 ORDER BY id LIMIT $2 `, page.Cursor, limit) if err != nil { return core.Page[iam.Group]{}, err } defer rows.Close() var items []iam.Group for rows.Next() { var g iam.Group if err := rows.Scan(&g.ID, &g.Path, &g.Name); err != nil { return core.Page[iam.Group]{}, err } items = append(items, g) } if err := rows.Err(); err != nil { return core.Page[iam.Group]{}, err } var cursor string if len(items) == limit { cursor = items[limit-1].ID.String() } return core.Page[iam.Group]{Items: items, Cursor: cursor}, nil } func (s *SqliteGroupStore) ListByNamespace(ctx context.Context, namespace string, page core.PageReq) (core.Page[iam.Group], error) { limit := page.Limit if limit <= 0 { limit = 100 } pattern := escapeLike(namespace) + ".%" rows, err := s.db.Query(ctx, ` SELECT id, path, name FROM groups WHERE path LIKE $1 ESCAPE '\' AND id > $2 ORDER BY id LIMIT $3 `, pattern, page.Cursor, limit) if err != nil { return core.Page[iam.Group]{}, err } defer rows.Close() var items []iam.Group for rows.Next() { var g iam.Group if err := rows.Scan(&g.ID, &g.Path, &g.Name); err != nil { return core.Page[iam.Group]{}, err } items = append(items, g) } if err := rows.Err(); err != nil { return core.Page[iam.Group]{}, err } var cursor string if len(items) == limit { cursor = items[limit-1].ID.String() } return core.Page[iam.Group]{Items: items, Cursor: cursor}, nil } // escapeLike escapes _ and % characters for use in LIKE patterns with ESCAPE '\'. func escapeLike(s string) string { s = strings.ReplaceAll(s, `\`, `\\`) s = strings.ReplaceAll(s, `%`, `\%`) s = strings.ReplaceAll(s, `_`, `\_`) return s }