package throttle import ( "context" "fmt" "time" ) // RedisClient is the minimal interface needed for the Redis token bucket. // Compatible with github.com/redis/go-redis/v9 and similar clients. type RedisClient interface { Eval(ctx context.Context, script string, keys []string, args ...any) (any, error) Del(ctx context.Context, keys ...string) error } // RedisBucket is a Redis-backed TokenBucket implementation. // Uses a Lua script for atomic take operations. type RedisBucket struct { Client RedisClient KeyPrefix string } func NewRedisBucket(client RedisClient) *RedisBucket { return &RedisBucket{ Client: client, KeyPrefix: "throttle:", } } // takeLua atomically refills and takes a token. // KEYS[1] = bucket key // ARGV[1] = capacity // ARGV[2] = refill rate (tokens/sec) // ARGV[3] = now (unix seconds, float) // Returns 1 if allowed, 0 if denied. const takeLua = ` local key = KEYS[1] local capacity = tonumber(ARGV[1]) local refill_rate = tonumber(ARGV[2]) local now = tonumber(ARGV[3]) local data = redis.call('HMGET', key, 'tokens', 'last_refill') local tokens = tonumber(data[1]) local last_refill = tonumber(data[2]) if tokens == nil then -- New bucket redis.call('HSET', key, 'tokens', capacity - 1, 'last_refill', now) return 1 end -- Refill local elapsed = now - last_refill tokens = tokens + elapsed * refill_rate if tokens > capacity then tokens = capacity end if tokens >= 1 then tokens = tokens - 1 redis.call('HSET', key, 'tokens', tokens, 'last_refill', now) return 1 end redis.call('HSET', key, 'tokens', tokens, 'last_refill', now) return 0 ` func (r *RedisBucket) Take(ctx context.Context, key string, capacity int, refillRate float64) (bool, error) { now := float64(time.Now().UnixMicro()) / 1e6 result, err := r.Client.Eval(ctx, takeLua, []string{r.KeyPrefix + key}, capacity, refillRate, fmt.Sprintf("%.6f", now), ) if err != nil { return false, err } switch v := result.(type) { case int64: return v == 1, nil default: return false, fmt.Errorf("unexpected redis result type: %T", result) } } func (r *RedisBucket) Reset(ctx context.Context, key string) error { return r.Client.Del(ctx, r.KeyPrefix+key) }