5 Commits

Author SHA1 Message Date
5c1b28ee0a Added support for Postgres polling sources 2026-03-29 10:53:13 -05:00
247937b65e Upgraded feedkit's handling of stream sources 2026-03-29 08:34:35 -05:00
eb9a7cb349 Refactor feedkit boundaries ahead of v1
Remove global Postgres schema registration in favor of explicit schema-aware sink factory wiring, and update weatherfeeder to register the Postgres sink explicitly. Add optional per-source HTTP timeout and response body limit overrides while keeping feedkit defaults. Remove remaining legacy source/config compatibility surfaces, including singular kind support and old source registry/type aliases, and migrate weatherfeeder sources to plural `Kinds()` metadata. Clean up related docs, tests, and sample config to match the new Postgres, HTTP, and NATS configuration model.
2026-03-28 13:52:48 -05:00
3281368922 Cleaned up documentation and removed stubs and TODOs throughout the application 2026-03-28 13:02:37 -05:00
3ef93faf69 Moved broadly useful helper functions upstream into feedkit 2026-03-28 11:29:09 -05:00
41 changed files with 2916 additions and 656 deletions

177
README.md
View File

@@ -1,127 +1,92 @@
# feedkit # feedkit
`feedkit` provides domain-agnostic plumbing for feed-processing daemons. `feedkit` is a small Go toolkit for building feed-processing daemons.
A daemon built on feedkit typically: It gives you the reusable plumbing around collection, processing, routing, and
- ingests upstream input (polling APIs or consuming streams) emission, while leaving domain concepts, schemas, and application wiring in
your daemon. The intended shape is a family of sibling applications such as
`weatherfeeder`, `newsfeeder`, or `earthquakefeeder` that all share the same
infrastructure patterns without sharing domain logic.
## What It Does
A daemon built on `feedkit` typically:
- ingests upstream input by polling HTTP APIs or consuming streams
- emits domain-agnostic `event.Event` values - emits domain-agnostic `event.Event` values
- applies optional processing (normalization, dedupe, policy) - optionally processes those events with stages like dedupe or normalization
- routes events to sinks (stdout, NATS, files, databases, etc.) - routes events to one or more sinks such as stdout, NATS, or Postgres
Conceptually, the pipeline is:
`Collect -> Process -> Route -> Emit`
## Philosophy ## Philosophy
feedkit is not a framework. It provides small composable packages and leaves `feedkit` is intentionally not a framework.
lifecycle, domain schemas, and domain-specific validation in your daemon.
## Conceptual pipeline It does not try to own:
- your domain payload schemas
- your domain event kinds
- your daemon lifecycle or `main.go`
- your observability stack or deployment model
Collect -> Process (optional stages, including dedupe + normalize) -> Route -> Emit Instead, it provides small composable packages that are easy to wire together in
different daemons.
| Stage | Package(s) | ## When To Use It
|---|---|
| Collect | `sources`, `scheduler` |
| Process | `pipeline`, `processors`, `processors/dedupe`, `processors/normalize` (optional stages) |
| Route | `dispatch` |
| Emit | `sinks` |
| Configure | `config` |
## Core packages `feedkit` is a good fit when you want:
- multiple small ingestion daemons with shared infrastructure patterns
- clear separation between raw upstream payloads and normalized canonical models
- reusable routing and sink behavior across domains
- strong config and event-envelope conventions without centralizing domain rules
### `config` It is a poor fit if you want a monolithic framework that dictates application
structure end-to-end.
Loads YAML config with strict decoding and domain-agnostic validation. ## Built-In Capabilities
`SourceConfig` supports both source modes: `feedkit` currently includes:
- `mode: poll` requires `every` - strict YAML config loading and validation
- `mode: stream` forbids `every` - polling and streaming source abstractions
- omitted `mode` means auto (inferred from the registered driver type) - scheduler orchestration for configured sources and supervised stream workers
- optional pipeline processors
- built-in dedupe and normalization processors
- route compilation and sink fanout
- built-in sinks for `stdout`, `nats`, and `postgres`
It also supports optional expected source kinds: The Postgres sink is intentionally split between feedkit-owned infrastructure
- `kinds: ["observation", "alert"]` (preferred) and daemon-owned schema mapping. `feedkit` manages connection setup, DDL,
- `kind: "observation"` (legacy fallback) writes, and pruning; downstream applications define the schema and event mapper.
### `event` ## Typical Wiring
Defines the domain-agnostic event envelope (`event.Event`) used across the system. At a high level, a daemon built on `feedkit` does this:
### `sources`
Defines source interfaces and driver registry:
```go
type Input interface {
Name() string
}
type PollSource interface {
Input
Poll(ctx context.Context) ([]event.Event, error)
}
type StreamSource interface {
Input
Run(ctx context.Context, out chan<- event.Event) error
}
```
Notes:
- a poll can emit `0..N` events
- stream sources emit events continuously
- a single source may emit multiple event kinds
- driver implementations live in downstream daemons and are registered via `sources.Registry`
### `scheduler`
Runs one goroutine per source job:
- poll sources: cadence driven (`every` + jitter)
- stream sources: continuous run loop
### `pipeline`
Optional processing chain between collection and dispatch.
Processors can transform, drop, or reject events.
### `processors`
Defines the generic processor interface and a named-driver registry used by
daemons to build ordered processor chains.
### `processors/dedupe`
Built-in in-memory LRU dedupe processor that drops repeated events by `Event.ID`.
### `processors/normalize`
Concrete normalization processor implementation. Typical use: sources emit raw
payload events, then a normalize stage maps them to canonical schemas.
### `dispatch`
Compiles routes and fans out events to sinks with per-sink queue/worker isolation.
### `sinks`
Defines sink interface and sink registry. Built-ins include:
- `stdout`
- `nats`
- `postgres`
Detailed Postgres configuration and wiring examples live in package docs:
`sinks/doc.go`.
## Typical wiring
1. Load config. 1. Load config.
2. Register/build sources from `cfg.Sources`. 2. Register domain-specific source drivers.
3. Register/build sinks from `cfg.Sinks`. 3. Register built-in and/or custom sinks.
4. Compile routes. 4. Build sources, sinks, and optional processor chain from config.
5. Start scheduler (`sources -> bus`). 5. Compile routes.
6. Start dispatcher (`bus -> pipeline -> sinks`). 6. Start the scheduler and dispatcher.
## Non-goals The package docs are the better source of truth for code-level details. In
particular, each subpackage `doc.go` describes its external API surface and any
optional helper APIs in `helpers.go`.
feedkit intentionally does not: ## Package Layout
- define domain payload schemas
- enforce domain-specific event kinds The major packages are:
- own application lifecycle - `config`: config loading and validation
- prescribe observability stack choices - `event`: the domain-agnostic event envelope
- `sources`: source interfaces and reusable source helpers
- `scheduler`: source execution and cadence management
- `processors`: processor interfaces and registry
- `processors/dedupe`: built-in in-memory dedupe processor
- `processors/normalize`: built-in normalization processor and helpers
- `pipeline`: optional processor chain
- `dispatch`: route compilation and fanout
- `sinks`: sink interfaces, built-ins, and explicit Postgres factory helpers
The root package docs in `doc.go` provide a concise package-by-package map for
Go documentation consumers.

View File

@@ -69,39 +69,31 @@ type SourceConfig struct {
// If set, it describes the expected emitted event kinds for this source. // If set, it describes the expected emitted event kinds for this source.
Kinds []string `yaml:"kinds"` Kinds []string `yaml:"kinds"`
// Kind is the legacy singular form. Prefer "kinds".
// If both kind and kinds are set, validation fails.
Kind string `yaml:"kind"`
// Params are driver-specific settings (URL, headers, station IDs, API keys, etc.). // Params are driver-specific settings (URL, headers, station IDs, API keys, etc.).
// The driver implementation is responsible for reading/validating these. // The driver implementation is responsible for reading/validating these.
Params map[string]any `yaml:"params"` Params map[string]any `yaml:"params"`
} }
// ExpectedKinds returns normalized expected kinds from config. // ExpectedKinds returns normalized expected kinds from config.
// "kinds" takes precedence; "kind" is used as a legacy fallback.
func (cfg SourceConfig) ExpectedKinds() []string { func (cfg SourceConfig) ExpectedKinds() []string {
if len(cfg.Kinds) > 0 { out := make([]string, 0, len(cfg.Kinds))
out := make([]string, 0, len(cfg.Kinds)) for _, k := range cfg.Kinds {
for _, k := range cfg.Kinds { k = strings.TrimSpace(k)
k = strings.TrimSpace(k) if k == "" {
if k == "" { continue
continue
}
out = append(out, k)
} }
return out out = append(out, k)
} }
if k := strings.TrimSpace(cfg.Kind); k != "" { if len(out) == 0 {
return []string{k} return nil
} }
return nil return out
} }
// SinkConfig describes one output sink adapter. // SinkConfig describes one output sink adapter.
type SinkConfig struct { type SinkConfig struct {
Name string `yaml:"name"` Name string `yaml:"name"`
Driver string `yaml:"driver"` // "stdout", "file", "postgres", "rabbitmq", ... Driver string `yaml:"driver"` // "stdout", "nats", "postgres", ...
Params map[string]any `yaml:"params"` // sink-specific settings Params map[string]any `yaml:"params"` // sink-specific settings
} }

View File

