package tasks import ( "bytes" "context" "database/sql" "fmt" "log/slog" "net/http" "strings" "time" "atlas9.dev/c/core/dbi" "atlas9.dev/c/demo/lib" ) type Workers struct { DB *sql.DB Handler http.Handler Access lib.Access Tables map[string]string Interval time.Duration LeaseDuration time.Duration MaxAttempts int BatchSize int } // RunAll starts a worker goroutine for each task queue. // Each task name maps to a route "POST /Tasks.{name}". func (ws *Workers) RunAll(ctx context.Context) { for name, table := range ws.Tables { w := &Worker{ DB: ws.DB, Consumer: func(tx dbi.DBI) Consumer { return NewSqliteConsumer(tx, table) }, Handler: ws.Handler, Route: "POST /" + name, Access: ws.Access, Interval: ws.Interval, LeaseDuration: ws.LeaseDuration, MaxAttempts: ws.MaxAttempts, BatchSize: ws.BatchSize, } go w.Run(ctx) } } type Worker struct { DB *sql.DB Consumer func(dbi.DBI) Consumer Handler http.Handler Route string Access lib.Access Interval time.Duration LeaseDuration time.Duration MaxAttempts int BatchSize int } func (w *Worker) Run(ctx context.Context) { for { if ctx.Err() != nil { return } processed := w.poll(ctx) if !processed { select { case <-time.After(w.interval()): case <-ctx.Done(): return } } } } func (w *Worker) interval() time.Duration { if w.Interval > 0 { return w.Interval } return 5 * time.Second } func (w *Worker) batchSize() int { if w.BatchSize > 0 { return w.BatchSize } return 1 } func (w *Worker) leaseDuration() time.Duration { if w.LeaseDuration > 0 { return w.LeaseDuration } return 30 * time.Second } func (w *Worker) poll(ctx context.Context) bool { var claimed []Task err := dbi.ReadWrite(ctx, w.DB, func(tx dbi.DBI) error { var err error claimed, err = w.Consumer(tx).Claim(ctx, w.leaseDuration(), w.batchSize()) return err }) if err != nil { slog.Error("claiming tasks", "err", err) return false } if len(claimed) == 0 { return false } for _, task := range claimed { err := w.execute(ctx, task) if err != nil { slog.Error("processing task", "id", task.ID, "attempts", task.Attempts, "err", err) if w.MaxAttempts > 0 && task.Attempts >= w.MaxAttempts { err := dbi.ReadWrite(ctx, w.DB, func(tx dbi.DBI) error { return w.Consumer(tx).Fail(ctx, task.ID) }) if err != nil { slog.Error("failing task", "id", task.ID, "err", err) } } else { retryAfter := time.Now().Add(backoff(task.Attempts)) err := dbi.ReadWrite(ctx, w.DB, func(tx dbi.DBI) error { return w.Consumer(tx).Retry(ctx, task.ID, retryAfter) }) if err != nil { slog.Error("retrying task", "id", task.ID, "err", err) } } continue } err = dbi.ReadWrite(ctx, w.DB, func(tx dbi.DBI) error { return w.Consumer(tx).Complete(ctx, task.ID) }) if err != nil { slog.Error("completing task", "id", task.ID, "err", err) } } return true } func (w *Worker) execute(ctx context.Context, task Task) error { method, path, _ := strings.Cut(w.Route, " ") req, err := http.NewRequestWithContext(ctx, method, path, bytes.NewReader(task.Payload)) if err != nil { return fmt.Errorf("building request: %w", err) } req.Header.Set("Content-Type", "application/json") // Grant the configured capabilities to the request. req = req.WithContext(lib.PutAccess(req.Context(), w.Access)) rec := &responseWriter{} w.Handler.ServeHTTP(rec, req) if rec.status >= 400 { return fmt.Errorf("task handler returned %d", rec.status) } return nil } // responseWriter is a minimal http.ResponseWriter that captures the status code. type responseWriter struct { status int } func (w *responseWriter) Header() http.Header { return http.Header{} } func (w *responseWriter) Write(b []byte) (int, error) { return len(b), nil } func (w *responseWriter) WriteHeader(statusCode int) { w.status = statusCode } func backoff(attempts int) time.Duration { d := time.Duration(1<