18 Commits

Author SHA1 Message Date
5c1b28ee0a Added support for Postgres polling sources 2026-03-29 10:53:13 -05:00
247937b65e Upgraded feedkit's handling of stream sources 2026-03-29 08:34:35 -05:00
eb9a7cb349 Refactor feedkit boundaries ahead of v1
Remove global Postgres schema registration in favor of explicit schema-aware sink factory wiring, and update weatherfeeder to register the Postgres sink explicitly. Add optional per-source HTTP timeout and response body limit overrides while keeping feedkit defaults. Remove remaining legacy source/config compatibility surfaces, including singular kind support and old source registry/type aliases, and migrate weatherfeeder sources to plural `Kinds()` metadata. Clean up related docs, tests, and sample config to match the new Postgres, HTTP, and NATS configuration model.
2026-03-28 13:52:48 -05:00
3281368922 Cleaned up documentation and removed stubs and TODOs throughout the application 2026-03-28 13:02:37 -05:00
3ef93faf69 Moved broadly useful helper functions upstream into feedkit 2026-03-28 11:29:09 -05:00
4910440756 Moved common HTTP polling helpers into feedkit and implemented support for ETag and Last-Modified 2026-03-28 09:59:58 -05:00
3b92c2284d Added automatic pruning for configured Postgres sinks 2026-03-28 08:04:15 -05:00
215afe1acf Added a dedupe processor, and moved processor packages under processors/* 2026-03-16 18:17:53 -05:00
4572c53580 Updated sinks to add a functional postgres sink API. 2026-03-16 14:54:57 -05:00
96039f6530 refactor!: introduce generic processors registry and remove normalize registry adapter
- add new `processors` package with canonical `Processor` interface
- add `processors.Registry` with Register/Build/BuildChain factory model
- switch `pipeline.Pipeline` to `[]processors.Processor`
- replace `normalize.Registry` + registry adapter with direct `normalize.Processor`
- remove `normalize/registry.go`
- update root docs to position normalize as one optional processing stage
- add tests for processors registry, normalize processor behavior, and pipeline flow

BREAKING CHANGE:
- `pipeline.Processor` removed; use `processors.Processor`
- `normalize.Registry` and old normalize processor adapter APIs removed
- downstream daemons must update processor wiring to new `processors.Registry`
  and `normalize.NewProcessor(...)`
2026-03-16 13:14:24 -05:00
6c5f95ad26 feat(sources)!: split source contracts into PollSource/StreamSource and add mode-aware source config
- Introduce explicit source interfaces: sources.PollSource and sources.StreamSource, with shared sources.Input (Name() only).
- Remove mandatory Kind() from the base source contract to support sources that emit multiple kinds.
- Add config.SourceMode (poll, stream, or omitted/auto) and SourceConfig.Kinds (plural expected kinds), while keeping legacy SourceConfig.Kind for compatibility.
- Enforce mode semantics in config validation (poll requires every, stream forbids every) and detect mode/driver mismatches in sources.Registry.
- Update docs and tests for the new source model and config behavior.
2026-03-15 19:19:19 -05:00
fafba0f01b Refactored the scheduler and source interfaces to accommondate both polling (e.g., HTTP) sources and streaming (e.g., message queue) sources. 2026-02-08 15:03:46 -06:00
3c95fa97cd Updated the README to reflect recent updates to default sinks. 2026-02-07 19:45:26 -06:00
dbca0548b1 Implemented a builtin NATS sink. 2026-02-07 11:35:55 -06:00
9b2c1e5ceb transport: Moved transport/http.go upstream to feedkit (was previously in weatherfeeder). 2026-01-15 19:08:28 -06:00
1d43adcfa0 dispatch: allow empty route kinds (match all) + add routing tests
- config: permit routes[].kinds to be omitted/empty; treat as "all kinds"
- dispatch: compile empty kinds to Route{Kinds:nil} (match all kinds)
- tests: add coverage for route compilation + config validation edge cases

Files:
- config/load.go
- config/config.go
- dispatch/routes.go
- config/validate_test.go
- dispatch/routes_test.go
2026-01-15 18:26:45 -06:00
a6c133319a feat(feedkit): add optional normalization hook and document external API
Introduce an optional normalization stage for feedkit pipelines via the new
normalize package. This adds:

- normalize.Normalizer interface with flexible Match() semantics
- normalize.Registry for ordered normalizer selection (first match wins)
- normalize.Processor adapter implementing pipeline.Processor
- Pass-through behavior when no normalizer matches (normalization is optional)
- Func helper for ergonomic normalizer definitions

Update root doc.go to fully document the normalization model, its role in the
pipeline, recommended conventions (Schema-based matching, raw vs normalized
events), and concrete wiring examples. The documentation now serves as a
complete external-facing API specification for downstream daemons such as
weatherfeeder.

This change preserves feedkit’s non-framework philosophy while enabling a
clean separation between data collection and domain normalization.
2026-01-13 18:23:43 -06:00
09bc65e947 feedkit: ergonomics pass (shared logger, route compiler, param helpers)
- Add logging.Logf as the canonical printf-style logger type used across feedkit.
  - Update scheduler and dispatch to alias their Logger types to logging.Logf.
  - Eliminates type-mismatch friction when wiring one log function through the system.

- Add dispatch.CompileRoutes(*config.Config) ([]dispatch.Route, error)
  - Compiles config routes into dispatch routes with event.ParseKind normalization.
  - If routes: is omitted, defaults to “all sinks receive all kinds”.

- Expand config param helpers for both SourceConfig and SinkConfig
  - Add ParamBool/ParamInt/ParamDuration/ParamStringSlice (+ Default variants).
  - Supports common YAML-decoded types (bool/int/float/string, []any, etc.)
  - Keeps driver code cleaner and reduces repeated type assertions.

- Fix Postgres sink validation error prefix ("postgres sink", not "rabbitmq sink").
2026-01-13 14:40:29 -06:00
64 changed files with 6997 additions and 266 deletions

View File

@@ -1,3 +1,92 @@
# feedkit
Feedkit provides core interfaces and plumbing for applications that ingest and process feeds.
`feedkit` is a small Go toolkit for building feed-processing daemons.
It gives you the reusable plumbing around collection, processing, routing, and
emission, while leaving domain concepts, schemas, and application wiring in
your daemon. The intended shape is a family of sibling applications such as
`weatherfeeder`, `newsfeeder`, or `earthquakefeeder` that all share the same
infrastructure patterns without sharing domain logic.
## What It Does
A daemon built on `feedkit` typically:
- ingests upstream input by polling HTTP APIs or consuming streams
- emits domain-agnostic `event.Event` values
- optionally processes those events with stages like dedupe or normalization
- routes events to one or more sinks such as stdout, NATS, or Postgres
Conceptually, the pipeline is:
`Collect -> Process -> Route -> Emit`
## Philosophy
`feedkit` is intentionally not a framework.
It does not try to own:
- your domain payload schemas
- your domain event kinds
- your daemon lifecycle or `main.go`
- your observability stack or deployment model
Instead, it provides small composable packages that are easy to wire together in
different daemons.
## When To Use It
`feedkit` is a good fit when you want:
- multiple small ingestion daemons with shared infrastructure patterns
- clear separation between raw upstream payloads and normalized canonical models
- reusable routing and sink behavior across domains
- strong config and event-envelope conventions without centralizing domain rules
It is a poor fit if you want a monolithic framework that dictates application
structure end-to-end.
## Built-In Capabilities
`feedkit` currently includes:
- strict YAML config loading and validation
- polling and streaming source abstractions
- scheduler orchestration for configured sources and supervised stream workers
- optional pipeline processors
- built-in dedupe and normalization processors
- route compilation and sink fanout
- built-in sinks for `stdout`, `nats`, and `postgres`
The Postgres sink is intentionally split between feedkit-owned infrastructure
and daemon-owned schema mapping. `feedkit` manages connection setup, DDL,
writes, and pruning; downstream applications define the schema and event mapper.
## Typical Wiring
At a high level, a daemon built on `feedkit` does this:
1. Load config.
2. Register domain-specific source drivers.
3. Register built-in and/or custom sinks.
4. Build sources, sinks, and optional processor chain from config.
5. Compile routes.
6. Start the scheduler and dispatcher.
The package docs are the better source of truth for code-level details. In
particular, each subpackage `doc.go` describes its external API surface and any
optional helper APIs in `helpers.go`.
## Package Layout
The major packages are:
- `config`: config loading and validation
- `event`: the domain-agnostic event envelope
- `sources`: source interfaces and reusable source helpers
- `scheduler`: source execution and cadence management
- `processors`: processor interfaces and registry
- `processors/dedupe`: built-in in-memory dedupe processor
- `processors/normalize`: built-in normalization processor and helpers
- `pipeline`: optional processor chain
- `dispatch`: route compilation and fanout
- `sinks`: sink interfaces, built-ins, and explicit Postgres factory helpers
The root package docs in `doc.go` provide a concise package-by-package map for
Go documentation consumers.

View File

@@ -21,31 +21,79 @@ type Config struct {
Routes []RouteConfig `yaml:"routes"`
}
// SourceConfig describes one polling job.
// SourceMode selects how a source receives upstream input.
//
// Empty mode means "auto": feedkit infers mode from the registered driver type.
type SourceMode string
const (
SourceModeAuto SourceMode = ""
SourceModePoll SourceMode = "poll"
SourceModeStream SourceMode = "stream"
)
// Normalize lowercases and trims the mode.
func (m SourceMode) Normalize() SourceMode {
switch strings.ToLower(strings.TrimSpace(string(m))) {
case "":
return SourceModeAuto
case string(SourceModePoll):
return SourceModePoll
case string(SourceModeStream):
return SourceModeStream
default:
return SourceMode(strings.ToLower(strings.TrimSpace(string(m))))
}
}
// SourceConfig describes one input source.
//
// This is intentionally generic:
// - driver-specific knobs belong in Params.
// - "kind" is allowed (useful for safety checks / routing), but feedkit does not
// restrict the allowed values.
// - mode controls polling vs streaming behavior.
// - expected emitted kinds are optional and domain-defined.
type SourceConfig struct {
Name string `yaml:"name"`
Driver string `yaml:"driver"` // e.g. "openmeteo_observation", "rss_feed", etc.
Every Duration `yaml:"every"` // "15m", "1m", etc.
// Mode is optional:
// - "poll": Every must be set (>0)
// - "stream": Every must be omitted/zero
// - empty: infer from driver registration type (poll vs stream)
Mode SourceMode `yaml:"mode"`
// Kind is optional and domain-defined. If set, it should be a non-empty string.
// Domains commonly use it to enforce "this source should only emit kind X".
Kind string `yaml:"kind"`
// Every is the poll cadence for poll-mode sources ("15m", "1m", etc.).
Every Duration `yaml:"every"`
// Kinds is optional and domain-defined.
// If set, it describes the expected emitted event kinds for this source.
Kinds []string `yaml:"kinds"`
// Params are driver-specific settings (URL, headers, station IDs, API keys, etc.).
// The driver implementation is responsible for reading/validating these.
Params map[string]any `yaml:"params"`
}
// ExpectedKinds returns normalized expected kinds from config.
func (cfg SourceConfig) ExpectedKinds() []string {
out := make([]string, 0, len(cfg.Kinds))
for _, k := range cfg.Kinds {
k = strings.TrimSpace(k)
if k == "" {
continue
}
out = append(out, k)
}
if len(out) == 0 {
return nil
}
return out
}
// SinkConfig describes one output sink adapter.
type SinkConfig struct {
Name string `yaml:"name"`
Driver string `yaml:"driver"` // "stdout", "file", "postgres", "rabbitmq", ...
Driver string `yaml:"driver"` // "stdout", "nats", "postgres", ...
Params map[string]any `yaml:"params"` // sink-specific settings
}
@@ -54,7 +102,10 @@ type RouteConfig struct {
Sink string `yaml:"sink"` // sink name
// Kinds is domain-defined. feedkit only enforces that each entry is non-empty.
// Whether a given daemon "recognizes" a kind is domain-specific validation.
//
// If Kinds is omitted or empty, the route matches ALL kinds.
// This is useful when you want explicit per-sink routing rules even when a
// particular sink should receive everything.
Kinds []string `yaml:"kinds"`
}
@@ -128,12 +179,3 @@ func (d *Duration) UnmarshalYAML(value *yaml.Node) error {
// Anything else: reject.
return fmt.Errorf("duration must be a string like 15m or an integer minutes, got tag %s", value.Tag)
}
func isAllDigits(s string) bool {
for _, r := range s {
if r < '0' || r > '9' {
return false
}
}
return len(s) > 0
}

48
config/config_test.go Normal file
View File

@@ -0,0 +1,48 @@
package config
import (
"reflect"
"testing"
)
func TestSourceConfigExpectedKinds(t *testing.T) {
tests := []struct {
name string
cfg SourceConfig
want []string
}{
{
name: "plural kinds normalized",
cfg: SourceConfig{
Kinds: []string{" observation ", "forecast"},
},
want: []string{"observation", "forecast"},
},
{
name: "empty kinds",
cfg: SourceConfig{},
want: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.cfg.ExpectedKinds()
if !reflect.DeepEqual(got, tt.want) {
t.Fatalf("ExpectedKinds() = %#v, want %#v", got, tt.want)
}
})
}
}
func TestSourceModeNormalize(t *testing.T) {
if got := SourceMode(" Poll ").Normalize(); got != SourceModePoll {
t.Fatalf("Normalize poll = %q, want %q", got, SourceModePoll)
}
if got := SourceMode("STREAM").Normalize(); got != SourceModeStream {
t.Fatalf("Normalize stream = %q, want %q", got, SourceModeStream)
}
if got := SourceMode("").Normalize(); got != SourceModeAuto {
t.Fatalf("Normalize auto = %q, want %q", got, SourceModeAuto)
}
}

View File

@@ -83,14 +83,34 @@ func (c *Config) Validate() error {
m.Add(fieldErr(path+".driver", "is required (e.g. openmeteo_observation, rss_feed, ...)"))
}
// Every
if s.Every.Duration <= 0 {
m.Add(fieldErr(path+".every", "must be a positive duration (e.g. 15m, 1m, 30s)"))
// Mode
mode := s.Mode.Normalize()
if s.Mode != SourceModeAuto && mode != SourceModePoll && mode != SourceModeStream {
m.Add(fieldErr(path+".mode", `must be one of: "poll", "stream" (or omit for auto)`))
}
// Kind (optional but if present must be non-empty after trimming)
if s.Kind != "" && strings.TrimSpace(s.Kind) == "" {
m.Add(fieldErr(path+".kind", "cannot be blank (omit it entirely, or provide a non-empty string)"))
// Every
if s.Every.Duration < 0 {
m.Add(fieldErr(path+".every", "is optional, but must be a positive duration (e.g. 15m, 1m, 30s) if provided"))
} else {
switch mode {
case SourceModePoll:
if s.Every.Duration <= 0 {
m.Add(fieldErr(path+".every", `is required when mode="poll" (e.g. 15m, 1m, 30s)`))
}
case SourceModeStream:
if s.Every.Duration > 0 {
m.Add(fieldErr(path+".every", `must be omitted when mode="stream"`))
}
}
}
// Kinds (optional)
for j, k := range s.Kinds {
kpath := fmt.Sprintf("%s.kinds[%d]", path, j)
if strings.TrimSpace(k) == "" {
m.Add(fieldErr(kpath, "kind cannot be empty"))
}
}
// Params can be nil; that's fine.
@@ -115,7 +135,7 @@ func (c *Config) Validate() error {
}
if strings.TrimSpace(s.Driver) == "" {
m.Add(fieldErr(path+".driver", "is required (stdout|file|postgres|rabbitmq|...)"))
m.Add(fieldErr(path+".driver", "is required (stdout|nats|postgres|...)"))
}
// Params can be nil; that's fine.
@@ -133,11 +153,8 @@ func (c *Config) Validate() error {
m.Add(fieldErr(path+".sink", fmt.Sprintf("references unknown sink %q (define it under sinks:)", r.Sink)))
}
if len(r.Kinds) == 0 {
// You could relax this later (e.g. empty == "all kinds"), but for now
// keeping it strict prevents accidental "route does nothing".
m.Add(fieldErr(path+".kinds", "must contain at least one kind"))
} else {
// Kinds is optional. If omitted or empty, the route matches ALL kinds.
// If provided, each entry must be non-empty.
for j, k := range r.Kinds {
kpath := fmt.Sprintf("%s.kinds[%d]", path, j)
if strings.TrimSpace(k) == "" {
@@ -145,7 +162,6 @@ func (c *Config) Validate() error {
}
}
}
}
return m.Err()
}

View File

@@ -1,32 +1,21 @@
// feedkit/config/params.go
package config
import "strings"
import (
"math"
"strconv"
"strings"
"time"
)
// ---- SourceConfig param helpers ----
// ParamString returns the first non-empty string found for any of the provided keys.
// Values must actually be strings in the decoded config; other types are ignored.
//
// This keeps cfg.Params flexible (map[string]any) while letting callers stay type-safe.
func (cfg SourceConfig) ParamString(keys ...string) (string, bool) {
if cfg.Params == nil {
return "", false
}
for _, k := range keys {
v, ok := cfg.Params[k]
if !ok || v == nil {
continue
}
s, ok := v.(string)
if !ok {
continue
}
s = strings.TrimSpace(s)
if s == "" {
continue
}
return s, true
}
return "", false
return paramString(cfg.Params, keys...)
}
// ParamStringDefault returns ParamString(keys...) if present; otherwise it returns def.
@@ -38,14 +27,150 @@ func (cfg SourceConfig) ParamStringDefault(def string, keys ...string) string {
return strings.TrimSpace(def)
}
// ParamBool returns the first boolean found for any of the provided keys.
//
// Accepted types in Params:
// - bool
// - string: parsed via strconv.ParseBool ("true"/"false"/"1"/"0", etc.)
func (cfg SourceConfig) ParamBool(keys ...string) (bool, bool) {
return paramBool(cfg.Params, keys...)
}
func (cfg SourceConfig) ParamBoolDefault(def bool, keys ...string) bool {
if v, ok := cfg.ParamBool(keys...); ok {
return v
}
return def
}
// ParamInt returns the first integer-like value found for any of the provided keys.
//
// Accepted types in Params:
// - any integer type (int, int64, uint32, ...)
// - float32/float64 ONLY if it is an exact integer (e.g. 15.0)
// - string: parsed via strconv.Atoi (e.g. "42")
func (cfg SourceConfig) ParamInt(keys ...string) (int, bool) {
return paramInt(cfg.Params, keys...)
}
func (cfg SourceConfig) ParamIntDefault(def int, keys ...string) int {
if v, ok := cfg.ParamInt(keys...); ok {
return v
}
return def
}
// ParamDuration returns the first duration-like value found for any of the provided keys.
//
// Accepted types in Params:
// - time.Duration
// - string: parsed via time.ParseDuration (e.g. "250ms", "30s", "5m")
// - if the string is all digits (e.g. "30"), it is interpreted as SECONDS
// - numeric: interpreted as SECONDS (e.g. 30 => 30s)
//
// Rationale: Param durations are usually timeouts/backoffs; seconds are a sane numeric default.
// If you want minutes/hours, prefer a duration string like "5m" or "1h".
func (cfg SourceConfig) ParamDuration(keys ...string) (time.Duration, bool) {
return paramDuration(cfg.Params, keys...)
}
func (cfg SourceConfig) ParamDurationDefault(def time.Duration, keys ...string) time.Duration {
if v, ok := cfg.ParamDuration(keys...); ok {
return v
}
return def
}
// ParamStringSlice returns the first string-slice-like value found for any of the provided keys.
//
// Accepted types in Params:
// - []string
// - []any where each element is a string
// - string:
// - if it contains commas, split on commas (",") and trim each item
// - otherwise treat as a single-item list
//
// Empty/blank items are removed.
func (cfg SourceConfig) ParamStringSlice(keys ...string) ([]string, bool) {
return paramStringSlice(cfg.Params, keys...)
}
// ---- SinkConfig param helpers ----
// ParamString returns the first non-empty string found for any of the provided keys
// in SinkConfig.Params. (Same rationale as SourceConfig.ParamString.)
func (cfg SinkConfig) ParamString(keys ...string) (string, bool) {
if cfg.Params == nil {
return "", false
return paramString(cfg.Params, keys...)
}
// ParamStringDefault returns ParamString(keys...) if present; otherwise it returns def.
// Symmetric helper for sink implementations.
func (cfg SinkConfig) ParamStringDefault(def string, keys ...string) string {
if s, ok := cfg.ParamString(keys...); ok {
return s
}
return strings.TrimSpace(def)
}
func (cfg SinkConfig) ParamBool(keys ...string) (bool, bool) {
return paramBool(cfg.Params, keys...)
}
func (cfg SinkConfig) ParamBoolDefault(def bool, keys ...string) bool {
if v, ok := cfg.ParamBool(keys...); ok {
return v
}
return def
}
func (cfg SinkConfig) ParamInt(keys ...string) (int, bool) {
return paramInt(cfg.Params, keys...)
}
func (cfg SinkConfig) ParamIntDefault(def int, keys ...string) int {
if v, ok := cfg.ParamInt(keys...); ok {
return v
}
return def
}
func (cfg SinkConfig) ParamDuration(keys ...string) (time.Duration, bool) {
return paramDuration(cfg.Params, keys...)
}
func (cfg SinkConfig) ParamDurationDefault(def time.Duration, keys ...string) time.Duration {
if v, ok := cfg.ParamDuration(keys...); ok {
return v
}
return def
}
func (cfg SinkConfig) ParamStringSlice(keys ...string) ([]string, bool) {
return paramStringSlice(cfg.Params, keys...)
}
// ---- shared implementations (package-private) ----
func paramAny(params map[string]any, keys ...string) (any, bool) {
if params == nil {
return nil, false
}
for _, k := range keys {
v, ok := cfg.Params[k]
v, ok := params[k]
if !ok || v == nil {
continue
}
return v, true
}
return nil, false
}
func paramString(params map[string]any, keys ...string) (string, bool) {
for _, k := range keys {
if params == nil {
return "", false
}
v, ok := params[k]
if !ok || v == nil {
continue
}
@@ -62,11 +187,213 @@ func (cfg SinkConfig) ParamString(keys ...string) (string, bool) {
return "", false
}
// ParamStringDefault returns ParamString(keys...) if present; otherwise it returns def.
// Symmetric helper for sink implementations.
func (cfg SinkConfig) ParamStringDefault(def string, keys ...string) string {
if s, ok := cfg.ParamString(keys...); ok {
return s
func paramBool(params map[string]any, keys ...string) (bool, bool) {
v, ok := paramAny(params, keys...)
if !ok {
return false, false
}
return strings.TrimSpace(def)
switch t := v.(type) {
case bool:
return t, true
case string:
s := strings.TrimSpace(t)
if s == "" {
return false, false
}
parsed, err := strconv.ParseBool(s)
if err != nil {
return false, false
}
return parsed, true
default:
return false, false
}
}
func paramInt(params map[string]any, keys ...string) (int, bool) {
v, ok := paramAny(params, keys...)
if !ok {
return 0, false
}
switch t := v.(type) {
case int:
return t, true
case int8:
return int(t), true
case int16:
return int(t), true
case int32:
return int(t), true
case int64:
return int(t), true
case uint:
return int(t), true
case uint8:
return int(t), true
case uint16:
return int(t), true
case uint32:
return int(t), true
case uint64:
return int(t), true
case float32:
f := float64(t)
if math.IsNaN(f) || math.IsInf(f, 0) {
return 0, false
}
if math.Trunc(f) != f {
return 0, false
}
return int(f), true
case float64:
if math.IsNaN(t) || math.IsInf(t, 0) {
return 0, false
}
if math.Trunc(t) != t {
return 0, false
}
return int(t), true
case string:
s := strings.TrimSpace(t)
if s == "" {
return 0, false
}
n, err := strconv.Atoi(s)
if err != nil {
return 0, false
}
return n, true
default:
return 0, false
}
}
func paramDuration(params map[string]any, keys ...string) (time.Duration, bool) {
v, ok := paramAny(params, keys...)
if !ok {
return 0, false
}
switch t := v.(type) {
case time.Duration:
if t <= 0 {
return 0, false
}
return t, true
case string:
s := strings.TrimSpace(t)
if s == "" {
return 0, false
}
// Numeric strings are interpreted as seconds (see doc comment).
if isAllDigits(s) {
n, err := strconv.Atoi(s)
if err != nil || n <= 0 {
return 0, false
}
return time.Duration(n) * time.Second, true
}
d, err := time.ParseDuration(s)
if err != nil || d <= 0 {
return 0, false
}
return d, true
case int:
if t <= 0 {
return 0, false
}
return time.Duration(t) * time.Second, true
case int64:
if t <= 0 {
return 0, false
}
return time.Duration(t) * time.Second, true
case float64:
if math.IsNaN(t) || math.IsInf(t, 0) || t <= 0 {
return 0, false
}
// Allow fractional seconds.
secs := t * float64(time.Second)
return time.Duration(secs), true
case float32:
f := float64(t)
if math.IsNaN(f) || math.IsInf(f, 0) || f <= 0 {
return 0, false
}
secs := f * float64(time.Second)
return time.Duration(secs), true
default:
return 0, false
}
}
func paramStringSlice(params map[string]any, keys ...string) ([]string, bool) {
v, ok := paramAny(params, keys...)
if !ok {
return nil, false
}
clean := func(items []string) ([]string, bool) {
out := make([]string, 0, len(items))
for _, it := range items {
it = strings.TrimSpace(it)
if it == "" {
continue
}
out = append(out, it)
}
if len(out) == 0 {
return nil, false
}
return out, true
}
switch t := v.(type) {
case []string:
return clean(t)
case []any:
tmp := make([]string, 0, len(t))
for _, it := range t {
s, ok := it.(string)
if !ok {
continue
}
tmp = append(tmp, s)
}
return clean(tmp)
case string:
s := strings.TrimSpace(t)
if s == "" {
return nil, false
}
if strings.Contains(s, ",") {
parts := strings.Split(s, ",")
return clean(parts)
}
return clean([]string{s})
default:
return nil, false
}
}
func isAllDigits(s string) bool {
for _, r := range s {
if r < '0' || r > '9' {
return false
}
}
return len(s) > 0
}

139
config/validate_test.go Normal file
View File

@@ -0,0 +1,139 @@
package config
import (
"strings"
"testing"
"time"
)
func TestValidate_RouteKindsEmptyIsAllowed(t *testing.T) {
cfg := &Config{
Sources: []SourceConfig{
{Name: "src1", Driver: "driver1", Every: Duration{Duration: time.Minute}},
},
Sinks: []SinkConfig{
{Name: "sink1", Driver: "stdout"},
},
Routes: []RouteConfig{
{Sink: "sink1", Kinds: nil}, // omitted
{Sink: "sink1", Kinds: []string{}}, // explicit empty
},
}
if err := cfg.Validate(); err != nil {
t.Fatalf("expected no error, got: %v", err)
}
}
func TestValidate_RouteKindsRejectsBlankEntries(t *testing.T) {
cfg := &Config{
Sources: []SourceConfig{
{Name: "src1", Driver: "driver1", Every: Duration{Duration: time.Minute}},
},
Sinks: []SinkConfig{
{Name: "sink1", Driver: "stdout"},
},
Routes: []RouteConfig{
{Sink: "sink1", Kinds: []string{"observation", " ", "alert"}},
},
}
err := cfg.Validate()
if err == nil {
t.Fatalf("expected error, got nil")
}
if !strings.Contains(err.Error(), "routes[0].kinds[1]") {
t.Fatalf("expected error to mention blank kind entry, got: %v", err)
}
}
func TestValidate_SourceModePollRequiresEvery(t *testing.T) {
cfg := &Config{
Sources: []SourceConfig{
{Name: "src1", Driver: "driver1", Mode: SourceModePoll},
},
Sinks: []SinkConfig{
{Name: "sink1", Driver: "stdout"},
},
}
err := cfg.Validate()
if err == nil {
t.Fatalf("expected error, got nil")
}
if !strings.Contains(err.Error(), `sources[0].every`) {
t.Fatalf("expected error to mention sources[0].every, got: %v", err)
}
}
func TestValidate_SourceModeStreamRejectsEvery(t *testing.T) {
cfg := &Config{
Sources: []SourceConfig{
{
Name: "src1",
Driver: "driver1",
Mode: SourceModeStream,
Every: Duration{Duration: time.Minute},
},
},
Sinks: []SinkConfig{
{Name: "sink1", Driver: "stdout"},
},
}
err := cfg.Validate()
if err == nil {
t.Fatalf("expected error, got nil")
}
if !strings.Contains(err.Error(), `sources[0].every`) {
t.Fatalf("expected error to mention sources[0].every, got: %v", err)
}
}
func TestValidate_SourceModeRejectsUnknownValue(t *testing.T) {
cfg := &Config{
Sources: []SourceConfig{
{
Name: "src1",
Driver: "driver1",
Mode: SourceMode("batch"),
Every: Duration{Duration: time.Minute},
},
},
Sinks: []SinkConfig{
{Name: "sink1", Driver: "stdout"},
},
}
err := cfg.Validate()
if err == nil {
t.Fatalf("expected error, got nil")
}
if !strings.Contains(err.Error(), `sources[0].mode`) {
t.Fatalf("expected error to mention sources[0].mode, got: %v", err)
}
}
func TestValidate_SourceKindsRejectBlankEntries(t *testing.T) {
cfg := &Config{
Sources: []SourceConfig{
{
Name: "src1",
Driver: "driver1",
Every: Duration{Duration: time.Minute},
Kinds: []string{"observation", " "},
},
},
Sinks: []SinkConfig{
{Name: "sink1", Driver: "stdout"},
},
}
err := cfg.Validate()
if err == nil {
t.Fatalf("expected error, got nil")
}
if !strings.Contains(err.Error(), `sources[0].kinds[1]`) {
t.Fatalf("expected error to mention sources[0].kinds[1], got: %v", err)
}
}

View File

@@ -6,10 +6,16 @@ import (
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
"gitea.maximumdirect.net/ejr/feedkit/logging"
"gitea.maximumdirect.net/ejr/feedkit/pipeline"
"gitea.maximumdirect.net/ejr/feedkit/sinks"
)
// Logger is a printf-style logger used throughout dispatch.
// It is an alias to the shared feedkit logging type so callers can pass
// one function everywhere without type mismatch friction.
type Logger = logging.Logf
type Dispatcher struct {
In <-chan event.Event
@@ -35,8 +41,6 @@ type Route struct {
Kinds map[event.Kind]bool
}
type Logger func(format string, args ...any)
func (d *Dispatcher) Run(ctx context.Context, logf Logger) error {
if d.In == nil {
return fmt.Errorf("dispatcher.Run: In channel is nil")

89
dispatch/routes.go Normal file
View File

@@ -0,0 +1,89 @@
package dispatch
import (
"fmt"
"strings"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// CompileRoutes converts config.Config routes into dispatch.Route rules.
//
// Behavior:
// - If cfg.Routes is empty, we default to "all sinks receive all kinds".
// (Implemented as one Route per sink with Kinds == nil.)
// - If a specific route's kinds: is omitted or empty, that route matches ALL kinds.
// (Also compiled as Kinds == nil.)
// - Kind strings are normalized via event.ParseKind (lowercase + trim).
//
// Note: config.Validate() ensures route.sink references a known sink and rejects
// blank kind entries. We re-check a few invariants here anyway so CompileRoutes
// is safe to call even if a daemon chooses not to call Validate().
func CompileRoutes(cfg *config.Config) ([]Route, error) {
if cfg == nil {
return nil, fmt.Errorf("dispatch.CompileRoutes: cfg is nil")
}
if len(cfg.Sinks) == 0 {
return nil, fmt.Errorf("dispatch.CompileRoutes: cfg has no sinks")
}
// Build a quick lookup of sink names (exact match; no normalization).
sinkNames := make(map[string]bool, len(cfg.Sinks))
for i, s := range cfg.Sinks {
if strings.TrimSpace(s.Name) == "" {
return nil, fmt.Errorf("dispatch.CompileRoutes: sinks[%d].name is empty", i)
}
sinkNames[s.Name] = true
}
// Default routing: everything to every sink.
if len(cfg.Routes) == 0 {
out := make([]Route, 0, len(cfg.Sinks))
for _, s := range cfg.Sinks {
out = append(out, Route{
SinkName: s.Name,
Kinds: nil, // nil/empty map means "all kinds"
})
}
return out, nil
}
out := make([]Route, 0, len(cfg.Routes))
for i, r := range cfg.Routes {
sink := r.Sink
if strings.TrimSpace(sink) == "" {
return nil, fmt.Errorf("dispatch.CompileRoutes: routes[%d].sink is required", i)
}
if !sinkNames[sink] {
return nil, fmt.Errorf("dispatch.CompileRoutes: routes[%d].sink references unknown sink %q", i, sink)
}
// If kinds is omitted/empty, this route matches all kinds.
if len(r.Kinds) == 0 {
out = append(out, Route{
SinkName: sink,
Kinds: nil,
})
continue
}
kinds := make(map[event.Kind]bool, len(r.Kinds))
for j, raw := range r.Kinds {
k, err := event.ParseKind(raw)
if err != nil {
return nil, fmt.Errorf("dispatch.CompileRoutes: routes[%d].kinds[%d]: %w", i, j, err)
}
kinds[k] = true
}
out = append(out, Route{
SinkName: sink,
Kinds: kinds,
})
}
return out, nil
}

67
dispatch/routes_test.go Normal file
View File

@@ -0,0 +1,67 @@
package dispatch
import (
"testing"
"gitea.maximumdirect.net/ejr/feedkit/config"
)
func TestCompileRoutes_DefaultIsAllSinksAllKinds(t *testing.T) {
cfg := &config.Config{
Sinks: []config.SinkConfig{
{Name: "a", Driver: "stdout"},
{Name: "b", Driver: "stdout"},
},
// Routes omitted => default
}
routes, err := CompileRoutes(cfg)
if err != nil {
t.Fatalf("CompileRoutes error: %v", err)
}
if len(routes) != 2 {
t.Fatalf("expected 2 routes, got %d", len(routes))
}
// Order should match cfg.Sinks order (deterministic).
if routes[0].SinkName != "a" || routes[1].SinkName != "b" {
t.Fatalf("unexpected route order: %+v", routes)
}
for _, r := range routes {
if len(r.Kinds) != 0 {
t.Fatalf("expected nil/empty kinds for default routes, got: %+v", r.Kinds)
}
}
}
func TestCompileRoutes_EmptyKindsMeansAllKinds(t *testing.T) {
cfg := &config.Config{
Sinks: []config.SinkConfig{
{Name: "sink1", Driver: "stdout"},
},
Routes: []config.RouteConfig{
{Sink: "sink1"}, // omitted kinds
{Sink: "sink1", Kinds: nil}, // explicit nil
{Sink: "sink1", Kinds: []string{}}, // explicit empty
},
}
routes, err := CompileRoutes(cfg)
if err != nil {
t.Fatalf("CompileRoutes error: %v", err)
}
if len(routes) != 3 {
t.Fatalf("expected 3 routes, got %d", len(routes))
}
for i, r := range routes {
if r.SinkName != "sink1" {
t.Fatalf("route[%d] unexpected sink: %q", i, r.SinkName)
}
if len(r.Kinds) != 0 {
t.Fatalf("route[%d] expected nil/empty kinds (match all), got: %+v", i, r.Kinds)
}
}
}

57
doc.go Normal file
View File

@@ -0,0 +1,57 @@
// Package feedkit provides a high-level map of the feedkit package set.
//
// Most real applications do not import the root package directly. Instead, they
// compose the subpackages that handle configuration, collection, processing,
// routing, and sinks.
//
// The usual flow through feedkit is:
//
// Collect -> Process -> Route -> Emit
//
// That flow maps to packages like this:
//
// - config
// Loads and validates daemon config. This package owns domain-agnostic
// config shape and consistency checks.
//
// - event
// Defines the event.Event envelope shared across sources, processors,
// dispatch, and sinks.
//
// - sources
// Defines polling and streaming source interfaces, the source registry, and
// reusable source helpers.
//
// - scheduler
// Runs configured sources on a cadence and supervises long-lived stream
// workers with restart/fatal handling.
//
// - processors
// Defines the generic processor interface and registry used to build
// ordered processor chains.
//
// - processors/dedupe
// Built-in in-memory dedupe processor keyed by Event.ID.
//
// - processors/normalize
// Built-in normalization processor plus helper APIs for raw-to-canonical
// event mapping.
//
// - pipeline
// Applies an ordered processor chain between collection and dispatch.
//
// - dispatch
// Compiles routes and fans events out to sinks with per-sink isolation.
//
// - sinks
// Defines sink interfaces, the sink registry, schema-free built-in sinks,
// and explicit Postgres factory helpers.
//
// feedkit is intentionally domain-agnostic. Domain schemas, domain event kinds,
// upstream-specific parsing, and daemon lifecycle remain the responsibility of
// each concrete application.
//
// For repository-level overview and usage narrative, see README.md. For
// code-level details, each subpackage doc.go is the source of truth for that
// package's public API surface and optional helpers.
package feedkit

14
go.mod
View File

@@ -2,4 +2,16 @@ module gitea.maximumdirect.net/ejr/feedkit
go 1.22
require gopkg.in/yaml.v3 v3.0.1
require (
github.com/lib/pq v1.10.9
github.com/nats-io/nats.go v1.34.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/klauspost/compress v1.17.2 // indirect
github.com/nats-io/nkeys v0.4.7 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
golang.org/x/crypto v0.18.0 // indirect
golang.org/x/sys v0.16.0 // indirect
)

15
go.sum
View File

@@ -1,3 +1,18 @@
github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4=
github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/nats-io/nats.go v1.34.0 h1:fnxnPCNiwIG5w08rlMcEKTUw4AV/nKyGCOJE8TdhSPk=
github.com/nats-io/nats.go v1.34.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8=
github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI=
github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc=
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,67 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"net/url"
"strings"
"time"
_ "github.com/lib/pq"
)
const initTimeout = 5 * time.Second
var (
sqlOpen = sql.Open
pingDB = func(ctx context.Context, db *sql.DB) error { return db.PingContext(ctx) }
)
// ConnConfig describes the minimal connection settings shared by feedkit's
// Postgres readers and writers.
type ConnConfig struct {
URI string
Username string
Password string
}
// BuildDSN validates a Postgres URI and injects credentials into it.
func BuildDSN(cfg ConnConfig) (string, error) {
u, err := url.Parse(strings.TrimSpace(cfg.URI))
if err != nil {
return "", fmt.Errorf("invalid uri: %w", err)
}
if u.Scheme == "" {
return "", fmt.Errorf("invalid uri: missing scheme")
}
if u.Host == "" {
return "", fmt.Errorf("invalid uri: missing host")
}
u.User = url.UserPassword(cfg.Username, cfg.Password)
return u.String(), nil
}
// Open builds a DSN, opens a database handle, and verifies connectivity with a
// bounded ping before returning the handle.
func Open(ctx context.Context, cfg ConnConfig) (*sql.DB, error) {
dsn, err := BuildDSN(cfg)
if err != nil {
return nil, err
}
db, err := sqlOpen("postgres", dsn)
if err != nil {
return nil, err
}
pingCtx, cancel := context.WithTimeout(ctx, initTimeout)
defer cancel()
if err := pingDB(pingCtx, db); err != nil {
_ = db.Close()
return nil, err
}
return db, nil
}

View File

@@ -0,0 +1,165 @@
package postgres
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"io"
"net/url"
"strings"
"sync"
"testing"
)
func withPostgresPackageTestState(t *testing.T) {
t.Helper()
oldSQLOpen := sqlOpen
oldPingDB := pingDB
t.Cleanup(func() {
sqlOpen = oldSQLOpen
pingDB = oldPingDB
})
}
func TestBuildDSNInjectsCredentials(t *testing.T) {
dsn, err := BuildDSN(ConnConfig{
URI: " postgres://db.example.local:5432/feedkit?sslmode=disable ",
Username: "app_user",
Password: "app_pass",
})
if err != nil {
t.Fatalf("BuildDSN() error = %v", err)
}
u, err := url.Parse(dsn)
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
if u.User == nil || u.User.Username() != "app_user" {
t.Fatalf("username = %q, want app_user", u.User.Username())
}
pass, ok := u.User.Password()
if !ok || pass != "app_pass" {
t.Fatalf("password = %q, want app_pass", pass)
}
}
func TestBuildDSNRejectsInvalidURI(t *testing.T) {
_, err := BuildDSN(ConnConfig{URI: "http://[::1", Username: "u", Password: "p"})
if err == nil {
t.Fatalf("BuildDSN() error = nil, want error")
}
if !strings.Contains(err.Error(), "invalid uri") {
t.Fatalf("BuildDSN() error = %q", err)
}
}
func TestBuildDSNRejectsMissingScheme(t *testing.T) {
_, err := BuildDSN(ConnConfig{URI: "//db.example.local/feedkit", Username: "u", Password: "p"})
if err == nil {
t.Fatalf("BuildDSN() error = nil, want error")
}
if !strings.Contains(err.Error(), "missing scheme") {
t.Fatalf("BuildDSN() error = %q", err)
}
}
func TestBuildDSNRejectsMissingHost(t *testing.T) {
_, err := BuildDSN(ConnConfig{URI: "postgres:///feedkit", Username: "u", Password: "p"})
if err == nil {
t.Fatalf("BuildDSN() error = nil, want error")
}
if !strings.Contains(err.Error(), "missing host") {
t.Fatalf("BuildDSN() error = %q", err)
}
}
func TestOpenPropagatesOpenFailure(t *testing.T) {
withPostgresPackageTestState(t)
sqlOpen = func(_, _ string) (*sql.DB, error) {
return nil, errors.New("open failed")
}
_, err := Open(context.Background(), ConnConfig{
URI: "postgres://db.example.local/feedkit",
Username: "u",
Password: "p",
})
if err == nil {
t.Fatalf("Open() error = nil, want error")
}
if !strings.Contains(err.Error(), "open failed") {
t.Fatalf("Open() error = %q", err)
}
}
func TestOpenPropagatesPingFailure(t *testing.T) {
withPostgresPackageTestState(t)
const driverName = "feedkit_internal_postgres_ping_fail"
registerPingTestDriver(driverName, errors.New("ping failed"))
sqlOpen = func(_, _ string) (*sql.DB, error) {
return sql.Open(driverName, "")
}
_, err := Open(context.Background(), ConnConfig{
URI: "postgres://db.example.local/feedkit",
Username: "u",
Password: "p",
})
if err == nil {
t.Fatalf("Open() error = nil, want error")
}
if !strings.Contains(err.Error(), "ping failed") {
t.Fatalf("Open() error = %q", err)
}
}
var (
pingDriverMu sync.Mutex
pingDriverSeen = map[string]bool{}
)
func registerPingTestDriver(name string, pingErr error) {
pingDriverMu.Lock()
defer pingDriverMu.Unlock()
if pingDriverSeen[name] {
return
}
sql.Register(name, &pingTestDriver{pingErr: pingErr})
pingDriverSeen[name] = true
}
type pingTestDriver struct {
pingErr error
}
func (d *pingTestDriver) Open(string) (driver.Conn, error) {
return &pingTestConn{pingErr: d.pingErr}, nil
}
type pingTestConn struct {
pingErr error
}
func (c *pingTestConn) Prepare(string) (driver.Stmt, error) {
return nil, errors.New("not implemented")
}
func (c *pingTestConn) Close() error { return nil }
func (c *pingTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") }
func (c *pingTestConn) Ping(context.Context) error { return c.pingErr }
func (c *pingTestConn) QueryContext(context.Context, string, []driver.NamedValue) (driver.Rows, error) {
return &pingTestRows{}, nil
}
type pingTestRows struct{}
func (r *pingTestRows) Columns() []string { return []string{"ok"} }
func (r *pingTestRows) Close() error { return nil }
func (r *pingTestRows) Next([]driver.Value) error { return io.EOF }

8
logging/logging.go Normal file
View File

@@ -0,0 +1,8 @@
package logging
// Logf is the shared printf-style logger signature used across feedkit.
//
// Keeping this in one place avoids the "scheduler.Logger vs dispatch.Logger"
// friction and makes it trivial for downstream apps to pass a single log
// function throughout the system.
type Logf func(format string, args ...any)

View File

@@ -1,5 +0,0 @@
package pipeline
// Placeholder for dedupe processor:
// - key by Event.ID or computed key
// - in-memory store first; later optional Postgres-backed

View File

@@ -5,15 +5,11 @@ import (
"fmt"
"gitea.maximumdirect.net/ejr/feedkit/event"
"gitea.maximumdirect.net/ejr/feedkit/processors"
)
// Processor can mutate/drop events (dedupe, rate-limit, normalization tweaks).
type Processor interface {
Process(ctx context.Context, in event.Event) (out *event.Event, err error)
}
type Pipeline struct {
Processors []Processor
Processors []processors.Processor
}
func (p *Pipeline) Process(ctx context.Context, e event.Event) (*event.Event, error) {

115
pipeline/pipeline_test.go Normal file
View File

@@ -0,0 +1,115 @@
package pipeline
import (
"context"
"fmt"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
"gitea.maximumdirect.net/ejr/feedkit/processors"
)
type procFunc func(context.Context, event.Event) (*event.Event, error)
func (f procFunc) Process(ctx context.Context, in event.Event) (*event.Event, error) {
return f(ctx, in)
}
func TestPipelineProcessSequentialOrder(t *testing.T) {
var gotOrder []string
p := &Pipeline{
Processors: []processors.Processor{
procFunc(func(_ context.Context, in event.Event) (*event.Event, error) {
gotOrder = append(gotOrder, "first")
out := in
out.Schema = "stage.one.v1"
return &out, nil
}),
procFunc(func(_ context.Context, in event.Event) (*event.Event, error) {
gotOrder = append(gotOrder, "second")
if in.Schema != "stage.one.v1" {
return nil, fmt.Errorf("expected schema from first stage, got %q", in.Schema)
}
out := in
out.Schema = "stage.two.v1"
return &out, nil
}),
},
}
out, err := p.Process(context.Background(), validEvent())
if err != nil {
t.Fatalf("Process error: %v", err)
}
if out == nil {
t.Fatalf("expected output event, got nil")
}
if out.Schema != "stage.two.v1" {
t.Fatalf("unexpected output schema: %q", out.Schema)
}
if strings.Join(gotOrder, ",") != "first,second" {
t.Fatalf("unexpected processor order: %v", gotOrder)
}
}
func TestPipelineProcessInvalidInput(t *testing.T) {
p := &Pipeline{}
_, err := p.Process(context.Background(), event.Event{})
if err == nil {
t.Fatalf("expected input validation error")
}
if !strings.Contains(err.Error(), "invalid input event") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestPipelineProcessDrop(t *testing.T) {
p := &Pipeline{
Processors: []processors.Processor{
procFunc(func(context.Context, event.Event) (*event.Event, error) {
return nil, nil
}),
},
}
out, err := p.Process(context.Background(), validEvent())
if err != nil {
t.Fatalf("Process error: %v", err)
}
if out != nil {
t.Fatalf("expected nil output for dropped event, got %#v", out)
}
}
func TestPipelineProcessInvalidOutput(t *testing.T) {
p := &Pipeline{
Processors: []processors.Processor{
procFunc(func(_ context.Context, in event.Event) (*event.Event, error) {
out := in
out.Payload = nil
return &out, nil
}),
},
}
_, err := p.Process(context.Background(), validEvent())
if err == nil {
t.Fatalf("expected output validation error")
}
if !strings.Contains(err.Error(), "invalid output event") {
t.Fatalf("unexpected error: %v", err)
}
}
func validEvent() event.Event {
return event.Event{
ID: "evt-1",
Kind: event.Kind("observation"),
Source: "source-1",
EmittedAt: time.Now().UTC(),
Payload: map[string]any{"ok": true},
}
}

View File

@@ -1,5 +0,0 @@
package pipeline
// Placeholder for rate limit processor:
// - per source/kind sink routing limits
// - cooldown windows

28
processors/dedupe/doc.go Normal file
View File

@@ -0,0 +1,28 @@
// Package dedupe provides a default in-memory LRU deduplication processor.
//
// The processor keys strictly by event.Event.ID:
// - first-seen IDs pass through
// - repeated IDs are dropped
//
// The in-memory seen-ID set is bounded by a required maxEntries capacity.
// When capacity is exceeded, the least recently used ID is evicted.
//
// Typical registry wiring:
//
// ```go
// reg := processors.NewRegistry()
// reg.Register("dedupe", dedupe.Factory(10_000))
//
// reg.Register("normalize", func() (processors.Processor, error) {
// return normalize.NewProcessor(myNormalizers, false), nil
// })
//
// chain, err := reg.BuildChain([]string{"dedupe", "normalize"})
//
// if err != nil {
// // handle wiring error
// }
//
// p := &pipeline.Pipeline{Processors: chain}
// ```
package dedupe

View File

@@ -0,0 +1,89 @@
package dedupe
import (
"container/list"
"context"
"fmt"
"strings"
"sync"
"gitea.maximumdirect.net/ejr/feedkit/event"
"gitea.maximumdirect.net/ejr/feedkit/processors"
)
// Processor drops duplicate events by Event.ID using an in-memory LRU.
type Processor struct {
maxEntries int
mu sync.Mutex
order *list.List // most-recent at front, least-recent at back
byID map[string]*list.Element // id -> list element (element.Value is string id)
}
var _ processors.Processor = (*Processor)(nil)
// NewProcessor constructs a dedupe processor with a required max entry count.
func NewProcessor(maxEntries int) (*Processor, error) {
if maxEntries <= 0 {
return nil, fmt.Errorf("dedupe: maxEntries must be > 0, got %d", maxEntries)
}
return &Processor{
maxEntries: maxEntries,
order: list.New(),
byID: make(map[string]*list.Element, maxEntries),
}, nil
}
// Factory returns a processors.Factory that constructs Processor instances.
func Factory(maxEntries int) processors.Factory {
return func() (processors.Processor, error) {
return NewProcessor(maxEntries)
}
}
// Process implements processors.Processor.
func (p *Processor) Process(_ context.Context, in event.Event) (*event.Event, error) {
if p == nil {
return nil, fmt.Errorf("dedupe: processor is nil")
}
if p.maxEntries <= 0 {
return nil, fmt.Errorf("dedupe: processor maxEntries must be > 0")
}
id := strings.TrimSpace(in.ID)
if id == "" {
return nil, fmt.Errorf("dedupe: event ID is required")
}
p.mu.Lock()
if p.order == nil || p.byID == nil {
p.mu.Unlock()
return nil, fmt.Errorf("dedupe: processor is not initialized")
}
if elem, exists := p.byID[id]; exists {
p.order.MoveToFront(elem)
p.mu.Unlock()
return nil, nil
}
elem := p.order.PushFront(id)
p.byID[id] = elem
if p.order.Len() > p.maxEntries {
oldest := p.order.Back()
if oldest != nil {
p.order.Remove(oldest)
if oldestID, ok := oldest.Value.(string); ok {
delete(p.byID, oldestID)
}
}
}
p.mu.Unlock()
out := in
return &out, nil
}

View File

@@ -0,0 +1,163 @@
package dedupe
import (
"context"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
"gitea.maximumdirect.net/ejr/feedkit/processors"
)
func TestNewProcessorValidation(t *testing.T) {
t.Run("rejects non-positive maxEntries", func(t *testing.T) {
for _, maxEntries := range []int{0, -1} {
p, err := NewProcessor(maxEntries)
if err == nil {
t.Fatalf("expected error for maxEntries=%d, got nil", maxEntries)
}
if p != nil {
t.Fatalf("expected nil processor for maxEntries=%d", maxEntries)
}
if !strings.Contains(err.Error(), "maxEntries") {
t.Fatalf("unexpected error: %v", err)
}
}
})
t.Run("accepts positive maxEntries", func(t *testing.T) {
p, err := NewProcessor(1)
if err != nil {
t.Fatalf("NewProcessor error: %v", err)
}
if p == nil {
t.Fatalf("expected processor, got nil")
}
})
}
func TestProcessorFirstSeenAndDuplicate(t *testing.T) {
p, err := NewProcessor(8)
if err != nil {
t.Fatalf("NewProcessor error: %v", err)
}
ctx := context.Background()
first := testEvent("evt-1")
out, err := p.Process(ctx, first)
if err != nil {
t.Fatalf("Process first error: %v", err)
}
if out == nil {
t.Fatalf("expected first event to pass through")
}
if out.ID != first.ID {
t.Fatalf("expected unchanged ID %q, got %q", first.ID, out.ID)
}
out, err = p.Process(ctx, first)
if err != nil {
t.Fatalf("Process duplicate error: %v", err)
}
if out != nil {
t.Fatalf("expected duplicate to be dropped, got %#v", out)
}
out, err = p.Process(ctx, testEvent("evt-2"))
if err != nil {
t.Fatalf("Process second unique error: %v", err)
}
if out == nil {
t.Fatalf("expected second unique event to pass through")
}
}
func TestProcessorLRUEvictionAndPromotion(t *testing.T) {
p, err := NewProcessor(2)
if err != nil {
t.Fatalf("NewProcessor error: %v", err)
}
ctx := context.Background()
mustPass(t, p, ctx, "a")
mustPass(t, p, ctx, "b")
mustDrop(t, p, ctx, "a") // promote "a" so "b" becomes least-recently-used
mustPass(t, p, ctx, "c") // evicts "b"
mustDrop(t, p, ctx, "a") // "a" should still be tracked after promotion
mustPass(t, p, ctx, "b") // "b" was evicted, so now it passes again
}
func TestProcessorRejectsBlankID(t *testing.T) {
p, err := NewProcessor(4)
if err != nil {
t.Fatalf("NewProcessor error: %v", err)
}
in := testEvent(" ")
out, err := p.Process(context.Background(), in)
if err == nil {
t.Fatalf("expected error for blank ID")
}
if out != nil {
t.Fatalf("expected nil output on error, got %#v", out)
}
if !strings.Contains(err.Error(), "event ID is required") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestFactoryWithRegistry(t *testing.T) {
r := processors.NewRegistry()
r.Register("dedupe", Factory(3))
p, err := r.Build("dedupe")
if err != nil {
t.Fatalf("Build error: %v", err)
}
if p == nil {
t.Fatalf("expected processor, got nil")
}
out, err := p.Process(context.Background(), testEvent("evt-factory-1"))
if err != nil {
t.Fatalf("Process error: %v", err)
}
if out == nil {
t.Fatalf("expected first event to pass through")
}
}
func mustPass(t *testing.T, p *Processor, ctx context.Context, id string) {
t.Helper()
out, err := p.Process(ctx, testEvent(id))
if err != nil {
t.Fatalf("expected pass for id=%q, got error: %v", id, err)
}
if out == nil {
t.Fatalf("expected pass for id=%q, got drop", id)
}
}
func mustDrop(t *testing.T, p *Processor, ctx context.Context, id string) {
t.Helper()
out, err := p.Process(ctx, testEvent(id))
if err != nil {
t.Fatalf("expected drop for id=%q, got error: %v", id, err)
}
if out != nil {
t.Fatalf("expected drop for id=%q, got output", id)
}
}
func testEvent(id string) event.Event {
return event.Event{
ID: id,
Kind: event.Kind("observation"),
Source: "source-1",
EmittedAt: time.Now().UTC(),
Payload: map[string]any{"ok": true},
}
}

24
processors/doc.go Normal file
View File

@@ -0,0 +1,24 @@
// Package processors defines feedkit's generic processor abstraction and registry.
//
// Processors are optional pipeline stages that can transform, drop, or reject
// events before dispatch to sinks.
//
// Registry provides name-based construction so daemons can assemble processor
// chains without embedding switch statements in wiring code.
//
// Example:
//
// reg := processors.NewRegistry()
// reg.Register("dedupe", dedupe.Factory(10_000))
// reg.Register("normalize", func() (processors.Processor, error) {
// // import "gitea.maximumdirect.net/ejr/feedkit/processors/normalize"
// return normalize.NewProcessor(myNormalizers, false), nil
// })
//
// chain, err := reg.BuildChain([]string{"dedupe", "normalize"})
// if err != nil {
// // handle wiring error
// }
//
// p := &pipeline.Pipeline{Processors: chain}
package processors

View File

@@ -0,0 +1,19 @@
// Package normalize provides the feedkit normalization processor and related
// helper APIs for raw-to-canonical event mapping.
//
// External API surface:
// - Processor: concrete processors.Processor implementation
// - Normalizer / Func: normalization interface and ergonomic function adapter
//
// Optional helpers from helpers.go:
// - PayloadJSONBytes: extract supported JSON-shaped payloads into bytes
// - DecodeJSONPayload: decode an event payload into a typed struct
// - FinalizeEvent: copy the input event envelope onto a normalized output
//
// Typical usage:
// sources emit raw events (often with json.RawMessage payloads), then a
// normalize.Processor converts matching raw schemas into canonical payloads.
//
// Key property: normalization is optional.
// If no Normalizer matches an event, Processor passes it through unchanged by default.
package normalize

View File

@@ -0,0 +1,84 @@
package normalize
import (
"encoding/json"
"fmt"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// PayloadJSONBytes extracts a JSON payload into bytes suitable for json.Unmarshal.
//
// Supported payload shapes:
// - json.RawMessage
// - []byte
// - string
// - map[string]any
func PayloadJSONBytes(e event.Event) ([]byte, error) {
if e.Payload == nil {
return nil, fmt.Errorf("payload is nil")
}
switch v := e.Payload.(type) {
case json.RawMessage:
if len(v) == 0 {
return nil, fmt.Errorf("payload is empty json.RawMessage")
}
return []byte(v), nil
case []byte:
if len(v) == 0 {
return nil, fmt.Errorf("payload is empty []byte")
}
return v, nil
case string:
if v == "" {
return nil, fmt.Errorf("payload is empty string")
}
return []byte(v), nil
case map[string]any:
b, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("marshal map payload: %w", err)
}
return b, nil
default:
return nil, fmt.Errorf("unsupported payload type %T", e.Payload)
}
}
// DecodeJSONPayload extracts the event payload as bytes and unmarshals it into T.
func DecodeJSONPayload[T any](in event.Event) (T, error) {
var zero T
b, err := PayloadJSONBytes(in)
if err != nil {
return zero, fmt.Errorf("extract payload: %w", err)
}
var parsed T
if err := json.Unmarshal(b, &parsed); err != nil {
return zero, fmt.Errorf("decode raw payload: %w", err)
}
return parsed, nil
}
// FinalizeEvent builds the output event envelope by copying the input and applying
// the new schema/payload, plus optional EffectiveAt.
func FinalizeEvent(in event.Event, outSchema string, outPayload any, effectiveAt time.Time) (*event.Event, error) {
out := in
out.Schema = outSchema
out.Payload = outPayload
if !effectiveAt.IsZero() {
t := effectiveAt.UTC()
out.EffectiveAt = &t
}
if err := out.Validate(); err != nil {
return nil, err
}
return &out, nil
}

View File

@@ -0,0 +1,118 @@
package normalize
import (
"encoding/json"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
func TestPayloadJSONBytesSupportedShapes(t *testing.T) {
cases := []struct {
name string
payload any
want string
}{
{name: "rawmessage", payload: json.RawMessage(`{"a":1}`), want: `{"a":1}`},
{name: "bytes", payload: []byte(`{"a":2}`), want: `{"a":2}`},
{name: "string", payload: `{"a":3}`, want: `{"a":3}`},
{name: "map", payload: map[string]any{"a": 4}, want: `{"a":4}`},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, err := PayloadJSONBytes(event.Event{Payload: tc.payload})
if err != nil {
t.Fatalf("PayloadJSONBytes() unexpected error: %v", err)
}
if string(got) != tc.want {
t.Fatalf("PayloadJSONBytes() = %s, want %s", string(got), tc.want)
}
})
}
}
func TestPayloadJSONBytesRejectsInvalidPayloads(t *testing.T) {
cases := []struct {
name string
payload any
want string
}{
{name: "nil", payload: nil, want: "payload is nil"},
{name: "empty rawmessage", payload: json.RawMessage{}, want: "payload is empty json.RawMessage"},
{name: "empty bytes", payload: []byte{}, want: "payload is empty []byte"},
{name: "empty string", payload: "", want: "payload is empty string"},
{name: "unsupported", payload: 123, want: "unsupported payload type"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
_, err := PayloadJSONBytes(event.Event{Payload: tc.payload})
if err == nil {
t.Fatalf("PayloadJSONBytes() expected error")
}
if !strings.Contains(err.Error(), tc.want) {
t.Fatalf("PayloadJSONBytes() error = %q, want substring %q", err, tc.want)
}
})
}
}
func TestDecodeJSONPayload(t *testing.T) {
type payload struct {
Name string `json:"name"`
}
got, err := DecodeJSONPayload[payload](event.Event{
Payload: json.RawMessage(`{"name":"alice"}`),
})
if err != nil {
t.Fatalf("DecodeJSONPayload() unexpected error: %v", err)
}
if got.Name != "alice" {
t.Fatalf("DecodeJSONPayload() = %#v, want name alice", got)
}
}
func TestFinalizeEventPreservesEnvelopeAndEffectiveAtBehavior(t *testing.T) {
existingEffectiveAt := time.Date(2026, 3, 28, 11, 0, 0, 0, time.UTC)
in := event.Event{
ID: "evt-1",
Kind: event.Kind("observation"),
Source: "source-a",
EmittedAt: time.Date(2026, 3, 28, 12, 0, 0, 0, time.UTC),
EffectiveAt: &existingEffectiveAt,
Schema: "raw.example.v1",
Payload: map[string]any{"old": true},
}
out, err := FinalizeEvent(in, "example.v1", map[string]any{"value": 1.234567}, time.Time{})
if err != nil {
t.Fatalf("FinalizeEvent() unexpected error: %v", err)
}
if out.ID != in.ID || out.Kind != in.Kind || out.Source != in.Source || out.EmittedAt != in.EmittedAt {
t.Fatalf("FinalizeEvent() changed preserved envelope fields: %#v", out)
}
if out.EffectiveAt == nil || !out.EffectiveAt.Equal(existingEffectiveAt) {
t.Fatalf("FinalizeEvent() effectiveAt = %#v, want preserved existing value", out.EffectiveAt)
}
nextEffectiveAt := time.Date(2026, 3, 28, 13, 0, 0, 0, time.FixedZone("x", -4*3600))
out, err = FinalizeEvent(in, "example.v1", map[string]any{"value": 1.234567}, nextEffectiveAt)
if err != nil {
t.Fatalf("FinalizeEvent() unexpected overwrite error: %v", err)
}
if out.EffectiveAt == nil || !out.EffectiveAt.Equal(nextEffectiveAt.UTC()) {
t.Fatalf("FinalizeEvent() effectiveAt = %#v, want %s", out.EffectiveAt, nextEffectiveAt.UTC())
}
payloadMap, ok := out.Payload.(map[string]any)
if !ok {
t.Fatalf("FinalizeEvent() payload type = %T, want map[string]any", out.Payload)
}
if payloadMap["value"] != 1.234567 {
t.Fatalf("FinalizeEvent() payload value = %#v, want unrounded 1.234567", payloadMap["value"])
}
}

View File

@@ -0,0 +1,76 @@
package normalize
import (
"context"
"fmt"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// Normalizer converts one event shape into another.
//
// A Normalizer is typically domain-owned code (weatherfeeder/newsfeeder/...)
// that knows how to interpret a specific upstream payload and produce a
// normalized payload.
//
// Normalizers are selected via Match(). The matching strategy is intentionally
// flexible: implementations may match on Schema, Kind, Source, or any other
// Event fields.
type Normalizer interface {
// Match reports whether this normalizer applies to the given event.
//
// Common patterns:
// - match on e.Schema (recommended for versioning)
// - match on e.Source (useful if Schema is empty)
// - match on (e.Kind + e.Source), etc.
Match(e event.Event) bool
// Normalize transforms the incoming event into a new (or modified) event.
//
// Return values:
// - (out, nil) where out != nil: emit the normalized event
// - (nil, nil): drop the event (treat as policy drop)
// - (nil, err): fail the pipeline
//
// Note: If you simply want to pass the event through unchanged, return &in.
Normalize(ctx context.Context, in event.Event) (*event.Event, error)
}
// Func is an ergonomic adapter that lets you define a Normalizer with functions.
//
// Example:
//
// n := normalize.Func{
// MatchFn: func(e event.Event) bool { return e.Schema == "raw.openweather.current.v1" },
// NormalizeFn: func(ctx context.Context, in event.Event) (*event.Event, error) {
// // ... map in.Payload -> normalized payload ...
// },
// }
type Func struct {
MatchFn func(e event.Event) bool
NormalizeFn func(ctx context.Context, in event.Event) (*event.Event, error)
// Optional: helps produce nicer panic/error messages if something goes wrong.
Name string
}
func (f Func) Match(e event.Event) bool {
if f.MatchFn == nil {
return false
}
return f.MatchFn(e)
}
func (f Func) Normalize(ctx context.Context, in event.Event) (*event.Event, error) {
if f.NormalizeFn == nil {
return nil, fmt.Errorf("normalize.Func(%s): NormalizeFn is nil", f.safeName())
}
return f.NormalizeFn(ctx, in)
}
func (f Func) safeName() string {
if f.Name == "" {
return "<unnamed>"
}
return f.Name
}

View File

@@ -0,0 +1,57 @@
package normalize
import (
"context"
"fmt"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// Processor applies ordered normalization rules to pipeline events.
//
// Selection rule:
// - iterate in Normalizers order
// - the first Normalizer whose Match returns true is applied
//
// If no normalizer matches, the default behavior is pass-through.
type Processor struct {
Normalizers []Normalizer
// If true, events that do not match any normalizer cause an error.
// Default is false (pass-through).
RequireMatch bool
}
// NewProcessor constructs a normalization processor from an ordered normalizer list.
func NewProcessor(normalizers []Normalizer, requireMatch bool) Processor {
return Processor{
Normalizers: append([]Normalizer(nil), normalizers...),
RequireMatch: requireMatch,
}
}
// Process implements processors.Processor.
func (p Processor) Process(ctx context.Context, in event.Event) (*event.Event, error) {
for _, n := range p.Normalizers {
if n == nil {
continue
}
if !n.Match(in) {
continue
}
out, err := n.Normalize(ctx, in)
if err != nil {
return nil, fmt.Errorf("normalize: normalizer failed: %w", err)
}
return out, nil
}
if p.RequireMatch {
return nil, fmt.Errorf("normalize: no normalizer matched event (id=%s kind=%s source=%s schema=%q)",
in.ID, in.Kind, in.Source, in.Schema)
}
out := in
return &out, nil
}

View File

@@ -0,0 +1,139 @@
package normalize
import (
"context"
"errors"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
func TestProcessorFirstMatchWins(t *testing.T) {
var firstCalls, secondCalls int
p := NewProcessor([]Normalizer{
Func{
MatchFn: func(event.Event) bool { return true },
NormalizeFn: func(_ context.Context, in event.Event) (*event.Event, error) {
firstCalls++
out := in
out.Schema = "normalized.first.v1"
return &out, nil
},
},
Func{
MatchFn: func(event.Event) bool { return true },
NormalizeFn: func(_ context.Context, in event.Event) (*event.Event, error) {
secondCalls++
out := in
out.Schema = "normalized.second.v1"
return &out, nil
},
},
}, false)
out, err := p.Process(context.Background(), testEvent())
if err != nil {
t.Fatalf("Process error: %v", err)
}
if out == nil {
t.Fatalf("expected output event, got nil")
}
if out.Schema != "normalized.first.v1" {
t.Fatalf("unexpected schema: %q", out.Schema)
}
if firstCalls != 1 {
t.Fatalf("expected first normalizer called once, got %d", firstCalls)
}
if secondCalls != 0 {
t.Fatalf("expected second normalizer skipped, got %d calls", secondCalls)
}
}
func TestProcessorNoMatchPassThroughAndRequireMatch(t *testing.T) {
in := testEvent()
in.Schema = "raw.schema.v1"
passThrough := NewProcessor([]Normalizer{
Func{
MatchFn: func(event.Event) bool { return false },
NormalizeFn: func(_ context.Context, in event.Event) (*event.Event, error) {
out := in
out.Schema = "should.not.run"
return &out, nil
},
},
}, false)
out, err := passThrough.Process(context.Background(), in)
if err != nil {
t.Fatalf("pass-through Process error: %v", err)
}
if out == nil {
t.Fatalf("expected pass-through output event, got nil")
}
if out.Schema != "raw.schema.v1" {
t.Fatalf("expected unchanged schema, got %q", out.Schema)
}
required := NewProcessor(nil, true)
_, err = required.Process(context.Background(), in)
if err == nil {
t.Fatalf("expected require-match error")
}
if !strings.Contains(err.Error(), "no normalizer matched") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestProcessorDropAndErrorPropagation(t *testing.T) {
t.Run("drop", func(t *testing.T) {
p := NewProcessor([]Normalizer{
Func{
MatchFn: func(event.Event) bool { return true },
NormalizeFn: func(context.Context, event.Event) (*event.Event, error) {
return nil, nil
},
},
}, false)
out, err := p.Process(context.Background(), testEvent())
if err != nil {
t.Fatalf("Process error: %v", err)
}
if out != nil {
t.Fatalf("expected nil output for dropped event, got %#v", out)
}
})
t.Run("error", func(t *testing.T) {
p := NewProcessor([]Normalizer{
Func{
MatchFn: func(event.Event) bool { return true },
NormalizeFn: func(context.Context, event.Event) (*event.Event, error) {
return nil, errors.New("map failed")
},
},
}, false)
_, err := p.Process(context.Background(), testEvent())
if err == nil {
t.Fatalf("expected error")
}
if !strings.Contains(err.Error(), "normalizer failed") {
t.Fatalf("unexpected error: %v", err)
}
})
}
func testEvent() event.Event {
return event.Event{
ID: "evt-normalize-1",
Kind: event.Kind("observation"),
Source: "source-1",
EmittedAt: time.Now().UTC(),
Payload: map[string]any{"x": 1},
}
}

15
processors/processor.go Normal file
View File

@@ -0,0 +1,15 @@
package processors
import (
"context"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// Processor can mutate/drop events (dedupe, rate-limit, normalization tweaks).
type Processor interface {
Process(ctx context.Context, in event.Event) (out *event.Event, err error)
}
// Factory constructs a configured Processor instance.
type Factory func() (Processor, error)

71
processors/registry.go Normal file
View File

@@ -0,0 +1,71 @@
package processors
import (
"fmt"
"strings"
)
type Registry struct {
byDriver map[string]Factory
}
func NewRegistry() *Registry {
return &Registry{byDriver: map[string]Factory{}}
}
// Register associates a processor driver name with a factory.
//
// Register panics for empty driver names, nil factories, and duplicates.
func (r *Registry) Register(driver string, f Factory) {
if r == nil {
panic("processors.Registry.Register: registry cannot be nil")
}
driver = strings.TrimSpace(driver)
if driver == "" {
panic("processors.Registry.Register: driver cannot be empty")
}
if f == nil {
panic(fmt.Sprintf("processors.Registry.Register: factory cannot be nil (driver=%q)", driver))
}
if r.byDriver == nil {
r.byDriver = map[string]Factory{}
}
if _, exists := r.byDriver[driver]; exists {
panic(fmt.Sprintf("processors.Registry.Register: driver %q already registered", driver))
}
r.byDriver[driver] = f
}
// Build constructs a Processor by driver name.
func (r *Registry) Build(driver string) (Processor, error) {
if r == nil {
return nil, fmt.Errorf("processors registry is nil")
}
driver = strings.TrimSpace(driver)
f, ok := r.byDriver[driver]
if !ok {
return nil, fmt.Errorf("unknown processor driver: %q", driver)
}
p, err := f()
if err != nil {
return nil, fmt.Errorf("build processor %q: %w", driver, err)
}
if p == nil {
return nil, fmt.Errorf("build processor %q: factory returned nil processor", driver)
}
return p, nil
}
// BuildChain constructs an ordered processor chain from a driver list.
func (r *Registry) BuildChain(drivers []string) ([]Processor, error) {
out := make([]Processor, 0, len(drivers))
for i, driver := range drivers {
p, err := r.Build(driver)
if err != nil {
return nil, fmt.Errorf("build processor chain[%d] (%q): %w", i, strings.TrimSpace(driver), err)
}
out = append(out, p)
}
return out, nil
}

100
processors/registry_test.go Normal file
View File

@@ -0,0 +1,100 @@
package processors
import (
"context"
"errors"
"strings"
"testing"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type testProcessor struct {
name string
}
func (p testProcessor) Process(context.Context, event.Event) (*event.Event, error) {
return nil, nil
}
func TestRegistryRegisterValidation(t *testing.T) {
t.Run("empty driver panics", func(t *testing.T) {
r := NewRegistry()
assertPanics(t, func() {
r.Register(" ", func() (Processor, error) { return testProcessor{name: "x"}, nil })
})
})
t.Run("nil factory panics", func(t *testing.T) {
r := NewRegistry()
assertPanics(t, func() {
r.Register("normalize", nil)
})
})
t.Run("duplicate driver panics", func(t *testing.T) {
r := NewRegistry()
r.Register("normalize", func() (Processor, error) { return testProcessor{name: "a"}, nil })
assertPanics(t, func() {
r.Register("normalize", func() (Processor, error) { return testProcessor{name: "b"}, nil })
})
})
}
func TestRegistryBuildUnknownDriver(t *testing.T) {
r := NewRegistry()
_, err := r.Build("does_not_exist")
if err == nil {
t.Fatalf("expected error for unknown driver")
}
if !strings.Contains(err.Error(), "unknown processor driver") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestRegistryBuildChainPreservesOrder(t *testing.T) {
r := NewRegistry()
r.Register("first", func() (Processor, error) { return testProcessor{name: "first"}, nil })
r.Register("second", func() (Processor, error) { return testProcessor{name: "second"}, nil })
chain, err := r.BuildChain([]string{"first", "second"})
if err != nil {
t.Fatalf("BuildChain error: %v", err)
}
if len(chain) != 2 {
t.Fatalf("expected 2 processors, got %d", len(chain))
}
p0, ok := chain[0].(testProcessor)
if !ok || p0.name != "first" {
t.Fatalf("unexpected chain[0]: %#v", chain[0])
}
p1, ok := chain[1].(testProcessor)
if !ok || p1.name != "second" {
t.Fatalf("unexpected chain[1]: %#v", chain[1])
}
}
func TestRegistryBuildChainIndexedFailure(t *testing.T) {
r := NewRegistry()
r.Register("ok", func() (Processor, error) { return testProcessor{name: "ok"}, nil })
r.Register("broken", func() (Processor, error) { return nil, errors.New("boom") })
_, err := r.BuildChain([]string{"ok", "broken"})
if err == nil {
t.Fatalf("expected error")
}
if !strings.Contains(err.Error(), "chain[1]") {
t.Fatalf("expected indexed error, got: %v", err)
}
}
func assertPanics(t *testing.T, fn func()) {
t.Helper()
defer func() {
if recover() == nil {
t.Fatalf("expected panic")
}
}()
fn()
}

25
scheduler/doc.go Normal file
View File

@@ -0,0 +1,25 @@
// Package scheduler runs feedkit sources and forwards their events to the
// daemon event bus.
//
// External API surface:
// - Scheduler: runs configured polling and streaming jobs
// - Job: one scheduler task bound to a source
// - StreamExitPolicy: stream supervision policy for non-fatal exits
// - StreamBackoff: restart pacing for supervised stream sources
//
// Optional helpers from helpers.go:
// - JobFromSourceConfig: build a scheduler job from a configured source and
// feedkit-owned scheduling params
//
// Poll sources are run on a fixed cadence with optional jitter. Stream sources
// are supervised long-lived workers. Their generic feedkit controls live under
// sources[].params:
// - stream_exit_policy: restart|stop|fatal (default restart)
// - stream_backoff_initial: positive duration (default 1s)
// - stream_backoff_max: positive duration (default 1m)
// - stream_backoff_jitter: non-negative duration (default 250ms)
//
// Stream sources can classify exits with sources.StreamRetryable and
// sources.StreamFatal. Plain errors are treated as retryable by default, while
// fatal exits are propagated from Scheduler.Run so the daemon can shut down.
package scheduler

138
scheduler/helpers.go Normal file
View File

@@ -0,0 +1,138 @@
package scheduler
import (
"fmt"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/sources"
)
// JobFromSourceConfig builds a scheduler Job from a configured source and its
// generic feedkit config.
func JobFromSourceConfig(src sources.Input, cfg config.SourceConfig) (Job, error) {
if src == nil {
return Job{}, fmt.Errorf("scheduler: source %q is nil", cfg.Name)
}
job := Job{
Source: src,
Every: cfg.Every.Duration,
}
if _, ok := src.(sources.StreamSource); ok {
if cfg.Every.Duration > 0 {
return Job{}, fmt.Errorf("source %q: sources[].every must be omitted for stream sources", cfg.Name)
}
policy, err := parseStreamExitPolicy(cfg)
if err != nil {
return Job{}, err
}
backoff, err := parseStreamBackoff(cfg)
if err != nil {
return Job{}, err
}
job.StreamExitPolicy = policy
job.StreamBackoff = backoff
return job, nil
}
if _, ok := src.(sources.PollSource); ok {
if cfg.Every.Duration <= 0 {
return Job{}, fmt.Errorf("source %q: sources[].every must be > 0 for polling sources", cfg.Name)
}
if err := rejectStreamParams(cfg); err != nil {
return Job{}, err
}
return job, nil
}
return Job{}, fmt.Errorf("scheduler: source %q implements neither PollSource nor StreamSource", cfg.Name)
}
func parseStreamExitPolicy(cfg config.SourceConfig) (StreamExitPolicy, error) {
const key = "stream_exit_policy"
raw, exists := cfg.Params[key]
if !exists || raw == nil {
return StreamExitPolicyRestart, nil
}
s, ok := raw.(string)
if !ok {
return "", fmt.Errorf("source %q: params.%s must be one of: restart, stop, fatal", cfg.Name, key)
}
switch StreamExitPolicy(strings.ToLower(strings.TrimSpace(s))) {
case StreamExitPolicyRestart:
return StreamExitPolicyRestart, nil
case StreamExitPolicyStop:
return StreamExitPolicyStop, nil
case StreamExitPolicyFatal:
return StreamExitPolicyFatal, nil
default:
return "", fmt.Errorf("source %q: params.%s must be one of: restart, stop, fatal", cfg.Name, key)
}
}
func parseStreamBackoff(cfg config.SourceConfig) (StreamBackoff, error) {
initial, err := parsePositiveOrDefaultDuration(cfg, "stream_backoff_initial", defaultStreamBackoffInitial)
if err != nil {
return StreamBackoff{}, err
}
max, err := parsePositiveOrDefaultDuration(cfg, "stream_backoff_max", defaultStreamBackoffMax)
if err != nil {
return StreamBackoff{}, err
}
jitter, err := parseNonNegativeOrDefaultDuration(cfg, "stream_backoff_jitter", defaultStreamBackoffJitter)
if err != nil {
return StreamBackoff{}, err
}
if max < initial {
return StreamBackoff{}, fmt.Errorf("source %q: params.stream_backoff_max must be >= params.stream_backoff_initial", cfg.Name)
}
return StreamBackoff{
Initial: initial,
Max: max,
Jitter: jitter,
}, nil
}
func rejectStreamParams(cfg config.SourceConfig) error {
streamKeys := []string{
"stream_exit_policy",
"stream_backoff_initial",
"stream_backoff_max",
"stream_backoff_jitter",
}
for _, key := range streamKeys {
if _, ok := cfg.Params[key]; ok {
return fmt.Errorf("source %q: params.%s is only valid for stream sources", cfg.Name, key)
}
}
return nil
}
func parsePositiveOrDefaultDuration(cfg config.SourceConfig, key string, def time.Duration) (time.Duration, error) {
if _, exists := cfg.Params[key]; !exists {
return def, nil
}
v, ok := cfg.ParamDuration(key)
if !ok || v <= 0 {
return 0, fmt.Errorf("source %q: params.%s must be a positive duration", cfg.Name, key)
}
return v, nil
}
func parseNonNegativeOrDefaultDuration(cfg config.SourceConfig, key string, def time.Duration) (time.Duration, error) {
if _, exists := cfg.Params[key]; !exists {
return def, nil
}
v, ok := cfg.ParamDuration(key)
if !ok || v < 0 {
return 0, fmt.Errorf("source %q: params.%s must be a non-negative duration", cfg.Name, key)
}
return v, nil
}

View File

@@ -5,25 +5,60 @@ import (
"fmt"
"hash/fnv"
"math/rand"
"sync"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
"gitea.maximumdirect.net/ejr/feedkit/logging"
"gitea.maximumdirect.net/ejr/feedkit/sources"
)
// Logger is a printf-style logger used throughout scheduler.
// It is an alias to the shared feedkit logging type so callers can pass
// one function everywhere without type mismatch friction.
type Logger = logging.Logf
// Job describes one scheduler task.
//
// A Job may be backed by either:
// - a polling source (sources.PollSource): uses Every + jitter and calls Poll()
// - a stream source (sources.StreamSource): ignores Every and calls Run()
//
// Jitter behavior:
// - For polling sources: Jitter is applied at startup and before each poll tick.
// - For stream sources: Jitter is applied once at startup only (optional; useful to avoid
// reconnect storms when many instances start together).
type Job struct {
Source sources.Source
Source sources.Input
Every time.Duration
StreamExitPolicy StreamExitPolicy
StreamBackoff StreamBackoff
// Jitter is the maximum additional delay added before each poll.
// Example: if Every=15m and Jitter=30s, each poll will occur at:
// tick time + random(0..30s)
//
// If Jitter == 0, we compute a default jitter based on Every.
// If Jitter == 0 for polling sources, we compute a default jitter based on Every.
//
// For stream sources, Jitter is treated as *startup jitter only*.
Jitter time.Duration
}
type Logger func(format string, args ...any)
// StreamExitPolicy controls how the scheduler handles non-fatal stream exits.
type StreamExitPolicy string
const (
StreamExitPolicyRestart StreamExitPolicy = "restart"
StreamExitPolicyStop StreamExitPolicy = "stop"
StreamExitPolicyFatal StreamExitPolicy = "fatal"
)
// StreamBackoff controls restart pacing for stream supervision.
type StreamBackoff struct {
Initial time.Duration
Max time.Duration
Jitter time.Duration
}
type Scheduler struct {
Jobs []Job
@@ -31,8 +66,18 @@ type Scheduler struct {
Logf Logger
}
// Run starts one polling goroutine per job.
// Each job runs on its own interval and emits 0..N events per poll.
const (
defaultStreamBackoffInitial = 1 * time.Second
defaultStreamBackoffMax = 1 * time.Minute
defaultStreamBackoffJitter = 250 * time.Millisecond
streamBackoffResetAfter = 5 * time.Minute
)
var timeNow = time.Now
// Run starts one goroutine per job.
// Poll jobs run on their own interval and emit 0..N events per poll.
// Stream jobs run continuously and emit events as they arrive.
func (s *Scheduler) Run(ctx context.Context) error {
if s.Out == nil {
return fmt.Errorf("scheduler.Run: Out channel is nil")
@@ -41,31 +86,117 @@ func (s *Scheduler) Run(ctx context.Context) error {
return fmt.Errorf("scheduler.Run: no jobs configured")
}
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
fatalErrCh := make(chan error, 1)
var wg sync.WaitGroup
for _, job := range s.Jobs {
job := job // capture loop variable
go s.runJob(ctx, job)
wg.Add(1)
go func() {
defer wg.Done()
s.runJob(runCtx, job, fatalErrCh)
}()
}
<-ctx.Done()
return ctx.Err()
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case err := <-fatalErrCh:
cancel()
<-done
return err
case <-runCtx.Done():
<-done
return runCtx.Err()
}
}
func (s *Scheduler) runJob(ctx context.Context, job Job) {
func (s *Scheduler) runJob(ctx context.Context, job Job, fatalErrCh chan<- error) {
if job.Source == nil {
s.logf("scheduler: job has nil source")
return
}
if job.Every <= 0 {
s.logf("scheduler: job %s has invalid interval", job.Source.Name())
// Stream sources: event-driven.
if ss, ok := job.Source.(sources.StreamSource); ok {
s.runStream(ctx, job, ss, fatalErrCh)
return
}
// Poll sources: time-based.
ps, ok := job.Source.(sources.PollSource)
if !ok {
s.logf("scheduler: source %T (%s) implements neither Poll() nor Run()", job.Source, job.Source.Name())
return
}
if job.Every <= 0 {
s.logf("scheduler: polling job %q missing/invalid interval (sources[].every)", ps.Name())
return
}
s.runPoller(ctx, job, ps)
}
func (s *Scheduler) runStream(ctx context.Context, job Job, src sources.StreamSource, fatalErrCh chan<- error) {
policy := effectiveStreamExitPolicy(job.StreamExitPolicy)
backoff := effectiveStreamBackoff(job.StreamBackoff)
rng := seededRNG(src.Name())
// Optional startup jitter: helps avoid reconnect storms if many daemons start at once.
if job.Jitter > 0 {
if !sleepJitter(ctx, rng, job.Jitter) {
return
}
}
nextDelay := backoff.Initial
for {
startedAt := timeNow()
err := src.Run(ctx, s.Out)
if ctx.Err() != nil {
return
}
normalizedErr := normalizeStreamExitError(src.Name(), err)
if sources.IsStreamFatal(normalizedErr) {
s.reportFatal(fatalErrCh, fmt.Errorf("scheduler: stream source %q exited fatally: %w", src.Name(), normalizedErr))
return
}
switch policy {
case StreamExitPolicyStop:
s.logf("scheduler: stream source %q stopped after exit: %v", src.Name(), normalizedErr)
return
case StreamExitPolicyFatal:
s.reportFatal(fatalErrCh, fmt.Errorf("scheduler: stream source %q exited under fatal policy: %w", src.Name(), normalizedErr))
return
}
if streamRunWasStable(startedAt, timeNow()) {
nextDelay = backoff.Initial
}
delay := nextDelay + randomDuration(rng, backoff.Jitter)
s.logf("scheduler: stream source %q exited; restarting in %s: %v", src.Name(), delay, normalizedErr)
if !sleepDuration(ctx, delay) {
return
}
nextDelay = nextStreamBackoff(nextDelay, backoff.Max)
}
}
func (s *Scheduler) runPoller(ctx context.Context, job Job, src sources.PollSource) {
// Compute jitter: either configured per job, or a sensible default.
jitter := effectiveJitter(job.Every, job.Jitter)
// Each worker gets its own RNG (safe + no lock contention).
seed := time.Now().UnixNano() ^ int64(hashStringFNV32a(job.Source.Name()))
rng := rand.New(rand.NewSource(seed))
rng := seededRNG(src.Name())
// Optional startup jitter: avoids all jobs firing at the exact moment the daemon starts.
if !sleepJitter(ctx, rng, jitter) {
@@ -73,7 +204,7 @@ func (s *Scheduler) runJob(ctx context.Context, job Job) {
}
// Immediate poll at startup (after startup jitter).
s.pollOnce(ctx, job)
s.pollOnce(ctx, src)
t := time.NewTicker(job.Every)
defer t.Stop()
@@ -85,7 +216,7 @@ func (s *Scheduler) runJob(ctx context.Context, job Job) {
if !sleepJitter(ctx, rng, jitter) {
return
}
s.pollOnce(ctx, job)
s.pollOnce(ctx, src)
case <-ctx.Done():
return
@@ -93,10 +224,10 @@ func (s *Scheduler) runJob(ctx context.Context, job Job) {
}
}
func (s *Scheduler) pollOnce(ctx context.Context, job Job) {
events, err := job.Source.Poll(ctx)
func (s *Scheduler) pollOnce(ctx context.Context, src sources.PollSource) {
events, err := src.Poll(ctx)
if err != nil {
s.logf("scheduler: poll failed (%s): %v", job.Source.Name(), err)
s.logf("scheduler: poll failed (%s): %v", src.Name(), err)
return
}
@@ -116,6 +247,80 @@ func (s *Scheduler) logf(format string, args ...any) {
s.Logf(format, args...)
}
func (s *Scheduler) reportFatal(ch chan<- error, err error) {
if err == nil {
return
}
select {
case ch <- err:
default:
}
}
// ---- helpers ----
func effectiveStreamExitPolicy(policy StreamExitPolicy) StreamExitPolicy {
switch policy {
case StreamExitPolicyStop, StreamExitPolicyFatal:
return policy
default:
return StreamExitPolicyRestart
}
}
func effectiveStreamBackoff(cfg StreamBackoff) StreamBackoff {
out := cfg
if out.Initial <= 0 {
out.Initial = defaultStreamBackoffInitial
}
if out.Max <= 0 {
out.Max = defaultStreamBackoffMax
}
if out.Max < out.Initial {
out.Max = out.Initial
}
if out.Jitter < 0 {
out.Jitter = 0
}
return out
}
func normalizeStreamExitError(sourceName string, err error) error {
if err != nil {
return err
}
return sources.StreamRetryable(fmt.Errorf("stream source %q exited unexpectedly without error", sourceName))
}
func nextStreamBackoff(current, max time.Duration) time.Duration {
if current <= 0 {
current = defaultStreamBackoffInitial
}
if max <= 0 {
max = defaultStreamBackoffMax
}
if current >= max {
return max
}
next := current * 2
if next < current || next > max {
return max
}
return next
}
func streamRunWasStable(startedAt, endedAt time.Time) bool {
if startedAt.IsZero() || endedAt.IsZero() {
return false
}
return endedAt.Sub(startedAt) >= streamBackoffResetAfter
}
func seededRNG(name string) *rand.Rand {
seed := timeNow().UnixNano() ^ int64(hashStringFNV32a(name))
return rand.New(rand.NewSource(seed))
}
// effectiveJitter chooses a jitter value.
// - If configuredMax > 0, use it (but clamp).
// - Else default to min(every/10, 30s).
@@ -151,11 +356,23 @@ func sleepJitter(ctx context.Context, rng *rand.Rand, max time.Duration) bool {
return true
}
return sleepDuration(ctx, randomDuration(rng, max))
}
func randomDuration(rng *rand.Rand, max time.Duration) time.Duration {
if max <= 0 {
return 0
}
// Int63n requires a positive argument.
// We add 1 so max itself is attainable.
n := rng.Int63n(int64(max) + 1)
d := time.Duration(n)
return time.Duration(n)
}
func sleepDuration(ctx context.Context, d time.Duration) bool {
if d <= 0 {
return true
}
timer := time.NewTimer(d)
defer timer.Stop()

472
scheduler/scheduler_test.go Normal file
View File

@@ -0,0 +1,472 @@
package scheduler
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
"gitea.maximumdirect.net/ejr/feedkit/sources"
)
type testPollSource struct {
name string
}
func (s testPollSource) Name() string { return s.name }
func (s testPollSource) Poll(context.Context) ([]event.Event, error) { return nil, nil }
type scriptedStreamSource struct {
name string
mu sync.Mutex
calls int
runs []func(context.Context, chan<- event.Event) error
}
func (s *scriptedStreamSource) Name() string { return s.name }
func (s *scriptedStreamSource) Run(ctx context.Context, out chan<- event.Event) error {
s.mu.Lock()
call := s.calls
s.calls++
var run func(context.Context, chan<- event.Event) error
if call < len(s.runs) {
run = s.runs[call]
}
s.mu.Unlock()
if run != nil {
return run(ctx, out)
}
<-ctx.Done()
return ctx.Err()
}
func (s *scriptedStreamSource) CallCount() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.calls
}
type capturingLogger struct {
mu sync.Mutex
lines []string
}
func (l *capturingLogger) Logf(format string, args ...any) {
l.mu.Lock()
defer l.mu.Unlock()
l.lines = append(l.lines, fmt.Sprintf(format, args...))
}
func (l *capturingLogger) Contains(substr string) bool {
l.mu.Lock()
defer l.mu.Unlock()
for _, line := range l.lines {
if strings.Contains(line, substr) {
return true
}
}
return false
}
func TestSchedulerRunRestartsPlainStreamErrors(t *testing.T) {
src := &scriptedStreamSource{
name: "stream-a",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return errors.New("temporary failure") },
func(context.Context, chan<- event.Event) error { return errors.New("temporary failure") },
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamBackoff: StreamBackoff{
Initial: time.Millisecond,
Max: time.Millisecond,
},
}},
Out: make(chan event.Event, 1),
}
errCh := make(chan error, 1)
go func() { errCh <- s.Run(ctx) }()
waitFor(t, func() bool { return src.CallCount() >= 3 })
cancel()
err := <-errCh
if !errors.Is(err, context.Canceled) {
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
}
if src.CallCount() < 3 {
t.Fatalf("stream call count = %d, want at least 3", src.CallCount())
}
}
func TestSchedulerRunFatalStreamErrorReturns(t *testing.T) {
base := errors.New("fatal failure")
src := &scriptedStreamSource{
name: "stream-fatal",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return sources.StreamFatal(base) },
},
}
s := &Scheduler{
Jobs: []Job{{Source: src}},
Out: make(chan event.Event, 1),
}
err := s.Run(context.Background())
if err == nil {
t.Fatalf("Scheduler.Run() error = nil, want fatal error")
}
if !sources.IsStreamFatal(err) {
t.Fatalf("Scheduler.Run() error = %v, want fatal classification", err)
}
if !errors.Is(err, base) {
t.Fatalf("Scheduler.Run() error does not wrap base fatal error: %v", err)
}
}
func TestSchedulerRunStopPolicyStopsOnlyThatSource(t *testing.T) {
src := &scriptedStreamSource{
name: "stream-stop",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return errors.New("stop now") },
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamExitPolicy: StreamExitPolicyStop,
}},
Out: make(chan event.Event, 1),
}
errCh := make(chan error, 1)
go func() { errCh <- s.Run(ctx) }()
waitFor(t, func() bool { return src.CallCount() >= 1 })
time.Sleep(20 * time.Millisecond)
select {
case err := <-errCh:
t.Fatalf("Scheduler.Run() returned early: %v", err)
default:
}
if src.CallCount() != 1 {
t.Fatalf("stream call count = %d, want 1", src.CallCount())
}
cancel()
err := <-errCh
if !errors.Is(err, context.Canceled) {
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
}
}
func TestSchedulerRunFatalPolicyTreatsPlainErrorAsFatal(t *testing.T) {
base := errors.New("plain failure")
src := &scriptedStreamSource{
name: "stream-fatal-policy",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return base },
},
}
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamExitPolicy: StreamExitPolicyFatal,
}},
Out: make(chan event.Event, 1),
}
err := s.Run(context.Background())
if err == nil {
t.Fatalf("Scheduler.Run() error = nil, want fatal-policy error")
}
if !errors.Is(err, base) {
t.Fatalf("Scheduler.Run() error does not wrap base error: %v", err)
}
}
func TestSchedulerRunNilExitRestartsAsUnexpected(t *testing.T) {
logger := &capturingLogger{}
src := &scriptedStreamSource{
name: "stream-nil-exit",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return nil },
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamBackoff: StreamBackoff{
Initial: time.Millisecond,
Max: time.Millisecond,
},
}},
Out: make(chan event.Event, 1),
Logf: logger.Logf,
}
errCh := make(chan error, 1)
go func() { errCh <- s.Run(ctx) }()
waitFor(t, func() bool { return src.CallCount() >= 2 })
cancel()
err := <-errCh
if !errors.Is(err, context.Canceled) {
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
}
if !logger.Contains("exited unexpectedly without error") {
t.Fatalf("expected log to mention unexpected nil stream exit")
}
}
func TestSchedulerRunContextCancelDuringBackoff(t *testing.T) {
src := &scriptedStreamSource{
name: "stream-backoff-cancel",
runs: []func(context.Context, chan<- event.Event) error{
func(context.Context, chan<- event.Event) error { return errors.New("retry me") },
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := &Scheduler{
Jobs: []Job{{
Source: src,
StreamBackoff: StreamBackoff{
Initial: time.Second,
Max: time.Second,
},
}},
Out: make(chan event.Event, 1),
}
errCh := make(chan error, 1)
go func() { errCh <- s.Run(ctx) }()
waitFor(t, func() bool { return src.CallCount() >= 1 })
cancel()
err := <-errCh
if !errors.Is(err, context.Canceled) {
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
}
time.Sleep(20 * time.Millisecond)
if src.CallCount() != 1 {
t.Fatalf("stream call count = %d, want 1", src.CallCount())
}
}
func TestNextStreamBackoffCapsAtMax(t *testing.T) {
if got := nextStreamBackoff(500*time.Millisecond, 2*time.Second); got != time.Second {
t.Fatalf("nextStreamBackoff() = %s, want 1s", got)
}
if got := nextStreamBackoff(time.Second, 2*time.Second); got != 2*time.Second {
t.Fatalf("nextStreamBackoff() = %s, want 2s", got)
}
if got := nextStreamBackoff(2*time.Second, 2*time.Second); got != 2*time.Second {
t.Fatalf("nextStreamBackoff() = %s, want 2s", got)
}
}
func TestStreamRunWasStableAfterFiveMinutes(t *testing.T) {
start := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)
if streamRunWasStable(start, start.Add(4*time.Minute+59*time.Second)) {
t.Fatalf("streamRunWasStable() = true, want false")
}
if !streamRunWasStable(start, start.Add(5*time.Minute)) {
t.Fatalf("streamRunWasStable() = false, want true")
}
}
func TestJobFromSourceConfigPollSource(t *testing.T) {
job, err := JobFromSourceConfig(testPollSource{name: "poll-a"}, config.SourceConfig{
Name: "poll-a",
Driver: "poll_driver",
Every: config.Duration{Duration: time.Minute},
})
if err != nil {
t.Fatalf("JobFromSourceConfig() error = %v", err)
}
if job.Every != time.Minute {
t.Fatalf("Job.Every = %s, want 1m", job.Every)
}
}
func TestJobFromSourceConfigPollSourceRejectsStreamParams(t *testing.T) {
_, err := JobFromSourceConfig(testPollSource{name: "poll-a"}, config.SourceConfig{
Name: "poll-a",
Driver: "poll_driver",
Every: config.Duration{Duration: time.Minute},
Params: map[string]any{
"stream_exit_policy": "restart",
},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want rejection")
}
if !strings.Contains(err.Error(), "only valid for stream sources") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
}
func TestJobFromSourceConfigStreamSourceParsesDefaultsAndOverrides(t *testing.T) {
src := &scriptedStreamSource{name: "stream-a"}
job, err := JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-a",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Params: map[string]any{
"stream_exit_policy": "stop",
"stream_backoff_initial": "2s",
"stream_backoff_max": "10s",
"stream_backoff_jitter": "500ms",
},
})
if err != nil {
t.Fatalf("JobFromSourceConfig() error = %v", err)
}
if job.StreamExitPolicy != StreamExitPolicyStop {
t.Fatalf("Job.StreamExitPolicy = %q, want %q", job.StreamExitPolicy, StreamExitPolicyStop)
}
if job.StreamBackoff.Initial != 2*time.Second {
t.Fatalf("Job.StreamBackoff.Initial = %s, want 2s", job.StreamBackoff.Initial)
}
if job.StreamBackoff.Max != 10*time.Second {
t.Fatalf("Job.StreamBackoff.Max = %s, want 10s", job.StreamBackoff.Max)
}
if job.StreamBackoff.Jitter != 500*time.Millisecond {
t.Fatalf("Job.StreamBackoff.Jitter = %s, want 500ms", job.StreamBackoff.Jitter)
}
defaultJob, err := JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-default",
Driver: "stream_driver",
Mode: config.SourceModeStream,
})
if err != nil {
t.Fatalf("JobFromSourceConfig() default error = %v", err)
}
if defaultJob.StreamExitPolicy != StreamExitPolicyRestart {
t.Fatalf("default Job.StreamExitPolicy = %q, want restart", defaultJob.StreamExitPolicy)
}
if defaultJob.StreamBackoff.Initial != defaultStreamBackoffInitial {
t.Fatalf("default Job.StreamBackoff.Initial = %s, want %s", defaultJob.StreamBackoff.Initial, defaultStreamBackoffInitial)
}
if defaultJob.StreamBackoff.Max != defaultStreamBackoffMax {
t.Fatalf("default Job.StreamBackoff.Max = %s, want %s", defaultJob.StreamBackoff.Max, defaultStreamBackoffMax)
}
if defaultJob.StreamBackoff.Jitter != defaultStreamBackoffJitter {
t.Fatalf("default Job.StreamBackoff.Jitter = %s, want %s", defaultJob.StreamBackoff.Jitter, defaultStreamBackoffJitter)
}
}
func TestJobFromSourceConfigStreamSourceRejectsInvalidSettings(t *testing.T) {
src := &scriptedStreamSource{name: "stream-b"}
_, err := JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-b",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Params: map[string]any{
"stream_exit_policy": "sometimes",
},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want invalid policy error")
}
if !strings.Contains(err.Error(), "stream_exit_policy") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
_, err = JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-b",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Params: map[string]any{
"stream_backoff_initial": "0s",
},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want invalid initial backoff error")
}
if !strings.Contains(err.Error(), "stream_backoff_initial") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
_, err = JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-b",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Params: map[string]any{
"stream_backoff_initial": "2s",
"stream_backoff_max": "1s",
},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want invalid max backoff error")
}
if !strings.Contains(err.Error(), "stream_backoff_max") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
}
func TestJobFromSourceConfigStreamSourceRejectsEvery(t *testing.T) {
src := &scriptedStreamSource{name: "stream-c"}
_, err := JobFromSourceConfig(src, config.SourceConfig{
Name: "stream-c",
Driver: "stream_driver",
Mode: config.SourceModeStream,
Every: config.Duration{Duration: time.Minute},
})
if err == nil {
t.Fatalf("JobFromSourceConfig() error = nil, want every rejection")
}
if !strings.Contains(err.Error(), "sources[].every must be omitted") {
t.Fatalf("JobFromSourceConfig() error = %q", err)
}
}
func waitFor(t *testing.T, cond func() bool) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("condition not satisfied before timeout")
}

View File

@@ -1,7 +0,0 @@
package scheduler
// Placeholder for per-source worker logic:
// - ticker loop
// - jitter
// - backoff on errors
// - emits events into scheduler.Out

View File

@@ -1,11 +1,6 @@
package sinks
import (
"fmt"
"strings"
"gitea.maximumdirect.net/ejr/feedkit/config"
)
import "gitea.maximumdirect.net/ejr/feedkit/config"
// RegisterBuiltins registers sink drivers included in this binary.
//
@@ -17,39 +12,8 @@ func RegisterBuiltins(r *Registry) {
return NewStdoutSink(cfg.Name), nil
})
// File sink: writes/archives events somewhere on disk.
r.Register("file", func(cfg config.SinkConfig) (Sink, error) {
return NewFileSinkFromConfig(cfg)
})
// Postgres sink: persists events durably.
r.Register("postgres", func(cfg config.SinkConfig) (Sink, error) {
return NewPostgresSinkFromConfig(cfg)
})
// RabbitMQ sink: publishes events to a broker for downstream consumers.
r.Register("rabbitmq", func(cfg config.SinkConfig) (Sink, error) {
return NewRabbitMQSinkFromConfig(cfg)
// NATS sink: publishes events to a broker for downstream consumers.
r.Register("nats", func(cfg config.SinkConfig) (Sink, error) {
return NewNATSSinkFromConfig(cfg)
})
}
// ---- helpers for validating sink params ----
//
// These helpers live in sinks (not config) on purpose:
// - config is domain-agnostic and should not embed driver-specific validation helpers.
// - sinks are adapters; validating their own params here keeps the logic near the driver.
func requireStringParam(cfg config.SinkConfig, key string) (string, error) {
v, ok := cfg.Params[key]
if !ok {
return "", fmt.Errorf("sink %q: params.%s is required", cfg.Name, key)
}
s, ok := v.(string)
if !ok {
return "", fmt.Errorf("sink %q: params.%s must be a string", cfg.Name, key)
}
if strings.TrimSpace(s) == "" {
return "", fmt.Errorf("sink %q: params.%s cannot be empty", cfg.Name, key)
}
return s, nil
}

101
sinks/doc.go Normal file
View File

@@ -0,0 +1,101 @@
// Package sinks defines the feedkit sink interface, sink driver registry, and
// built-in infrastructure sinks.
//
// External API surface:
// - Sink: adapter interface that consumes event.Event values
// - Registry / NewRegistry: named sink factory registry
// - RegisterBuiltins: registers the schema-free built-in sink drivers
//
// Built-in sink implementations:
// - stdout
// - nats
// - postgres
//
// Optional helpers from helpers.go:
// - PostgresFactory: returns a sink factory for the built-in Postgres sink
// using a provided downstream schema
//
// # NATS built-in overview
//
// The NATS sink publishes each event as JSON to a configured subject.
//
// Required params:
// - url: NATS server URL (for example, nats://localhost:4222)
// - subject: NATS subject to publish to
//
// Example config:
//
// sinks:
// - name: nats_main
// driver: nats
// params:
// url: nats://localhost:4222
// subject: feedkit.events
//
// # Postgres built-in overview
//
// The postgres sink is intentionally split between downstream daemon ownership
// and feedkit ownership:
// - downstream code registers table schema + event mapping functions
// - feedkit manages DB connection, create-if-missing DDL, transactional
// inserts, optional automatic retention pruning, and manual prune helpers
//
// Example config:
//
// sinks:
// - name: pg_main
// driver: postgres
// params:
// uri: postgres://localhost:5432/feedkit?sslmode=disable
// username: feedkit_user
// password: feedkit_pass
// prune: 3d # optional: prune rows older than now-3d on each write tx
//
// params.prune supports:
// - Go duration strings (72h, 90m, 30s, ...)
// - day/week suffixes (3d, 2w)
//
// If params.prune is omitted, automatic pruning is disabled.
//
// Example downstream wiring:
//
// sinkReg := sinks.NewRegistry()
// sinks.RegisterBuiltins(sinkReg)
// sinkReg.Register("postgres", sinks.PostgresFactory(sinks.PostgresSchema{
// Tables: []sinks.PostgresTable{
// {
// Name: "events",
// Columns: []sinks.PostgresColumn{
// {Name: "event_id", Type: "TEXT", Nullable: false},
// {Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
// {Name: "payload_json", Type: "JSONB", Nullable: false},
// },
// PrimaryKey: []string{"event_id"},
// PruneColumn: "emitted_at", // required for retention pruning
// },
// },
// MapEvent: func(ctx context.Context, e event.Event) ([]sinks.PostgresWrite, error) {
// b, err := json.Marshal(e.Payload)
// if err != nil {
// return nil, err
// }
// return []sinks.PostgresWrite{
// {
// Table: "events",
// Values: map[string]any{
// "event_id": e.ID,
// "emitted_at": e.EmittedAt,
// "payload_json": string(b),
// },
// },
// }, nil
// },
// }))
//
// Manual pruning via type assertion (administrative helpers):
//
// if p, ok := sink.(sinks.PostgresPruner); ok {
// _, _ = p.PruneKeepLatest(ctx, "events", 10000)
// _, _ = p.PruneOlderThan(ctx, "events", time.Now().UTC().AddDate(0, -1, 0))
// }
package sinks

View File

@@ -1,30 +0,0 @@
package sinks
import (
"context"
"fmt"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type FileSink struct {
name string
path string
}
func NewFileSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
path, err := requireStringParam(cfg, "path")
if err != nil {
return nil, err
}
return &FileSink{name: cfg.Name, path: path}, nil
}
func (s *FileSink) Name() string { return s.name }
func (s *FileSink) Consume(ctx context.Context, e event.Event) error {
_ = ctx
_ = e
return fmt.Errorf("file sink: TODO implement (path=%s)", s.path)
}

35
sinks/helpers.go Normal file
View File

@@ -0,0 +1,35 @@
package sinks
import (
"fmt"
"strings"
"gitea.maximumdirect.net/ejr/feedkit/config"
)
// requireStringParam returns a non-empty string sink param.
//
// This helper is intentionally local to sinks rather than config so
// driver-specific validation stays close to the adapters that use it.
func requireStringParam(cfg config.SinkConfig, key string) (string, error) {
v, ok := cfg.Params[key]
if !ok {
return "", fmt.Errorf("sink %q: params.%s is required", cfg.Name, key)
}
s, ok := v.(string)
if !ok {
return "", fmt.Errorf("sink %q: params.%s must be a string", cfg.Name, key)
}
if strings.TrimSpace(s) == "" {
return "", fmt.Errorf("sink %q: params.%s cannot be empty", cfg.Name, key)
}
return s, nil
}
// PostgresFactory returns a sink factory that builds the built-in Postgres sink
// using the provided downstream schema definition.
func PostgresFactory(schema PostgresSchema) Factory {
return func(cfg config.SinkConfig) (Sink, error) {
return NewPostgresSinkFromConfig(cfg, schema)
}
}

46
sinks/helpers_test.go Normal file
View File

@@ -0,0 +1,46 @@
package sinks
import (
"context"
"testing"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
func TestPostgresFactoryReturnsWorkingFactory(t *testing.T) {
factory := PostgresFactory(testPostgresSchema())
if factory == nil {
t.Fatalf("PostgresFactory() returned nil")
}
if _, err := factory(config.SinkConfig{}); err == nil {
t.Fatalf("factory(config) expected parameter validation error")
}
}
func testPostgresSchema() PostgresSchema {
return PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "event_id", Type: "TEXT", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
},
PrimaryKey: []string{"event_id"},
PruneColumn: "emitted_at",
},
},
MapEvent: func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{
Table: "events",
Values: map[string]any{
"event_id": e.ID,
"emitted_at": e.EmittedAt,
},
},
}, nil
},
}
}

97
sinks/nats.go Normal file
View File

@@ -0,0 +1,97 @@
package sinks
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
"github.com/nats-io/nats.go"
)
type NATSSink struct {
name string
url string
subject string
mu sync.Mutex
conn *nats.Conn
}
func NewNATSSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
url, err := requireStringParam(cfg, "url")
if err != nil {
return nil, err
}
subject, err := requireStringParam(cfg, "subject")
if err != nil {
return nil, err
}
return &NATSSink{name: cfg.Name, url: url, subject: subject}, nil
}
func (r *NATSSink) Name() string { return r.name }
func (r *NATSSink) Consume(ctx context.Context, e event.Event) error {
// Boundary validation: if something upstream violated invariants,
// surface it loudly rather than printing partial nonsense.
if err := e.Validate(); err != nil {
return fmt.Errorf("NATS sink: invalid event: %w", err)
}
if err := ctx.Err(); err != nil {
return err
}
conn, err := r.connect(ctx)
if err != nil {
return fmt.Errorf("NATS sink: connect: %w", err)
}
b, err := json.Marshal(e)
if err != nil {
return fmt.Errorf("NATS sink: marshal event: %w", err)
}
if err := ctx.Err(); err != nil {
return err
}
if err := conn.Publish(r.subject, b); err != nil {
return fmt.Errorf("NATS sink: publish: %w", err)
}
return nil
}
func (r *NATSSink) connect(ctx context.Context) (*nats.Conn, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
r.mu.Lock()
defer r.mu.Unlock()
if r.conn != nil && r.conn.Status() != nats.CLOSED {
return r.conn, nil
}
opts := []nats.Option{
nats.Name(fmt.Sprintf("feedkit sink %s", r.name)),
}
if deadline, ok := ctx.Deadline(); ok {
timeout := time.Until(deadline)
if timeout <= 0 {
return nil, ctx.Err()
}
opts = append(opts, nats.Timeout(timeout))
}
conn, err := nats.Connect(r.url, opts...)
if err != nil {
return nil, err
}
r.conn = conn
return conn, nil
}

47
sinks/nats_test.go Normal file
View File

@@ -0,0 +1,47 @@
package sinks
import (
"strings"
"testing"
"gitea.maximumdirect.net/ejr/feedkit/config"
)
func TestNewNATSSinkFromConfigRequiresSubject(t *testing.T) {
sink, err := NewNATSSinkFromConfig(config.SinkConfig{
Name: "nats-main",
Driver: "nats",
Params: map[string]any{
"url": "nats://localhost:4222",
"subject": "feedkit.events",
},
})
if err != nil {
t.Fatalf("NewNATSSinkFromConfig() error = %v", err)
}
natsSink, ok := sink.(*NATSSink)
if !ok {
t.Fatalf("sink type = %T, want *NATSSink", sink)
}
if natsSink.subject != "feedkit.events" {
t.Fatalf("subject = %q, want feedkit.events", natsSink.subject)
}
}
func TestNewNATSSinkFromConfigRejectsLegacyExchange(t *testing.T) {
_, err := NewNATSSinkFromConfig(config.SinkConfig{
Name: "nats-main",
Driver: "nats",
Params: map[string]any{
"url": "nats://localhost:4222",
"exchange": "feedkit.events",
},
})
if err == nil {
t.Fatalf("NewNATSSinkFromConfig() expected error")
}
if !strings.Contains(err.Error(), "params.subject is required") {
t.Fatalf("error = %q, want params.subject is required", err)
}
}

View File

@@ -2,36 +2,437 @@ package sinks
import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
)
type PostgresSink struct {
name string
dsn string
type postgresTx interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Commit() error
Rollback() error
}
func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
dsn, err := requireStringParam(cfg, "dsn")
type postgresExecer interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
type postgresDB interface {
PingContext(ctx context.Context) error
BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Close() error
}
type sqlDBWrapper struct {
db *sql.DB
}
func (w *sqlDBWrapper) PingContext(ctx context.Context) error {
return w.db.PingContext(ctx)
}
func (w *sqlDBWrapper) BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error) {
tx, err := w.db.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &PostgresSink{name: cfg.Name, dsn: dsn}, nil
return &sqlTxWrapper{tx: tx}, nil
}
func (w *sqlDBWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
return w.db.ExecContext(ctx, query, args...)
}
func (w *sqlDBWrapper) Close() error {
return w.db.Close()
}
type sqlTxWrapper struct {
tx *sql.Tx
}
func (w *sqlTxWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
return w.tx.ExecContext(ctx, query, args...)
}
func (w *sqlTxWrapper) Commit() error {
return w.tx.Commit()
}
func (w *sqlTxWrapper) Rollback() error {
return w.tx.Rollback()
}
var openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
db, err := pgconn.Open(ctx, cfg)
if err != nil {
return nil, err
}
return &sqlDBWrapper{db: db}, nil
}
type PostgresSink struct {
name string
db postgresDB
schema postgresSchemaCompiled
pruneWindow time.Duration
}
func NewPostgresSinkFromConfig(cfg config.SinkConfig, schemaDef PostgresSchema) (Sink, error) {
uri, err := requireStringParam(cfg, "uri")
if err != nil {
return nil, err
}
username, err := requireStringParam(cfg, "username")
if err != nil {
return nil, err
}
password, err := requireStringParam(cfg, "password")
if err != nil {
return nil, err
}
pruneWindow, err := parsePostgresPruneWindow(cfg)
if err != nil {
return nil, err
}
schema, err := compilePostgresSchema(schemaDef)
if err != nil {
return nil, fmt.Errorf("postgres sink %q: compile schema: %w", cfg.Name, err)
}
db, err := openPostgresDB(context.Background(), pgconn.ConnConfig{
URI: uri,
Username: username,
Password: password,
})
if err != nil {
return nil, fmt.Errorf("postgres sink %q: open db: %w", cfg.Name, err)
}
s := &PostgresSink{name: cfg.Name, db: db, schema: schema, pruneWindow: pruneWindow}
if err := s.initialize(); err != nil {
_ = db.Close()
return nil, err
}
return s, nil
}
func (p *PostgresSink) Name() string { return p.name }
func (p *PostgresSink) Consume(ctx context.Context, e event.Event) error {
_ = ctx
// Boundary validation: if something upstream violated invariants,
// surface it loudly rather than printing partial nonsense.
// surface it loudly rather than writing corrupt rows.
if err := e.Validate(); err != nil {
return fmt.Errorf("rabbitmq sink: invalid event: %w", err)
return fmt.Errorf("postgres sink: invalid event: %w", err)
}
// TODO implement Postgres transaction
if err := ctx.Err(); err != nil {
return err
}
writes, err := p.schema.mapEvent(ctx, e)
if err != nil {
return fmt.Errorf("postgres sink: map event: %w", err)
}
if len(writes) == 0 {
return nil
}
tx, err := p.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("postgres sink: begin tx: %w", err)
}
committed := false
defer func() {
if !committed {
_ = tx.Rollback()
}
}()
for _, w := range writes {
tbl, err := p.schema.validateWrite(w)
if err != nil {
return fmt.Errorf("postgres sink: %w", err)
}
query, args, err := buildInsertSQL(tbl, w)
if err != nil {
return fmt.Errorf("postgres sink: build insert for table %q: %w", tbl.name, err)
}
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return fmt.Errorf("postgres sink: insert into %q: %w", tbl.name, err)
}
}
if p.pruneWindow > 0 {
cutoff := time.Now().UTC().Add(-p.pruneWindow)
for _, tableName := range p.schema.tableOrder {
tbl := p.schema.tables[tableName]
if _, err := execPruneOlderThan(ctx, tx, tbl, cutoff); err != nil {
return fmt.Errorf("postgres sink: %w", err)
}
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("postgres sink: commit tx: %w", err)
}
committed = true
return nil
}
func (p *PostgresSink) PruneKeepLatest(ctx context.Context, table string, keep int) (int64, error) {
if keep < 0 {
return 0, fmt.Errorf("postgres sink: keep must be >= 0")
}
tbl, err := p.lookupTable(table)
if err != nil {
return 0, err
}
query := fmt.Sprintf(
`DELETE FROM %s WHERE ctid IN (
SELECT ctid FROM %s
ORDER BY %s DESC
OFFSET $1
)`,
quotePostgresIdent(tbl.name),
quotePostgresIdent(tbl.name),
quotePostgresIdent(tbl.pruneColumn),
)
res, err := p.db.ExecContext(ctx, query, keep)
if err != nil {
return 0, fmt.Errorf("postgres sink: prune keep latest table %q: %w", tbl.name, err)
}
rows, err := res.RowsAffected()
if err != nil {
return 0, fmt.Errorf("postgres sink: prune keep latest table %q rows affected: %w", tbl.name, err)
}
return rows, nil
}
func (p *PostgresSink) PruneOlderThan(ctx context.Context, table string, cutoff time.Time) (int64, error) {
tbl, err := p.lookupTable(table)
if err != nil {
return 0, err
}
rows, err := execPruneOlderThan(ctx, p.db, tbl, cutoff)
if err != nil {
return 0, fmt.Errorf("postgres sink: %w", err)
}
return rows, nil
}
func (p *PostgresSink) PruneAllKeepLatest(ctx context.Context, keep int) (map[string]int64, error) {
counts := make(map[string]int64, len(p.schema.tableOrder))
for _, table := range p.schema.tableOrder {
n, err := p.PruneKeepLatest(ctx, table, keep)
if err != nil {
return counts, err
}
counts[table] = n
}
return counts, nil
}
func (p *PostgresSink) PruneAllOlderThan(ctx context.Context, cutoff time.Time) (map[string]int64, error) {
counts := make(map[string]int64, len(p.schema.tableOrder))
for _, table := range p.schema.tableOrder {
n, err := p.PruneOlderThan(ctx, table, cutoff)
if err != nil {
return counts, err
}
counts[table] = n
}
return counts, nil
}
func (p *PostgresSink) initialize() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
for _, tableName := range p.schema.tableOrder {
tbl := p.schema.tables[tableName]
createTableSQL := buildCreateTableSQL(tbl)
if _, err := p.db.ExecContext(ctx, createTableSQL); err != nil {
return fmt.Errorf("postgres sink %q: ensure table %q: %w", p.name, tbl.name, err)
}
for _, idx := range tbl.indexes {
createIndexSQL := buildCreateIndexSQL(tbl.name, idx)
if _, err := p.db.ExecContext(ctx, createIndexSQL); err != nil {
return fmt.Errorf("postgres sink %q: ensure index %q on %q: %w", p.name, idx.Name, tbl.name, err)
}
}
}
return nil
}
func (p *PostgresSink) lookupTable(table string) (postgresTableCompiled, error) {
table = strings.TrimSpace(table)
if table == "" {
return postgresTableCompiled{}, fmt.Errorf("postgres sink: table cannot be empty")
}
tbl, ok := p.schema.tables[table]
if !ok {
return postgresTableCompiled{}, fmt.Errorf("postgres sink: unknown table %q", table)
}
return tbl, nil
}
func parsePostgresPruneWindow(cfg config.SinkConfig) (time.Duration, error) {
raw, ok := cfg.Params["prune"]
if !ok || raw == nil {
return 0, nil
}
s, ok := raw.(string)
if !ok {
return 0, fmt.Errorf("sink %q: params.prune must be a string duration (e.g. 72h, 3d, 2w)", cfg.Name)
}
d, err := parsePostgresPruneDuration(s)
if err != nil {
return 0, fmt.Errorf("sink %q: params.prune %q is invalid: %w", cfg.Name, s, err)
}
return d, nil
}
func parsePostgresPruneDuration(raw string) (time.Duration, error) {
s := strings.TrimSpace(raw)
if s == "" {
return 0, fmt.Errorf("must not be empty")
}
lower := strings.ToLower(s)
if strings.HasSuffix(lower, "d") || strings.HasSuffix(lower, "w") {
unit := lower[len(lower)-1]
n, err := strconv.Atoi(strings.TrimSpace(lower[:len(lower)-1]))
if err != nil {
return 0, fmt.Errorf("must use a positive integer before %q", string(unit))
}
if n <= 0 {
return 0, fmt.Errorf("must be > 0")
}
if unit == 'd' {
return time.Duration(n) * 24 * time.Hour, nil
}
return time.Duration(n) * 7 * 24 * time.Hour, nil
}
d, err := time.ParseDuration(s)
if err != nil {
return 0, fmt.Errorf("must be a Go duration or use d/w suffixes")
}
if d <= 0 {
return 0, fmt.Errorf("must be > 0")
}
return d, nil
}
func buildPruneOlderThanSQL(tbl postgresTableCompiled) string {
return fmt.Sprintf(
`DELETE FROM %s WHERE %s < $1`,
quotePostgresIdent(tbl.name),
quotePostgresIdent(tbl.pruneColumn),
)
}
func execPruneOlderThan(ctx context.Context, execer postgresExecer, tbl postgresTableCompiled, cutoff time.Time) (int64, error) {
query := buildPruneOlderThanSQL(tbl)
res, err := execer.ExecContext(ctx, query, cutoff)
if err != nil {
return 0, fmt.Errorf("prune older than table %q: %w", tbl.name, err)
}
rows, err := res.RowsAffected()
if err != nil {
return 0, fmt.Errorf("prune older than table %q rows affected: %w", tbl.name, err)
}
return rows, nil
}
func buildCreateTableSQL(tbl postgresTableCompiled) string {
defs := make([]string, 0, len(tbl.columnOrder)+1)
for _, colName := range tbl.columnOrder {
col := tbl.columns[colName]
def := fmt.Sprintf("%s %s", quotePostgresIdent(col.Name), col.Type)
if !col.Nullable {
def += " NOT NULL"
}
defs = append(defs, def)
}
if len(tbl.primaryKey) > 0 {
defs = append(defs, fmt.Sprintf("PRIMARY KEY (%s)", joinQuotedIdents(tbl.primaryKey)))
}
return fmt.Sprintf(
"CREATE TABLE IF NOT EXISTS %s (%s)",
quotePostgresIdent(tbl.name),
strings.Join(defs, ", "),
)
}
func buildCreateIndexSQL(tableName string, idx PostgresIndex) string {
unique := ""
if idx.Unique {
unique = "UNIQUE "
}
return fmt.Sprintf(
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
unique,
quotePostgresIdent(idx.Name),
quotePostgresIdent(tableName),
joinQuotedIdents(idx.Columns),
)
}
func buildInsertSQL(tbl postgresTableCompiled, w PostgresWrite) (string, []any, error) {
cols := make([]string, 0, len(tbl.columnOrder))
args := make([]any, 0, len(tbl.columnOrder))
placeholders := make([]string, 0, len(tbl.columnOrder))
for i, colName := range tbl.columnOrder {
v, ok := w.Values[colName]
if !ok {
return "", nil, fmt.Errorf("missing value for column %q", colName)
}
cols = append(cols, quotePostgresIdent(colName))
args = append(args, v)
placeholders = append(placeholders, "$"+strconv.Itoa(i+1))
}
q := fmt.Sprintf(
"INSERT INTO %s (%s) VALUES (%s)",
quotePostgresIdent(tbl.name),
strings.Join(cols, ", "),
strings.Join(placeholders, ", "),
)
return q, args, nil
}
func joinQuotedIdents(idents []string) string {
quoted := make([]string, 0, len(idents))
for _, s := range idents {
quoted = append(quoted, quotePostgresIdent(s))
}
return strings.Join(quoted, ", ")
}
func quotePostgresIdent(s string) string {
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
}

239
sinks/postgres_schema.go Normal file
View File

@@ -0,0 +1,239 @@
package sinks
import (
"context"
"fmt"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// PostgresMapFunc maps one event into zero or more table writes.
//
// Returning zero writes means "this event is not mapped for this sink" and is
// treated as a non-error no-op.
type PostgresMapFunc func(ctx context.Context, e event.Event) ([]PostgresWrite, error)
// PostgresSchema describes the downstream-provided relational model and mapper
// for one configured postgres sink.
type PostgresSchema struct {
Tables []PostgresTable
MapEvent PostgresMapFunc
}
type PostgresWrite struct {
Table string
Values map[string]any
}
type PostgresTable struct {
Name string
Columns []PostgresColumn
PrimaryKey []string
PruneColumn string
Indexes []PostgresIndex
}
type PostgresColumn struct {
Name string
Type string
Nullable bool
}
type PostgresIndex struct {
Name string
Columns []string
Unique bool
}
// PostgresPruner is an optional interface exposed by PostgresSink so downstream
// applications can call retention helpers via type assertion.
type PostgresPruner interface {
PruneKeepLatest(ctx context.Context, table string, keep int) (int64, error)
PruneOlderThan(ctx context.Context, table string, cutoff time.Time) (int64, error)
PruneAllKeepLatest(ctx context.Context, keep int) (map[string]int64, error)
PruneAllOlderThan(ctx context.Context, cutoff time.Time) (map[string]int64, error)
}
type postgresSchemaCompiled struct {
tableOrder []string
tables map[string]postgresTableCompiled
mapEvent PostgresMapFunc
}
type postgresTableCompiled struct {
name string
columns map[string]PostgresColumn
columnOrder []string
primaryKey []string
pruneColumn string
indexes []PostgresIndex
}
func compilePostgresSchema(schema PostgresSchema) (postgresSchemaCompiled, error) {
if schema.MapEvent == nil {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: map function is required")
}
if len(schema.Tables) == 0 {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: at least one table is required")
}
compiled := postgresSchemaCompiled{
tableOrder: make([]string, 0, len(schema.Tables)),
tables: make(map[string]postgresTableCompiled, len(schema.Tables)),
mapEvent: schema.MapEvent,
}
seenTables := map[string]bool{}
seenIndexes := map[string]bool{}
for i, tbl := range schema.Tables {
tableName := strings.TrimSpace(tbl.Name)
if tableName == "" {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: tables[%d].name is required", i)
}
if seenTables[tableName] {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: duplicate table name %q", tableName)
}
seenTables[tableName] = true
if len(tbl.Columns) == 0 {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q must define at least one column", tableName)
}
colOrder := make([]string, 0, len(tbl.Columns))
colMap := make(map[string]PostgresColumn, len(tbl.Columns))
for j, col := range tbl.Columns {
colName := strings.TrimSpace(col.Name)
if colName == "" {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q columns[%d].name is required", tableName, j)
}
if _, exists := colMap[colName]; exists {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q duplicate column %q", tableName, colName)
}
if strings.TrimSpace(col.Type) == "" {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q column %q type is required", tableName, colName)
}
colOrder = append(colOrder, colName)
colMap[colName] = PostgresColumn{
Name: colName,
Type: strings.TrimSpace(col.Type),
Nullable: col.Nullable,
}
}
pk, err := validatePostgresColumnSet(tableName, "primary key", tbl.PrimaryKey, colMap)
if err != nil {
return postgresSchemaCompiled{}, err
}
pruneCol := strings.TrimSpace(tbl.PruneColumn)
if pruneCol == "" {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q prune column is required", tableName)
}
if _, ok := colMap[pruneCol]; !ok {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q prune column %q not found in columns", tableName, pruneCol)
}
indexes := make([]PostgresIndex, 0, len(tbl.Indexes))
for j, idx := range tbl.Indexes {
idxName := strings.TrimSpace(idx.Name)
if idxName == "" {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q indexes[%d].name is required", tableName, j)
}
if len(idx.Columns) == 0 {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q index %q must include at least one column", tableName, idxName)
}
if seenIndexes[idxName] {
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: duplicate index name %q", idxName)
}
seenIndexes[idxName] = true
idxCols, err := validatePostgresColumnSet(tableName, fmt.Sprintf("index %q columns", idxName), idx.Columns, colMap)
if err != nil {
return postgresSchemaCompiled{}, err
}
indexes = append(indexes, PostgresIndex{
Name: idxName,
Columns: idxCols,
Unique: idx.Unique,
})
}
compiled.tableOrder = append(compiled.tableOrder, tableName)
compiled.tables[tableName] = postgresTableCompiled{
name: tableName,
columns: colMap,
columnOrder: colOrder,
primaryKey: pk,
pruneColumn: pruneCol,
indexes: indexes,
}
}
return compiled, nil
}
func validatePostgresColumnSet(tableName, label string, cols []string, colMap map[string]PostgresColumn) ([]string, error) {
if len(cols) == 0 {
return nil, nil
}
out := make([]string, 0, len(cols))
seen := map[string]bool{}
for _, c := range cols {
name := strings.TrimSpace(c)
if name == "" {
return nil, fmt.Errorf("postgres schema: table %q %s contains empty column name", tableName, label)
}
if seen[name] {
return nil, fmt.Errorf("postgres schema: table %q %s contains duplicate column %q", tableName, label, name)
}
if _, ok := colMap[name]; !ok {
return nil, fmt.Errorf("postgres schema: table %q %s references unknown column %q", tableName, label, name)
}
seen[name] = true
out = append(out, name)
}
return out, nil
}
func (s postgresSchemaCompiled) validateWrite(w PostgresWrite) (postgresTableCompiled, error) {
tableName := strings.TrimSpace(w.Table)
if tableName == "" {
return postgresTableCompiled{}, fmt.Errorf("write table is required")
}
t, ok := s.tables[tableName]
if !ok {
return postgresTableCompiled{}, fmt.Errorf("table %q is not defined in postgres schema", tableName)
}
if len(w.Values) == 0 {
return postgresTableCompiled{}, fmt.Errorf("write for table %q must include values", tableName)
}
for k := range w.Values {
if _, ok := t.columns[k]; !ok {
return postgresTableCompiled{}, fmt.Errorf("write for table %q includes unknown column %q", tableName, k)
}
}
if len(w.Values) != len(t.columnOrder) {
return postgresTableCompiled{}, fmt.Errorf("write for table %q must include all declared columns", tableName)
}
for _, col := range t.columnOrder {
v, ok := w.Values[col]
if !ok {
return postgresTableCompiled{}, fmt.Errorf("write for table %q is missing column %q", tableName, col)
}
if v == nil {
if c := t.columns[col]; !c.Nullable {
return postgresTableCompiled{}, fmt.Errorf("write for table %q has nil value for non-null column %q", tableName, col)
}
}
}
return t, nil
}

748
sinks/postgres_test.go Normal file
View File

@@ -0,0 +1,748 @@
package sinks
import (
"context"
"database/sql"
"errors"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
)
type fakeResult struct {
rows int64
}
func (r fakeResult) LastInsertId() (int64, error) { return 0, errors.New("unsupported") }
func (r fakeResult) RowsAffected() (int64, error) { return r.rows, nil }
type execCall struct {
query string
args []any
}
type fakeTx struct {
execCalls []execCall
execErr error
execErrOnCall int
execRows int64
commitErr error
rollbackErr error
commitCalls int
rollbackCalls int
}
func (t *fakeTx) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) {
t.execCalls = append(t.execCalls, execCall{query: query, args: append([]any(nil), args...)})
if t.execErr != nil && t.execErrOnCall == len(t.execCalls) {
return nil, t.execErr
}
return fakeResult{rows: t.execRows}, nil
}
func (t *fakeTx) Commit() error {
t.commitCalls++
return t.commitErr
}
func (t *fakeTx) Rollback() error {
t.rollbackCalls++
return t.rollbackErr
}
type fakeDB struct {
pingErr error
beginErr error
execErr error
execErrOnCall int
execRows int64
pingCalls int
beginCalls int
execCalls []execCall
closeCalls int
tx *fakeTx
}
func (d *fakeDB) PingContext(_ context.Context) error {
d.pingCalls++
return d.pingErr
}
func (d *fakeDB) BeginTx(_ context.Context, _ *sql.TxOptions) (postgresTx, error) {
d.beginCalls++
if d.beginErr != nil {
return nil, d.beginErr
}
if d.tx == nil {
d.tx = &fakeTx{}
}
return d.tx, nil
}
func (d *fakeDB) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) {
d.execCalls = append(d.execCalls, execCall{query: query, args: append([]any(nil), args...)})
if d.execErr != nil && d.execErrOnCall == len(d.execCalls) {
return nil, d.execErr
}
return fakeResult{rows: d.execRows}, nil
}
func (d *fakeDB) Close() error {
d.closeCalls++
return nil
}
func withPostgresTestState(t *testing.T) {
t.Helper()
oldOpen := openPostgresDB
t.Cleanup(func() {
openPostgresDB = oldOpen
})
}
func validTestEvent() event.Event {
now := time.Now().UTC()
return event.Event{
ID: "evt-1",
Kind: event.Kind("observation"),
Source: "source-1",
EmittedAt: now,
Payload: map[string]any{
"x": 1,
},
}
}
func schemaOneTable(mapFn PostgresMapFunc) PostgresSchema {
return PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "event_id", Type: "TEXT", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
{Name: "payload_json", Type: "JSONB", Nullable: false},
},
PrimaryKey: []string{"event_id"},
PruneColumn: "emitted_at",
Indexes: []PostgresIndex{
{Name: "idx_events_emitted_at", Columns: []string{"emitted_at"}},
},
},
},
MapEvent: mapFn,
}
}
func schemaTwoTables(mapFn PostgresMapFunc) PostgresSchema {
return PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "event_id", Type: "TEXT", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
},
PrimaryKey: []string{"event_id"},
PruneColumn: "emitted_at",
},
{
Name: "event_payloads",
Columns: []PostgresColumn{
{Name: "event_id", Type: "TEXT", Nullable: false},
{Name: "payload_json", Type: "JSONB", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
},
PrimaryKey: []string{"event_id"},
PruneColumn: "emitted_at",
},
},
MapEvent: mapFn,
}
}
func mustCompileSchema(t *testing.T, s PostgresSchema) postgresSchemaCompiled {
t.Helper()
compiled, err := compilePostgresSchema(s)
if err != nil {
t.Fatalf("compile schema: %v", err)
}
return compiled
}
func TestCompilePostgresSchemaRejectsInvalidSchema(t *testing.T) {
_, err := compilePostgresSchema(PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "id", Type: "TEXT", Nullable: false},
},
PruneColumn: "missing_col",
},
},
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
})
if err == nil {
t.Fatalf("expected invalid schema error")
}
if !strings.Contains(err.Error(), "prune column") {
t.Fatalf("unexpected schema validation error: %v", err)
}
_, err = compilePostgresSchema(PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "id", Type: "TEXT", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
},
PruneColumn: "emitted_at",
Indexes: []PostgresIndex{
{Name: "idx_events_empty", Columns: nil},
},
},
},
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
})
if err == nil {
t.Fatalf("expected invalid index schema error")
}
if !strings.Contains(err.Error(), "at least one column") {
t.Fatalf("unexpected index validation error: %v", err)
}
}
func TestPostgresFactoryBuildsMultipleSinksWithSameSchema(t *testing.T) {
withPostgresTestState(t)
dbs := []*fakeDB{{}, {}}
var gotCfgs []pgconn.ConnConfig
openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
gotCfgs = append(gotCfgs, cfg)
db := dbs[len(gotCfgs)-1]
if err := db.PingContext(ctx); err != nil {
return nil, err
}
return db, nil
}
factory := PostgresFactory(schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
}))
for _, name := range []string{"pg_a", "pg_b"} {
sink, err := factory(config.SinkConfig{
Name: name,
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
},
})
if err != nil {
t.Fatalf("factory(%q) error = %v", name, err)
}
if sink == nil {
t.Fatalf("factory(%q) returned nil sink", name)
}
}
if len(gotCfgs) != 2 {
t.Fatalf("len(gotCfgs) = %d, want 2", len(gotCfgs))
}
if gotCfgs[0].Username != "user" || gotCfgs[0].Password != "pass" {
t.Fatalf("first ConnConfig = %+v", gotCfgs[0])
}
for i, db := range dbs {
if db.pingCalls != 1 {
t.Fatalf("db[%d] pingCalls = %d, want 1", i, db.pingCalls)
}
}
}
func TestNewPostgresSinkFromConfigMissingParams(t *testing.T) {
withPostgresTestState(t)
tests := []struct {
name string
params map[string]any
want string
}{
{name: "missing uri", params: map[string]any{"username": "u", "password": "p"}, want: "params.uri"},
{name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p"}, want: "params.username"},
{name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u"}, want: "params.password"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: tc.params,
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err == nil {
t.Fatalf("expected error")
}
if !strings.Contains(err.Error(), tc.want) {
t.Fatalf("expected %q in error, got: %v", tc.want, err)
}
})
}
}
func TestNewPostgresSinkFromConfigRejectsInvalidSchema(t *testing.T) {
withPostgresTestState(t)
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
},
}, PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "id", Type: "TEXT", Nullable: false},
},
PruneColumn: "missing_col",
},
},
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
})
if err == nil {
t.Fatalf("expected invalid schema error")
}
if !strings.Contains(err.Error(), "compile schema") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNewPostgresSinkFromConfigEagerInit(t *testing.T) {
withPostgresTestState(t)
db := &fakeDB{}
var gotCfg pgconn.ConnConfig
openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
gotCfg = cfg
if err := db.PingContext(ctx); err != nil {
return nil, err
}
return db, nil
}
s, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://db.example.local:5432/feedkit?sslmode=disable",
"username": "app_user",
"password": "app_pass",
},
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err != nil {
t.Fatalf("new postgres sink: %v", err)
}
if s == nil {
t.Fatalf("expected sink")
}
if db.pingCalls != 1 {
t.Fatalf("expected one ping, got %d", db.pingCalls)
}
if len(db.execCalls) != 2 {
t.Fatalf("expected 2 init exec calls (table + index), got %d", len(db.execCalls))
}
if !strings.Contains(db.execCalls[0].query, `CREATE TABLE IF NOT EXISTS "events"`) {
t.Fatalf("unexpected create table query: %s", db.execCalls[0].query)
}
if !strings.Contains(db.execCalls[1].query, `CREATE INDEX IF NOT EXISTS "idx_events_emitted_at"`) {
t.Fatalf("unexpected create index query: %s", db.execCalls[1].query)
}
if gotCfg.URI != "postgres://db.example.local:5432/feedkit?sslmode=disable" {
t.Fatalf("URI = %q", gotCfg.URI)
}
if gotCfg.Username != "app_user" {
t.Fatalf("Username = %q, want app_user", gotCfg.Username)
}
if gotCfg.Password != "app_pass" {
t.Fatalf("Password = %q, want app_pass", gotCfg.Password)
}
}
func TestNewPostgresSinkFromConfigInitFailureClosesDB(t *testing.T) {
withPostgresTestState(t)
db := &fakeDB{execErrOnCall: 1, execErr: errors.New("ddl failed")}
openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) {
return db, nil
}
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
},
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err == nil {
t.Fatalf("expected init error")
}
if db.closeCalls != 1 {
t.Fatalf("expected db close on init failure")
}
}
func TestNewPostgresSinkFromConfigPruneParamAccepted(t *testing.T) {
tests := []struct {
name string
in string
want time.Duration
}{
{name: "go duration", in: "72h", want: 72 * time.Hour},
{name: "days suffix", in: "3d", want: 72 * time.Hour},
{name: "weeks suffix", in: "2w", want: 14 * 24 * time.Hour},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
withPostgresTestState(t)
openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) {
return &fakeDB{}, nil
}
s, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
"prune": tc.in,
},
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err != nil {
t.Fatalf("new postgres sink: %v", err)
}
pg, ok := s.(*PostgresSink)
if !ok {
t.Fatalf("expected *PostgresSink, got %T", s)
}
if pg.pruneWindow != tc.want {
t.Fatalf("prune window = %s, want %s", pg.pruneWindow, tc.want)
}
})
}
}
func TestNewPostgresSinkFromConfigPruneParamRejected(t *testing.T) {
withPostgresTestState(t)
tests := []struct {
name string
in any
}{
{name: "empty", in: ""},
{name: "zero", in: "0"},
{name: "negative", in: "-1h"},
{name: "malformed", in: "abc"},
{name: "fractional day", in: "1.5d"},
{name: "wrong type", in: 5},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
"prune": tc.in,
},
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err == nil {
t.Fatalf("expected error")
}
if !strings.Contains(err.Error(), "params.prune") {
t.Fatalf("expected params.prune error, got %v", err)
}
})
}
}
func TestPostgresSinkConsumeInvalidEvent(t *testing.T) {
db := &fakeDB{}
called := 0
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
called++
return nil, nil
})),
}
err := sink.Consume(context.Background(), event.Event{})
if err == nil {
t.Fatalf("expected invalid event error")
}
if !strings.Contains(err.Error(), "invalid event") {
t.Fatalf("unexpected error: %v", err)
}
if called != 0 {
t.Fatalf("expected mapper not called for invalid events")
}
}
func TestPostgresSinkConsumeUnmappedEventIsNoOp(t *testing.T) {
db := &fakeDB{}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
})),
}
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
t.Fatalf("consume: %v", err)
}
if db.beginCalls != 0 {
t.Fatalf("expected no transaction for unmapped events")
}
}
func TestPostgresSinkConsumeOneEventWritesMultipleTablesAtomically(t *testing.T) {
tx := &fakeTx{}
db := &fakeDB{tx: tx}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
}, nil
})),
}
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
t.Fatalf("consume: %v", err)
}
if db.beginCalls != 1 {
t.Fatalf("expected one transaction begin, got %d", db.beginCalls)
}
if len(tx.execCalls) != 2 {
t.Fatalf("expected 2 insert statements, got %d", len(tx.execCalls))
}
if tx.commitCalls != 1 {
t.Fatalf("expected one commit, got %d", tx.commitCalls)
}
if tx.rollbackCalls != 0 {
t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls)
}
}
func TestPostgresSinkConsumeInsertFailureRollsBack(t *testing.T) {
tx := &fakeTx{execErrOnCall: 2, execErr: errors.New("duplicate key")}
db := &fakeDB{tx: tx}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
}, nil
})),
}
err := sink.Consume(context.Background(), validTestEvent())
if err == nil {
t.Fatalf("expected insert error")
}
if !strings.Contains(err.Error(), "insert into") {
t.Fatalf("unexpected error: %v", err)
}
if tx.commitCalls != 0 {
t.Fatalf("expected no commit")
}
if tx.rollbackCalls != 1 {
t.Fatalf("expected rollback, got %d", tx.rollbackCalls)
}
}
func TestPostgresSinkConsumeAutoPruneRunsInSameTransaction(t *testing.T) {
tx := &fakeTx{}
db := &fakeDB{tx: tx}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
}, nil
})),
pruneWindow: 24 * time.Hour,
}
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
t.Fatalf("consume: %v", err)
}
if len(tx.execCalls) != 4 {
t.Fatalf("expected 4 tx statements (2 inserts + 2 prunes), got %d", len(tx.execCalls))
}
if !strings.Contains(tx.execCalls[2].query, `DELETE FROM "events"`) {
t.Fatalf("expected prune delete for events, got %s", tx.execCalls[2].query)
}
if !strings.Contains(tx.execCalls[3].query, `DELETE FROM "event_payloads"`) {
t.Fatalf("expected prune delete for event_payloads, got %s", tx.execCalls[3].query)
}
if tx.commitCalls != 1 {
t.Fatalf("expected one commit, got %d", tx.commitCalls)
}
if tx.rollbackCalls != 0 {
t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls)
}
}
func TestPostgresSinkConsumeAutoPruneFailureRollsBack(t *testing.T) {
tx := &fakeTx{execErrOnCall: 3, execErr: errors.New("prune failed")}
db := &fakeDB{tx: tx}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
}, nil
})),
pruneWindow: 24 * time.Hour,
}
err := sink.Consume(context.Background(), validTestEvent())
if err == nil {
t.Fatalf("expected prune error")
}
if !strings.Contains(err.Error(), "prune older than") {
t.Fatalf("unexpected error: %v", err)
}
if tx.commitCalls != 0 {
t.Fatalf("expected no commit")
}
if tx.rollbackCalls != 1 {
t.Fatalf("expected rollback, got %d", tx.rollbackCalls)
}
}
func TestPostgresSinkPrunePerTable(t *testing.T) {
db := &fakeDB{execRows: 7}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
})),
}
rows, err := sink.PruneKeepLatest(context.Background(), "events", 10)
if err != nil {
t.Fatalf("prune keep latest: %v", err)
}
if rows != 7 {
t.Fatalf("unexpected rows affected: %d", rows)
}
if len(db.execCalls) != 1 {
t.Fatalf("expected one prune query")
}
if !strings.Contains(db.execCalls[0].query, `ORDER BY "emitted_at" DESC`) {
t.Fatalf("unexpected keep-latest query: %s", db.execCalls[0].query)
}
if len(db.execCalls[0].args) != 1 || db.execCalls[0].args[0] != 10 {
t.Fatalf("unexpected keep-latest args: %#v", db.execCalls[0].args)
}
cutoff := time.Now().UTC().Add(-24 * time.Hour)
rows, err = sink.PruneOlderThan(context.Background(), "events", cutoff)
if err != nil {
t.Fatalf("prune older than: %v", err)
}
if rows != 7 {
t.Fatalf("unexpected rows affected: %d", rows)
}
if len(db.execCalls) != 2 {
t.Fatalf("expected two prune queries")
}
if !strings.Contains(db.execCalls[1].query, `WHERE "emitted_at" < $1`) {
t.Fatalf("unexpected older-than query: %s", db.execCalls[1].query)
}
}
func TestPostgresSinkPruneAllTables(t *testing.T) {
db := &fakeDB{execRows: 3}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
})),
}
keepCounts, err := sink.PruneAllKeepLatest(context.Background(), 5)
if err != nil {
t.Fatalf("prune all keep latest: %v", err)
}
if len(keepCounts) != 2 || keepCounts["events"] != 3 || keepCounts["event_payloads"] != 3 {
t.Fatalf("unexpected keep counts: %#v", keepCounts)
}
db.execCalls = nil
olderCounts, err := sink.PruneAllOlderThan(context.Background(), time.Now().UTC())
if err != nil {
t.Fatalf("prune all older than: %v", err)
}
if len(olderCounts) != 2 || olderCounts["events"] != 3 || olderCounts["event_payloads"] != 3 {
t.Fatalf("unexpected older-than counts: %#v", olderCounts)
}
if len(db.execCalls) != 2 {
t.Fatalf("expected one prune call per table")
}
}
func TestPostgresSinkPruneErrors(t *testing.T) {
db := &fakeDB{}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
})),
}
if _, err := sink.PruneKeepLatest(context.Background(), "events", -1); err == nil {
t.Fatalf("expected negative keep error")
}
if _, err := sink.PruneOlderThan(context.Background(), "missing", time.Now().UTC()); err == nil {
t.Fatalf("expected unknown table error")
}
}

View File

@@ -1,42 +0,0 @@
package sinks
import (
"context"
"fmt"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type RabbitMQSink struct {
name string
url string
exchange string
}
func NewRabbitMQSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
url, err := requireStringParam(cfg, "url")
if err != nil {
return nil, err
}
ex, err := requireStringParam(cfg, "exchange")
if err != nil {
return nil, err
}
return &RabbitMQSink{name: cfg.Name, url: url, exchange: ex}, nil
}
func (r *RabbitMQSink) Name() string { return r.name }
func (r *RabbitMQSink) Consume(ctx context.Context, e event.Event) error {
_ = ctx
// Boundary validation: if something upstream violated invariants,
// surface it loudly rather than printing partial nonsense.
if err := e.Validate(); err != nil {
return fmt.Errorf("rabbitmq sink: invalid event: %w", err)
}
// TODO implement RabbitMQ publishing
return nil
}

View File

@@ -2,6 +2,7 @@ package sinks
import (
"fmt"
"strings"
"gitea.maximumdirect.net/ejr/feedkit/config"
)
@@ -21,13 +22,40 @@ func NewRegistry() *Registry {
}
func (r *Registry) Register(driver string, f Factory) {
if r == nil {
panic("sinks.Registry.Register: registry cannot be nil")
}
driver = strings.TrimSpace(driver)
if driver == "" {
panic("sinks.Registry.Register: driver cannot be empty")
}
if f == nil {
panic(fmt.Sprintf("sinks.Registry.Register: factory cannot be nil (driver=%q)", driver))
}
if r.byDriver == nil {
r.byDriver = map[string]Factory{}
}
if _, exists := r.byDriver[driver]; exists {
panic(fmt.Sprintf("sinks.Registry.Register: driver %q already registered", driver))
}
r.byDriver[driver] = f
}
func (r *Registry) Build(cfg config.SinkConfig) (Sink, error) {
f, ok := r.byDriver[cfg.Driver]
if r == nil {
return nil, fmt.Errorf("sinks registry is nil")
}
driver := strings.TrimSpace(cfg.Driver)
f, ok := r.byDriver[driver]
if !ok {
return nil, fmt.Errorf("unknown sink driver: %q", cfg.Driver)
return nil, fmt.Errorf("unknown sink driver: %q", driver)
}
return f(cfg)
sink, err := f(cfg)
if err != nil {
return nil, fmt.Errorf("build sink %q: %w", driver, err)
}
if sink == nil {
return nil, fmt.Errorf("build sink %q: factory returned nil sink", driver)
}
return sink, nil
}

126
sinks/registry_test.go Normal file
View File

@@ -0,0 +1,126 @@
package sinks
import (
"context"
"errors"
"strings"
"testing"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type testSink struct{ name string }
func (s testSink) Name() string { return s.name }
func (s testSink) Consume(context.Context, event.Event) error { return nil }
func TestRegistryRegisterPanicsOnNilRegistry(t *testing.T) {
var r *Registry
defer func() {
if recover() == nil {
t.Fatalf("Register() expected panic on nil registry")
}
}()
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "stdout"}, nil })
}
func TestRegistryRegisterPanicsOnEmptyDriver(t *testing.T) {
r := NewRegistry()
defer func() {
if recover() == nil {
t.Fatalf("Register() expected panic on empty driver")
}
}()
r.Register(" ", func(config.SinkConfig) (Sink, error) { return testSink{name: "x"}, nil })
}
func TestRegistryRegisterPanicsOnNilFactory(t *testing.T) {
r := NewRegistry()
defer func() {
if recover() == nil {
t.Fatalf("Register() expected panic on nil factory")
}
}()
r.Register("stdout", nil)
}
func TestRegistryRegisterPanicsOnDuplicateDriver(t *testing.T) {
r := NewRegistry()
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "a"}, nil })
defer func() {
if recover() == nil {
t.Fatalf("Register() expected panic on duplicate driver")
}
}()
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "b"}, nil })
}
func TestRegistryBuildNilRegistryFails(t *testing.T) {
var r *Registry
_, err := r.Build(config.SinkConfig{Driver: "stdout"})
if err == nil {
t.Fatalf("Build() expected error for nil registry")
}
if !strings.Contains(err.Error(), "registry is nil") {
t.Fatalf("Build() error = %q, want registry is nil", err)
}
}
func TestRegistryBuildTrimsDriver(t *testing.T) {
r := NewRegistry()
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "stdout"}, nil })
sink, err := r.Build(config.SinkConfig{Name: "sink1", Driver: " stdout "})
if err != nil {
t.Fatalf("Build() error = %v", err)
}
if sink.Name() != "stdout" {
t.Fatalf("Build() sink name = %q, want stdout", sink.Name())
}
}
func TestRegistryBuildWrapsFactoryError(t *testing.T) {
r := NewRegistry()
r.Register("broken", func(config.SinkConfig) (Sink, error) { return nil, errors.New("boom") })
_, err := r.Build(config.SinkConfig{Driver: "broken"})
if err == nil {
t.Fatalf("Build() expected error")
}
if !strings.Contains(err.Error(), `build sink "broken": boom`) {
t.Fatalf("Build() error = %q", err)
}
}
func TestRegistryBuildRejectsNilSink(t *testing.T) {
r := NewRegistry()
r.Register("nil_sink", func(config.SinkConfig) (Sink, error) { return nil, nil })
_, err := r.Build(config.SinkConfig{Driver: "nil_sink"})
if err == nil {
t.Fatalf("Build() expected nil sink error")
}
if !strings.Contains(err.Error(), "factory returned nil sink") {
t.Fatalf("Build() error = %q", err)
}
}
func TestRegisterBuiltinsExposesExpectedDrivers(t *testing.T) {
r := NewRegistry()
RegisterBuiltins(r)
if len(r.byDriver) != 2 {
t.Fatalf("len(byDriver) = %d, want 2", len(r.byDriver))
}
for _, driver := range []string{"stdout", "nats"} {
if _, ok := r.byDriver[driver]; !ok {
t.Fatalf("builtins missing driver %q", driver)
}
}
if _, ok := r.byDriver["postgres"]; ok {
t.Fatalf("builtins unexpectedly registered postgres driver")
}
}

52
sources/doc.go Normal file
View File

@@ -0,0 +1,52 @@
// Package sources defines feedkit's input-source abstractions and source
// registry.
//
// External API surface:
// - Input: common source identity surface
// - PollSource: polling source interface
// - StreamSource: streaming source interface
// - StreamRetryable / StreamFatal / IsStreamRetryable / IsStreamFatal:
// stream exit classification helpers
// - Registry / NewRegistry: source driver registry and builders
// - HTTPSource / NewHTTPSource: reusable HTTP polling helper
// - PostgresQuerySource / NewPostgresQuerySource: reusable Postgres polling
// helper
//
// Source drivers are domain-specific and registered into Registry by driver name.
// Registry can then build configured sources from config.SourceConfig.
//
// A single source may emit 0..N events per poll or stream iteration, and those
// events may span multiple event kinds.
//
// Optional helpers from helpers.go:
// - DefaultEventID: default event ID policy for source implementations
// - SingleEvent: construct and validate a one-element event slice
// - ValidateExpectedKinds: compare configured expected kinds against source
// advertised kinds when metadata is available
//
// HTTP-backed polling sources can share NewHTTPSource for generic HTTP config
// parsing and conditional GET behavior. The helper understands:
// - params.url
// - params.user_agent
// - params.conditional (optional, default true)
// - params.http_timeout (optional, default transport.DefaultHTTPTimeout)
// - params.http_response_body_limit_bytes (optional, default
// transport.DefaultHTTPResponseBodyLimitBytes)
//
// When validators are available, NewHTTPSource prefers ETag/If-None-Match and
// falls back to Last-Modified/If-Modified-Since. A 304 Not Modified response is
// treated as a successful unchanged poll.
//
// Postgres-backed polling sources can share NewPostgresQuerySource for generic
// DB config parsing and query execution. The helper understands:
// - params.uri
// - params.username
// - params.password
// - params.query
// - params.query_timeout (optional, default 30s)
//
// feedkit does not register a built-in postgres poll driver. Downstream daemons
// should register domain-specific driver names that call
// NewPostgresQuerySource, then keep SQL semantics, row scanning, ordering,
// watermark policy, and event construction in their own source types.
package sources

140
sources/helpers.go Normal file
View File

@@ -0,0 +1,140 @@
package sources
import (
"fmt"
"sort"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// DefaultEventID applies feedkit's default Event.ID policy:
//
// - If upstream provides an ID, use it (trimmed).
// - Otherwise, ID is "<Source>:<EffectiveAt>" when available.
// - If EffectiveAt is unavailable, fall back to "<Source>:<EmittedAt>".
//
// Timestamps are encoded as RFC3339Nano in UTC.
func DefaultEventID(upstreamID, sourceName string, effectiveAt *time.Time, emittedAt time.Time) string {
if id := strings.TrimSpace(upstreamID); id != "" {
return id
}
src := strings.TrimSpace(sourceName)
if src == "" {
src = "UNKNOWN_SOURCE"
}
if effectiveAt != nil && !effectiveAt.IsZero() {
return fmt.Sprintf("%s:%s", src, effectiveAt.UTC().Format(time.RFC3339Nano))
}
t := emittedAt.UTC()
if t.IsZero() {
t = time.Now().UTC()
}
return fmt.Sprintf("%s:%s", src, t.Format(time.RFC3339Nano))
}
// SingleEvent constructs, validates, and returns a slice containing exactly one event.
func SingleEvent(
kind event.Kind,
sourceName string,
schema string,
id string,
emittedAt time.Time,
effectiveAt *time.Time,
payload any,
) ([]event.Event, error) {
if emittedAt.IsZero() {
emittedAt = time.Now().UTC()
} else {
emittedAt = emittedAt.UTC()
}
e := event.Event{
ID: id,
Kind: kind,
Source: sourceName,
EmittedAt: emittedAt,
EffectiveAt: effectiveAt,
Schema: schema,
Payload: payload,
}
if err := e.Validate(); err != nil {
return nil, err
}
return []event.Event{e}, nil
}
// ValidateExpectedKinds checks that configured source expected kinds are a subset
// of the kinds advertised by the built source, when the source exposes kind
// metadata. If the source does not advertise kinds, the check is skipped.
func ValidateExpectedKinds(cfg config.SourceConfig, in Input) error {
expectedKinds, err := parseExpectedKinds(cfg.ExpectedKinds())
if err != nil {
return err
}
if len(expectedKinds) == 0 {
return nil
}
advertisedKinds := advertisedSourceKinds(in)
if len(advertisedKinds) == 0 {
return nil
}
for kind := range expectedKinds {
if !advertisedKinds[kind] {
return fmt.Errorf(
"configured expected kind %q not advertised by source (configured=%v advertised=%v)",
kind,
sortedKinds(expectedKinds),
sortedKinds(advertisedKinds),
)
}
}
return nil
}
func parseExpectedKinds(raw []string) (map[event.Kind]bool, error) {
kinds := map[event.Kind]bool{}
for i, k := range raw {
kind, err := event.ParseKind(k)
if err != nil {
return nil, fmt.Errorf("invalid expected kind at index %d (%q): %w", i, k, err)
}
kinds[kind] = true
}
return kinds, nil
}
func advertisedSourceKinds(in Input) map[event.Kind]bool {
if in == nil {
return nil
}
kinds := map[event.Kind]bool{}
if ks, ok := in.(KindsSource); ok {
for _, kind := range ks.Kinds() {
kinds[kind] = true
}
return kinds
}
return nil
}
func sortedKinds(kindSet map[event.Kind]bool) []string {
out := make([]string, 0, len(kindSet))
for kind := range kindSet {
out = append(out, string(kind))
}
sort.Strings(out)
return out
}

112
sources/helpers_test.go Normal file
View File

@@ -0,0 +1,112 @@
package sources
import (
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type testInput struct {
name string
}
func (s testInput) Name() string { return s.name }
type testKindsSource struct {
testInput
kinds []event.Kind
}
func (s testKindsSource) Kinds() []event.Kind { return s.kinds }
func TestValidateExpectedKindsSubsetAllowed(t *testing.T) {
cfg := config.SourceConfig{Kinds: []string{"observation"}}
in := testKindsSource{
testInput: testInput{name: "test"},
kinds: []event.Kind{"observation", "forecast"},
}
if err := ValidateExpectedKinds(cfg, in); err != nil {
t.Fatalf("ValidateExpectedKinds() unexpected error: %v", err)
}
}
func TestValidateExpectedKindsMismatchFails(t *testing.T) {
cfg := config.SourceConfig{Kinds: []string{"alert"}}
in := testKindsSource{
testInput: testInput{name: "test"},
kinds: []event.Kind{"observation", "forecast"},
}
err := ValidateExpectedKinds(cfg, in)
if err == nil {
t.Fatalf("ValidateExpectedKinds() expected mismatch error, got nil")
}
if !strings.Contains(err.Error(), "configured expected kind") {
t.Fatalf("ValidateExpectedKinds() error %q does not include expected message", err)
}
}
func TestValidateExpectedKindsNoMetadataSkipsCheck(t *testing.T) {
cfg := config.SourceConfig{Kinds: []string{"alert"}}
in := testInput{name: "test"}
if err := ValidateExpectedKinds(cfg, in); err != nil {
t.Fatalf("ValidateExpectedKinds() unexpected error: %v", err)
}
}
func TestDefaultEventIDUsesUpstreamID(t *testing.T) {
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123, time.UTC)
got := DefaultEventID(" upstream-id ", "source", nil, emittedAt)
if got != "upstream-id" {
t.Fatalf("DefaultEventID() = %q, want upstream-id", got)
}
}
func TestDefaultEventIDPrefersEffectiveAt(t *testing.T) {
effectiveAt := time.Date(2026, 3, 28, 16, 4, 5, 987654321, time.FixedZone("x", -6*3600))
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123, time.UTC)
got := DefaultEventID("", "source", &effectiveAt, emittedAt)
want := "source:" + effectiveAt.UTC().Format(time.RFC3339Nano)
if got != want {
t.Fatalf("DefaultEventID() = %q, want %q", got, want)
}
}
func TestDefaultEventIDFallsBackToEmittedAt(t *testing.T) {
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123456789, time.FixedZone("y", 3*3600))
got := DefaultEventID("", "source", nil, emittedAt)
want := "source:" + emittedAt.UTC().Format(time.RFC3339Nano)
if got != want {
t.Fatalf("DefaultEventID() = %q, want %q", got, want)
}
}
func TestSingleEventBuildsValidatedSlice(t *testing.T) {
effectiveAt := time.Date(2026, 3, 28, 16, 0, 0, 0, time.UTC)
emittedAt := time.Date(2026, 3, 28, 15, 0, 0, 0, time.FixedZone("z", -5*3600))
got, err := SingleEvent(
event.Kind("observation"),
"source-a",
"raw.example.v1",
"evt-1",
emittedAt,
&effectiveAt,
map[string]any{"ok": true},
)
if err != nil {
t.Fatalf("SingleEvent() unexpected error: %v", err)
}
if len(got) != 1 {
t.Fatalf("SingleEvent() len = %d, want 1", len(got))
}
if got[0].EmittedAt != emittedAt.UTC() {
t.Fatalf("SingleEvent() emittedAt = %s, want %s", got[0].EmittedAt, emittedAt.UTC())
}
}

152
sources/http.go Normal file
View File

@@ -0,0 +1,152 @@
package sources
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/transport"
)
// HTTPSource is a reusable helper for polling HTTP-backed sources.
//
// It centralizes generic source config parsing (`params.url`,
// `params.user_agent`, and optional `params.conditional`), default HTTP client
// setup, and conditional GET validator handling. Concrete daemon sources remain
// responsible for decoding the response body and constructing events.
type HTTPSource struct {
Driver string
Name string
URL string
UserAgent string
Accept string
Conditional bool
ResponseBodyLimitBytes int64
Client *http.Client
mu sync.Mutex
validators transport.HTTPValidators
}
// NewHTTPSource builds a generic HTTP polling helper from SourceConfig.
//
// Required params:
// - params.url
// - params.user_agent
//
// Optional params:
// - params.conditional (default true): enable conditional GET using cached
// ETag / Last-Modified validators
func NewHTTPSource(driver string, cfg config.SourceConfig, accept string) (*HTTPSource, error) {
name := strings.TrimSpace(cfg.Name)
if name == "" {
return nil, fmt.Errorf("%s: name is required", driver)
}
if cfg.Params == nil {
return nil, fmt.Errorf("%s %q: params are required (need params.url and params.user_agent)", driver, cfg.Name)
}
url, ok := cfg.ParamString("url", "URL")
if !ok {
return nil, fmt.Errorf("%s %q: params.url is required", driver, cfg.Name)
}
userAgent, ok := cfg.ParamString("user_agent", "userAgent")
if !ok {
return nil, fmt.Errorf("%s %q: params.user_agent is required", driver, cfg.Name)
}
conditional := true
if _, exists := cfg.Params["conditional"]; exists {
var ok bool
conditional, ok = cfg.ParamBool("conditional")
if !ok {
return nil, fmt.Errorf("source %q: params.conditional must be a boolean", cfg.Name)
}
}
timeout := transport.DefaultHTTPTimeout
if _, exists := cfg.Params["http_timeout"]; exists {
var ok bool
timeout, ok = cfg.ParamDuration("http_timeout")
if !ok || timeout <= 0 {
return nil, fmt.Errorf("source %q: params.http_timeout must be a positive duration", cfg.Name)
}
}
bodyLimit := transport.DefaultHTTPResponseBodyLimitBytes
if _, exists := cfg.Params["http_response_body_limit_bytes"]; exists {
rawLimit, ok := cfg.ParamInt("http_response_body_limit_bytes")
if !ok || rawLimit <= 0 {
return nil, fmt.Errorf("source %q: params.http_response_body_limit_bytes must be a positive integer", cfg.Name)
}
bodyLimit = int64(rawLimit)
}
return &HTTPSource{
Driver: driver,
Name: name,
URL: url,
UserAgent: userAgent,
Accept: accept,
Conditional: conditional,
ResponseBodyLimitBytes: bodyLimit,
Client: transport.NewHTTPClient(timeout),
}, nil
}
// FetchBytesIfChanged fetches the configured URL and reports whether the
// upstream content changed. An unchanged 304 response returns changed=false
// with no body and no error.
func (s *HTTPSource) FetchBytesIfChanged(ctx context.Context) ([]byte, bool, error) {
client := s.Client
if client == nil {
client = transport.NewHTTPClient(transport.DefaultHTTPTimeout)
}
s.mu.Lock()
validators := s.validators
s.mu.Unlock()
bodyLimit := s.ResponseBodyLimitBytes
if bodyLimit <= 0 {
bodyLimit = transport.DefaultHTTPResponseBodyLimitBytes
}
body, changed, next, err := transport.FetchBodyIfChangedWithLimit(
ctx,
client,
s.URL,
s.UserAgent,
s.Accept,
s.Conditional,
validators,
bodyLimit,
)
if err != nil {
return nil, false, fmt.Errorf("%s %q: %w", s.Driver, s.Name, err)
}
if s.Conditional {
s.mu.Lock()
s.validators = next
s.mu.Unlock()
}
return body, changed, nil
}
// FetchJSONIfChanged fetches the configured URL and returns the raw response
// body as json.RawMessage when content changed. An unchanged 304 response
// returns changed=false with a nil body and no error.
func (s *HTTPSource) FetchJSONIfChanged(ctx context.Context) (json.RawMessage, bool, error) {
body, changed, err := s.FetchBytesIfChanged(ctx)
if err != nil || !changed {
return nil, changed, err
}
return json.RawMessage(body), true, nil
}

260
sources/http_test.go Normal file
View File

@@ -0,0 +1,260 @@
package sources
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/transport"
)
func TestNewHTTPSourceConditionalDefaultsTrue(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if !src.Conditional {
t.Fatalf("Conditional = false, want true")
}
}
func TestNewHTTPSourceRejectsInvalidConditional(t *testing.T) {
_, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"conditional": "sometimes",
},
}, "application/json")
if err == nil {
t.Fatalf("NewHTTPSource() error = nil, want error")
}
if !strings.Contains(err.Error(), "params.conditional must be a boolean") {
t.Fatalf("NewHTTPSource() error = %q, want params.conditional must be a boolean", err)
}
}
func TestNewHTTPSourceConditionalCanBeExplicitlyFalse(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"conditional": false,
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.Conditional {
t.Fatalf("Conditional = true, want false")
}
}
func TestNewHTTPSourceHTTPTimeoutOverride(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"http_timeout": "250ms",
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.Client == nil {
t.Fatalf("Client = nil")
}
if src.Client.Timeout != 250*time.Millisecond {
t.Fatalf("Client.Timeout = %s, want 250ms", src.Client.Timeout)
}
}
func TestNewHTTPSourceBodyLimitOverride(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"http_response_body_limit_bytes": 12345,
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.ResponseBodyLimitBytes != 12345 {
t.Fatalf("ResponseBodyLimitBytes = %d, want 12345", src.ResponseBodyLimitBytes)
}
}
func TestNewHTTPSourceRejectsInvalidHTTPTimeout(t *testing.T) {
_, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"http_timeout": "soon",
},
}, "application/json")
if err == nil {
t.Fatalf("NewHTTPSource() error = nil, want error")
}
if !strings.Contains(err.Error(), "params.http_timeout must be a positive duration") {
t.Fatalf("NewHTTPSource() error = %q", err)
}
}
func TestNewHTTPSourceRejectsInvalidBodyLimit(t *testing.T) {
_, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
"http_response_body_limit_bytes": "abc",
},
}, "application/json")
if err == nil {
t.Fatalf("NewHTTPSource() error = nil, want error")
}
if !strings.Contains(err.Error(), "params.http_response_body_limit_bytes must be a positive integer") {
t.Fatalf("NewHTTPSource() error = %q", err)
}
}
func TestHTTPSourceFetchJSONIfChanged(t *testing.T) {
var call int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
call++
switch call {
case 1:
w.Header().Set("ETag", `"v1"`)
_, _ = w.Write([]byte(`{"ok":true}`))
case 2:
if got := r.Header.Get("If-None-Match"); got != `"v1"` {
t.Fatalf("second request If-None-Match = %q", got)
}
w.WriteHeader(http.StatusNotModified)
default:
t.Fatalf("unexpected call count %d", call)
}
}))
defer srv.Close()
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": srv.URL,
"user_agent": "test-agent",
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
raw, changed, err := src.FetchJSONIfChanged(context.Background())
if err != nil {
t.Fatalf("first FetchJSONIfChanged() error = %v", err)
}
if !changed {
t.Fatalf("first FetchJSONIfChanged() changed = false, want true")
}
if got := string(raw); got != `{"ok":true}` {
t.Fatalf("first FetchJSONIfChanged() body = %q", got)
}
raw, changed, err = src.FetchJSONIfChanged(context.Background())
if err != nil {
t.Fatalf("second FetchJSONIfChanged() error = %v", err)
}
if changed {
t.Fatalf("second FetchJSONIfChanged() changed = true, want false")
}
if raw != nil {
t.Fatalf("second FetchJSONIfChanged() body = %q, want nil", string(raw))
}
}
func TestHTTPSourceFetchJSONIfChangedHonorsBodyLimitOverride(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": srv.URL,
"user_agent": "test-agent",
"http_response_body_limit_bytes": 4,
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
_, _, err = src.FetchJSONIfChanged(context.Background())
if err == nil {
t.Fatalf("FetchJSONIfChanged() error = nil, want limit error")
}
if !strings.Contains(err.Error(), "response body too large") {
t.Fatalf("FetchJSONIfChanged() error = %q", err)
}
}
func TestNewHTTPSourceUsesDefaultBodyLimitWhenUnset(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.ResponseBodyLimitBytes != transport.DefaultHTTPResponseBodyLimitBytes {
t.Fatalf("ResponseBodyLimitBytes = %d, want %d", src.ResponseBodyLimitBytes, transport.DefaultHTTPResponseBodyLimitBytes)
}
}
func TestNewHTTPSourceUsesDefaultTimeoutWhenUnset(t *testing.T) {
src, err := NewHTTPSource("test_driver", config.SourceConfig{
Name: "test-source",
Driver: "test_driver",
Params: map[string]any{
"url": "https://example.invalid",
"user_agent": "test-agent",
},
}, "application/json")
if err != nil {
t.Fatalf("NewHTTPSource() error = %v", err)
}
if src.Client == nil {
t.Fatalf("Client = nil")
}
if src.Client.Timeout != transport.DefaultHTTPTimeout {
t.Fatalf("Client.Timeout = %s, want %s", src.Client.Timeout, transport.DefaultHTTPTimeout)
}
}

117
sources/postgres.go Normal file
View File

@@ -0,0 +1,117 @@
package sources
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
)
const defaultPostgresQueryTimeout = 30 * time.Second
type postgresQueryDB interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
var openPostgresQueryDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) {
return pgconn.Open(ctx, cfg)
}
// PostgresQuerySource is a reusable helper for polling Postgres-backed sources.
//
// It centralizes generic source config parsing and query execution. Concrete
// daemon sources remain responsible for SQL semantics, row scanning, cursoring,
// and event construction.
type PostgresQuerySource struct {
Driver string
Name string
SQL string
QueryTimeout time.Duration
db postgresQueryDB
}
// NewPostgresQuerySource builds a generic Postgres polling helper from
// SourceConfig.
//
// Required params:
// - params.uri
// - params.username
// - params.password
// - params.query
//
// Optional params:
// - params.query_timeout (default 30s)
func NewPostgresQuerySource(driver string, cfg config.SourceConfig) (*PostgresQuerySource, error) {
name := strings.TrimSpace(cfg.Name)
if name == "" {
return nil, fmt.Errorf("%s: name is required", driver)
}
if cfg.Params == nil {
return nil, fmt.Errorf("%s %q: params are required (need params.uri, params.username, params.password, and params.query)", driver, cfg.Name)
}
uri, ok := cfg.ParamString("uri")
if !ok {
return nil, fmt.Errorf("%s %q: params.uri is required", driver, cfg.Name)
}
username, ok := cfg.ParamString("username")
if !ok {
return nil, fmt.Errorf("%s %q: params.username is required", driver, cfg.Name)
}
password, ok := cfg.ParamString("password")
if !ok {
return nil, fmt.Errorf("%s %q: params.password is required", driver, cfg.Name)
}
query, ok := cfg.ParamString("query")
if !ok {
return nil, fmt.Errorf("%s %q: params.query is required", driver, cfg.Name)
}
queryTimeout := defaultPostgresQueryTimeout
if _, exists := cfg.Params["query_timeout"]; exists {
var ok bool
queryTimeout, ok = cfg.ParamDuration("query_timeout")
if !ok || queryTimeout <= 0 {
return nil, fmt.Errorf("source %q: params.query_timeout must be a positive duration", cfg.Name)
}
}
db, err := openPostgresQueryDB(context.Background(), pgconn.ConnConfig{
URI: uri,
Username: username,
Password: password,
})
if err != nil {
return nil, fmt.Errorf("%s %q: open db: %w", driver, cfg.Name, err)
}
return &PostgresQuerySource{
Driver: driver,
Name: name,
SQL: query,
QueryTimeout: queryTimeout,
db: db,
}, nil
}
func (s *PostgresQuerySource) Query(ctx context.Context, args ...any) (*sql.Rows, error) {
queryCtx := ctx
if s.QueryTimeout > 0 {
if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > s.QueryTimeout {
// We intentionally do not cancel this derived context here because the
// returned rows may still be reading from the database.
queryCtx, _ = context.WithTimeout(ctx, s.QueryTimeout)
}
}
rows, err := s.db.QueryContext(queryCtx, s.SQL, args...)
if err != nil {
return nil, fmt.Errorf("%s %q: query: %w", s.Driver, s.Name, err)
}
return rows, nil
}

352
sources/postgres_test.go Normal file
View File

@@ -0,0 +1,352 @@
package sources
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"io"
"strings"
"sync"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
)
type fakePostgresQueryDB struct {
queryErr error
lastCtx context.Context
lastQuery string
lastArgs []any
returnRows *sql.Rows
}
func (db *fakePostgresQueryDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
db.lastCtx = ctx
db.lastQuery = query
db.lastArgs = append([]any(nil), args...)
if db.queryErr != nil {
return nil, db.queryErr
}
return db.returnRows, nil
}
func withPostgresQuerySourceTestState(t *testing.T) {
t.Helper()
oldOpen := openPostgresQueryDB
t.Cleanup(func() {
openPostgresQueryDB = oldOpen
})
}
func TestNewPostgresQuerySourceMissingParams(t *testing.T) {
withPostgresQuerySourceTestState(t)
tests := []struct {
name string
params map[string]any
want string
}{
{name: "missing uri", params: map[string]any{"username": "u", "password": "p", "query": "SELECT 1"}, want: "params.uri"},
{name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p", "query": "SELECT 1"}, want: "params.username"},
{name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "query": "SELECT 1"}, want: "params.password"},
{name: "missing query", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "password": "p"}, want: "params.query"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: tc.params,
})
if err == nil {
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
}
if !strings.Contains(err.Error(), tc.want) {
t.Fatalf("NewPostgresQuerySource() error = %q, want substring %q", err, tc.want)
}
})
}
}
func TestNewPostgresQuerySourceRejectsInvalidQueryTimeout(t *testing.T) {
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "u",
"password": "p",
"query": "SELECT 1",
"query_timeout": "soon",
},
})
if err == nil {
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
}
if !strings.Contains(err.Error(), "params.query_timeout must be a positive duration") {
t.Fatalf("NewPostgresQuerySource() error = %q", err)
}
}
func TestNewPostgresQuerySourceSuccessfulConstruction(t *testing.T) {
withPostgresQuerySourceTestState(t)
db := &fakePostgresQueryDB{}
var gotCfg pgconn.ConnConfig
openPostgresQueryDB = func(_ context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) {
gotCfg = cfg
return db, nil
}
src, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: map[string]any{
"uri": "postgres://db.example.local/feedkit",
"username": "app_user",
"password": "app_pass",
"query": "SELECT * FROM observations",
"query_timeout": "45s",
},
})
if err != nil {
t.Fatalf("NewPostgresQuerySource() error = %v", err)
}
if src.Name != "pg-source" {
t.Fatalf("Name = %q, want pg-source", src.Name)
}
if src.QueryTimeout != 45*time.Second {
t.Fatalf("QueryTimeout = %s, want 45s", src.QueryTimeout)
}
if src.SQL != "SELECT * FROM observations" {
t.Fatalf("SQL = %q", src.SQL)
}
if gotCfg.Username != "app_user" || gotCfg.Password != "app_pass" {
t.Fatalf("ConnConfig = %+v", gotCfg)
}
}
func TestNewPostgresQuerySourceOpenFailure(t *testing.T) {
withPostgresQuerySourceTestState(t)
openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) {
return nil, errors.New("db unavailable")
}
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "u",
"password": "p",
"query": "SELECT 1",
},
})
if err == nil {
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
}
if !strings.Contains(err.Error(), `test_driver "pg-source": open db: db unavailable`) {
t.Fatalf("NewPostgresQuerySource() error = %q", err)
}
}
func TestPostgresQuerySourceQueryAppliesTimeoutAndWrapsError(t *testing.T) {
db := &fakePostgresQueryDB{queryErr: errors.New("query failed")}
src := &PostgresQuerySource{
Driver: "test_driver",
Name: "pg-source",
SQL: "SELECT 1",
QueryTimeout: 30 * time.Second,
db: db,
}
ctx := context.Background()
_, err := src.Query(ctx, "arg1")
if err == nil {
t.Fatalf("Query() error = nil, want error")
}
if !strings.Contains(err.Error(), `test_driver "pg-source": query: query failed`) {
t.Fatalf("Query() error = %q", err)
}
if db.lastCtx == nil {
t.Fatalf("lastCtx = nil")
}
if _, ok := db.lastCtx.Deadline(); !ok {
t.Fatalf("expected derived deadline on query context")
}
if db.lastQuery != "SELECT 1" {
t.Fatalf("lastQuery = %q", db.lastQuery)
}
if len(db.lastArgs) != 1 || db.lastArgs[0] != "arg1" {
t.Fatalf("lastArgs = %#v", db.lastArgs)
}
}
func TestPostgresQuerySourceQueryUsesEarlierCallerDeadline(t *testing.T) {
db := &fakePostgresQueryDB{queryErr: errors.New("query failed")}
src := &PostgresQuerySource{
Driver: "test_driver",
Name: "pg-source",
SQL: "SELECT 1",
QueryTimeout: 30 * time.Second,
db: db,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
_, _ = src.Query(ctx)
if db.lastCtx != ctx {
t.Fatalf("expected source to reuse earlier caller deadline")
}
}
func TestPostgresQuerySourceSupportsDownstreamPollingPattern(t *testing.T) {
withPostgresQuerySourceTestState(t)
db, cleanup := openRowsTestDB(t, "feedkit_sources_pg_rows", []string{"event_id"}, [][]driver.Value{{"evt-1"}})
defer cleanup()
openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) {
return db, nil
}
type fakeDownstreamSource struct {
pg *PostgresQuerySource
}
poll := func(s fakeDownstreamSource, ctx context.Context) ([]event.Event, error) {
rows, err := s.pg.Query(ctx)
if err != nil {
return nil, err
}
defer rows.Close()
var out []event.Event
for rows.Next() {
var eventID string
if err := rows.Scan(&eventID); err != nil {
return nil, err
}
out = append(out, event.Event{
ID: eventID,
Kind: event.Kind("observation"),
Source: s.pg.Name,
Schema: "raw.test.v1",
EmittedAt: time.Now().UTC(),
Payload: map[string]any{"event_id": eventID},
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
pg, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
Name: "pg-source",
Driver: "test_driver",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "u",
"password": "p",
"query": "SELECT event_id FROM events",
},
})
if err != nil {
t.Fatalf("NewPostgresQuerySource() error = %v", err)
}
events, err := poll(fakeDownstreamSource{pg: pg}, context.Background())
if err != nil {
t.Fatalf("poll() error = %v", err)
}
if len(events) != 1 {
t.Fatalf("len(events) = %d, want 1", len(events))
}
if events[0].ID != "evt-1" {
t.Fatalf("events[0].ID = %q, want evt-1", events[0].ID)
}
}
var (
rowsDriverMu sync.Mutex
rowsDriverSeen = map[string]bool{}
)
func openRowsTestDB(t *testing.T, driverName string, columns []string, rows [][]driver.Value) (*sql.DB, func()) {
t.Helper()
rowsDriverMu.Lock()
if !rowsDriverSeen[driverName] {
sql.Register(driverName, &rowsTestDriver{columns: append([]string(nil), columns...), rows: cloneDriverRows(rows)})
rowsDriverSeen[driverName] = true
}
rowsDriverMu.Unlock()
db, err := sql.Open(driverName, "")
if err != nil {
t.Fatalf("sql.Open() error = %v", err)
}
return db, func() {
_ = db.Close()
}
}
func cloneDriverRows(in [][]driver.Value) [][]driver.Value {
out := make([][]driver.Value, 0, len(in))
for _, row := range in {
copied := append([]driver.Value(nil), row...)
out = append(out, copied)
}
return out
}
type rowsTestDriver struct {
columns []string
rows [][]driver.Value
}
func (d *rowsTestDriver) Open(string) (driver.Conn, error) {
return &rowsTestConn{columns: append([]string(nil), d.columns...), rows: cloneDriverRows(d.rows)}, nil
}
type rowsTestConn struct {
columns []string
rows [][]driver.Value
}
func (c *rowsTestConn) Prepare(string) (driver.Stmt, error) {
return nil, errors.New("not implemented")
}
func (c *rowsTestConn) Close() error { return nil }
func (c *rowsTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") }
func (c *rowsTestConn) QueryContext(_ context.Context, _ string, _ []driver.NamedValue) (driver.Rows, error) {
return &rowsTestRows{columns: append([]string(nil), c.columns...), rows: cloneDriverRows(c.rows)}, nil
}
type rowsTestRows struct {
columns []string
rows [][]driver.Value
idx int
}
func (r *rowsTestRows) Columns() []string { return append([]string(nil), r.columns...) }
func (r *rowsTestRows) Close() error { return nil }
func (r *rowsTestRows) Next(dest []driver.Value) error {
if r.idx >= len(r.rows) {
return io.EOF
}
copy(dest, r.rows[r.idx])
r.idx++
return nil
}

View File

@@ -7,43 +7,115 @@ import (
"gitea.maximumdirect.net/ejr/feedkit/config"
)
// Factory constructs a configured Source instance from config.
// PollFactory constructs a configured PollSource instance from config.
//
// This is how concrete daemons (weatherfeeder/newsfeeder/...) register their
// domain-specific source drivers (Open-Meteo, NWS, RSS, etc.) while feedkit
// remains domain-agnostic.
type Factory func(cfg config.SourceConfig) (Source, error)
type PollFactory func(cfg config.SourceConfig) (PollSource, error)
type StreamFactory func(cfg config.SourceConfig) (StreamSource, error)
type Registry struct {
byDriver map[string]Factory
byPollDriver map[string]PollFactory
byStreamDriver map[string]StreamFactory
}
func NewRegistry() *Registry {
return &Registry{byDriver: map[string]Factory{}}
return &Registry{
byPollDriver: map[string]PollFactory{},
byStreamDriver: map[string]StreamFactory{},
}
}
// Register associates a driver name (e.g. "openmeteo_observation") with a factory.
//
// The driver string is the "lookup key" used by config.sources[].driver.
func (r *Registry) Register(driver string, f Factory) {
// RegisterPoll associates a driver name with a polling-source factory.
func (r *Registry) RegisterPoll(driver string, f PollFactory) {
driver = strings.TrimSpace(driver)
if driver == "" {
// Panic is appropriate here: registering an empty driver is always a programmer error,
// and it will lead to extremely confusing runtime behavior if allowed.
panic("sources.Registry.Register: driver cannot be empty")
panic("sources.Registry.RegisterPoll: driver cannot be empty")
}
if f == nil {
panic(fmt.Sprintf("sources.Registry.Register: factory cannot be nil (driver=%q)", driver))
panic(fmt.Sprintf("sources.Registry.RegisterPoll: factory cannot be nil (driver=%q)", driver))
}
if _, exists := r.byStreamDriver[driver]; exists {
panic(fmt.Sprintf("sources.Registry.RegisterPoll: driver %q already registered as a stream source", driver))
}
if _, exists := r.byPollDriver[driver]; exists {
panic(fmt.Sprintf("sources.Registry.RegisterPoll: driver %q already registered as a polling source", driver))
}
r.byPollDriver[driver] = f
}
r.byDriver[driver] = f
// RegisterStream is the StreamSource equivalent of Register.
func (r *Registry) RegisterStream(driver string, f StreamFactory) {
driver = strings.TrimSpace(driver)
if driver == "" {
panic("sources.Registry.RegisterStream: driver cannot be empty")
}
if f == nil {
panic(fmt.Sprintf("sources.Registry.RegisterStream: factory cannot be nil (driver=%q)", driver))
}
if _, exists := r.byPollDriver[driver]; exists {
panic(fmt.Sprintf("sources.Registry.RegisterStream: driver %q already registered as a polling source", driver))
}
if _, exists := r.byStreamDriver[driver]; exists {
panic(fmt.Sprintf("sources.Registry.RegisterStream: driver %q already registered as a stream source", driver))
}
r.byStreamDriver[driver] = f
}
// Build constructs a Source from a SourceConfig by looking up cfg.Driver.
func (r *Registry) Build(cfg config.SourceConfig) (Source, error) {
f, ok := r.byDriver[cfg.Driver]
// BuildPoll constructs a polling source from a SourceConfig by looking up cfg.Driver.
func (r *Registry) BuildPoll(cfg config.SourceConfig) (PollSource, error) {
driver := strings.TrimSpace(cfg.Driver)
if cfg.Mode.Normalize() == config.SourceModeStream {
return nil, fmt.Errorf("source %q mode=stream cannot be built as polling source", cfg.Name)
}
f, ok := r.byPollDriver[driver]
if !ok {
return nil, fmt.Errorf("unknown source driver: %q", cfg.Driver)
if _, streamExists := r.byStreamDriver[driver]; streamExists {
return nil, fmt.Errorf("source driver %q is stream-only; cannot build as polling source", driver)
}
return nil, fmt.Errorf("unknown source driver: %q", driver)
}
return f(cfg)
}
// BuildInput can return either a polling Source or a StreamSource.
func (r *Registry) BuildInput(cfg config.SourceConfig) (Input, error) {
driver := strings.TrimSpace(cfg.Driver)
mode := cfg.Mode.Normalize()
if mode != config.SourceModeAuto && mode != config.SourceModePoll && mode != config.SourceModeStream {
return nil, fmt.Errorf("source %q has invalid mode %q (expected \"poll\" or \"stream\")", cfg.Name, cfg.Mode)
}
switch mode {
case config.SourceModePoll:
f, ok := r.byPollDriver[driver]
if !ok {
if _, streamExists := r.byStreamDriver[driver]; streamExists {
return nil, fmt.Errorf("source %q mode=poll conflicts with stream-only driver %q", cfg.Name, driver)
}
return nil, fmt.Errorf("unknown source driver: %q", driver)
}
return f(cfg)
case config.SourceModeStream:
f, ok := r.byStreamDriver[driver]
if !ok {
if _, pollExists := r.byPollDriver[driver]; pollExists {
return nil, fmt.Errorf("source %q mode=stream conflicts with polling driver %q", cfg.Name, driver)
}
return nil, fmt.Errorf("unknown source driver: %q", driver)
}
return f(cfg)
}
if f, ok := r.byStreamDriver[driver]; ok {
return f(cfg)
}
if f, ok := r.byPollDriver[driver]; ok {
return f(cfg)
}
return nil, fmt.Errorf("unknown source driver: %q", driver)
}

84
sources/registry_test.go Normal file
View File

@@ -0,0 +1,84 @@
package sources
import (
"context"
"strings"
"testing"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type testPollSource struct{ name string }
func (s testPollSource) Name() string { return s.name }
func (s testPollSource) Poll(context.Context) ([]event.Event, error) {
return nil, nil
}
type testStreamSource struct{ name string }
func (s testStreamSource) Name() string { return s.name }
func (s testStreamSource) Run(context.Context, chan<- event.Event) error {
return nil
}
func TestRegistryBuildInputModeConflicts(t *testing.T) {
r := NewRegistry()
r.RegisterPoll("poll_driver", func(cfg config.SourceConfig) (PollSource, error) {
return testPollSource{name: cfg.Name}, nil
})
r.RegisterStream("stream_driver", func(cfg config.SourceConfig) (StreamSource, error) {
return testStreamSource{name: cfg.Name}, nil
})
_, err := r.BuildInput(config.SourceConfig{
Name: "s1",
Driver: "stream_driver",
Mode: config.SourceModePoll,
})
if err == nil {
t.Fatalf("expected mode conflict error, got nil")
}
if !strings.Contains(err.Error(), "mode=poll") {
t.Fatalf("expected poll conflict error, got: %v", err)
}
_, err = r.BuildInput(config.SourceConfig{
Name: "s2",
Driver: "poll_driver",
Mode: config.SourceModeStream,
})
if err == nil {
t.Fatalf("expected mode conflict error, got nil")
}
if !strings.Contains(err.Error(), "mode=stream") {
t.Fatalf("expected stream conflict error, got: %v", err)
}
}
func TestRegistryBuildInputAutoByDriverType(t *testing.T) {
r := NewRegistry()
r.RegisterPoll("poll_driver", func(cfg config.SourceConfig) (PollSource, error) {
return testPollSource{name: cfg.Name}, nil
})
r.RegisterStream("stream_driver", func(cfg config.SourceConfig) (StreamSource, error) {
return testStreamSource{name: cfg.Name}, nil
})
src, err := r.BuildInput(config.SourceConfig{Name: "p", Driver: "poll_driver"})
if err != nil {
t.Fatalf("BuildInput poll auto failed: %v", err)
}
if _, ok := src.(PollSource); !ok {
t.Fatalf("expected PollSource, got %T", src)
}
src, err = r.BuildInput(config.SourceConfig{Name: "s", Driver: "stream_driver"})
if err != nil {
t.Fatalf("BuildInput stream auto failed: %v", err)
}
if _, ok := src.(StreamSource); !ok {
t.Fatalf("expected StreamSource, got %T", src)
}
}

View File

@@ -6,25 +6,44 @@ import (
"gitea.maximumdirect.net/ejr/feedkit/event"
)
// Source is a configured polling job that emits 0..N events per poll.
// Input is the common surface shared by all source types.
//
// Source implementations live in domain modules (weatherfeeder/newsfeeder/...)
// A source may be polling (PollSource) or event-driven (StreamSource).
// Both source types emit domain-agnostic event.Event values.
type Input interface {
Name() string
}
// PollSource is a configured polling source that emits 0..N events per poll.
//
// PollSource implementations live in domain modules (weatherfeeder/newsfeeder/...)
// and are registered into a feedkit sources.Registry.
//
// feedkit infrastructure treats Source as opaque; it just calls Poll()
// feedkit infrastructure treats PollSource as opaque; it just calls Poll()
// on the configured cadence and publishes the resulting events.
type Source interface {
type PollSource interface {
// Name is the configured source name (used for logs and included in emitted events).
Name() string
// Kind is the "primary kind" emitted by this source.
//
// This is mainly useful as a *safety check* (e.g. config says kind=forecast but
// driver emits observation). Some future sources may emit multiple kinds; if/when
// that happens, we can evolve this interface (e.g., make Kind optional, or remove it).
Kind() event.Kind
// Poll fetches from upstream and returns 0..N events.
// Poll fetches/processes one input batch and returns 0..N events.
// A single poll can emit multiple event kinds.
// Implementations should honor ctx.Done() for network calls and other I/O.
Poll(ctx context.Context) ([]event.Event, error)
}
// StreamSource is an event-driven source (NATS/RabbitMQ/MQTT/etc).
//
// Run should block, producing events into `out` until ctx is cancelled or a fatal error occurs.
// It MUST NOT close out (the scheduler/daemon owns the bus).
//
// Stream sources can classify exits by wrapping errors with StreamRetryable or
// StreamFatal. Plain non-nil errors are treated as retryable by the scheduler.
type StreamSource interface {
Input
Run(ctx context.Context, out chan<- event.Event) error
}
// KindsSource is an optional interface for sources that advertise multiple kinds.
type KindsSource interface {
Kinds() []event.Kind
}

63
sources/stream_errors.go Normal file
View File

@@ -0,0 +1,63 @@
package sources
import "errors"
type streamRetryableError struct {
err error
}
func (e *streamRetryableError) Error() string {
if e.err == nil {
return "retryable stream error"
}
return e.err.Error()
}
func (e *streamRetryableError) Unwrap() error { return e.err }
type streamFatalError struct {
err error
}
func (e *streamFatalError) Error() string {
if e.err == nil {
return "fatal stream error"
}
return e.err.Error()
}
func (e *streamFatalError) Unwrap() error { return e.err }
// StreamRetryable marks a stream-source exit as retryable.
func StreamRetryable(err error) error {
if err == nil {
return nil
}
return &streamRetryableError{err: err}
}
// StreamFatal marks a stream-source exit as fatal.
func StreamFatal(err error) error {
if err == nil {
return nil
}
return &streamFatalError{err: err}
}
// IsStreamRetryable reports whether err contains a retryable stream marker.
func IsStreamRetryable(err error) bool {
if err == nil {
return false
}
var target *streamRetryableError
return errors.As(err, &target)
}
// IsStreamFatal reports whether err contains a fatal stream marker.
func IsStreamFatal(err error) bool {
if err == nil {
return false
}
var target *streamFatalError
return errors.As(err, &target)
}

View File

@@ -0,0 +1,52 @@
package sources
import (
"errors"
"fmt"
"testing"
)
func TestStreamRetryableWrapsThroughErrorChains(t *testing.T) {
base := errors.New("retry me")
err := fmt.Errorf("outer: %w", StreamRetryable(base))
if !IsStreamRetryable(err) {
t.Fatalf("IsStreamRetryable() = false, want true")
}
if IsStreamFatal(err) {
t.Fatalf("IsStreamFatal() = true, want false")
}
if !errors.Is(err, base) {
t.Fatalf("errors.Is(err, base) = false, want true")
}
}
func TestStreamFatalWrapsThroughErrorChains(t *testing.T) {
base := errors.New("fatal")
err := fmt.Errorf("outer: %w", StreamFatal(base))
if !IsStreamFatal(err) {
t.Fatalf("IsStreamFatal() = false, want true")
}
if IsStreamRetryable(err) {
t.Fatalf("IsStreamRetryable() = true, want false")
}
if !errors.Is(err, base) {
t.Fatalf("errors.Is(err, base) = false, want true")
}
}
func TestStreamErrorHelpersNil(t *testing.T) {
if StreamRetryable(nil) != nil {
t.Fatalf("StreamRetryable(nil) != nil")
}
if StreamFatal(nil) != nil {
t.Fatalf("StreamFatal(nil) != nil")
}
if IsStreamRetryable(nil) {
t.Fatalf("IsStreamRetryable(nil) = true")
}
if IsStreamFatal(nil) {
t.Fatalf("IsStreamFatal(nil) = true")
}
}

190
transport/http.go Normal file
View File

@@ -0,0 +1,190 @@
// FILE: ./transport/http.go
package transport
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// DefaultHTTPResponseBodyLimitBytes is a hard safety limit on HTTP response bodies.
// API responses should be small, so this protects us from accidental
// or malicious large responses.
const DefaultHTTPResponseBodyLimitBytes int64 = 2 << 21 // 4 MiB
// DefaultHTTPTimeout is the standard timeout used by HTTP sources.
// Individual drivers may override this if they have a specific need.
const DefaultHTTPTimeout = 10 * time.Second
// NewHTTPClient returns a simple http.Client configured with a timeout.
// If timeout <= 0, DefaultHTTPTimeout is used.
func NewHTTPClient(timeout time.Duration) *http.Client {
if timeout <= 0 {
timeout = DefaultHTTPTimeout
}
return &http.Client{Timeout: timeout}
}
func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept string) ([]byte, error) {
return FetchBodyWithLimit(ctx, client, url, userAgent, accept, DefaultHTTPResponseBodyLimitBytes)
}
func FetchBodyWithLimit(ctx context.Context, client *http.Client, url, userAgent, accept string, bodyLimitBytes int64) ([]byte, error) {
res, err := doRequest(ctx, client, http.MethodGet, url, userAgent, accept, "", "")
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return nil, fmt.Errorf("HTTP %s", res.Status)
}
return readValidatedBody(res.Body, bodyLimitBytes)
}
// HTTPValidators are cache validators learned from prior successful GET responses.
//
// ETag is preferred when present. LastModified is used as a fallback validator
// when ETag is unavailable.
type HTTPValidators struct {
ETag string
LastModified string
}
// FetchBodyIfChanged performs an HTTP GET and opportunistically uses conditional
// request headers based on the provided validators.
//
// Behavior:
// - if conditional is false, this behaves like a normal GET and leaves validators unchanged
// - if validators.ETag is set, sends If-None-Match
// - else if validators.LastModified is set, sends If-Modified-Since
// - 304 Not Modified is treated as success with changed=false and no body
// - 200 responses are treated as changed=true and still enforce the normal body checks
//
// Returned validators reflect any updates learned from the response headers.
func FetchBodyIfChanged(
ctx context.Context,
client *http.Client,
url, userAgent, accept string,
conditional bool,
validators HTTPValidators,
) ([]byte, bool, HTTPValidators, error) {
return FetchBodyIfChangedWithLimit(ctx, client, url, userAgent, accept, conditional, validators, DefaultHTTPResponseBodyLimitBytes)
}
func FetchBodyIfChangedWithLimit(
ctx context.Context,
client *http.Client,
url, userAgent, accept string,
conditional bool,
validators HTTPValidators,
bodyLimitBytes int64,
) ([]byte, bool, HTTPValidators, error) {
headerName, headerValue := conditionalHeader(conditional, validators)
res, err := doRequest(ctx, client, http.MethodGet, url, userAgent, accept, headerName, headerValue)
if err != nil {
return nil, false, validators, err
}
defer res.Body.Close()
switch res.StatusCode {
case http.StatusNotModified:
if conditional {
validators = refreshValidators(validators, res.Header)
}
return nil, false, validators, nil
default:
if res.StatusCode < 200 || res.StatusCode >= 300 {
return nil, false, validators, fmt.Errorf("HTTP %s", res.Status)
}
}
b, err := readValidatedBody(res.Body, bodyLimitBytes)
if err != nil {
return nil, false, validators, err
}
if conditional {
validators = replaceValidators(res.Header)
}
return b, true, validators, nil
}
func doRequest(ctx context.Context, client *http.Client, method, url, userAgent, accept, headerName, headerValue string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
return nil, err
}
if userAgent != "" {
req.Header.Set("User-Agent", userAgent)
}
if accept != "" {
req.Header.Set("Accept", accept)
}
if headerName != "" && headerValue != "" {
req.Header.Set(headerName, headerValue)
}
return client.Do(req)
}
func conditionalHeader(enabled bool, validators HTTPValidators) (string, string) {
if !enabled {
return "", ""
}
if etag := strings.TrimSpace(validators.ETag); etag != "" {
return "If-None-Match", etag
}
if lastModified := strings.TrimSpace(validators.LastModified); lastModified != "" {
return "If-Modified-Since", lastModified
}
return "", ""
}
func replaceValidators(header http.Header) HTTPValidators {
return HTTPValidators{
ETag: strings.TrimSpace(header.Get("ETag")),
LastModified: strings.TrimSpace(header.Get("Last-Modified")),
}
}
func refreshValidators(current HTTPValidators, header http.Header) HTTPValidators {
if etag := strings.TrimSpace(header.Get("ETag")); etag != "" {
current.ETag = etag
}
if lastModified := strings.TrimSpace(header.Get("Last-Modified")); lastModified != "" {
current.LastModified = lastModified
}
return current
}
func readValidatedBody(r io.Reader, bodyLimitBytes int64) ([]byte, error) {
if bodyLimitBytes <= 0 {
bodyLimitBytes = DefaultHTTPResponseBodyLimitBytes
}
// Read at most bodyLimitBytes + 1 so we can detect overflow.
limited := io.LimitReader(r, bodyLimitBytes+1)
b, err := io.ReadAll(limited)
if err != nil {
return nil, err
}
if len(b) == 0 {
return nil, fmt.Errorf("empty response body")
}
if int64(len(b)) > bodyLimitBytes {
return nil, fmt.Errorf("response body too large (>%d bytes)", bodyLimitBytes)
}
return b, nil
}

232
transport/http_test.go Normal file
View File

@@ -0,0 +1,232 @@
package transport
import (
"context"
"net/http"
"net/http/httptest"
"testing"
)
func TestFetchBodyIfChangedPrefersETagAndTreats304AsUnchanged(t *testing.T) {
t.Helper()
var call int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
call++
switch call {
case 1:
if got := r.Header.Get("If-None-Match"); got != "" {
t.Fatalf("first request If-None-Match = %q, want empty", got)
}
if got := r.Header.Get("If-Modified-Since"); got != "" {
t.Fatalf("first request If-Modified-Since = %q, want empty", got)
}
w.Header().Set("ETag", `"v1"`)
w.Header().Set("Last-Modified", "Mon, 02 Jan 2006 15:04:05 GMT")
_, _ = w.Write([]byte(`{"ok":true}`))
case 2:
if got := r.Header.Get("If-None-Match"); got != `"v1"` {
t.Fatalf("second request If-None-Match = %q, want %q", got, `"v1"`)
}
if got := r.Header.Get("If-Modified-Since"); got != "" {
t.Fatalf("second request If-Modified-Since = %q, want empty when ETag is cached", got)
}
w.WriteHeader(http.StatusNotModified)
default:
t.Fatalf("unexpected call count %d", call)
}
}))
defer srv.Close()
validators := HTTPValidators{}
body, changed, next, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "application/json", true, validators)
if err != nil {
t.Fatalf("first FetchBodyIfChanged() error = %v", err)
}
if !changed {
t.Fatalf("first FetchBodyIfChanged() changed = false, want true")
}
if got := string(body); got != `{"ok":true}` {
t.Fatalf("first FetchBodyIfChanged() body = %q", got)
}
if got := next.ETag; got != `"v1"` {
t.Fatalf("cached ETag = %q, want %q", got, `"v1"`)
}
if got := next.LastModified; got != "Mon, 02 Jan 2006 15:04:05 GMT" {
t.Fatalf("cached Last-Modified = %q", got)
}
body, changed, next, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "application/json", true, next)
if err != nil {
t.Fatalf("second FetchBodyIfChanged() error = %v", err)
}
if changed {
t.Fatalf("second FetchBodyIfChanged() changed = true, want false")
}
if body != nil {
t.Fatalf("second FetchBodyIfChanged() body = %q, want nil", string(body))
}
if got := next.ETag; got != `"v1"` {
t.Fatalf("cached ETag after 304 = %q, want %q", got, `"v1"`)
}
}
func TestFetchBodyIfChangedFallsBackToIfModifiedSince(t *testing.T) {
t.Helper()
var call int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
call++
switch call {
case 1:
w.Header().Set("Last-Modified", "Tue, 03 Jan 2006 15:04:05 GMT")
_, _ = w.Write([]byte(`first`))
case 2:
if got := r.Header.Get("If-None-Match"); got != "" {
t.Fatalf("second request If-None-Match = %q, want empty", got)
}
if got := r.Header.Get("If-Modified-Since"); got != "Tue, 03 Jan 2006 15:04:05 GMT" {
t.Fatalf("second request If-Modified-Since = %q", got)
}
w.WriteHeader(http.StatusNotModified)
default:
t.Fatalf("unexpected call count %d", call)
}
}))
defer srv.Close()
_, changed, validators, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, HTTPValidators{})
if err != nil {
t.Fatalf("first FetchBodyIfChanged() error = %v", err)
}
if !changed {
t.Fatalf("first FetchBodyIfChanged() changed = false, want true")
}
if got := validators.LastModified; got != "Tue, 03 Jan 2006 15:04:05 GMT" {
t.Fatalf("cached Last-Modified = %q", got)
}
_, changed, _, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, validators)
if err != nil {
t.Fatalf("second FetchBodyIfChanged() error = %v", err)
}
if changed {
t.Fatalf("second FetchBodyIfChanged() changed = true, want false")
}
}
func TestFetchBodyIfChangedClearsValidatorsOn200WithoutValidators(t *testing.T) {
t.Helper()
var call int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
call++
switch call {
case 1:
w.Header().Set("ETag", `"v1"`)
_, _ = w.Write([]byte(`first`))
case 2:
if got := r.Header.Get("If-None-Match"); got != `"v1"` {
t.Fatalf("second request If-None-Match = %q", got)
}
_, _ = w.Write([]byte(`second`))
case 3:
if got := r.Header.Get("If-None-Match"); got != "" {
t.Fatalf("third request If-None-Match = %q, want empty", got)
}
if got := r.Header.Get("If-Modified-Since"); got != "" {
t.Fatalf("third request If-Modified-Since = %q, want empty", got)
}
_, _ = w.Write([]byte(`third`))
default:
t.Fatalf("unexpected call count %d", call)
}
}))
defer srv.Close()
_, _, validators, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, HTTPValidators{})
if err != nil {
t.Fatalf("first FetchBodyIfChanged() error = %v", err)
}
_, _, validators, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, validators)
if err != nil {
t.Fatalf("second FetchBodyIfChanged() error = %v", err)
}
if validators.ETag != "" || validators.LastModified != "" {
t.Fatalf("validators after 200 without validators = %+v, want cleared", validators)
}
_, _, _, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, validators)
if err != nil {
t.Fatalf("third FetchBodyIfChanged() error = %v", err)
}
}
func TestFetchBodyIfChangedConditionalDisabledSkipsConditionalHeaders(t *testing.T) {
t.Helper()
var calls int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls++
if got := r.Header.Get("If-None-Match"); got != "" {
t.Fatalf("request If-None-Match = %q, want empty", got)
}
if got := r.Header.Get("If-Modified-Since"); got != "" {
t.Fatalf("request If-Modified-Since = %q, want empty", got)
}
_, _ = w.Write([]byte(`body`))
}))
defer srv.Close()
validators := HTTPValidators{ETag: `"v1"`, LastModified: "Wed, 04 Jan 2006 15:04:05 GMT"}
_, changed, next, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", false, validators)
if err != nil {
t.Fatalf("FetchBodyIfChanged() error = %v", err)
}
if !changed {
t.Fatalf("FetchBodyIfChanged() changed = false, want true")
}
if next != validators {
t.Fatalf("validators changed when conditional disabled: got %+v want %+v", next, validators)
}
if calls != 1 {
t.Fatalf("calls = %d, want 1", calls)
}
}
func TestFetchBodyIfChangedAllowsEmpty304ButRejectsEmpty200(t *testing.T) {
t.Helper()
notModified := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotModified)
}))
defer notModified.Close()
_, changed, _, err := FetchBodyIfChanged(
context.Background(),
notModified.Client(),
notModified.URL,
"test-agent",
"",
true,
HTTPValidators{ETag: `"v1"`},
)
if err != nil {
t.Fatalf("304 FetchBodyIfChanged() error = %v", err)
}
if changed {
t.Fatalf("304 FetchBodyIfChanged() changed = true, want false")
}
emptyBody := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer emptyBody.Close()
_, _, _, err = FetchBodyIfChanged(context.Background(), emptyBody.Client(), emptyBody.URL, "test-agent", "", true, HTTPValidators{})
if err == nil {
t.Fatalf("empty 200 FetchBodyIfChanged() error = nil, want error")
}
if err.Error() != "empty response body" {
t.Fatalf("empty 200 FetchBodyIfChanged() error = %q", err)
}
}