@@ -12,20 +12,12 @@ func TestSourceConfigExpectedKinds(t *testing.T) {
want []string want []string
}{ }{
{ {
name: "plural kinds preferred", name: "plural kinds normalized",
cfg: SourceConfig{ cfg: SourceConfig{
Kinds: []string{" observation ", "forecast"}, Kinds: []string{" observation ", "forecast"},
Kind: "alert",
}, },
want: []string{"observation", "forecast"}, want: []string{"observation", "forecast"},
}, },
{
name: "legacy singular fallback",
cfg: SourceConfig{
Kind: " alert ",
},
want: []string{"alert"},
},
{ {
name: "empty kinds", name: "empty kinds",
cfg: SourceConfig{}, cfg: SourceConfig{},

View File

@@ -105,13 +105,7 @@ func (c *Config) Validate() error {
} }
} }
// Kind/Kinds (optional) // Kinds (optional)
if s.Kind != "" && len(s.Kinds) > 0 {
m.Add(fieldErr(path+".kind", `cannot be set when "kinds" is provided (use only "kinds")`))
}
if s.Kind != "" && strings.TrimSpace(s.Kind) == "" {
m.Add(fieldErr(path+".kind", "cannot be blank (omit it entirely, or provide a non-empty string)"))
}
for j, k := range s.Kinds { for j, k := range s.Kinds {
kpath := fmt.Sprintf("%s.kinds[%d]", path, j) kpath := fmt.Sprintf("%s.kinds[%d]", path, j)
if strings.TrimSpace(k) == "" { if strings.TrimSpace(k) == "" {
@@ -141,7 +135,7 @@ func (c *Config) Validate() error {
} }
if strings.TrimSpace(s.Driver) == "" { if strings.TrimSpace(s.Driver) == "" {
m.Add(fieldErr(path+".driver", "is required (stdout|file|postgres|rabbitmq|...)")) m.Add(fieldErr(path+".driver", "is required (stdout|nats|postgres|...)"))
} }
// Params can be nil; that's fine. // Params can be nil; that's fine.

View File

@@ -114,31 +114,6 @@ func TestValidate_SourceModeRejectsUnknownValue(t *testing.T) {
} }
} }
func TestValidate_SourceKindAndKindsConflict(t *testing.T) {
cfg := &Config{
Sources: []SourceConfig{
{
Name: "src1",
Driver: "driver1",
Every: Duration{Duration: time.Minute},
Kind: "observation",
Kinds: []string{"forecast"},
},
},
Sinks: []SinkConfig{
{Name: "sink1", Driver: "stdout"},
},
}
err := cfg.Validate()
if err == nil {
t.Fatalf("expected error, got nil")
}
if !strings.Contains(err.Error(), `sources[0].kind`) {
t.Fatalf("expected error to mention sources[0].kind, got: %v", err)
}
}
func TestValidate_SourceKindsRejectBlankEntries(t *testing.T) { func TestValidate_SourceKindsRejectBlankEntries(t *testing.T) {
cfg := &Config{ cfg := &Config{
Sources: []SourceConfig{ Sources: []SourceConfig{

137
doc.go
View File

@@ -1,130 +1,57 @@
// Package feedkit provides domain-agnostic plumbing for feed-processing daemons. // Package feedkit provides a high-level map of the feedkit package set.
// //
// A feed daemon ingests upstream input, turns it into event.Event values, applies // Most real applications do not import the root package directly. Instead, they
// optional processing, and emits to sinks. // compose the subpackages that handle configuration, collection, processing,
// routing, and sinks.
// //
// Conceptual flow: // The usual flow through feedkit is:
// //
// Collect -> Process (optional stages, including dedupe + normalize) -> Route -> Emit // Collect -> Process -> Route -> Emit
// //
// In feedkit this maps to: // That flow maps to packages like this:
//
// Collect: sources + scheduler
// Process: pipeline + processors + processors/dedupe + processors/normalize (optional stages)
// Route: dispatch
// Emit: sinks
// Config: config
//
// feedkit intentionally does not define domain payload schemas or domain-specific
// validation rules. Those belong in each concrete daemon.
//
// Public packages
// //
// - config // - config
// YAML config loading/validation (strict decode + domain-agnostic checks). // Loads and validates daemon config. This package owns domain-agnostic
// // config shape and consistency checks.
// SourceConfig supports both polling and streaming sources:
//
// - mode: "poll" | "stream" | omitted (auto by driver type)
//
// - every: poll interval (required for mode="poll")
//
// - kinds: optional expected emitted kinds
//
// - kind: legacy singular fallback
//
// - params: driver-specific settings
// //
// - event // - event
// Domain-agnostic event envelope (ID, Kind, Source, EmittedAt, Schema, Payload). // Defines the event.Event envelope shared across sources, processors,
// dispatch, and sinks.
// //
// - sources // - sources
// Source abstractions and source-driver registry. // Defines polling and streaming source interfaces, the source registry, and
// // reusable source helpers.
// There are two source interfaces:
//
// - PollSource: Poll(ctx) ([]event.Event, error)
//
// - StreamSource: Run(ctx, out) error
//
// Both share Input{Name()}. A source may emit 0..N events per poll/run step,
// and may emit multiple event kinds.
//
// For HTTP-backed polling sources, sources.NewHTTPSource provides a shared
// helper for generic params:
//
// - params.url
//
// - params.user_agent
//
// - params.conditional (optional, default true)
//
// When conditional polling is enabled, feedkit opportunistically uses ETag
// and Last-Modified validators. A 304 Not Modified response is treated as a
// successful poll that emits no events.
// //
// - scheduler // - scheduler
// Runs one goroutine per job: // Runs configured sources on a cadence and supervises long-lived stream
// // workers with restart/fatal handling.
// - PollSource jobs run on Every (+ jitter)
//
// - StreamSource jobs run continuously
//
// - pipeline
// Processor chain between scheduler and dispatch.
// Processors can transform, drop, or reject events.
// //
// - processors // - processors
// Generic processor interface and named factory registry for wiring chains. // Defines the generic processor interface and registry used to build
// ordered processor chains.
// //
// - processors/dedupe // - processors/dedupe
// Built-in in-memory LRU dedupe processor keyed by Event.ID. // Built-in in-memory dedupe processor keyed by Event.ID.
// //
// - processors/normalize // - processors/normalize
// Concrete pipeline processor for raw->canonical mapping. // Built-in normalization processor plus helper APIs for raw-to-canonical
// If no normalizer matches, the event passes through unchanged by default. // event mapping.
//
// - pipeline
// Applies an ordered processor chain between collection and dispatch.
// //
// - dispatch // - dispatch
// Routes events to sinks and isolates slow sinks via per-sink queues/workers. // Compiles routes and fans events out to sinks with per-sink isolation.
// //
// - sinks // - sinks
// Sink abstractions + sink registry. // Defines sink interfaces, the sink registry, schema-free built-in sinks,
// Built-ins include stdout, NATS, and Postgres. For Postgres, downstream // and explicit Postgres factory helpers.
// code registers table schemas/mappers while feedkit manages DDL, writes,
// optional automatic retention pruning (via sink params.prune), and
// manual prune helpers. Postgres table schemas must declare PruneColumn.
// //
// Typical wiring (daemon main.go) // feedkit is intentionally domain-agnostic. Domain schemas, domain event kinds,
// upstream-specific parsing, and daemon lifecycle remain the responsibility of
// each concrete application.
// //
// 1. Load config. // For repository-level overview and usage narrative, see README.md. For
// 2. Register source drivers and build sources from config.Sources. // code-level details, each subpackage doc.go is the source of truth for that
// 3. Register sink drivers and build sinks from config.Sinks. // package's public API surface and optional helpers.
// 4. Compile routes.
// 5. Start scheduler (sources -> bus) and dispatcher (bus -> pipeline -> sinks).
//
// Sketch:
//
// cfg, _ := config.Load("config.yml")
// srcReg := sources.NewRegistry()
// // domain registers poll/stream drivers...
//
// var jobs []scheduler.Job
// for _, sc := range cfg.Sources {
// src, _ := srcReg.BuildInput(sc)
// jobs = append(jobs, scheduler.Job{
// Source: src,
// Every: sc.Every.Duration,
// })
// }
//
// bus := make(chan event.Event, 256)
// s := &scheduler.Scheduler{Jobs: jobs, Out: bus, Logf: logf}
// // start dispatcher similarly...
//
// # Context and cancellation
//
// All blocking work should honor context cancellation:
// - source polling/streaming I/O
// - sink consumption
// - any expensive processor work
package feedkit package feedkit

View File

@@ -0,0 +1,67 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"net/url"
"strings"
"time"
_ "github.com/lib/pq"
)
const initTimeout = 5 * time.Second
var (
sqlOpen = sql.Open
pingDB = func(ctx context.Context, db *sql.DB) error { return db.PingContext(ctx) }
)
// ConnConfig describes the minimal connection settings shared by feedkit's
// Postgres readers and writers.
type ConnConfig struct {
URI string
Username string
Password string
}
// BuildDSN validates a Postgres URI and injects credentials into it.
func BuildDSN(cfg ConnConfig) (string, error) {
u, err := url.Parse(strings.TrimSpace(cfg.URI))
if err != nil {
return "", fmt.Errorf("invalid uri: %w", err)
}
if u.Scheme == "" {
return "", fmt.Errorf("invalid uri: missing scheme")
}
if u.Host == "" {
return "", fmt.Errorf("invalid uri: missing host")
}
u.User = url.UserPassword(cfg.Username, cfg.Password)
return u.String(), nil
}
// Open builds a DSN, opens a database handle, and verifies connectivity with a
// bounded ping before returning the handle.
func Open(ctx context.Context, cfg ConnConfig) (*sql.DB, error) {
dsn, err := BuildDSN(cfg)
if err != nil {
return nil, err
}
db, err := sqlOpen("postgres", dsn)
if err != nil {
return nil, err
}
pingCtx, cancel := context.WithTimeout(ctx, initTimeout)
defer cancel()
if err := pingDB(pingCtx, db); err != nil {
_ = db.Close()
return nil, err
}
return db, nil
}

View File

@@ -0,0 +1,165 @@
package postgres
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"io"
"net/url"
"strings"
"sync"
"testing"
)
func withPostgresPackageTestState(t *testing.T) {
t.Helper()
oldSQLOpen := sqlOpen
oldPingDB := pingDB
t.Cleanup(func() {
sqlOpen = oldSQLOpen
pingDB = oldPingDB
})
}
func TestBuildDSNInjectsCredentials(t *testing.T) {
dsn, err := BuildDSN(ConnConfig{
URI: " postgres://db.example.local:5432/feedkit?sslmode=disable ",
Username: "app_user",
Password: "app_pass",
})
if err != nil {
t.Fatalf("BuildDSN() error = %v", err)
}
u, err := url.Parse(dsn)
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
if u.User == nil || u.User.Username() != "app_user" {
t.Fatalf("username = %q, want app_user", u.User.Username())
}
pass, ok := u.User.Password()
if !ok || pass != "app_pass" {
t.Fatalf("password = %q, want app_pass", pass)
}
}
func TestBuildDSNRejectsInvalidURI(t *testing.T) {
_, err := BuildDSN(ConnConfig{URI: "http://[::1", Username: "u", Password: "p"})
if err == nil {
t.Fatalf("BuildDSN() error = nil, want error")
}
if !strings.Contains(err.Error(), "invalid uri") {
t.Fatalf("BuildDSN() error = %q", err)
}
}
func TestBuildDSNRejectsMissingScheme(t *testing.T) {
_, err := BuildDSN(ConnConfig{URI: "//db.example.local/feedkit", Username: "u", Password: "p"})
if err == nil {
t.Fatalf("BuildDSN() error = nil, want error")
}
if !strings.Contains(err.Error(), "missing scheme") {
t.Fatalf("BuildDSN() error = %q", err)
}
}
func TestBuildDSNRejectsMissingHost(t *testing.T) {
_, err := BuildDSN(ConnConfig{URI: "postgres:///feedkit", Username: "u", Password: "p"})
if err == nil {
t.Fatalf("BuildDSN() error = nil, want error")
}
if !strings.Contains(err.Error(), "missing host") {
t.Fatalf("BuildDSN() error = %q", err)
}
}
func TestOpenPropagatesOpenFailure(t *testing.T) {
withPostgresPackageTestState(t)
sqlOpen = func(_, _ string) (*sql.DB, error) {
return nil, errors.New("open failed")
}
_, err := Open(context.Background(), ConnConfig{
URI: "postgres://db.example.local/feedkit",
Username: "u",
Password: "p",
})
if err == nil {
t.Fatalf("Open() error = nil, want error")
}
if !strings.Contains(err.Error(), "open failed") {
t.Fatalf("Open() error = %q", err)
}
}
func TestOpenPropagatesPingFailure(t *testing.T) {
withPostgresPackageTestState(t)
const driverName = "feedkit_internal_postgres_ping_fail"
registerPingTestDriver(driverName, errors.New("ping failed"))
sqlOpen = func(_, _ string) (*sql.DB, error) {
return sql.Open(driverName, "")
}
_, err := Open(context.Background(), ConnConfig{
URI: "postgres://db.example.local/feedkit",
Username: "u",
Password: "p",
})
if err == nil {
t.Fatalf("Open() error = nil, want error")
}
if !strings.Contains(err.Error(), "ping failed") {
t.Fatalf("Open() error = %q", err)
}
}
var (
pingDriverMu sync.Mutex
pingDriverSeen = map[string]bool{}
)
func registerPingTestDriver(name string, pingErr error) {
pingDriverMu.Lock()
defer pingDriverMu.Unlock()
if pingDriverSeen[name] {
return
}
sql.Register(name, &pingTestDriver{pingErr: pingErr})
pingDriverSeen[name] = true
}
type pingTestDriver struct {
pingErr error
}
func (d *pingTestDriver) Open(string) (driver.Conn, error) {
return &pingTestConn{pingErr: d.pingErr}, nil
}
type pingTestConn struct {
pingErr error
}
func (c *pingTestConn) Prepare(string) (driver.Stmt, error) {
return nil, errors.New("not implemented")
}
func (c *pingTestConn) Close() error { return nil }
func (c *pingTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") }
func (c *pingTestConn) Ping(context.Context) error { return c.pingErr }
func (c *pingTestConn) QueryContext(context.Context, string, []driver.NamedValue) (driver.Rows, error) {
return &pingTestRows{}, nil
}
type pingTestRows struct{}
func (r *pingTestRows) Columns() []string { return []string{"ok"} }
func (r *pingTestRows) Close() error { return nil }
func (r *pingTestRows) Next([]driver.Value) error { return io.EOF }

View File

@@ -1,5 +0,0 @@
package pipeline
// Placeholder for rate limit processor:
// - per source/kind sink routing limits
// - cooldown windows

View File

@@ -1,16 +1,18 @@
// Package normalize provides a concrete normalization processor for feedkit pipelines. // Package normalize provides the feedkit normalization processor and related
// helper APIs for raw-to-canonical event mapping.
// //
// Motivation: // External API surface:
// Many daemons have sources that: // - Processor: concrete processors.Processor implementation
// 1. fetch raw upstream data (often JSON), and // - Normalizer / Func: normalization interface and ergonomic function adapter
// 2. transform it into a domain's normalized payload format.
// //
// Doing both steps inside Source.Poll works, but tends to make sources large and // Optional helpers from helpers.go:
// encourages duplication (unit conversions, common mapping helpers, etc.). // - PayloadJSONBytes: extract supported JSON-shaped payloads into bytes
// - DecodeJSONPayload: decode an event payload into a typed struct
// - FinalizeEvent: copy the input event envelope onto a normalized output
// //
// This package lets a source emit a "raw" event (e.g., Schema="raw.openweather.current.v1", // Typical usage:
// Payload=json.RawMessage), and then a normalize.Processor can convert it into a // sources emit raw events (often with json.RawMessage payloads), then a
// normalized event (e.g., Schema="weather.observation.v1", Payload=WeatherObservation{}). // normalize.Processor converts matching raw schemas into canonical payloads.
// //
// Key property: normalization is optional. // Key property: normalization is optional.
// If no Normalizer matches an event, Processor passes it through unchanged by default. // If no Normalizer matches an event, Processor passes it through unchanged by default.

View File

@@ -0,0 +1,84 @@
package normalize
import (
"encoding/json"
"fmt"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// PayloadJSONBytes extracts a JSON payload into bytes suitable for json.Unmarshal.
//
// Supported payload shapes:
// - json.RawMessage
// - []byte
// - string
// - map[string]any
func PayloadJSONBytes(e event.Event) ([]byte, error) {
if e.Payload == nil {
return nil, fmt.Errorf("payload is nil")
}
switch v := e.Payload.(type) {
case json.RawMessage:
if len(v) == 0 {
return nil, fmt.Errorf("payload is empty json.RawMessage")
}
return []byte(v), nil
case []byte:
if len(v) == 0 {
return nil, fmt.Errorf("payload is empty []byte")
}
return v, nil
case string:
if v == "" {
return nil, fmt.Errorf("payload is empty string")
}
return []byte(v), nil
case map[string]any:
b, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("marshal map payload: %w", err)
}
return b, nil
default:
return nil, fmt.Errorf("unsupported payload type %T", e.Payload)
}
}
// DecodeJSONPayload extracts the event payload as bytes and unmarshals it into T.
func DecodeJSONPayload[T any](in event.Event) (T, error) {
var zero T
b, err := PayloadJSONBytes(in)
if err != nil {
return zero, fmt.Errorf("extract payload: %w", err)
}
var parsed T
if err := json.Unmarshal(b, &parsed); err != nil {
return zero, fmt.Errorf("decode raw payload: %w", err)
}
return parsed, nil
}
// FinalizeEvent builds the output event envelope by copying the input and applying
// the new schema/payload, plus optional EffectiveAt.
func FinalizeEvent(in event.Event, outSchema string, outPayload any, effectiveAt time.Time) (*event.Event, error) {
out := in
out.Schema = outSchema
out.Payload = outPayload
if !effectiveAt.IsZero() {
t := effectiveAt.UTC()
out.EffectiveAt = &t
}
if err := out.Validate(); err != nil {
return nil, err
}
return &out, nil
}

View File

@@ -0,0 +1,118 @@
package normalize
import (
"encoding/json"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
func TestPayloadJSONBytesSupportedShapes(t *testing.T) {
cases := []struct {
name string
payload any
want string
}{
{name: "rawmessage", payload: json.RawMessage(`{"a":1}`), want: `{"a":1}`},
{name: "bytes", payload: []byte(`{"a":2}`), want: `{"a":2}`},
{name: "string", payload: `{"a":3}`, want: `{"a":3}`},
{name: "map", payload: map[string]any{"a": 4}, want: `{"a":4}`},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, err := PayloadJSONBytes(event.Event{Payload: tc.payload})
if err != nil {
t.Fatalf("PayloadJSONBytes() unexpected error: %v", err)
}
if string(got) != tc.want {
t.Fatalf("PayloadJSONBytes() = %s, want %s", string(got), tc.want)
}
})
}
}
func TestPayloadJSONBytesRejectsInvalidPayloads(t *testing.T) {
cases := []struct {
name string
payload any
want string
}{
{name: "nil", payload: nil, want: "payload is nil"},
{name: "empty rawmessage", payload: json.RawMessage{}, want: "payload is empty json.RawMessage"},
{name: "empty bytes", payload: []byte{}, want: "payload is empty []byte"},
{name: "empty string", payload: "", want: "payload is empty string"},
{name: "unsupported", payload: 123, want: "unsupported payload type"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
_, err := PayloadJSONBytes(event.Event{Payload: tc.payload})
if err == nil {
t.Fatalf("PayloadJSONBytes() expected error")
}
if !strings.Contains(err.Error(), tc.want) {
t.Fatalf("PayloadJSONBytes() error = %q, want substring %q", err, tc.want)
}
})
}
}
func TestDecodeJSONPayload(t *testing.T) {
type payload struct {
Name string `json:"name"`
}
got, err := DecodeJSONPayload[payload](event.Event{
Payload: json.RawMessage(`{"name":"alice"}`),
})
if err != nil {
t.Fatalf("DecodeJSONPayload() unexpected error: %v", err)
}
if got.Name != "alice" {
t.Fatalf("DecodeJSONPayload() = %#v, want name alice", got)
}
}
func TestFinalizeEventPreservesEnvelopeAndEffectiveAtBehavior(t *testing.T) {
existingEffectiveAt := time.Date(2026, 3, 28, 11, 0, 0, 0, time.UTC)
in := event.Event{
ID: "evt-1",
Kind: event.Kind("observation"),
Source: "source-a",
EmittedAt: time.Date(2026, 3, 28, 12, 0, 0, 0, time.UTC),
EffectiveAt: &existingEffectiveAt,
Schema: "raw.example.v1",
Payload: map[string]any{"old": true},
}
out, err := FinalizeEvent(in, "example.v1", map[string]any{"value": 1.234567}, time.Time{})
if err != nil {
t.Fatalf("FinalizeEvent() unexpected error: %v", err)
}
if out.ID != in.ID || out.Kind != in.Kind || out.Source != in.Source || out.EmittedAt != in.EmittedAt {
t.Fatalf("FinalizeEvent() changed preserved envelope fields: %#v", out)
}
if out.EffectiveAt == nil || !out.EffectiveAt.Equal(existingEffectiveAt) {
t.Fatalf("FinalizeEvent() effectiveAt = %#v, want preserved existing value", out.EffectiveAt)
}
nextEffectiveAt := time.Date(2026, 3, 28, 13, 0, 0, 0, time.FixedZone("x", -4*3600))
out, err = FinalizeEvent(in, "example.v1", map[string]any{"value": 1.234567}, nextEffectiveAt)
if err != nil {
t.Fatalf("FinalizeEvent() unexpected overwrite error: %v", err)
}
if out.EffectiveAt == nil || !out.EffectiveAt.Equal(nextEffectiveAt.UTC()) {
t.Fatalf("FinalizeEvent() effectiveAt = %#v, want %s", out.EffectiveAt, nextEffectiveAt.UTC())
}
payloadMap, ok := out.Payload.(map[string]any)
if !ok {
t.Fatalf("FinalizeEvent() payload type = %T, want map[string]any", out.Payload)
}
if payloadMap["value"] != 1.234567 {
t.Fatalf("FinalizeEvent() payload value = %#v, want unrounded 1.234567", payloadMap["value"])
}
}

25
scheduler/doc.go Normal file
View File

@@ -0,0 +1,25 @@
// Package scheduler runs feedkit sources and forwards their events to the
// daemon event bus.
//
// External API surface:
// - Scheduler: runs configured polling and streaming jobs
// - Job: one scheduler task bound to a source
// - StreamExitPolicy: stream supervision policy for non-fatal exits
// - StreamBackoff: restart pacing for supervised stream sources
//
// Optional helpers from helpers.go:
// - JobFromSourceConfig: build a scheduler job from a configured source and
// feedkit-owned scheduling params
//
// Poll sources are run on a fixed cadence with optional jitter. Stream sources
// are supervised long-lived workers. Their generic feedkit controls live under
// sources[].params:
// - stream_exit_policy: restart|stop|fatal (default restart)
// - stream_backoff_initial: positive duration (default 1s)
// - stream_backoff_max: positive duration (default 1m)
// - stream_backoff_jitter: non-negative duration (default 250ms)
//
// Stream sources can classify exits with sources.StreamRetryable and
// sources.StreamFatal. Plain errors are treated as retryable by default, while
// fatal exits are propagated from Scheduler.Run so the daemon can shut down.
package scheduler

138
scheduler/helpers.go Normal file
View File

@@ -0,0 +1,138 @@
package scheduler
import (
"fmt"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/sources"
)
// JobFromSourceConfig builds a scheduler Job from a configured source and its
// generic feedkit config.
func JobFromSourceConfig(src sources.Input, cfg config.SourceConfig) (Job, error) {
if src == nil {
return Job{}, fmt.Errorf("scheduler: source %q is nil", cfg.Name)
}
job := Job{
Source: src,
Every: cfg.Every.Duration,
}
if _, ok := src.(sources.StreamSource); ok {
if cfg.Every.Duration > 0 {
return Job{}, fmt.Errorf("source %q: sources[].every must be omitted for stream sources", cfg.Name)
}
policy, err := parseStreamExitPolicy(cfg)
if err != nil {
return Job{}, err
}
backoff, err := parseStreamBackoff(cfg)
if err != nil {
return Job{}, err
}
job.StreamExitPolicy = policy
job.StreamBackoff = backoff
return job, nil
}
if _, ok := src.(sources.PollSource); ok {
if cfg.Every.Duration <= 0 {
return Job{}, fmt.Errorf("source %q: sources[].every must be > 0 for polling sources", cfg.Name)
}
if err := rejectStreamParams(cfg); err != nil {
return Job{}, err
}
return job, nil
}
return Job{}, fmt.Errorf("scheduler: source %q implements neither PollSource nor StreamSource", cfg.Name)
}
func parseStreamExitPolicy(cfg config.SourceConfig) (StreamExitPolicy, error) {
const key = "stream_exit_policy"
raw, exists := cfg.Params[key]
if !exists || raw == nil {
return StreamExitPolicyRestart, nil
}
s, ok := raw.(string)
if !ok {
return "", fmt.Errorf("source %q: params.%s must be one of: restart, stop, fatal", cfg.Name, key)
}
switch StreamExitPolicy(strings.ToLower(strings.TrimSpace(s))) {
case StreamExitPolicyRestart:
return StreamExitPolicyRestart, nil
case StreamExitPolicyStop:
return StreamExitPolicyStop, nil
case StreamExitPolicyFatal:
return StreamExitPolicyFatal, nil
default:
return "", fmt.Errorf("source %q: params.%s must be one of: restart, stop, fatal", cfg.Name, key)
}
}
func parseStreamBackoff(cfg config.SourceConfig) (StreamBackoff, error) {
initial, err := parsePositiveOrDefaultDuration(cfg, "stream_backoff_initial", defaultStreamBackoffInitial)
if err != nil {
return StreamBackoff{}, err
}
max, err := parsePositiveOrDefaultDuration(cfg, "stream_backoff_max", defaultStreamBackoffMax)
if err != nil {
return StreamBackoff{}, err
}
jitter, err := parseNonNegativeOrDefaultDuration(cfg, "stream_backoff_jitter", defaultStreamBackoffJitter)
if err != nil {
return StreamBackoff{}, err
}
if max < initial {
return StreamBackoff{}, fmt.Errorf("source %q: params.stream_backoff_max must be >= params.stream_backoff_initial", cfg.Name)
}
return StreamBackoff{
Initial: initial,
Max: max,
Jitter: jitter,
}, nil
}
func rejectStreamParams(cfg config.SourceConfig) error {
streamKeys := []string{
"stream_exit_policy",
"stream_backoff_initial",
"stream_backoff_max",
"stream_backoff_jitter",
}
for _, key := range streamKeys {
if _, ok := cfg.Params[key]; ok {
return fmt.Errorf("source %q: params.%s is only valid for stream sources", cfg.Name, key)
}
}
return nil
}
func parsePositiveOrDefaultDuration(cfg config.SourceConfig, key string, def time.Duration) (time.Duration, error) {
if _, exists := cfg.Params[key]; !exists {
return def, nil
}
v, ok := cfg.ParamDuration(key)
if !ok || v <= 0 {
return 0, fmt.Errorf("source %q: params.%s must be a positive duration", cfg.Name, key)
}
return v, nil
}
func parseNonNegativeOrDefaultDuration(cfg config.SourceConfig, key string, def time.Duration) (time.Duration, error) {
if _, exists := cfg.Params[key]; !exists {
return def, nil
}
v, ok := cfg.ParamDuration(key)
if !ok || v < 0 {
return 0, fmt.Errorf("source %q: params.%s must be a non-negative duration", cfg.Name, key)
}
return v, nil
}

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"math/rand" "math/rand"
"sync"
"time" "time"
"gitea.maximumdirect.net/ejr/feedkit/event" "gitea.maximumdirect.net/ejr/feedkit/event"
@@ -28,8 +29,10 @@ type Logger = logging.Logf
// - For stream sources: Jitter is applied once at startup only (optional; useful to avoid // - For stream sources: Jitter is applied once at startup only (optional; useful to avoid
// reconnect storms when many instances start together). // reconnect storms when many instances start together).
type Job struct { type Job struct {
Source sources.Input Source sources.Input
Every time.Duration Every time.Duration
StreamExitPolicy StreamExitPolicy
StreamBackoff StreamBackoff
// Jitter is the maximum additional delay added before each poll. // Jitter is the maximum additional delay added before each poll.
// Example: if Every=15m and Jitter=30s, each poll will occur at: // Example: if Every=15m and Jitter=30s, each poll will occur at:
@@ -41,12 +44,37 @@ type Job struct {
Jitter time.Duration Jitter time.Duration
} }
// StreamExitPolicy controls how the scheduler handles non-fatal stream exits.
type StreamExitPolicy string
const (
StreamExitPolicyRestart StreamExitPolicy = "restart"
StreamExitPolicyStop StreamExitPolicy = "stop"
StreamExitPolicyFatal StreamExitPolicy = "fatal"
)
// StreamBackoff controls restart pacing for stream supervision.
type StreamBackoff struct {
Initial time.Duration
Max time.Duration
Jitter time.Duration
}
type Scheduler struct { type Scheduler struct {
Jobs []Job Jobs []Job
Out chan<- event.Event Out chan<- event.Event
Logf Logger Logf Logger
} }
const (
defaultStreamBackoffInitial = 1 * time.Second
defaultStreamBackoffMax = 1 * time.Minute
defaultStreamBackoffJitter = 250 * time.Millisecond
streamBackoffResetAfter = 5 * time.Minute
)
var timeNow = time.Now
// Run starts one goroutine per job. // Run starts one goroutine per job.
// Poll jobs run on their own interval and emit 0..N events per poll. // Poll jobs run on their own interval and emit 0..N events per poll.
// Stream jobs run continuously and emit events as they arrive. // Stream jobs run continuously and emit events as they arrive.
@@ -58,16 +86,38 @@ func (s *Scheduler) Run(ctx context.Context) error {
return fmt.Errorf("scheduler.Run: no jobs configured") return fmt.Errorf("scheduler.Run: no jobs configured")
} }
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
fatalErrCh := make(chan error, 1)
var wg sync.WaitGroup
for _, job := range s.Jobs { for _, job := range s.Jobs {
job := job // capture loop variable job := job // capture loop variable
go s.runJob(ctx, job) wg.Add(1)
go func() {
defer wg.Done()
s.runJob(runCtx, job, fatalErrCh)
}()
} }
<-ctx.Done() done := make(chan struct{})
return ctx.Err() go func() {
wg.Wait()
close(done)
}()
select {
case err := <-fatalErrCh:
cancel()
<-done
return err
case <-runCtx.Done():
<-done
return runCtx.Err()
}
} }
func (s *Scheduler) runJob(ctx context.Context, job Job) { func (s *Scheduler) runJob(ctx context.Context, job Job, fatalErrCh chan<- error) {
if job.Source == nil { if job.Source == nil {
s.logf("scheduler: job has nil source") s.logf("scheduler: job has nil source")
return return
@@ -75,7 +125,7 @@ func (s *Scheduler) runJob(ctx context.Context, job Job) {
// Stream sources: event-driven. // Stream sources: event-driven.
if ss, ok := job.Source.(sources.StreamSource); ok { if ss, ok := job.Source.(sources.StreamSource); ok {
s.runStream(ctx, job, ss) s.runStream(ctx, job, ss, fatalErrCh)
return return
} }
@@ -93,18 +143,51 @@ func (s *Scheduler) runJob(ctx context.Context, job Job) {
s.runPoller(ctx, job, ps) s.runPoller(ctx, job, ps)
} }
func (s *Scheduler) runStream(ctx context.Context, job Job, src sources.StreamSource) { func (s *Scheduler) runStream(ctx context.Context, job Job, src sources.StreamSource, fatalErrCh chan<- error) {
policy := effectiveStreamExitPolicy(job.StreamExitPolicy)
backoff := effectiveStreamBackoff(job.StreamBackoff)
rng := seededRNG(src.Name())
// Optional startup jitter: helps avoid reconnect storms if many daemons start at once. // Optional startup jitter: helps avoid reconnect storms if many daemons start at once.
if job.Jitter > 0 { if job.Jitter > 0 {
rng := seededRNG(src.Name())
if !sleepJitter(ctx, rng, job.Jitter) { if !sleepJitter(ctx, rng, job.Jitter) {
return return
} }
} }
// Stream sources should block until ctx cancel or fatal error. nextDelay := backoff.Initial
if err := src.Run(ctx, s.Out); err != nil && ctx.Err() == nil { for {
s.logf("scheduler: stream source %q exited with error: %v", src.Name(), err) startedAt := timeNow()
err := src.Run(ctx, s.Out)
if ctx.Err() != nil {
return
}
normalizedErr := normalizeStreamExitError(src.Name(), err)
if sources.IsStreamFatal(normalizedErr) {
s.reportFatal(fatalErrCh, fmt.Errorf("scheduler: stream source %q exited fatally: %w", src.Name(), normalizedErr))
return
}
switch policy {
case StreamExitPolicyStop:
s.logf("scheduler: stream source %q stopped after exit: %v", src.Name(), normalizedErr)
return
case StreamExitPolicyFatal:
s.reportFatal(fatalErrCh, fmt.Errorf("scheduler: stream source %q exited under fatal policy: %w", src.Name(), normalizedErr))
return
}
if streamRunWasStable(startedAt, timeNow()) {
nextDelay = backoff.Initial
}
delay := nextDelay + randomDuration(rng, backoff.Jitter)
s.logf("scheduler: stream source %q exited; restarting in %s: %v", src.Name(), delay, normalizedErr)
if !sleepDuration(ctx, delay) {
return
}
nextDelay = nextStreamBackoff(nextDelay, backoff.Max)
} }
} }
@@ -164,10 +247,77 @@ func (s *Scheduler) logf(format string, args ...any) {
s.Logf(format, args...) s.Logf(format, args...)
} }
func (s *Scheduler) reportFatal(ch chan<- error, err error) {
if err == nil {
return
}
select {
case ch <- err:
default:
}
}
// ---- helpers ---- // ---- helpers ----
func effectiveStreamExitPolicy(policy StreamExitPolicy) StreamExitPolicy {
switch policy {
case StreamExitPolicyStop, StreamExitPolicyFatal:
return policy
default:
return StreamExitPolicyRestart
}
}
func effectiveStreamBackoff(cfg StreamBackoff) StreamBackoff {
out := cfg
if out.Initial <= 0 {
out.Initial = defaultStreamBackoffInitial
}
if out.Max <= 0 {
out.Max = defaultStreamBackoffMax
}
if out.Max < out.Initial {
out.Max = out.Initial
}
if out.Jitter < 0 {
out.Jitter = 0
}
return out
}
func normalizeStreamExitError(sourceName string, err error) error {
if err != nil {
return err
}
return sources.StreamRetryable(fmt.Errorf("stream source %q exited unexpectedly without error", sourceName))
}
func nextStreamBackoff(current, max time.Duration) time.Duration {
if current <= 0 {
current = defaultStreamBackoffInitial
}
if max <= 0 {
max = defaultStreamBackoffMax
}
if current >= max {
return max
}
next := current * 2
if next < current || next > max {
return max
}
return next
}
func streamRunWasStable(startedAt, endedAt time.Time) bool {
if startedAt.IsZero() || endedAt.IsZero() {
return false
}
return endedAt.Sub(startedAt) >= streamBackoffResetAfter
}
func seededRNG(name string) *rand.Rand { func seededRNG(name string) *rand.Rand {
seed := time.Now().UnixNano() ^ int64(hashStringFNV32a(name)) seed := timeNow().UnixNano() ^ int64(hashStringFNV32a(name))
return rand.New(rand.NewSource(seed)) return rand.New(rand.NewSource(seed))
} }
@@ -206,11 +356,23 @@ func sleepJitter(ctx context.Context, rng *rand.Rand, max time.Duration) bool {
return true return true
} }
return sleepDuration(ctx, randomDuration(rng, max))
}
func randomDuration(rng *rand.Rand, max time.Duration) time.Duration {
if max <= 0 {
return 0
}
// Int63n requires a positive argument. // Int63n requires a positive argument.
// We add 1 so max itself is attainable. // We add 1 so max itself is attainable.
n := rng.Int63n(int64(max) + 1) n := rng.Int63n(int64(max) + 1)
d := time.Duration(n) return time.Duration(n)
}
func sleepDuration(ctx context.Context, d time.Duration) bool {
if d <= 0 {
return true
}
timer := time.NewTimer(d) timer := time.NewTimer(d)
defer timer.Stop() defer timer.Stop()

472
scheduler/scheduler_test.go Normal file
View File

@@ -0,0 +1,472 @@
package scheduler
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
"gitea.maximumdirect.net/ejr/feedkit/sources"
)
type testPollSource struct {
name string
}
func (s testPollSource) Name() string { return s.name }
func (s testPollSource) Poll(context.Context) ([]event.Event, error) { return nil, nil }
type scriptedStreamSource struct {
name string
mu sync.Mutex
calls int
runs []func(context.Context, chan<- event.Event) error
}
func (s *scriptedStreamSource) Name() string { return s.name }
func (s *scriptedStreamSource) Run(ctx context.Context, out chan<- event.Event) error {
s.mu.Lock()
call := s.calls
s.calls++
var run func(context.Context, chan<- event.Event) error
if call < len(s.runs) {
run = s.runs[call]
}
s.mu.Unlock()
if run != nil {
return run(ctx, out)
}
<-ctx.Done()
return ctx.Err()
}
func (s *scriptedStreamSource) CallCount() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.calls
}
type capturingLogger struct {
mu sync.Mutex
lines []string
}
func (l *capturingLogger) Logf(format string, args ...any) {
l.mu.Lock()
defer l.mu.Unlock()
l.lines = append(l.lines, fmt.Sprintf(format, args...))
}
func (l *capturingLogger) Contains(substr string) bool {
l.mu.Lock()
defer l.mu.Unlock()
for _, line := range l.lines {
if strings.Contains(line, substr) {
return true
}
}
return false
}
func TestSchedulerRunRestartsPlainStreamErrors(t *testing.T) {
src := &scriptedStreamSource{
name: "stream-a",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return errors.New("temporary failure") },
func(context.Context, chan<- event.Event) error { return errors.New("temporary failure") },
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamBackoff: StreamBackoff{
Initial: time.Millisecond,
Max: time.Millisecond,
},
}},
Out: make(chan event.Event, 1),
}
errCh := make(chan error, 1)
go func() { errCh <- s.Run(ctx) }()
waitFor(t, func() bool { return src.CallCount() >= 3 })
cancel()
err := <-errCh
if !errors.Is(err, context.Canceled) {
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
}
if src.CallCount() < 3 {
t.Fatalf("stream call count = %d, want at least 3", src.CallCount())
}
}
func TestSchedulerRunFatalStreamErrorReturns(t *testing.T) {
base := errors.New("fatal failure")
src := &scriptedStreamSource{
name: "stream-fatal",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return sources.StreamFatal(base) },
},
}
s := &Scheduler{
Jobs: []Job{{Source: src}},
Out: make(chan event.Event, 1),
}
err := s.Run(context.Background())
if err == nil {
t.Fatalf("Scheduler.Run() error = nil, want fatal error")
}
if !sources.IsStreamFatal(err) {
t.Fatalf("Scheduler.Run() error = %v, want fatal classification", err)
}
if !errors.Is(err, base) {
t.Fatalf("Scheduler.Run() error does not wrap base fatal error: %v", err)
}
}
func TestSchedulerRunStopPolicyStopsOnlyThatSource(t *testing.T) {
src := &scriptedStreamSource{
name: "stream-stop",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return errors.New("stop now") },
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamExitPolicy: StreamExitPolicyStop,
}},
Out: make(chan event.Event, 1),
}
errCh := make(chan error, 1)
go func() { errCh <- s.Run(ctx) }()
waitFor(t, func() bool { return src.CallCount() >= 1 })
time.Sleep(20 * time.Millisecond)
select {
case err := <-errCh:
t.Fatalf("Scheduler.Run() returned early: %v", err)
default:
}
if src.CallCount() != 1 {
t.Fatalf("stream call count = %d, want 1", src.CallCount())
}
cancel()
err := <-errCh
if !errors.Is(err, context.Canceled) {
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
}
}
func TestSchedulerRunFatalPolicyTreatsPlainErrorAsFatal(t *testing.T) {
base := errors.New("plain failure")
src := &scriptedStreamSource{
name: "stream-fatal-policy",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return base },
},
}
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamExitPolicy: StreamExitPolicyFatal,
}},
Out: make(chan event.Event, 1),
}
err := s.Run(context.Background())
if err == nil {
t.Fatalf("Scheduler.Run() error = nil, want fatal-policy error")
}
if !errors.Is(err, base) {
t.Fatalf("Scheduler.Run() error does not wrap base error: %v", err)
}
}
func TestSchedulerRunNilExitRestartsAsUnexpected(t *testing.T) {
logger := &capturingLogger{}
src := &scriptedStreamSource{
name: "stream-nil-exit",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return nil },
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamBackoff: StreamBackoff{
Initial: time.Millisecond,
Max: time.Millisecond,
},
}},
Out: make(chan event.Event, 1),
Logf: logger.Logf,
}
errCh := make(chan error, 1)
go func() { errCh <- s.Run(ctx) }()
waitFor(t, func() bool { return src.CallCount() >= 2 })
cancel()
err := <-errCh
if !errors.Is(err, context.Canceled) {
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
}
if !logger.Contains("exited unexpectedly without error") {
t.Fatalf("expected log to mention unexpected nil stream exit")
}
}
func TestSchedulerRunContextCancelDuringBackoff(t *testing.T) {
src := &scriptedStreamSource{
name: "stream-backoff-cancel",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return errors.New("retry me") },
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamBackoff: StreamBackoff{
Initial: time.Second,
Max: time.Second,
},
}},
Out: make(chan event.Event, 1),
}
errCh := make(chan error, 1)
go func() { errCh <- s.Run(ctx) }()
waitFor(t, func() bool { return src.CallCount() >= 1 })
cancel()
err := <-errCh
if !errors.Is(err, context.Canceled) {
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
}
time.Sleep(20 * time.Millisecond)
if src.CallCount() != 1 {
t.Fatalf("stream call count = %d, want 1", src.CallCount())
}
}
func TestNextStreamBackoffCapsAtMax(t *testing.T) {
if got := nextStreamBackoff(500*time.Millisecond, 2*time.Second); got != time.Second {
t.Fatalf("nextStreamBackoff() = %s, want 1s", got)
}
if got := nextStreamBackoff(time.Second, 2*time.Second); got != 2*time.Second {
t.Fatalf("nextStreamBackoff() = %s, want 2s", got)
}
if got := nextStreamBackoff(2*time.Second, 2*time.Second); got != 2*time.Second {
t.Fatalf("nextStreamBackoff() = %s, want 2s", got)
}
}
func TestStreamRunWasStableAfterFiveMinutes(t *testing.T) {
start := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)
if streamRunWasStable(start, start.Add(4*time.Minute+59*time.Second)) {
t.Fatalf("streamRunWasStable() = true, want false")
}
if !streamRunWasStable(start, start.Add(5*time.Minute)) {
t.Fatalf("streamRunWasStable() = false, want true")
}
}
func TestJobFromSourceConfigPollSource(t *testing.T) {
job, err := JobFromSourceConfig(testPollSource{name: "poll-a"}, config.SourceConfig{
Name: "poll-a",
Driver: "poll_driver",
Every: config.Duration{Duration: time.Minute},
})
if err != nil {
t.Fatalf("JobFromSourceConfig() error = %v", err)
}
if job.Every != time.Minute {
t.Fatalf("Job.Every = %s, want 1m", job.Every)
}
}
func TestJobFromSourceConfigPollSourceRejectsStreamParams(t *testing.T) {
_, err := JobFromSourceConfig(testPollSource{name: "poll-a"}, config.SourceConfig{
Name: "poll-a",
Driver: "poll_driver",
Every: config.Duration{Duration: time.Minute},
Params: map[string]any{
"stream_exit_policy": "restart",
},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want rejection")
}
if !strings.Contains(err.Error(), "only valid for stream sources") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
}
func TestJobFromSourceConfigStreamSourceParsesDefaultsAndOverrides(t *testing.T) {
src := &scriptedStreamSource{name: "stream-a"}
job, err := JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-a",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Params: map[string]any{
"stream_exit_policy": "stop",
"stream_backoff_initial": "2s",
"stream_backoff_max": "10s",
"stream_backoff_jitter": "500ms",
},
})
if err != nil {
t.Fatalf("JobFromSourceConfig() error = %v", err)
}
if job.StreamExitPolicy != StreamExitPolicyStop {
t.Fatalf("Job.StreamExitPolicy = %q, want %q", job.StreamExitPolicy, StreamExitPolicyStop)
}
if job.StreamBackoff.Initial != 2*time.Second {
t.Fatalf("Job.StreamBackoff.Initial = %s, want 2s", job.StreamBackoff.Initial)
}
if job.StreamBackoff.Max != 10*time.Second {
t.Fatalf("Job.StreamBackoff.Max = %s, want 10s", job.StreamBackoff.Max)
}
if job.StreamBackoff.Jitter != 500*time.Millisecond {
t.Fatalf("Job.StreamBackoff.Jitter = %s, want 500ms", job.StreamBackoff.Jitter)
}
defaultJob, err := JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-default",
Driver: "stream_driver",
Mode: config.SourceModeStream,
})
if err != nil {
t.Fatalf("JobFromSourceConfig() default error = %v", err)
}
if defaultJob.StreamExitPolicy != StreamExitPolicyRestart {
t.Fatalf("default Job.StreamExitPolicy = %q, want restart", defaultJob.StreamExitPolicy)
}
if defaultJob.StreamBackoff.Initial != defaultStreamBackoffInitial {
t.Fatalf("default Job.StreamBackoff.Initial = %s, want %s", defaultJob.StreamBackoff.Initial, defaultStreamBackoffInitial)
}
if defaultJob.StreamBackoff.Max != defaultStreamBackoffMax {
t.Fatalf("default Job.StreamBackoff.Max = %s, want %s", defaultJob.StreamBackoff.Max, defaultStreamBackoffMax)
}
if defaultJob.StreamBackoff.Jitter != defaultStreamBackoffJitter {
t.Fatalf("default Job.StreamBackoff.Jitter = %s, want %s", defaultJob.StreamBackoff.Jitter, defaultStreamBackoffJitter)
}
}
func TestJobFromSourceConfigStreamSourceRejectsInvalidSettings(t *testing.T) {
src := &scriptedStreamSource{name: "stream-b"}
_, err := JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-b",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Params: map[string]any{
"stream_exit_policy": "sometimes",
},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want invalid policy error")
}
if !strings.Contains(err.Error(), "stream_exit_policy") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
_, err = JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-b",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Params: map[string]any{
"stream_backoff_initial": "0s",
},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want invalid initial backoff error")
}
if !strings.Contains(err.Error(), "stream_backoff_initial") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
_, err = JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-b",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Params: map[string]any{
"stream_backoff_initial": "2s",
"stream_backoff_max": "1s",
},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want invalid max backoff error")
}
if !strings.Contains(err.Error(), "stream_backoff_max") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
}
func TestJobFromSourceConfigStreamSourceRejectsEvery(t *testing.T) {
src := &scriptedStreamSource{name: "stream-c"}
_, err := JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-c",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Every: config.Duration{Duration: time.Minute},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want every rejection")
}
if !strings.Contains(err.Error(), "sources[].every must be omitted") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
}
func waitFor(t *testing.T, cond func() bool) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("condition not satisfied before timeout")
}

