package tokens import ( "context" "database/sql" "encoding/json" "fmt" "time" "atlas9.dev/c/core" ) const sqliteTimeFormat = "2006-01-02 15:04:05" type Options struct { Table string Expiration time.Duration } type SqliteStore[T any] struct { db *sql.DB opts Options } func NewSqliteStore[T any](db *sql.DB, opts Options) *SqliteStore[T] { return &SqliteStore[T]{db: db, opts: opts} } func (s *SqliteStore[T]) Put(ctx context.Context, key string, data T) (*Token[T], error) { dataJSON, err := json.Marshal(data) if err != nil { return nil, fmt.Errorf("marshaling token data: %w", err) } expiresAt := time.Now().UTC().Add(s.opts.Expiration) _, err = s.db.ExecContext(ctx, "INSERT OR REPLACE INTO "+s.opts.Table+" (key, data, expires_at) VALUES ($1, $2, $3)", key, string(dataJSON), expiresAt.Format(sqliteTimeFormat)) if err != nil { return nil, fmt.Errorf("storing token: %w", err) } return &Token[T]{Key: key, ExpiresAt: expiresAt, Data: data}, nil } func (s *SqliteStore[T]) Get(ctx context.Context, key string) (*Token[T], error) { var dataJSON string var expiresAtStr string err := s.db.QueryRowContext(ctx, "SELECT data, expires_at FROM "+s.opts.Table+" WHERE key = $1 AND expires_at > datetime('now')", key).Scan(&dataJSON, &expiresAtStr) if err != nil { if err == sql.ErrNoRows { return nil, core.ErrNotFound } return nil, fmt.Errorf("getting token: %w", err) } var data T if err := json.Unmarshal([]byte(dataJSON), &data); err != nil { return nil, fmt.Errorf("unmarshaling token data: %w", err) } expiresAt, _ := time.Parse(sqliteTimeFormat, expiresAtStr) return &Token[T]{Key: key, ExpiresAt: expiresAt, Data: data}, nil } func (s *SqliteStore[T]) Delete(ctx context.Context, key string) error { _, err := s.db.ExecContext(ctx, "DELETE FROM "+s.opts.Table+" WHERE key = $1", key) if err != nil { return fmt.Errorf("deleting token: %w", err) } return nil }