package throttle import ( "context" "database/sql" "fmt" "testing" "testing/synctest" "time" _ "github.com/mattn/go-sqlite3" ) func TestMemoryBucket_AllowsRequests(t *testing.T) { m := NewMemoryBucket() ctx := context.Background() ok, err := m.Take(ctx, "test", 5, 1.0) if err != nil { t.Fatal(err) } if !ok { t.Error("expected request to be allowed") } } func TestMemoryBucket_DeniesAfterCapacity(t *testing.T) { m := NewMemoryBucket() ctx := context.Background() for i := range 5 { ok, err := m.Take(ctx, "test", 5, 0) if err != nil { t.Fatal(err) } if !ok { t.Errorf("request %d should be allowed", i+1) } } ok, err := m.Take(ctx, "test", 5, 0) if err != nil { t.Fatal(err) } if ok { t.Error("request should be denied after capacity exhausted") } } func TestMemoryBucket_Refill(t *testing.T) { synctest.Test(t, func(t *testing.T) { m := NewMemoryBucket() ctx := context.Background() // Exhaust the bucket (capacity 1). m.Take(ctx, "test", 1, 0) // Verify denied. ok, _ := m.Take(ctx, "test", 1, 0) if ok { t.Fatal("should be denied") } // Advance time so 1 token refills (rate = 1/sec, advance 1 sec). time.Sleep(time.Second) ok, err := m.Take(ctx, "test", 1, 1.0) if err != nil { t.Fatal(err) } if !ok { t.Error("expected request to be allowed after refill") } }) } func TestMemoryBucket_RefillCappedAtCapacity(t *testing.T) { synctest.Test(t, func(t *testing.T) { m := NewMemoryBucket() ctx := context.Background() // Use one token from a capacity-3 bucket. m.Take(ctx, "test", 3, 0) // Advance a long time — should not exceed capacity. time.Sleep(time.Hour) // Use all 3 tokens (bucket should be at cap=3, minus one from Take). for i := range 3 { ok, _ := m.Take(ctx, "test", 3, 1.0) if !ok { t.Errorf("request %d should be allowed", i+1) } } ok, _ := m.Take(ctx, "test", 3, 1.0) if ok { t.Error("should be denied, bucket should not exceed capacity") } }) } func TestMemoryBucket_Reset(t *testing.T) { m := NewMemoryBucket() ctx := context.Background() // Exhaust bucket. for i := 0; i < 3; i++ { m.Take(ctx, "test", 3, 0) } ok, _ := m.Take(ctx, "test", 3, 0) if ok { t.Fatal("should be denied") } // Reset and verify allowed again. m.Reset(ctx, "test") ok, err := m.Take(ctx, "test", 3, 0) if err != nil { t.Fatal(err) } if !ok { t.Error("expected request allowed after reset") } } func TestMemoryBucket_IndependentKeys(t *testing.T) { m := NewMemoryBucket() ctx := context.Background() // Exhaust key "a". m.Take(ctx, "a", 1, 0) ok, _ := m.Take(ctx, "a", 1, 0) if ok { t.Error("key 'a' should be denied") } // Key "b" should still be allowed. ok, err := m.Take(ctx, "b", 1, 0) if err != nil { t.Fatal(err) } if !ok { t.Error("key 'b' should be allowed") } } func TestMemoryBucket_Cleanup(t *testing.T) { synctest.Test(t, func(t *testing.T) { m := NewMemoryBucket() ctx := context.Background() m.Take(ctx, "old", 5, 0) m.Take(ctx, "recent", 5, 0) // Advance time past cleanup threshold for "old" only. time.Sleep(2 * time.Hour) m.Take(ctx, "recent", 5, 0) // touch "recent" m.Cleanup(time.Hour) // "old" should be cleaned up, "recent" should remain. m.mu.Lock() _, hasOld := m.buckets["old"] _, hasRecent := m.buckets["recent"] m.mu.Unlock() if hasOld { t.Error("expected 'old' to be cleaned up") } if !hasRecent { t.Error("expected 'recent' to still exist") } }) } // --- SQLite backend tests --- func setupSqliteDB(t *testing.T) *sql.DB { t.Helper() db, err := sql.Open("sqlite3", ":memory:") if err != nil { t.Fatal(err) } t.Cleanup(func() { db.Close() }) _, err = db.Exec(SqliteBucketSchema) if err != nil { t.Fatal(err) } return db } func TestSqliteBucket_AllowsRequests(t *testing.T) { s := NewSqliteBucket(setupSqliteDB(t)) ctx := context.Background() ok, err := s.Take(ctx, "test", 5, 1.0) if err != nil { t.Fatal(err) } if !ok { t.Error("expected request to be allowed") } } func TestSqliteBucket_DeniesAfterCapacity(t *testing.T) { s := NewSqliteBucket(setupSqliteDB(t)) ctx := context.Background() for i := 0; i < 5; i++ { ok, err := s.Take(ctx, "test", 5, 0) if err != nil { t.Fatal(err) } if !ok { t.Errorf("request %d should be allowed", i+1) } } ok, err := s.Take(ctx, "test", 5, 0) if err != nil { t.Fatal(err) } if ok { t.Error("request should be denied after capacity exhausted") } } func TestSqliteBucket_Refill(t *testing.T) { synctest.Test(t, func(t *testing.T) { s := NewSqliteBucket(setupSqliteDB(t)) ctx := context.Background() s.Take(ctx, "test", 1, 0) ok, _ := s.Take(ctx, "test", 1, 0) if ok { t.Fatal("should be denied") } time.Sleep(time.Second) ok, err := s.Take(ctx, "test", 1, 1.0) if err != nil { t.Fatal(err) } if !ok { t.Error("expected request to be allowed after refill") } }) } func TestSqliteBucket_Reset(t *testing.T) { s := NewSqliteBucket(setupSqliteDB(t)) ctx := context.Background() for i := 0; i < 3; i++ { s.Take(ctx, "test", 3, 0) } ok, _ := s.Take(ctx, "test", 3, 0) if ok { t.Fatal("should be denied") } s.Reset(ctx, "test") ok, err := s.Take(ctx, "test", 3, 0) if err != nil { t.Fatal(err) } if !ok { t.Error("expected request allowed after reset") } } func TestSqliteBucket_IndependentKeys(t *testing.T) { s := NewSqliteBucket(setupSqliteDB(t)) ctx := context.Background() s.Take(ctx, "a", 1, 0) ok, _ := s.Take(ctx, "a", 1, 0) if ok { t.Error("key 'a' should be denied") } ok, err := s.Take(ctx, "b", 1, 0) if err != nil { t.Fatal(err) } if !ok { t.Error("key 'b' should be allowed") } } // --- Redis backend tests (mock) --- type mockRedisClient struct { data map[string]map[string]string } func newMockRedisClient() *mockRedisClient { return &mockRedisClient{data: make(map[string]map[string]string)} } func (m *mockRedisClient) Eval(ctx context.Context, script string, keys []string, args ...any) (any, error) { // Minimal Lua script emulation for the token bucket script. key := keys[0] capacity := toFloat64(args[0]) refillRate := toFloat64(args[1]) nowStr := args[2].(string) now := parseFloat64(nowStr) h, ok := m.data[key] if !ok { // New bucket. m.data[key] = map[string]string{ "tokens": fmt.Sprintf("%.6f", capacity-1), "last_refill": fmt.Sprintf("%.6f", now), } return int64(1), nil } tokens := parseFloat64(h["tokens"]) lastRefill := parseFloat64(h["last_refill"]) elapsed := now - lastRefill tokens += elapsed * refillRate if tokens > capacity { tokens = capacity } if tokens >= 1 { tokens-- m.data[key]["tokens"] = fmt.Sprintf("%.6f", tokens) m.data[key]["last_refill"] = fmt.Sprintf("%.6f", now) return int64(1), nil } m.data[key]["tokens"] = fmt.Sprintf("%.6f", tokens) m.data[key]["last_refill"] = fmt.Sprintf("%.6f", now) return int64(0), nil } func (m *mockRedisClient) Del(ctx context.Context, keys ...string) error { for _, k := range keys { delete(m.data, k) } return nil } func toFloat64(v any) float64 { switch x := v.(type) { case int: return float64(x) case float64: return x default: return 0 } } func parseFloat64(s string) float64 { var f float64 fmt.Sscanf(s, "%f", &f) return f } func TestRedisBucket_AllowsRequests(t *testing.T) { r := NewRedisBucket(newMockRedisClient()) ctx := context.Background() ok, err := r.Take(ctx, "test", 5, 1.0) if err != nil { t.Fatal(err) } if !ok { t.Error("expected request to be allowed") } } func TestRedisBucket_DeniesAfterCapacity(t *testing.T) { r := NewRedisBucket(newMockRedisClient()) ctx := context.Background() for i := 0; i < 5; i++ { ok, err := r.Take(ctx, "test", 5, 0) if err != nil { t.Fatal(err) } if !ok { t.Errorf("request %d should be allowed", i+1) } } ok, err := r.Take(ctx, "test", 5, 0) if err != nil { t.Fatal(err) } if ok { t.Error("request should be denied after capacity exhausted") } } func TestRedisBucket_Refill(t *testing.T) { synctest.Test(t, func(t *testing.T) { r := NewRedisBucket(newMockRedisClient()) ctx := context.Background() r.Take(ctx, "test", 1, 0) ok, _ := r.Take(ctx, "test", 1, 0) if ok { t.Fatal("should be denied") } time.Sleep(time.Second) ok, err := r.Take(ctx, "test", 1, 1.0) if err != nil { t.Fatal(err) } if !ok { t.Error("expected request to be allowed after refill") } }) } func TestRedisBucket_Reset(t *testing.T) { r := NewRedisBucket(newMockRedisClient()) ctx := context.Background() for i := 0; i < 3; i++ { r.Take(ctx, "test", 3, 0) } ok, _ := r.Take(ctx, "test", 3, 0) if ok { t.Fatal("should be denied") } r.Reset(ctx, "test") ok, err := r.Take(ctx, "test", 3, 0) if err != nil { t.Fatal(err) } if !ok { t.Error("expected request allowed after reset") } }