View File

@@ -1,7 +0,0 @@
package scheduler
// Placeholder for per-source worker logic:
// - ticker loop
// - jitter
// - backoff on errors
// - emits events into scheduler.Out

View File

@@ -1,11 +1,6 @@
package sinks package sinks
import ( import "gitea.maximumdirect.net/ejr/feedkit/config"
"fmt"
"strings"
"gitea.maximumdirect.net/ejr/feedkit/config"
)
// RegisterBuiltins registers sink drivers included in this binary. // RegisterBuiltins registers sink drivers included in this binary.
// //
@@ -17,39 +12,8 @@ func RegisterBuiltins(r *Registry) {
return NewStdoutSink(cfg.Name), nil return NewStdoutSink(cfg.Name), nil
}) })
// File sink: writes/archives events somewhere on disk.
r.Register("file", func(cfg config.SinkConfig) (Sink, error) {
return NewFileSinkFromConfig(cfg)
})
// Postgres sink: persists events durably.
r.Register("postgres", func(cfg config.SinkConfig) (Sink, error) {
return NewPostgresSinkFromConfig(cfg)
})
// NATS sink: publishes events to a broker for downstream consumers. // NATS sink: publishes events to a broker for downstream consumers.
r.Register("nats", func(cfg config.SinkConfig) (Sink, error) { r.Register("nats", func(cfg config.SinkConfig) (Sink, error) {
return NewNATSSinkFromConfig(cfg) return NewNATSSinkFromConfig(cfg)
}) })
} }
// ---- helpers for validating sink params ----
//
// These helpers live in sinks (not config) on purpose:
// - config is domain-agnostic and should not embed driver-specific validation helpers.
// - sinks are adapters; validating their own params here keeps the logic near the driver.
func requireStringParam(cfg config.SinkConfig, key string) (string, error) {
v, ok := cfg.Params[key]
if !ok {
return "", fmt.Errorf("sink %q: params.%s is required", cfg.Name, key)
}
s, ok := v.(string)
if !ok {
return "", fmt.Errorf("sink %q: params.%s must be a string", cfg.Name, key)
}
if strings.TrimSpace(s) == "" {
return "", fmt.Errorf("sink %q: params.%s cannot be empty", cfg.Name, key)
}
return s, nil
}

