package store import ( "context" "database/sql" "encoding/json" "fmt" "time" "atlas9.dev/c/core" "atlas9.dev/c/core/dbi" "atlas9.dev/c/core/tokens" ) // SqliteStore is like SqliteStore but accepts a dbi.DBI, so it can // participate in an existing transaction. type SqliteTokenStore[T any] struct { db dbi.DBI table string ttl time.Duration keyGen func() string secretGen func() string timeFormat string } var _ tokens.Store[any] = (*SqliteTokenStore[any])(nil) // TODO guard func NewSqliteTokenStore[T any](db dbi.DBI, table string, ttl time.Duration, keyGen, secretGen func() string) *SqliteTokenStore[T] { return &SqliteTokenStore[T]{ db: db, table: table, ttl: ttl, keyGen: keyGen, secretGen: secretGen, timeFormat: "2006-01-02 15:04:05", } } func (s *SqliteTokenStore[T]) Create(ctx context.Context, data T) (*tokens.Token[T], error) { key := s.keyGen() secret := s.secretGen() dataJSON, err := json.Marshal(data) if err != nil { return nil, fmt.Errorf("marshaling token data: %w", err) } expiresAt := time.Now().UTC().Add(s.ttl) _, err = s.db.Exec(ctx, "INSERT OR REPLACE INTO "+s.table+" (key, data, expires_at) VALUES ($1, $2, $3)", key, string(dataJSON), expiresAt.Format(s.timeFormat)) if err != nil { return nil, fmt.Errorf("storing token: %w", err) } return &tokens.Token[T]{ Key: key, Secret: secret, ExpiresAt: expiresAt, Data: data, }, nil } func (s *SqliteTokenStore[T]) Consume(ctx context.Context, key string) (*tokens.Token[T], error) { var dataJSON string var expiresAtStr string err := s.db.QueryRow(ctx, "SELECT data, expires_at FROM "+s.table+" WHERE key = $1 AND expires_at > datetime('now')", key).Scan(&dataJSON, &expiresAtStr) if err != nil { if err == sql.ErrNoRows { return nil, core.ErrNotFound } return nil, fmt.Errorf("getting token: %w", err) } var data T if err := json.Unmarshal([]byte(dataJSON), &data); err != nil { return nil, fmt.Errorf("unmarshaling token data: %w", err) } _, err = s.db.Exec(ctx, "DELETE FROM "+s.table+" WHERE key = $1", key) if err != nil { return nil, fmt.Errorf("deleting token: %w", err) } expiresAt, _ := time.Parse(s.timeFormat, expiresAtStr) return &tokens.Token[T]{Key: key, ExpiresAt: expiresAt, Data: data}, nil }