package throttle import ( "context" "database/sql" "time" "atlas9.dev/c/core/dbi" ) // PostgresBucket is a Postgres-backed TokenBucket implementation. // Requires a table created with PostgresBucketSchema. type PostgresBucket struct { DB *sql.DB now func() time.Time } func NewPostgresBucket(db *sql.DB) *PostgresBucket { return &PostgresBucket{DB: db, now: time.Now} } const PostgresBucketSchema = `CREATE TABLE IF NOT EXISTS token_buckets ( key TEXT PRIMARY KEY, tokens DOUBLE PRECISION NOT NULL, last_refill TIMESTAMPTZ NOT NULL )` func (p *PostgresBucket) Take(ctx context.Context, key string, capacity int, refillRate float64) (bool, error) { return dbi.ReadWrite(ctx, p.DB, func(tx dbi.DBI) (bool, error) { now := p.now() cap := float64(capacity) var tokens float64 var lastRefill time.Time err := tx.QueryRow(ctx, `SELECT tokens, last_refill FROM token_buckets WHERE key = $1 FOR UPDATE`, key, ).Scan(&tokens, &lastRefill) if err == sql.ErrNoRows { _, err = tx.Exec(ctx, `INSERT INTO token_buckets (key, tokens, last_refill) VALUES ($1, $2, $3)`, key, cap-1, now, ) return err == nil, err } if err != nil { return false, err } // Refill tokens. elapsed := now.Sub(lastRefill).Seconds() tokens += elapsed * refillRate if tokens > cap { tokens = cap } if tokens >= 1 { tokens-- _, err = tx.Exec(ctx, `UPDATE token_buckets SET tokens = $1, last_refill = $2 WHERE key = $3`, tokens, now, key, ) return err == nil, err } _, err = tx.Exec(ctx, `UPDATE token_buckets SET tokens = $1, last_refill = $2 WHERE key = $3`, tokens, now, key, ) return false, err }) } func (p *PostgresBucket) Reset(ctx context.Context, key string) error { _, err := dbi.ReadWrite(ctx, p.DB, func(tx dbi.DBI) (struct{}, error) { _, err := tx.Exec(ctx, `DELETE FROM token_buckets WHERE key = $1`, key) return struct{}{}, err }) return err }