View File

@@ -1,18 +1,27 @@
// Package sinks provides sink abstractions, a sink driver registry, and several // Package sinks defines the feedkit sink interface, sink driver registry, and
// built-in sink drivers. // built-in infrastructure sinks.
// //
// Built-in drivers: // External API surface:
// - Sink: adapter interface that consumes event.Event values
// - Registry / NewRegistry: named sink factory registry
// - RegisterBuiltins: registers the schema-free built-in sink drivers
//
// Built-in sink implementations:
// - stdout // - stdout
// - nats // - nats
// - postgres // - postgres
// //
// Optional helpers from helpers.go:
// - PostgresFactory: returns a sink factory for the built-in Postgres sink
// using a provided downstream schema
//
// # NATS built-in overview // # NATS built-in overview
// //
// The NATS sink publishes each event as JSON to a configured subject. // The NATS sink publishes each event as JSON to a configured subject.
// //
// Required params: // Required params:
// - url: NATS server URL (for example, nats://localhost:4222) // - url: NATS server URL (for example, nats://localhost:4222)
// - exchange: NATS subject to publish to // - subject: NATS subject to publish to
// //
// Example config: // Example config:
// //
@@ -21,7 +30,7 @@
// driver: nats // driver: nats
// params: // params:
// url: nats://localhost:4222 // url: nats://localhost:4222
// exchange: feedkit.events // subject: feedkit.events
// //
// # Postgres built-in overview // # Postgres built-in overview
// //
@@ -50,7 +59,9 @@
// //
// Example downstream wiring: // Example downstream wiring:
// //
// sinks.MustRegisterPostgresSchema("pg_main", sinks.PostgresSchema{ // sinkReg := sinks.NewRegistry()
// sinks.RegisterBuiltins(sinkReg)
// sinkReg.Register("postgres", sinks.PostgresFactory(sinks.PostgresSchema{
// Tables: []sinks.PostgresTable{ // Tables: []sinks.PostgresTable{
// { // {
// Name: "events", // Name: "events",
@@ -79,7 +90,7 @@
// }, // },
// }, nil // }, nil
// }, // },
// }) // }))
// //
// Manual pruning via type assertion (administrative helpers): // Manual pruning via type assertion (administrative helpers):
// //

View File

@@ -1,30 +0,0 @@
package sinks
import (
"context"
"fmt"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type FileSink struct {
name string
path string
}
func NewFileSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
path, err := requireStringParam(cfg, "path")
if err != nil {
return nil, err
}
return &FileSink{name: cfg.Name, path: path}, nil
}
func (s *FileSink) Name() string { return s.name }
func (s *FileSink) Consume(ctx context.Context, e event.Event) error {
_ = ctx
_ = e
return fmt.Errorf("file sink: TODO implement (path=%s)", s.path)
}

35
sinks/helpers.go Normal file
View File

@@ -0,0 +1,35 @@
package sinks
import (
"fmt"
"strings"
"gitea.maximumdirect.net/ejr/feedkit/config"
)
// requireStringParam returns a non-empty string sink param.
//
// This helper is intentionally local to sinks rather than config so
// driver-specific validation stays close to the adapters that use it.
func requireStringParam(cfg config.SinkConfig, key string) (string, error) {
v, ok := cfg.Params[key]
if !ok {
return "", fmt.Errorf("sink %q: params.%s is required", cfg.Name, key)
}
s, ok := v.(string)
if !ok {
return "", fmt.Errorf("sink %q: params.%s must be a string", cfg.Name, key)
}
if strings.TrimSpace(s) == "" {
return "", fmt.Errorf("sink %q: params.%s cannot be empty", cfg.Name, key)
}
return s, nil
}
// PostgresFactory returns a sink factory that builds the built-in Postgres sink
// using the provided downstream schema definition.
func PostgresFactory(schema PostgresSchema) Factory {
return func(cfg config.SinkConfig) (Sink, error) {
return NewPostgresSinkFromConfig(cfg, schema)
}
}

46
sinks/helpers_test.go Normal file
View File

@@ -0,0 +1,46 @@
package sinks
import (
"context"
"testing"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
func TestPostgresFactoryReturnsWorkingFactory(t *testing.T) {
factory := PostgresFactory(testPostgresSchema())
if factory == nil {
t.Fatalf("PostgresFactory() returned nil")
}
if _, err := factory(config.SinkConfig{}); err == nil {
t.Fatalf("factory(config) expected parameter validation error")
}
}
func testPostgresSchema() PostgresSchema {
return PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "event_id", Type: "TEXT", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
},
PrimaryKey: []string{"event_id"},
PruneColumn: "emitted_at",
},
},
MapEvent: func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{
Table: "events",
Values: map[string]any{
"event_id": e.ID,
"emitted_at": e.EmittedAt,
},
},
}, nil
},
}
}

View File

@@ -13,9 +13,9 @@ import (
) )
type NATSSink struct { type NATSSink struct {
name string name string
url string url string
exchange string subject string
mu sync.Mutex mu sync.Mutex
conn *nats.Conn conn *nats.Conn
@@ -26,11 +26,11 @@ func NewNATSSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
ex, err := requireStringParam(cfg, "exchange") subject, err := requireStringParam(cfg, "subject")
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &NATSSink{name: cfg.Name, url: url, exchange: ex}, nil return &NATSSink{name: cfg.Name, url: url, subject: subject}, nil
} }
func (r *NATSSink) Name() string { return r.name } func (r *NATSSink) Name() string { return r.name }
@@ -59,7 +59,7 @@ func (r *NATSSink) Consume(ctx context.Context, e event.Event) error {
if err := ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
return err return err
} }
if err := conn.Publish(r.exchange, b); err != nil { if err := conn.Publish(r.subject, b); err != nil {
return fmt.Errorf("NATS sink: publish: %w", err) return fmt.Errorf("NATS sink: publish: %w", err)
} }
return nil return nil

47
sinks/nats_test.go Normal file
View File

@@ -0,0 +1,47 @@
package sinks
import (
"strings"
"testing"
"gitea.maximumdirect.net/ejr/feedkit/config"
)
func TestNewNATSSinkFromConfigRequiresSubject(t *testing.T) {
sink, err := NewNATSSinkFromConfig(config.SinkConfig{
Name: "nats-main",
Driver: "nats",
Params: map[string]any{
"url": "nats://localhost:4222",
"subject": "feedkit.events",
},
})
if err != nil {
t.Fatalf("NewNATSSinkFromConfig() error = %v", err)
}
natsSink, ok := sink.(*NATSSink)
if !ok {
t.Fatalf("sink type = %T, want *NATSSink", sink)
}
if natsSink.subject != "feedkit.events" {
t.Fatalf("subject = %q, want feedkit.events", natsSink.subject)
}
}
func TestNewNATSSinkFromConfigRejectsLegacyExchange(t *testing.T) {
_, err := NewNATSSinkFromConfig(config.SinkConfig{
Name: "nats-main",
Driver: "nats",
Params: map[string]any{
"url": "nats://localhost:4222",
"exchange": "feedkit.events",
},
})
if err == nil {
t.Fatalf("NewNATSSinkFromConfig() expected error")
}
if !strings.Contains(err.Error(), "params.subject is required") {
t.Fatalf("error = %q, want params.subject is required", err)
}
}

View File

