package throttle import ( "context" "database/sql" "time" "atlas9.dev/c/core/dbi" ) const sqliteTimeFormat = "2006-01-02 15:04:05" // SqliteBucket is a SQLite-backed TokenBucket implementation. // Requires a table created with SqliteBucketSchema. type SqliteBucket struct { DB *sql.DB } func NewSqliteBucket(db *sql.DB) *SqliteBucket { return &SqliteBucket{DB: db} } // SqliteBucketSchema returns the SQL to create the token_buckets table. const SqliteBucketSchema = `CREATE TABLE IF NOT EXISTS token_buckets ( key TEXT PRIMARY KEY, tokens REAL NOT NULL, last_refill TEXT NOT NULL )` func (s *SqliteBucket) Take(ctx context.Context, key string, capacity int, refillRate float64) (bool, error) { return dbi.ReadWrite(ctx, s.DB, func(tx dbi.DBI) (bool, error) { now := time.Now() nowStr := now.UTC().Format(sqliteTimeFormat) cap := float64(capacity) var tokens float64 var lastRefillStr string err := tx.QueryRow(ctx, `SELECT tokens, last_refill FROM token_buckets WHERE key = $1`, key, ).Scan(&tokens, &lastRefillStr) if err == sql.ErrNoRows { // New bucket: start with capacity-1 (this request uses one token). _, err = tx.Exec(ctx, `INSERT INTO token_buckets (key, tokens, last_refill) VALUES ($1, $2, $3)`, key, cap-1, nowStr, ) return err == nil, err } if err != nil { return false, err } lastRefill, err := time.Parse(sqliteTimeFormat, lastRefillStr) if err != nil { return false, err } // Refill tokens. elapsed := now.Sub(lastRefill).Seconds() tokens += elapsed * refillRate if tokens > cap { tokens = cap } if tokens >= 1 { tokens-- _, err = tx.Exec(ctx, `UPDATE token_buckets SET tokens = $1, last_refill = $2 WHERE key = $3`, tokens, nowStr, key, ) return err == nil, err } // Denied — still update last_refill so next call computes elapsed correctly. _, err = tx.Exec(ctx, `UPDATE token_buckets SET tokens = $1, last_refill = $2 WHERE key = $3`, tokens, nowStr, key, ) return false, err }) } func (s *SqliteBucket) Reset(ctx context.Context, key string) error { _, err := dbi.ReadWrite(ctx, s.DB, func(tx dbi.DBI) (struct{}, error) { _, err := tx.Exec(ctx, `DELETE FROM token_buckets WHERE key = $1`, key) return struct{}{}, err }) return err }