package tokens import ( "context" "database/sql" "errors" "testing" "time" "atlas9.dev/c/core" _ "github.com/mattn/go-sqlite3" ) func setupTokenDB(t *testing.T) *sql.DB { t.Helper() db, err := sql.Open("sqlite3", ":memory:") if err != nil { t.Fatal(err) } t.Cleanup(func() { db.Close() }) _, err = db.Exec(`CREATE TABLE test_tokens ( key TEXT PRIMARY KEY, data TEXT NOT NULL, expires_at TEXT NOT NULL )`) if err != nil { t.Fatal(err) } return db } // --- SqliteStore Tests --- func TestSqliteStorePutGet(t *testing.T) { db := setupTokenDB(t) store := NewSqliteStore[string](db, Options{Table: "test_tokens", Expiration: 1 * time.Hour}) ctx := context.Background() tok, err := store.Put(ctx, "mykey", "mydata") if err != nil { t.Fatal(err) } if tok.Key != "mykey" { t.Errorf("key = %q, want %q", tok.Key, "mykey") } if tok.Data != "mydata" { t.Errorf("data = %q, want %q", tok.Data, "mydata") } if tok.ExpiresAt.IsZero() { t.Error("expected non-zero expiration") } got, err := store.Get(ctx, "mykey") if err != nil { t.Fatal(err) } if got.Data != "mydata" { t.Errorf("got data = %q, want %q", got.Data, "mydata") } } func TestSqliteStoreExpiration(t *testing.T) { db := setupTokenDB(t) store := NewSqliteStore[string](db, Options{Table: "test_tokens", Expiration: 1 * time.Millisecond}) ctx := context.Background() _, err := store.Put(ctx, "expiring", "data") if err != nil { t.Fatal(err) } time.Sleep(10 * time.Millisecond) _, err = store.Get(ctx, "expiring") if !errors.Is(err, core.ErrNotFound) { t.Errorf("expected ErrNotFound for expired token, got %v", err) } } func TestSqliteStoreDelete(t *testing.T) { db := setupTokenDB(t) store := NewSqliteStore[string](db, Options{Table: "test_tokens", Expiration: 1 * time.Hour}) ctx := context.Background() _, err := store.Put(ctx, "deleteme", "data") if err != nil { t.Fatal(err) } if err := store.Delete(ctx, "deleteme"); err != nil { t.Fatal(err) } _, err = store.Get(ctx, "deleteme") if !errors.Is(err, core.ErrNotFound) { t.Errorf("expected ErrNotFound after delete, got %v", err) } } func TestSqliteStoreUpsert(t *testing.T) { db := setupTokenDB(t) store := NewSqliteStore[string](db, Options{Table: "test_tokens", Expiration: 1 * time.Hour}) ctx := context.Background() _, err := store.Put(ctx, "key1", "first") if err != nil { t.Fatal(err) } _, err = store.Put(ctx, "key1", "second") if err != nil { t.Fatal(err) } got, err := store.Get(ctx, "key1") if err != nil { t.Fatal(err) } if got.Data != "second" { t.Errorf("data = %q, want %q (upsert should overwrite)", got.Data, "second") } } func TestSqliteStoreGetNotFound(t *testing.T) { db := setupTokenDB(t) store := NewSqliteStore[string](db, Options{Table: "test_tokens", Expiration: 1 * time.Hour}) ctx := context.Background() _, err := store.Get(ctx, "nonexistent") if !errors.Is(err, core.ErrNotFound) { t.Errorf("expected ErrNotFound, got %v", err) } } // --- SecureStore Tests --- func TestSecureStorePutGet(t *testing.T) { db := setupTokenDB(t) inner := NewSqliteStore[VerifiedData[string]](db, Options{Table: "test_tokens", Expiration: 1 * time.Hour}) store := NewSecureStore[string](inner, RandomString(32)) ctx := context.Background() tok, err := store.Put(ctx, "mykey", "mydata") if err != nil { t.Fatal(err) } // Combined token should contain a dot if len(tok.Key) <= len("mykey.") { t.Errorf("expected combined token longer than key, got %q", tok.Key) } got, err := store.Get(ctx, tok.Key) if err != nil { t.Fatal(err) } if got.Data != "mydata" { t.Errorf("data = %q, want %q", got.Data, "mydata") } } func TestSecureStoreWrongCode(t *testing.T) { db := setupTokenDB(t) inner := NewSqliteStore[VerifiedData[string]](db, Options{Table: "test_tokens", Expiration: 1 * time.Hour}) store := NewSecureStore[string](inner, RandomString(32)) ctx := context.Background() _, err := store.Put(ctx, "mykey", "mydata") if err != nil { t.Fatal(err) } // Use correct key but wrong code _, err = store.Get(ctx, "mykey.wrongcode") if !errors.Is(err, core.ErrNotFound) { t.Errorf("expected ErrNotFound for wrong code, got %v", err) } } func TestSecureStoreMissingDot(t *testing.T) { db := setupTokenDB(t) inner := NewSqliteStore[VerifiedData[string]](db, Options{Table: "test_tokens", Expiration: 1 * time.Hour}) store := NewSecureStore[string](inner, RandomString(32)) ctx := context.Background() _, err := store.Get(ctx, "nodottoken") if !errors.Is(err, core.ErrNotFound) { t.Errorf("expected ErrNotFound for token without dot, got %v", err) } } func TestSecureStoreDelete(t *testing.T) { db := setupTokenDB(t) inner := NewSqliteStore[VerifiedData[string]](db, Options{Table: "test_tokens", Expiration: 1 * time.Hour}) store := NewSecureStore[string](inner, RandomString(32)) ctx := context.Background() tok, err := store.Put(ctx, "mykey", "mydata") if err != nil { t.Fatal(err) } if err := store.Delete(ctx, tok.Key); err != nil { t.Fatal(err) } _, err = store.Get(ctx, tok.Key) if !errors.Is(err, core.ErrNotFound) { t.Errorf("expected ErrNotFound after delete, got %v", err) } } // --- KeyGen Tests --- func TestRandomString(t *testing.T) { gen := RandomString(32) a, err := gen() if err != nil { t.Fatal(err) } if a == "" { t.Error("expected non-empty string") } b, err := gen() if err != nil { t.Fatal(err) } if a == b { t.Error("expected different strings from two calls") } } func TestNumericCode(t *testing.T) { gen := NumericCode(6) code, err := gen() if err != nil { t.Fatal(err) } if len(code) != 6 { t.Errorf("code length = %d, want 6", len(code)) } for _, c := range code { if c < '0' || c > '9' { t.Errorf("code contains non-digit: %q", c) } } }