@@ -4,18 +4,15 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"gitea.maximumdirect.net/ejr/feedkit/config" "gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event" "gitea.maximumdirect.net/ejr/feedkit/event"
_ "github.com/lib/pq" pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
) )
const postgresInitTimeout = 5 * time.Second
type postgresTx interface { type postgresTx interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Commit() error Commit() error
@@ -73,8 +70,8 @@ func (w *sqlTxWrapper) Rollback() error {
return w.tx.Rollback() return w.tx.Rollback()
} }
var openPostgresDB = func(dsn string) (postgresDB, error) { var openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
db, err := sql.Open("postgres", dsn) db, err := pgconn.Open(ctx, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -88,7 +85,7 @@ type PostgresSink struct {
pruneWindow time.Duration pruneWindow time.Duration
} }
func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) { func NewPostgresSinkFromConfig(cfg config.SinkConfig, schemaDef PostgresSchema) (Sink, error) {
uri, err := requireStringParam(cfg, "uri") uri, err := requireStringParam(cfg, "uri")
if err != nil { if err != nil {
return nil, err return nil, err
@@ -106,17 +103,16 @@ func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
return nil, err return nil, err
} }
schema, ok := lookupPostgresSchema(cfg.Name) schema, err := compilePostgresSchema(schemaDef)
if !ok {
return nil, fmt.Errorf("postgres sink %q: no schema registered (call sinks.RegisterPostgresSchema before building sinks)", cfg.Name)
}
dsn, err := buildPostgresDSN(uri, username, password)
if err != nil { if err != nil {
return nil, fmt.Errorf("postgres sink %q: build dsn: %w", cfg.Name, err) return nil, fmt.Errorf("postgres sink %q: compile schema: %w", cfg.Name, err)
} }
db, err := openPostgresDB(dsn) db, err := openPostgresDB(context.Background(), pgconn.ConnConfig{
URI: uri,
Username: username,
Password: password,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("postgres sink %q: open db: %w", cfg.Name, err) return nil, fmt.Errorf("postgres sink %q: open db: %w", cfg.Name, err)
} }
@@ -264,13 +260,9 @@ func (p *PostgresSink) PruneAllOlderThan(ctx context.Context, cutoff time.Time)
} }
func (p *PostgresSink) initialize() error { func (p *PostgresSink) initialize() error {
ctx, cancel := context.WithTimeout(context.Background(), postgresInitTimeout) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := p.db.PingContext(ctx); err != nil {
return fmt.Errorf("postgres sink %q: ping db: %w", p.name, err)
}
for _, tableName := range p.schema.tableOrder { for _, tableName := range p.schema.tableOrder {
tbl := p.schema.tables[tableName] tbl := p.schema.tables[tableName]
@@ -302,21 +294,6 @@ func (p *PostgresSink) lookupTable(table string) (postgresTableCompiled, error)
return tbl, nil return tbl, nil
} }
func buildPostgresDSN(uri, username, password string) (string, error) {
u, err := url.Parse(strings.TrimSpace(uri))
if err != nil {
return "", fmt.Errorf("invalid uri: %w", err)
}
if u.Scheme == "" {
return "", fmt.Errorf("invalid uri: missing scheme")
}
if u.Host == "" {
return "", fmt.Errorf("invalid uri: missing host")
}
u.User = url.UserPassword(username, password)
return u.String(), nil
}
func parsePostgresPruneWindow(cfg config.SinkConfig) (time.Duration, error) { func parsePostgresPruneWindow(cfg config.SinkConfig) (time.Duration, error) {
raw, ok := cfg.Params["prune"] raw, ok := cfg.Params["prune"]
if !ok || raw == nil { if !ok || raw == nil {

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"strings" "strings"
"sync"
"time" "time"
"gitea.maximumdirect.net/ejr/feedkit/event" "gitea.maximumdirect.net/ejr/feedkit/event"
@@ -72,51 +71,6 @@ type postgresTableCompiled struct {
indexes []PostgresIndex indexes []PostgresIndex
} }
var (
postgresSchemaRegistryMu sync.RWMutex
postgresSchemaRegistry = map[string]postgresSchemaCompiled{}
)
// RegisterPostgresSchema registers one downstream schema by sink name.
//
// This should be called by downstream daemon wiring code before sink
// construction. Duplicate sink-name registrations are rejected.
func RegisterPostgresSchema(sinkName string, schema PostgresSchema) error {
sinkName = strings.TrimSpace(sinkName)
if sinkName == "" {
return fmt.Errorf("postgres schema: sink name cannot be empty")
}
compiled, err := compilePostgresSchema(schema)
if err != nil {
return err
}
postgresSchemaRegistryMu.Lock()
defer postgresSchemaRegistryMu.Unlock()
if _, exists := postgresSchemaRegistry[sinkName]; exists {
return fmt.Errorf("postgres schema: sink %q already registered", sinkName)
}
postgresSchemaRegistry[sinkName] = compiled
return nil
}
func MustRegisterPostgresSchema(sinkName string, schema PostgresSchema) {
if err := RegisterPostgresSchema(sinkName, schema); err != nil {
panic(err)
}
}
func lookupPostgresSchema(sinkName string) (postgresSchemaCompiled, bool) {
postgresSchemaRegistryMu.RLock()
defer postgresSchemaRegistryMu.RUnlock()
s, ok := postgresSchemaRegistry[sinkName]
return s, ok
}
func compilePostgresSchema(schema PostgresSchema) (postgresSchemaCompiled, error) { func compilePostgresSchema(schema PostgresSchema) (postgresSchemaCompiled, error) {
if schema.MapEvent == nil { if schema.MapEvent == nil {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: map function is required") return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: map function is required")

View File

@@ -4,13 +4,13 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"net/url"
"strings" "strings"
"testing" "testing"
"time" "time"
"gitea.maximumdirect.net/ejr/feedkit/config" "gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event" "gitea.maximumdirect.net/ejr/feedkit/event"
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
) )
type fakeResult struct { type fakeResult struct {
@@ -96,20 +96,12 @@ func (d *fakeDB) Close() error {
return nil return nil
} }
func resetPostgresSchemaRegistryForTest() {
postgresSchemaRegistryMu.Lock()
defer postgresSchemaRegistryMu.Unlock()
postgresSchemaRegistry = map[string]postgresSchemaCompiled{}
}
func withPostgresTestState(t *testing.T) { func withPostgresTestState(t *testing.T) {
t.Helper() t.Helper()
resetPostgresSchemaRegistryForTest()
oldOpen := openPostgresDB oldOpen := openPostgresDB
t.Cleanup(func() { t.Cleanup(func() {
openPostgresDB = oldOpen openPostgresDB = oldOpen
resetPostgresSchemaRegistryForTest()
}) })
} }
@@ -183,35 +175,8 @@ func mustCompileSchema(t *testing.T, s PostgresSchema) postgresSchemaCompiled {
return compiled return compiled
} }
func TestRegisterPostgresSchema(t *testing.T) { func TestCompilePostgresSchemaRejectsInvalidSchema(t *testing.T) {
withPostgresTestState(t) _, err := compilePostgresSchema(PostgresSchema{
err := RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
}))
if err != nil {
t.Fatalf("register schema: %v", err)
}
if _, ok := lookupPostgresSchema("pg"); !ok {
t.Fatalf("expected schema registration")
}
err = RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
}))
if err == nil {
t.Fatalf("expected duplicate registration error")
}
if !strings.Contains(err.Error(), "already registered") {
t.Fatalf("unexpected duplicate error: %v", err)
}
}
func TestRegisterPostgresSchema_RejectsInvalidSchema(t *testing.T) {
withPostgresTestState(t)
err := RegisterPostgresSchema("pg", PostgresSchema{
Tables: []PostgresTable{ Tables: []PostgresTable{
{ {
Name: "events", Name: "events",
@@ -230,7 +195,7 @@ func TestRegisterPostgresSchema_RejectsInvalidSchema(t *testing.T) {
t.Fatalf("unexpected schema validation error: %v", err) t.Fatalf("unexpected schema validation error: %v", err)
} }
err = RegisterPostgresSchema("pg2", PostgresSchema{ _, err = compilePostgresSchema(PostgresSchema{
Tables: []PostgresTable{ Tables: []PostgresTable{
{ {
Name: "events", Name: "events",
@@ -254,7 +219,56 @@ func TestRegisterPostgresSchema_RejectsInvalidSchema(t *testing.T) {
} }
} }
func TestNewPostgresSinkFromConfig_MissingParams(t *testing.T) { func TestPostgresFactoryBuildsMultipleSinksWithSameSchema(t *testing.T) {
withPostgresTestState(t)
dbs := []*fakeDB{{}, {}}
var gotCfgs []pgconn.ConnConfig
openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
gotCfgs = append(gotCfgs, cfg)
db := dbs[len(gotCfgs)-1]
if err := db.PingContext(ctx); err != nil {
return nil, err
}
return db, nil
}
factory := PostgresFactory(schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
}))
for _, name := range []string{"pg_a", "pg_b"} {
sink, err := factory(config.SinkConfig{
Name: name,
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
},
})
if err != nil {
t.Fatalf("factory(%q) error = %v", name, err)
}
if sink == nil {
t.Fatalf("factory(%q) returned nil sink", name)
}
}
if len(gotCfgs) != 2 {
t.Fatalf("len(gotCfgs) = %d, want 2", len(gotCfgs))
}
if gotCfgs[0].Username != "user" || gotCfgs[0].Password != "pass" {
t.Fatalf("first ConnConfig = %+v", gotCfgs[0])
}
for i, db := range dbs {
if db.pingCalls != 1 {
t.Fatalf("db[%d] pingCalls = %d, want 1", i, db.pingCalls)
}
}
}
func TestNewPostgresSinkFromConfigMissingParams(t *testing.T) {
withPostgresTestState(t) withPostgresTestState(t)
tests := []struct { tests := []struct {
@@ -273,7 +287,7 @@ func TestNewPostgresSinkFromConfig_MissingParams(t *testing.T) {
Name: "pg", Name: "pg",
Driver: "postgres", Driver: "postgres",
Params: tc.params, Params: tc.params,
}) }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err == nil { if err == nil {
t.Fatalf("expected error") t.Fatalf("expected error")
} }
@@ -284,7 +298,7 @@ func TestNewPostgresSinkFromConfig_MissingParams(t *testing.T) {
} }
} }
func TestNewPostgresSinkFromConfig_MissingSchemaRegistration(t *testing.T) { func TestNewPostgresSinkFromConfigRejectsInvalidSchema(t *testing.T) {
withPostgresTestState(t) withPostgresTestState(t)
_, err := NewPostgresSinkFromConfig(config.SinkConfig{ _, err := NewPostgresSinkFromConfig(config.SinkConfig{
@@ -295,29 +309,36 @@ func TestNewPostgresSinkFromConfig_MissingSchemaRegistration(t *testing.T) {
"username": "user", "username": "user",
"password": "pass", "password": "pass",
}, },
}, PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "id", Type: "TEXT", Nullable: false},
},
PruneColumn: "missing_col",
},
},
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
}) })
if err == nil { if err == nil {
t.Fatalf("expected error") t.Fatalf("expected invalid schema error")
} }
if !strings.Contains(err.Error(), "no schema registered") { if !strings.Contains(err.Error(), "compile schema") {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
} }
func TestNewPostgresSinkFromConfig_EagerInit(t *testing.T) { func TestNewPostgresSinkFromConfigEagerInit(t *testing.T) {
withPostgresTestState(t) withPostgresTestState(t)
err := RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
}))
if err != nil {
t.Fatalf("register schema: %v", err)
}
db := &fakeDB{} db := &fakeDB{}
var gotDSN string var gotCfg pgconn.ConnConfig
openPostgresDB = func(dsn string) (postgresDB, error) { openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
gotDSN = dsn gotCfg = cfg
if err := db.PingContext(ctx); err != nil {
return nil, err
}
return db, nil return db, nil
} }
@@ -329,7 +350,7 @@ func TestNewPostgresSinkFromConfig_EagerInit(t *testing.T) {
"username": "app_user", "username": "app_user",
"password": "app_pass", "password": "app_pass",
}, },
}) }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err != nil { if err != nil {
t.Fatalf("new postgres sink: %v", err) t.Fatalf("new postgres sink: %v", err)
} }
@@ -350,35 +371,26 @@ func TestNewPostgresSinkFromConfig_EagerInit(t *testing.T) {
t.Fatalf("unexpected create index query: %s", db.execCalls[1].query) t.Fatalf("unexpected create index query: %s", db.execCalls[1].query)
} }
u, err := url.Parse(gotDSN) if gotCfg.URI != "postgres://db.example.local:5432/feedkit?sslmode=disable" {
if err != nil { t.Fatalf("URI = %q", gotCfg.URI)
t.Fatalf("parse dsn: %v", err)
} }
if u.User == nil || u.User.Username() != "app_user" { if gotCfg.Username != "app_user" {
t.Fatalf("dsn missing username: %q", gotDSN) t.Fatalf("Username = %q, want app_user", gotCfg.Username)
} }
pass, ok := u.User.Password() if gotCfg.Password != "app_pass" {
if !ok || pass != "app_pass" { t.Fatalf("Password = %q, want app_pass", gotCfg.Password)
t.Fatalf("dsn missing password: %q", gotDSN)
} }
} }
func TestNewPostgresSinkFromConfig_InitFailureClosesDB(t *testing.T) { func TestNewPostgresSinkFromConfigInitFailureClosesDB(t *testing.T) {
withPostgresTestState(t) withPostgresTestState(t)
err := RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
}))
if err != nil {
t.Fatalf("register schema: %v", err)
}
db := &fakeDB{execErrOnCall: 1, execErr: errors.New("ddl failed")} db := &fakeDB{execErrOnCall: 1, execErr: errors.New("ddl failed")}
openPostgresDB = func(_ string) (postgresDB, error) { openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) {
return db, nil return db, nil
} }
_, err = NewPostgresSinkFromConfig(config.SinkConfig{ _, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg", Name: "pg",
Driver: "postgres", Driver: "postgres",
Params: map[string]any{ Params: map[string]any{
@@ -386,7 +398,7 @@ func TestNewPostgresSinkFromConfig_InitFailureClosesDB(t *testing.T) {
"username": "user", "username": "user",
"password": "pass", "password": "pass",
}, },
}) }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err == nil { if err == nil {
t.Fatalf("expected init error") t.Fatalf("expected init error")
} }
@@ -395,7 +407,7 @@ func TestNewPostgresSinkFromConfig_InitFailureClosesDB(t *testing.T) {
} }
} }
func TestNewPostgresSinkFromConfig_PruneParamAccepted(t *testing.T) { func TestNewPostgresSinkFromConfigPruneParamAccepted(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
in string in string
@@ -410,14 +422,7 @@ func TestNewPostgresSinkFromConfig_PruneParamAccepted(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
withPostgresTestState(t) withPostgresTestState(t)
err := RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) {
return nil, nil
}))
if err != nil {
t.Fatalf("register schema: %v", err)
}
openPostgresDB = func(_ string) (postgresDB, error) {
return &fakeDB{}, nil return &fakeDB{}, nil
} }
@@ -430,7 +435,7 @@ func TestNewPostgresSinkFromConfig_PruneParamAccepted(t *testing.T) {
"password": "pass", "password": "pass",
"prune": tc.in, "prune": tc.in,
}, },
}) }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err != nil { if err != nil {
t.Fatalf("new postgres sink: %v", err) t.Fatalf("new postgres sink: %v", err)
} }
@@ -446,7 +451,7 @@ func TestNewPostgresSinkFromConfig_PruneParamAccepted(t *testing.T) {
} }
} }
func TestNewPostgresSinkFromConfig_PruneParamRejected(t *testing.T) { func TestNewPostgresSinkFromConfigPruneParamRejected(t *testing.T) {
withPostgresTestState(t) withPostgresTestState(t)
tests := []struct { tests := []struct {
@@ -472,7 +477,7 @@ func TestNewPostgresSinkFromConfig_PruneParamRejected(t *testing.T) {
"password": "pass", "password": "pass",
"prune": tc.in, "prune": tc.in,
}, },
}) }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err == nil { if err == nil {
t.Fatalf("expected error") t.Fatalf("expected error")
} }
@@ -483,7 +488,7 @@ func TestNewPostgresSinkFromConfig_PruneParamRejected(t *testing.T) {
} }
} }
func TestPostgresSinkConsume_InvalidEvent(t *testing.T) { func TestPostgresSinkConsumeInvalidEvent(t *testing.T) {
db := &fakeDB{} db := &fakeDB{}
called := 0 called := 0
sink := &PostgresSink{ sink := &PostgresSink{
@@ -507,7 +512,7 @@ func TestPostgresSinkConsume_InvalidEvent(t *testing.T) {
} }
} }
func TestPostgresSinkConsume_UnmappedEventIsNoOp(t *testing.T) { func TestPostgresSinkConsumeUnmappedEventIsNoOp(t *testing.T) {
db := &fakeDB{} db := &fakeDB{}
sink := &PostgresSink{ sink := &PostgresSink{
name: "pg", name: "pg",
@@ -525,7 +530,7 @@ func TestPostgresSinkConsume_UnmappedEventIsNoOp(t *testing.T) {
} }
} }
func TestPostgresSinkConsume_OneEventWritesMultipleTablesAtomically(t *testing.T) { func TestPostgresSinkConsumeOneEventWritesMultipleTablesAtomically(t *testing.T) {
tx := &fakeTx{} tx := &fakeTx{}
db := &fakeDB{tx: tx} db := &fakeDB{tx: tx}
sink := &PostgresSink{ sink := &PostgresSink{
@@ -556,7 +561,7 @@ func TestPostgresSinkConsume_OneEventWritesMultipleTablesAtomically(t *testing.T
} }
} }
func TestPostgresSinkConsume_InsertFailureRollsBack(t *testing.T) { func TestPostgresSinkConsumeInsertFailureRollsBack(t *testing.T) {
tx := &fakeTx{execErrOnCall: 2, execErr: errors.New("duplicate key")} tx := &fakeTx{execErrOnCall: 2, execErr: errors.New("duplicate key")}
db := &fakeDB{tx: tx} db := &fakeDB{tx: tx}
sink := &PostgresSink{ sink := &PostgresSink{
@@ -585,13 +590,13 @@ func TestPostgresSinkConsume_InsertFailureRollsBack(t *testing.T) {
} }
} }
func TestPostgresSinkConsume_AutoPruneRunsInSameTransaction(t *testing.T) { func TestPostgresSinkConsumeAutoPruneRunsInSameTransaction(t *testing.T) {
tx := &fakeTx{} tx := &fakeTx{}
db := &fakeDB{tx: tx} db := &fakeDB{tx: tx}
sink := &PostgresSink{ sink := &PostgresSink{
name: "pg", name: "pg",
db: db, db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{ return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}}, {Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}}, {Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
@@ -620,13 +625,13 @@ func TestPostgresSinkConsume_AutoPruneRunsInSameTransaction(t *testing.T) {
} }
} }
func TestPostgresSinkConsume_AutoPruneFailureRollsBack(t *testing.T) { func TestPostgresSinkConsumeAutoPruneFailureRollsBack(t *testing.T) {
tx := &fakeTx{execErrOnCall: 3, execErr: errors.New("prune failed")} tx := &fakeTx{execErrOnCall: 3, execErr: errors.New("prune failed")}
db := &fakeDB{tx: tx} db := &fakeDB{tx: tx}
sink := &PostgresSink{ sink := &PostgresSink{
name: "pg", name: "pg",
db: db, db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{ return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}}, {Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}}, {Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
@@ -650,7 +655,7 @@ func TestPostgresSinkConsume_AutoPruneFailureRollsBack(t *testing.T) {
} }
} }
func TestPostgresSinkPrune_PerTable(t *testing.T) { func TestPostgresSinkPrunePerTable(t *testing.T) {
db := &fakeDB{execRows: 7} db := &fakeDB{execRows: 7}
sink := &PostgresSink{ sink := &PostgresSink{
name: "pg", name: "pg",
@@ -693,7 +698,7 @@ func TestPostgresSinkPrune_PerTable(t *testing.T) {
} }
} }
func TestPostgresSinkPrune_AllTables(t *testing.T) { func TestPostgresSinkPruneAllTables(t *testing.T) {
db := &fakeDB{execRows: 3} db := &fakeDB{execRows: 3}
sink := &PostgresSink{ sink := &PostgresSink{
name: "pg", name: "pg",
@@ -724,7 +729,7 @@ func TestPostgresSinkPrune_AllTables(t *testing.T) {
} }
} }
func TestPostgresSinkPrune_Errors(t *testing.T) { func TestPostgresSinkPruneErrors(t *testing.T) {
db := &fakeDB{} db := &fakeDB{}
sink := &PostgresSink{ sink := &PostgresSink{
name: "pg", name: "pg",

View File

@@ -2,6 +2,7 @@ package sinks
import ( import (
"fmt" "fmt"
"strings"
"gitea.maximumdirect.net/ejr/feedkit/config" "gitea.maximumdirect.net/ejr/feedkit/config"
) )
@@ -21,13 +22,40 @@ func NewRegistry() *Registry {
} }
func (r *Registry) Register(driver string, f Factory) { func (r *Registry) Register(driver string, f Factory) {
if r == nil {
panic("sinks.Registry.Register: registry cannot be nil")
}
driver = strings.TrimSpace(driver)
if driver == "" {
panic("sinks.Registry.Register: driver cannot be empty")
}
if f == nil {
panic(fmt.Sprintf("sinks.Registry.Register: factory cannot be nil (driver=%q)", driver))
}
if r.byDriver == nil {
r.byDriver = map[string]Factory{}
}
if _, exists := r.byDriver[driver]; exists {
panic(fmt.Sprintf("sinks.Registry.Register: driver %q already registered", driver))
}
r.byDriver[driver] = f r.byDriver[driver] = f
} }
func (r *Registry) Build(cfg config.SinkConfig) (Sink, error) { func (r *Registry) Build(cfg config.SinkConfig) (Sink, error) {
f, ok := r.byDriver[cfg.Driver] if r == nil {
if !ok { return nil, fmt.Errorf("sinks registry is nil")
return nil, fmt.Errorf("unknown sink driver: %q", cfg.Driver)
} }
return f(cfg) driver := strings.TrimSpace(cfg.Driver)
f, ok := r.byDriver[driver]
if !ok {
return nil, fmt.Errorf("unknown sink driver: %q", driver)
}
sink, err := f(cfg)
if err != nil {
return nil, fmt.Errorf("build sink %q: %w", driver, err)
}
if sink == nil {
return nil, fmt.Errorf("build sink %q: factory returned nil sink", driver)
}
return sink, nil
} }

126
sinks/registry_test.go Normal file
View File

@@ -0,0 +1,126 @@
package sinks
import (
"context"
"errors"
"strings"
"testing"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type testSink struct{ name string }
func (s testSink) Name() string { return s.name }
func (s testSink) Consume(context.Context, event.Event) error { return nil }
func TestRegistryRegisterPanicsOnNilRegistry(t *testing.T) {
var r *Registry
defer func() {
if recover() == nil {
t.Fatalf("Register() expected panic on nil registry")
}
}()
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "stdout"}, nil })
}
func TestRegistryRegisterPanicsOnEmptyDriver(t *testing.T) {
r := NewRegistry()
defer func() {
if recover() == nil {
t.Fatalf("Register() expected panic on empty driver")
}
}()
r.Register(" ", func(config.SinkConfig) (Sink, error) { return testSink{name: "x"}, nil })
}
func TestRegistryRegisterPanicsOnNilFactory(t *testing.T) {
r := NewRegistry()
defer func() {
if recover() == nil {
t.Fatalf("Register() expected panic on nil factory")
}
}()
r.Register("stdout", nil)
}
func TestRegistryRegisterPanicsOnDuplicateDriver(t *testing.T) {
r := NewRegistry()
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "a"}, nil })
defer func() {
if recover() == nil {
t.Fatalf("Register() expected panic on duplicate driver")
}
}()
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "b"}, nil })
}
func TestRegistryBuildNilRegistryFails(t *testing.T) {
var r *Registry
_, err := r.Build(config.SinkConfig{Driver: "stdout"})
if err == nil {
t.Fatalf("Build() expected error for nil registry")
}
if !strings.Contains(err.Error(), "registry is nil") {
t.Fatalf("Build() error = %q, want registry is nil", err)
}
}
func TestRegistryBuildTrimsDriver(t *testing.T) {
r := NewRegistry()
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "stdout"}, nil })
sink, err := r.Build(config.SinkConfig{Name: "sink1", Driver: " stdout "})
if err != nil {
t.Fatalf("Build() error = %v", err)
}
if sink.Name() != "stdout" {
t.Fatalf("Build() sink name = %q, want stdout", sink.Name())
}
}
func TestRegistryBuildWrapsFactoryError(t *testing.T) {
r := NewRegistry()
r.Register("broken", func(config.SinkConfig) (Sink, error) { return nil, errors.New("boom") })
_, err := r.Build(config.SinkConfig{Driver: "broken"})
if err == nil {
t.Fatalf("Build() expected error")
}
if !strings.Contains(err.Error(), `build sink "broken": boom`) {
t.Fatalf("Build() error = %q", err)
}
}
func TestRegistryBuildRejectsNilSink(t *testing.T) {
r := NewRegistry()
r.Register("nil_sink", func(config.SinkConfig) (Sink, error) { return nil, nil })
_, err := r.Build(config.SinkConfig{Driver: "nil_sink"})
if err == nil {
t.Fatalf("Build() expected nil sink error")
}
if !strings.Contains(err.Error(), "factory returned nil sink") {
t.Fatalf("Build() error = %q", err)
}
}
func TestRegisterBuiltinsExposesExpectedDrivers(t *testing.T) {
r := NewRegistry()
RegisterBuiltins(r)
if len(r.byDriver) != 2 {
t.Fatalf("len(byDriver) = %d, want 2", len(r.byDriver))
}
for _, driver := range []string{"stdout", "nats"} {
if _, ok := r.byDriver[driver]; !ok {
t.Fatalf("builtins missing driver %q", driver)
}
}
if _, ok := r.byDriver["postgres"]; ok {
t.Fatalf("builtins unexpectedly registered postgres driver")
}
}

