10 Commits

Author SHA1 Message Date
5c1b28ee0a Added support for Postgres polling sources 2026-03-29 10:53:13 -05:00
247937b65e Upgraded feedkit's handling of stream sources 2026-03-29 08:34:35 -05:00
eb9a7cb349 Refactor feedkit boundaries ahead of v1
Remove global Postgres schema registration in favor of explicit schema-aware sink factory wiring, and update weatherfeeder to register the Postgres sink explicitly. Add optional per-source HTTP timeout and response body limit overrides while keeping feedkit defaults. Remove remaining legacy source/config compatibility surfaces, including singular kind support and old source registry/type aliases, and migrate weatherfeeder sources to plural `Kinds()` metadata. Clean up related docs, tests, and sample config to match the new Postgres, HTTP, and NATS configuration model.
2026-03-28 13:52:48 -05:00
3281368922 Cleaned up documentation and removed stubs and TODOs throughout the application 2026-03-28 13:02:37 -05:00
3ef93faf69 Moved broadly useful helper functions upstream into feedkit 2026-03-28 11:29:09 -05:00
4910440756 Moved common HTTP polling helpers into feedkit and implemented support for ETag and Last-Modified 2026-03-28 09:59:58 -05:00
3b92c2284d Added automatic pruning for configured Postgres sinks 2026-03-28 08:04:15 -05:00
215afe1acf Added a dedupe processor, and moved processor packages under processors/* 2026-03-16 18:17:53 -05:00
4572c53580 Updated sinks to add a functional postgres sink API. 2026-03-16 14:54:57 -05:00
96039f6530 refactor!: introduce generic processors registry and remove normalize registry adapter
- add new `processors` package with canonical `Processor` interface
- add `processors.Registry` with Register/Build/BuildChain factory model
- switch `pipeline.Pipeline` to `[]processors.Processor`
- replace `normalize.Registry` + registry adapter with direct `normalize.Processor`
- remove `normalize/registry.go`
- update root docs to position normalize as one optional processing stage
- add tests for processors registry, normalize processor behavior, and pipeline flow

BREAKING CHANGE:
- `pipeline.Processor` removed; use `processors.Processor`
- `normalize.Registry` and old normalize processor adapter APIs removed
- downstream daemons must update processor wiring to new `processors.Registry`
  and `normalize.NewProcessor(...)`
2026-03-16 13:14:24 -05:00
59 changed files with 5643 additions and 565 deletions

164
README.md
View File

@@ -1,114 +1,92 @@
# feedkit
`feedkit` provides domain-agnostic plumbing for feed-processing daemons.
`feedkit` is a small Go toolkit for building feed-processing daemons.
A daemon built on feedkit typically:
- ingests upstream input (polling APIs or consuming streams)
It gives you the reusable plumbing around collection, processing, routing, and
emission, while leaving domain concepts, schemas, and application wiring in
your daemon. The intended shape is a family of sibling applications such as
`weatherfeeder`, `newsfeeder`, or `earthquakefeeder` that all share the same
infrastructure patterns without sharing domain logic.
## What It Does
A daemon built on `feedkit` typically:
- ingests upstream input by polling HTTP APIs or consuming streams
- emits domain-agnostic `event.Event` values
- applies optional processing (normalization, dedupe, policy)
- routes events to sinks (stdout, NATS, files, databases, etc.)
- optionally processes those events with stages like dedupe or normalization
- routes events to one or more sinks such as stdout, NATS, or Postgres
Conceptually, the pipeline is:
`Collect -> Process -> Route -> Emit`
## Philosophy
feedkit is not a framework. It provides small composable packages and leaves
lifecycle, domain schemas, and domain-specific validation in your daemon.
`feedkit` is intentionally not a framework.
## Conceptual pipeline
It does not try to own:
- your domain payload schemas
- your domain event kinds
- your daemon lifecycle or `main.go`
- your observability stack or deployment model
Collect -> Normalize (optional) -> Policy -> Route -> Emit
Instead, it provides small composable packages that are easy to wire together in
different daemons.
| Stage | Package(s) |
|---|---|
| Collect | `sources`, `scheduler` |
| Normalize | `normalize` (optional in `pipeline`) |
| Policy | `pipeline` |
| Route | `dispatch` |
| Emit | `sinks` |
| Configure | `config` |
## When To Use It
## Core packages
`feedkit` is a good fit when you want:
- multiple small ingestion daemons with shared infrastructure patterns
- clear separation between raw upstream payloads and normalized canonical models
- reusable routing and sink behavior across domains
- strong config and event-envelope conventions without centralizing domain rules
### `config`
It is a poor fit if you want a monolithic framework that dictates application
structure end-to-end.
Loads YAML config with strict decoding and domain-agnostic validation.
## Built-In Capabilities
`SourceConfig` supports both source modes:
- `mode: poll` requires `every`
- `mode: stream` forbids `every`
- omitted `mode` means auto (inferred from the registered driver type)
`feedkit` currently includes:
- strict YAML config loading and validation
- polling and streaming source abstractions
- scheduler orchestration for configured sources and supervised stream workers
- optional pipeline processors
- built-in dedupe and normalization processors
- route compilation and sink fanout
- built-in sinks for `stdout`, `nats`, and `postgres`
It also supports optional expected source kinds:
- `kinds: ["observation", "alert"]` (preferred)
- `kind: "observation"` (legacy fallback)
The Postgres sink is intentionally split between feedkit-owned infrastructure
and daemon-owned schema mapping. `feedkit` manages connection setup, DDL,
writes, and pruning; downstream applications define the schema and event mapper.
### `event`
## Typical Wiring
Defines the domain-agnostic event envelope (`event.Event`) used across the system.
### `sources`
Defines source interfaces and driver registry:
```go
type Input interface {
Name() string
}
type PollSource interface {
Input
Poll(ctx context.Context) ([]event.Event, error)
}
type StreamSource interface {
Input
Run(ctx context.Context, out chan<- event.Event) error
}
```
Notes:
- a poll can emit `0..N` events
- stream sources emit events continuously
- a single source may emit multiple event kinds
- driver implementations live in downstream daemons and are registered via `sources.Registry`
### `scheduler`
Runs one goroutine per source job:
- poll sources: cadence driven (`every` + jitter)
- stream sources: continuous run loop
### `pipeline`
Optional processing chain between collection and dispatch.
Processors can transform, drop, or reject events.
### `normalize`
Optional normalization package (already implemented). Typical use: sources emit raw
payload events, then normalize to canonical schemas in a pipeline stage.
### `dispatch`
Compiles routes and fans out events to sinks with per-sink queue/worker isolation.
### `sinks`
Defines sink interface and sink registry. Built-ins include `stdout` and `nats`, with
additional sink implementations at varying maturity.
## Typical wiring
At a high level, a daemon built on `feedkit` does this:
1. Load config.
2. Register/build sources from `cfg.Sources`.
3. Register/build sinks from `cfg.Sinks`.
4. Compile routes.
5. Start scheduler (`sources -> bus`).
6. Start dispatcher (`bus -> pipeline -> sinks`).
2. Register domain-specific source drivers.
3. Register built-in and/or custom sinks.
4. Build sources, sinks, and optional processor chain from config.
5. Compile routes.
6. Start the scheduler and dispatcher.
## Non-goals
The package docs are the better source of truth for code-level details. In
particular, each subpackage `doc.go` describes its external API surface and any
optional helper APIs in `helpers.go`.
feedkit intentionally does not:
- define domain payload schemas
- enforce domain-specific event kinds
- own application lifecycle
- prescribe observability stack choices
## Package Layout
The major packages are:
- `config`: config loading and validation
- `event`: the domain-agnostic event envelope
- `sources`: source interfaces and reusable source helpers
- `scheduler`: source execution and cadence management
- `processors`: processor interfaces and registry
- `processors/dedupe`: built-in in-memory dedupe processor
- `processors/normalize`: built-in normalization processor and helpers
- `pipeline`: optional processor chain
- `dispatch`: route compilation and fanout
- `sinks`: sink interfaces, built-ins, and explicit Postgres factory helpers
The root package docs in `doc.go` provide a concise package-by-package map for
Go documentation consumers.

View File

@@ -69,39 +69,31 @@ type SourceConfig struct {
// If set, it describes the expected emitted event kinds for this source.
Kinds []string `yaml:"kinds"`
// Kind is the legacy singular form. Prefer "kinds".
// If both kind and kinds are set, validation fails.
Kind string `yaml:"kind"`
// Params are driver-specific settings (URL, headers, station IDs, API keys, etc.).
// The driver implementation is responsible for reading/validating these.
Params map[string]any `yaml:"params"`
}
// ExpectedKinds returns normalized expected kinds from config.
// "kinds" takes precedence; "kind" is used as a legacy fallback.
func (cfg SourceConfig) ExpectedKinds() []string {
if len(cfg.Kinds) > 0 {
out := make([]string, 0, len(cfg.Kinds))
for _, k := range cfg.Kinds {
k = strings.TrimSpace(k)
if k == "" {
continue
}
out = append(out, k)
out := make([]string, 0, len(cfg.Kinds))
for _, k := range cfg.Kinds {
k = strings.TrimSpace(k)
if k == "" {
continue
}
return out
out = append(out, k)
}
if k := strings.TrimSpace(cfg.Kind); k != "" {
return []string{k}
if len(out) == 0 {
return nil
}
return nil
return out
}
// SinkConfig describes one output sink adapter.
type SinkConfig struct {
Name string `yaml:"name"`
Driver string `yaml:"driver"` // "stdout", "file", "postgres", "rabbitmq", ...
Driver string `yaml:"driver"` // "stdout", "nats", "postgres", ...
Params map[string]any `yaml:"params"` // sink-specific settings
}

View File

@@ -12,20 +12,12 @@ func TestSourceConfigExpectedKinds(t *testing.T) {
want []string
}{
{
name: "plural kinds preferred",
name: "plural kinds normalized",
cfg: SourceConfig{
Kinds: []string{" observation ", "forecast"},
Kind: "alert",
},
want: []string{"observation", "forecast"},
},
{
name: "legacy singular fallback",
cfg: SourceConfig{
Kind: " alert ",
},
want: []string{"alert"},
},
{
name: "empty kinds",
cfg: SourceConfig{},

View File

@@ -105,13 +105,7 @@ func (c *Config) Validate() error {
}
}
// Kind/Kinds (optional)
if s.Kind != "" && len(s.Kinds) > 0 {
m.Add(fieldErr(path+".kind", `cannot be set when "kinds" is provided (use only "kinds")`))
}
if s.Kind != "" && strings.TrimSpace(s.Kind) == "" {
m.Add(fieldErr(path+".kind", "cannot be blank (omit it entirely, or provide a non-empty string)"))
}
// Kinds (optional)
for j, k := range s.Kinds {
kpath := fmt.Sprintf("%s.kinds[%d]", path, j)
if strings.TrimSpace(k) == "" {
@@ -141,7 +135,7 @@ func (c *Config) Validate() error {
}
if strings.TrimSpace(s.Driver) == "" {
m.Add(fieldErr(path+".driver", "is required (stdout|file|postgres|rabbitmq|...)"))
m.Add(fieldErr(path+".driver", "is required (stdout|nats|postgres|...)"))
}
// Params can be nil; that's fine.

View File

@@ -114,31 +114,6 @@ func TestValidate_SourceModeRejectsUnknownValue(t *testing.T) {
}
}
func TestValidate_SourceKindAndKindsConflict(t *testing.T) {
cfg := &Config{
Sources: []SourceConfig{
{
Name: "src1",
Driver: "driver1",
Every: Duration{Duration: time.Minute},
Kind: "observation",
Kinds: []string{"forecast"},
},
},
Sinks: []SinkConfig{
{Name: "sink1", Driver: "stdout"},
},
}
err := cfg.Validate()
if err == nil {
t.Fatalf("expected error, got nil")
}
if !strings.Contains(err.Error(), `sources[0].kind`) {
t.Fatalf("expected error to mention sources[0].kind, got: %v", err)
}
}
func TestValidate_SourceKindsRejectBlankEntries(t *testing.T) {
cfg := &Config{
Sources: []SourceConfig{

119
doc.go
View File

@@ -1,108 +1,57 @@
// Package feedkit provides domain-agnostic plumbing for feed-processing daemons.
// Package feedkit provides a high-level map of the feedkit package set.
//
// A feed daemon ingests upstream input, turns it into event.Event values, applies
// optional processing, and emits to sinks.
// Most real applications do not import the root package directly. Instead, they
// compose the subpackages that handle configuration, collection, processing,
// routing, and sinks.
//
// Conceptual flow:
// The usual flow through feedkit is:
//
// Collect -> Normalize (optional) -> Policy -> Route -> Emit
// Collect -> Process -> Route -> Emit
//
// In feedkit this maps to:
//
// Collect: sources + scheduler
// Normalize: normalize (optional pipeline stage)
// Policy: pipeline
// Route: dispatch
// Emit: sinks
// Config: config
//
// feedkit intentionally does not define domain payload schemas or domain-specific
// validation rules. Those belong in each concrete daemon.
//
// Public packages
// That flow maps to packages like this:
//
// - config
// YAML config loading/validation (strict decode + domain-agnostic checks).
//
// SourceConfig supports both polling and streaming sources:
//
// - mode: "poll" | "stream" | omitted (auto by driver type)
//
// - every: poll interval (required for mode="poll")
//
// - kinds: optional expected emitted kinds
//
// - kind: legacy singular fallback
//
// - params: driver-specific settings
// Loads and validates daemon config. This package owns domain-agnostic
// config shape and consistency checks.
//
// - event
// Domain-agnostic event envelope (ID, Kind, Source, EmittedAt, Schema, Payload).
// Defines the event.Event envelope shared across sources, processors,
// dispatch, and sinks.
//
// - sources
// Source abstractions and source-driver registry.
//
// There are two source interfaces:
//
// - PollSource: Poll(ctx) ([]event.Event, error)
//
// - StreamSource: Run(ctx, out) error
//
// Both share Input{Name()}. A source may emit 0..N events per poll/run step,
// and may emit multiple event kinds.
// Defines polling and streaming source interfaces, the source registry, and
// reusable source helpers.
//
// - scheduler
// Runs one goroutine per job:
// Runs configured sources on a cadence and supervises long-lived stream
// workers with restart/fatal handling.
//
// - PollSource jobs run on Every (+ jitter)
// - processors
// Defines the generic processor interface and registry used to build
// ordered processor chains.
//
// - StreamSource jobs run continuously
// - processors/dedupe
// Built-in in-memory dedupe processor keyed by Event.ID.
//
// - processors/normalize
// Built-in normalization processor plus helper APIs for raw-to-canonical
// event mapping.
//
// - pipeline
// Processor chain between scheduler and dispatch.
// Processors can transform, drop, or reject events.
//
// - normalize
// Optional pipeline processor for raw->canonical mapping.
// If no normalizer matches, the event passes through unchanged by default.
// Applies an ordered processor chain between collection and dispatch.
//
// - dispatch
// Routes events to sinks and isolates slow sinks via per-sink queues/workers.
// Compiles routes and fans events out to sinks with per-sink isolation.
//
// - sinks
// Sink abstractions + sink registry.
// Defines sink interfaces, the sink registry, schema-free built-in sinks,
// and explicit Postgres factory helpers.
//
// Typical wiring (daemon main.go)
// feedkit is intentionally domain-agnostic. Domain schemas, domain event kinds,
// upstream-specific parsing, and daemon lifecycle remain the responsibility of
// each concrete application.
//
// 1. Load config.
// 2. Register source drivers and build sources from config.Sources.
// 3. Register sink drivers and build sinks from config.Sinks.
// 4. Compile routes.
// 5. Start scheduler (sources -> bus) and dispatcher (bus -> pipeline -> sinks).
//
// Sketch:
//
// cfg, _ := config.Load("config.yml")
// srcReg := sources.NewRegistry()
// // domain registers poll/stream drivers...
//
// var jobs []scheduler.Job
// for _, sc := range cfg.Sources {
// src, _ := srcReg.BuildInput(sc)
// jobs = append(jobs, scheduler.Job{
// Source: src,
// Every: sc.Every.Duration,
// })
// }
//
// bus := make(chan event.Event, 256)
// s := &scheduler.Scheduler{Jobs: jobs, Out: bus, Logf: logf}
// // start dispatcher similarly...
//
// # Context and cancellation
//
// All blocking work should honor context cancellation:
// - source polling/streaming I/O
// - sink consumption
// - any expensive processor work
// For repository-level overview and usage narrative, see README.md. For
// code-level details, each subpackage doc.go is the source of truth for that
// package's public API surface and optional helpers.
package feedkit

1
go.mod
View File

@@ -3,6 +3,7 @@ module gitea.maximumdirect.net/ejr/feedkit
go 1.22
require (
github.com/lib/pq v1.10.9
github.com/nats-io/nats.go v1.34.0
gopkg.in/yaml.v3 v3.0.1
)

2
go.sum
View File

@@ -1,5 +1,7 @@
github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4=
github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/nats-io/nats.go v1.34.0 h1:fnxnPCNiwIG5w08rlMcEKTUw4AV/nKyGCOJE8TdhSPk=
github.com/nats-io/nats.go v1.34.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8=
github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI=

View File

@@ -0,0 +1,67 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"net/url"
"strings"
"time"
_ "github.com/lib/pq"
)
const initTimeout = 5 * time.Second
var (
sqlOpen = sql.Open
pingDB = func(ctx context.Context, db *sql.DB) error { return db.PingContext(ctx) }
)
// ConnConfig describes the minimal connection settings shared by feedkit's
// Postgres readers and writers.
type ConnConfig struct {
URI string
Username string
Password string
}
// BuildDSN validates a Postgres URI and injects credentials into it.
func BuildDSN(cfg ConnConfig) (string, error) {
u, err := url.Parse(strings.TrimSpace(cfg.URI))
if err != nil {
return "", fmt.Errorf("invalid uri: %w", err)
}
if u.Scheme == "" {
return "", fmt.Errorf("invalid uri: missing scheme")
}
if u.Host == "" {
return "", fmt.Errorf("invalid uri: missing host")
}
u.User = url.UserPassword(cfg.Username, cfg.Password)
return u.String(), nil
}
// Open builds a DSN, opens a database handle, and verifies connectivity with a
// bounded ping before returning the handle.
func Open(ctx context.Context, cfg ConnConfig) (*sql.DB, error) {
dsn, err := BuildDSN(cfg)
if err != nil {
return nil, err
}
db, err := sqlOpen("postgres", dsn)
if err != nil {
return nil, err
}
pingCtx, cancel := context.WithTimeout(ctx, initTimeout)
defer cancel()
if err := pingDB(pingCtx, db); err != nil {
_ = db.Close()
return nil, err
}
return db, nil
}

View File

@@ -0,0 +1,165 @@
package postgres
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"io"
"net/url"
"strings"
"sync"
"testing"
)
func withPostgresPackageTestState(t *testing.T) {
t.Helper()
oldSQLOpen := sqlOpen
oldPingDB := pingDB
t.Cleanup(func() {
sqlOpen = oldSQLOpen
pingDB = oldPingDB
})
}
func TestBuildDSNInjectsCredentials(t *testing.T) {
dsn, err := BuildDSN(ConnConfig{
URI: " postgres://db.example.local:5432/feedkit?sslmode=disable ",
Username: "app_user",
Password: "app_pass",
})
if err != nil {
t.Fatalf("BuildDSN() error = %v", err)
}
u, err := url.Parse(dsn)
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
if u.User == nil || u.User.Username() != "app_user" {
t.Fatalf("username = %q, want app_user", u.User.Username())
}
pass, ok := u.User.Password()
if !ok || pass != "app_pass" {
t.Fatalf("password = %q, want app_pass", pass)
}
}
func TestBuildDSNRejectsInvalidURI(t *testing.T) {
_, err := BuildDSN(ConnConfig{URI: "http://[::1", Username: "u", Password: "p"})
if err == nil {
t.Fatalf("BuildDSN() error = nil, want error")
}
if !strings.Contains(err.Error(), "invalid uri") {
t.Fatalf("BuildDSN() error = %q", err)
}
}
func TestBuildDSNRejectsMissingScheme(t *testing.T) {
_, err := BuildDSN(ConnConfig{URI: "//db.example.local/feedkit", Username: "u", Password: "p"})
if err == nil {
t.Fatalf("BuildDSN() error = nil, want error")
}
if !strings.Contains(err.Error(), "missing scheme") {
t.Fatalf("BuildDSN() error = %q", err)
}
}
func TestBuildDSNRejectsMissingHost(t *testing.T) {
_, err := BuildDSN(ConnConfig{URI: "postgres:///feedkit", Username: "u", Password: "p"})
if err == nil {
t.Fatalf("BuildDSN() error = nil, want error")
}
if !strings.Contains(err.Error(), "missing host") {
t.Fatalf("BuildDSN() error = %q", err)
}
}
func TestOpenPropagatesOpenFailure(t *testing.T) {
withPostgresPackageTestState(t)
sqlOpen = func(_, _ string) (*sql.DB, error) {
return nil, errors.New("open failed")
}
_, err := Open(context.Background(), ConnConfig{
URI: "postgres://db.example.local/feedkit",
Username: "u",
Password: "p",
})
if err == nil {
t.Fatalf("Open() error = nil, want error")
}
if !strings.Contains(err.Error(), "open failed") {
t.Fatalf("Open() error = %q", err)
}
}
func TestOpenPropagatesPingFailure(t *testing.T) {
withPostgresPackageTestState(t)
const driverName = "feedkit_internal_postgres_ping_fail"
registerPingTestDriver(driverName, errors.New("ping failed"))
sqlOpen = func(_, _ string) (*sql.DB, error) {
return sql.Open(driverName, "")
}
_, err := Open(context.Background(), ConnConfig{
URI: "postgres://db.example.local/feedkit",
Username: "u",
Password: "p",
})
if err == nil {
t.Fatalf("Open() error = nil, want error")
}
if !strings.Contains(err.Error(), "ping failed") {
t.Fatalf("Open() error = %q", err)
}
}
var (
pingDriverMu sync.Mutex
pingDriverSeen = map[string]bool{}
)
func registerPingTestDriver(name string, pingErr error) {
pingDriverMu.Lock()
defer pingDriverMu.Unlock()
if pingDriverSeen[name] {
return
}
sql.Register(name, &pingTestDriver{pingErr: pingErr})
pingDriverSeen[name] = true
}
type pingTestDriver struct {
pingErr error
}
func (d *pingTestDriver) Open(string) (driver.Conn, error) {
return &pingTestConn{pingErr: d.pingErr}, nil
}
type pingTestConn struct {
pingErr error
}
func (c *pingTestConn) Prepare(string) (driver.Stmt, error) {
return nil, errors.New("not implemented")
}
func (c *pingTestConn) Close() error { return nil }
func (c *pingTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") }
func (c *pingTestConn) Ping(context.Context) error { return c.pingErr }
func (c *pingTestConn) QueryContext(context.Context, string, []driver.NamedValue) (driver.Rows, error) {
return &pingTestRows{}, nil
}
type pingTestRows struct{}
func (r *pingTestRows) Columns() []string { return []string{"ok"} }
func (r *pingTestRows) Close() error { return nil }
func (r *pingTestRows) Next([]driver.Value) error { return io.EOF }

View File

@@ -1,17 +0,0 @@
// Package normalize provides an OPTIONAL normalization hook for feedkit pipelines.
//
// Motivation:
// Many daemons have sources that:
// 1. fetch raw upstream data (often JSON), and
// 2. transform it into a domain's normalized payload format.
//
// Doing both steps inside Source.Poll works, but tends to make sources large and
// encourages duplication (unit conversions, common mapping helpers, etc.).
//
// This package lets a source emit a "raw" event (e.g., Schema="raw.openweather.current.v1",
// Payload=json.RawMessage), and then a normalization processor can convert it into a
// normalized event (e.g., Schema="weather.observation.v1", Payload=WeatherObservation{}).
//
// Key property: normalization is optional.
// If no registered Normalizer matches an event, it passes through unchanged.
package normalize

View File

@@ -1,140 +0,0 @@
package normalize
import (
"context"
"fmt"
"sync"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// Registry holds a set of Normalizers and selects one for a given event.
//
// Selection rule (simple + predictable):
// - iterate in registration order
// - the FIRST Normalizer whose Match(e) returns true is used
//
// If none match, the event passes through unchanged.
//
// Why "first match wins"?
// Normalization is usually a single mapping step from a raw schema/version into
// a normalized schema/version. If you want multiple transformation steps,
// model them as multiple pipeline processors (which feedkit already supports).
type Registry struct {
mu sync.RWMutex
ns []Normalizer
}
// Register adds a normalizer to the registry.
//
// Register panics if n is nil; this is a programmer error and should fail fast.
func (r *Registry) Register(n Normalizer) {
if n == nil {
panic("normalize.Registry.Register: normalizer cannot be nil")
}
r.mu.Lock()
defer r.mu.Unlock()
r.ns = append(r.ns, n)
}
// Normalize finds the first matching Normalizer and applies it.
//
// If no normalizer matches, it returns the input event unchanged.
//
// If a normalizer returns (nil, nil), the event is dropped.
func (r *Registry) Normalize(ctx context.Context, in event.Event) (*event.Event, error) {
if r == nil {
// Nil registry is a valid "feature off" state.
out := in
return &out, nil
}
r.mu.RLock()
ns := append([]Normalizer(nil), r.ns...) // copy for safe iteration outside lock
r.mu.RUnlock()
for _, n := range ns {
if n == nil {
// Shouldn't happen (Register panics), but guard anyway.
continue
}
if !n.Match(in) {
continue
}
out, err := n.Normalize(ctx, in)
if err != nil {
return nil, fmt.Errorf("normalize: normalizer failed: %w", err)
}
// out may be nil to signal "drop".
return out, nil
}
// No match: pass through unchanged.
out := in
return &out, nil
}
// Processor adapts a Registry into a pipeline Processor.
//
// It implements:
//
// Process(ctx context.Context, in event.Event) (*event.Event, error)
//
// which matches feedkit/pipeline.Processor.
//
// Optionality:
// - If Registry is nil, Processor becomes a no-op pass-through.
// - If Registry has no matching normalizer for an event, that event passes through unchanged.
type Processor struct {
Registry *Registry
// If true, events that do not match any normalizer cause an error.
// Default is false (pass-through), which is the behavior you asked for.
RequireMatch bool
}
// Process implements the pipeline.Processor interface.
func (p Processor) Process(ctx context.Context, in event.Event) (*event.Event, error) {
// "Feature off": no registry means no normalization.
if p.Registry == nil {
out := in
return &out, nil
}
out, err := p.Registry.Normalize(ctx, in)
if err != nil {
return nil, err
}
if out == nil {
// Dropped by normalization policy.
return nil, nil
}
if p.RequireMatch {
// Detect "no-op pass-through due to no match" by checking whether a match existed.
// We do this with a cheap second pass to avoid changing Normalize()'s signature.
// (This is rare to enable; correctness/clarity > micro-optimization.)
if !p.Registry.hasMatch(in) {
return nil, fmt.Errorf("normalize: no normalizer matched event (id=%s kind=%s source=%s schema=%q)",
in.ID, in.Kind, in.Source, in.Schema)
}
}
return out, nil
}
func (r *Registry) hasMatch(in event.Event) bool {
if r == nil {
return false
}
r.mu.RLock()
defer r.mu.RUnlock()
for _, n := range r.ns {
if n != nil && n.Match(in) {
return true
}
}
return false
}

View File

@@ -1,5 +0,0 @@
package pipeline
// Placeholder for dedupe processor:
// - key by Event.ID or computed key
// - in-memory store first; later optional Postgres-backed

View File

@@ -5,15 +5,11 @@ import (
"fmt"
"gitea.maximumdirect.net/ejr/feedkit/event"
"gitea.maximumdirect.net/ejr/feedkit/processors"
)
// Processor can mutate/drop events (dedupe, rate-limit, normalization tweaks).
type Processor interface {
Process(ctx context.Context, in event.Event) (out *event.Event, err error)
}
type Pipeline struct {
Processors []Processor
Processors []processors.Processor
}
func (p *Pipeline) Process(ctx context.Context, e event.Event) (*event.Event, error) {

115
pipeline/pipeline_test.go Normal file
View File

@@ -0,0 +1,115 @@
package pipeline
import (
"context"
"fmt"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
"gitea.maximumdirect.net/ejr/feedkit/processors"
)
type procFunc func(context.Context, event.Event) (*event.Event, error)
func (f procFunc) Process(ctx context.Context, in event.Event) (*event.Event, error) {
return f(ctx, in)
}
func TestPipelineProcessSequentialOrder(t *testing.T) {
var gotOrder []string
p := &Pipeline{
Processors: []processors.Processor{
procFunc(func(_ context.Context, in event.Event) (*event.Event, error) {
gotOrder = append(gotOrder, "first")
out := in
out.Schema = "stage.one.v1"
return &out, nil
}),
procFunc(func(_ context.Context, in event.Event) (*event.Event, error) {
gotOrder = append(gotOrder, "second")
if in.Schema != "stage.one.v1" {
return nil, fmt.Errorf("expected schema from first stage, got %q", in.Schema)
}
out := in
out.Schema = "stage.two.v1"
return &out, nil
}),
},
}
out, err := p.Process(context.Background(), validEvent())
if err != nil {
t.Fatalf("Process error: %v", err)
}
if out == nil {
t.Fatalf("expected output event, got nil")
}
if out.Schema != "stage.two.v1" {
t.Fatalf("unexpected output schema: %q", out.Schema)
}
if strings.Join(gotOrder, ",") != "first,second" {
t.Fatalf("unexpected processor order: %v", gotOrder)
}
}
func TestPipelineProcessInvalidInput(t *testing.T) {
p := &Pipeline{}
_, err := p.Process(context.Background(), event.Event{})
if err == nil {
t.Fatalf("expected input validation error")
}
if !strings.Contains(err.Error(), "invalid input event") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestPipelineProcessDrop(t *testing.T) {
p := &Pipeline{
Processors: []processors.Processor{
procFunc(func(context.Context, event.Event) (*event.Event, error) {
return nil, nil
}),
},
}
out, err := p.Process(context.Background(), validEvent())
if err != nil {
t.Fatalf("Process error: %v", err)
}
if out != nil {
t.Fatalf("expected nil output for dropped event, got %#v", out)
}
}
func TestPipelineProcessInvalidOutput(t *testing.T) {
p := &Pipeline{
Processors: []processors.Processor{
procFunc(func(_ context.Context, in event.Event) (*event.Event, error) {
out := in
out.Payload = nil
return &out, nil
}),
},
}
_, err := p.Process(context.Background(), validEvent())
if err == nil {
t.Fatalf("expected output validation error")
}
if !strings.Contains(err.Error(), "invalid output event") {
t.Fatalf("unexpected error: %v", err)
}
}
func validEvent() event.Event {
return event.Event{
ID: "evt-1",
Kind: event.Kind("observation"),
Source: "source-1",
EmittedAt: time.Now().UTC(),
Payload: map[string]any{"ok": true},
}
}

View File

@@ -1,5 +0,0 @@
package pipeline
// Placeholder for rate limit processor:
// - per source/kind sink routing limits
// - cooldown windows

28
processors/dedupe/doc.go Normal file
View File

@@ -0,0 +1,28 @@
// Package dedupe provides a default in-memory LRU deduplication processor.
//
// The processor keys strictly by event.Event.ID:
// - first-seen IDs pass through
// - repeated IDs are dropped
//
// The in-memory seen-ID set is bounded by a required maxEntries capacity.
// When capacity is exceeded, the least recently used ID is evicted.
//
// Typical registry wiring:
//
// ```go
// reg := processors.NewRegistry()
// reg.Register("dedupe", dedupe.Factory(10_000))
//
// reg.Register("normalize", func() (processors.Processor, error) {
// return normalize.NewProcessor(myNormalizers, false), nil
// })
//
// chain, err := reg.BuildChain([]string{"dedupe", "normalize"})
//
// if err != nil {
// // handle wiring error
// }
//
// p := &pipeline.Pipeline{Processors: chain}
// ```
package dedupe

View File

@@ -0,0 +1,89 @@
package dedupe
import (
"container/list"
"context"
"fmt"
"strings"
"sync"
"gitea.maximumdirect.net/ejr/feedkit/event"
"gitea.maximumdirect.net/ejr/feedkit/processors"
)
// Processor drops duplicate events by Event.ID using an in-memory LRU.
type Processor struct {
maxEntries int
mu sync.Mutex
order *list.List // most-recent at front, least-recent at back
byID map[string]*list.Element // id -> list element (element.Value is string id)
}
var _ processors.Processor = (*Processor)(nil)
// NewProcessor constructs a dedupe processor with a required max entry count.
func NewProcessor(maxEntries int) (*Processor, error) {
if maxEntries <= 0 {
return nil, fmt.Errorf("dedupe: maxEntries must be > 0, got %d", maxEntries)
}
return &Processor{
maxEntries: maxEntries,
order: list.New(),
byID: make(map[string]*list.Element, maxEntries),
}, nil
}
// Factory returns a processors.Factory that constructs Processor instances.
func Factory(maxEntries int) processors.Factory {
return func() (processors.Processor, error) {
return NewProcessor(maxEntries)
}
}
// Process implements processors.Processor.
func (p *Processor) Process(_ context.Context, in event.Event) (*event.Event, error) {
if p == nil {
return nil, fmt.Errorf("dedupe: processor is nil")
}
if p.maxEntries <= 0 {
return nil, fmt.Errorf("dedupe: processor maxEntries must be > 0")
}
id := strings.TrimSpace(in.ID)
if id == "" {
return nil, fmt.Errorf("dedupe: event ID is required")
}
p.mu.Lock()
if p.order == nil || p.byID == nil {
p.mu.Unlock()
return nil, fmt.Errorf("dedupe: processor is not initialized")
}
if elem, exists := p.byID[id]; exists {
p.order.MoveToFront(elem)
p.mu.Unlock()
return nil, nil
}
elem := p.order.PushFront(id)
p.byID[id] = elem
if p.order.Len() > p.maxEntries {
oldest := p.order.Back()
if oldest != nil {
p.order.Remove(oldest)
if oldestID, ok := oldest.Value.(string); ok {
delete(p.byID, oldestID)
}
}
}
p.mu.Unlock()
out := in
return &out, nil
}

View File

@@ -0,0 +1,163 @@
package dedupe
import (
"context"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
"gitea.maximumdirect.net/ejr/feedkit/processors"
)
func TestNewProcessorValidation(t *testing.T) {
t.Run("rejects non-positive maxEntries", func(t *testing.T) {
for _, maxEntries := range []int{0, -1} {
p, err := NewProcessor(maxEntries)
if err == nil {
t.Fatalf("expected error for maxEntries=%d, got nil", maxEntries)
}
if p != nil {
t.Fatalf("expected nil processor for maxEntries=%d", maxEntries)
}
if !strings.Contains(err.Error(), "maxEntries") {
t.Fatalf("unexpected error: %v", err)
}
}
})
t.Run("accepts positive maxEntries", func(t *testing.T) {
p, err := NewProcessor(1)
if err != nil {
t.Fatalf("NewProcessor error: %v", err)
}
if p == nil {
t.Fatalf("expected processor, got nil")
}
})
}
func TestProcessorFirstSeenAndDuplicate(t *testing.T) {
p, err := NewProcessor(8)
if err != nil {
t.Fatalf("NewProcessor error: %v", err)
}
ctx := context.Background()
first := testEvent("evt-1")
out, err := p.Process(ctx, first)
if err != nil {
t.Fatalf("Process first error: %v", err)
}
if out == nil {
t.Fatalf("expected first event to pass through")
}
if out.ID != first.ID {
t.Fatalf("expected unchanged ID %q, got %q", first.ID, out.ID)
}
out, err = p.Process(ctx, first)
if err != nil {
t.Fatalf("Process duplicate error: %v", err)
}
if out != nil {
t.Fatalf("expected duplicate to be dropped, got %#v", out)
}
out, err = p.Process(ctx, testEvent("evt-2"))
if err != nil {
t.Fatalf("Process second unique error: %v", err)
}
if out == nil {
t.Fatalf("expected second unique event to pass through")
}
}
func TestProcessorLRUEvictionAndPromotion(t *testing.T) {
p, err := NewProcessor(2)
if err != nil {
t.Fatalf("NewProcessor error: %v", err)
}
ctx := context.Background()
mustPass(t, p, ctx, "a")
mustPass(t, p, ctx, "b")
mustDrop(t, p, ctx, "a") // promote "a" so "b" becomes least-recently-used
mustPass(t, p, ctx, "c") // evicts "b"
mustDrop(t, p, ctx, "a") // "a" should still be tracked after promotion
mustPass(t, p, ctx, "b") // "b" was evicted, so now it passes again
}
func TestProcessorRejectsBlankID(t *testing.T) {
p, err := NewProcessor(4)
if err != nil {
t.Fatalf("NewProcessor error: %v", err)
}
in := testEvent(" ")
out, err := p.Process(context.Background(), in)
if err == nil {
t.Fatalf("expected error for blank ID")
}
if out != nil {
t.Fatalf("expected nil output on error, got %#v", out)
}
if !strings.Contains(err.Error(), "event ID is required") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestFactoryWithRegistry(t *testing.T) {
r := processors.NewRegistry()
r.Register("dedupe", Factory(3))
p, err := r.Build("dedupe")
if err != nil {
t.Fatalf("Build error: %v", err)
}
if p == nil {
t.Fatalf("expected processor, got nil")
}
out, err := p.Process(context.Background(), testEvent("evt-factory-1"))
if err != nil {
t.Fatalf("Process error: %v", err)
}
if out == nil {
t.Fatalf("expected first event to pass through")
}
}
func mustPass(t *testing.T, p *Processor, ctx context.Context, id string) {
t.Helper()
out, err := p.Process(ctx, testEvent(id))
if err != nil {
t.Fatalf("expected pass for id=%q, got error: %v", id, err)
}
if out == nil {
t.Fatalf("expected pass for id=%q, got drop", id)
}
}
func mustDrop(t *testing.T, p *Processor, ctx context.Context, id string) {
t.Helper()
out, err := p.Process(ctx, testEvent(id))
if err != nil {
t.Fatalf("expected drop for id=%q, got error: %v", id, err)
}
if out != nil {
t.Fatalf("expected drop for id=%q, got output", id)
}
}
func testEvent(id string) event.Event {
return event.Event{
ID: id,
Kind: event.Kind("observation"),
Source: "source-1",
EmittedAt: time.Now().UTC(),
Payload: map[string]any{"ok": true},
}
}

24
processors/doc.go Normal file
View File

@@ -0,0 +1,24 @@
// Package processors defines feedkit's generic processor abstraction and registry.
//
// Processors are optional pipeline stages that can transform, drop, or reject
// events before dispatch to sinks.
//
// Registry provides name-based construction so daemons can assemble processor
// chains without embedding switch statements in wiring code.
//
// Example:
//
// reg := processors.NewRegistry()
// reg.Register("dedupe", dedupe.Factory(10_000))
// reg.Register("normalize", func() (processors.Processor, error) {
// // import "gitea.maximumdirect.net/ejr/feedkit/processors/normalize"
// return normalize.NewProcessor(myNormalizers, false), nil
// })
//
// chain, err := reg.BuildChain([]string{"dedupe", "normalize"})
// if err != nil {
// // handle wiring error
// }
//
// p := &pipeline.Pipeline{Processors: chain}
package processors

View File

@@ -0,0 +1,19 @@
// Package normalize provides the feedkit normalization processor and related
// helper APIs for raw-to-canonical event mapping.
//
// External API surface:
// - Processor: concrete processors.Processor implementation
// - Normalizer / Func: normalization interface and ergonomic function adapter
//
// Optional helpers from helpers.go:
// - PayloadJSONBytes: extract supported JSON-shaped payloads into bytes
// - DecodeJSONPayload: decode an event payload into a typed struct
// - FinalizeEvent: copy the input event envelope onto a normalized output
//
// Typical usage:
// sources emit raw events (often with json.RawMessage payloads), then a
// normalize.Processor converts matching raw schemas into canonical payloads.
//
// Key property: normalization is optional.
// If no Normalizer matches an event, Processor passes it through unchanged by default.
package normalize

View File

@@ -0,0 +1,84 @@
package normalize
import (
"encoding/json"
"fmt"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// PayloadJSONBytes extracts a JSON payload into bytes suitable for json.Unmarshal.
//
// Supported payload shapes:
// - json.RawMessage
// - []byte
// - string
// - map[string]any
func PayloadJSONBytes(e event.Event) ([]byte, error) {
if e.Payload == nil {
return nil, fmt.Errorf("payload is nil")
}
switch v := e.Payload.(type) {
case json.RawMessage:
if len(v) == 0 {
return nil, fmt.Errorf("payload is empty json.RawMessage")
}
return []byte(v), nil
case []byte:
if len(v) == 0 {
return nil, fmt.Errorf("payload is empty []byte")
}
return v, nil
case string:
if v == "" {
return nil, fmt.Errorf("payload is empty string")
}
return []byte(v), nil
case map[string]any:
b, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("marshal map payload: %w", err)
}
return b, nil
default:
return nil, fmt.Errorf("unsupported payload type %T", e.Payload)
}
}
// DecodeJSONPayload extracts the event payload as bytes and unmarshals it into T.
func DecodeJSONPayload[T any](in event.Event) (T, error) {
var zero T
b, err := PayloadJSONBytes(in)
if err != nil {
return zero, fmt.Errorf("extract payload: %w", err)
}
var parsed T
if err := json.Unmarshal(b, &parsed); err != nil {
return zero, fmt.Errorf("decode raw payload: %w", err)
}
return parsed, nil
}
// FinalizeEvent builds the output event envelope by copying the input and applying
// the new schema/payload, plus optional EffectiveAt.
func FinalizeEvent(in event.Event, outSchema string, outPayload any, effectiveAt time.Time) (*event.Event, error) {
out := in
out.Schema = outSchema
out.Payload = outPayload
if !effectiveAt.IsZero() {
t := effectiveAt.UTC()
out.EffectiveAt = &t
}
if err := out.Validate(); err != nil {
return nil, err
}
return &out, nil
}

View File

@@ -0,0 +1,118 @@
package normalize
import (
"encoding/json"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
func TestPayloadJSONBytesSupportedShapes(t *testing.T) {
cases := []struct {
name string
payload any
want string
}{
{name: "rawmessage", payload: json.RawMessage(`{"a":1}`), want: `{"a":1}`},
{name: "bytes", payload: []byte(`{"a":2}`), want: `{"a":2}`},
{name: "string", payload: `{"a":3}`, want: `{"a":3}`},
{name: "map", payload: map[string]any{"a": 4}, want: `{"a":4}`},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, err := PayloadJSONBytes(event.Event{Payload: tc.payload})
if err != nil {
t.Fatalf("PayloadJSONBytes() unexpected error: %v", err)
}
if string(got) != tc.want {
t.Fatalf("PayloadJSONBytes() = %s, want %s", string(got), tc.want)
}
})
}
}
func TestPayloadJSONBytesRejectsInvalidPayloads(t *testing.T) {
cases := []struct {
name string
payload any
want string
}{
{name: "nil", payload: nil, want: "payload is nil"},
{name: "empty rawmessage", payload: json.RawMessage{}, want: "payload is empty json.RawMessage"},
{name: "empty bytes", payload: []byte{}, want: "payload is empty []byte"},
{name: "empty string", payload: "", want: "payload is empty string"},
{name: "unsupported", payload: 123, want: "unsupported payload type"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
_, err := PayloadJSONBytes(event.Event{Payload: tc.payload})
if err == nil {
t.Fatalf("PayloadJSONBytes() expected error")
}
if !strings.Contains(err.Error(), tc.want) {
t.Fatalf("PayloadJSONBytes() error = %q, want substring %q", err, tc.want)
}
})
}
}
func TestDecodeJSONPayload(t *testing.T) {
type payload struct {
Name string `json:"name"`
}
got, err := DecodeJSONPayload[payload](event.Event{
Payload: json.RawMessage(`{"name":"alice"}`),
})
if err != nil {
t.Fatalf("DecodeJSONPayload() unexpected error: %v", err)
}
if got.Name != "alice" {
t.Fatalf("DecodeJSONPayload() = %#v, want name alice", got)
}
}
func TestFinalizeEventPreservesEnvelopeAndEffectiveAtBehavior(t *testing.T) {
existingEffectiveAt := time.Date(2026, 3, 28, 11, 0, 0, 0, time.UTC)
in := event.Event{
ID: "evt-1",
Kind: event.Kind("observation"),
Source: "source-a",
EmittedAt: time.Date(2026, 3, 28, 12, 0, 0, 0, time.UTC),
EffectiveAt: &existingEffectiveAt,
Schema: "raw.example.v1",
Payload: map[string]any{"old": true},
}
out, err := FinalizeEvent(in, "example.v1", map[string]any{"value": 1.234567}, time.Time{})
if err != nil {
t.Fatalf("FinalizeEvent() unexpected error: %v", err)
}
if out.ID != in.ID || out.Kind != in.Kind || out.Source != in.Source || out.EmittedAt != in.EmittedAt {
t.Fatalf("FinalizeEvent() changed preserved envelope fields: %#v", out)
}
if out.EffectiveAt == nil || !out.EffectiveAt.Equal(existingEffectiveAt) {
t.Fatalf("FinalizeEvent() effectiveAt = %#v, want preserved existing value", out.EffectiveAt)
}
nextEffectiveAt := time.Date(2026, 3, 28, 13, 0, 0, 0, time.FixedZone("x", -4*3600))
out, err = FinalizeEvent(in, "example.v1", map[string]any{"value": 1.234567}, nextEffectiveAt)
if err != nil {
t.Fatalf("FinalizeEvent() unexpected overwrite error: %v", err)
}
if out.EffectiveAt == nil || !out.EffectiveAt.Equal(nextEffectiveAt.UTC()) {
t.Fatalf("FinalizeEvent() effectiveAt = %#v, want %s", out.EffectiveAt, nextEffectiveAt.UTC())
}
payloadMap, ok := out.Payload.(map[string]any)
if !ok {
t.Fatalf("FinalizeEvent() payload type = %T, want map[string]any", out.Payload)
}
if payloadMap["value"] != 1.234567 {
t.Fatalf("FinalizeEvent() payload value = %#v, want unrounded 1.234567", payloadMap["value"])
}
}

View File

@@ -0,0 +1,57 @@
package normalize
import (
"context"
"fmt"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// Processor applies ordered normalization rules to pipeline events.
//
// Selection rule:
// - iterate in Normalizers order
// - the first Normalizer whose Match returns true is applied
//
// If no normalizer matches, the default behavior is pass-through.
type Processor struct {
Normalizers []Normalizer
// If true, events that do not match any normalizer cause an error.
// Default is false (pass-through).
RequireMatch bool
}
// NewProcessor constructs a normalization processor from an ordered normalizer list.
func NewProcessor(normalizers []Normalizer, requireMatch bool) Processor {
return Processor{
Normalizers: append([]Normalizer(nil), normalizers...),
RequireMatch: requireMatch,
}
}
// Process implements processors.Processor.
func (p Processor) Process(ctx context.Context, in event.Event) (*event.Event, error) {
for _, n := range p.Normalizers {
if n == nil {
continue
}
if !n.Match(in) {
continue
}
out, err := n.Normalize(ctx, in)
if err != nil {
return nil, fmt.Errorf("normalize: normalizer failed: %w", err)
}
return out, nil
}
if p.RequireMatch {
return nil, fmt.Errorf("normalize: no normalizer matched event (id=%s kind=%s source=%s schema=%q)",
in.ID, in.Kind, in.Source, in.Schema)
}
out := in
return &out, nil
}

View File

@@ -0,0 +1,139 @@
package normalize
import (
"context"
"errors"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
func TestProcessorFirstMatchWins(t *testing.T) {
var firstCalls, secondCalls int
p := NewProcessor([]Normalizer{
Func{
MatchFn: func(event.Event) bool { return true },
NormalizeFn: func(_ context.Context, in event.Event) (*event.Event, error) {
firstCalls++
out := in
out.Schema = "normalized.first.v1"
return &out, nil
},
},
Func{
MatchFn: func(event.Event) bool { return true },
NormalizeFn: func(_ context.Context, in event.Event) (*event.Event, error) {
secondCalls++
out := in
out.Schema = "normalized.second.v1"
return &out, nil
},
},
}, false)
out, err := p.Process(context.Background(), testEvent())
if err != nil {
t.Fatalf("Process error: %v", err)
}
if out == nil {
t.Fatalf("expected output event, got nil")
}
if out.Schema != "normalized.first.v1" {
t.Fatalf("unexpected schema: %q", out.Schema)
}
if firstCalls != 1 {
t.Fatalf("expected first normalizer called once, got %d", firstCalls)
}
if secondCalls != 0 {
t.Fatalf("expected second normalizer skipped, got %d calls", secondCalls)
}
}
func TestProcessorNoMatchPassThroughAndRequireMatch(t *testing.T) {
in := testEvent()
in.Schema = "raw.schema.v1"
passThrough := NewProcessor([]Normalizer{
Func{
MatchFn: func(event.Event) bool { return false },
NormalizeFn: func(_ context.Context, in event.Event) (*event.Event, error) {
out := in
out.Schema = "should.not.run"
return &out, nil
},
},
}, false)
out, err := passThrough.Process(context.Background(), in)
if err != nil {
t.Fatalf("pass-through Process error: %v", err)
}
if out == nil {
t.Fatalf("expected pass-through output event, got nil")
}
if out.Schema != "raw.schema.v1" {
t.Fatalf("expected unchanged schema, got %q", out.Schema)
}
required := NewProcessor(nil, true)
_, err = required.Process(context.Background(), in)
if err == nil {
t.Fatalf("expected require-match error")
}
if !strings.Contains(err.Error(), "no normalizer matched") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestProcessorDropAndErrorPropagation(t *testing.T) {
t.Run("drop", func(t *testing.T) {
p := NewProcessor([]Normalizer{
Func{
MatchFn: func(event.Event) bool { return true },
NormalizeFn: func(context.Context, event.Event) (*event.Event, error) {
return nil, nil
},
},
}, false)
out, err := p.Process(context.Background(), testEvent())
if err != nil {
t.Fatalf("Process error: %v", err)
}
if out != nil {
t.Fatalf("expected nil output for dropped event, got %#v", out)
}
})
t.Run("error", func(t *testing.T) {
p := NewProcessor([]Normalizer{
Func{
MatchFn: func(event.Event) bool { return true },
NormalizeFn: func(context.Context, event.Event) (*event.Event, error) {
return nil, errors.New("map failed")
},
},
}, false)
_, err := p.Process(context.Background(), testEvent())
if err == nil {
t.Fatalf("expected error")
}
if !strings.Contains(err.Error(), "normalizer failed") {
t.Fatalf("unexpected error: %v", err)
}
})
}
func testEvent() event.Event {
return event.Event{
ID: "evt-normalize-1",
Kind: event.Kind("observation"),
Source: "source-1",
EmittedAt: time.Now().UTC(),
Payload: map[string]any{"x": 1},
}
}

15
processors/processor.go Normal file
View File

@@ -0,0 +1,15 @@
package processors
import (
"context"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// Processor can mutate/drop events (dedupe, rate-limit, normalization tweaks).
type Processor interface {
Process(ctx context.Context, in event.Event) (out *event.Event, err error)
}
// Factory constructs a configured Processor instance.
type Factory func() (Processor, error)

71
processors/registry.go Normal file
View File

@@ -0,0 +1,71 @@
package processors
import (
"fmt"
"strings"
)
type Registry struct {
byDriver map[string]Factory
}
func NewRegistry() *Registry {
return &Registry{byDriver: map[string]Factory{}}
}
// Register associates a processor driver name with a factory.
//
// Register panics for empty driver names, nil factories, and duplicates.
func (r *Registry) Register(driver string, f Factory) {
if r == nil {
panic("processors.Registry.Register: registry cannot be nil")
}
driver = strings.TrimSpace(driver)
if driver == "" {
panic("processors.Registry.Register: driver cannot be empty")
}
if f == nil {
panic(fmt.Sprintf("processors.Registry.Register: factory cannot be nil (driver=%q)", driver))
}
if r.byDriver == nil {
r.byDriver = map[string]Factory{}
}
if _, exists := r.byDriver[driver]; exists {
panic(fmt.Sprintf("processors.Registry.Register: driver %q already registered", driver))
}
r.byDriver[driver] = f
}
// Build constructs a Processor by driver name.
func (r *Registry) Build(driver string) (Processor, error) {
if r == nil {
return nil, fmt.Errorf("processors registry is nil")
}
driver = strings.TrimSpace(driver)
f, ok := r.byDriver[driver]
if !ok {
return nil, fmt.Errorf("unknown processor driver: %q", driver)
}
p, err := f()
if err != nil {
return nil, fmt.Errorf("build processor %q: %w", driver, err)
}
if p == nil {
return nil, fmt.Errorf("build processor %q: factory returned nil processor", driver)
}
return p, nil
}
// BuildChain constructs an ordered processor chain from a driver list.
func (r *Registry) BuildChain(drivers []string) ([]Processor, error) {
out := make([]Processor, 0, len(drivers))
for i, driver := range drivers {
p, err := r.Build(driver)
if err != nil {
return nil, fmt.Errorf("build processor chain[%d] (%q): %w", i, strings.TrimSpace(driver), err)
}
out = append(out, p)
}
return out, nil
}

100
processors/registry_test.go Normal file
View File

@@ -0,0 +1,100 @@
package processors
import (
"context"
"errors"
"strings"
"testing"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type testProcessor struct {
name string
}
func (p testProcessor) Process(context.Context, event.Event) (*event.Event, error) {
return nil, nil
}
func TestRegistryRegisterValidation(t *testing.T) {
t.Run("empty driver panics", func(t *testing.T) {
r := NewRegistry()
assertPanics(t, func() {
r.Register(" ", func() (Processor, error) { return testProcessor{name: "x"}, nil })
})
})
t.Run("nil factory panics", func(t *testing.T) {
r := NewRegistry()
assertPanics(t, func() {
r.Register("normalize", nil)
})
})
t.Run("duplicate driver panics", func(t *testing.T) {
r := NewRegistry()
r.Register("normalize", func() (Processor, error) { return testProcessor{name: "a"}, nil })
assertPanics(t, func() {
r.Register("normalize", func() (Processor, error) { return testProcessor{name: "b"}, nil })
})
})
}
func TestRegistryBuildUnknownDriver(t *testing.T) {
r := NewRegistry()
_, err := r.Build("does_not_exist")
if err == nil {
t.Fatalf("expected error for unknown driver")
}
if !strings.Contains(err.Error(), "unknown processor driver") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestRegistryBuildChainPreservesOrder(t *testing.T) {
r := NewRegistry()
r.Register("first", func() (Processor, error) { return testProcessor{name: "first"}, nil })
r.Register("second", func() (Processor, error) { return testProcessor{name: "second"}, nil })
chain, err := r.BuildChain([]string{"first", "second"})
if err != nil {
t.Fatalf("BuildChain error: %v", err)
}
if len(chain) != 2 {
t.Fatalf("expected 2 processors, got %d", len(chain))
}
p0, ok := chain[0].(testProcessor)
if !ok || p0.name != "first" {
t.Fatalf("unexpected chain[0]: %#v", chain[0])
}
p1, ok := chain[1].(testProcessor)
if !ok || p1.name != "second" {
t.Fatalf("unexpected chain[1]: %#v", chain[1])
}
}
func TestRegistryBuildChainIndexedFailure(t *testing.T) {
r := NewRegistry()
r.Register("ok", func() (Processor, error) { return testProcessor{name: "ok"}, nil })
r.Register("broken", func() (Processor, error) { return nil, errors.New("boom") })
_, err := r.BuildChain([]string{"ok", "broken"})
if err == nil {
t.Fatalf("expected error")
}
if !strings.Contains(err.Error(), "chain[1]") {
t.Fatalf("expected indexed error, got: %v", err)
}
}
func assertPanics(t *testing.T, fn func()) {
t.Helper()
defer func() {
if recover() == nil {
t.Fatalf("expected panic")
}
}()
fn()
}

25
scheduler/doc.go Normal file
View File

@@ -0,0 +1,25 @@
// Package scheduler runs feedkit sources and forwards their events to the
// daemon event bus.
//
// External API surface:
// - Scheduler: runs configured polling and streaming jobs
// - Job: one scheduler task bound to a source
// - StreamExitPolicy: stream supervision policy for non-fatal exits
// - StreamBackoff: restart pacing for supervised stream sources
//
// Optional helpers from helpers.go:
// - JobFromSourceConfig: build a scheduler job from a configured source and
// feedkit-owned scheduling params
//
// Poll sources are run on a fixed cadence with optional jitter. Stream sources
// are supervised long-lived workers. Their generic feedkit controls live under
// sources[].params:
// - stream_exit_policy: restart|stop|fatal (default restart)
// - stream_backoff_initial: positive duration (default 1s)
// - stream_backoff_max: positive duration (default 1m)
// - stream_backoff_jitter: non-negative duration (default 250ms)
//
// Stream sources can classify exits with sources.StreamRetryable and
// sources.StreamFatal. Plain errors are treated as retryable by default, while
// fatal exits are propagated from Scheduler.Run so the daemon can shut down.
package scheduler

138
scheduler/helpers.go Normal file
View File

@@ -0,0 +1,138 @@
package scheduler
import (
"fmt"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/sources"
)
// JobFromSourceConfig builds a scheduler Job from a configured source and its
// generic feedkit config.
func JobFromSourceConfig(src sources.Input, cfg config.SourceConfig) (Job, error) {
if src == nil {
return Job{}, fmt.Errorf("scheduler: source %q is nil", cfg.Name)
}
job := Job{
Source: src,
Every: cfg.Every.Duration,
}
if _, ok := src.(sources.StreamSource); ok {
if cfg.Every.Duration > 0 {
return Job{}, fmt.Errorf("source %q: sources[].every must be omitted for stream sources", cfg.Name)
}
policy, err := parseStreamExitPolicy(cfg)
if err != nil {
return Job{}, err
}
backoff, err := parseStreamBackoff(cfg)
if err != nil {
return Job{}, err
}
job.StreamExitPolicy = policy
job.StreamBackoff = backoff
return job, nil
}
if _, ok := src.(sources.PollSource); ok {
if cfg.Every.Duration <= 0 {
return Job{}, fmt.Errorf("source %q: sources[].every must be > 0 for polling sources", cfg.Name)
}
if err := rejectStreamParams(cfg); err != nil {
return Job{}, err
}
return job, nil
}
return Job{}, fmt.Errorf("scheduler: source %q implements neither PollSource nor StreamSource", cfg.Name)
}
func parseStreamExitPolicy(cfg config.SourceConfig) (StreamExitPolicy, error) {
const key = "stream_exit_policy"
raw, exists := cfg.Params[key]
if !exists || raw == nil {
return StreamExitPolicyRestart, nil
}
s, ok := raw.(string)
if !ok {
return "", fmt.Errorf("source %q: params.%s must be one of: restart, stop, fatal", cfg.Name, key)
}
switch StreamExitPolicy(strings.ToLower(strings.TrimSpace(s))) {
case StreamExitPolicyRestart:
return StreamExitPolicyRestart, nil
case StreamExitPolicyStop:
return StreamExitPolicyStop, nil
case StreamExitPolicyFatal:
return StreamExitPolicyFatal, nil
default:
return "", fmt.Errorf("source %q: params.%s must be one of: restart, stop, fatal", cfg.Name, key)
}
}
func parseStreamBackoff(cfg config.SourceConfig) (StreamBackoff, error) {
initial, err := parsePositiveOrDefaultDuration(cfg, "stream_backoff_initial", defaultStreamBackoffInitial)
if err != nil {
return StreamBackoff{}, err
}
max, err := parsePositiveOrDefaultDuration(cfg, "stream_backoff_max", defaultStreamBackoffMax)
if err != nil {
return StreamBackoff{}, err
}
jitter, err := parseNonNegativeOrDefaultDuration(cfg, "stream_backoff_jitter", defaultStreamBackoffJitter)
if err != nil {
return StreamBackoff{}, err
}
if max < initial {
return StreamBackoff{}, fmt.Errorf("source %q: params.stream_backoff_max must be >= params.stream_backoff_initial", cfg.Name)
}
return StreamBackoff{
Initial: initial,
Max: max,
Jitter: jitter,
}, nil
}
func rejectStreamParams(cfg config.SourceConfig) error {
streamKeys := []string{
"stream_exit_policy",
"stream_backoff_initial",
"stream_backoff_max",
"stream_backoff_jitter",
}
for _, key := range streamKeys {
if _, ok := cfg.Params[key]; ok {
return fmt.Errorf("source %q: params.%s is only valid for stream sources", cfg.Name, key)
}
}
return nil
}
func parsePositiveOrDefaultDuration(cfg config.SourceConfig, key string, def time.Duration) (time.Duration, error) {
if _, exists := cfg.Params[key]; !exists {
return def, nil
}
v, ok := cfg.ParamDuration(key)
if !ok || v <= 0 {
return 0, fmt.Errorf("source %q: params.%s must be a positive duration", cfg.Name, key)
}
return v, nil
}
func parseNonNegativeOrDefaultDuration(cfg config.SourceConfig, key string, def time.Duration) (time.Duration, error) {
if _, exists := cfg.Params[key]; !exists {
return def, nil
}
v, ok := cfg.ParamDuration(key)
if !ok || v < 0 {
return 0, fmt.Errorf("source %q: params.%s must be a non-negative duration", cfg.Name, key)
}
return v, nil
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"hash/fnv"
"math/rand"
"sync"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
@@ -28,8 +29,10 @@ type Logger = logging.Logf
// - For stream sources: Jitter is applied once at startup only (optional; useful to avoid
// reconnect storms when many instances start together).
type Job struct {
Source sources.Input
Every time.Duration
Source sources.Input
Every time.Duration
StreamExitPolicy StreamExitPolicy
StreamBackoff StreamBackoff
// Jitter is the maximum additional delay added before each poll.
// Example: if Every=15m and Jitter=30s, each poll will occur at:
@@ -41,12 +44,37 @@ type Job struct {
Jitter time.Duration
}
// StreamExitPolicy controls how the scheduler handles non-fatal stream exits.
type StreamExitPolicy string
const (
StreamExitPolicyRestart StreamExitPolicy = "restart"
StreamExitPolicyStop StreamExitPolicy = "stop"
StreamExitPolicyFatal StreamExitPolicy = "fatal"
)
// StreamBackoff controls restart pacing for stream supervision.
type StreamBackoff struct {
Initial time.Duration
Max time.Duration
Jitter time.Duration
}
type Scheduler struct {
Jobs []Job
Out chan<- event.Event
Logf Logger
}
const (
defaultStreamBackoffInitial = 1 * time.Second
defaultStreamBackoffMax = 1 * time.Minute
defaultStreamBackoffJitter = 250 * time.Millisecond
streamBackoffResetAfter = 5 * time.Minute
)
var timeNow = time.Now
// Run starts one goroutine per job.
// Poll jobs run on their own interval and emit 0..N events per poll.
// Stream jobs run continuously and emit events as they arrive.
@@ -58,16 +86,38 @@ func (s *Scheduler) Run(ctx context.Context) error {
return fmt.Errorf("scheduler.Run: no jobs configured")
}
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
fatalErrCh := make(chan error, 1)
var wg sync.WaitGroup
for _, job := range s.Jobs {
job := job // capture loop variable
go s.runJob(ctx, job)
wg.Add(1)
go func() {
defer wg.Done()
s.runJob(runCtx, job, fatalErrCh)
}()
}
<-ctx.Done()
return ctx.Err()
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case err := <-fatalErrCh:
cancel()
<-done
return err
case <-runCtx.Done():
<-done
return runCtx.Err()
}
}
func (s *Scheduler) runJob(ctx context.Context, job Job) {
func (s *Scheduler) runJob(ctx context.Context, job Job, fatalErrCh chan<- error) {
if job.Source == nil {
s.logf("scheduler: job has nil source")
return
@@ -75,7 +125,7 @@ func (s *Scheduler) runJob(ctx context.Context, job Job) {
// Stream sources: event-driven.
if ss, ok := job.Source.(sources.StreamSource); ok {
s.runStream(ctx, job, ss)
s.runStream(ctx, job, ss, fatalErrCh)
return
}
@@ -93,18 +143,51 @@ func (s *Scheduler) runJob(ctx context.Context, job Job) {
s.runPoller(ctx, job, ps)
}
func (s *Scheduler) runStream(ctx context.Context, job Job, src sources.StreamSource) {
func (s *Scheduler) runStream(ctx context.Context, job Job, src sources.StreamSource, fatalErrCh chan<- error) {
policy := effectiveStreamExitPolicy(job.StreamExitPolicy)
backoff := effectiveStreamBackoff(job.StreamBackoff)
rng := seededRNG(src.Name())
// Optional startup jitter: helps avoid reconnect storms if many daemons start at once.
if job.Jitter > 0 {
rng := seededRNG(src.Name())
if !sleepJitter(ctx, rng, job.Jitter) {
return
}
}
// Stream sources should block until ctx cancel or fatal error.
if err := src.Run(ctx, s.Out); err != nil && ctx.Err() == nil {
s.logf("scheduler: stream source %q exited with error: %v", src.Name(), err)
nextDelay := backoff.Initial
for {
startedAt := timeNow()
err := src.Run(ctx, s.Out)
if ctx.Err() != nil {
return
}
normalizedErr := normalizeStreamExitError(src.Name(), err)
if sources.IsStreamFatal(normalizedErr) {
s.reportFatal(fatalErrCh, fmt.Errorf("scheduler: stream source %q exited fatally: %w", src.Name(), normalizedErr))
return
}
switch policy {
case StreamExitPolicyStop:
s.logf("scheduler: stream source %q stopped after exit: %v", src.Name(), normalizedErr)
return
case StreamExitPolicyFatal:
s.reportFatal(fatalErrCh, fmt.Errorf("scheduler: stream source %q exited under fatal policy: %w", src.Name(), normalizedErr))
return
}
if streamRunWasStable(startedAt, timeNow()) {
nextDelay = backoff.Initial
}
delay := nextDelay + randomDuration(rng, backoff.Jitter)
s.logf("scheduler: stream source %q exited; restarting in %s: %v", src.Name(), delay, normalizedErr)
if !sleepDuration(ctx, delay) {
return
}
nextDelay = nextStreamBackoff(nextDelay, backoff.Max)
}
}
@@ -164,10 +247,77 @@ func (s *Scheduler) logf(format string, args ...any) {
s.Logf(format, args...)
}
func (s *Scheduler) reportFatal(ch chan<- error, err error) {
if err == nil {
return
}
select {
case ch <- err:
default:
}
}
// ---- helpers ----
func effectiveStreamExitPolicy(policy StreamExitPolicy) StreamExitPolicy {
switch policy {
case StreamExitPolicyStop, StreamExitPolicyFatal:
return policy
default:
return StreamExitPolicyRestart
}
}
func effectiveStreamBackoff(cfg StreamBackoff) StreamBackoff {
out := cfg
if out.Initial <= 0 {
out.Initial = defaultStreamBackoffInitial
}
if out.Max <= 0 {
out.Max = defaultStreamBackoffMax
}
if out.Max < out.Initial {
out.Max = out.Initial
}
if out.Jitter < 0 {
out.Jitter = 0
}
return out
}
func normalizeStreamExitError(sourceName string, err error) error {
if err != nil {
return err
}
return sources.StreamRetryable(fmt.Errorf("stream source %q exited unexpectedly without error", sourceName))
}
func nextStreamBackoff(current, max time.Duration) time.Duration {
if current <= 0 {
current = defaultStreamBackoffInitial
}
if max <= 0 {
max = defaultStreamBackoffMax
}
if current >= max {
return max
}
next := current * 2
if next < current || next > max {
return max
}
return next
}
func streamRunWasStable(startedAt, endedAt time.Time) bool {
if startedAt.IsZero() || endedAt.IsZero() {
return false
}
return endedAt.Sub(startedAt) >= streamBackoffResetAfter
}
func seededRNG(name string) *rand.Rand {
seed := time.Now().UnixNano() ^ int64(hashStringFNV32a(name))
seed := timeNow().UnixNano() ^ int64(hashStringFNV32a(name))
return rand.New(rand.NewSource(seed))
}
@@ -206,11 +356,23 @@ func sleepJitter(ctx context.Context, rng *rand.Rand, max time.Duration) bool {
return true
}
return sleepDuration(ctx, randomDuration(rng, max))
}
func randomDuration(rng *rand.Rand, max time.Duration) time.Duration {
if max <= 0 {
return 0
}
// Int63n requires a positive argument.
// We add 1 so max itself is attainable.
n := rng.Int63n(int64(max) + 1)
d := time.Duration(n)
return time.Duration(n)
}
func sleepDuration(ctx context.Context, d time.Duration) bool {
if d <= 0 {
return true
}
timer := time.NewTimer(d)
defer timer.Stop()

472
scheduler/scheduler_test.go Normal file
View File

@@ -0,0 +1,472 @@
package scheduler
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
"gitea.maximumdirect.net/ejr/feedkit/sources"
)
type testPollSource struct {
name string
}
func (s testPollSource) Name() string { return s.name }
func (s testPollSource) Poll(context.Context) ([]event.Event, error) { return nil, nil }
type scriptedStreamSource struct {
name string
mu sync.Mutex
calls int
runs []func(context.Context, chan<- event.Event) error
}
func (s *scriptedStreamSource) Name() string { return s.name }
func (s *scriptedStreamSource) Run(ctx context.Context, out chan<- event.Event) error {
s.mu.Lock()
call := s.calls
s.calls++
var run func(context.Context, chan<- event.Event) error
if call < len(s.runs) {
run = s.runs[call]
}
s.mu.Unlock()
if run != nil {
return run(ctx, out)
}
<-ctx.Done()
return ctx.Err()
}
func (s *scriptedStreamSource) CallCount() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.calls
}
type capturingLogger struct {
mu sync.Mutex
lines []string
}
func (l *capturingLogger) Logf(format string, args ...any) {
l.mu.Lock()
defer l.mu.Unlock()
l.lines = append(l.lines, fmt.Sprintf(format, args...))
}
func (l *capturingLogger) Contains(substr string) bool {
l.mu.Lock()
defer l.mu.Unlock()
for _, line := range l.lines {
if strings.Contains(line, substr) {
return true
}
}
return false
}
func TestSchedulerRunRestartsPlainStreamErrors(t *testing.T) {
src := &scriptedStreamSource{
name: "stream-a",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return errors.New("temporary failure") },
func(context.Context, chan<- event.Event) error { return errors.New("temporary failure") },
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamBackoff: StreamBackoff{
Initial: time.Millisecond,
Max: time.Millisecond,
},
}},
Out: make(chan event.Event, 1),
}
errCh := make(chan error, 1)
go func() { errCh <- s.Run(ctx) }()
waitFor(t, func() bool { return src.CallCount() >= 3 })
cancel()
err := <-errCh
if !errors.Is(err, context.Canceled) {
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
}
if src.CallCount() < 3 {
t.Fatalf("stream call count = %d, want at least 3", src.CallCount())
}
}
func TestSchedulerRunFatalStreamErrorReturns(t *testing.T) {
base := errors.New("fatal failure")
src := &scriptedStreamSource{
name: "stream-fatal",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return sources.StreamFatal(base) },
},
}
s := &Scheduler{
Jobs: []Job{{Source: src}},
Out: make(chan event.Event, 1),
}
err := s.Run(context.Background())
if err == nil {
t.Fatalf("Scheduler.Run() error = nil, want fatal error")
}
if !sources.IsStreamFatal(err) {
t.Fatalf("Scheduler.Run() error = %v, want fatal classification", err)
}
if !errors.Is(err, base) {
t.Fatalf("Scheduler.Run() error does not wrap base fatal error: %v", err)
}
}
func TestSchedulerRunStopPolicyStopsOnlyThatSource(t *testing.T) {
src := &scriptedStreamSource{
name: "stream-stop",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return errors.New("stop now") },
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamExitPolicy: StreamExitPolicyStop,
}},
Out: make(chan event.Event, 1),
}
errCh := make(chan error, 1)
go func() { errCh <- s.Run(ctx) }()
waitFor(t, func() bool { return src.CallCount() >= 1 })
time.Sleep(20 * time.Millisecond)
select {
case err := <-errCh:
t.Fatalf("Scheduler.Run() returned early: %v", err)
default:
}
if src.CallCount() != 1 {
t.Fatalf("stream call count = %d, want 1", src.CallCount())
}
cancel()
err := <-errCh
if !errors.Is(err, context.Canceled) {
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
}
}
func TestSchedulerRunFatalPolicyTreatsPlainErrorAsFatal(t *testing.T) {
base := errors.New("plain failure")
src := &scriptedStreamSource{
name: "stream-fatal-policy",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return base },
},
}
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamExitPolicy: StreamExitPolicyFatal,
}},
Out: make(chan event.Event, 1),
}
err := s.Run(context.Background())
if err == nil {
t.Fatalf("Scheduler.Run() error = nil, want fatal-policy error")
}
if !errors.Is(err, base) {
t.Fatalf("Scheduler.Run() error does not wrap base error: %v", err)
}
}
func TestSchedulerRunNilExitRestartsAsUnexpected(t *testing.T) {
logger := &capturingLogger{}
src := &scriptedStreamSource{
name: "stream-nil-exit",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return nil },
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamBackoff: StreamBackoff{
Initial: time.Millisecond,
Max: time.Millisecond,
},
}},
Out: make(chan event.Event, 1),
Logf: logger.Logf,
}
errCh := make(chan error, 1)
go func() { errCh <- s.Run(ctx) }()
waitFor(t, func() bool { return src.CallCount() >= 2 })
cancel()
err := <-errCh
if !errors.Is(err, context.Canceled) {
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
}
if !logger.Contains("exited unexpectedly without error") {
t.Fatalf("expected log to mention unexpected nil stream exit")
}
}
func TestSchedulerRunContextCancelDuringBackoff(t *testing.T) {
src := &scriptedStreamSource{
name: "stream-backoff-cancel",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return errors.New("retry me") },
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamBackoff: StreamBackoff{
Initial: time.Second,
Max: time.Second,
},
}},
Out: make(chan event.Event, 1),
}
errCh := make(chan error, 1)
go func() { errCh <- s.Run(ctx) }()
waitFor(t, func() bool { return src.CallCount() >= 1 })
cancel()
err := <-errCh
if !errors.Is(err, context.Canceled) {
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
}
time.Sleep(20 * time.Millisecond)
if src.CallCount() != 1 {
t.Fatalf("stream call count = %d, want 1", src.CallCount())
}
}
func TestNextStreamBackoffCapsAtMax(t *testing.T) {
if got := nextStreamBackoff(500*time.Millisecond, 2*time.Second); got != time.Second {
t.Fatalf("nextStreamBackoff() = %s, want 1s", got)
}
if got := nextStreamBackoff(time.Second, 2*time.Second); got != 2*time.Second {
t.Fatalf("nextStreamBackoff() = %s, want 2s", got)
}
if got := nextStreamBackoff(2*time.Second, 2*time.Second); got != 2*time.Second {
t.Fatalf("nextStreamBackoff() = %s, want 2s", got)
}
}
func TestStreamRunWasStableAfterFiveMinutes(t *testing.T) {
start := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)
if streamRunWasStable(start, start.Add(4*time.Minute+59*time.Second)) {
t.Fatalf("streamRunWasStable() = true, want false")
}
if !streamRunWasStable(start, start.Add(5*time.Minute)) {
t.Fatalf("streamRunWasStable() = false, want true")
}
}
func TestJobFromSourceConfigPollSource(t *testing.T) {
job, err := JobFromSourceConfig(testPollSource{name: "poll-a"}, config.SourceConfig{
Name: "poll-a",
Driver: "poll_driver",
Every: config.Duration{Duration: time.Minute},
})
if err != nil {
t.Fatalf("JobFromSourceConfig() error = %v", err)
}
if job.Every != time.Minute {
t.Fatalf("Job.Every = %s, want 1m", job.Every)
}
}
func TestJobFromSourceConfigPollSourceRejectsStreamParams(t *testing.T) {
_, err := JobFromSourceConfig(testPollSource{name: "poll-a"}, config.SourceConfig{
Name: "poll-a",
Driver: "poll_driver",
Every: config.Duration{Duration: time.Minute},
Params: map[string]any{
"stream_exit_policy": "restart",
},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want rejection")
}
if !strings.Contains(err.Error(), "only valid for stream sources") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
}
func TestJobFromSourceConfigStreamSourceParsesDefaultsAndOverrides(t *testing.T) {
src := &scriptedStreamSource{name: "stream-a"}
job, err := JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-a",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Params: map[string]any{
"stream_exit_policy": "stop",
"stream_backoff_initial": "2s",
"stream_backoff_max": "10s",
"stream_backoff_jitter": "500ms",
},
})
if err != nil {
t.Fatalf("JobFromSourceConfig() error = %v", err)
}
if job.StreamExitPolicy != StreamExitPolicyStop {
t.Fatalf("Job.StreamExitPolicy = %q, want %q", job.StreamExitPolicy, StreamExitPolicyStop)
}
if job.StreamBackoff.Initial != 2*time.Second {
t.Fatalf("Job.StreamBackoff.Initial = %s, want 2s", job.StreamBackoff.Initial)
}
if job.StreamBackoff.Max != 10*time.Second {
t.Fatalf("Job.StreamBackoff.Max = %s, want 10s", job.StreamBackoff.Max)
}
if job.StreamBackoff.Jitter != 500*time.Millisecond {
t.Fatalf("Job.StreamBackoff.Jitter = %s, want 500ms", job.StreamBackoff.Jitter)
}
defaultJob, err := JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-default",
Driver: "stream_driver",
Mode: config.SourceModeStream,
})
if err != nil {
t.Fatalf("JobFromSourceConfig() default error = %v", err)
}
if defaultJob.StreamExitPolicy != StreamExitPolicyRestart {
t.Fatalf("default Job.StreamExitPolicy = %q, want restart", defaultJob.StreamExitPolicy)
}
if defaultJob.StreamBackoff.Initial != defaultStreamBackoffInitial {
t.Fatalf("default Job.StreamBackoff.Initial = %s, want %s", defaultJob.StreamBackoff.Initial, defaultStreamBackoffInitial)
}
if defaultJob.StreamBackoff.Max != defaultStreamBackoffMax {
t.Fatalf("default Job.StreamBackoff.Max = %s, want %s", defaultJob.StreamBackoff.Max, defaultStreamBackoffMax)
}
if defaultJob.StreamBackoff.Jitter != defaultStreamBackoffJitter {
t.Fatalf("default Job.StreamBackoff.Jitter = %s, want %s", defaultJob.StreamBackoff.Jitter, defaultStreamBackoffJitter)
}
}
func TestJobFromSourceConfigStreamSourceRejectsInvalidSettings(t *testing.T) {
src := &scriptedStreamSource{name: "stream-b"}
_, err := JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-b",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Params: map[string]any{
"stream_exit_policy": "sometimes",
},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want invalid policy error")
}
if !strings.Contains(err.Error(), "stream_exit_policy") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
_, err = JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-b",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Params: map[string]any{
"stream_backoff_initial": "0s",
},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want invalid initial backoff error")
}
if !strings.Contains(err.Error(), "stream_backoff_initial") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
_, err = JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-b",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Params: map[string]any{
"stream_backoff_initial": "2s",
"stream_backoff_max": "1s",
},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want invalid max backoff error")
}
if !strings.Contains(err.Error(), "stream_backoff_max") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
}
func TestJobFromSourceConfigStreamSourceRejectsEvery(t *testing.T) {
src := &scriptedStreamSource{name: "stream-c"}
_, err := JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-c",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Every: config.Duration{Duration: time.Minute},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want every rejection")
}
if !strings.Contains(err.Error(), "sources[].every must be omitted") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
}
func waitFor(t *testing.T, cond func() bool) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("condition not satisfied before timeout")
}

View File

@@ -1,7 +0,0 @@
package scheduler
// Placeholder for per-source worker logic:
// - ticker loop
// - jitter
// - backoff on errors
// - emits events into scheduler.Out

View File

@@ -1,11 +1,6 @@
package sinks
import (
"fmt"
"strings"
"gitea.maximumdirect.net/ejr/feedkit/config"
)
import "gitea.maximumdirect.net/ejr/feedkit/config"
// RegisterBuiltins registers sink drivers included in this binary.
//
@@ -17,39 +12,8 @@ func RegisterBuiltins(r *Registry) {
return NewStdoutSink(cfg.Name), nil
})
// File sink: writes/archives events somewhere on disk.
r.Register("file", func(cfg config.SinkConfig) (Sink, error) {
return NewFileSinkFromConfig(cfg)
})
// Postgres sink: persists events durably.
r.Register("postgres", func(cfg config.SinkConfig) (Sink, error) {
return NewPostgresSinkFromConfig(cfg)
})
// NATS sink: publishes events to a broker for downstream consumers.
r.Register("nats", func(cfg config.SinkConfig) (Sink, error) {
return NewNATSSinkFromConfig(cfg)
})
}
// ---- helpers for validating sink params ----
//
// These helpers live in sinks (not config) on purpose:
// - config is domain-agnostic and should not embed driver-specific validation helpers.
// - sinks are adapters; validating their own params here keeps the logic near the driver.
func requireStringParam(cfg config.SinkConfig, key string) (string, error) {
v, ok := cfg.Params[key]
if !ok {
return "", fmt.Errorf("sink %q: params.%s is required", cfg.Name, key)
}
s, ok := v.(string)
if !ok {
return "", fmt.Errorf("sink %q: params.%s must be a string", cfg.Name, key)
}
if strings.TrimSpace(s) == "" {
return "", fmt.Errorf("sink %q: params.%s cannot be empty", cfg.Name, key)
}
return s, nil
}

101
sinks/doc.go Normal file
View File

@@ -0,0 +1,101 @@
// Package sinks defines the feedkit sink interface, sink driver registry, and
// built-in infrastructure sinks.
//
// External API surface:
// - Sink: adapter interface that consumes event.Event values
// - Registry / NewRegistry: named sink factory registry
// - RegisterBuiltins: registers the schema-free built-in sink drivers
//
// Built-in sink implementations:
// - stdout
// - nats
// - postgres
//
// Optional helpers from helpers.go:
// - PostgresFactory: returns a sink factory for the built-in Postgres sink
// using a provided downstream schema
//
// # NATS built-in overview
//
// The NATS sink publishes each event as JSON to a configured subject.
//
// Required params:
// - url: NATS server URL (for example, nats://localhost:4222)
// - subject: NATS subject to publish to
//
// Example config:
//
// sinks:
// - name: nats_main
// driver: nats
// params:
// url: nats://localhost:4222
// subject: feedkit.events
//
// # Postgres built-in overview
//
// The postgres sink is intentionally split between downstream daemon ownership
// and feedkit ownership:
// - downstream code registers table schema + event mapping functions
// - feedkit manages DB connection, create-if-missing DDL, transactional
// inserts, optional automatic retention pruning, and manual prune helpers
//
// Example config:
//
// sinks:
// - name: pg_main
// driver: postgres
// params:
// uri: postgres://localhost:5432/feedkit?sslmode=disable
// username: feedkit_user
// password: feedkit_pass
// prune: 3d # optional: prune rows older than now-3d on each write tx
//
// params.prune supports:
// - Go duration strings (72h, 90m, 30s, ...)
// - day/week suffixes (3d, 2w)
//
// If params.prune is omitted, automatic pruning is disabled.
//
// Example downstream wiring:
//
// sinkReg := sinks.NewRegistry()
// sinks.RegisterBuiltins(sinkReg)
// sinkReg.Register("postgres", sinks.PostgresFactory(sinks.PostgresSchema{
// Tables: []sinks.PostgresTable{
// {
// Name: "events",
// Columns: []sinks.PostgresColumn{
// {Name: "event_id", Type: "TEXT", Nullable: false},
// {Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
// {Name: "payload_json", Type: "JSONB", Nullable: false},
// },
// PrimaryKey: []string{"event_id"},
// PruneColumn: "emitted_at", // required for retention pruning
// },
// },
// MapEvent: func(ctx context.Context, e event.Event) ([]sinks.PostgresWrite, error) {
// b, err := json.Marshal(e.Payload)
// if err != nil {
// return nil, err
// }
// return []sinks.PostgresWrite{
// {
// Table: "events",
// Values: map[string]any{
// "event_id": e.ID,
// "emitted_at": e.EmittedAt,
// "payload_json": string(b),
// },
// },
// }, nil
// },
// }))
//
// Manual pruning via type assertion (administrative helpers):
//
// if p, ok := sink.(sinks.PostgresPruner); ok {
// _, _ = p.PruneKeepLatest(ctx, "events", 10000)
// _, _ = p.PruneOlderThan(ctx, "events", time.Now().UTC().AddDate(0, -1, 0))
// }
package sinks

View File

@@ -1,30 +0,0 @@
package sinks
import (
"context"
"fmt"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type FileSink struct {
name string
path string
}
func NewFileSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
path, err := requireStringParam(cfg, "path")
if err != nil {
return nil, err
}
return &FileSink{name: cfg.Name, path: path}, nil
}
func (s *FileSink) Name() string { return s.name }
func (s *FileSink) Consume(ctx context.Context, e event.Event) error {
_ = ctx
_ = e
return fmt.Errorf("file sink: TODO implement (path=%s)", s.path)
}

35
sinks/helpers.go Normal file
View File

@@ -0,0 +1,35 @@
package sinks
import (
"fmt"
"strings"
"gitea.maximumdirect.net/ejr/feedkit/config"
)
// requireStringParam returns a non-empty string sink param.
//
// This helper is intentionally local to sinks rather than config so
// driver-specific validation stays close to the adapters that use it.
func requireStringParam(cfg config.SinkConfig, key string) (string, error) {
v, ok := cfg.Params[key]
if !ok {
return "", fmt.Errorf("sink %q: params.%s is required", cfg.Name, key)
}
s, ok := v.(string)
if !ok {
return "", fmt.Errorf("sink %q: params.%s must be a string", cfg.Name, key)
}
if strings.TrimSpace(s) == "" {
return "", fmt.Errorf("sink %q: params.%s cannot be empty", cfg.Name, key)
}
return s, nil
}
// PostgresFactory returns a sink factory that builds the built-in Postgres sink
// using the provided downstream schema definition.
func PostgresFactory(schema PostgresSchema) Factory {
return func(cfg config.SinkConfig) (Sink, error) {
return NewPostgresSinkFromConfig(cfg, schema)
}
}

46
sinks/helpers_test.go Normal file
View File

@@ -0,0 +1,46 @@
package sinks
import (
"context"
"testing"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
func TestPostgresFactoryReturnsWorkingFactory(t *testing.T) {
factory := PostgresFactory(testPostgresSchema())
if factory == nil {
t.Fatalf("PostgresFactory() returned nil")
}
if _, err := factory(config.SinkConfig{}); err == nil {
t.Fatalf("factory(config) expected parameter validation error")
}
}
func testPostgresSchema() PostgresSchema {
return PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "event_id", Type: "TEXT", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
},
PrimaryKey: []string{"event_id"},
PruneColumn: "emitted_at",
},
},
MapEvent: func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{
Table: "events",
Values: map[string]any{
"event_id": e.ID,
"emitted_at": e.EmittedAt,
},
},
}, nil
},
}
}

View File

@@ -13,9 +13,9 @@ import (
)
type NATSSink struct {
name string
url string
exchange string
name string
url string
subject string
mu sync.Mutex
conn *nats.Conn
@@ -26,11 +26,11 @@ func NewNATSSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
if err != nil {
return nil, err
}
ex, err := requireStringParam(cfg, "exchange")
subject, err := requireStringParam(cfg, "subject")
if err != nil {
return nil, err
}
return &NATSSink{name: cfg.Name, url: url, exchange: ex}, nil
return &NATSSink{name: cfg.Name, url: url, subject: subject}, nil
}
func (r *NATSSink) Name() string { return r.name }
@@ -59,7 +59,7 @@ func (r *NATSSink) Consume(ctx context.Context, e event.Event) error {
if err := ctx.Err(); err != nil {
return err
}
if err := conn.Publish(r.exchange, b); err != nil {
if err := conn.Publish(r.subject, b); err != nil {
return fmt.Errorf("NATS sink: publish: %w", err)
}
return nil

47
sinks/nats_test.go Normal file
View File

@@ -0,0 +1,47 @@
package sinks
import (
"strings"
"testing"
"gitea.maximumdirect.net/ejr/feedkit/config"
)
func TestNewNATSSinkFromConfigRequiresSubject(t *testing.T) {
sink, err := NewNATSSinkFromConfig(config.SinkConfig{
Name: "nats-main",
Driver: "nats",
Params: map[string]any{
"url": "nats://localhost:4222",
"subject": "feedkit.events",
},
})
if err != nil {
t.Fatalf("NewNATSSinkFromConfig() error = %v", err)
}
natsSink, ok := sink.(*NATSSink)
if !ok {
t.Fatalf("sink type = %T, want *NATSSink", sink)
}
if natsSink.subject != "feedkit.events" {
t.Fatalf("subject = %q, want feedkit.events", natsSink.subject)
}
}
func TestNewNATSSinkFromConfigRejectsLegacyExchange(t *testing.T) {
_, err := NewNATSSinkFromConfig(config.SinkConfig{
Name: "nats-main",
Driver: "nats",
Params: map[string]any{
"url": "nats://localhost:4222",
"exchange": "feedkit.events",
},
})
if err == nil {
t.Fatalf("NewNATSSinkFromConfig() expected error")
}
if !strings.Contains(err.Error(), "params.subject is required") {
t.Fatalf("error = %q, want params.subject is required", err)
}
}

View File

@@ -2,36 +2,437 @@ package sinks
import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
)
type PostgresSink struct {
name string
dsn string
type postgresTx interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Commit() error
Rollback() error
}
func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
dsn, err := requireStringParam(cfg, "dsn")
type postgresExecer interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
type postgresDB interface {
PingContext(ctx context.Context) error
BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Close() error
}
type sqlDBWrapper struct {
db *sql.DB
}
func (w *sqlDBWrapper) PingContext(ctx context.Context) error {
return w.db.PingContext(ctx)
}
func (w *sqlDBWrapper) BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error) {
tx, err := w.db.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &PostgresSink{name: cfg.Name, dsn: dsn}, nil
return &sqlTxWrapper{tx: tx}, nil
}
func (w *sqlDBWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
return w.db.ExecContext(ctx, query, args...)
}
func (w *sqlDBWrapper) Close() error {
return w.db.Close()
}
type sqlTxWrapper struct {
tx *sql.Tx
}
func (w *sqlTxWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
return w.tx.ExecContext(ctx, query, args...)
}
func (w *sqlTxWrapper) Commit() error {
return w.tx.Commit()
}
func (w *sqlTxWrapper) Rollback() error {
return w.tx.Rollback()
}
var openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
db, err := pgconn.Open(ctx, cfg)
if err != nil {
return nil, err
}
return &sqlDBWrapper{db: db}, nil
}
type PostgresSink struct {
name string
db postgresDB
schema postgresSchemaCompiled
pruneWindow time.Duration
}
func NewPostgresSinkFromConfig(cfg config.SinkConfig, schemaDef PostgresSchema) (Sink, error) {
uri, err := requireStringParam(cfg, "uri")
if err != nil {
return nil, err
}
username, err := requireStringParam(cfg, "username")
if err != nil {
return nil, err
}
password, err := requireStringParam(cfg, "password")
if err != nil {
return nil, err
}
pruneWindow, err := parsePostgresPruneWindow(cfg)
if err != nil {
return nil, err
}
schema, err := compilePostgresSchema(schemaDef)
if err != nil {
return nil, fmt.Errorf("postgres sink %q: compile schema: %w", cfg.Name, err)
}
db, err := openPostgresDB(context.Background(), pgconn.ConnConfig{
URI: uri,
Username: username,
Password: password,
})
if err != nil {
return nil, fmt.Errorf("postgres sink %q: open db: %w", cfg.Name, err)
}
s := &PostgresSink{name: cfg.Name, db: db, schema: schema, pruneWindow: pruneWindow}
if err := s.initialize(); err != nil {
_ = db.Close()
return nil, err
}
return s, nil
}
func (p *PostgresSink) Name() string { return p.name }
func (p *PostgresSink) Consume(ctx context.Context, e event.Event) error {
_ = ctx
// Boundary validation: if something upstream violated invariants,
// surface it loudly rather than printing partial nonsense.
// surface it loudly rather than writing corrupt rows.
if err := e.Validate(); err != nil {
return fmt.Errorf("postgres sink: invalid event: %w", err)
}
// TODO implement Postgres transaction
if err := ctx.Err(); err != nil {
return err
}
writes, err := p.schema.mapEvent(ctx, e)
if err != nil {
return fmt.Errorf("postgres sink: map event: %w", err)
}
if len(writes) == 0 {
return nil
}
tx, err := p.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("postgres sink: begin tx: %w", err)
}
committed := false
defer func() {
if !committed {
_ = tx.Rollback()
}
}()
for _, w := range writes {
tbl, err := p.schema.validateWrite(w)
if err != nil {
return fmt.Errorf("postgres sink: %w", err)
}
query, args, err := buildInsertSQL(tbl, w)
if err != nil {
return fmt.Errorf("postgres sink: build insert for table %q: %w", tbl.name, err)
}
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return fmt.Errorf("postgres sink: insert into %q: %w", tbl.name, err)
}
}
if p.pruneWindow > 0 {
cutoff := time.Now().UTC().Add(-p.pruneWindow)
for _, tableName := range p.schema.tableOrder {
tbl := p.schema.tables[tableName]
if _, err := execPruneOlderThan(ctx, tx, tbl, cutoff); err != nil {
return fmt.Errorf("postgres sink: %w", err)
}
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("postgres sink: commit tx: %w", err)
}
committed = true
return nil
}
func (p *PostgresSink) PruneKeepLatest(ctx context.Context, table string, keep int) (int64, error) {
if keep < 0 {
return 0, fmt.Errorf("postgres sink: keep must be >= 0")
}
tbl, err := p.lookupTable(table)
if err != nil {
return 0, err
}
query := fmt.Sprintf(
`DELETE FROM %s WHERE ctid IN (
SELECT ctid FROM %s
ORDER BY %s DESC
OFFSET $1
)`,
quotePostgresIdent(tbl.name),
quotePostgresIdent(tbl.name),
quotePostgresIdent(tbl.pruneColumn),
)
res, err := p.db.ExecContext(ctx, query, keep)
if err != nil {
return 0, fmt.Errorf("postgres sink: prune keep latest table %q: %w", tbl.name, err)
}
rows, err := res.RowsAffected()
if err != nil {
return 0, fmt.Errorf("postgres sink: prune keep latest table %q rows affected: %w", tbl.name, err)
}
return rows, nil
}
func (p *PostgresSink) PruneOlderThan(ctx context.Context, table string, cutoff time.Time) (int64, error) {
tbl, err := p.lookupTable(table)
if err != nil {
return 0, err
}
rows, err := execPruneOlderThan(ctx, p.db, tbl, cutoff)
if err != nil {
return 0, fmt.Errorf("postgres sink: %w", err)
}
return rows, nil
}
func (p *PostgresSink) PruneAllKeepLatest(ctx context.Context, keep int) (map[string]int64, error) {
counts := make(map[string]int64, len(p.schema.tableOrder))
for _, table := range p.schema.tableOrder {
n, err := p.PruneKeepLatest(ctx, table, keep)
if err != nil {
return counts, err
}
counts[table] = n
}
return counts, nil
}
func (p *PostgresSink) PruneAllOlderThan(ctx context.Context, cutoff time.Time) (map[string]int64, error) {
counts := make(map[string]int64, len(p.schema.tableOrder))
for _, table := range p.schema.tableOrder {
n, err := p.PruneOlderThan(ctx, table, cutoff)
if err != nil {
return counts, err
}
counts[table] = n
}
return counts, nil
}
func (p *PostgresSink) initialize() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
for _, tableName := range p.schema.tableOrder {
tbl := p.schema.tables[tableName]
createTableSQL := buildCreateTableSQL(tbl)
if _, err := p.db.ExecContext(ctx, createTableSQL); err != nil {
return fmt.Errorf("postgres sink %q: ensure table %q: %w", p.name, tbl.name, err)
}
for _, idx := range tbl.indexes {
createIndexSQL := buildCreateIndexSQL(tbl.name, idx)
if _, err := p.db.ExecContext(ctx, createIndexSQL); err != nil {
return fmt.Errorf("postgres sink %q: ensure index %q on %q: %w", p.name, idx.Name, tbl.name, err)
}
}
}
return nil
}
func (p *PostgresSink) lookupTable(table string) (postgresTableCompiled, error) {
table = strings.TrimSpace(table)
if table == "" {
return postgresTableCompiled{}, fmt.Errorf("postgres sink: table cannot be empty")
}
tbl, ok := p.schema.tables[table]
if !ok {
return postgresTableCompiled{}, fmt.Errorf("postgres sink: unknown table %q", table)
}
return tbl, nil
}
func parsePostgresPruneWindow(cfg config.SinkConfig) (time.Duration, error) {
raw, ok := cfg.Params["prune"]
if !ok || raw == nil {
return 0, nil
}
s, ok := raw.(string)
if !ok {
return 0, fmt.Errorf("sink %q: params.prune must be a string duration (e.g. 72h, 3d, 2w)", cfg.Name)
}
d, err := parsePostgresPruneDuration(s)
if err != nil {
return 0, fmt.Errorf("sink %q: params.prune %q is invalid: %w", cfg.Name, s, err)
}
return d, nil
}
func parsePostgresPruneDuration(raw string) (time.Duration, error) {
s := strings.TrimSpace(raw)
if s == "" {
return 0, fmt.Errorf("must not be empty")
}
lower := strings.ToLower(s)
if strings.HasSuffix(lower, "d") || strings.HasSuffix(lower, "w") {
unit := lower[len(lower)-1]
n, err := strconv.Atoi(strings.TrimSpace(lower[:len(lower)-1]))
if err != nil {
return 0, fmt.Errorf("must use a positive integer before %q", string(unit))
}
if n <= 0 {
return 0, fmt.Errorf("must be > 0")
}
if unit == 'd' {
return time.Duration(n) * 24 * time.Hour, nil
}
return time.Duration(n) * 7 * 24 * time.Hour, nil
}
d, err := time.ParseDuration(s)
if err != nil {
return 0, fmt.Errorf("must be a Go duration or use d/w suffixes")
}
if d <= 0 {
return 0, fmt.Errorf("must be > 0")
}
return d, nil
}
func buildPruneOlderThanSQL(tbl postgresTableCompiled) string {
return fmt.Sprintf(
`DELETE FROM %s WHERE %s < $1`,
quotePostgresIdent(tbl.name),
quotePostgresIdent(tbl.pruneColumn),
)
}
func execPruneOlderThan(ctx context.Context, execer postgresExecer, tbl postgresTableCompiled, cutoff time.Time) (int64, error) {
query := buildPruneOlderThanSQL(tbl)
res, err := execer.ExecContext(ctx, query, cutoff)
if err != nil {
return 0, fmt.Errorf("prune older than table %q: %w", tbl.name, err)
}
rows, err := res.RowsAffected()
if err != nil {
return 0, fmt.Errorf("prune older than table %q rows affected: %w", tbl.name, err)
}
return rows, nil
}
func buildCreateTableSQL(tbl postgresTableCompiled) string {
defs := make([]string, 0, len(tbl.columnOrder)+1)
for _, colName := range tbl.columnOrder {
col := tbl.columns[colName]
def := fmt.Sprintf("%s %s", quotePostgresIdent(col.Name), col.Type)
if !col.Nullable {
def += " NOT NULL"
}
defs = append(defs, def)
}
if len(tbl.primaryKey) > 0 {
defs = append(defs, fmt.Sprintf("PRIMARY KEY (%s)", joinQuotedIdents(tbl.primaryKey)))
}
return fmt.Sprintf(
"CREATE TABLE IF NOT EXISTS %s (%s)",
quotePostgresIdent(tbl.name),
strings.Join(defs, ", "),
)
}
func buildCreateIndexSQL(tableName string, idx PostgresIndex) string {
unique := ""
if idx.Unique {
unique = "UNIQUE "
}
return fmt.Sprintf(
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
unique,
quotePostgresIdent(idx.Name),
quotePostgresIdent(tableName),
joinQuotedIdents(idx.Columns),
)
}
func buildInsertSQL(tbl postgresTableCompiled, w PostgresWrite) (string, []any, error) {
cols := make([]string, 0, len(tbl.columnOrder))
args := make([]any, 0, len(tbl.columnOrder))
placeholders := make([]string, 0, len(tbl.columnOrder))
for i, colName := range tbl.columnOrder {
v, ok := w.Values[colName]
if !ok {
return "", nil, fmt.Errorf("missing value for column %q", colName)
}
cols = append(cols, quotePostgresIdent(colName))
args = append(args, v)
placeholders = append(placeholders, "$"+strconv.Itoa(i+1))
}
q := fmt.Sprintf(
"INSERT INTO %s (%s) VALUES (%s)",
quotePostgresIdent(tbl.name),
strings.Join(cols, ", "),
strings.Join(placeholders, ", "),
)
return q, args, nil
}
func joinQuotedIdents(idents []string) string {
quoted := make([]string, 0, len(idents))
for _, s := range idents {
quoted = append(quoted, quotePostgresIdent(s))
}
return strings.Join(quoted, ", ")
}
func quotePostgresIdent(s string) string {
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
}

239
sinks/postgres_schema.go Normal file
View File

@@ -0,0 +1,239 @@
package sinks
import (
"context"
"fmt"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// PostgresMapFunc maps one event into zero or more table writes.
//
// Returning zero writes means "this event is not mapped for this sink" and is
// treated as a non-error no-op.
type PostgresMapFunc func(ctx context.Context, e event.Event) ([]PostgresWrite, error)
// PostgresSchema describes the downstream-provided relational model and mapper
// for one configured postgres sink.
type PostgresSchema struct {
Tables []PostgresTable
MapEvent PostgresMapFunc
}
type PostgresWrite struct {
Table string
Values map[string]any
}
type PostgresTable struct {
Name string
Columns []PostgresColumn
PrimaryKey []string
PruneColumn string
Indexes []PostgresIndex
}
type PostgresColumn struct {
Name string
Type string
Nullable bool
}
type PostgresIndex struct {
Name string
Columns []string
Unique bool
}
// PostgresPruner is an optional interface exposed by PostgresSink so downstream
// applications can call retention helpers via type assertion.
type PostgresPruner interface {
PruneKeepLatest(ctx context.Context, table string, keep int) (int64, error)
PruneOlderThan(ctx context.Context, table string, cutoff time.Time) (int64, error)
PruneAllKeepLatest(ctx context.Context, keep int) (map[string]int64, error)
PruneAllOlderThan(ctx context.Context, cutoff time.Time) (map[string]int64, error)
}
type postgresSchemaCompiled struct {
tableOrder []string
tables map[string]postgresTableCompiled
mapEvent PostgresMapFunc
}
type postgresTableCompiled struct {
name string
columns map[string]PostgresColumn
columnOrder []string
primaryKey []string
pruneColumn string
indexes []PostgresIndex
}
func compilePostgresSchema(schema PostgresSchema) (postgresSchemaCompiled, error) {
if schema.MapEvent == nil {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: map function is required")
}
if len(schema.Tables) == 0 {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: at least one table is required")
}
compiled := postgresSchemaCompiled{
tableOrder: make([]string, 0, len(schema.Tables)),
tables: make(map[string]postgresTableCompiled, len(schema.Tables)),
mapEvent: schema.MapEvent,
}
seenTables := map[string]bool{}
seenIndexes := map[string]bool{}
for i, tbl := range schema.Tables {
tableName := strings.TrimSpace(tbl.Name)
if tableName == "" {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: tables[%d].name is required", i)
}
if seenTables[tableName] {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: duplicate table name %q", tableName)
}
seenTables[tableName] = true
if len(tbl.Columns) == 0 {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q must define at least one column", tableName)
}
colOrder := make([]string, 0, len(tbl.Columns))
colMap := make(map[string]PostgresColumn, len(tbl.Columns))
for j, col := range tbl.Columns {
colName := strings.TrimSpace(col.Name)
if colName == "" {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q columns[%d].name is required", tableName, j)
}
if _, exists := colMap[colName]; exists {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q duplicate column %q", tableName, colName)
}
if strings.TrimSpace(col.Type) == "" {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q column %q type is required", tableName, colName)
}
colOrder = append(colOrder, colName)
colMap[colName] = PostgresColumn{
Name: colName,
Type: strings.TrimSpace(col.Type),
Nullable: col.Nullable,
}
}
pk, err := validatePostgresColumnSet(tableName, "primary key", tbl.PrimaryKey, colMap)
if err != nil {
return postgresSchemaCompiled{}, err
}
pruneCol := strings.TrimSpace(tbl.PruneColumn)
if pruneCol == "" {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q prune column is required", tableName)
}
if _, ok := colMap[pruneCol]; !ok {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q prune column %q not found in columns", tableName, pruneCol)
}
indexes := make([]PostgresIndex, 0, len(tbl.Indexes))
for j, idx := range tbl.Indexes {
idxName := strings.TrimSpace(idx.Name)
if idxName == "" {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q indexes[%d].name is required", tableName, j)
}
if len(idx.Columns) == 0 {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q index %q must include at least one column", tableName, idxName)
}
if seenIndexes[idxName] {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: duplicate index name %q", idxName)
}
seenIndexes[idxName] = true
idxCols, err := validatePostgresColumnSet(tableName, fmt.Sprintf("index %q columns", idxName), idx.Columns, colMap)
if err != nil {
return postgresSchemaCompiled{}, err
}
indexes = append(indexes, PostgresIndex{
Name: idxName,
Columns: idxCols,
Unique: idx.Unique,
})
}
compiled.tableOrder = append(compiled.tableOrder, tableName)
compiled.tables[tableName] = postgresTableCompiled{
name: tableName,
columns: colMap,
columnOrder: colOrder,
primaryKey: pk,
pruneColumn: pruneCol,
indexes: indexes,
}
}
return compiled, nil
}
func validatePostgresColumnSet(tableName, label string, cols []string, colMap map[string]PostgresColumn) ([]string, error) {
if len(cols) == 0 {
return nil, nil
}
out := make([]string, 0, len(cols))
seen := map[string]bool{}
for _, c := range cols {
name := strings.TrimSpace(c)
if name == "" {
return nil, fmt.Errorf("postgres schema: table %q %s contains empty column name", tableName, label)
}
if seen[name] {
return nil, fmt.Errorf("postgres schema: table %q %s contains duplicate column %q", tableName, label, name)
}
if _, ok := colMap[name]; !ok {
return nil, fmt.Errorf("postgres schema: table %q %s references unknown column %q", tableName, label, name)
}
seen[name] = true
out = append(out, name)
}
return out, nil
}
func (s postgresSchemaCompiled) validateWrite(w PostgresWrite) (postgresTableCompiled, error) {
tableName := strings.TrimSpace(w.Table)
if tableName == "" {
return postgresTableCompiled{}, fmt.Errorf("write table is required")
}
t, ok := s.tables[tableName]
if !ok {
return postgresTableCompiled{}, fmt.Errorf("table %q is not defined in postgres schema", tableName)
}
if len(w.Values) == 0 {
return postgresTableCompiled{}, fmt.Errorf("write for table %q must include values", tableName)
}
for k := range w.Values {
if _, ok := t.columns[k]; !ok {
return postgresTableCompiled{}, fmt.Errorf("write for table %q includes unknown column %q", tableName, k)
}
}
if len(w.Values) != len(t.columnOrder) {
return postgresTableCompiled{}, fmt.Errorf("write for table %q must include all declared columns", tableName)
}
for _, col := range t.columnOrder {
v, ok := w.Values[col]
if !ok {
return postgresTableCompiled{}, fmt.Errorf("write for table %q is missing column %q", tableName, col)
}
if v == nil {
if c := t.columns[col]; !c.Nullable {
return postgresTableCompiled{}, fmt.Errorf("write for table %q has nil value for non-null column %q", tableName, col)
}
}
}
return t, nil
}

748
sinks/postgres_test.go Normal file
View File

@@ -0,0 +1,748 @@
package sinks
import (
"context"
"database/sql"
"errors"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
)
type fakeResult struct {
rows int64
}
func (r fakeResult) LastInsertId() (int64, error) { return 0, errors.New("unsupported") }
func (r fakeResult) RowsAffected() (int64, error) { return r.rows, nil }
type execCall struct {
query string
args []any
}
type fakeTx struct {
execCalls []execCall
execErr error
execErrOnCall int
execRows int64
commitErr error
rollbackErr error
commitCalls int
rollbackCalls int
}
func (t *fakeTx) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) {
t.execCalls = append(t.execCalls, execCall{query: query, args: append([]any(nil), args...)})
if t.execErr != nil && t.execErrOnCall == len(t.execCalls) {
return nil, t.execErr
}
return fakeResult{rows: t.execRows}, nil
}
func (t *fakeTx) Commit() error {
t.commitCalls++
return t.commitErr
}
func (t *fakeTx) Rollback() error {
t.rollbackCalls++
return t.rollbackErr
}
type fakeDB struct {
pingErr error
beginErr error
execErr error
execErrOnCall int
execRows int64
pingCalls int
beginCalls int
execCalls []execCall
closeCalls int
tx *fakeTx
}
func (d *fakeDB) PingContext(_ context.Context) error {
d.pingCalls++
return d.pingErr
}
func (d *fakeDB) BeginTx(_ context.Context, _ *sql.TxOptions) (postgresTx, error) {
d.beginCalls++
if d.beginErr != nil {
return nil, d.beginErr
}
if d.tx == nil {
d.tx = &fakeTx{}
}
return d.tx, nil
}
func (d *fakeDB) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) {
d.execCalls = append(d.execCalls, execCall{query: query, args: append([]any(nil), args...)})
if d.execErr != nil && d.execErrOnCall == len(d.execCalls) {
return nil, d.execErr
}
return fakeResult{rows: d.execRows}, nil
}
func (d *fakeDB) Close() error {
d.closeCalls++
return nil
}
func withPostgresTestState(t *testing.T) {
t.Helper()
oldOpen := openPostgresDB
t.Cleanup(func() {
openPostgresDB = oldOpen
})
}
func validTestEvent() event.Event {
now := time.Now().UTC()
return event.Event{
ID: "evt-1",
Kind: event.Kind("observation"),
Source: "source-1",
EmittedAt: now,
Payload: map[string]any{
"x": 1,
},
}
}
func schemaOneTable(mapFn PostgresMapFunc) PostgresSchema {
return PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "event_id", Type: "TEXT", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
{Name: "payload_json", Type: "JSONB", Nullable: false},
},
PrimaryKey: []string{"event_id"},
PruneColumn: "emitted_at",
Indexes: []PostgresIndex{
{Name: "idx_events_emitted_at", Columns: []string{"emitted_at"}},
},
},
},
MapEvent: mapFn,
}
}
func schemaTwoTables(mapFn PostgresMapFunc) PostgresSchema {
return PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "event_id", Type: "TEXT", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
},
PrimaryKey: []string{"event_id"},
PruneColumn: "emitted_at",
},
{
Name: "event_payloads",
Columns: []PostgresColumn{
{Name: "event_id", Type: "TEXT", Nullable: false},
{Name: "payload_json", Type: "JSONB", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
},
PrimaryKey: []string{"event_id"},
PruneColumn: "emitted_at",
},
},
MapEvent: mapFn,
}
}
func mustCompileSchema(t *testing.T, s PostgresSchema) postgresSchemaCompiled {
t.Helper()
compiled, err := compilePostgresSchema(s)
if err != nil {
t.Fatalf("compile schema: %v", err)
}
return compiled
}
func TestCompilePostgresSchemaRejectsInvalidSchema(t *testing.T) {
_, err := compilePostgresSchema(PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "id", Type: "TEXT", Nullable: false},
},
PruneColumn: "missing_col",
},
},
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
})
if err == nil {
t.Fatalf("expected invalid schema error")
}
if !strings.Contains(err.Error(), "prune column") {
t.Fatalf("unexpected schema validation error: %v", err)
}
_, err = compilePostgresSchema(PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "id", Type: "TEXT", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
},
PruneColumn: "emitted_at",
Indexes: []PostgresIndex{
{Name: "idx_events_empty", Columns: nil},
},
},
},
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
})
if err == nil {
t.Fatalf("expected invalid index schema error")
}
if !strings.Contains(err.Error(), "at least one column") {
t.Fatalf("unexpected index validation error: %v", err)
}
}
func TestPostgresFactoryBuildsMultipleSinksWithSameSchema(t *testing.T) {
withPostgresTestState(t)
dbs := []*fakeDB{{}, {}}
var gotCfgs []pgconn.ConnConfig
openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
gotCfgs = append(gotCfgs, cfg)
db := dbs[len(gotCfgs)-1]
if err := db.PingContext(ctx); err != nil {
return nil, err
}
return db, nil
}
factory := PostgresFactory(schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
}))
for _, name := range []string{"pg_a", "pg_b"} {
sink, err := factory(config.SinkConfig{
Name: name,
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
},
})
if err != nil {
t.Fatalf("factory(%q) error = %v", name, err)
}
if sink == nil {
t.Fatalf("factory(%q) returned nil sink", name)
}
}
if len(gotCfgs) != 2 {
t.Fatalf("len(gotCfgs) = %d, want 2", len(gotCfgs))
}
if gotCfgs[0].Username != "user" || gotCfgs[0].Password != "pass" {
t.Fatalf("first ConnConfig = %+v", gotCfgs[0])
}
for i, db := range dbs {
if db.pingCalls != 1 {
t.Fatalf("db[%d] pingCalls = %d, want 1", i, db.pingCalls)
}
}
}
func TestNewPostgresSinkFromConfigMissingParams(t *testing.T) {
withPostgresTestState(t)
tests := []struct {
name string
params map[string]any
want string
}{
{name: "missing uri", params: map[string]any{"username": "u", "password": "p"}, want: "params.uri"},
{name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p"}, want: "params.username"},
{name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u"}, want: "params.password"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: tc.params,
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err == nil {
t.Fatalf("expected error")
}
if !strings.Contains(err.Error(), tc.want) {
t.Fatalf("expected %q in error, got: %v", tc.want, err)
}
})
}
}
func TestNewPostgresSinkFromConfigRejectsInvalidSchema(t *testing.T) {
withPostgresTestState(t)
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
},
}, PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "id", Type: "TEXT", Nullable: false},
},
PruneColumn: "missing_col",
},
},
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
})
if err == nil {
t.Fatalf("expected invalid schema error")
}
if !strings.Contains(err.Error(), "compile schema") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNewPostgresSinkFromConfigEagerInit(t *testing.T) {
withPostgresTestState(t)
db := &fakeDB{}
var gotCfg pgconn.ConnConfig
openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
gotCfg = cfg
if err := db.PingContext(ctx); err != nil {
return nil, err
}
return db, nil
}
s, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://db.example.local:5432/feedkit?sslmode=disable",
"username": "app_user",
"password": "app_pass",
},
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err != nil {
t.Fatalf("new postgres sink: %v", err)
}
if s == nil {
t.Fatalf("expected sink")
}
if db.pingCalls != 1 {
t.Fatalf("expected one ping, got %d", db.pingCalls)
}
if len(db.execCalls) != 2 {
t.Fatalf("expected 2 init exec calls (table + index), got %d", len(db.execCalls))
}
if !strings.Contains(db.execCalls[0].query, `CREATE TABLE IF NOT EXISTS "events"`) {
t.Fatalf("unexpected create table query: %s", db.execCalls[0].query)
}
if !strings.Contains(db.execCalls[1].query, `CREATE INDEX IF NOT EXISTS "idx_events_emitted_at"`) {
t.Fatalf("unexpected create index query: %s", db.execCalls[1].query)
}
if gotCfg.URI != "postgres://db.example.local:5432/feedkit?sslmode=disable" {
t.Fatalf("URI = %q", gotCfg.URI)
}
if gotCfg.Username != "app_user" {
t.Fatalf("Username = %q, want app_user", gotCfg.Username)
}
if gotCfg.Password != "app_pass" {
t.Fatalf("Password = %q, want app_pass", gotCfg.Password)
}
}
func TestNewPostgresSinkFromConfigInitFailureClosesDB(t *testing.T) {
withPostgresTestState(t)
db := &fakeDB{execErrOnCall: 1, execErr: errors.New("ddl failed")}
openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) {
return db, nil
}
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
},
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err == nil {
t.Fatalf("expected init error")
}
if db.closeCalls != 1 {
t.Fatalf("expected db close on init failure")
}
}
func TestNewPostgresSinkFromConfigPruneParamAccepted(t *testing.T) {
tests := []struct {
name string
in string
want time.Duration
}{
{name: "go duration", in: "72h", want: 72 * time.Hour},
{name: "days suffix", in: "3d", want: 72 * time.Hour},
{name: "weeks suffix", in: "2w", want: 14 * 24 * time.Hour},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
withPostgresTestState(t)
openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) {
return &fakeDB{}, nil
}
s, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
"prune": tc.in,
},
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err != nil {
t.Fatalf("new postgres sink: %v", err)
}
pg, ok := s.(*PostgresSink)
if !ok {
t.Fatalf("expected *PostgresSink, got %T", s)
}
if pg.pruneWindow != tc.want {
t.Fatalf("prune window = %s, want %s", pg.pruneWindow, tc.want)
}
})
}
}
func TestNewPostgresSinkFromConfigPruneParamRejected(t *testing.T) {
withPostgresTestState(t)
tests := []struct {
name string
in any
}{
{name: "empty", in: ""},
{name: "zero", in: "0"},
{name: "negative", in: "-1h"},
{name: "malformed", in: "abc"},
{name: "fractional day", in: "1.5d"},
{name: "wrong type", in: 5},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
"prune": tc.in,
},
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err == nil {
t.Fatalf("expected error")
}
if !strings.Contains(err.Error(), "params.prune") {
t.Fatalf("expected params.prune error, got %v", err)
}
})
}
}
func TestPostgresSinkConsumeInvalidEvent(t *testing.T) {
db := &fakeDB{}
called := 0
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
called++
return nil, nil
})),
}
err := sink.Consume(context.Background(), event.Event{})
if err == nil {
t.Fatalf("expected invalid event error")
}
if !strings.Contains(err.Error(), "invalid event") {
t.Fatalf("unexpected error: %v", err)
}
if called != 0 {
t.Fatalf("expected mapper not called for invalid events")
}
}
func TestPostgresSinkConsumeUnmappedEventIsNoOp(t *testing.T) {
db := &fakeDB{}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
})),
}
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
t.Fatalf("consume: %v", err)
}
if db.beginCalls != 0 {
t.Fatalf("expected no transaction for unmapped events")
}
}
func TestPostgresSinkConsumeOneEventWritesMultipleTablesAtomically(t *testing.T) {
tx := &fakeTx{}
db := &fakeDB{tx: tx}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
}, nil
})),
}
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
t.Fatalf("consume: %v", err)
}
if db.beginCalls != 1 {
t.Fatalf("expected one transaction begin, got %d", db.beginCalls)
}
if len(tx.execCalls) != 2 {
t.Fatalf("expected 2 insert statements, got %d", len(tx.execCalls))
}
if tx.commitCalls != 1 {
t.Fatalf("expected one commit, got %d", tx.commitCalls)
}
if tx.rollbackCalls != 0 {
t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls)
}
}
func TestPostgresSinkConsumeInsertFailureRollsBack(t *testing.T) {
tx := &fakeTx{execErrOnCall: 2, execErr: errors.New("duplicate key")}
db := &fakeDB{tx: tx}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
}, nil
})),
}
err := sink.Consume(context.Background(), validTestEvent())
if err == nil {
t.Fatalf("expected insert error")
}
if !strings.Contains(err.Error(), "insert into") {
t.Fatalf("unexpected error: %v", err)
}
if tx.commitCalls != 0 {
t.Fatalf("expected no commit")
}
if tx.rollbackCalls != 1 {
t.Fatalf("expected rollback, got %d", tx.rollbackCalls)
}
}
func TestPostgresSinkConsumeAutoPruneRunsInSameTransaction(t *testing.T) {
tx := &fakeTx{}
db := &fakeDB{tx: tx}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
}, nil
})),
pruneWindow: 24 * time.Hour,
}
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
t.Fatalf("consume: %v", err)
}
if len(tx.execCalls) != 4 {
t.Fatalf("expected 4 tx statements (2 inserts + 2 prunes), got %d", len(tx.execCalls))
}
if !strings.Contains(tx.execCalls[2].query, `DELETE FROM "events"`) {
t.Fatalf("expected prune delete for events, got %s", tx.execCalls[2].query)
}
if !strings.Contains(tx.execCalls[3].query, `DELETE FROM "event_payloads"`) {
t.Fatalf("expected prune delete for event_payloads, got %s", tx.execCalls[3].query)
}
if tx.commitCalls != 1 {
t.Fatalf("expected one commit, got %d", tx.commitCalls)
}
if tx.rollbackCalls != 0 {
t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls)
}
}
func TestPostgresSinkConsumeAutoPruneFailureRollsBack(t *testing.T) {
tx := &fakeTx{execErrOnCall: 3, execErr: errors.New("prune failed")}
db := &fakeDB{tx: tx}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
}, nil
})),
pruneWindow: 24 * time.Hour,
}
err := sink.Consume(context.Background(), validTestEvent())
if err == nil {
t.Fatalf("expected prune error")
}
if !strings.Contains(err.Error(), "prune older than") {
t.Fatalf("unexpected error: %v", err)
}
if tx.commitCalls != 0 {
t.Fatalf("expected no commit")
}
if tx.rollbackCalls != 1 {
t.Fatalf("expected rollback, got %d", tx.rollbackCalls)
}
}
func TestPostgresSinkPrunePerTable(t *testing.T) {
db := &fakeDB{execRows: 7}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
})),
}
rows, err := sink.PruneKeepLatest(context.Background(), "events", 10)
if err != nil {
t.Fatalf("prune keep latest: %v", err)
}
if rows != 7 {
t.Fatalf("unexpected rows affected: %d", rows)
}
if len(db.execCalls) != 1 {
t.Fatalf("expected one prune query")
}
if !strings.Contains(db.execCalls[0].query, `ORDER BY "emitted_at" DESC`) {
t.Fatalf("unexpected keep-latest query: %s", db.execCalls[0].query)
}
if len(db.execCalls[0].args) != 1 || db.execCalls[0].args[0] != 10 {
t.Fatalf("unexpected keep-latest args: %#v", db.execCalls[0].args)
}
cutoff := time.Now().UTC().Add(-24 * time.Hour)
rows, err = sink.PruneOlderThan(context.Background(), "events", cutoff)
if err != nil {
t.Fatalf("prune older than: %v", err)
}
if rows != 7 {
t.Fatalf("unexpected rows affected: %d", rows)
}
if len(db.execCalls) != 2 {
t.Fatalf("expected two prune queries")
}
if !strings.Contains(db.execCalls[1].query, `WHERE "emitted_at" < $1`) {
t.Fatalf("unexpected older-than query: %s", db.execCalls[1].query)
}
}
func TestPostgresSinkPruneAllTables(t *testing.T) {
db := &fakeDB{execRows: 3}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
})),
}
keepCounts, err := sink.PruneAllKeepLatest(context.Background(), 5)
if err != nil {
t.Fatalf("prune all keep latest: %v", err)
}
if len(keepCounts) != 2 || keepCounts["events"] != 3 || keepCounts["event_payloads"] != 3 {
t.Fatalf("unexpected keep counts: %#v", keepCounts)
}
db.execCalls = nil
olderCounts, err := sink.PruneAllOlderThan(context.Background(), time.Now().UTC())
if err != nil {
t.Fatalf("prune all older than: %v", err)
}
if len(olderCounts) != 2 || olderCounts["events"] != 3 || olderCounts["event_payloads"] != 3 {
t.Fatalf("unexpected older-than counts: %#v", olderCounts)
}
if len(db.execCalls) != 2 {
t.Fatalf("expected one prune call per table")
}
}
func TestPostgresSinkPruneErrors(t *testing.T) {
db := &fakeDB{}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
})),
}
if _, err := sink.PruneKeepLatest(context.Background(), "events", -1); err == nil {
t.Fatalf("expected negative keep error")
}
if _, err := sink.PruneOlderThan(context.Background(), "missing", time.Now().UTC()); err == nil {
t.Fatalf("expected unknown table error")
}
}

View File

@@ -2,6 +2,7 @@ package sinks
import (
"fmt"
"strings"
"gitea.maximumdirect.net/ejr/feedkit/config"
)
@@ -21,13 +22,40 @@ func NewRegistry() *Registry {
}
func (r *Registry) Register(driver string, f Factory) {
if r == nil {
panic("sinks.Registry.Register: registry cannot be nil")
}
driver = strings.TrimSpace(driver)
if driver == "" {
panic("sinks.Registry.Register: driver cannot be empty")
}
if f == nil {
panic(fmt.Sprintf("sinks.Registry.Register: factory cannot be nil (driver=%q)", driver))
}
if r.byDriver == nil {
r.byDriver = map[string]Factory{}
}
if _, exists := r.byDriver[driver]; exists {
panic(fmt.Sprintf("sinks.Registry.Register: driver %q already registered", driver))
}
r.byDriver[driver] = f
}
func (r *Registry) Build(cfg config.SinkConfig) (Sink, error) {
f, ok := r.byDriver[cfg.Driver]
if !ok {
return nil, fmt.Errorf("unknown sink driver: %q", cfg.Driver)
if r == nil {
return nil, fmt.Errorf("sinks registry is nil")
}
return f(cfg)
driver := strings.TrimSpace(cfg.Driver)
f, ok := r.byDriver[driver]
if !ok {
return nil, fmt.Errorf("unknown sink driver: %q", driver)
}
sink, err := f(cfg)
if err != nil {
return nil, fmt.Errorf("build sink %q: %w", driver, err)
}
if sink == nil {
return nil, fmt.Errorf("build sink %q: factory returned nil sink", driver)
}
return sink, nil
}

126
sinks/registry_test.go Normal file
View File

@@ -0,0 +1,126 @@
package sinks
import (
"context"
"errors"
"strings"
"testing"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type testSink struct{ name string }
func (s testSink) Name() string { return s.name }
func (s testSink) Consume(context.Context, event.Event) error { return nil }
func TestRegistryRegisterPanicsOnNilRegistry(t *testing.T) {
var r *Registry
defer func() {
if recover() == nil {
t.Fatalf("Register() expected panic on nil registry")
}
}()
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "stdout"}, nil })
}
func TestRegistryRegisterPanicsOnEmptyDriver(t *testing.T) {
r := NewRegistry()
defer func() {
if recover() == nil {
t.Fatalf("Register() expected panic on empty driver")
}
}()
r.Register(" ", func(config.SinkConfig) (Sink, error) { return testSink{name: "x"}, nil })
}
func TestRegistryRegisterPanicsOnNilFactory(t *testing.T) {
r := NewRegistry()
defer func() {
if recover() == nil {
t.Fatalf("Register() expected panic on nil factory")
}
}()
r.Register("stdout", nil)
}
func TestRegistryRegisterPanicsOnDuplicateDriver(t *testing.T) {
r := NewRegistry()
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "a"}, nil })
defer func() {
if recover() == nil {
t.Fatalf("Register() expected panic on duplicate driver")
}
}()
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "b"}, nil })
}
func TestRegistryBuildNilRegistryFails(t *testing.T) {
var r *Registry
_, err := r.Build(config.SinkConfig{Driver: "stdout"})
if err == nil {
t.Fatalf("Build() expected error for nil registry")
}
if !strings.Contains(err.Error(), "registry is nil") {
t.Fatalf("Build() error = %q, want registry is nil", err)
}
}
func TestRegistryBuildTrimsDriver(t *testing.T) {
r := NewRegistry()
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "stdout"}, nil })
sink, err := r.Build(config.SinkConfig{Name: "sink1", Driver: " stdout "})
if err != nil {
t.Fatalf("Build() error = %v", err)
}
if sink.Name() != "stdout" {
t.Fatalf("Build() sink name = %q, want stdout", sink.Name())
}
}
func TestRegistryBuildWrapsFactoryError(t *testing.T) {
r := NewRegistry()
r.Register("broken", func(config.SinkConfig) (Sink, error) { return nil, errors.New("boom") })
_, err := r.Build(config.SinkConfig{Driver: "broken"})
if err == nil {
t.Fatalf("Build() expected error")
}
if !strings.Contains(err.Error(), `build sink "broken": boom`) {
t.Fatalf("Build() error = %q", err)
}
}
func TestRegistryBuildRejectsNilSink(t *testing.T) {
r := NewRegistry()
r.Register("nil_sink", func(config.SinkConfig) (Sink, error) { return nil, nil })
_, err := r.Build(config.SinkConfig{Driver: "nil_sink"})
if err == nil {
t.Fatalf("Build() expected nil sink error")
}
if !strings.Contains(err.Error(), "factory returned nil sink") {
t.Fatalf("Build() error = %q", err)
}
}
func TestRegisterBuiltinsExposesExpectedDrivers(t *testing.T) {
r := NewRegistry()
RegisterBuiltins(r)
if len(r.byDriver) != 2 {
t.Fatalf("len(byDriver) = %d, want 2", len(r.byDriver))
}
for _, driver := range []string{"stdout", "nats"} {
if _, ok := r.byDriver[driver]; !ok {
t.Fatalf("builtins missing driver %q", driver)
}
}
if _, ok := r.byDriver["postgres"]; ok {
t.Fatalf("builtins unexpectedly registered postgres driver")
}
}

View File

@@ -1,14 +1,52 @@
// Package sources defines feedkit's input-source abstraction.
// Package sources defines feedkit's input-source abstractions and source
// registry.
//
// A source ingests upstream input and emits one or more event.Event values.
//
// feedkit supports two source modes:
// - PollSource: scheduler invokes Poll on a cadence.
// - StreamSource: source runs continuously and pushes events as input arrives.
// External API surface:
// - Input: common source identity surface
// - PollSource: polling source interface
// - StreamSource: streaming source interface
// - StreamRetryable / StreamFatal / IsStreamRetryable / IsStreamFatal:
// stream exit classification helpers
// - Registry / NewRegistry: source driver registry and builders
// - HTTPSource / NewHTTPSource: reusable HTTP polling helper
// - PostgresQuerySource / NewPostgresQuerySource: reusable Postgres polling
// helper
//
// Source drivers are domain-specific and registered into Registry by driver name.
// Registry can then build configured sources from config.SourceConfig.
//
// A single source may emit 0..N events per poll or stream iteration, and those
// events may span multiple event kinds.
//
// Optional helpers from helpers.go:
// - DefaultEventID: default event ID policy for source implementations
// - SingleEvent: construct and validate a one-element event slice
// - ValidateExpectedKinds: compare configured expected kinds against source
// advertised kinds when metadata is available
//
// HTTP-backed polling sources can share NewHTTPSource for generic HTTP config
// parsing and conditional GET behavior. The helper understands:
// - params.url
// - params.user_agent
// - params.conditional (optional, default true)
// - params.http_timeout (optional, default transport.DefaultHTTPTimeout)
// - params.http_response_body_limit_bytes (optional, default
// transport.DefaultHTTPResponseBodyLimitBytes)
//
// When validators are available, NewHTTPSource prefers ETag/If-None-Match and
// falls back to Last-Modified/If-Modified-Since. A 304 Not Modified response is
// treated as a successful unchanged poll.
//
// Postgres-backed polling sources can share NewPostgresQuerySource for generic
// DB config parsing and query execution. The helper understands:
// - params.uri
// - params.username
// - params.password
// - params.query
// - params.query_timeout (optional, default 30s)
//
// feedkit does not register a built-in postgres poll driver. Downstream daemons
// should register domain-specific driver names that call
// NewPostgresQuerySource, then keep SQL semantics, row scanning, ordering,
// watermark policy, and event construction in their own source types.
package sources

140
sources/helpers.go Normal file
View File

@@ -0,0 +1,140 @@
package sources
import (
"fmt"
"sort"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// DefaultEventID applies feedkit's default Event.ID policy:
//
// - If upstream provides an ID, use it (trimmed).
// - Otherwise, ID is "<Source>:<EffectiveAt>" when available.
// - If EffectiveAt is unavailable, fall back to "<Source>:<EmittedAt>".
//
// Timestamps are encoded as RFC3339Nano in UTC.
func DefaultEventID(upstreamID, sourceName string, effectiveAt *time.Time, emittedAt time.Time) string {
if id := strings.TrimSpace(upstreamID); id != "" {
return id
}
src := strings.TrimSpace(sourceName)
if src == "" {
src = "UNKNOWN_SOURCE"
}
if effectiveAt != nil && !effectiveAt.IsZero() {
return fmt.Sprintf("%s:%s", src, effectiveAt.UTC().Format(time.RFC3339Nano))
}
t := emittedAt.UTC()
if t.IsZero() {
t = time.Now().UTC()
}
return fmt.Sprintf("%s:%s", src, t.Format(time.RFC3339Nano))
}
// SingleEvent constructs, validates, and returns a slice containing exactly one event.
func SingleEvent(
kind event.Kind,
sourceName string,
schema string,
id string,
emittedAt time.Time,
effectiveAt *time.Time,
payload any,
) ([]event.Event, error) {
if emittedAt.IsZero() {
emittedAt = time.Now().UTC()
} else {
emittedAt = emittedAt.UTC()
}
e := event.Event{
ID: id,
Kind: kind,
Source: sourceName,
EmittedAt: emittedAt,
EffectiveAt: effectiveAt,
Schema: schema,
Payload: payload,
}
if err := e.Validate(); err != nil {
return nil, err
}
return []event.Event{e}, nil
}
// ValidateExpectedKinds checks that configured source expected kinds are a subset
// of the kinds advertised by the built source, when the source exposes kind
// metadata. If the source does not advertise kinds, the check is skipped.
func ValidateExpectedKinds(cfg config.SourceConfig, in Input) error {
expectedKinds, err := parseExpectedKinds(cfg.ExpectedKinds())
if err != nil {
return err
}
if len(expectedKinds) == 0 {
return nil
}
advertisedKinds := advertisedSourceKinds(in)
if len(advertisedKinds) == 0 {
return nil
}
for kind := range expectedKinds {
if !advertisedKinds[kind] {
return fmt.Errorf(
"configured expected kind %q not advertised by source (configured=%v advertised=%v)",
kind,
sortedKinds(expectedKinds),
sortedKinds(advertisedKinds),
)
}
}
return nil
}
func parseExpectedKinds(raw []string) (map[event.Kind]bool, error) {
kinds := map[event.Kind]bool{}
for i, k := range raw {
kind, err := event.ParseKind(k)
if err != nil {
return nil, fmt.Errorf("invalid expected kind at index %d (%q): %w", i, k, err)
}
kinds[kind] = true
}
return kinds, nil
}
func advertisedSourceKinds(in Input) map[event.Kind]bool {
if in == nil {
return nil
}
kinds := map[event.Kind]bool{}
if ks, ok := in.(KindsSource); ok {
for _, kind := range ks.Kinds() {
kinds[kind] = true
}
return kinds
}
return nil
}
func sortedKinds(kindSet map[event.Kind]bool) []string {
out := make([]string, 0, len(kindSet))
for kind := range kindSet {
out = append(out, string(kind))
}
sort.Strings(out)
return out
}

112
sources/helpers_test.go Normal file
View File

@@ -0,0 +1,112 @@
package sources
import (
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type testInput struct {
name string
}
func (s testInput) Name() string { return s.name }
type testKindsSource struct {
testInput
kinds []event.Kind
}
func (s testKindsSource) Kinds() []event.Kind { return s.kinds }
func TestValidateExpectedKindsSubsetAllowed(t *testing.T) {
cfg := config.SourceConfig{Kinds: []string{"observation"}}
in := testKindsSource{
testInput: testInput{name: "test"},
kinds: []event.Kind{"observation", "forecast"},
}
if err := ValidateExpectedKinds(cfg, in); err != nil {
t.Fatalf("ValidateExpectedKinds() unexpected error: %v", err)
}
}
func TestValidateExpectedKindsMismatchFails(t *testing.T) {
cfg := config.SourceConfig{Kinds: []string{"alert"}}
in := testKindsSource{
testInput: testInput{name: "test"},
kinds: []event.Kind{"observation", "forecast"},
}
err := ValidateExpectedKinds(cfg, in)
if err == nil {
t.Fatalf("ValidateExpectedKinds() expected mismatch error, got nil")
}
if !strings.Contains(err.Error(), "configured expected kind") {
t.Fatalf("ValidateExpectedKinds() error %q does not include expected message", err)
}
}
func TestValidateExpectedKindsNoMetadataSkipsCheck(t *testing.T) {
cfg := config.SourceConfig{Kinds: []string{"alert"}}
in := testInput{name: "test"}
if err := ValidateExpectedKinds(cfg, in); err != nil {
t.Fatalf("ValidateExpectedKinds() unexpected error: %v", err)
}
}
func TestDefaultEventIDUsesUpstreamID(t *testing.T) {
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123, time.UTC)
got := DefaultEventID(" upstream-id ", "source", nil, emittedAt)
if got != "upstream-id" {
t.Fatalf("DefaultEventID() = %q, want upstream-id", got)
}
}
func TestDefaultEventIDPrefersEffectiveAt(t *testing.T) {
effectiveAt := time.Date(2026, 3, 28, 16, 4, 5, 987654321, time.FixedZone("x", -6*3600))
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123, time.UTC)
got := DefaultEventID("", "source", &effectiveAt, emittedAt)
want := "source:" + effectiveAt.UTC().Format(time.RFC3339Nano)
if got != want {
t.Fatalf("DefaultEventID() = %q, want %q", got, want)
}
}
func TestDefaultEventIDFallsBackToEmittedAt(t *testing.T) {
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123456789, time.FixedZone("y", 3*3600))
got := DefaultEventID("", "source", nil, emittedAt)
want := "source:" + emittedAt.UTC().Format(time.RFC3339Nano)
if got != want {
t.Fatalf("DefaultEventID() = %q, want %q", got, want)
}
}
func TestSingleEventBuildsValidatedSlice(t *testing.T) {
effectiveAt := time.Date(2026, 3, 28, 16, 0, 0, 0, time.UTC)
emittedAt := time.Date(2026, 3, 28, 15, 0, 0, 0, time.FixedZone("z", -5*3600))
got, err := SingleEvent(
event.Kind("observation"),
"source-a",
"raw.example.v1",
"evt-1",
emittedAt,
&effectiveAt,
map[string]any{"ok": true},
)
if err != nil {
t.Fatalf("SingleEvent() unexpected error: %v", err)
}
if len(got) != 1 {
t.Fatalf("SingleEvent() len = %d, want 1", len(got))
}
if got[0].EmittedAt != emittedAt.UTC() {
t.Fatalf("SingleEvent() emittedAt = %s, want %s", got[0].EmittedAt, emittedAt.UTC())
}
}

152
sources/http.go Normal file
View File

@@ -0,0 +1,152 @@
package sources
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/transport"
)
// HTTPSource is a reusable helper for polling HTTP-backed sources.
//
// It centralizes generic source config parsing (`params.url`,
// `params.user_agent`, and optional `params.conditional`), default HTTP client
// setup, and conditional GET validator handling. Concrete daemon sources remain
// responsible for decoding the response body and constructing events.
type HTTPSource struct {
Driver string
Name string
URL string
UserAgent string
Accept string
Conditional bool
ResponseBodyLimitBytes int64
Client *http.Client
mu sync.Mutex
validators transport.HTTPValidators
}
// NewHTTPSource builds a generic HTTP polling helper from SourceConfig.
//
// Required params:
// - params.url
// - params.user_agent
//
// Optional params:
// - params.conditional (default true): enable conditional GET using cached
// ETag / Last-Modified validators
func NewHTTPSource(driver string, cfg config.SourceConfig, accept string) (*HTTPSource, error) {
name := strings.TrimSpace(cfg.Name)
if name == "" {
return nil, fmt.Errorf("%s: name is required", driver)
}
if cfg.Params == nil {
return nil, fmt.Errorf("%s %q: params are required (need params.url and params.user_agent)", driver, cfg.Name)
}
url, ok := cfg.ParamString("url", "URL")
if !ok {
return nil, fmt.Errorf("%s %q: params.url is required", driver, cfg.Name)
}
userAgent, ok := cfg.ParamString("user_agent", "userAgent")
if !ok {
return nil, fmt.Errorf("%s %q: params.user_agent is required", driver, cfg.Name)
}
conditional := true
if _, exists := cfg.Params["conditional"]; exists {
var ok bool
conditional, ok = cfg.ParamBool("conditional")
if !ok {
return nil, fmt.Errorf("source %q: params.conditional must be a boolean", cfg.Name)
}
}
timeout := transport.DefaultHTTPTimeout
if _, exists := cfg.Params["http_timeout"]; exists {
var ok bool
timeout, ok = cfg.ParamDuration("http_timeout")
if !ok || timeout <= 0 {
return nil, fmt.Errorf("source %q: params.http_timeout must be a positive duration", cfg.Name)
}
}
bodyLimit := transport.DefaultHTTPResponseBodyLimitBytes
if _, exists := cfg.Params["http_response_body_limit_bytes"]; exists {
rawLimit, ok := cfg.ParamInt("http_response_body_limit_bytes")
if !ok || rawLimit <= 0 {
return nil, fmt.Errorf("source %q: params.http_response_body_limit_bytes must be a positive integer", cfg.Name)
}
bodyLimit = int64(rawLimit)
}
return &HTTPSource{
Driver: driver,
Name: name,
URL: url,
UserAgent: userAgent,
Accept: accept,
Conditional: conditional,
ResponseBodyLimitBytes: bodyLimit,
Client: transport.NewHTTPClient(timeout),
}, nil
}
// FetchBytesIfChanged fetches the configured URL and reports whether the
// upstream content changed. An unchanged 304 response returns changed=false
// with no body and no error.
func (s *HTTPSource) FetchBytesIfChanged(ctx context.Context) ([]byte, bool, error) {
client := s.Client
if client == nil {
client = transport.NewHTTPClient(transport.DefaultHTTPTimeout)
}
s.mu.Lock()
validators := s.validators
s.mu.Unlock()
bodyLimit := s.ResponseBodyLimitBytes
if bodyLimit <= 0 {
bodyLimit = transport.DefaultHTTPResponseBodyLimitBytes
}
body, changed, next, err := transport.FetchBodyIfChangedWithLimit(
ctx,
client,
s.URL,
s.UserAgent,
s.Accept,
s.Conditional,
validators,
bodyLimit,
)
if err != nil {
return nil, false, fmt.Errorf("%s %q: %w", s.Driver, s.Name, err)
}
if s.Conditional {
s.mu.Lock()
s.validators = next
s.mu.Unlock()
}
return body, changed, nil
}
// FetchJSONIfChanged fetches the configured URL and returns the raw response
// body as json.RawMessage when content changed. An unchanged 304 response
// returns changed=false with a nil body and no error.
func (s *HTTPSource) FetchJSONIfChanged(ctx context.Context) (json.RawMessage, bool, error) {
body, changed, err := s.FetchBytesIfChanged(ctx)
if err != nil || !changed {
return nil, changed, err
}
return json.RawMessage(body), true, nil
}

260
sources/http_test.go Normal file
View File

@@ -0,0 +1,260 @@
package sources
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/transport"
)
func TestNewHTTPSourceConditionalDefaultsTrue(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if !src.Conditional {
t.Fatalf("Conditional = false, want true")
}
}
func TestNewHTTPSourceRejectsInvalidConditional(t *testing.T) {
_, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"conditional": "sometimes",
},
}, "application/json")
if err == nil {
t.Fatalf("NewHTTPSource() error = nil, want error")
}
if !strings.Contains(err.Error(), "params.conditional must be a boolean") {
t.Fatalf("NewHTTPSource() error = %q, want params.conditional must be a boolean", err)
}
}
func TestNewHTTPSourceConditionalCanBeExplicitlyFalse(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"conditional": false,
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.Conditional {
t.Fatalf("Conditional = true, want false")
}
}
func TestNewHTTPSourceHTTPTimeoutOverride(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"http_timeout": "250ms",
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.Client == nil {
t.Fatalf("Client = nil")
}
if src.Client.Timeout != 250*time.Millisecond {
t.Fatalf("Client.Timeout = %s, want 250ms", src.Client.Timeout)
}
}
func TestNewHTTPSourceBodyLimitOverride(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"http_response_body_limit_bytes": 12345,
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.ResponseBodyLimitBytes != 12345 {
t.Fatalf("ResponseBodyLimitBytes = %d, want 12345", src.ResponseBodyLimitBytes)
}
}
func TestNewHTTPSourceRejectsInvalidHTTPTimeout(t *testing.T) {
_, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"http_timeout": "soon",
},
}, "application/json")
if err == nil {
t.Fatalf("NewHTTPSource() error = nil, want error")
}
if !strings.Contains(err.Error(), "params.http_timeout must be a positive duration") {
t.Fatalf("NewHTTPSource() error = %q", err)
}
}
func TestNewHTTPSourceRejectsInvalidBodyLimit(t *testing.T) {
_, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"http_response_body_limit_bytes": "abc",
},
}, "application/json")
if err == nil {
t.Fatalf("NewHTTPSource() error = nil, want error")
}
if !strings.Contains(err.Error(), "params.http_response_body_limit_bytes must be a positive integer") {
t.Fatalf("NewHTTPSource() error = %q", err)
}
}
func TestHTTPSourceFetchJSONIfChanged(t *testing.T) {
var call int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
call++
switch call {
case 1:
w.Header().Set("ETag", `"v1"`)
_, _ = w.Write([]byte(`{"ok":true}`))
case 2:
if got := r.Header.Get("If-None-Match"); got != `"v1"` {
t.Fatalf("second request If-None-Match = %q", got)
}
w.WriteHeader(http.StatusNotModified)
default:
t.Fatalf("unexpected call count %d", call)
}
}))
defer srv.Close()
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": srv.URL,
"user_agent": "test-agent",
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
raw, changed, err := src.FetchJSONIfChanged(context.Background())
if err != nil {
t.Fatalf("first FetchJSONIfChanged() error = %v", err)
}
if !changed {
t.Fatalf("first FetchJSONIfChanged() changed = false, want true")
}
if got := string(raw); got != `{"ok":true}` {
t.Fatalf("first FetchJSONIfChanged() body = %q", got)
}
raw, changed, err = src.FetchJSONIfChanged(context.Background())
if err != nil {
t.Fatalf("second FetchJSONIfChanged() error = %v", err)
}
if changed {
t.Fatalf("second FetchJSONIfChanged() changed = true, want false")
}
if raw != nil {
t.Fatalf("second FetchJSONIfChanged() body = %q, want nil", string(raw))
}
}
func TestHTTPSourceFetchJSONIfChangedHonorsBodyLimitOverride(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": srv.URL,
"user_agent": "test-agent",
"http_response_body_limit_bytes": 4,
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
_, _, err = src.FetchJSONIfChanged(context.Background())
if err == nil {
t.Fatalf("FetchJSONIfChanged() error = nil, want limit error")
}
if !strings.Contains(err.Error(), "response body too large") {
t.Fatalf("FetchJSONIfChanged() error = %q", err)
}
}
func TestNewHTTPSourceUsesDefaultBodyLimitWhenUnset(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.ResponseBodyLimitBytes != transport.DefaultHTTPResponseBodyLimitBytes {
t.Fatalf("ResponseBodyLimitBytes = %d, want %d", src.ResponseBodyLimitBytes, transport.DefaultHTTPResponseBodyLimitBytes)
}
}
func TestNewHTTPSourceUsesDefaultTimeoutWhenUnset(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.Client == nil {
t.Fatalf("Client = nil")
}
if src.Client.Timeout != transport.DefaultHTTPTimeout {
t.Fatalf("Client.Timeout = %s, want %s", src.Client.Timeout, transport.DefaultHTTPTimeout)
}
}

117
sources/postgres.go Normal file
View File

@@ -0,0 +1,117 @@
package sources
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
)
const defaultPostgresQueryTimeout = 30 * time.Second
type postgresQueryDB interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
var openPostgresQueryDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) {
return pgconn.Open(ctx, cfg)
}
// PostgresQuerySource is a reusable helper for polling Postgres-backed sources.
//
// It centralizes generic source config parsing and query execution. Concrete
// daemon sources remain responsible for SQL semantics, row scanning, cursoring,
// and event construction.
type PostgresQuerySource struct {
Driver string
Name string
SQL string
QueryTimeout time.Duration
db postgresQueryDB
}
// NewPostgresQuerySource builds a generic Postgres polling helper from
// SourceConfig.
//
// Required params:
// - params.uri
// - params.username
// - params.password
// - params.query
//
// Optional params:
// - params.query_timeout (default 30s)
func NewPostgresQuerySource(driver string, cfg config.SourceConfig) (*PostgresQuerySource, error) {
name := strings.TrimSpace(cfg.Name)
if name == "" {
return nil, fmt.Errorf("%s: name is required", driver)
}
if cfg.Params == nil {
return nil, fmt.Errorf("%s %q: params are required (need params.uri, params.username, params.password, and params.query)", driver, cfg.Name)
}
uri, ok := cfg.ParamString("uri")
if !ok {
return nil, fmt.Errorf("%s %q: params.uri is required", driver, cfg.Name)
}
username, ok := cfg.ParamString("username")
if !ok {
return nil, fmt.Errorf("%s %q: params.username is required", driver, cfg.Name)
}
password, ok := cfg.ParamString("password")
if !ok {
return nil, fmt.Errorf("%s %q: params.password is required", driver, cfg.Name)
}
query, ok := cfg.ParamString("query")
if !ok {
return nil, fmt.Errorf("%s %q: params.query is required", driver, cfg.Name)
}
queryTimeout := defaultPostgresQueryTimeout
if _, exists := cfg.Params["query_timeout"]; exists {
var ok bool
queryTimeout, ok = cfg.ParamDuration("query_timeout")
if !ok || queryTimeout <= 0 {
return nil, fmt.Errorf("source %q: params.query_timeout must be a positive duration", cfg.Name)
}
}
db, err := openPostgresQueryDB(context.Background(), pgconn.ConnConfig{
URI: uri,
Username: username,
Password: password,
})
if err != nil {
return nil, fmt.Errorf("%s %q: open db: %w", driver, cfg.Name, err)
}
return &PostgresQuerySource{
Driver: driver,
Name: name,
SQL: query,
QueryTimeout: queryTimeout,
db: db,
}, nil
}
func (s *PostgresQuerySource) Query(ctx context.Context, args ...any) (*sql.Rows, error) {
queryCtx := ctx
if s.QueryTimeout > 0 {
if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > s.QueryTimeout {
// We intentionally do not cancel this derived context here because the
// returned rows may still be reading from the database.
queryCtx, _ = context.WithTimeout(ctx, s.QueryTimeout)
}
}
rows, err := s.db.QueryContext(queryCtx, s.SQL, args...)
if err != nil {
return nil, fmt.Errorf("%s %q: query: %w", s.Driver, s.Name, err)
}
return rows, nil
}

352
sources/postgres_test.go Normal file
View File

@@ -0,0 +1,352 @@
package sources
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"io"
"strings"
"sync"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
)
type fakePostgresQueryDB struct {
queryErr error
lastCtx context.Context
lastQuery string
lastArgs []any
returnRows *sql.Rows
}
func (db *fakePostgresQueryDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
db.lastCtx = ctx
db.lastQuery = query
db.lastArgs = append([]any(nil), args...)
if db.queryErr != nil {
return nil, db.queryErr
}
return db.returnRows, nil
}
func withPostgresQuerySourceTestState(t *testing.T) {
t.Helper()
oldOpen := openPostgresQueryDB
t.Cleanup(func() {
openPostgresQueryDB = oldOpen
})
}
func TestNewPostgresQuerySourceMissingParams(t *testing.T) {
withPostgresQuerySourceTestState(t)
tests := []struct {
name string
params map[string]any
want string
}{
{name: "missing uri", params: map[string]any{"username": "u", "password": "p", "query": "SELECT 1"}, want: "params.uri"},
{name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p", "query": "SELECT 1"}, want: "params.username"},
{name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "query": "SELECT 1"}, want: "params.password"},
{name: "missing query", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "password": "p"}, want: "params.query"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: tc.params,
})
if err == nil {
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
}
if !strings.Contains(err.Error(), tc.want) {
t.Fatalf("NewPostgresQuerySource() error = %q, want substring %q", err, tc.want)
}
})
}
}
func TestNewPostgresQuerySourceRejectsInvalidQueryTimeout(t *testing.T) {
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "u",
"password": "p",
"query": "SELECT 1",
"query_timeout": "soon",
},
})
if err == nil {
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
}
if !strings.Contains(err.Error(), "params.query_timeout must be a positive duration") {
t.Fatalf("NewPostgresQuerySource() error = %q", err)
}
}
func TestNewPostgresQuerySourceSuccessfulConstruction(t *testing.T) {
withPostgresQuerySourceTestState(t)
db := &fakePostgresQueryDB{}
var gotCfg pgconn.ConnConfig
openPostgresQueryDB = func(_ context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) {
gotCfg = cfg
return db, nil
}
src, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: map[string]any{
"uri": "postgres://db.example.local/feedkit",
"username": "app_user",
"password": "app_pass",
"query": "SELECT * FROM observations",
"query_timeout": "45s",
},
})
if err != nil {
t.Fatalf("NewPostgresQuerySource() error = %v", err)
}
if src.Name != "pg-source" {
t.Fatalf("Name = %q, want pg-source", src.Name)
}
if src.QueryTimeout != 45*time.Second {
t.Fatalf("QueryTimeout = %s, want 45s", src.QueryTimeout)
}
if src.SQL != "SELECT * FROM observations" {
t.Fatalf("SQL = %q", src.SQL)
}
if gotCfg.Username != "app_user" || gotCfg.Password != "app_pass" {
t.Fatalf("ConnConfig = %+v", gotCfg)
}
}
func TestNewPostgresQuerySourceOpenFailure(t *testing.T) {
withPostgresQuerySourceTestState(t)
openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) {
return nil, errors.New("db unavailable")
}
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "u",
"password": "p",
"query": "SELECT 1",
},
})
if err == nil {
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
}
if !strings.Contains(err.Error(), `test_driver "pg-source": open db: db unavailable`) {
t.Fatalf("NewPostgresQuerySource() error = %q", err)
}
}
func TestPostgresQuerySourceQueryAppliesTimeoutAndWrapsError(t *testing.T) {
db := &fakePostgresQueryDB{queryErr: errors.New("query failed")}
src := &PostgresQuerySource{
Driver: "test_driver",
Name: "pg-source",
SQL: "SELECT 1",
QueryTimeout: 30 * time.Second,
db: db,
}
ctx := context.Background()
_, err := src.Query(ctx, "arg1")
if err == nil {
t.Fatalf("Query() error = nil, want error")
}
if !strings.Contains(err.Error(), `test_driver "pg-source": query: query failed`) {
t.Fatalf("Query() error = %q", err)
}
if db.lastCtx == nil {
t.Fatalf("lastCtx = nil")
}
if _, ok := db.lastCtx.Deadline(); !ok {
t.Fatalf("expected derived deadline on query context")
}
if db.lastQuery != "SELECT 1" {
t.Fatalf("lastQuery = %q", db.lastQuery)
}
if len(db.lastArgs) != 1 || db.lastArgs[0] != "arg1" {
t.Fatalf("lastArgs = %#v", db.lastArgs)
}
}
func TestPostgresQuerySourceQueryUsesEarlierCallerDeadline(t *testing.T) {
db := &fakePostgresQueryDB{queryErr: errors.New("query failed")}
src := &PostgresQuerySource{
Driver: "test_driver",
Name: "pg-source",
SQL: "SELECT 1",
QueryTimeout: 30 * time.Second,
db: db,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
_, _ = src.Query(ctx)
if db.lastCtx != ctx {
t.Fatalf("expected source to reuse earlier caller deadline")
}
}
func TestPostgresQuerySourceSupportsDownstreamPollingPattern(t *testing.T) {
withPostgresQuerySourceTestState(t)
db, cleanup := openRowsTestDB(t, "feedkit_sources_pg_rows", []string{"event_id"}, [][]driver.Value{{"evt-1"}})
defer cleanup()
openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) {
return db, nil
}
type fakeDownstreamSource struct {
pg *PostgresQuerySource
}
poll := func(s fakeDownstreamSource, ctx context.Context) ([]event.Event, error) {
rows, err := s.pg.Query(ctx)
if err != nil {
return nil, err
}
defer rows.Close()
var out []event.Event
for rows.Next() {
var eventID string
if err := rows.Scan(&eventID); err != nil {
return nil, err
}
out = append(out, event.Event{
ID: eventID,
Kind: event.Kind("observation"),
Source: s.pg.Name,
Schema: "raw.test.v1",
EmittedAt: time.Now().UTC(),
Payload: map[string]any{"event_id": eventID},
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
pg, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "u",
"password": "p",
"query": "SELECT event_id FROM events",
},
})
if err != nil {
t.Fatalf("NewPostgresQuerySource() error = %v", err)
}
events, err := poll(fakeDownstreamSource{pg: pg}, context.Background())
if err != nil {
t.Fatalf("poll() error = %v", err)
}
if len(events) != 1 {
t.Fatalf("len(events) = %d, want 1", len(events))
}
if events[0].ID != "evt-1" {
t.Fatalf("events[0].ID = %q, want evt-1", events[0].ID)
}
}
var (
rowsDriverMu sync.Mutex
rowsDriverSeen = map[string]bool{}
)
func openRowsTestDB(t *testing.T, driverName string, columns []string, rows [][]driver.Value) (*sql.DB, func()) {
t.Helper()
rowsDriverMu.Lock()
if !rowsDriverSeen[driverName] {
sql.Register(driverName, &rowsTestDriver{columns: append([]string(nil), columns...), rows: cloneDriverRows(rows)})
rowsDriverSeen[driverName] = true
}
rowsDriverMu.Unlock()
db, err := sql.Open(driverName, "")
if err != nil {
t.Fatalf("sql.Open() error = %v", err)
}
return db, func() {
_ = db.Close()
}
}
func cloneDriverRows(in [][]driver.Value) [][]driver.Value {
out := make([][]driver.Value, 0, len(in))
for _, row := range in {
copied := append([]driver.Value(nil), row...)
out = append(out, copied)
}
return out
}
type rowsTestDriver struct {
columns []string
rows [][]driver.Value
}
func (d *rowsTestDriver) Open(string) (driver.Conn, error) {
return &rowsTestConn{columns: append([]string(nil), d.columns...), rows: cloneDriverRows(d.rows)}, nil
}
type rowsTestConn struct {
columns []string
rows [][]driver.Value
}
func (c *rowsTestConn) Prepare(string) (driver.Stmt, error) {
return nil, errors.New("not implemented")
}
func (c *rowsTestConn) Close() error { return nil }
func (c *rowsTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") }
func (c *rowsTestConn) QueryContext(_ context.Context, _ string, _ []driver.NamedValue) (driver.Rows, error) {
return &rowsTestRows{columns: append([]string(nil), c.columns...), rows: cloneDriverRows(c.rows)}, nil
}
type rowsTestRows struct {
columns []string
rows [][]driver.Value
idx int
}
func (r *rowsTestRows) Columns() []string { return append([]string(nil), r.columns...) }
func (r *rowsTestRows) Close() error { return nil }
func (r *rowsTestRows) Next(dest []driver.Value) error {
if r.idx >= len(r.rows) {
return io.EOF
}
copy(dest, r.rows[r.idx])
r.idx++
return nil
}

View File

@@ -15,9 +15,6 @@ import (
type PollFactory func(cfg config.SourceConfig) (PollSource, error)
type StreamFactory func(cfg config.SourceConfig) (StreamSource, error)
// Factory is the legacy alias for poll source factories.
type Factory = PollFactory
type Registry struct {
byPollDriver map[string]PollFactory
byStreamDriver map[string]StreamFactory
@@ -30,13 +27,6 @@ func NewRegistry() *Registry {
}
}
// Register associates a driver name (e.g. "openmeteo_observation") with a factory.
//
// The driver string is the "lookup key" used by config.sources[].driver.
func (r *Registry) Register(driver string, f PollFactory) {
r.RegisterPoll(driver, f)
}
// RegisterPoll associates a driver name with a polling-source factory.
func (r *Registry) RegisterPoll(driver string, f PollFactory) {
driver = strings.TrimSpace(driver)
@@ -75,11 +65,6 @@ func (r *Registry) RegisterStream(driver string, f StreamFactory) {
r.byStreamDriver[driver] = f
}
// Build constructs a polling source from a SourceConfig by looking up cfg.Driver.
func (r *Registry) Build(cfg config.SourceConfig) (PollSource, error) {
return r.BuildPoll(cfg)
}
// BuildPoll constructs a polling source from a SourceConfig by looking up cfg.Driver.
func (r *Registry) BuildPoll(cfg config.SourceConfig) (PollSource, error) {
driver := strings.TrimSpace(cfg.Driver)

View File

@@ -31,24 +31,18 @@ type PollSource interface {
Poll(ctx context.Context) ([]event.Event, error)
}
// Source is a compatibility alias for the legacy polling-source name.
type Source = PollSource
// StreamSource is an event-driven source (NATS/RabbitMQ/MQTT/etc).
//
// Run should block, producing events into `out` until ctx is cancelled or a fatal error occurs.
// It MUST NOT close out (the scheduler/daemon owns the bus).
//
// Stream sources can classify exits by wrapping errors with StreamRetryable or
// StreamFatal. Plain non-nil errors are treated as retryable by the scheduler.
type StreamSource interface {
Input
Run(ctx context.Context, out chan<- event.Event) error
}
// KindSource is an optional interface for sources that advertise one "primary" kind.
// This is legacy-friendly but no longer required.
type KindSource interface {
Kind() event.Kind
}
// KindsSource is an optional interface for sources that advertise multiple kinds.
type KindsSource interface {
Kinds() []event.Kind

63
sources/stream_errors.go Normal file
View File

@@ -0,0 +1,63 @@
package sources
import "errors"
type streamRetryableError struct {
err error
}
func (e *streamRetryableError) Error() string {
if e.err == nil {
return "retryable stream error"
}
return e.err.Error()
}
func (e *streamRetryableError) Unwrap() error { return e.err }
type streamFatalError struct {
err error
}
func (e *streamFatalError) Error() string {
if e.err == nil {
return "fatal stream error"
}
return e.err.Error()
}
func (e *streamFatalError) Unwrap() error { return e.err }
// StreamRetryable marks a stream-source exit as retryable.
func StreamRetryable(err error) error {
if err == nil {
return nil
}
return &streamRetryableError{err: err}
}
// StreamFatal marks a stream-source exit as fatal.
func StreamFatal(err error) error {
if err == nil {
return nil
}
return &streamFatalError{err: err}
}
// IsStreamRetryable reports whether err contains a retryable stream marker.
func IsStreamRetryable(err error) bool {
if err == nil {
return false
}
var target *streamRetryableError
return errors.As(err, &target)
}
// IsStreamFatal reports whether err contains a fatal stream marker.
func IsStreamFatal(err error) bool {
if err == nil {
return false
}
var target *streamFatalError
return errors.As(err, &target)
}

View File

@@ -0,0 +1,52 @@
package sources
import (
"errors"
"fmt"
"testing"
)
func TestStreamRetryableWrapsThroughErrorChains(t *testing.T) {
base := errors.New("retry me")
err := fmt.Errorf("outer: %w", StreamRetryable(base))
if !IsStreamRetryable(err) {
t.Fatalf("IsStreamRetryable() = false, want true")
}
if IsStreamFatal(err) {
t.Fatalf("IsStreamFatal() = true, want false")
}
if !errors.Is(err, base) {
t.Fatalf("errors.Is(err, base) = false, want true")
}
}
func TestStreamFatalWrapsThroughErrorChains(t *testing.T) {
base := errors.New("fatal")
err := fmt.Errorf("outer: %w", StreamFatal(base))
if !IsStreamFatal(err) {
t.Fatalf("IsStreamFatal() = false, want true")
}
if IsStreamRetryable(err) {
t.Fatalf("IsStreamRetryable() = true, want false")
}
if !errors.Is(err, base) {
t.Fatalf("errors.Is(err, base) = false, want true")
}
}
func TestStreamErrorHelpersNil(t *testing.T) {
if StreamRetryable(nil) != nil {
t.Fatalf("StreamRetryable(nil) != nil")
}
if StreamFatal(nil) != nil {
t.Fatalf("StreamFatal(nil) != nil")
}
if IsStreamRetryable(nil) {
t.Fatalf("IsStreamRetryable(nil) = true")
}
if IsStreamFatal(nil) {
t.Fatalf("IsStreamFatal(nil) = true")
}
}

View File

@@ -6,13 +6,14 @@ import (
"fmt"
"io"
"net/http"
"strings"
"time"
)
// maxResponseBodyBytes is a hard safety limit on HTTP response bodies.
// DefaultHTTPResponseBodyLimitBytes is a hard safety limit on HTTP response bodies.
// API responses should be small, so this protects us from accidental
// or malicious large responses.
const maxResponseBodyBytes = 2 << 21 // 4 MiB
const DefaultHTTPResponseBodyLimitBytes int64 = 2 << 21 // 4 MiB
// DefaultHTTPTimeout is the standard timeout used by HTTP sources.
// Individual drivers may override this if they have a specific need.
@@ -28,7 +29,95 @@ func NewHTTPClient(timeout time.Duration) *http.Client {
}
func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
return FetchBodyWithLimit(ctx, client, url, userAgent, accept, DefaultHTTPResponseBodyLimitBytes)
}
func FetchBodyWithLimit(ctx context.Context, client *http.Client, url, userAgent, accept string, bodyLimitBytes int64) ([]byte, error) {
res, err := doRequest(ctx, client, http.MethodGet, url, userAgent, accept, "", "")
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return nil, fmt.Errorf("HTTP %s", res.Status)
}
return readValidatedBody(res.Body, bodyLimitBytes)
}
// HTTPValidators are cache validators learned from prior successful GET responses.
//
// ETag is preferred when present. LastModified is used as a fallback validator
// when ETag is unavailable.
type HTTPValidators struct {
ETag string
LastModified string
}
// FetchBodyIfChanged performs an HTTP GET and opportunistically uses conditional
// request headers based on the provided validators.
//
// Behavior:
// - if conditional is false, this behaves like a normal GET and leaves validators unchanged
// - if validators.ETag is set, sends If-None-Match
// - else if validators.LastModified is set, sends If-Modified-Since
// - 304 Not Modified is treated as success with changed=false and no body
// - 200 responses are treated as changed=true and still enforce the normal body checks
//
// Returned validators reflect any updates learned from the response headers.
func FetchBodyIfChanged(
ctx context.Context,
client *http.Client,
url, userAgent, accept string,
conditional bool,
validators HTTPValidators,
) ([]byte, bool, HTTPValidators, error) {
return FetchBodyIfChangedWithLimit(ctx, client, url, userAgent, accept, conditional, validators, DefaultHTTPResponseBodyLimitBytes)
}
func FetchBodyIfChangedWithLimit(
ctx context.Context,
client *http.Client,
url, userAgent, accept string,
conditional bool,
validators HTTPValidators,
bodyLimitBytes int64,
) ([]byte, bool, HTTPValidators, error) {
headerName, headerValue := conditionalHeader(conditional, validators)
res, err := doRequest(ctx, client, http.MethodGet, url, userAgent, accept, headerName, headerValue)
if err != nil {
return nil, false, validators, err
}
defer res.Body.Close()
switch res.StatusCode {
case http.StatusNotModified:
if conditional {
validators = refreshValidators(validators, res.Header)
}
return nil, false, validators, nil
default:
if res.StatusCode < 200 || res.StatusCode >= 300 {
return nil, false, validators, fmt.Errorf("HTTP %s", res.Status)
}
}
b, err := readValidatedBody(res.Body, bodyLimitBytes)
if err != nil {
return nil, false, validators, err
}
if conditional {
validators = replaceValidators(res.Header)
}
return b, true, validators, nil
}
func doRequest(ctx context.Context, client *http.Client, method, url, userAgent, accept, headerName, headerValue string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
return nil, err
}
@@ -39,19 +128,50 @@ func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept
if accept != "" {
req.Header.Set("Accept", accept)
}
res, err := client.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return nil, fmt.Errorf("HTTP %s", res.Status)
if headerName != "" && headerValue != "" {
req.Header.Set(headerName, headerValue)
}
// Read at most maxResponseBodyBytes + 1 so we can detect overflow.
limited := io.LimitReader(res.Body, maxResponseBodyBytes+1)
return client.Do(req)
}
func conditionalHeader(enabled bool, validators HTTPValidators) (string, string) {
if !enabled {
return "", ""
}
if etag := strings.TrimSpace(validators.ETag); etag != "" {
return "If-None-Match", etag
}
if lastModified := strings.TrimSpace(validators.LastModified); lastModified != "" {
return "If-Modified-Since", lastModified
}
return "", ""
}
func replaceValidators(header http.Header) HTTPValidators {
return HTTPValidators{
ETag: strings.TrimSpace(header.Get("ETag")),
LastModified: strings.TrimSpace(header.Get("Last-Modified")),
}
}
func refreshValidators(current HTTPValidators, header http.Header) HTTPValidators {
if etag := strings.TrimSpace(header.Get("ETag")); etag != "" {
current.ETag = etag
}
if lastModified := strings.TrimSpace(header.Get("Last-Modified")); lastModified != "" {
current.LastModified = lastModified
}
return current
}
func readValidatedBody(r io.Reader, bodyLimitBytes int64) ([]byte, error) {
if bodyLimitBytes <= 0 {
bodyLimitBytes = DefaultHTTPResponseBodyLimitBytes
}
// Read at most bodyLimitBytes + 1 so we can detect overflow.
limited := io.LimitReader(r, bodyLimitBytes+1)
b, err := io.ReadAll(limited)
if err != nil {
@@ -62,8 +182,8 @@ func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept
return nil, fmt.Errorf("empty response body")
}
if len(b) > maxResponseBodyBytes {
return nil, fmt.Errorf("response body too large (>%d bytes)", maxResponseBodyBytes)
if int64(len(b)) > bodyLimitBytes {
return nil, fmt.Errorf("response body too large (>%d bytes)", bodyLimitBytes)
}
return b, nil

232
transport/http_test.go Normal file
View File

@@ -0,0 +1,232 @@
package transport
import (
"context"
"net/http"
"net/http/httptest"
"testing"
)
func TestFetchBodyIfChangedPrefersETagAndTreats304AsUnchanged(t *testing.T) {
t.Helper()
var call int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
call++
switch call {
case 1:
if got := r.Header.Get("If-None-Match"); got != "" {
t.Fatalf("first request If-None-Match = %q, want empty", got)
}
if got := r.Header.Get("If-Modified-Since"); got != "" {
t.Fatalf("first request If-Modified-Since = %q, want empty", got)
}
w.Header().Set("ETag", `"v1"`)
w.Header().Set("Last-Modified", "Mon, 02 Jan 2006 15:04:05 GMT")
_, _ = w.Write([]byte(`{"ok":true}`))
case 2:
if got := r.Header.Get("If-None-Match"); got != `"v1"` {
t.Fatalf("second request If-None-Match = %q, want %q", got, `"v1"`)
}
if got := r.Header.Get("If-Modified-Since"); got != "" {
t.Fatalf("second request If-Modified-Since = %q, want empty when ETag is cached", got)
}
w.WriteHeader(http.StatusNotModified)
default:
t.Fatalf("unexpected call count %d", call)
}
}))
defer srv.Close()
validators := HTTPValidators{}
body, changed, next, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "application/json", true, validators)
if err != nil {
t.Fatalf("first FetchBodyIfChanged() error = %v", err)
}
if !changed {
t.Fatalf("first FetchBodyIfChanged() changed = false, want true")
}
if got := string(body); got != `{"ok":true}` {
t.Fatalf("first FetchBodyIfChanged() body = %q", got)
}
if got := next.ETag; got != `"v1"` {
t.Fatalf("cached ETag = %q, want %q", got, `"v1"`)
}
if got := next.LastModified; got != "Mon, 02 Jan 2006 15:04:05 GMT" {
t.Fatalf("cached Last-Modified = %q", got)
}
body, changed, next, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "application/json", true, next)
if err != nil {
t.Fatalf("second FetchBodyIfChanged() error = %v", err)
}
if changed {
t.Fatalf("second FetchBodyIfChanged() changed = true, want false")
}
if body != nil {
t.Fatalf("second FetchBodyIfChanged() body = %q, want nil", string(body))
}
if got := next.ETag; got != `"v1"` {
t.Fatalf("cached ETag after 304 = %q, want %q", got, `"v1"`)
}
}
func TestFetchBodyIfChangedFallsBackToIfModifiedSince(t *testing.T) {
t.Helper()
var call int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
call++
switch call {
case 1:
w.Header().Set("Last-Modified", "Tue, 03 Jan 2006 15:04:05 GMT")
_, _ = w.Write([]byte(`first`))
case 2:
if got := r.Header.Get("If-None-Match"); got != "" {
t.Fatalf("second request If-None-Match = %q, want empty", got)
}
if got := r.Header.Get("If-Modified-Since"); got != "Tue, 03 Jan 2006 15:04:05 GMT" {
t.Fatalf("second request If-Modified-Since = %q", got)
}
w.WriteHeader(http.StatusNotModified)
default:
t.Fatalf("unexpected call count %d", call)
}
}))
defer srv.Close()
_, changed, validators, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, HTTPValidators{})
if err != nil {
t.Fatalf("first FetchBodyIfChanged() error = %v", err)
}
if !changed {
t.Fatalf("first FetchBodyIfChanged() changed = false, want true")
}
if got := validators.LastModified; got != "Tue, 03 Jan 2006 15:04:05 GMT" {
t.Fatalf("cached Last-Modified = %q", got)
}
_, changed, _, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, validators)
if err != nil {
t.Fatalf("second FetchBodyIfChanged() error = %v", err)
}
if changed {
t.Fatalf("second FetchBodyIfChanged() changed = true, want false")
}
}
func TestFetchBodyIfChangedClearsValidatorsOn200WithoutValidators(t *testing.T) {
t.Helper()
var call int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
call++
switch call {
case 1:
w.Header().Set("ETag", `"v1"`)
_, _ = w.Write([]byte(`first`))
case 2:
if got := r.Header.Get("If-None-Match"); got != `"v1"` {
t.Fatalf("second request If-None-Match = %q", got)
}
_, _ = w.Write([]byte(`second`))
case 3:
if got := r.Header.Get("If-None-Match"); got != "" {
t.Fatalf("third request If-None-Match = %q, want empty", got)
}
if got := r.Header.Get("If-Modified-Since"); got != "" {
t.Fatalf("third request If-Modified-Since = %q, want empty", got)
}
_, _ = w.Write([]byte(`third`))
default:
t.Fatalf("unexpected call count %d", call)
}
}))
defer srv.Close()
_, _, validators, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, HTTPValidators{})
if err != nil {
t.Fatalf("first FetchBodyIfChanged() error = %v", err)
}
_, _, validators, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, validators)
if err != nil {
t.Fatalf("second FetchBodyIfChanged() error = %v", err)
}
if validators.ETag != "" || validators.LastModified != "" {
t.Fatalf("validators after 200 without validators = %+v, want cleared", validators)
}
_, _, _, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, validators)
if err != nil {
t.Fatalf("third FetchBodyIfChanged() error = %v", err)
}
}
func TestFetchBodyIfChangedConditionalDisabledSkipsConditionalHeaders(t *testing.T) {
t.Helper()
var calls int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls++
if got := r.Header.Get("If-None-Match"); got != "" {
t.Fatalf("request If-None-Match = %q, want empty", got)
}
if got := r.Header.Get("If-Modified-Since"); got != "" {
t.Fatalf("request If-Modified-Since = %q, want empty", got)
}
_, _ = w.Write([]byte(`body`))
}))
defer srv.Close()
validators := HTTPValidators{ETag: `"v1"`, LastModified: "Wed, 04 Jan 2006 15:04:05 GMT"}
_, changed, next, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", false, validators)
if err != nil {
t.Fatalf("FetchBodyIfChanged() error = %v", err)
}
if !changed {
t.Fatalf("FetchBodyIfChanged() changed = false, want true")
}
if next != validators {
t.Fatalf("validators changed when conditional disabled: got %+v want %+v", next, validators)
}
if calls != 1 {
t.Fatalf("calls = %d, want 1", calls)
}
}
func TestFetchBodyIfChangedAllowsEmpty304ButRejectsEmpty200(t *testing.T) {
t.Helper()
notModified := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotModified)
}))
defer notModified.Close()
_, changed, _, err := FetchBodyIfChanged(
context.Background(),
notModified.Client(),
notModified.URL,
"test-agent",
"",
true,
HTTPValidators{ETag: `"v1"`},
)
if err != nil {
t.Fatalf("304 FetchBodyIfChanged() error = %v", err)
}
if changed {
t.Fatalf("304 FetchBodyIfChanged() changed = true, want false")
}
emptyBody := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer emptyBody.Close()
_, _, _, err = FetchBodyIfChanged(context.Background(), emptyBody.Client(), emptyBody.URL, "test-agent", "", true, HTTPValidators{})
if err == nil {
t.Fatalf("empty 200 FetchBodyIfChanged() error = nil, want error")
}
if err.Error() != "empty response body" {
t.Fatalf("empty 200 FetchBodyIfChanged() error = %q", err)
}
}