View File

@@ -1,10 +1,16 @@
// Package sources defines feedkit's input-source abstraction. // Package sources defines feedkit's input-source abstractions and source
// registry.
// //
// A source ingests upstream input and emits one or more event.Event values. // External API surface:
// // - Input: common source identity surface
// feedkit supports two source modes: // - PollSource: polling source interface
// - PollSource: scheduler invokes Poll on a cadence. // - StreamSource: streaming source interface
// - StreamSource: source runs continuously and pushes events as input arrives. // - StreamRetryable / StreamFatal / IsStreamRetryable / IsStreamFatal:
// stream exit classification helpers
// - Registry / NewRegistry: source driver registry and builders
// - HTTPSource / NewHTTPSource: reusable HTTP polling helper
// - PostgresQuerySource / NewPostgresQuerySource: reusable Postgres polling
// helper
// //
// Source drivers are domain-specific and registered into Registry by driver name. // Source drivers are domain-specific and registered into Registry by driver name.
// Registry can then build configured sources from config.SourceConfig. // Registry can then build configured sources from config.SourceConfig.
@@ -12,13 +18,35 @@
// A single source may emit 0..N events per poll or stream iteration, and those // A single source may emit 0..N events per poll or stream iteration, and those
// events may span multiple event kinds. // events may span multiple event kinds.
// //
// Optional helpers from helpers.go:
// - DefaultEventID: default event ID policy for source implementations
// - SingleEvent: construct and validate a one-element event slice
// - ValidateExpectedKinds: compare configured expected kinds against source
// advertised kinds when metadata is available
//
// HTTP-backed polling sources can share NewHTTPSource for generic HTTP config // HTTP-backed polling sources can share NewHTTPSource for generic HTTP config
// parsing and conditional GET behavior. The helper understands: // parsing and conditional GET behavior. The helper understands:
// - params.url // - params.url
// - params.user_agent // - params.user_agent
// - params.conditional (optional, default true) // - params.conditional (optional, default true)
// - params.http_timeout (optional, default transport.DefaultHTTPTimeout)
// - params.http_response_body_limit_bytes (optional, default
// transport.DefaultHTTPResponseBodyLimitBytes)
// //
// When validators are available, NewHTTPSource prefers ETag/If-None-Match and // When validators are available, NewHTTPSource prefers ETag/If-None-Match and
// falls back to Last-Modified/If-Modified-Since. A 304 Not Modified response is // falls back to Last-Modified/If-Modified-Since. A 304 Not Modified response is
// treated as a successful unchanged poll. // treated as a successful unchanged poll.
//
// Postgres-backed polling sources can share NewPostgresQuerySource for generic
// DB config parsing and query execution. The helper understands:
// - params.uri
// - params.username
// - params.password
// - params.query
// - params.query_timeout (optional, default 30s)
//
// feedkit does not register a built-in postgres poll driver. Downstream daemons
// should register domain-specific driver names that call
// NewPostgresQuerySource, then keep SQL semantics, row scanning, ordering,
// watermark policy, and event construction in their own source types.
package sources package sources

140
sources/helpers.go Normal file
View File

@@ -0,0 +1,140 @@
package sources
import (
"fmt"
"sort"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// DefaultEventID applies feedkit's default Event.ID policy:
//
// - If upstream provides an ID, use it (trimmed).
// - Otherwise, ID is "<Source>:<EffectiveAt>" when available.
// - If EffectiveAt is unavailable, fall back to "<Source>:<EmittedAt>".
//
// Timestamps are encoded as RFC3339Nano in UTC.
func DefaultEventID(upstreamID, sourceName string, effectiveAt *time.Time, emittedAt time.Time) string {
if id := strings.TrimSpace(upstreamID); id != "" {
return id
}
src := strings.TrimSpace(sourceName)
if src == "" {
src = "UNKNOWN_SOURCE"
}
if effectiveAt != nil && !effectiveAt.IsZero() {
return fmt.Sprintf("%s:%s", src, effectiveAt.UTC().Format(time.RFC3339Nano))
}
t := emittedAt.UTC()
if t.IsZero() {
t = time.Now().UTC()
}
return fmt.Sprintf("%s:%s", src, t.Format(time.RFC3339Nano))
}
// SingleEvent constructs, validates, and returns a slice containing exactly one event.
func SingleEvent(
kind event.Kind,
sourceName string,
schema string,
id string,
emittedAt time.Time,
effectiveAt *time.Time,
payload any,
) ([]event.Event, error) {
if emittedAt.IsZero() {
emittedAt = time.Now().UTC()
} else {
emittedAt = emittedAt.UTC()
}
e := event.Event{
ID: id,
Kind: kind,
Source: sourceName,
EmittedAt: emittedAt,
EffectiveAt: effectiveAt,
Schema: schema,
Payload: payload,
}
if err := e.Validate(); err != nil {
return nil, err
}
return []event.Event{e}, nil
}
// ValidateExpectedKinds checks that configured source expected kinds are a subset
// of the kinds advertised by the built source, when the source exposes kind
// metadata. If the source does not advertise kinds, the check is skipped.
func ValidateExpectedKinds(cfg config.SourceConfig, in Input) error {
expectedKinds, err := parseExpectedKinds(cfg.ExpectedKinds())
if err != nil {
return err
}
if len(expectedKinds) == 0 {
return nil
}
advertisedKinds := advertisedSourceKinds(in)
if len(advertisedKinds) == 0 {
return nil
}
for kind := range expectedKinds {
if !advertisedKinds[kind] {
return fmt.Errorf(
"configured expected kind %q not advertised by source (configured=%v advertised=%v)",
kind,
sortedKinds(expectedKinds),
sortedKinds(advertisedKinds),
)
}
}
return nil
}
func parseExpectedKinds(raw []string) (map[event.Kind]bool, error) {
kinds := map[event.Kind]bool{}
for i, k := range raw {
kind, err := event.ParseKind(k)
if err != nil {
return nil, fmt.Errorf("invalid expected kind at index %d (%q): %w", i, k, err)
}
kinds[kind] = true
}
return kinds, nil
}
func advertisedSourceKinds(in Input) map[event.Kind]bool {
if in == nil {
return nil
}
kinds := map[event.Kind]bool{}
if ks, ok := in.(KindsSource); ok {
for _, kind := range ks.Kinds() {
kinds[kind] = true
}
return kinds
}
return nil
}
func sortedKinds(kindSet map[event.Kind]bool) []string {
out := make([]string, 0, len(kindSet))
for kind := range kindSet {
out = append(out, string(kind))
}
sort.Strings(out)
return out
}

112
sources/helpers_test.go Normal file
View File

@@ -0,0 +1,112 @@
package sources
import (
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type testInput struct {
name string
}
func (s testInput) Name() string { return s.name }
type testKindsSource struct {
testInput
kinds []event.Kind
}
func (s testKindsSource) Kinds() []event.Kind { return s.kinds }
func TestValidateExpectedKindsSubsetAllowed(t *testing.T) {
cfg := config.SourceConfig{Kinds: []string{"observation"}}
in := testKindsSource{
testInput: testInput{name: "test"},
kinds: []event.Kind{"observation", "forecast"},
}
if err := ValidateExpectedKinds(cfg, in); err != nil {
t.Fatalf("ValidateExpectedKinds() unexpected error: %v", err)
}
}
func TestValidateExpectedKindsMismatchFails(t *testing.T) {
cfg := config.SourceConfig{Kinds: []string{"alert"}}
in := testKindsSource{
testInput: testInput{name: "test"},
kinds: []event.Kind{"observation", "forecast"},
}
err := ValidateExpectedKinds(cfg, in)
if err == nil {
t.Fatalf("ValidateExpectedKinds() expected mismatch error, got nil")
}
if !strings.Contains(err.Error(), "configured expected kind") {
t.Fatalf("ValidateExpectedKinds() error %q does not include expected message", err)
}
}
func TestValidateExpectedKindsNoMetadataSkipsCheck(t *testing.T) {
cfg := config.SourceConfig{Kinds: []string{"alert"}}
in := testInput{name: "test"}
if err := ValidateExpectedKinds(cfg, in); err != nil {
t.Fatalf("ValidateExpectedKinds() unexpected error: %v", err)
}
}
func TestDefaultEventIDUsesUpstreamID(t *testing.T) {
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123, time.UTC)
got := DefaultEventID(" upstream-id ", "source", nil, emittedAt)
if got != "upstream-id" {
t.Fatalf("DefaultEventID() = %q, want upstream-id", got)
}
}
func TestDefaultEventIDPrefersEffectiveAt(t *testing.T) {
effectiveAt := time.Date(2026, 3, 28, 16, 4, 5, 987654321, time.FixedZone("x", -6*3600))
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123, time.UTC)
got := DefaultEventID("", "source", &effectiveAt, emittedAt)
want := "source:" + effectiveAt.UTC().Format(time.RFC3339Nano)
if got != want {
t.Fatalf("DefaultEventID() = %q, want %q", got, want)
}
}
func TestDefaultEventIDFallsBackToEmittedAt(t *testing.T) {
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123456789, time.FixedZone("y", 3*3600))
got := DefaultEventID("", "source", nil, emittedAt)
want := "source:" + emittedAt.UTC().Format(time.RFC3339Nano)
if got != want {
t.Fatalf("DefaultEventID() = %q, want %q", got, want)
}
}
func TestSingleEventBuildsValidatedSlice(t *testing.T) {
effectiveAt := time.Date(2026, 3, 28, 16, 0, 0, 0, time.UTC)
emittedAt := time.Date(2026, 3, 28, 15, 0, 0, 0, time.FixedZone("z", -5*3600))
got, err := SingleEvent(
event.Kind("observation"),
"source-a",
"raw.example.v1",
"evt-1",
emittedAt,
&effectiveAt,
map[string]any{"ok": true},
)
if err != nil {
t.Fatalf("SingleEvent() unexpected error: %v", err)
}
if len(got) != 1 {
t.Fatalf("SingleEvent() len = %d, want 1", len(got))
}
if got[0].EmittedAt != emittedAt.UTC() {
t.Fatalf("SingleEvent() emittedAt = %s, want %s", got[0].EmittedAt, emittedAt.UTC())
}
}

View File

@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"strconv"
"strings" "strings"
"sync" "sync"
@@ -20,13 +19,14 @@ import (
// setup, and conditional GET validator handling. Concrete daemon sources remain // setup, and conditional GET validator handling. Concrete daemon sources remain
// responsible for decoding the response body and constructing events. // responsible for decoding the response body and constructing events.
type HTTPSource struct { type HTTPSource struct {
Driver string Driver string
Name string Name string
URL string URL string
UserAgent string UserAgent string
Accept string Accept string
Conditional bool Conditional bool
Client *http.Client ResponseBodyLimitBytes int64
Client *http.Client
mu sync.Mutex mu sync.Mutex
validators transport.HTTPValidators validators transport.HTTPValidators
@@ -60,19 +60,42 @@ func NewHTTPSource(driver string, cfg config.SourceConfig, accept string) (*HTTP
return nil, fmt.Errorf("%s %q: params.user_agent is required", driver, cfg.Name) return nil, fmt.Errorf("%s %q: params.user_agent is required", driver, cfg.Name)
} }
conditional, err := parseConditionalParam(cfg) conditional := true
if err != nil { if _, exists := cfg.Params["conditional"]; exists {
return nil, err var ok bool
conditional, ok = cfg.ParamBool("conditional")
if !ok {
return nil, fmt.Errorf("source %q: params.conditional must be a boolean", cfg.Name)
}
}
timeout := transport.DefaultHTTPTimeout
if _, exists := cfg.Params["http_timeout"]; exists {
var ok bool
timeout, ok = cfg.ParamDuration("http_timeout")
if !ok || timeout <= 0 {
return nil, fmt.Errorf("source %q: params.http_timeout must be a positive duration", cfg.Name)
}
}
bodyLimit := transport.DefaultHTTPResponseBodyLimitBytes
if _, exists := cfg.Params["http_response_body_limit_bytes"]; exists {
rawLimit, ok := cfg.ParamInt("http_response_body_limit_bytes")
if !ok || rawLimit <= 0 {
return nil, fmt.Errorf("source %q: params.http_response_body_limit_bytes must be a positive integer", cfg.Name)
}
bodyLimit = int64(rawLimit)
} }
return &HTTPSource{ return &HTTPSource{
Driver: driver, Driver: driver,
Name: name, Name: name,
URL: url, URL: url,
UserAgent: userAgent, UserAgent: userAgent,
Accept: accept, Accept: accept,
Conditional: conditional, Conditional: conditional,
Client: transport.NewHTTPClient(transport.DefaultHTTPTimeout), ResponseBodyLimitBytes: bodyLimit,
Client: transport.NewHTTPClient(timeout),
}, nil }, nil
} }
@@ -89,7 +112,12 @@ func (s *HTTPSource) FetchBytesIfChanged(ctx context.Context) ([]byte, bool, err
validators := s.validators validators := s.validators
s.mu.Unlock() s.mu.Unlock()
body, changed, next, err := transport.FetchBodyIfChanged( bodyLimit := s.ResponseBodyLimitBytes
if bodyLimit <= 0 {
bodyLimit = transport.DefaultHTTPResponseBodyLimitBytes
}
body, changed, next, err := transport.FetchBodyIfChangedWithLimit(
ctx, ctx,
client, client,
s.URL, s.URL,
@@ -97,6 +125,7 @@ func (s *HTTPSource) FetchBytesIfChanged(ctx context.Context) ([]byte, bool, err
s.Accept, s.Accept,
s.Conditional, s.Conditional,
validators, validators,
bodyLimit,
) )
if err != nil { if err != nil {
return nil, false, fmt.Errorf("%s %q: %w", s.Driver, s.Name, err) return nil, false, fmt.Errorf("%s %q: %w", s.Driver, s.Name, err)
@@ -121,27 +150,3 @@ func (s *HTTPSource) FetchJSONIfChanged(ctx context.Context) (json.RawMessage, b
} }
return json.RawMessage(body), true, nil return json.RawMessage(body), true, nil
} }
func parseConditionalParam(cfg config.SourceConfig) (bool, error) {
raw, ok := cfg.Params["conditional"]
if !ok || raw == nil {
return true, nil
}
switch v := raw.(type) {
case bool:
return v, nil
case string:
s := strings.TrimSpace(v)
if s == "" {
return false, fmt.Errorf("source %q: params.conditional must be a boolean", cfg.Name)
}
parsed, err := strconv.ParseBool(s)
if err != nil {
return false, fmt.Errorf("source %q: params.conditional must be a boolean", cfg.Name)
}
return parsed, nil
default:
return false, fmt.Errorf("source %q: params.conditional must be a boolean", cfg.Name)
}
}

View File

@@ -4,9 +4,12 @@ import (
"context" "context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config" "gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/transport"
) )
func TestNewHTTPSourceConditionalDefaultsTrue(t *testing.T) { func TestNewHTTPSourceConditionalDefaultsTrue(t *testing.T) {
@@ -39,6 +42,102 @@ func TestNewHTTPSourceRejectsInvalidConditional(t *testing.T) {
if err == nil { if err == nil {
t.Fatalf("NewHTTPSource() error = nil, want error") t.Fatalf("NewHTTPSource() error = nil, want error")
} }
if !strings.Contains(err.Error(), "params.conditional must be a boolean") {
t.Fatalf("NewHTTPSource() error = %q, want params.conditional must be a boolean", err)
}
}
func TestNewHTTPSourceConditionalCanBeExplicitlyFalse(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"conditional": false,
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.Conditional {
t.Fatalf("Conditional = true, want false")
}
}
func TestNewHTTPSourceHTTPTimeoutOverride(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"http_timeout": "250ms",
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.Client == nil {
t.Fatalf("Client = nil")
}
if src.Client.Timeout != 250*time.Millisecond {
t.Fatalf("Client.Timeout = %s, want 250ms", src.Client.Timeout)
}
}
func TestNewHTTPSourceBodyLimitOverride(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"http_response_body_limit_bytes": 12345,
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.ResponseBodyLimitBytes != 12345 {
t.Fatalf("ResponseBodyLimitBytes = %d, want 12345", src.ResponseBodyLimitBytes)
}
}
func TestNewHTTPSourceRejectsInvalidHTTPTimeout(t *testing.T) {
_, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"http_timeout": "soon",
},
}, "application/json")
if err == nil {
t.Fatalf("NewHTTPSource() error = nil, want error")
}
if !strings.Contains(err.Error(), "params.http_timeout must be a positive duration") {
t.Fatalf("NewHTTPSource() error = %q", err)
}
}
func TestNewHTTPSourceRejectsInvalidBodyLimit(t *testing.T) {
_, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"http_response_body_limit_bytes": "abc",
},
}, "application/json")
if err == nil {
t.Fatalf("NewHTTPSource() error = nil, want error")
}
if !strings.Contains(err.Error(), "params.http_response_body_limit_bytes must be a positive integer") {
t.Fatalf("NewHTTPSource() error = %q", err)
}
} }
func TestHTTPSourceFetchJSONIfChanged(t *testing.T) { func TestHTTPSourceFetchJSONIfChanged(t *testing.T) {
@@ -94,3 +193,68 @@ func TestHTTPSourceFetchJSONIfChanged(t *testing.T) {
t.Fatalf("second FetchJSONIfChanged() body = %q, want nil", string(raw)) t.Fatalf("second FetchJSONIfChanged() body = %q, want nil", string(raw))
} }
} }
func TestHTTPSourceFetchJSONIfChangedHonorsBodyLimitOverride(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": srv.URL,
"user_agent": "test-agent",
"http_response_body_limit_bytes": 4,
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
_, _, err = src.FetchJSONIfChanged(context.Background())
if err == nil {
t.Fatalf("FetchJSONIfChanged() error = nil, want limit error")
}
if !strings.Contains(err.Error(), "response body too large") {
t.Fatalf("FetchJSONIfChanged() error = %q", err)
}
}
func TestNewHTTPSourceUsesDefaultBodyLimitWhenUnset(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.ResponseBodyLimitBytes != transport.DefaultHTTPResponseBodyLimitBytes {
t.Fatalf("ResponseBodyLimitBytes = %d, want %d", src.ResponseBodyLimitBytes, transport.DefaultHTTPResponseBodyLimitBytes)
}
}
func TestNewHTTPSourceUsesDefaultTimeoutWhenUnset(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.Client == nil {
t.Fatalf("Client = nil")
}
if src.Client.Timeout != transport.DefaultHTTPTimeout {
t.Fatalf("Client.Timeout = %s, want %s", src.Client.Timeout, transport.DefaultHTTPTimeout)
}
}

117
sources/postgres.go Normal file
View File

@@ -0,0 +1,117 @@
package sources
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
)
const defaultPostgresQueryTimeout = 30 * time.Second
type postgresQueryDB interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
var openPostgresQueryDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) {
return pgconn.Open(ctx, cfg)
}
// PostgresQuerySource is a reusable helper for polling Postgres-backed sources.
//
// It centralizes generic source config parsing and query execution. Concrete
// daemon sources remain responsible for SQL semantics, row scanning, cursoring,
// and event construction.
type PostgresQuerySource struct {
Driver string
Name string
SQL string
QueryTimeout time.Duration
db postgresQueryDB
}
// NewPostgresQuerySource builds a generic Postgres polling helper from
// SourceConfig.
//
// Required params:
// - params.uri
// - params.username
// - params.password
// - params.query
//
// Optional params:
// - params.query_timeout (default 30s)
func NewPostgresQuerySource(driver string, cfg config.SourceConfig) (*PostgresQuerySource, error) {
name := strings.TrimSpace(cfg.Name)
if name == "" {
return nil, fmt.Errorf("%s: name is required", driver)
}
if cfg.Params == nil {
return nil, fmt.Errorf("%s %q: params are required (need params.uri, params.username, params.password, and params.query)", driver, cfg.Name)
}
uri, ok := cfg.ParamString("uri")
if !ok {
return nil, fmt.Errorf("%s %q: params.uri is required", driver, cfg.Name)
}
username, ok := cfg.ParamString("username")
if !ok {
return nil, fmt.Errorf("%s %q: params.username is required", driver, cfg.Name)
}
password, ok := cfg.ParamString("password")
if !ok {
return nil, fmt.Errorf("%s %q: params.password is required", driver, cfg.Name)
}
query, ok := cfg.ParamString("query")
if !ok {
return nil, fmt.Errorf("%s %q: params.query is required", driver, cfg.Name)
}
queryTimeout := defaultPostgresQueryTimeout
if _, exists := cfg.Params["query_timeout"]; exists {
var ok bool
queryTimeout, ok = cfg.ParamDuration("query_timeout")
if !ok || queryTimeout <= 0 {
return nil, fmt.Errorf("source %q: params.query_timeout must be a positive duration", cfg.Name)
}
}
db, err := openPostgresQueryDB(context.Background(), pgconn.ConnConfig{
URI: uri,
Username: username,
Password: password,
})
if err != nil {
return nil, fmt.Errorf("%s %q: open db: %w", driver, cfg.Name, err)
}
return &PostgresQuerySource{
Driver: driver,
Name: name,
SQL: query,
QueryTimeout: queryTimeout,
db: db,
}, nil
}
func (s *PostgresQuerySource) Query(ctx context.Context, args ...any) (*sql.Rows, error) {
queryCtx := ctx
if s.QueryTimeout > 0 {
if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > s.QueryTimeout {
// We intentionally do not cancel this derived context here because the
// returned rows may still be reading from the database.
queryCtx, _ = context.WithTimeout(ctx, s.QueryTimeout)
}
}
rows, err := s.db.QueryContext(queryCtx, s.SQL, args...)
if err != nil {
return nil, fmt.Errorf("%s %q: query: %w", s.Driver, s.Name, err)
}
return rows, nil
}

352
sources/postgres_test.go Normal file
View File

@@ -0,0 +1,352 @@
package sources
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"io"
"strings"
"sync"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
)
type fakePostgresQueryDB struct {
queryErr error
lastCtx context.Context
lastQuery string
lastArgs []any
returnRows *sql.Rows
}
func (db *fakePostgresQueryDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
db.lastCtx = ctx
db.lastQuery = query
db.lastArgs = append([]any(nil), args...)
if db.queryErr != nil {
return nil, db.queryErr
}
return db.returnRows, nil
}
func withPostgresQuerySourceTestState(t *testing.T) {
t.Helper()
oldOpen := openPostgresQueryDB
t.Cleanup(func() {
openPostgresQueryDB = oldOpen
})
}
func TestNewPostgresQuerySourceMissingParams(t *testing.T) {
withPostgresQuerySourceTestState(t)
tests := []struct {
name string
params map[string]any
want string
}{
{name: "missing uri", params: map[string]any{"username": "u", "password": "p", "query": "SELECT 1"}, want: "params.uri"},
{name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p", "query": "SELECT 1"}, want: "params.username"},
{name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "query": "SELECT 1"}, want: "params.password"},
{name: "missing query", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "password": "p"}, want: "params.query"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: tc.params,
})
if err == nil {
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
}
if !strings.Contains(err.Error(), tc.want) {
t.Fatalf("NewPostgresQuerySource() error = %q, want substring %q", err, tc.want)
}
})
}
}
func TestNewPostgresQuerySourceRejectsInvalidQueryTimeout(t *testing.T) {
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "u",
"password": "p",
"query": "SELECT 1",
"query_timeout": "soon",
},
})
if err == nil {
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
}
if !strings.Contains(err.Error(), "params.query_timeout must be a positive duration") {
t.Fatalf("NewPostgresQuerySource() error = %q", err)
}
}
func TestNewPostgresQuerySourceSuccessfulConstruction(t *testing.T) {
withPostgresQuerySourceTestState(t)
db := &fakePostgresQueryDB{}
var gotCfg pgconn.ConnConfig
openPostgresQueryDB = func(_ context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) {
gotCfg = cfg
return db, nil
}
src, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: map[string]any{
"uri": "postgres://db.example.local/feedkit",
"username": "app_user",
"password": "app_pass",
"query": "SELECT * FROM observations",
"query_timeout": "45s",
},
})
if err != nil {
t.Fatalf("NewPostgresQuerySource() error = %v", err)
}
if src.Name != "pg-source" {
t.Fatalf("Name = %q, want pg-source", src.Name)
}
if src.QueryTimeout != 45*time.Second {
t.Fatalf("QueryTimeout = %s, want 45s", src.QueryTimeout)
}
if src.SQL != "SELECT * FROM observations" {
t.Fatalf("SQL = %q", src.SQL)
}
if gotCfg.Username != "app_user" || gotCfg.Password != "app_pass" {
t.Fatalf("ConnConfig = %+v", gotCfg)
}
}
func TestNewPostgresQuerySourceOpenFailure(t *testing.T) {
withPostgresQuerySourceTestState(t)
openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) {
return nil, errors.New("db unavailable")
}
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "u",
"password": "p",
"query": "SELECT 1",
},
})
if err == nil {
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
}
if !strings.Contains(err.Error(), `test_driver "pg-source": open db: db unavailable`) {
t.Fatalf("NewPostgresQuerySource() error = %q", err)
}
}
func TestPostgresQuerySourceQueryAppliesTimeoutAndWrapsError(t *testing.T) {
db := &fakePostgresQueryDB{queryErr: errors.New("query failed")}
src := &PostgresQuerySource{
Driver: "test_driver",
Name: "pg-source",
SQL: "SELECT 1",
QueryTimeout: 30 * time.Second,
db: db,
}
ctx := context.Background()
_, err := src.Query(ctx, "arg1")
if err == nil {
t.Fatalf("Query() error = nil, want error")
}
if !strings.Contains(err.Error(), `test_driver "pg-source": query: query failed`) {
t.Fatalf("Query() error = %q", err)
}
if db.lastCtx == nil {
t.Fatalf("lastCtx = nil")
}
if _, ok := db.lastCtx.Deadline(); !ok {
t.Fatalf("expected derived deadline on query context")
}
if db.lastQuery != "SELECT 1" {
t.Fatalf("lastQuery = %q", db.lastQuery)
}
if len(db.lastArgs) != 1 || db.lastArgs[0] != "arg1" {
t.Fatalf("lastArgs = %#v", db.lastArgs)
}
}
func TestPostgresQuerySourceQueryUsesEarlierCallerDeadline(t *testing.T) {
db := &fakePostgresQueryDB{queryErr: errors.New("query failed")}
src := &PostgresQuerySource{
Driver: "test_driver",
Name: "pg-source",
SQL: "SELECT 1",
QueryTimeout: 30 * time.Second,
db: db,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
_, _ = src.Query(ctx)
if db.lastCtx != ctx {
t.Fatalf("expected source to reuse earlier caller deadline")
}
}
func TestPostgresQuerySourceSupportsDownstreamPollingPattern(t *testing.T) {
withPostgresQuerySourceTestState(t)
db, cleanup := openRowsTestDB(t, "feedkit_sources_pg_rows", []string{"event_id"}, [][]driver.Value{{"evt-1"}})
defer cleanup()
openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) {
return db, nil
}
type fakeDownstreamSource struct {
pg *PostgresQuerySource
}
poll := func(s fakeDownstreamSource, ctx context.Context) ([]event.Event, error) {
rows, err := s.pg.Query(ctx)
if err != nil {
return nil, err
}
defer rows.Close()
var out []event.Event
for rows.Next() {
var eventID string
if err := rows.Scan(&eventID); err != nil {
return nil, err
}
out = append(out, event.Event{
ID: eventID,
Kind: event.Kind("observation"),
Source: s.pg.Name,
Schema: "raw.test.v1",
EmittedAt: time.Now().UTC(),
Payload: map[string]any{"event_id": eventID},
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
pg, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "u",
"password": "p",
"query": "SELECT event_id FROM events",
},
})
if err != nil {
t.Fatalf("NewPostgresQuerySource() error = %v", err)
}
events, err := poll(fakeDownstreamSource{pg: pg}, context.Background())
if err != nil {
t.Fatalf("poll() error = %v", err)
}
if len(events) != 1 {
t.Fatalf("len(events) = %d, want 1", len(events))
}
if events[0].ID != "evt-1" {
t.Fatalf("events[0].ID = %q, want evt-1", events[0].ID)
}
}
var (
rowsDriverMu sync.Mutex
rowsDriverSeen = map[string]bool{}
)
func openRowsTestDB(t *testing.T, driverName string, columns []string, rows [][]driver.Value) (*sql.DB, func()) {
t.Helper()
rowsDriverMu.Lock()
if !rowsDriverSeen[driverName] {
sql.Register(driverName, &rowsTestDriver{columns: append([]string(nil), columns...), rows: cloneDriverRows(rows)})
rowsDriverSeen[driverName] = true
}
rowsDriverMu.Unlock()
db, err := sql.Open(driverName, "")
if err != nil {
t.Fatalf("sql.Open() error = %v", err)
}
return db, func() {
_ = db.Close()
}
}
func cloneDriverRows(in [][]driver.Value) [][]driver.Value {
out := make([][]driver.Value, 0, len(in))
for _, row := range in {
copied := append([]driver.Value(nil), row...)
out = append(out, copied)
}
return out
}
type rowsTestDriver struct {
columns []string
rows [][]driver.Value
}
func (d *rowsTestDriver) Open(string) (driver.Conn, error) {
return &rowsTestConn{columns: append([]string(nil), d.columns...), rows: cloneDriverRows(d.rows)}, nil
}
type rowsTestConn struct {
columns []string
rows [][]driver.Value
}
func (c *rowsTestConn) Prepare(string) (driver.Stmt, error) {
return nil, errors.New("not implemented")
}
func (c *rowsTestConn) Close() error { return nil }
func (c *rowsTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") }
func (c *rowsTestConn) QueryContext(_ context.Context, _ string, _ []driver.NamedValue) (driver.Rows, error) {
return &rowsTestRows{columns: append([]string(nil), c.columns...), rows: cloneDriverRows(c.rows)}, nil
}
type rowsTestRows struct {
columns []string
rows [][]driver.Value
idx int
}
func (r *rowsTestRows) Columns() []string { return append([]string(nil), r.columns...) }
func (r *rowsTestRows) Close() error { return nil }
func (r *rowsTestRows) Next(dest []driver.Value) error {
if r.idx >= len(r.rows) {
return io.EOF
}
copy(dest, r.rows[r.idx])
r.idx++
return nil
}

View File

@@ -15,9 +15,6 @@ import (
type PollFactory func(cfg config.SourceConfig) (PollSource, error) type PollFactory func(cfg config.SourceConfig) (PollSource, error)
type StreamFactory func(cfg config.SourceConfig) (StreamSource, error) type StreamFactory func(cfg config.SourceConfig) (StreamSource, error)
// Factory is the legacy alias for poll source factories.
type Factory = PollFactory
type Registry struct { type Registry struct {
byPollDriver map[string]PollFactory byPollDriver map[string]PollFactory
byStreamDriver map[string]StreamFactory byStreamDriver map[string]StreamFactory
@@ -30,13 +27,6 @@ func NewRegistry() *Registry {
} }
} }
// Register associates a driver name (e.g. "openmeteo_observation") with a factory.
//
// The driver string is the "lookup key" used by config.sources[].driver.
func (r *Registry) Register(driver string, f PollFactory) {
r.RegisterPoll(driver, f)
}
// RegisterPoll associates a driver name with a polling-source factory. // RegisterPoll associates a driver name with a polling-source factory.
func (r *Registry) RegisterPoll(driver string, f PollFactory) { func (r *Registry) RegisterPoll(driver string, f PollFactory) {
driver = strings.TrimSpace(driver) driver = strings.TrimSpace(driver)
@@ -75,11 +65,6 @@ func (r *Registry) RegisterStream(driver string, f StreamFactory) {
r.byStreamDriver[driver] = f r.byStreamDriver[driver] = f
} }
// Build constructs a polling source from a SourceConfig by looking up cfg.Driver.
func (r *Registry) Build(cfg config.SourceConfig) (PollSource, error) {
return r.BuildPoll(cfg)
}
// BuildPoll constructs a polling source from a SourceConfig by looking up cfg.Driver. // BuildPoll constructs a polling source from a SourceConfig by looking up cfg.Driver.
func (r *Registry) BuildPoll(cfg config.SourceConfig) (PollSource, error) { func (r *Registry) BuildPoll(cfg config.SourceConfig) (PollSource, error) {
driver := strings.TrimSpace(cfg.Driver) driver := strings.TrimSpace(cfg.Driver)

View File

@@ -31,24 +31,18 @@ type PollSource interface {
Poll(ctx context.Context) ([]event.Event, error) Poll(ctx context.Context) ([]event.Event, error)
} }
// Source is a compatibility alias for the legacy polling-source name.
type Source = PollSource
// StreamSource is an event-driven source (NATS/RabbitMQ/MQTT/etc). // StreamSource is an event-driven source (NATS/RabbitMQ/MQTT/etc).
// //
// Run should block, producing events into `out` until ctx is cancelled or a fatal error occurs. // Run should block, producing events into `out` until ctx is cancelled or a fatal error occurs.
// It MUST NOT close out (the scheduler/daemon owns the bus). // It MUST NOT close out (the scheduler/daemon owns the bus).
//
// Stream sources can classify exits by wrapping errors with StreamRetryable or
// StreamFatal. Plain non-nil errors are treated as retryable by the scheduler.
type StreamSource interface { type StreamSource interface {
Input Input
Run(ctx context.Context, out chan<- event.Event) error Run(ctx context.Context, out chan<- event.Event) error
} }
// KindSource is an optional interface for sources that advertise one "primary" kind.
// This is legacy-friendly but no longer required.
type KindSource interface {
Kind() event.Kind
}
// KindsSource is an optional interface for sources that advertise multiple kinds. // KindsSource is an optional interface for sources that advertise multiple kinds.
type KindsSource interface { type KindsSource interface {
Kinds() []event.Kind Kinds() []event.Kind

63
sources/stream_errors.go Normal file
View File

@@ -0,0 +1,63 @@
package sources
import "errors"
type streamRetryableError struct {
err error
}
func (e *streamRetryableError) Error() string {
if e.err == nil {
return "retryable stream error"
}
return e.err.Error()
}
func (e *streamRetryableError) Unwrap() error { return e.err }
type streamFatalError struct {
err error
}
func (e *streamFatalError) Error() string {
if e.err == nil {
return "fatal stream error"
}
return e.err.Error()
}
func (e *streamFatalError) Unwrap() error { return e.err }
// StreamRetryable marks a stream-source exit as retryable.
func StreamRetryable(err error) error {
if err == nil {
return nil
}
return &streamRetryableError{err: err}
}
// StreamFatal marks a stream-source exit as fatal.
func StreamFatal(err error) error {
if err == nil {
return nil
}
return &streamFatalError{err: err}
}
// IsStreamRetryable reports whether err contains a retryable stream marker.
func IsStreamRetryable(err error) bool {
if err == nil {
return false
}
var target *streamRetryableError
return errors.As(err, &target)
}
// IsStreamFatal reports whether err contains a fatal stream marker.
func IsStreamFatal(err error) bool {
if err == nil {
return false
}
var target *streamFatalError
return errors.As(err, &target)
}

View File

@@ -0,0 +1,52 @@
package sources
import (
"errors"
"fmt"
"testing"
)
func TestStreamRetryableWrapsThroughErrorChains(t *testing.T) {
base := errors.New("retry me")
err := fmt.Errorf("outer: %w", StreamRetryable(base))
if !IsStreamRetryable(err) {
t.Fatalf("IsStreamRetryable() = false, want true")
}
if IsStreamFatal(err) {
t.Fatalf("IsStreamFatal() = true, want false")
}
if !errors.Is(err, base) {
t.Fatalf("errors.Is(err, base) = false, want true")
}
}
func TestStreamFatalWrapsThroughErrorChains(t *testing.T) {
base := errors.New("fatal")
err := fmt.Errorf("outer: %w", StreamFatal(base))
if !IsStreamFatal(err) {
t.Fatalf("IsStreamFatal() = false, want true")
}
if IsStreamRetryable(err) {
t.Fatalf("IsStreamRetryable() = true, want false")
}
if !errors.Is(err, base) {
t.Fatalf("errors.Is(err, base) = false, want true")
}
}
func TestStreamErrorHelpersNil(t *testing.T) {
if StreamRetryable(nil) != nil {
t.Fatalf("StreamRetryable(nil) != nil")
}
if StreamFatal(nil) != nil {
t.Fatalf("StreamFatal(nil) != nil")
}
if IsStreamRetryable(nil) {
t.Fatalf("IsStreamRetryable(nil) = true")
}
if IsStreamFatal(nil) {
t.Fatalf("IsStreamFatal(nil) = true")
}
}

View File

@@ -10,10 +10,10 @@ import (
"time" "time"
) )
// maxResponseBodyBytes is a hard safety limit on HTTP response bodies. // DefaultHTTPResponseBodyLimitBytes is a hard safety limit on HTTP response bodies.
// API responses should be small, so this protects us from accidental // API responses should be small, so this protects us from accidental
// or malicious large responses. // or malicious large responses.
const maxResponseBodyBytes = 2 << 21 // 4 MiB const DefaultHTTPResponseBodyLimitBytes int64 = 2 << 21 // 4 MiB
// DefaultHTTPTimeout is the standard timeout used by HTTP sources. // DefaultHTTPTimeout is the standard timeout used by HTTP sources.
// Individual drivers may override this if they have a specific need. // Individual drivers may override this if they have a specific need.
@@ -29,6 +29,10 @@ func NewHTTPClient(timeout time.Duration) *http.Client {
} }
func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept string) ([]byte, error) { func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept string) ([]byte, error) {
return FetchBodyWithLimit(ctx, client, url, userAgent, accept, DefaultHTTPResponseBodyLimitBytes)
}
func FetchBodyWithLimit(ctx context.Context, client *http.Client, url, userAgent, accept string, bodyLimitBytes int64) ([]byte, error) {
res, err := doRequest(ctx, client, http.MethodGet, url, userAgent, accept, "", "") res, err := doRequest(ctx, client, http.MethodGet, url, userAgent, accept, "", "")
if err != nil { if err != nil {
return nil, err return nil, err
@@ -39,7 +43,7 @@ func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept
return nil, fmt.Errorf("HTTP %s", res.Status) return nil, fmt.Errorf("HTTP %s", res.Status)
} }
return readValidatedBody(res.Body) return readValidatedBody(res.Body, bodyLimitBytes)
} }
// HTTPValidators are cache validators learned from prior successful GET responses. // HTTPValidators are cache validators learned from prior successful GET responses.
@@ -68,6 +72,17 @@ func FetchBodyIfChanged(
url, userAgent, accept string, url, userAgent, accept string,
conditional bool, conditional bool,
validators HTTPValidators, validators HTTPValidators,
) ([]byte, bool, HTTPValidators, error) {
return FetchBodyIfChangedWithLimit(ctx, client, url, userAgent, accept, conditional, validators, DefaultHTTPResponseBodyLimitBytes)
}
func FetchBodyIfChangedWithLimit(
ctx context.Context,
client *http.Client,
url, userAgent, accept string,
conditional bool,
validators HTTPValidators,
bodyLimitBytes int64,
) ([]byte, bool, HTTPValidators, error) { ) ([]byte, bool, HTTPValidators, error) {
headerName, headerValue := conditionalHeader(conditional, validators) headerName, headerValue := conditionalHeader(conditional, validators)
@@ -89,7 +104,7 @@ func FetchBodyIfChanged(
} }
} }
b, err := readValidatedBody(res.Body) b, err := readValidatedBody(res.Body, bodyLimitBytes)
if err != nil { if err != nil {
return nil, false, validators, err return nil, false, validators, err
} }
@@ -150,9 +165,13 @@ func refreshValidators(current HTTPValidators, header http.Header) HTTPValidator
return current return current
} }
func readValidatedBody(r io.Reader) ([]byte, error) { func readValidatedBody(r io.Reader, bodyLimitBytes int64) ([]byte, error) {
// Read at most maxResponseBodyBytes + 1 so we can detect overflow. if bodyLimitBytes <= 0 {
limited := io.LimitReader(r, maxResponseBodyBytes+1) bodyLimitBytes = DefaultHTTPResponseBodyLimitBytes
}
// Read at most bodyLimitBytes + 1 so we can detect overflow.
limited := io.LimitReader(r, bodyLimitBytes+1)
b, err := io.ReadAll(limited) b, err := io.ReadAll(limited)
if err != nil { if err != nil {
@@ -163,8 +182,8 @@ func readValidatedBody(r io.Reader) ([]byte, error) {
return nil, fmt.Errorf("empty response body") return nil, fmt.Errorf("empty response body")
} }
if len(b) > maxResponseBodyBytes { if int64(len(b)) > bodyLimitBytes {
return nil, fmt.Errorf("response body too large (>%d bytes)", maxResponseBodyBytes) return nil, fmt.Errorf("response body too large (>%d bytes)", bodyLimitBytes)
} }
return b, nil return b, nil