Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5c1b28ee0a | |||
| 247937b65e | |||
| eb9a7cb349 | |||
| 3281368922 | |||
| 3ef93faf69 | |||
| 4910440756 | |||
| 3b92c2284d | |||
| 215afe1acf | |||
| 4572c53580 | |||
| 96039f6530 | |||
| 6c5f95ad26 | |||
| fafba0f01b | |||
| 3c95fa97cd | |||
| dbca0548b1 | |||
| 9b2c1e5ceb | |||
| 1d43adcfa0 | |||
| a6c133319a | |||
| 09bc65e947 |
91
README.md
91
README.md
@@ -1,3 +1,92 @@
|
||||
# feedkit
|
||||
|
||||
Feedkit provides core interfaces and plumbing for applications that ingest and process feeds.
|
||||
`feedkit` is a small Go toolkit for building feed-processing daemons.
|
||||
|
||||
It gives you the reusable plumbing around collection, processing, routing, and
|
||||
emission, while leaving domain concepts, schemas, and application wiring in
|
||||
your daemon. The intended shape is a family of sibling applications such as
|
||||
`weatherfeeder`, `newsfeeder`, or `earthquakefeeder` that all share the same
|
||||
infrastructure patterns without sharing domain logic.
|
||||
|
||||
## What It Does
|
||||
|
||||
A daemon built on `feedkit` typically:
|
||||
- ingests upstream input by polling HTTP APIs or consuming streams
|
||||
- emits domain-agnostic `event.Event` values
|
||||
- optionally processes those events with stages like dedupe or normalization
|
||||
- routes events to one or more sinks such as stdout, NATS, or Postgres
|
||||
|
||||
Conceptually, the pipeline is:
|
||||
|
||||
`Collect -> Process -> Route -> Emit`
|
||||
|
||||
## Philosophy
|
||||
|
||||
`feedkit` is intentionally not a framework.
|
||||
|
||||
It does not try to own:
|
||||
- your domain payload schemas
|
||||
- your domain event kinds
|
||||
- your daemon lifecycle or `main.go`
|
||||
- your observability stack or deployment model
|
||||
|
||||
Instead, it provides small composable packages that are easy to wire together in
|
||||
different daemons.
|
||||
|
||||
## When To Use It
|
||||
|
||||
`feedkit` is a good fit when you want:
|
||||
- multiple small ingestion daemons with shared infrastructure patterns
|
||||
- clear separation between raw upstream payloads and normalized canonical models
|
||||
- reusable routing and sink behavior across domains
|
||||
- strong config and event-envelope conventions without centralizing domain rules
|
||||
|
||||
It is a poor fit if you want a monolithic framework that dictates application
|
||||
structure end-to-end.
|
||||
|
||||
## Built-In Capabilities
|
||||
|
||||
`feedkit` currently includes:
|
||||
- strict YAML config loading and validation
|
||||
- polling and streaming source abstractions
|
||||
- scheduler orchestration for configured sources and supervised stream workers
|
||||
- optional pipeline processors
|
||||
- built-in dedupe and normalization processors
|
||||
- route compilation and sink fanout
|
||||
- built-in sinks for `stdout`, `nats`, and `postgres`
|
||||
|
||||
The Postgres sink is intentionally split between feedkit-owned infrastructure
|
||||
and daemon-owned schema mapping. `feedkit` manages connection setup, DDL,
|
||||
writes, and pruning; downstream applications define the schema and event mapper.
|
||||
|
||||
## Typical Wiring
|
||||
|
||||
At a high level, a daemon built on `feedkit` does this:
|
||||
|
||||
1. Load config.
|
||||
2. Register domain-specific source drivers.
|
||||
3. Register built-in and/or custom sinks.
|
||||
4. Build sources, sinks, and optional processor chain from config.
|
||||
5. Compile routes.
|
||||
6. Start the scheduler and dispatcher.
|
||||
|
||||
The package docs are the better source of truth for code-level details. In
|
||||
particular, each subpackage `doc.go` describes its external API surface and any
|
||||
optional helper APIs in `helpers.go`.
|
||||
|
||||
## Package Layout
|
||||
|
||||
The major packages are:
|
||||
- `config`: config loading and validation
|
||||
- `event`: the domain-agnostic event envelope
|
||||
- `sources`: source interfaces and reusable source helpers
|
||||
- `scheduler`: source execution and cadence management
|
||||
- `processors`: processor interfaces and registry
|
||||
- `processors/dedupe`: built-in in-memory dedupe processor
|
||||
- `processors/normalize`: built-in normalization processor and helpers
|
||||
- `pipeline`: optional processor chain
|
||||
- `dispatch`: route compilation and fanout
|
||||
- `sinks`: sink interfaces, built-ins, and explicit Postgres factory helpers
|
||||
|
||||
The root package docs in `doc.go` provide a concise package-by-package map for
|
||||
Go documentation consumers.
|
||||
|
||||
@@ -21,31 +21,79 @@ type Config struct {
|
||||
Routes []RouteConfig `yaml:"routes"`
|
||||
}
|
||||
|
||||
// SourceConfig describes one polling job.
|
||||
// SourceMode selects how a source receives upstream input.
|
||||
//
|
||||
// Empty mode means "auto": feedkit infers mode from the registered driver type.
|
||||
type SourceMode string
|
||||
|
||||
const (
|
||||
SourceModeAuto SourceMode = ""
|
||||
SourceModePoll SourceMode = "poll"
|
||||
SourceModeStream SourceMode = "stream"
|
||||
)
|
||||
|
||||
// Normalize lowercases and trims the mode.
|
||||
func (m SourceMode) Normalize() SourceMode {
|
||||
switch strings.ToLower(strings.TrimSpace(string(m))) {
|
||||
case "":
|
||||
return SourceModeAuto
|
||||
case string(SourceModePoll):
|
||||
return SourceModePoll
|
||||
case string(SourceModeStream):
|
||||
return SourceModeStream
|
||||
default:
|
||||
return SourceMode(strings.ToLower(strings.TrimSpace(string(m))))
|
||||
}
|
||||
}
|
||||
|
||||
// SourceConfig describes one input source.
|
||||
//
|
||||
// This is intentionally generic:
|
||||
// - driver-specific knobs belong in Params.
|
||||
// - "kind" is allowed (useful for safety checks / routing), but feedkit does not
|
||||
// restrict the allowed values.
|
||||
// - mode controls polling vs streaming behavior.
|
||||
// - expected emitted kinds are optional and domain-defined.
|
||||
type SourceConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Driver string `yaml:"driver"` // e.g. "openmeteo_observation", "rss_feed", etc.
|
||||
|
||||
Every Duration `yaml:"every"` // "15m", "1m", etc.
|
||||
// Mode is optional:
|
||||
// - "poll": Every must be set (>0)
|
||||
// - "stream": Every must be omitted/zero
|
||||
// - empty: infer from driver registration type (poll vs stream)
|
||||
Mode SourceMode `yaml:"mode"`
|
||||
|
||||
// Kind is optional and domain-defined. If set, it should be a non-empty string.
|
||||
// Domains commonly use it to enforce "this source should only emit kind X".
|
||||
Kind string `yaml:"kind"`
|
||||
// Every is the poll cadence for poll-mode sources ("15m", "1m", etc.).
|
||||
Every Duration `yaml:"every"`
|
||||
|
||||
// Kinds is optional and domain-defined.
|
||||
// If set, it describes the expected emitted event kinds for this source.
|
||||
Kinds []string `yaml:"kinds"`
|
||||
|
||||
// Params are driver-specific settings (URL, headers, station IDs, API keys, etc.).
|
||||
// The driver implementation is responsible for reading/validating these.
|
||||
Params map[string]any `yaml:"params"`
|
||||
}
|
||||
|
||||
// ExpectedKinds returns normalized expected kinds from config.
|
||||
func (cfg SourceConfig) ExpectedKinds() []string {
|
||||
out := make([]string, 0, len(cfg.Kinds))
|
||||
for _, k := range cfg.Kinds {
|
||||
k = strings.TrimSpace(k)
|
||||
if k == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, k)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SinkConfig describes one output sink adapter.
|
||||
type SinkConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Driver string `yaml:"driver"` // "stdout", "file", "postgres", "rabbitmq", ...
|
||||
Driver string `yaml:"driver"` // "stdout", "nats", "postgres", ...
|
||||
Params map[string]any `yaml:"params"` // sink-specific settings
|
||||
}
|
||||
|
||||
@@ -54,7 +102,10 @@ type RouteConfig struct {
|
||||
Sink string `yaml:"sink"` // sink name
|
||||
|
||||
// Kinds is domain-defined. feedkit only enforces that each entry is non-empty.
|
||||
// Whether a given daemon "recognizes" a kind is domain-specific validation.
|
||||
//
|
||||
// If Kinds is omitted or empty, the route matches ALL kinds.
|
||||
// This is useful when you want explicit per-sink routing rules even when a
|
||||
// particular sink should receive everything.
|
||||
Kinds []string `yaml:"kinds"`
|
||||
}
|
||||
|
||||
@@ -128,12 +179,3 @@ func (d *Duration) UnmarshalYAML(value *yaml.Node) error {
|
||||
// Anything else: reject.
|
||||
return fmt.Errorf("duration must be a string like 15m or an integer minutes, got tag %s", value.Tag)
|
||||
}
|
||||
|
||||
func isAllDigits(s string) bool {
|
||||
for _, r := range s {
|
||||
if r < '0' || r > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(s) > 0
|
||||
}
|
||||
|
||||
48
config/config_test.go
Normal file
48
config/config_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSourceConfigExpectedKinds(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg SourceConfig
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "plural kinds normalized",
|
||||
cfg: SourceConfig{
|
||||
Kinds: []string{" observation ", "forecast"},
|
||||
},
|
||||
want: []string{"observation", "forecast"},
|
||||
},
|
||||
{
|
||||
name: "empty kinds",
|
||||
cfg: SourceConfig{},
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.cfg.ExpectedKinds()
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Fatalf("ExpectedKinds() = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSourceModeNormalize(t *testing.T) {
|
||||
if got := SourceMode(" Poll ").Normalize(); got != SourceModePoll {
|
||||
t.Fatalf("Normalize poll = %q, want %q", got, SourceModePoll)
|
||||
}
|
||||
if got := SourceMode("STREAM").Normalize(); got != SourceModeStream {
|
||||
t.Fatalf("Normalize stream = %q, want %q", got, SourceModeStream)
|
||||
}
|
||||
if got := SourceMode("").Normalize(); got != SourceModeAuto {
|
||||
t.Fatalf("Normalize auto = %q, want %q", got, SourceModeAuto)
|
||||
}
|
||||
}
|
||||
@@ -83,14 +83,34 @@ func (c *Config) Validate() error {
|
||||
m.Add(fieldErr(path+".driver", "is required (e.g. openmeteo_observation, rss_feed, ...)"))
|
||||
}
|
||||
|
||||
// Every
|
||||
if s.Every.Duration <= 0 {
|
||||
m.Add(fieldErr(path+".every", "must be a positive duration (e.g. 15m, 1m, 30s)"))
|
||||
// Mode
|
||||
mode := s.Mode.Normalize()
|
||||
if s.Mode != SourceModeAuto && mode != SourceModePoll && mode != SourceModeStream {
|
||||
m.Add(fieldErr(path+".mode", `must be one of: "poll", "stream" (or omit for auto)`))
|
||||
}
|
||||
|
||||
// Kind (optional but if present must be non-empty after trimming)
|
||||
if s.Kind != "" && strings.TrimSpace(s.Kind) == "" {
|
||||
m.Add(fieldErr(path+".kind", "cannot be blank (omit it entirely, or provide a non-empty string)"))
|
||||
// Every
|
||||
if s.Every.Duration < 0 {
|
||||
m.Add(fieldErr(path+".every", "is optional, but must be a positive duration (e.g. 15m, 1m, 30s) if provided"))
|
||||
} else {
|
||||
switch mode {
|
||||
case SourceModePoll:
|
||||
if s.Every.Duration <= 0 {
|
||||
m.Add(fieldErr(path+".every", `is required when mode="poll" (e.g. 15m, 1m, 30s)`))
|
||||
}
|
||||
case SourceModeStream:
|
||||
if s.Every.Duration > 0 {
|
||||
m.Add(fieldErr(path+".every", `must be omitted when mode="stream"`))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Kinds (optional)
|
||||
for j, k := range s.Kinds {
|
||||
kpath := fmt.Sprintf("%s.kinds[%d]", path, j)
|
||||
if strings.TrimSpace(k) == "" {
|
||||
m.Add(fieldErr(kpath, "kind cannot be empty"))
|
||||
}
|
||||
}
|
||||
|
||||
// Params can be nil; that's fine.
|
||||
@@ -115,7 +135,7 @@ func (c *Config) Validate() error {
|
||||
}
|
||||
|
||||
if strings.TrimSpace(s.Driver) == "" {
|
||||
m.Add(fieldErr(path+".driver", "is required (stdout|file|postgres|rabbitmq|...)"))
|
||||
m.Add(fieldErr(path+".driver", "is required (stdout|nats|postgres|...)"))
|
||||
}
|
||||
|
||||
// Params can be nil; that's fine.
|
||||
@@ -133,16 +153,12 @@ func (c *Config) Validate() error {
|
||||
m.Add(fieldErr(path+".sink", fmt.Sprintf("references unknown sink %q (define it under sinks:)", r.Sink)))
|
||||
}
|
||||
|
||||
if len(r.Kinds) == 0 {
|
||||
// You could relax this later (e.g. empty == "all kinds"), but for now
|
||||
// keeping it strict prevents accidental "route does nothing".
|
||||
m.Add(fieldErr(path+".kinds", "must contain at least one kind"))
|
||||
} else {
|
||||
for j, k := range r.Kinds {
|
||||
kpath := fmt.Sprintf("%s.kinds[%d]", path, j)
|
||||
if strings.TrimSpace(k) == "" {
|
||||
m.Add(fieldErr(kpath, "kind cannot be empty"))
|
||||
}
|
||||
// Kinds is optional. If omitted or empty, the route matches ALL kinds.
|
||||
// If provided, each entry must be non-empty.
|
||||
for j, k := range r.Kinds {
|
||||
kpath := fmt.Sprintf("%s.kinds[%d]", path, j)
|
||||
if strings.TrimSpace(k) == "" {
|
||||
m.Add(fieldErr(kpath, "kind cannot be empty"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
385
config/params.go
385
config/params.go
@@ -1,32 +1,21 @@
|
||||
// feedkit/config/params.go
|
||||
package config
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---- SourceConfig param helpers ----
|
||||
|
||||
// ParamString returns the first non-empty string found for any of the provided keys.
|
||||
// Values must actually be strings in the decoded config; other types are ignored.
|
||||
//
|
||||
// This keeps cfg.Params flexible (map[string]any) while letting callers stay type-safe.
|
||||
func (cfg SourceConfig) ParamString(keys ...string) (string, bool) {
|
||||
if cfg.Params == nil {
|
||||
return "", false
|
||||
}
|
||||
for _, k := range keys {
|
||||
v, ok := cfg.Params[k]
|
||||
if !ok || v == nil {
|
||||
continue
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
return "", false
|
||||
return paramString(cfg.Params, keys...)
|
||||
}
|
||||
|
||||
// ParamStringDefault returns ParamString(keys...) if present; otherwise it returns def.
|
||||
@@ -38,14 +27,150 @@ func (cfg SourceConfig) ParamStringDefault(def string, keys ...string) string {
|
||||
return strings.TrimSpace(def)
|
||||
}
|
||||
|
||||
// ParamBool returns the first boolean found for any of the provided keys.
|
||||
//
|
||||
// Accepted types in Params:
|
||||
// - bool
|
||||
// - string: parsed via strconv.ParseBool ("true"/"false"/"1"/"0", etc.)
|
||||
func (cfg SourceConfig) ParamBool(keys ...string) (bool, bool) {
|
||||
return paramBool(cfg.Params, keys...)
|
||||
}
|
||||
|
||||
func (cfg SourceConfig) ParamBoolDefault(def bool, keys ...string) bool {
|
||||
if v, ok := cfg.ParamBool(keys...); ok {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// ParamInt returns the first integer-like value found for any of the provided keys.
|
||||
//
|
||||
// Accepted types in Params:
|
||||
// - any integer type (int, int64, uint32, ...)
|
||||
// - float32/float64 ONLY if it is an exact integer (e.g. 15.0)
|
||||
// - string: parsed via strconv.Atoi (e.g. "42")
|
||||
func (cfg SourceConfig) ParamInt(keys ...string) (int, bool) {
|
||||
return paramInt(cfg.Params, keys...)
|
||||
}
|
||||
|
||||
func (cfg SourceConfig) ParamIntDefault(def int, keys ...string) int {
|
||||
if v, ok := cfg.ParamInt(keys...); ok {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// ParamDuration returns the first duration-like value found for any of the provided keys.
|
||||
//
|
||||
// Accepted types in Params:
|
||||
// - time.Duration
|
||||
// - string: parsed via time.ParseDuration (e.g. "250ms", "30s", "5m")
|
||||
// - if the string is all digits (e.g. "30"), it is interpreted as SECONDS
|
||||
// - numeric: interpreted as SECONDS (e.g. 30 => 30s)
|
||||
//
|
||||
// Rationale: Param durations are usually timeouts/backoffs; seconds are a sane numeric default.
|
||||
// If you want minutes/hours, prefer a duration string like "5m" or "1h".
|
||||
func (cfg SourceConfig) ParamDuration(keys ...string) (time.Duration, bool) {
|
||||
return paramDuration(cfg.Params, keys...)
|
||||
}
|
||||
|
||||
func (cfg SourceConfig) ParamDurationDefault(def time.Duration, keys ...string) time.Duration {
|
||||
if v, ok := cfg.ParamDuration(keys...); ok {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
// ParamStringSlice returns the first string-slice-like value found for any of the provided keys.
|
||||
//
|
||||
// Accepted types in Params:
|
||||
// - []string
|
||||
// - []any where each element is a string
|
||||
// - string:
|
||||
// - if it contains commas, split on commas (",") and trim each item
|
||||
// - otherwise treat as a single-item list
|
||||
//
|
||||
// Empty/blank items are removed.
|
||||
func (cfg SourceConfig) ParamStringSlice(keys ...string) ([]string, bool) {
|
||||
return paramStringSlice(cfg.Params, keys...)
|
||||
}
|
||||
|
||||
// ---- SinkConfig param helpers ----
|
||||
|
||||
// ParamString returns the first non-empty string found for any of the provided keys
|
||||
// in SinkConfig.Params. (Same rationale as SourceConfig.ParamString.)
|
||||
func (cfg SinkConfig) ParamString(keys ...string) (string, bool) {
|
||||
if cfg.Params == nil {
|
||||
return "", false
|
||||
return paramString(cfg.Params, keys...)
|
||||
}
|
||||
|
||||
// ParamStringDefault returns ParamString(keys...) if present; otherwise it returns def.
|
||||
// Symmetric helper for sink implementations.
|
||||
func (cfg SinkConfig) ParamStringDefault(def string, keys ...string) string {
|
||||
if s, ok := cfg.ParamString(keys...); ok {
|
||||
return s
|
||||
}
|
||||
return strings.TrimSpace(def)
|
||||
}
|
||||
|
||||
func (cfg SinkConfig) ParamBool(keys ...string) (bool, bool) {
|
||||
return paramBool(cfg.Params, keys...)
|
||||
}
|
||||
|
||||
func (cfg SinkConfig) ParamBoolDefault(def bool, keys ...string) bool {
|
||||
if v, ok := cfg.ParamBool(keys...); ok {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func (cfg SinkConfig) ParamInt(keys ...string) (int, bool) {
|
||||
return paramInt(cfg.Params, keys...)
|
||||
}
|
||||
|
||||
func (cfg SinkConfig) ParamIntDefault(def int, keys ...string) int {
|
||||
if v, ok := cfg.ParamInt(keys...); ok {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func (cfg SinkConfig) ParamDuration(keys ...string) (time.Duration, bool) {
|
||||
return paramDuration(cfg.Params, keys...)
|
||||
}
|
||||
|
||||
func (cfg SinkConfig) ParamDurationDefault(def time.Duration, keys ...string) time.Duration {
|
||||
if v, ok := cfg.ParamDuration(keys...); ok {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func (cfg SinkConfig) ParamStringSlice(keys ...string) ([]string, bool) {
|
||||
return paramStringSlice(cfg.Params, keys...)
|
||||
}
|
||||
|
||||
// ---- shared implementations (package-private) ----
|
||||
|
||||
func paramAny(params map[string]any, keys ...string) (any, bool) {
|
||||
if params == nil {
|
||||
return nil, false
|
||||
}
|
||||
for _, k := range keys {
|
||||
v, ok := cfg.Params[k]
|
||||
v, ok := params[k]
|
||||
if !ok || v == nil {
|
||||
continue
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func paramString(params map[string]any, keys ...string) (string, bool) {
|
||||
for _, k := range keys {
|
||||
if params == nil {
|
||||
return "", false
|
||||
}
|
||||
v, ok := params[k]
|
||||
if !ok || v == nil {
|
||||
continue
|
||||
}
|
||||
@@ -62,11 +187,213 @@ func (cfg SinkConfig) ParamString(keys ...string) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// ParamStringDefault returns ParamString(keys...) if present; otherwise it returns def.
|
||||
// Symmetric helper for sink implementations.
|
||||
func (cfg SinkConfig) ParamStringDefault(def string, keys ...string) string {
|
||||
if s, ok := cfg.ParamString(keys...); ok {
|
||||
return s
|
||||
func paramBool(params map[string]any, keys ...string) (bool, bool) {
|
||||
v, ok := paramAny(params, keys...)
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
|
||||
switch t := v.(type) {
|
||||
case bool:
|
||||
return t, true
|
||||
case string:
|
||||
s := strings.TrimSpace(t)
|
||||
if s == "" {
|
||||
return false, false
|
||||
}
|
||||
parsed, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
return parsed, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
return strings.TrimSpace(def)
|
||||
}
|
||||
|
||||
func paramInt(params map[string]any, keys ...string) (int, bool) {
|
||||
v, ok := paramAny(params, keys...)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
switch t := v.(type) {
|
||||
case int:
|
||||
return t, true
|
||||
case int8:
|
||||
return int(t), true
|
||||
case int16:
|
||||
return int(t), true
|
||||
case int32:
|
||||
return int(t), true
|
||||
case int64:
|
||||
return int(t), true
|
||||
|
||||
case uint:
|
||||
return int(t), true
|
||||
case uint8:
|
||||
return int(t), true
|
||||
case uint16:
|
||||
return int(t), true
|
||||
case uint32:
|
||||
return int(t), true
|
||||
case uint64:
|
||||
return int(t), true
|
||||
|
||||
case float32:
|
||||
f := float64(t)
|
||||
if math.IsNaN(f) || math.IsInf(f, 0) {
|
||||
return 0, false
|
||||
}
|
||||
if math.Trunc(f) != f {
|
||||
return 0, false
|
||||
}
|
||||
return int(f), true
|
||||
|
||||
case float64:
|
||||
if math.IsNaN(t) || math.IsInf(t, 0) {
|
||||
return 0, false
|
||||
}
|
||||
if math.Trunc(t) != t {
|
||||
return 0, false
|
||||
}
|
||||
return int(t), true
|
||||
|
||||
case string:
|
||||
s := strings.TrimSpace(t)
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
n, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return n, true
|
||||
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func paramDuration(params map[string]any, keys ...string) (time.Duration, bool) {
|
||||
v, ok := paramAny(params, keys...)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
switch t := v.(type) {
|
||||
case time.Duration:
|
||||
if t <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return t, true
|
||||
|
||||
case string:
|
||||
s := strings.TrimSpace(t)
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
// Numeric strings are interpreted as seconds (see doc comment).
|
||||
if isAllDigits(s) {
|
||||
n, err := strconv.Atoi(s)
|
||||
if err != nil || n <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return time.Duration(n) * time.Second, true
|
||||
}
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil || d <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return d, true
|
||||
|
||||
case int:
|
||||
if t <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return time.Duration(t) * time.Second, true
|
||||
case int64:
|
||||
if t <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return time.Duration(t) * time.Second, true
|
||||
case float64:
|
||||
if math.IsNaN(t) || math.IsInf(t, 0) || t <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
// Allow fractional seconds.
|
||||
secs := t * float64(time.Second)
|
||||
return time.Duration(secs), true
|
||||
case float32:
|
||||
f := float64(t)
|
||||
if math.IsNaN(f) || math.IsInf(f, 0) || f <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
secs := f * float64(time.Second)
|
||||
return time.Duration(secs), true
|
||||
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func paramStringSlice(params map[string]any, keys ...string) ([]string, bool) {
|
||||
v, ok := paramAny(params, keys...)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
clean := func(items []string) ([]string, bool) {
|
||||
out := make([]string, 0, len(items))
|
||||
for _, it := range items {
|
||||
it = strings.TrimSpace(it)
|
||||
if it == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, it)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
switch t := v.(type) {
|
||||
case []string:
|
||||
return clean(t)
|
||||
|
||||
case []any:
|
||||
tmp := make([]string, 0, len(t))
|
||||
for _, it := range t {
|
||||
s, ok := it.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
tmp = append(tmp, s)
|
||||
}
|
||||
return clean(tmp)
|
||||
|
||||
case string:
|
||||
s := strings.TrimSpace(t)
|
||||
if s == "" {
|
||||
return nil, false
|
||||
}
|
||||
if strings.Contains(s, ",") {
|
||||
parts := strings.Split(s, ",")
|
||||
return clean(parts)
|
||||
}
|
||||
return clean([]string{s})
|
||||
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func isAllDigits(s string) bool {
|
||||
for _, r := range s {
|
||||
if r < '0' || r > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(s) > 0
|
||||
}
|
||||
|
||||
139
config/validate_test.go
Normal file
139
config/validate_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestValidate_RouteKindsEmptyIsAllowed(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Sources: []SourceConfig{
|
||||
{Name: "src1", Driver: "driver1", Every: Duration{Duration: time.Minute}},
|
||||
},
|
||||
Sinks: []SinkConfig{
|
||||
{Name: "sink1", Driver: "stdout"},
|
||||
},
|
||||
Routes: []RouteConfig{
|
||||
{Sink: "sink1", Kinds: nil}, // omitted
|
||||
{Sink: "sink1", Kinds: []string{}}, // explicit empty
|
||||
},
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_RouteKindsRejectsBlankEntries(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Sources: []SourceConfig{
|
||||
{Name: "src1", Driver: "driver1", Every: Duration{Duration: time.Minute}},
|
||||
},
|
||||
Sinks: []SinkConfig{
|
||||
{Name: "sink1", Driver: "stdout"},
|
||||
},
|
||||
Routes: []RouteConfig{
|
||||
{Sink: "sink1", Kinds: []string{"observation", " ", "alert"}},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "routes[0].kinds[1]") {
|
||||
t.Fatalf("expected error to mention blank kind entry, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_SourceModePollRequiresEvery(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Sources: []SourceConfig{
|
||||
{Name: "src1", Driver: "driver1", Mode: SourceModePoll},
|
||||
},
|
||||
Sinks: []SinkConfig{
|
||||
{Name: "sink1", Driver: "stdout"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `sources[0].every`) {
|
||||
t.Fatalf("expected error to mention sources[0].every, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_SourceModeStreamRejectsEvery(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Sources: []SourceConfig{
|
||||
{
|
||||
Name: "src1",
|
||||
Driver: "driver1",
|
||||
Mode: SourceModeStream,
|
||||
Every: Duration{Duration: time.Minute},
|
||||
},
|
||||
},
|
||||
Sinks: []SinkConfig{
|
||||
{Name: "sink1", Driver: "stdout"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `sources[0].every`) {
|
||||
t.Fatalf("expected error to mention sources[0].every, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_SourceModeRejectsUnknownValue(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Sources: []SourceConfig{
|
||||
{
|
||||
Name: "src1",
|
||||
Driver: "driver1",
|
||||
Mode: SourceMode("batch"),
|
||||
Every: Duration{Duration: time.Minute},
|
||||
},
|
||||
},
|
||||
Sinks: []SinkConfig{
|
||||
{Name: "sink1", Driver: "stdout"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `sources[0].mode`) {
|
||||
t.Fatalf("expected error to mention sources[0].mode, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_SourceKindsRejectBlankEntries(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Sources: []SourceConfig{
|
||||
{
|
||||
Name: "src1",
|
||||
Driver: "driver1",
|
||||
Every: Duration{Duration: time.Minute},
|
||||
Kinds: []string{"observation", " "},
|
||||
},
|
||||
},
|
||||
Sinks: []SinkConfig{
|
||||
{Name: "sink1", Driver: "stdout"},
|
||||
},
|
||||
}
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `sources[0].kinds[1]`) {
|
||||
t.Fatalf("expected error to mention sources[0].kinds[1], got: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -6,10 +6,16 @@ import (
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/logging"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/pipeline"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/sinks"
|
||||
)
|
||||
|
||||
// Logger is a printf-style logger used throughout dispatch.
|
||||
// It is an alias to the shared feedkit logging type so callers can pass
|
||||
// one function everywhere without type mismatch friction.
|
||||
type Logger = logging.Logf
|
||||
|
||||
type Dispatcher struct {
|
||||
In <-chan event.Event
|
||||
|
||||
@@ -35,8 +41,6 @@ type Route struct {
|
||||
Kinds map[event.Kind]bool
|
||||
}
|
||||
|
||||
type Logger func(format string, args ...any)
|
||||
|
||||
func (d *Dispatcher) Run(ctx context.Context, logf Logger) error {
|
||||
if d.In == nil {
|
||||
return fmt.Errorf("dispatcher.Run: In channel is nil")
|
||||
|
||||
89
dispatch/routes.go
Normal file
89
dispatch/routes.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package dispatch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
// CompileRoutes converts config.Config routes into dispatch.Route rules.
|
||||
//
|
||||
// Behavior:
|
||||
// - If cfg.Routes is empty, we default to "all sinks receive all kinds".
|
||||
// (Implemented as one Route per sink with Kinds == nil.)
|
||||
// - If a specific route's kinds: is omitted or empty, that route matches ALL kinds.
|
||||
// (Also compiled as Kinds == nil.)
|
||||
// - Kind strings are normalized via event.ParseKind (lowercase + trim).
|
||||
//
|
||||
// Note: config.Validate() ensures route.sink references a known sink and rejects
|
||||
// blank kind entries. We re-check a few invariants here anyway so CompileRoutes
|
||||
// is safe to call even if a daemon chooses not to call Validate().
|
||||
func CompileRoutes(cfg *config.Config) ([]Route, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("dispatch.CompileRoutes: cfg is nil")
|
||||
}
|
||||
|
||||
if len(cfg.Sinks) == 0 {
|
||||
return nil, fmt.Errorf("dispatch.CompileRoutes: cfg has no sinks")
|
||||
}
|
||||
|
||||
// Build a quick lookup of sink names (exact match; no normalization).
|
||||
sinkNames := make(map[string]bool, len(cfg.Sinks))
|
||||
for i, s := range cfg.Sinks {
|
||||
if strings.TrimSpace(s.Name) == "" {
|
||||
return nil, fmt.Errorf("dispatch.CompileRoutes: sinks[%d].name is empty", i)
|
||||
}
|
||||
sinkNames[s.Name] = true
|
||||
}
|
||||
|
||||
// Default routing: everything to every sink.
|
||||
if len(cfg.Routes) == 0 {
|
||||
out := make([]Route, 0, len(cfg.Sinks))
|
||||
for _, s := range cfg.Sinks {
|
||||
out = append(out, Route{
|
||||
SinkName: s.Name,
|
||||
Kinds: nil, // nil/empty map means "all kinds"
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
out := make([]Route, 0, len(cfg.Routes))
|
||||
|
||||
for i, r := range cfg.Routes {
|
||||
sink := r.Sink
|
||||
if strings.TrimSpace(sink) == "" {
|
||||
return nil, fmt.Errorf("dispatch.CompileRoutes: routes[%d].sink is required", i)
|
||||
}
|
||||
if !sinkNames[sink] {
|
||||
return nil, fmt.Errorf("dispatch.CompileRoutes: routes[%d].sink references unknown sink %q", i, sink)
|
||||
}
|
||||
|
||||
// If kinds is omitted/empty, this route matches all kinds.
|
||||
if len(r.Kinds) == 0 {
|
||||
out = append(out, Route{
|
||||
SinkName: sink,
|
||||
Kinds: nil,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
kinds := make(map[event.Kind]bool, len(r.Kinds))
|
||||
for j, raw := range r.Kinds {
|
||||
k, err := event.ParseKind(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dispatch.CompileRoutes: routes[%d].kinds[%d]: %w", i, j, err)
|
||||
}
|
||||
kinds[k] = true
|
||||
}
|
||||
|
||||
out = append(out, Route{
|
||||
SinkName: sink,
|
||||
Kinds: kinds,
|
||||
})
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
67
dispatch/routes_test.go
Normal file
67
dispatch/routes_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package dispatch
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
)
|
||||
|
||||
func TestCompileRoutes_DefaultIsAllSinksAllKinds(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Sinks: []config.SinkConfig{
|
||||
{Name: "a", Driver: "stdout"},
|
||||
{Name: "b", Driver: "stdout"},
|
||||
},
|
||||
// Routes omitted => default
|
||||
}
|
||||
|
||||
routes, err := CompileRoutes(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CompileRoutes error: %v", err)
|
||||
}
|
||||
if len(routes) != 2 {
|
||||
t.Fatalf("expected 2 routes, got %d", len(routes))
|
||||
}
|
||||
|
||||
// Order should match cfg.Sinks order (deterministic).
|
||||
if routes[0].SinkName != "a" || routes[1].SinkName != "b" {
|
||||
t.Fatalf("unexpected route order: %+v", routes)
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
if len(r.Kinds) != 0 {
|
||||
t.Fatalf("expected nil/empty kinds for default routes, got: %+v", r.Kinds)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileRoutes_EmptyKindsMeansAllKinds(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Sinks: []config.SinkConfig{
|
||||
{Name: "sink1", Driver: "stdout"},
|
||||
},
|
||||
Routes: []config.RouteConfig{
|
||||
{Sink: "sink1"}, // omitted kinds
|
||||
{Sink: "sink1", Kinds: nil}, // explicit nil
|
||||
{Sink: "sink1", Kinds: []string{}}, // explicit empty
|
||||
},
|
||||
}
|
||||
|
||||
routes, err := CompileRoutes(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CompileRoutes error: %v", err)
|
||||
}
|
||||
|
||||
if len(routes) != 3 {
|
||||
t.Fatalf("expected 3 routes, got %d", len(routes))
|
||||
}
|
||||
|
||||
for i, r := range routes {
|
||||
if r.SinkName != "sink1" {
|
||||
t.Fatalf("route[%d] unexpected sink: %q", i, r.SinkName)
|
||||
}
|
||||
if len(r.Kinds) != 0 {
|
||||
t.Fatalf("route[%d] expected nil/empty kinds (match all), got: %+v", i, r.Kinds)
|
||||
}
|
||||
}
|
||||
}
|
||||
57
doc.go
Normal file
57
doc.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Package feedkit provides a high-level map of the feedkit package set.
|
||||
//
|
||||
// Most real applications do not import the root package directly. Instead, they
|
||||
// compose the subpackages that handle configuration, collection, processing,
|
||||
// routing, and sinks.
|
||||
//
|
||||
// The usual flow through feedkit is:
|
||||
//
|
||||
// Collect -> Process -> Route -> Emit
|
||||
//
|
||||
// That flow maps to packages like this:
|
||||
//
|
||||
// - config
|
||||
// Loads and validates daemon config. This package owns domain-agnostic
|
||||
// config shape and consistency checks.
|
||||
//
|
||||
// - event
|
||||
// Defines the event.Event envelope shared across sources, processors,
|
||||
// dispatch, and sinks.
|
||||
//
|
||||
// - sources
|
||||
// Defines polling and streaming source interfaces, the source registry, and
|
||||
// reusable source helpers.
|
||||
//
|
||||
// - scheduler
|
||||
// Runs configured sources on a cadence and supervises long-lived stream
|
||||
// workers with restart/fatal handling.
|
||||
//
|
||||
// - processors
|
||||
// Defines the generic processor interface and registry used to build
|
||||
// ordered processor chains.
|
||||
//
|
||||
// - processors/dedupe
|
||||
// Built-in in-memory dedupe processor keyed by Event.ID.
|
||||
//
|
||||
// - processors/normalize
|
||||
// Built-in normalization processor plus helper APIs for raw-to-canonical
|
||||
// event mapping.
|
||||
//
|
||||
// - pipeline
|
||||
// Applies an ordered processor chain between collection and dispatch.
|
||||
//
|
||||
// - dispatch
|
||||
// Compiles routes and fans events out to sinks with per-sink isolation.
|
||||
//
|
||||
// - sinks
|
||||
// Defines sink interfaces, the sink registry, schema-free built-in sinks,
|
||||
// and explicit Postgres factory helpers.
|
||||
//
|
||||
// feedkit is intentionally domain-agnostic. Domain schemas, domain event kinds,
|
||||
// upstream-specific parsing, and daemon lifecycle remain the responsibility of
|
||||
// each concrete application.
|
||||
//
|
||||
// For repository-level overview and usage narrative, see README.md. For
|
||||
// code-level details, each subpackage doc.go is the source of truth for that
|
||||
// package's public API surface and optional helpers.
|
||||
package feedkit
|
||||
14
go.mod
14
go.mod
@@ -2,4 +2,16 @@ module gitea.maximumdirect.net/ejr/feedkit
|
||||
|
||||
go 1.22
|
||||
|
||||
require gopkg.in/yaml.v3 v3.0.1
|
||||
require (
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/nats-io/nats.go v1.34.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/klauspost/compress v1.17.2 // indirect
|
||||
github.com/nats-io/nkeys v0.4.7 // indirect
|
||||
github.com/nats-io/nuid v1.0.1 // indirect
|
||||
golang.org/x/crypto v0.18.0 // indirect
|
||||
golang.org/x/sys v0.16.0 // indirect
|
||||
)
|
||||
|
||||
15
go.sum
15
go.sum
@@ -1,3 +1,18 @@
|
||||
github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4=
|
||||
github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/nats-io/nats.go v1.34.0 h1:fnxnPCNiwIG5w08rlMcEKTUw4AV/nKyGCOJE8TdhSPk=
|
||||
github.com/nats-io/nats.go v1.34.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8=
|
||||
github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI=
|
||||
github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc=
|
||||
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
|
||||
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
|
||||
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
|
||||
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
|
||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
67
internal/postgres/postgres.go
Normal file
67
internal/postgres/postgres.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
const initTimeout = 5 * time.Second
|
||||
|
||||
var (
|
||||
sqlOpen = sql.Open
|
||||
pingDB = func(ctx context.Context, db *sql.DB) error { return db.PingContext(ctx) }
|
||||
)
|
||||
|
||||
// ConnConfig describes the minimal connection settings shared by feedkit's
|
||||
// Postgres readers and writers.
|
||||
type ConnConfig struct {
|
||||
URI string
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
// BuildDSN validates a Postgres URI and injects credentials into it.
|
||||
func BuildDSN(cfg ConnConfig) (string, error) {
|
||||
u, err := url.Parse(strings.TrimSpace(cfg.URI))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid uri: %w", err)
|
||||
}
|
||||
if u.Scheme == "" {
|
||||
return "", fmt.Errorf("invalid uri: missing scheme")
|
||||
}
|
||||
if u.Host == "" {
|
||||
return "", fmt.Errorf("invalid uri: missing host")
|
||||
}
|
||||
u.User = url.UserPassword(cfg.Username, cfg.Password)
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// Open builds a DSN, opens a database handle, and verifies connectivity with a
|
||||
// bounded ping before returning the handle.
|
||||
func Open(ctx context.Context, cfg ConnConfig) (*sql.DB, error) {
|
||||
dsn, err := BuildDSN(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err := sqlOpen("postgres", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, initTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := pingDB(pingCtx, db); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
165
internal/postgres/postgres_test.go
Normal file
165
internal/postgres/postgres_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"io"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func withPostgresPackageTestState(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
oldSQLOpen := sqlOpen
|
||||
oldPingDB := pingDB
|
||||
t.Cleanup(func() {
|
||||
sqlOpen = oldSQLOpen
|
||||
pingDB = oldPingDB
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildDSNInjectsCredentials(t *testing.T) {
|
||||
dsn, err := BuildDSN(ConnConfig{
|
||||
URI: " postgres://db.example.local:5432/feedkit?sslmode=disable ",
|
||||
Username: "app_user",
|
||||
Password: "app_pass",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("BuildDSN() error = %v", err)
|
||||
}
|
||||
|
||||
u, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error = %v", err)
|
||||
}
|
||||
if u.User == nil || u.User.Username() != "app_user" {
|
||||
t.Fatalf("username = %q, want app_user", u.User.Username())
|
||||
}
|
||||
pass, ok := u.User.Password()
|
||||
if !ok || pass != "app_pass" {
|
||||
t.Fatalf("password = %q, want app_pass", pass)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDSNRejectsInvalidURI(t *testing.T) {
|
||||
_, err := BuildDSN(ConnConfig{URI: "http://[::1", Username: "u", Password: "p"})
|
||||
if err == nil {
|
||||
t.Fatalf("BuildDSN() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid uri") {
|
||||
t.Fatalf("BuildDSN() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDSNRejectsMissingScheme(t *testing.T) {
|
||||
_, err := BuildDSN(ConnConfig{URI: "//db.example.local/feedkit", Username: "u", Password: "p"})
|
||||
if err == nil {
|
||||
t.Fatalf("BuildDSN() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "missing scheme") {
|
||||
t.Fatalf("BuildDSN() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDSNRejectsMissingHost(t *testing.T) {
|
||||
_, err := BuildDSN(ConnConfig{URI: "postgres:///feedkit", Username: "u", Password: "p"})
|
||||
if err == nil {
|
||||
t.Fatalf("BuildDSN() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "missing host") {
|
||||
t.Fatalf("BuildDSN() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPropagatesOpenFailure(t *testing.T) {
|
||||
withPostgresPackageTestState(t)
|
||||
|
||||
sqlOpen = func(_, _ string) (*sql.DB, error) {
|
||||
return nil, errors.New("open failed")
|
||||
}
|
||||
|
||||
_, err := Open(context.Background(), ConnConfig{
|
||||
URI: "postgres://db.example.local/feedkit",
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("Open() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "open failed") {
|
||||
t.Fatalf("Open() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPropagatesPingFailure(t *testing.T) {
|
||||
withPostgresPackageTestState(t)
|
||||
|
||||
const driverName = "feedkit_internal_postgres_ping_fail"
|
||||
registerPingTestDriver(driverName, errors.New("ping failed"))
|
||||
|
||||
sqlOpen = func(_, _ string) (*sql.DB, error) {
|
||||
return sql.Open(driverName, "")
|
||||
}
|
||||
|
||||
_, err := Open(context.Background(), ConnConfig{
|
||||
URI: "postgres://db.example.local/feedkit",
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("Open() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "ping failed") {
|
||||
t.Fatalf("Open() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
pingDriverMu sync.Mutex
|
||||
pingDriverSeen = map[string]bool{}
|
||||
)
|
||||
|
||||
func registerPingTestDriver(name string, pingErr error) {
|
||||
pingDriverMu.Lock()
|
||||
defer pingDriverMu.Unlock()
|
||||
|
||||
if pingDriverSeen[name] {
|
||||
return
|
||||
}
|
||||
sql.Register(name, &pingTestDriver{pingErr: pingErr})
|
||||
pingDriverSeen[name] = true
|
||||
}
|
||||
|
||||
type pingTestDriver struct {
|
||||
pingErr error
|
||||
}
|
||||
|
||||
func (d *pingTestDriver) Open(string) (driver.Conn, error) {
|
||||
return &pingTestConn{pingErr: d.pingErr}, nil
|
||||
}
|
||||
|
||||
type pingTestConn struct {
|
||||
pingErr error
|
||||
}
|
||||
|
||||
func (c *pingTestConn) Prepare(string) (driver.Stmt, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (c *pingTestConn) Close() error { return nil }
|
||||
func (c *pingTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") }
|
||||
func (c *pingTestConn) Ping(context.Context) error { return c.pingErr }
|
||||
|
||||
func (c *pingTestConn) QueryContext(context.Context, string, []driver.NamedValue) (driver.Rows, error) {
|
||||
return &pingTestRows{}, nil
|
||||
}
|
||||
|
||||
type pingTestRows struct{}
|
||||
|
||||
func (r *pingTestRows) Columns() []string { return []string{"ok"} }
|
||||
func (r *pingTestRows) Close() error { return nil }
|
||||
func (r *pingTestRows) Next([]driver.Value) error { return io.EOF }
|
||||
8
logging/logging.go
Normal file
8
logging/logging.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package logging
|
||||
|
||||
// Logf is the shared printf-style logger signature used across feedkit.
|
||||
//
|
||||
// Keeping this in one place avoids the "scheduler.Logger vs dispatch.Logger"
|
||||
// friction and makes it trivial for downstream apps to pass a single log
|
||||
// function throughout the system.
|
||||
type Logf func(format string, args ...any)
|
||||
@@ -1,5 +0,0 @@
|
||||
package pipeline
|
||||
|
||||
// Placeholder for dedupe processor:
|
||||
// - key by Event.ID or computed key
|
||||
// - in-memory store first; later optional Postgres-backed
|
||||
@@ -5,15 +5,11 @@ import (
|
||||
"fmt"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/processors"
|
||||
)
|
||||
|
||||
// Processor can mutate/drop events (dedupe, rate-limit, normalization tweaks).
|
||||
type Processor interface {
|
||||
Process(ctx context.Context, in event.Event) (out *event.Event, err error)
|
||||
}
|
||||
|
||||
type Pipeline struct {
|
||||
Processors []Processor
|
||||
Processors []processors.Processor
|
||||
}
|
||||
|
||||
func (p *Pipeline) Process(ctx context.Context, e event.Event) (*event.Event, error) {
|
||||
|
||||
115
pipeline/pipeline_test.go
Normal file
115
pipeline/pipeline_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package pipeline
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/processors"
|
||||
)
|
||||
|
||||
type procFunc func(context.Context, event.Event) (*event.Event, error)
|
||||
|
||||
func (f procFunc) Process(ctx context.Context, in event.Event) (*event.Event, error) {
|
||||
return f(ctx, in)
|
||||
}
|
||||
|
||||
func TestPipelineProcessSequentialOrder(t *testing.T) {
|
||||
var gotOrder []string
|
||||
|
||||
p := &Pipeline{
|
||||
Processors: []processors.Processor{
|
||||
procFunc(func(_ context.Context, in event.Event) (*event.Event, error) {
|
||||
gotOrder = append(gotOrder, "first")
|
||||
out := in
|
||||
out.Schema = "stage.one.v1"
|
||||
return &out, nil
|
||||
}),
|
||||
procFunc(func(_ context.Context, in event.Event) (*event.Event, error) {
|
||||
gotOrder = append(gotOrder, "second")
|
||||
if in.Schema != "stage.one.v1" {
|
||||
return nil, fmt.Errorf("expected schema from first stage, got %q", in.Schema)
|
||||
}
|
||||
out := in
|
||||
out.Schema = "stage.two.v1"
|
||||
return &out, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
out, err := p.Process(context.Background(), validEvent())
|
||||
if err != nil {
|
||||
t.Fatalf("Process error: %v", err)
|
||||
}
|
||||
if out == nil {
|
||||
t.Fatalf("expected output event, got nil")
|
||||
}
|
||||
if out.Schema != "stage.two.v1" {
|
||||
t.Fatalf("unexpected output schema: %q", out.Schema)
|
||||
}
|
||||
if strings.Join(gotOrder, ",") != "first,second" {
|
||||
t.Fatalf("unexpected processor order: %v", gotOrder)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipelineProcessInvalidInput(t *testing.T) {
|
||||
p := &Pipeline{}
|
||||
_, err := p.Process(context.Background(), event.Event{})
|
||||
if err == nil {
|
||||
t.Fatalf("expected input validation error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid input event") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipelineProcessDrop(t *testing.T) {
|
||||
p := &Pipeline{
|
||||
Processors: []processors.Processor{
|
||||
procFunc(func(context.Context, event.Event) (*event.Event, error) {
|
||||
return nil, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
out, err := p.Process(context.Background(), validEvent())
|
||||
if err != nil {
|
||||
t.Fatalf("Process error: %v", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatalf("expected nil output for dropped event, got %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipelineProcessInvalidOutput(t *testing.T) {
|
||||
p := &Pipeline{
|
||||
Processors: []processors.Processor{
|
||||
procFunc(func(_ context.Context, in event.Event) (*event.Event, error) {
|
||||
out := in
|
||||
out.Payload = nil
|
||||
return &out, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
_, err := p.Process(context.Background(), validEvent())
|
||||
if err == nil {
|
||||
t.Fatalf("expected output validation error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid output event") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func validEvent() event.Event {
|
||||
return event.Event{
|
||||
ID: "evt-1",
|
||||
Kind: event.Kind("observation"),
|
||||
Source: "source-1",
|
||||
EmittedAt: time.Now().UTC(),
|
||||
Payload: map[string]any{"ok": true},
|
||||
}
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
package pipeline
|
||||
|
||||
// Placeholder for rate limit processor:
|
||||
// - per source/kind sink routing limits
|
||||
// - cooldown windows
|
||||
28
processors/dedupe/doc.go
Normal file
28
processors/dedupe/doc.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Package dedupe provides a default in-memory LRU deduplication processor.
|
||||
//
|
||||
// The processor keys strictly by event.Event.ID:
|
||||
// - first-seen IDs pass through
|
||||
// - repeated IDs are dropped
|
||||
//
|
||||
// The in-memory seen-ID set is bounded by a required maxEntries capacity.
|
||||
// When capacity is exceeded, the least recently used ID is evicted.
|
||||
//
|
||||
// Typical registry wiring:
|
||||
//
|
||||
// ```go
|
||||
// reg := processors.NewRegistry()
|
||||
// reg.Register("dedupe", dedupe.Factory(10_000))
|
||||
//
|
||||
// reg.Register("normalize", func() (processors.Processor, error) {
|
||||
// return normalize.NewProcessor(myNormalizers, false), nil
|
||||
// })
|
||||
//
|
||||
// chain, err := reg.BuildChain([]string{"dedupe", "normalize"})
|
||||
//
|
||||
// if err != nil {
|
||||
// // handle wiring error
|
||||
// }
|
||||
//
|
||||
// p := &pipeline.Pipeline{Processors: chain}
|
||||
// ```
|
||||
package dedupe
|
||||
89
processors/dedupe/processor.go
Normal file
89
processors/dedupe/processor.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package dedupe
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/processors"
|
||||
)
|
||||
|
||||
// Processor drops duplicate events by Event.ID using an in-memory LRU.
|
||||
type Processor struct {
|
||||
maxEntries int
|
||||
|
||||
mu sync.Mutex
|
||||
order *list.List // most-recent at front, least-recent at back
|
||||
byID map[string]*list.Element // id -> list element (element.Value is string id)
|
||||
}
|
||||
|
||||
var _ processors.Processor = (*Processor)(nil)
|
||||
|
||||
// NewProcessor constructs a dedupe processor with a required max entry count.
|
||||
func NewProcessor(maxEntries int) (*Processor, error) {
|
||||
if maxEntries <= 0 {
|
||||
return nil, fmt.Errorf("dedupe: maxEntries must be > 0, got %d", maxEntries)
|
||||
}
|
||||
|
||||
return &Processor{
|
||||
maxEntries: maxEntries,
|
||||
order: list.New(),
|
||||
byID: make(map[string]*list.Element, maxEntries),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Factory returns a processors.Factory that constructs Processor instances.
|
||||
func Factory(maxEntries int) processors.Factory {
|
||||
return func() (processors.Processor, error) {
|
||||
return NewProcessor(maxEntries)
|
||||
}
|
||||
}
|
||||
|
||||
// Process implements processors.Processor.
|
||||
func (p *Processor) Process(_ context.Context, in event.Event) (*event.Event, error) {
|
||||
if p == nil {
|
||||
return nil, fmt.Errorf("dedupe: processor is nil")
|
||||
}
|
||||
if p.maxEntries <= 0 {
|
||||
return nil, fmt.Errorf("dedupe: processor maxEntries must be > 0")
|
||||
}
|
||||
|
||||
id := strings.TrimSpace(in.ID)
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("dedupe: event ID is required")
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
|
||||
if p.order == nil || p.byID == nil {
|
||||
p.mu.Unlock()
|
||||
return nil, fmt.Errorf("dedupe: processor is not initialized")
|
||||
}
|
||||
|
||||
if elem, exists := p.byID[id]; exists {
|
||||
p.order.MoveToFront(elem)
|
||||
p.mu.Unlock()
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
elem := p.order.PushFront(id)
|
||||
p.byID[id] = elem
|
||||
|
||||
if p.order.Len() > p.maxEntries {
|
||||
oldest := p.order.Back()
|
||||
if oldest != nil {
|
||||
p.order.Remove(oldest)
|
||||
if oldestID, ok := oldest.Value.(string); ok {
|
||||
delete(p.byID, oldestID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
out := in
|
||||
return &out, nil
|
||||
}
|
||||
163
processors/dedupe/processor_test.go
Normal file
163
processors/dedupe/processor_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package dedupe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/processors"
|
||||
)
|
||||
|
||||
func TestNewProcessorValidation(t *testing.T) {
|
||||
t.Run("rejects non-positive maxEntries", func(t *testing.T) {
|
||||
for _, maxEntries := range []int{0, -1} {
|
||||
p, err := NewProcessor(maxEntries)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for maxEntries=%d, got nil", maxEntries)
|
||||
}
|
||||
if p != nil {
|
||||
t.Fatalf("expected nil processor for maxEntries=%d", maxEntries)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "maxEntries") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("accepts positive maxEntries", func(t *testing.T) {
|
||||
p, err := NewProcessor(1)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProcessor error: %v", err)
|
||||
}
|
||||
if p == nil {
|
||||
t.Fatalf("expected processor, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestProcessorFirstSeenAndDuplicate(t *testing.T) {
|
||||
p, err := NewProcessor(8)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProcessor error: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
first := testEvent("evt-1")
|
||||
|
||||
out, err := p.Process(ctx, first)
|
||||
if err != nil {
|
||||
t.Fatalf("Process first error: %v", err)
|
||||
}
|
||||
if out == nil {
|
||||
t.Fatalf("expected first event to pass through")
|
||||
}
|
||||
if out.ID != first.ID {
|
||||
t.Fatalf("expected unchanged ID %q, got %q", first.ID, out.ID)
|
||||
}
|
||||
|
||||
out, err = p.Process(ctx, first)
|
||||
if err != nil {
|
||||
t.Fatalf("Process duplicate error: %v", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatalf("expected duplicate to be dropped, got %#v", out)
|
||||
}
|
||||
|
||||
out, err = p.Process(ctx, testEvent("evt-2"))
|
||||
if err != nil {
|
||||
t.Fatalf("Process second unique error: %v", err)
|
||||
}
|
||||
if out == nil {
|
||||
t.Fatalf("expected second unique event to pass through")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessorLRUEvictionAndPromotion(t *testing.T) {
|
||||
p, err := NewProcessor(2)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProcessor error: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
mustPass(t, p, ctx, "a")
|
||||
mustPass(t, p, ctx, "b")
|
||||
mustDrop(t, p, ctx, "a") // promote "a" so "b" becomes least-recently-used
|
||||
mustPass(t, p, ctx, "c") // evicts "b"
|
||||
mustDrop(t, p, ctx, "a") // "a" should still be tracked after promotion
|
||||
mustPass(t, p, ctx, "b") // "b" was evicted, so now it passes again
|
||||
}
|
||||
|
||||
func TestProcessorRejectsBlankID(t *testing.T) {
|
||||
p, err := NewProcessor(4)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProcessor error: %v", err)
|
||||
}
|
||||
|
||||
in := testEvent(" ")
|
||||
out, err := p.Process(context.Background(), in)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for blank ID")
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatalf("expected nil output on error, got %#v", out)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "event ID is required") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFactoryWithRegistry(t *testing.T) {
|
||||
r := processors.NewRegistry()
|
||||
r.Register("dedupe", Factory(3))
|
||||
|
||||
p, err := r.Build("dedupe")
|
||||
if err != nil {
|
||||
t.Fatalf("Build error: %v", err)
|
||||
}
|
||||
if p == nil {
|
||||
t.Fatalf("expected processor, got nil")
|
||||
}
|
||||
|
||||
out, err := p.Process(context.Background(), testEvent("evt-factory-1"))
|
||||
if err != nil {
|
||||
t.Fatalf("Process error: %v", err)
|
||||
}
|
||||
if out == nil {
|
||||
t.Fatalf("expected first event to pass through")
|
||||
}
|
||||
}
|
||||
|
||||
func mustPass(t *testing.T, p *Processor, ctx context.Context, id string) {
|
||||
t.Helper()
|
||||
out, err := p.Process(ctx, testEvent(id))
|
||||
if err != nil {
|
||||
t.Fatalf("expected pass for id=%q, got error: %v", id, err)
|
||||
}
|
||||
if out == nil {
|
||||
t.Fatalf("expected pass for id=%q, got drop", id)
|
||||
}
|
||||
}
|
||||
|
||||
func mustDrop(t *testing.T, p *Processor, ctx context.Context, id string) {
|
||||
t.Helper()
|
||||
out, err := p.Process(ctx, testEvent(id))
|
||||
if err != nil {
|
||||
t.Fatalf("expected drop for id=%q, got error: %v", id, err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatalf("expected drop for id=%q, got output", id)
|
||||
}
|
||||
}
|
||||
|
||||
func testEvent(id string) event.Event {
|
||||
return event.Event{
|
||||
ID: id,
|
||||
Kind: event.Kind("observation"),
|
||||
Source: "source-1",
|
||||
EmittedAt: time.Now().UTC(),
|
||||
Payload: map[string]any{"ok": true},
|
||||
}
|
||||
}
|
||||
24
processors/doc.go
Normal file
24
processors/doc.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Package processors defines feedkit's generic processor abstraction and registry.
|
||||
//
|
||||
// Processors are optional pipeline stages that can transform, drop, or reject
|
||||
// events before dispatch to sinks.
|
||||
//
|
||||
// Registry provides name-based construction so daemons can assemble processor
|
||||
// chains without embedding switch statements in wiring code.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// reg := processors.NewRegistry()
|
||||
// reg.Register("dedupe", dedupe.Factory(10_000))
|
||||
// reg.Register("normalize", func() (processors.Processor, error) {
|
||||
// // import "gitea.maximumdirect.net/ejr/feedkit/processors/normalize"
|
||||
// return normalize.NewProcessor(myNormalizers, false), nil
|
||||
// })
|
||||
//
|
||||
// chain, err := reg.BuildChain([]string{"dedupe", "normalize"})
|
||||
// if err != nil {
|
||||
// // handle wiring error
|
||||
// }
|
||||
//
|
||||
// p := &pipeline.Pipeline{Processors: chain}
|
||||
package processors
|
||||
19
processors/normalize/doc.go
Normal file
19
processors/normalize/doc.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// Package normalize provides the feedkit normalization processor and related
|
||||
// helper APIs for raw-to-canonical event mapping.
|
||||
//
|
||||
// External API surface:
|
||||
// - Processor: concrete processors.Processor implementation
|
||||
// - Normalizer / Func: normalization interface and ergonomic function adapter
|
||||
//
|
||||
// Optional helpers from helpers.go:
|
||||
// - PayloadJSONBytes: extract supported JSON-shaped payloads into bytes
|
||||
// - DecodeJSONPayload: decode an event payload into a typed struct
|
||||
// - FinalizeEvent: copy the input event envelope onto a normalized output
|
||||
//
|
||||
// Typical usage:
|
||||
// sources emit raw events (often with json.RawMessage payloads), then a
|
||||
// normalize.Processor converts matching raw schemas into canonical payloads.
|
||||
//
|
||||
// Key property: normalization is optional.
|
||||
// If no Normalizer matches an event, Processor passes it through unchanged by default.
|
||||
package normalize
|
||||
84
processors/normalize/helpers.go
Normal file
84
processors/normalize/helpers.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package normalize
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
// PayloadJSONBytes extracts a JSON payload into bytes suitable for json.Unmarshal.
|
||||
//
|
||||
// Supported payload shapes:
|
||||
// - json.RawMessage
|
||||
// - []byte
|
||||
// - string
|
||||
// - map[string]any
|
||||
func PayloadJSONBytes(e event.Event) ([]byte, error) {
|
||||
if e.Payload == nil {
|
||||
return nil, fmt.Errorf("payload is nil")
|
||||
}
|
||||
|
||||
switch v := e.Payload.(type) {
|
||||
case json.RawMessage:
|
||||
if len(v) == 0 {
|
||||
return nil, fmt.Errorf("payload is empty json.RawMessage")
|
||||
}
|
||||
return []byte(v), nil
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return nil, fmt.Errorf("payload is empty []byte")
|
||||
}
|
||||
return v, nil
|
||||
case string:
|
||||
if v == "" {
|
||||
return nil, fmt.Errorf("payload is empty string")
|
||||
}
|
||||
return []byte(v), nil
|
||||
case map[string]any:
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal map payload: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported payload type %T", e.Payload)
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeJSONPayload extracts the event payload as bytes and unmarshals it into T.
|
||||
func DecodeJSONPayload[T any](in event.Event) (T, error) {
|
||||
var zero T
|
||||
|
||||
b, err := PayloadJSONBytes(in)
|
||||
if err != nil {
|
||||
return zero, fmt.Errorf("extract payload: %w", err)
|
||||
}
|
||||
|
||||
var parsed T
|
||||
if err := json.Unmarshal(b, &parsed); err != nil {
|
||||
return zero, fmt.Errorf("decode raw payload: %w", err)
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// FinalizeEvent builds the output event envelope by copying the input and applying
|
||||
// the new schema/payload, plus optional EffectiveAt.
|
||||
func FinalizeEvent(in event.Event, outSchema string, outPayload any, effectiveAt time.Time) (*event.Event, error) {
|
||||
out := in
|
||||
out.Schema = outSchema
|
||||
out.Payload = outPayload
|
||||
|
||||
if !effectiveAt.IsZero() {
|
||||
t := effectiveAt.UTC()
|
||||
out.EffectiveAt = &t
|
||||
}
|
||||
|
||||
if err := out.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
118
processors/normalize/helpers_test.go
Normal file
118
processors/normalize/helpers_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package normalize
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
func TestPayloadJSONBytesSupportedShapes(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payload any
|
||||
want string
|
||||
}{
|
||||
{name: "rawmessage", payload: json.RawMessage(`{"a":1}`), want: `{"a":1}`},
|
||||
{name: "bytes", payload: []byte(`{"a":2}`), want: `{"a":2}`},
|
||||
{name: "string", payload: `{"a":3}`, want: `{"a":3}`},
|
||||
{name: "map", payload: map[string]any{"a": 4}, want: `{"a":4}`},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := PayloadJSONBytes(event.Event{Payload: tc.payload})
|
||||
if err != nil {
|
||||
t.Fatalf("PayloadJSONBytes() unexpected error: %v", err)
|
||||
}
|
||||
if string(got) != tc.want {
|
||||
t.Fatalf("PayloadJSONBytes() = %s, want %s", string(got), tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPayloadJSONBytesRejectsInvalidPayloads(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payload any
|
||||
want string
|
||||
}{
|
||||
{name: "nil", payload: nil, want: "payload is nil"},
|
||||
{name: "empty rawmessage", payload: json.RawMessage{}, want: "payload is empty json.RawMessage"},
|
||||
{name: "empty bytes", payload: []byte{}, want: "payload is empty []byte"},
|
||||
{name: "empty string", payload: "", want: "payload is empty string"},
|
||||
{name: "unsupported", payload: 123, want: "unsupported payload type"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := PayloadJSONBytes(event.Event{Payload: tc.payload})
|
||||
if err == nil {
|
||||
t.Fatalf("PayloadJSONBytes() expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.want) {
|
||||
t.Fatalf("PayloadJSONBytes() error = %q, want substring %q", err, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeJSONPayload(t *testing.T) {
|
||||
type payload struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
got, err := DecodeJSONPayload[payload](event.Event{
|
||||
Payload: json.RawMessage(`{"name":"alice"}`),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("DecodeJSONPayload() unexpected error: %v", err)
|
||||
}
|
||||
if got.Name != "alice" {
|
||||
t.Fatalf("DecodeJSONPayload() = %#v, want name alice", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeEventPreservesEnvelopeAndEffectiveAtBehavior(t *testing.T) {
|
||||
existingEffectiveAt := time.Date(2026, 3, 28, 11, 0, 0, 0, time.UTC)
|
||||
in := event.Event{
|
||||
ID: "evt-1",
|
||||
Kind: event.Kind("observation"),
|
||||
Source: "source-a",
|
||||
EmittedAt: time.Date(2026, 3, 28, 12, 0, 0, 0, time.UTC),
|
||||
EffectiveAt: &existingEffectiveAt,
|
||||
Schema: "raw.example.v1",
|
||||
Payload: map[string]any{"old": true},
|
||||
}
|
||||
|
||||
out, err := FinalizeEvent(in, "example.v1", map[string]any{"value": 1.234567}, time.Time{})
|
||||
if err != nil {
|
||||
t.Fatalf("FinalizeEvent() unexpected error: %v", err)
|
||||
}
|
||||
if out.ID != in.ID || out.Kind != in.Kind || out.Source != in.Source || out.EmittedAt != in.EmittedAt {
|
||||
t.Fatalf("FinalizeEvent() changed preserved envelope fields: %#v", out)
|
||||
}
|
||||
if out.EffectiveAt == nil || !out.EffectiveAt.Equal(existingEffectiveAt) {
|
||||
t.Fatalf("FinalizeEvent() effectiveAt = %#v, want preserved existing value", out.EffectiveAt)
|
||||
}
|
||||
|
||||
nextEffectiveAt := time.Date(2026, 3, 28, 13, 0, 0, 0, time.FixedZone("x", -4*3600))
|
||||
out, err = FinalizeEvent(in, "example.v1", map[string]any{"value": 1.234567}, nextEffectiveAt)
|
||||
if err != nil {
|
||||
t.Fatalf("FinalizeEvent() unexpected overwrite error: %v", err)
|
||||
}
|
||||
if out.EffectiveAt == nil || !out.EffectiveAt.Equal(nextEffectiveAt.UTC()) {
|
||||
t.Fatalf("FinalizeEvent() effectiveAt = %#v, want %s", out.EffectiveAt, nextEffectiveAt.UTC())
|
||||
}
|
||||
|
||||
payloadMap, ok := out.Payload.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("FinalizeEvent() payload type = %T, want map[string]any", out.Payload)
|
||||
}
|
||||
if payloadMap["value"] != 1.234567 {
|
||||
t.Fatalf("FinalizeEvent() payload value = %#v, want unrounded 1.234567", payloadMap["value"])
|
||||
}
|
||||
}
|
||||
76
processors/normalize/normalize.go
Normal file
76
processors/normalize/normalize.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package normalize
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
// Normalizer converts one event shape into another.
|
||||
//
|
||||
// A Normalizer is typically domain-owned code (weatherfeeder/newsfeeder/...)
|
||||
// that knows how to interpret a specific upstream payload and produce a
|
||||
// normalized payload.
|
||||
//
|
||||
// Normalizers are selected via Match(). The matching strategy is intentionally
|
||||
// flexible: implementations may match on Schema, Kind, Source, or any other
|
||||
// Event fields.
|
||||
type Normalizer interface {
|
||||
// Match reports whether this normalizer applies to the given event.
|
||||
//
|
||||
// Common patterns:
|
||||
// - match on e.Schema (recommended for versioning)
|
||||
// - match on e.Source (useful if Schema is empty)
|
||||
// - match on (e.Kind + e.Source), etc.
|
||||
Match(e event.Event) bool
|
||||
|
||||
// Normalize transforms the incoming event into a new (or modified) event.
|
||||
//
|
||||
// Return values:
|
||||
// - (out, nil) where out != nil: emit the normalized event
|
||||
// - (nil, nil): drop the event (treat as policy drop)
|
||||
// - (nil, err): fail the pipeline
|
||||
//
|
||||
// Note: If you simply want to pass the event through unchanged, return &in.
|
||||
Normalize(ctx context.Context, in event.Event) (*event.Event, error)
|
||||
}
|
||||
|
||||
// Func is an ergonomic adapter that lets you define a Normalizer with functions.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// n := normalize.Func{
|
||||
// MatchFn: func(e event.Event) bool { return e.Schema == "raw.openweather.current.v1" },
|
||||
// NormalizeFn: func(ctx context.Context, in event.Event) (*event.Event, error) {
|
||||
// // ... map in.Payload -> normalized payload ...
|
||||
// },
|
||||
// }
|
||||
type Func struct {
|
||||
MatchFn func(e event.Event) bool
|
||||
NormalizeFn func(ctx context.Context, in event.Event) (*event.Event, error)
|
||||
|
||||
// Optional: helps produce nicer panic/error messages if something goes wrong.
|
||||
Name string
|
||||
}
|
||||
|
||||
func (f Func) Match(e event.Event) bool {
|
||||
if f.MatchFn == nil {
|
||||
return false
|
||||
}
|
||||
return f.MatchFn(e)
|
||||
}
|
||||
|
||||
func (f Func) Normalize(ctx context.Context, in event.Event) (*event.Event, error) {
|
||||
if f.NormalizeFn == nil {
|
||||
return nil, fmt.Errorf("normalize.Func(%s): NormalizeFn is nil", f.safeName())
|
||||
}
|
||||
return f.NormalizeFn(ctx, in)
|
||||
}
|
||||
|
||||
func (f Func) safeName() string {
|
||||
if f.Name == "" {
|
||||
return "<unnamed>"
|
||||
}
|
||||
return f.Name
|
||||
}
|
||||
57
processors/normalize/processor.go
Normal file
57
processors/normalize/processor.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package normalize
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
// Processor applies ordered normalization rules to pipeline events.
|
||||
//
|
||||
// Selection rule:
|
||||
// - iterate in Normalizers order
|
||||
// - the first Normalizer whose Match returns true is applied
|
||||
//
|
||||
// If no normalizer matches, the default behavior is pass-through.
|
||||
type Processor struct {
|
||||
Normalizers []Normalizer
|
||||
|
||||
// If true, events that do not match any normalizer cause an error.
|
||||
// Default is false (pass-through).
|
||||
RequireMatch bool
|
||||
}
|
||||
|
||||
// NewProcessor constructs a normalization processor from an ordered normalizer list.
|
||||
func NewProcessor(normalizers []Normalizer, requireMatch bool) Processor {
|
||||
return Processor{
|
||||
Normalizers: append([]Normalizer(nil), normalizers...),
|
||||
RequireMatch: requireMatch,
|
||||
}
|
||||
}
|
||||
|
||||
// Process implements processors.Processor.
|
||||
func (p Processor) Process(ctx context.Context, in event.Event) (*event.Event, error) {
|
||||
for _, n := range p.Normalizers {
|
||||
if n == nil {
|
||||
continue
|
||||
}
|
||||
if !n.Match(in) {
|
||||
continue
|
||||
}
|
||||
|
||||
out, err := n.Normalize(ctx, in)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("normalize: normalizer failed: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
if p.RequireMatch {
|
||||
return nil, fmt.Errorf("normalize: no normalizer matched event (id=%s kind=%s source=%s schema=%q)",
|
||||
in.ID, in.Kind, in.Source, in.Schema)
|
||||
}
|
||||
|
||||
out := in
|
||||
return &out, nil
|
||||
}
|
||||
139
processors/normalize/processor_test.go
Normal file
139
processors/normalize/processor_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package normalize
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
func TestProcessorFirstMatchWins(t *testing.T) {
|
||||
var firstCalls, secondCalls int
|
||||
|
||||
p := NewProcessor([]Normalizer{
|
||||
Func{
|
||||
MatchFn: func(event.Event) bool { return true },
|
||||
NormalizeFn: func(_ context.Context, in event.Event) (*event.Event, error) {
|
||||
firstCalls++
|
||||
out := in
|
||||
out.Schema = "normalized.first.v1"
|
||||
return &out, nil
|
||||
},
|
||||
},
|
||||
Func{
|
||||
MatchFn: func(event.Event) bool { return true },
|
||||
NormalizeFn: func(_ context.Context, in event.Event) (*event.Event, error) {
|
||||
secondCalls++
|
||||
out := in
|
||||
out.Schema = "normalized.second.v1"
|
||||
return &out, nil
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
|
||||
out, err := p.Process(context.Background(), testEvent())
|
||||
if err != nil {
|
||||
t.Fatalf("Process error: %v", err)
|
||||
}
|
||||
if out == nil {
|
||||
t.Fatalf("expected output event, got nil")
|
||||
}
|
||||
if out.Schema != "normalized.first.v1" {
|
||||
t.Fatalf("unexpected schema: %q", out.Schema)
|
||||
}
|
||||
if firstCalls != 1 {
|
||||
t.Fatalf("expected first normalizer called once, got %d", firstCalls)
|
||||
}
|
||||
if secondCalls != 0 {
|
||||
t.Fatalf("expected second normalizer skipped, got %d calls", secondCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessorNoMatchPassThroughAndRequireMatch(t *testing.T) {
|
||||
in := testEvent()
|
||||
in.Schema = "raw.schema.v1"
|
||||
|
||||
passThrough := NewProcessor([]Normalizer{
|
||||
Func{
|
||||
MatchFn: func(event.Event) bool { return false },
|
||||
NormalizeFn: func(_ context.Context, in event.Event) (*event.Event, error) {
|
||||
out := in
|
||||
out.Schema = "should.not.run"
|
||||
return &out, nil
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
|
||||
out, err := passThrough.Process(context.Background(), in)
|
||||
if err != nil {
|
||||
t.Fatalf("pass-through Process error: %v", err)
|
||||
}
|
||||
if out == nil {
|
||||
t.Fatalf("expected pass-through output event, got nil")
|
||||
}
|
||||
if out.Schema != "raw.schema.v1" {
|
||||
t.Fatalf("expected unchanged schema, got %q", out.Schema)
|
||||
}
|
||||
|
||||
required := NewProcessor(nil, true)
|
||||
_, err = required.Process(context.Background(), in)
|
||||
if err == nil {
|
||||
t.Fatalf("expected require-match error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no normalizer matched") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessorDropAndErrorPropagation(t *testing.T) {
|
||||
t.Run("drop", func(t *testing.T) {
|
||||
p := NewProcessor([]Normalizer{
|
||||
Func{
|
||||
MatchFn: func(event.Event) bool { return true },
|
||||
NormalizeFn: func(context.Context, event.Event) (*event.Event, error) {
|
||||
return nil, nil
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
|
||||
out, err := p.Process(context.Background(), testEvent())
|
||||
if err != nil {
|
||||
t.Fatalf("Process error: %v", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatalf("expected nil output for dropped event, got %#v", out)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error", func(t *testing.T) {
|
||||
p := NewProcessor([]Normalizer{
|
||||
Func{
|
||||
MatchFn: func(event.Event) bool { return true },
|
||||
NormalizeFn: func(context.Context, event.Event) (*event.Event, error) {
|
||||
return nil, errors.New("map failed")
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
|
||||
_, err := p.Process(context.Background(), testEvent())
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "normalizer failed") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func testEvent() event.Event {
|
||||
return event.Event{
|
||||
ID: "evt-normalize-1",
|
||||
Kind: event.Kind("observation"),
|
||||
Source: "source-1",
|
||||
EmittedAt: time.Now().UTC(),
|
||||
Payload: map[string]any{"x": 1},
|
||||
}
|
||||
}
|
||||
15
processors/processor.go
Normal file
15
processors/processor.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package processors
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
// Processor can mutate/drop events (dedupe, rate-limit, normalization tweaks).
|
||||
type Processor interface {
|
||||
Process(ctx context.Context, in event.Event) (out *event.Event, err error)
|
||||
}
|
||||
|
||||
// Factory constructs a configured Processor instance.
|
||||
type Factory func() (Processor, error)
|
||||
71
processors/registry.go
Normal file
71
processors/registry.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package processors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Registry struct {
|
||||
byDriver map[string]Factory
|
||||
}
|
||||
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{byDriver: map[string]Factory{}}
|
||||
}
|
||||
|
||||
// Register associates a processor driver name with a factory.
|
||||
//
|
||||
// Register panics for empty driver names, nil factories, and duplicates.
|
||||
func (r *Registry) Register(driver string, f Factory) {
|
||||
if r == nil {
|
||||
panic("processors.Registry.Register: registry cannot be nil")
|
||||
}
|
||||
driver = strings.TrimSpace(driver)
|
||||
if driver == "" {
|
||||
panic("processors.Registry.Register: driver cannot be empty")
|
||||
}
|
||||
if f == nil {
|
||||
panic(fmt.Sprintf("processors.Registry.Register: factory cannot be nil (driver=%q)", driver))
|
||||
}
|
||||
if r.byDriver == nil {
|
||||
r.byDriver = map[string]Factory{}
|
||||
}
|
||||
if _, exists := r.byDriver[driver]; exists {
|
||||
panic(fmt.Sprintf("processors.Registry.Register: driver %q already registered", driver))
|
||||
}
|
||||
r.byDriver[driver] = f
|
||||
}
|
||||
|
||||
// Build constructs a Processor by driver name.
|
||||
func (r *Registry) Build(driver string) (Processor, error) {
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("processors registry is nil")
|
||||
}
|
||||
driver = strings.TrimSpace(driver)
|
||||
f, ok := r.byDriver[driver]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown processor driver: %q", driver)
|
||||
}
|
||||
|
||||
p, err := f()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build processor %q: %w", driver, err)
|
||||
}
|
||||
if p == nil {
|
||||
return nil, fmt.Errorf("build processor %q: factory returned nil processor", driver)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// BuildChain constructs an ordered processor chain from a driver list.
|
||||
func (r *Registry) BuildChain(drivers []string) ([]Processor, error) {
|
||||
out := make([]Processor, 0, len(drivers))
|
||||
for i, driver := range drivers {
|
||||
p, err := r.Build(driver)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build processor chain[%d] (%q): %w", i, strings.TrimSpace(driver), err)
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
100
processors/registry_test.go
Normal file
100
processors/registry_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package processors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
type testProcessor struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (p testProcessor) Process(context.Context, event.Event) (*event.Event, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestRegistryRegisterValidation(t *testing.T) {
|
||||
t.Run("empty driver panics", func(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
assertPanics(t, func() {
|
||||
r.Register(" ", func() (Processor, error) { return testProcessor{name: "x"}, nil })
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("nil factory panics", func(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
assertPanics(t, func() {
|
||||
r.Register("normalize", nil)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("duplicate driver panics", func(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("normalize", func() (Processor, error) { return testProcessor{name: "a"}, nil })
|
||||
assertPanics(t, func() {
|
||||
r.Register("normalize", func() (Processor, error) { return testProcessor{name: "b"}, nil })
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegistryBuildUnknownDriver(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_, err := r.Build("does_not_exist")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for unknown driver")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown processor driver") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryBuildChainPreservesOrder(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("first", func() (Processor, error) { return testProcessor{name: "first"}, nil })
|
||||
r.Register("second", func() (Processor, error) { return testProcessor{name: "second"}, nil })
|
||||
|
||||
chain, err := r.BuildChain([]string{"first", "second"})
|
||||
if err != nil {
|
||||
t.Fatalf("BuildChain error: %v", err)
|
||||
}
|
||||
if len(chain) != 2 {
|
||||
t.Fatalf("expected 2 processors, got %d", len(chain))
|
||||
}
|
||||
|
||||
p0, ok := chain[0].(testProcessor)
|
||||
if !ok || p0.name != "first" {
|
||||
t.Fatalf("unexpected chain[0]: %#v", chain[0])
|
||||
}
|
||||
p1, ok := chain[1].(testProcessor)
|
||||
if !ok || p1.name != "second" {
|
||||
t.Fatalf("unexpected chain[1]: %#v", chain[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryBuildChainIndexedFailure(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("ok", func() (Processor, error) { return testProcessor{name: "ok"}, nil })
|
||||
r.Register("broken", func() (Processor, error) { return nil, errors.New("boom") })
|
||||
|
||||
_, err := r.BuildChain([]string{"ok", "broken"})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "chain[1]") {
|
||||
t.Fatalf("expected indexed error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertPanics(t *testing.T, fn func()) {
|
||||
t.Helper()
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatalf("expected panic")
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
}
|
||||
25
scheduler/doc.go
Normal file
25
scheduler/doc.go
Normal file
@@ -0,0 +1,25 @@
|
||||
// Package scheduler runs feedkit sources and forwards their events to the
|
||||
// daemon event bus.
|
||||
//
|
||||
// External API surface:
|
||||
// - Scheduler: runs configured polling and streaming jobs
|
||||
// - Job: one scheduler task bound to a source
|
||||
// - StreamExitPolicy: stream supervision policy for non-fatal exits
|
||||
// - StreamBackoff: restart pacing for supervised stream sources
|
||||
//
|
||||
// Optional helpers from helpers.go:
|
||||
// - JobFromSourceConfig: build a scheduler job from a configured source and
|
||||
// feedkit-owned scheduling params
|
||||
//
|
||||
// Poll sources are run on a fixed cadence with optional jitter. Stream sources
|
||||
// are supervised long-lived workers. Their generic feedkit controls live under
|
||||
// sources[].params:
|
||||
// - stream_exit_policy: restart|stop|fatal (default restart)
|
||||
// - stream_backoff_initial: positive duration (default 1s)
|
||||
// - stream_backoff_max: positive duration (default 1m)
|
||||
// - stream_backoff_jitter: non-negative duration (default 250ms)
|
||||
//
|
||||
// Stream sources can classify exits with sources.StreamRetryable and
|
||||
// sources.StreamFatal. Plain errors are treated as retryable by default, while
|
||||
// fatal exits are propagated from Scheduler.Run so the daemon can shut down.
|
||||
package scheduler
|
||||
138
scheduler/helpers.go
Normal file
138
scheduler/helpers.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/sources"
|
||||
)
|
||||
|
||||
// JobFromSourceConfig builds a scheduler Job from a configured source and its
|
||||
// generic feedkit config.
|
||||
func JobFromSourceConfig(src sources.Input, cfg config.SourceConfig) (Job, error) {
|
||||
if src == nil {
|
||||
return Job{}, fmt.Errorf("scheduler: source %q is nil", cfg.Name)
|
||||
}
|
||||
|
||||
job := Job{
|
||||
Source: src,
|
||||
Every: cfg.Every.Duration,
|
||||
}
|
||||
|
||||
if _, ok := src.(sources.StreamSource); ok {
|
||||
if cfg.Every.Duration > 0 {
|
||||
return Job{}, fmt.Errorf("source %q: sources[].every must be omitted for stream sources", cfg.Name)
|
||||
}
|
||||
|
||||
policy, err := parseStreamExitPolicy(cfg)
|
||||
if err != nil {
|
||||
return Job{}, err
|
||||
}
|
||||
backoff, err := parseStreamBackoff(cfg)
|
||||
if err != nil {
|
||||
return Job{}, err
|
||||
}
|
||||
|
||||
job.StreamExitPolicy = policy
|
||||
job.StreamBackoff = backoff
|
||||
return job, nil
|
||||
}
|
||||
|
||||
if _, ok := src.(sources.PollSource); ok {
|
||||
if cfg.Every.Duration <= 0 {
|
||||
return Job{}, fmt.Errorf("source %q: sources[].every must be > 0 for polling sources", cfg.Name)
|
||||
}
|
||||
if err := rejectStreamParams(cfg); err != nil {
|
||||
return Job{}, err
|
||||
}
|
||||
return job, nil
|
||||
}
|
||||
|
||||
return Job{}, fmt.Errorf("scheduler: source %q implements neither PollSource nor StreamSource", cfg.Name)
|
||||
}
|
||||
|
||||
func parseStreamExitPolicy(cfg config.SourceConfig) (StreamExitPolicy, error) {
|
||||
const key = "stream_exit_policy"
|
||||
raw, exists := cfg.Params[key]
|
||||
if !exists || raw == nil {
|
||||
return StreamExitPolicyRestart, nil
|
||||
}
|
||||
|
||||
s, ok := raw.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("source %q: params.%s must be one of: restart, stop, fatal", cfg.Name, key)
|
||||
}
|
||||
|
||||
switch StreamExitPolicy(strings.ToLower(strings.TrimSpace(s))) {
|
||||
case StreamExitPolicyRestart:
|
||||
return StreamExitPolicyRestart, nil
|
||||
case StreamExitPolicyStop:
|
||||
return StreamExitPolicyStop, nil
|
||||
case StreamExitPolicyFatal:
|
||||
return StreamExitPolicyFatal, nil
|
||||
default:
|
||||
return "", fmt.Errorf("source %q: params.%s must be one of: restart, stop, fatal", cfg.Name, key)
|
||||
}
|
||||
}
|
||||
|
||||
func parseStreamBackoff(cfg config.SourceConfig) (StreamBackoff, error) {
|
||||
initial, err := parsePositiveOrDefaultDuration(cfg, "stream_backoff_initial", defaultStreamBackoffInitial)
|
||||
if err != nil {
|
||||
return StreamBackoff{}, err
|
||||
}
|
||||
max, err := parsePositiveOrDefaultDuration(cfg, "stream_backoff_max", defaultStreamBackoffMax)
|
||||
if err != nil {
|
||||
return StreamBackoff{}, err
|
||||
}
|
||||
jitter, err := parseNonNegativeOrDefaultDuration(cfg, "stream_backoff_jitter", defaultStreamBackoffJitter)
|
||||
if err != nil {
|
||||
return StreamBackoff{}, err
|
||||
}
|
||||
if max < initial {
|
||||
return StreamBackoff{}, fmt.Errorf("source %q: params.stream_backoff_max must be >= params.stream_backoff_initial", cfg.Name)
|
||||
}
|
||||
return StreamBackoff{
|
||||
Initial: initial,
|
||||
Max: max,
|
||||
Jitter: jitter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func rejectStreamParams(cfg config.SourceConfig) error {
|
||||
streamKeys := []string{
|
||||
"stream_exit_policy",
|
||||
"stream_backoff_initial",
|
||||
"stream_backoff_max",
|
||||
"stream_backoff_jitter",
|
||||
}
|
||||
for _, key := range streamKeys {
|
||||
if _, ok := cfg.Params[key]; ok {
|
||||
return fmt.Errorf("source %q: params.%s is only valid for stream sources", cfg.Name, key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePositiveOrDefaultDuration(cfg config.SourceConfig, key string, def time.Duration) (time.Duration, error) {
|
||||
if _, exists := cfg.Params[key]; !exists {
|
||||
return def, nil
|
||||
}
|
||||
v, ok := cfg.ParamDuration(key)
|
||||
if !ok || v <= 0 {
|
||||
return 0, fmt.Errorf("source %q: params.%s must be a positive duration", cfg.Name, key)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func parseNonNegativeOrDefaultDuration(cfg config.SourceConfig, key string, def time.Duration) (time.Duration, error) {
|
||||
if _, exists := cfg.Params[key]; !exists {
|
||||
return def, nil
|
||||
}
|
||||
v, ok := cfg.ParamDuration(key)
|
||||
if !ok || v < 0 {
|
||||
return 0, fmt.Errorf("source %q: params.%s must be a non-negative duration", cfg.Name, key)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
@@ -5,25 +5,60 @@ import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/logging"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/sources"
|
||||
)
|
||||
|
||||
// Logger is a printf-style logger used throughout scheduler.
|
||||
// It is an alias to the shared feedkit logging type so callers can pass
|
||||
// one function everywhere without type mismatch friction.
|
||||
type Logger = logging.Logf
|
||||
|
||||
// Job describes one scheduler task.
|
||||
//
|
||||
// A Job may be backed by either:
|
||||
// - a polling source (sources.PollSource): uses Every + jitter and calls Poll()
|
||||
// - a stream source (sources.StreamSource): ignores Every and calls Run()
|
||||
//
|
||||
// Jitter behavior:
|
||||
// - For polling sources: Jitter is applied at startup and before each poll tick.
|
||||
// - For stream sources: Jitter is applied once at startup only (optional; useful to avoid
|
||||
// reconnect storms when many instances start together).
|
||||
type Job struct {
|
||||
Source sources.Source
|
||||
Every time.Duration
|
||||
Source sources.Input
|
||||
Every time.Duration
|
||||
StreamExitPolicy StreamExitPolicy
|
||||
StreamBackoff StreamBackoff
|
||||
|
||||
// Jitter is the maximum additional delay added before each poll.
|
||||
// Example: if Every=15m and Jitter=30s, each poll will occur at:
|
||||
// tick time + random(0..30s)
|
||||
//
|
||||
// If Jitter == 0, we compute a default jitter based on Every.
|
||||
// If Jitter == 0 for polling sources, we compute a default jitter based on Every.
|
||||
//
|
||||
// For stream sources, Jitter is treated as *startup jitter only*.
|
||||
Jitter time.Duration
|
||||
}
|
||||
|
||||
type Logger func(format string, args ...any)
|
||||
// StreamExitPolicy controls how the scheduler handles non-fatal stream exits.
|
||||
type StreamExitPolicy string
|
||||
|
||||
const (
|
||||
StreamExitPolicyRestart StreamExitPolicy = "restart"
|
||||
StreamExitPolicyStop StreamExitPolicy = "stop"
|
||||
StreamExitPolicyFatal StreamExitPolicy = "fatal"
|
||||
)
|
||||
|
||||
// StreamBackoff controls restart pacing for stream supervision.
|
||||
type StreamBackoff struct {
|
||||
Initial time.Duration
|
||||
Max time.Duration
|
||||
Jitter time.Duration
|
||||
}
|
||||
|
||||
type Scheduler struct {
|
||||
Jobs []Job
|
||||
@@ -31,8 +66,18 @@ type Scheduler struct {
|
||||
Logf Logger
|
||||
}
|
||||
|
||||
// Run starts one polling goroutine per job.
|
||||
// Each job runs on its own interval and emits 0..N events per poll.
|
||||
const (
|
||||
defaultStreamBackoffInitial = 1 * time.Second
|
||||
defaultStreamBackoffMax = 1 * time.Minute
|
||||
defaultStreamBackoffJitter = 250 * time.Millisecond
|
||||
streamBackoffResetAfter = 5 * time.Minute
|
||||
)
|
||||
|
||||
var timeNow = time.Now
|
||||
|
||||
// Run starts one goroutine per job.
|
||||
// Poll jobs run on their own interval and emit 0..N events per poll.
|
||||
// Stream jobs run continuously and emit events as they arrive.
|
||||
func (s *Scheduler) Run(ctx context.Context) error {
|
||||
if s.Out == nil {
|
||||
return fmt.Errorf("scheduler.Run: Out channel is nil")
|
||||
@@ -41,31 +86,117 @@ func (s *Scheduler) Run(ctx context.Context) error {
|
||||
return fmt.Errorf("scheduler.Run: no jobs configured")
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
fatalErrCh := make(chan error, 1)
|
||||
var wg sync.WaitGroup
|
||||
for _, job := range s.Jobs {
|
||||
job := job // capture loop variable
|
||||
go s.runJob(ctx, job)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
s.runJob(runCtx, job, fatalErrCh)
|
||||
}()
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-fatalErrCh:
|
||||
cancel()
|
||||
<-done
|
||||
return err
|
||||
case <-runCtx.Done():
|
||||
<-done
|
||||
return runCtx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) runJob(ctx context.Context, job Job) {
|
||||
func (s *Scheduler) runJob(ctx context.Context, job Job, fatalErrCh chan<- error) {
|
||||
if job.Source == nil {
|
||||
s.logf("scheduler: job has nil source")
|
||||
return
|
||||
}
|
||||
if job.Every <= 0 {
|
||||
s.logf("scheduler: job %s has invalid interval", job.Source.Name())
|
||||
|
||||
// Stream sources: event-driven.
|
||||
if ss, ok := job.Source.(sources.StreamSource); ok {
|
||||
s.runStream(ctx, job, ss, fatalErrCh)
|
||||
return
|
||||
}
|
||||
|
||||
// Poll sources: time-based.
|
||||
ps, ok := job.Source.(sources.PollSource)
|
||||
if !ok {
|
||||
s.logf("scheduler: source %T (%s) implements neither Poll() nor Run()", job.Source, job.Source.Name())
|
||||
return
|
||||
}
|
||||
if job.Every <= 0 {
|
||||
s.logf("scheduler: polling job %q missing/invalid interval (sources[].every)", ps.Name())
|
||||
return
|
||||
}
|
||||
|
||||
s.runPoller(ctx, job, ps)
|
||||
}
|
||||
|
||||
func (s *Scheduler) runStream(ctx context.Context, job Job, src sources.StreamSource, fatalErrCh chan<- error) {
|
||||
policy := effectiveStreamExitPolicy(job.StreamExitPolicy)
|
||||
backoff := effectiveStreamBackoff(job.StreamBackoff)
|
||||
rng := seededRNG(src.Name())
|
||||
|
||||
// Optional startup jitter: helps avoid reconnect storms if many daemons start at once.
|
||||
if job.Jitter > 0 {
|
||||
if !sleepJitter(ctx, rng, job.Jitter) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
nextDelay := backoff.Initial
|
||||
for {
|
||||
startedAt := timeNow()
|
||||
err := src.Run(ctx, s.Out)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
normalizedErr := normalizeStreamExitError(src.Name(), err)
|
||||
if sources.IsStreamFatal(normalizedErr) {
|
||||
s.reportFatal(fatalErrCh, fmt.Errorf("scheduler: stream source %q exited fatally: %w", src.Name(), normalizedErr))
|
||||
return
|
||||
}
|
||||
|
||||
switch policy {
|
||||
case StreamExitPolicyStop:
|
||||
s.logf("scheduler: stream source %q stopped after exit: %v", src.Name(), normalizedErr)
|
||||
return
|
||||
case StreamExitPolicyFatal:
|
||||
s.reportFatal(fatalErrCh, fmt.Errorf("scheduler: stream source %q exited under fatal policy: %w", src.Name(), normalizedErr))
|
||||
return
|
||||
}
|
||||
|
||||
if streamRunWasStable(startedAt, timeNow()) {
|
||||
nextDelay = backoff.Initial
|
||||
}
|
||||
|
||||
delay := nextDelay + randomDuration(rng, backoff.Jitter)
|
||||
s.logf("scheduler: stream source %q exited; restarting in %s: %v", src.Name(), delay, normalizedErr)
|
||||
if !sleepDuration(ctx, delay) {
|
||||
return
|
||||
}
|
||||
nextDelay = nextStreamBackoff(nextDelay, backoff.Max)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) runPoller(ctx context.Context, job Job, src sources.PollSource) {
|
||||
// Compute jitter: either configured per job, or a sensible default.
|
||||
jitter := effectiveJitter(job.Every, job.Jitter)
|
||||
|
||||
// Each worker gets its own RNG (safe + no lock contention).
|
||||
seed := time.Now().UnixNano() ^ int64(hashStringFNV32a(job.Source.Name()))
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
rng := seededRNG(src.Name())
|
||||
|
||||
// Optional startup jitter: avoids all jobs firing at the exact moment the daemon starts.
|
||||
if !sleepJitter(ctx, rng, jitter) {
|
||||
@@ -73,7 +204,7 @@ func (s *Scheduler) runJob(ctx context.Context, job Job) {
|
||||
}
|
||||
|
||||
// Immediate poll at startup (after startup jitter).
|
||||
s.pollOnce(ctx, job)
|
||||
s.pollOnce(ctx, src)
|
||||
|
||||
t := time.NewTicker(job.Every)
|
||||
defer t.Stop()
|
||||
@@ -85,7 +216,7 @@ func (s *Scheduler) runJob(ctx context.Context, job Job) {
|
||||
if !sleepJitter(ctx, rng, jitter) {
|
||||
return
|
||||
}
|
||||
s.pollOnce(ctx, job)
|
||||
s.pollOnce(ctx, src)
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@@ -93,10 +224,10 @@ func (s *Scheduler) runJob(ctx context.Context, job Job) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) pollOnce(ctx context.Context, job Job) {
|
||||
events, err := job.Source.Poll(ctx)
|
||||
func (s *Scheduler) pollOnce(ctx context.Context, src sources.PollSource) {
|
||||
events, err := src.Poll(ctx)
|
||||
if err != nil {
|
||||
s.logf("scheduler: poll failed (%s): %v", job.Source.Name(), err)
|
||||
s.logf("scheduler: poll failed (%s): %v", src.Name(), err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -116,6 +247,80 @@ func (s *Scheduler) logf(format string, args ...any) {
|
||||
s.Logf(format, args...)
|
||||
}
|
||||
|
||||
func (s *Scheduler) reportFatal(ch chan<- error, err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// ---- helpers ----
|
||||
|
||||
func effectiveStreamExitPolicy(policy StreamExitPolicy) StreamExitPolicy {
|
||||
switch policy {
|
||||
case StreamExitPolicyStop, StreamExitPolicyFatal:
|
||||
return policy
|
||||
default:
|
||||
return StreamExitPolicyRestart
|
||||
}
|
||||
}
|
||||
|
||||
func effectiveStreamBackoff(cfg StreamBackoff) StreamBackoff {
|
||||
out := cfg
|
||||
if out.Initial <= 0 {
|
||||
out.Initial = defaultStreamBackoffInitial
|
||||
}
|
||||
if out.Max <= 0 {
|
||||
out.Max = defaultStreamBackoffMax
|
||||
}
|
||||
if out.Max < out.Initial {
|
||||
out.Max = out.Initial
|
||||
}
|
||||
if out.Jitter < 0 {
|
||||
out.Jitter = 0
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeStreamExitError(sourceName string, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sources.StreamRetryable(fmt.Errorf("stream source %q exited unexpectedly without error", sourceName))
|
||||
}
|
||||
|
||||
func nextStreamBackoff(current, max time.Duration) time.Duration {
|
||||
if current <= 0 {
|
||||
current = defaultStreamBackoffInitial
|
||||
}
|
||||
if max <= 0 {
|
||||
max = defaultStreamBackoffMax
|
||||
}
|
||||
if current >= max {
|
||||
return max
|
||||
}
|
||||
next := current * 2
|
||||
if next < current || next > max {
|
||||
return max
|
||||
}
|
||||
return next
|
||||
}
|
||||
|
||||
func streamRunWasStable(startedAt, endedAt time.Time) bool {
|
||||
if startedAt.IsZero() || endedAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return endedAt.Sub(startedAt) >= streamBackoffResetAfter
|
||||
}
|
||||
|
||||
func seededRNG(name string) *rand.Rand {
|
||||
seed := timeNow().UnixNano() ^ int64(hashStringFNV32a(name))
|
||||
return rand.New(rand.NewSource(seed))
|
||||
}
|
||||
|
||||
// effectiveJitter chooses a jitter value.
|
||||
// - If configuredMax > 0, use it (but clamp).
|
||||
// - Else default to min(every/10, 30s).
|
||||
@@ -151,11 +356,23 @@ func sleepJitter(ctx context.Context, rng *rand.Rand, max time.Duration) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
return sleepDuration(ctx, randomDuration(rng, max))
|
||||
}
|
||||
|
||||
func randomDuration(rng *rand.Rand, max time.Duration) time.Duration {
|
||||
if max <= 0 {
|
||||
return 0
|
||||
}
|
||||
// Int63n requires a positive argument.
|
||||
// We add 1 so max itself is attainable.
|
||||
n := rng.Int63n(int64(max) + 1)
|
||||
d := time.Duration(n)
|
||||
return time.Duration(n)
|
||||
}
|
||||
|
||||
func sleepDuration(ctx context.Context, d time.Duration) bool {
|
||||
if d <= 0 {
|
||||
return true
|
||||
}
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
|
||||
|
||||
472
scheduler/scheduler_test.go
Normal file
472
scheduler/scheduler_test.go
Normal file
@@ -0,0 +1,472 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/sources"
|
||||
)
|
||||
|
||||
type testPollSource struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (s testPollSource) Name() string { return s.name }
|
||||
|
||||
func (s testPollSource) Poll(context.Context) ([]event.Event, error) { return nil, nil }
|
||||
|
||||
type scriptedStreamSource struct {
|
||||
name string
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
runs []func(context.Context, chan<- event.Event) error
|
||||
}
|
||||
|
||||
func (s *scriptedStreamSource) Name() string { return s.name }
|
||||
|
||||
func (s *scriptedStreamSource) Run(ctx context.Context, out chan<- event.Event) error {
|
||||
s.mu.Lock()
|
||||
call := s.calls
|
||||
s.calls++
|
||||
var run func(context.Context, chan<- event.Event) error
|
||||
if call < len(s.runs) {
|
||||
run = s.runs[call]
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if run != nil {
|
||||
return run(ctx, out)
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
func (s *scriptedStreamSource) CallCount() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.calls
|
||||
}
|
||||
|
||||
type capturingLogger struct {
|
||||
mu sync.Mutex
|
||||
lines []string
|
||||
}
|
||||
|
||||
func (l *capturingLogger) Logf(format string, args ...any) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.lines = append(l.lines, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (l *capturingLogger) Contains(substr string) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
for _, line := range l.lines {
|
||||
if strings.Contains(line, substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestSchedulerRunRestartsPlainStreamErrors(t *testing.T) {
|
||||
src := &scriptedStreamSource{
|
||||
name: "stream-a",
|
||||
runs: []func(context.Context, chan<- event.Event) error{
|
||||
func(context.Context, chan<- event.Event) error { return errors.New("temporary failure") },
|
||||
func(context.Context, chan<- event.Event) error { return errors.New("temporary failure") },
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s := &Scheduler{
|
||||
Jobs: []Job{{
|
||||
Source: src,
|
||||
StreamBackoff: StreamBackoff{
|
||||
Initial: time.Millisecond,
|
||||
Max: time.Millisecond,
|
||||
},
|
||||
}},
|
||||
Out: make(chan event.Event, 1),
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- s.Run(ctx) }()
|
||||
|
||||
waitFor(t, func() bool { return src.CallCount() >= 3 })
|
||||
cancel()
|
||||
|
||||
err := <-errCh
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
|
||||
}
|
||||
if src.CallCount() < 3 {
|
||||
t.Fatalf("stream call count = %d, want at least 3", src.CallCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerRunFatalStreamErrorReturns(t *testing.T) {
|
||||
base := errors.New("fatal failure")
|
||||
src := &scriptedStreamSource{
|
||||
name: "stream-fatal",
|
||||
runs: []func(context.Context, chan<- event.Event) error{
|
||||
func(context.Context, chan<- event.Event) error { return sources.StreamFatal(base) },
|
||||
},
|
||||
}
|
||||
|
||||
s := &Scheduler{
|
||||
Jobs: []Job{{Source: src}},
|
||||
Out: make(chan event.Event, 1),
|
||||
}
|
||||
|
||||
err := s.Run(context.Background())
|
||||
if err == nil {
|
||||
t.Fatalf("Scheduler.Run() error = nil, want fatal error")
|
||||
}
|
||||
if !sources.IsStreamFatal(err) {
|
||||
t.Fatalf("Scheduler.Run() error = %v, want fatal classification", err)
|
||||
}
|
||||
if !errors.Is(err, base) {
|
||||
t.Fatalf("Scheduler.Run() error does not wrap base fatal error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerRunStopPolicyStopsOnlyThatSource(t *testing.T) {
|
||||
src := &scriptedStreamSource{
|
||||
name: "stream-stop",
|
||||
runs: []func(context.Context, chan<- event.Event) error{
|
||||
func(context.Context, chan<- event.Event) error { return errors.New("stop now") },
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s := &Scheduler{
|
||||
Jobs: []Job{{
|
||||
Source: src,
|
||||
StreamExitPolicy: StreamExitPolicyStop,
|
||||
}},
|
||||
Out: make(chan event.Event, 1),
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- s.Run(ctx) }()
|
||||
|
||||
waitFor(t, func() bool { return src.CallCount() >= 1 })
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatalf("Scheduler.Run() returned early: %v", err)
|
||||
default:
|
||||
}
|
||||
|
||||
if src.CallCount() != 1 {
|
||||
t.Fatalf("stream call count = %d, want 1", src.CallCount())
|
||||
}
|
||||
|
||||
cancel()
|
||||
err := <-errCh
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerRunFatalPolicyTreatsPlainErrorAsFatal(t *testing.T) {
|
||||
base := errors.New("plain failure")
|
||||
src := &scriptedStreamSource{
|
||||
name: "stream-fatal-policy",
|
||||
runs: []func(context.Context, chan<- event.Event) error{
|
||||
func(context.Context, chan<- event.Event) error { return base },
|
||||
},
|
||||
}
|
||||
|
||||
s := &Scheduler{
|
||||
Jobs: []Job{{
|
||||
Source: src,
|
||||
StreamExitPolicy: StreamExitPolicyFatal,
|
||||
}},
|
||||
Out: make(chan event.Event, 1),
|
||||
}
|
||||
|
||||
err := s.Run(context.Background())
|
||||
if err == nil {
|
||||
t.Fatalf("Scheduler.Run() error = nil, want fatal-policy error")
|
||||
}
|
||||
if !errors.Is(err, base) {
|
||||
t.Fatalf("Scheduler.Run() error does not wrap base error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerRunNilExitRestartsAsUnexpected(t *testing.T) {
|
||||
logger := &capturingLogger{}
|
||||
src := &scriptedStreamSource{
|
||||
name: "stream-nil-exit",
|
||||
runs: []func(context.Context, chan<- event.Event) error{
|
||||
func(context.Context, chan<- event.Event) error { return nil },
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s := &Scheduler{
|
||||
Jobs: []Job{{
|
||||
Source: src,
|
||||
StreamBackoff: StreamBackoff{
|
||||
Initial: time.Millisecond,
|
||||
Max: time.Millisecond,
|
||||
},
|
||||
}},
|
||||
Out: make(chan event.Event, 1),
|
||||
Logf: logger.Logf,
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- s.Run(ctx) }()
|
||||
|
||||
waitFor(t, func() bool { return src.CallCount() >= 2 })
|
||||
cancel()
|
||||
|
||||
err := <-errCh
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
|
||||
}
|
||||
if !logger.Contains("exited unexpectedly without error") {
|
||||
t.Fatalf("expected log to mention unexpected nil stream exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerRunContextCancelDuringBackoff(t *testing.T) {
|
||||
src := &scriptedStreamSource{
|
||||
name: "stream-backoff-cancel",
|
||||
runs: []func(context.Context, chan<- event.Event) error{
|
||||
func(context.Context, chan<- event.Event) error { return errors.New("retry me") },
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s := &Scheduler{
|
||||
Jobs: []Job{{
|
||||
Source: src,
|
||||
StreamBackoff: StreamBackoff{
|
||||
Initial: time.Second,
|
||||
Max: time.Second,
|
||||
},
|
||||
}},
|
||||
Out: make(chan event.Event, 1),
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- s.Run(ctx) }()
|
||||
|
||||
waitFor(t, func() bool { return src.CallCount() >= 1 })
|
||||
cancel()
|
||||
|
||||
err := <-errCh
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("Scheduler.Run() error = %v, want context canceled", err)
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
if src.CallCount() != 1 {
|
||||
t.Fatalf("stream call count = %d, want 1", src.CallCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextStreamBackoffCapsAtMax(t *testing.T) {
|
||||
if got := nextStreamBackoff(500*time.Millisecond, 2*time.Second); got != time.Second {
|
||||
t.Fatalf("nextStreamBackoff() = %s, want 1s", got)
|
||||
}
|
||||
if got := nextStreamBackoff(time.Second, 2*time.Second); got != 2*time.Second {
|
||||
t.Fatalf("nextStreamBackoff() = %s, want 2s", got)
|
||||
}
|
||||
if got := nextStreamBackoff(2*time.Second, 2*time.Second); got != 2*time.Second {
|
||||
t.Fatalf("nextStreamBackoff() = %s, want 2s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamRunWasStableAfterFiveMinutes(t *testing.T) {
|
||||
start := time.Date(2026, 3, 29, 12, 0, 0, 0, time.UTC)
|
||||
if streamRunWasStable(start, start.Add(4*time.Minute+59*time.Second)) {
|
||||
t.Fatalf("streamRunWasStable() = true, want false")
|
||||
}
|
||||
if !streamRunWasStable(start, start.Add(5*time.Minute)) {
|
||||
t.Fatalf("streamRunWasStable() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJobFromSourceConfigPollSource(t *testing.T) {
|
||||
job, err := JobFromSourceConfig(testPollSource{name: "poll-a"}, config.SourceConfig{
|
||||
Name: "poll-a",
|
||||
Driver: "poll_driver",
|
||||
Every: config.Duration{Duration: time.Minute},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("JobFromSourceConfig() error = %v", err)
|
||||
}
|
||||
if job.Every != time.Minute {
|
||||
t.Fatalf("Job.Every = %s, want 1m", job.Every)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJobFromSourceConfigPollSourceRejectsStreamParams(t *testing.T) {
|
||||
_, err := JobFromSourceConfig(testPollSource{name: "poll-a"}, config.SourceConfig{
|
||||
Name: "poll-a",
|
||||
Driver: "poll_driver",
|
||||
Every: config.Duration{Duration: time.Minute},
|
||||
Params: map[string]any{
|
||||
"stream_exit_policy": "restart",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("JobFromSourceConfig() error = nil, want rejection")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "only valid for stream sources") {
|
||||
t.Fatalf("JobFromSourceConfig() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJobFromSourceConfigStreamSourceParsesDefaultsAndOverrides(t *testing.T) {
|
||||
src := &scriptedStreamSource{name: "stream-a"}
|
||||
|
||||
job, err := JobFromSourceConfig(src, config.SourceConfig{
|
||||
Name: "stream-a",
|
||||
Driver: "stream_driver",
|
||||
Mode: config.SourceModeStream,
|
||||
Params: map[string]any{
|
||||
"stream_exit_policy": "stop",
|
||||
"stream_backoff_initial": "2s",
|
||||
"stream_backoff_max": "10s",
|
||||
"stream_backoff_jitter": "500ms",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("JobFromSourceConfig() error = %v", err)
|
||||
}
|
||||
if job.StreamExitPolicy != StreamExitPolicyStop {
|
||||
t.Fatalf("Job.StreamExitPolicy = %q, want %q", job.StreamExitPolicy, StreamExitPolicyStop)
|
||||
}
|
||||
if job.StreamBackoff.Initial != 2*time.Second {
|
||||
t.Fatalf("Job.StreamBackoff.Initial = %s, want 2s", job.StreamBackoff.Initial)
|
||||
}
|
||||
if job.StreamBackoff.Max != 10*time.Second {
|
||||
t.Fatalf("Job.StreamBackoff.Max = %s, want 10s", job.StreamBackoff.Max)
|
||||
}
|
||||
if job.StreamBackoff.Jitter != 500*time.Millisecond {
|
||||
t.Fatalf("Job.StreamBackoff.Jitter = %s, want 500ms", job.StreamBackoff.Jitter)
|
||||
}
|
||||
|
||||
defaultJob, err := JobFromSourceConfig(src, config.SourceConfig{
|
||||
Name: "stream-default",
|
||||
Driver: "stream_driver",
|
||||
Mode: config.SourceModeStream,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("JobFromSourceConfig() default error = %v", err)
|
||||
}
|
||||
if defaultJob.StreamExitPolicy != StreamExitPolicyRestart {
|
||||
t.Fatalf("default Job.StreamExitPolicy = %q, want restart", defaultJob.StreamExitPolicy)
|
||||
}
|
||||
if defaultJob.StreamBackoff.Initial != defaultStreamBackoffInitial {
|
||||
t.Fatalf("default Job.StreamBackoff.Initial = %s, want %s", defaultJob.StreamBackoff.Initial, defaultStreamBackoffInitial)
|
||||
}
|
||||
if defaultJob.StreamBackoff.Max != defaultStreamBackoffMax {
|
||||
t.Fatalf("default Job.StreamBackoff.Max = %s, want %s", defaultJob.StreamBackoff.Max, defaultStreamBackoffMax)
|
||||
}
|
||||
if defaultJob.StreamBackoff.Jitter != defaultStreamBackoffJitter {
|
||||
t.Fatalf("default Job.StreamBackoff.Jitter = %s, want %s", defaultJob.StreamBackoff.Jitter, defaultStreamBackoffJitter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJobFromSourceConfigStreamSourceRejectsInvalidSettings(t *testing.T) {
|
||||
src := &scriptedStreamSource{name: "stream-b"}
|
||||
|
||||
_, err := JobFromSourceConfig(src, config.SourceConfig{
|
||||
Name: "stream-b",
|
||||
Driver: "stream_driver",
|
||||
Mode: config.SourceModeStream,
|
||||
Params: map[string]any{
|
||||
"stream_exit_policy": "sometimes",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("JobFromSourceConfig() error = nil, want invalid policy error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "stream_exit_policy") {
|
||||
t.Fatalf("JobFromSourceConfig() error = %q", err)
|
||||
}
|
||||
|
||||
_, err = JobFromSourceConfig(src, config.SourceConfig{
|
||||
Name: "stream-b",
|
||||
Driver: "stream_driver",
|
||||
Mode: config.SourceModeStream,
|
||||
Params: map[string]any{
|
||||
"stream_backoff_initial": "0s",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("JobFromSourceConfig() error = nil, want invalid initial backoff error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "stream_backoff_initial") {
|
||||
t.Fatalf("JobFromSourceConfig() error = %q", err)
|
||||
}
|
||||
|
||||
_, err = JobFromSourceConfig(src, config.SourceConfig{
|
||||
Name: "stream-b",
|
||||
Driver: "stream_driver",
|
||||
Mode: config.SourceModeStream,
|
||||
Params: map[string]any{
|
||||
"stream_backoff_initial": "2s",
|
||||
"stream_backoff_max": "1s",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("JobFromSourceConfig() error = nil, want invalid max backoff error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "stream_backoff_max") {
|
||||
t.Fatalf("JobFromSourceConfig() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJobFromSourceConfigStreamSourceRejectsEvery(t *testing.T) {
|
||||
src := &scriptedStreamSource{name: "stream-c"}
|
||||
|
||||
_, err := JobFromSourceConfig(src, config.SourceConfig{
|
||||
Name: "stream-c",
|
||||
Driver: "stream_driver",
|
||||
Mode: config.SourceModeStream,
|
||||
Every: config.Duration{Duration: time.Minute},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("JobFromSourceConfig() error = nil, want every rejection")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "sources[].every must be omitted") {
|
||||
t.Fatalf("JobFromSourceConfig() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func waitFor(t *testing.T, cond func() bool) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if cond() {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("condition not satisfied before timeout")
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package scheduler
|
||||
|
||||
// Placeholder for per-source worker logic:
|
||||
// - ticker loop
|
||||
// - jitter
|
||||
// - backoff on errors
|
||||
// - emits events into scheduler.Out
|
||||
@@ -1,11 +1,6 @@
|
||||
package sinks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
)
|
||||
import "gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
|
||||
// RegisterBuiltins registers sink drivers included in this binary.
|
||||
//
|
||||
@@ -17,39 +12,8 @@ func RegisterBuiltins(r *Registry) {
|
||||
return NewStdoutSink(cfg.Name), nil
|
||||
})
|
||||
|
||||
// File sink: writes/archives events somewhere on disk.
|
||||
r.Register("file", func(cfg config.SinkConfig) (Sink, error) {
|
||||
return NewFileSinkFromConfig(cfg)
|
||||
})
|
||||
|
||||
// Postgres sink: persists events durably.
|
||||
r.Register("postgres", func(cfg config.SinkConfig) (Sink, error) {
|
||||
return NewPostgresSinkFromConfig(cfg)
|
||||
})
|
||||
|
||||
// RabbitMQ sink: publishes events to a broker for downstream consumers.
|
||||
r.Register("rabbitmq", func(cfg config.SinkConfig) (Sink, error) {
|
||||
return NewRabbitMQSinkFromConfig(cfg)
|
||||
// NATS sink: publishes events to a broker for downstream consumers.
|
||||
r.Register("nats", func(cfg config.SinkConfig) (Sink, error) {
|
||||
return NewNATSSinkFromConfig(cfg)
|
||||
})
|
||||
}
|
||||
|
||||
// ---- helpers for validating sink params ----
|
||||
//
|
||||
// These helpers live in sinks (not config) on purpose:
|
||||
// - config is domain-agnostic and should not embed driver-specific validation helpers.
|
||||
// - sinks are adapters; validating their own params here keeps the logic near the driver.
|
||||
|
||||
func requireStringParam(cfg config.SinkConfig, key string) (string, error) {
|
||||
v, ok := cfg.Params[key]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("sink %q: params.%s is required", cfg.Name, key)
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("sink %q: params.%s must be a string", cfg.Name, key)
|
||||
}
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return "", fmt.Errorf("sink %q: params.%s cannot be empty", cfg.Name, key)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
101
sinks/doc.go
Normal file
101
sinks/doc.go
Normal file
@@ -0,0 +1,101 @@
|
||||
// Package sinks defines the feedkit sink interface, sink driver registry, and
|
||||
// built-in infrastructure sinks.
|
||||
//
|
||||
// External API surface:
|
||||
// - Sink: adapter interface that consumes event.Event values
|
||||
// - Registry / NewRegistry: named sink factory registry
|
||||
// - RegisterBuiltins: registers the schema-free built-in sink drivers
|
||||
//
|
||||
// Built-in sink implementations:
|
||||
// - stdout
|
||||
// - nats
|
||||
// - postgres
|
||||
//
|
||||
// Optional helpers from helpers.go:
|
||||
// - PostgresFactory: returns a sink factory for the built-in Postgres sink
|
||||
// using a provided downstream schema
|
||||
//
|
||||
// # NATS built-in overview
|
||||
//
|
||||
// The NATS sink publishes each event as JSON to a configured subject.
|
||||
//
|
||||
// Required params:
|
||||
// - url: NATS server URL (for example, nats://localhost:4222)
|
||||
// - subject: NATS subject to publish to
|
||||
//
|
||||
// Example config:
|
||||
//
|
||||
// sinks:
|
||||
// - name: nats_main
|
||||
// driver: nats
|
||||
// params:
|
||||
// url: nats://localhost:4222
|
||||
// subject: feedkit.events
|
||||
//
|
||||
// # Postgres built-in overview
|
||||
//
|
||||
// The postgres sink is intentionally split between downstream daemon ownership
|
||||
// and feedkit ownership:
|
||||
// - downstream code registers table schema + event mapping functions
|
||||
// - feedkit manages DB connection, create-if-missing DDL, transactional
|
||||
// inserts, optional automatic retention pruning, and manual prune helpers
|
||||
//
|
||||
// Example config:
|
||||
//
|
||||
// sinks:
|
||||
// - name: pg_main
|
||||
// driver: postgres
|
||||
// params:
|
||||
// uri: postgres://localhost:5432/feedkit?sslmode=disable
|
||||
// username: feedkit_user
|
||||
// password: feedkit_pass
|
||||
// prune: 3d # optional: prune rows older than now-3d on each write tx
|
||||
//
|
||||
// params.prune supports:
|
||||
// - Go duration strings (72h, 90m, 30s, ...)
|
||||
// - day/week suffixes (3d, 2w)
|
||||
//
|
||||
// If params.prune is omitted, automatic pruning is disabled.
|
||||
//
|
||||
// Example downstream wiring:
|
||||
//
|
||||
// sinkReg := sinks.NewRegistry()
|
||||
// sinks.RegisterBuiltins(sinkReg)
|
||||
// sinkReg.Register("postgres", sinks.PostgresFactory(sinks.PostgresSchema{
|
||||
// Tables: []sinks.PostgresTable{
|
||||
// {
|
||||
// Name: "events",
|
||||
// Columns: []sinks.PostgresColumn{
|
||||
// {Name: "event_id", Type: "TEXT", Nullable: false},
|
||||
// {Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
|
||||
// {Name: "payload_json", Type: "JSONB", Nullable: false},
|
||||
// },
|
||||
// PrimaryKey: []string{"event_id"},
|
||||
// PruneColumn: "emitted_at", // required for retention pruning
|
||||
// },
|
||||
// },
|
||||
// MapEvent: func(ctx context.Context, e event.Event) ([]sinks.PostgresWrite, error) {
|
||||
// b, err := json.Marshal(e.Payload)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// return []sinks.PostgresWrite{
|
||||
// {
|
||||
// Table: "events",
|
||||
// Values: map[string]any{
|
||||
// "event_id": e.ID,
|
||||
// "emitted_at": e.EmittedAt,
|
||||
// "payload_json": string(b),
|
||||
// },
|
||||
// },
|
||||
// }, nil
|
||||
// },
|
||||
// }))
|
||||
//
|
||||
// Manual pruning via type assertion (administrative helpers):
|
||||
//
|
||||
// if p, ok := sink.(sinks.PostgresPruner); ok {
|
||||
// _, _ = p.PruneKeepLatest(ctx, "events", 10000)
|
||||
// _, _ = p.PruneOlderThan(ctx, "events", time.Now().UTC().AddDate(0, -1, 0))
|
||||
// }
|
||||
package sinks
|
||||
@@ -1,30 +0,0 @@
|
||||
package sinks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
type FileSink struct {
|
||||
name string
|
||||
path string
|
||||
}
|
||||
|
||||
func NewFileSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
|
||||
path, err := requireStringParam(cfg, "path")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &FileSink{name: cfg.Name, path: path}, nil
|
||||
}
|
||||
|
||||
func (s *FileSink) Name() string { return s.name }
|
||||
|
||||
func (s *FileSink) Consume(ctx context.Context, e event.Event) error {
|
||||
_ = ctx
|
||||
_ = e
|
||||
return fmt.Errorf("file sink: TODO implement (path=%s)", s.path)
|
||||
}
|
||||
35
sinks/helpers.go
Normal file
35
sinks/helpers.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package sinks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
)
|
||||
|
||||
// requireStringParam returns a non-empty string sink param.
|
||||
//
|
||||
// This helper is intentionally local to sinks rather than config so
|
||||
// driver-specific validation stays close to the adapters that use it.
|
||||
func requireStringParam(cfg config.SinkConfig, key string) (string, error) {
|
||||
v, ok := cfg.Params[key]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("sink %q: params.%s is required", cfg.Name, key)
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("sink %q: params.%s must be a string", cfg.Name, key)
|
||||
}
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return "", fmt.Errorf("sink %q: params.%s cannot be empty", cfg.Name, key)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// PostgresFactory returns a sink factory that builds the built-in Postgres sink
|
||||
// using the provided downstream schema definition.
|
||||
func PostgresFactory(schema PostgresSchema) Factory {
|
||||
return func(cfg config.SinkConfig) (Sink, error) {
|
||||
return NewPostgresSinkFromConfig(cfg, schema)
|
||||
}
|
||||
}
|
||||
46
sinks/helpers_test.go
Normal file
46
sinks/helpers_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package sinks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
func TestPostgresFactoryReturnsWorkingFactory(t *testing.T) {
|
||||
factory := PostgresFactory(testPostgresSchema())
|
||||
if factory == nil {
|
||||
t.Fatalf("PostgresFactory() returned nil")
|
||||
}
|
||||
if _, err := factory(config.SinkConfig{}); err == nil {
|
||||
t.Fatalf("factory(config) expected parameter validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func testPostgresSchema() PostgresSchema {
|
||||
return PostgresSchema{
|
||||
Tables: []PostgresTable{
|
||||
{
|
||||
Name: "events",
|
||||
Columns: []PostgresColumn{
|
||||
{Name: "event_id", Type: "TEXT", Nullable: false},
|
||||
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
|
||||
},
|
||||
PrimaryKey: []string{"event_id"},
|
||||
PruneColumn: "emitted_at",
|
||||
},
|
||||
},
|
||||
MapEvent: func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
|
||||
return []PostgresWrite{
|
||||
{
|
||||
Table: "events",
|
||||
Values: map[string]any{
|
||||
"event_id": e.ID,
|
||||
"emitted_at": e.EmittedAt,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
97
sinks/nats.go
Normal file
97
sinks/nats.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package sinks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
type NATSSink struct {
|
||||
name string
|
||||
url string
|
||||
subject string
|
||||
|
||||
mu sync.Mutex
|
||||
conn *nats.Conn
|
||||
}
|
||||
|
||||
func NewNATSSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
|
||||
url, err := requireStringParam(cfg, "url")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subject, err := requireStringParam(cfg, "subject")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &NATSSink{name: cfg.Name, url: url, subject: subject}, nil
|
||||
}
|
||||
|
||||
func (r *NATSSink) Name() string { return r.name }
|
||||
|
||||
func (r *NATSSink) Consume(ctx context.Context, e event.Event) error {
|
||||
// Boundary validation: if something upstream violated invariants,
|
||||
// surface it loudly rather than printing partial nonsense.
|
||||
if err := e.Validate(); err != nil {
|
||||
return fmt.Errorf("NATS sink: invalid event: %w", err)
|
||||
}
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := r.connect(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("NATS sink: connect: %w", err)
|
||||
}
|
||||
|
||||
b, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
return fmt.Errorf("NATS sink: marshal event: %w", err)
|
||||
}
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.Publish(r.subject, b); err != nil {
|
||||
return fmt.Errorf("NATS sink: publish: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *NATSSink) connect(ctx context.Context) (*nats.Conn, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.conn != nil && r.conn.Status() != nats.CLOSED {
|
||||
return r.conn, nil
|
||||
}
|
||||
|
||||
opts := []nats.Option{
|
||||
nats.Name(fmt.Sprintf("feedkit sink %s", r.name)),
|
||||
}
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
timeout := time.Until(deadline)
|
||||
if timeout <= 0 {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
opts = append(opts, nats.Timeout(timeout))
|
||||
}
|
||||
|
||||
conn, err := nats.Connect(r.url, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.conn = conn
|
||||
return conn, nil
|
||||
}
|
||||
47
sinks/nats_test.go
Normal file
47
sinks/nats_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package sinks
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
)
|
||||
|
||||
func TestNewNATSSinkFromConfigRequiresSubject(t *testing.T) {
|
||||
sink, err := NewNATSSinkFromConfig(config.SinkConfig{
|
||||
Name: "nats-main",
|
||||
Driver: "nats",
|
||||
Params: map[string]any{
|
||||
"url": "nats://localhost:4222",
|
||||
"subject": "feedkit.events",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewNATSSinkFromConfig() error = %v", err)
|
||||
}
|
||||
|
||||
natsSink, ok := sink.(*NATSSink)
|
||||
if !ok {
|
||||
t.Fatalf("sink type = %T, want *NATSSink", sink)
|
||||
}
|
||||
if natsSink.subject != "feedkit.events" {
|
||||
t.Fatalf("subject = %q, want feedkit.events", natsSink.subject)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNATSSinkFromConfigRejectsLegacyExchange(t *testing.T) {
|
||||
_, err := NewNATSSinkFromConfig(config.SinkConfig{
|
||||
Name: "nats-main",
|
||||
Driver: "nats",
|
||||
Params: map[string]any{
|
||||
"url": "nats://localhost:4222",
|
||||
"exchange": "feedkit.events",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("NewNATSSinkFromConfig() expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "params.subject is required") {
|
||||
t.Fatalf("error = %q, want params.subject is required", err)
|
||||
}
|
||||
}
|
||||
@@ -2,36 +2,437 @@ package sinks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
|
||||
)
|
||||
|
||||
type PostgresSink struct {
|
||||
name string
|
||||
dsn string
|
||||
type postgresTx interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
|
||||
dsn, err := requireStringParam(cfg, "dsn")
|
||||
type postgresExecer interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
type postgresDB interface {
|
||||
PingContext(ctx context.Context) error
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error)
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type sqlDBWrapper struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (w *sqlDBWrapper) PingContext(ctx context.Context) error {
|
||||
return w.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (w *sqlDBWrapper) BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error) {
|
||||
tx, err := w.db.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PostgresSink{name: cfg.Name, dsn: dsn}, nil
|
||||
return &sqlTxWrapper{tx: tx}, nil
|
||||
}
|
||||
|
||||
func (w *sqlDBWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
return w.db.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (w *sqlDBWrapper) Close() error {
|
||||
return w.db.Close()
|
||||
}
|
||||
|
||||
type sqlTxWrapper struct {
|
||||
tx *sql.Tx
|
||||
}
|
||||
|
||||
func (w *sqlTxWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
return w.tx.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (w *sqlTxWrapper) Commit() error {
|
||||
return w.tx.Commit()
|
||||
}
|
||||
|
||||
func (w *sqlTxWrapper) Rollback() error {
|
||||
return w.tx.Rollback()
|
||||
}
|
||||
|
||||
var openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
|
||||
db, err := pgconn.Open(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sqlDBWrapper{db: db}, nil
|
||||
}
|
||||
|
||||
type PostgresSink struct {
|
||||
name string
|
||||
db postgresDB
|
||||
schema postgresSchemaCompiled
|
||||
pruneWindow time.Duration
|
||||
}
|
||||
|
||||
func NewPostgresSinkFromConfig(cfg config.SinkConfig, schemaDef PostgresSchema) (Sink, error) {
|
||||
uri, err := requireStringParam(cfg, "uri")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
username, err := requireStringParam(cfg, "username")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
password, err := requireStringParam(cfg, "password")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pruneWindow, err := parsePostgresPruneWindow(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
schema, err := compilePostgresSchema(schemaDef)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres sink %q: compile schema: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
db, err := openPostgresDB(context.Background(), pgconn.ConnConfig{
|
||||
URI: uri,
|
||||
Username: username,
|
||||
Password: password,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres sink %q: open db: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
s := &PostgresSink{name: cfg.Name, db: db, schema: schema, pruneWindow: pruneWindow}
|
||||
if err := s.initialize(); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (p *PostgresSink) Name() string { return p.name }
|
||||
|
||||
func (p *PostgresSink) Consume(ctx context.Context, e event.Event) error {
|
||||
_ = ctx
|
||||
|
||||
// Boundary validation: if something upstream violated invariants,
|
||||
// surface it loudly rather than printing partial nonsense.
|
||||
// surface it loudly rather than writing corrupt rows.
|
||||
if err := e.Validate(); err != nil {
|
||||
return fmt.Errorf("rabbitmq sink: invalid event: %w", err)
|
||||
return fmt.Errorf("postgres sink: invalid event: %w", err)
|
||||
}
|
||||
|
||||
// TODO implement Postgres transaction
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
writes, err := p.schema.mapEvent(ctx, e)
|
||||
if err != nil {
|
||||
return fmt.Errorf("postgres sink: map event: %w", err)
|
||||
}
|
||||
if len(writes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := p.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("postgres sink: begin tx: %w", err)
|
||||
}
|
||||
|
||||
committed := false
|
||||
defer func() {
|
||||
if !committed {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
for _, w := range writes {
|
||||
tbl, err := p.schema.validateWrite(w)
|
||||
if err != nil {
|
||||
return fmt.Errorf("postgres sink: %w", err)
|
||||
}
|
||||
|
||||
query, args, err := buildInsertSQL(tbl, w)
|
||||
if err != nil {
|
||||
return fmt.Errorf("postgres sink: build insert for table %q: %w", tbl.name, err)
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
|
||||
return fmt.Errorf("postgres sink: insert into %q: %w", tbl.name, err)
|
||||
}
|
||||
}
|
||||
if p.pruneWindow > 0 {
|
||||
cutoff := time.Now().UTC().Add(-p.pruneWindow)
|
||||
for _, tableName := range p.schema.tableOrder {
|
||||
tbl := p.schema.tables[tableName]
|
||||
if _, err := execPruneOlderThan(ctx, tx, tbl, cutoff); err != nil {
|
||||
return fmt.Errorf("postgres sink: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("postgres sink: commit tx: %w", err)
|
||||
}
|
||||
committed = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgresSink) PruneKeepLatest(ctx context.Context, table string, keep int) (int64, error) {
|
||||
if keep < 0 {
|
||||
return 0, fmt.Errorf("postgres sink: keep must be >= 0")
|
||||
}
|
||||
tbl, err := p.lookupTable(table)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
`DELETE FROM %s WHERE ctid IN (
|
||||
SELECT ctid FROM %s
|
||||
ORDER BY %s DESC
|
||||
OFFSET $1
|
||||
)`,
|
||||
quotePostgresIdent(tbl.name),
|
||||
quotePostgresIdent(tbl.name),
|
||||
quotePostgresIdent(tbl.pruneColumn),
|
||||
)
|
||||
|
||||
res, err := p.db.ExecContext(ctx, query, keep)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("postgres sink: prune keep latest table %q: %w", tbl.name, err)
|
||||
}
|
||||
rows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("postgres sink: prune keep latest table %q rows affected: %w", tbl.name, err)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (p *PostgresSink) PruneOlderThan(ctx context.Context, table string, cutoff time.Time) (int64, error) {
|
||||
tbl, err := p.lookupTable(table)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
rows, err := execPruneOlderThan(ctx, p.db, tbl, cutoff)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("postgres sink: %w", err)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (p *PostgresSink) PruneAllKeepLatest(ctx context.Context, keep int) (map[string]int64, error) {
|
||||
counts := make(map[string]int64, len(p.schema.tableOrder))
|
||||
for _, table := range p.schema.tableOrder {
|
||||
n, err := p.PruneKeepLatest(ctx, table, keep)
|
||||
if err != nil {
|
||||
return counts, err
|
||||
}
|
||||
counts[table] = n
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (p *PostgresSink) PruneAllOlderThan(ctx context.Context, cutoff time.Time) (map[string]int64, error) {
|
||||
counts := make(map[string]int64, len(p.schema.tableOrder))
|
||||
for _, table := range p.schema.tableOrder {
|
||||
n, err := p.PruneOlderThan(ctx, table, cutoff)
|
||||
if err != nil {
|
||||
return counts, err
|
||||
}
|
||||
counts[table] = n
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (p *PostgresSink) initialize() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for _, tableName := range p.schema.tableOrder {
|
||||
tbl := p.schema.tables[tableName]
|
||||
|
||||
createTableSQL := buildCreateTableSQL(tbl)
|
||||
if _, err := p.db.ExecContext(ctx, createTableSQL); err != nil {
|
||||
return fmt.Errorf("postgres sink %q: ensure table %q: %w", p.name, tbl.name, err)
|
||||
}
|
||||
|
||||
for _, idx := range tbl.indexes {
|
||||
createIndexSQL := buildCreateIndexSQL(tbl.name, idx)
|
||||
if _, err := p.db.ExecContext(ctx, createIndexSQL); err != nil {
|
||||
return fmt.Errorf("postgres sink %q: ensure index %q on %q: %w", p.name, idx.Name, tbl.name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgresSink) lookupTable(table string) (postgresTableCompiled, error) {
|
||||
table = strings.TrimSpace(table)
|
||||
if table == "" {
|
||||
return postgresTableCompiled{}, fmt.Errorf("postgres sink: table cannot be empty")
|
||||
}
|
||||
tbl, ok := p.schema.tables[table]
|
||||
if !ok {
|
||||
return postgresTableCompiled{}, fmt.Errorf("postgres sink: unknown table %q", table)
|
||||
}
|
||||
return tbl, nil
|
||||
}
|
||||
|
||||
func parsePostgresPruneWindow(cfg config.SinkConfig) (time.Duration, error) {
|
||||
raw, ok := cfg.Params["prune"]
|
||||
if !ok || raw == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
s, ok := raw.(string)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("sink %q: params.prune must be a string duration (e.g. 72h, 3d, 2w)", cfg.Name)
|
||||
}
|
||||
|
||||
d, err := parsePostgresPruneDuration(s)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("sink %q: params.prune %q is invalid: %w", cfg.Name, s, err)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func parsePostgresPruneDuration(raw string) (time.Duration, error) {
|
||||
s := strings.TrimSpace(raw)
|
||||
if s == "" {
|
||||
return 0, fmt.Errorf("must not be empty")
|
||||
}
|
||||
|
||||
lower := strings.ToLower(s)
|
||||
if strings.HasSuffix(lower, "d") || strings.HasSuffix(lower, "w") {
|
||||
unit := lower[len(lower)-1]
|
||||
n, err := strconv.Atoi(strings.TrimSpace(lower[:len(lower)-1]))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("must use a positive integer before %q", string(unit))
|
||||
}
|
||||
if n <= 0 {
|
||||
return 0, fmt.Errorf("must be > 0")
|
||||
}
|
||||
if unit == 'd' {
|
||||
return time.Duration(n) * 24 * time.Hour, nil
|
||||
}
|
||||
return time.Duration(n) * 7 * 24 * time.Hour, nil
|
||||
}
|
||||
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("must be a Go duration or use d/w suffixes")
|
||||
}
|
||||
if d <= 0 {
|
||||
return 0, fmt.Errorf("must be > 0")
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func buildPruneOlderThanSQL(tbl postgresTableCompiled) string {
|
||||
return fmt.Sprintf(
|
||||
`DELETE FROM %s WHERE %s < $1`,
|
||||
quotePostgresIdent(tbl.name),
|
||||
quotePostgresIdent(tbl.pruneColumn),
|
||||
)
|
||||
}
|
||||
|
||||
func execPruneOlderThan(ctx context.Context, execer postgresExecer, tbl postgresTableCompiled, cutoff time.Time) (int64, error) {
|
||||
query := buildPruneOlderThanSQL(tbl)
|
||||
res, err := execer.ExecContext(ctx, query, cutoff)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("prune older than table %q: %w", tbl.name, err)
|
||||
}
|
||||
rows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("prune older than table %q rows affected: %w", tbl.name, err)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func buildCreateTableSQL(tbl postgresTableCompiled) string {
|
||||
defs := make([]string, 0, len(tbl.columnOrder)+1)
|
||||
for _, colName := range tbl.columnOrder {
|
||||
col := tbl.columns[colName]
|
||||
def := fmt.Sprintf("%s %s", quotePostgresIdent(col.Name), col.Type)
|
||||
if !col.Nullable {
|
||||
def += " NOT NULL"
|
||||
}
|
||||
defs = append(defs, def)
|
||||
}
|
||||
if len(tbl.primaryKey) > 0 {
|
||||
defs = append(defs, fmt.Sprintf("PRIMARY KEY (%s)", joinQuotedIdents(tbl.primaryKey)))
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"CREATE TABLE IF NOT EXISTS %s (%s)",
|
||||
quotePostgresIdent(tbl.name),
|
||||
strings.Join(defs, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
func buildCreateIndexSQL(tableName string, idx PostgresIndex) string {
|
||||
unique := ""
|
||||
if idx.Unique {
|
||||
unique = "UNIQUE "
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
|
||||
unique,
|
||||
quotePostgresIdent(idx.Name),
|
||||
quotePostgresIdent(tableName),
|
||||
joinQuotedIdents(idx.Columns),
|
||||
)
|
||||
}
|
||||
|
||||
func buildInsertSQL(tbl postgresTableCompiled, w PostgresWrite) (string, []any, error) {
|
||||
cols := make([]string, 0, len(tbl.columnOrder))
|
||||
args := make([]any, 0, len(tbl.columnOrder))
|
||||
placeholders := make([]string, 0, len(tbl.columnOrder))
|
||||
|
||||
for i, colName := range tbl.columnOrder {
|
||||
v, ok := w.Values[colName]
|
||||
if !ok {
|
||||
return "", nil, fmt.Errorf("missing value for column %q", colName)
|
||||
}
|
||||
cols = append(cols, quotePostgresIdent(colName))
|
||||
args = append(args, v)
|
||||
placeholders = append(placeholders, "$"+strconv.Itoa(i+1))
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(
|
||||
"INSERT INTO %s (%s) VALUES (%s)",
|
||||
quotePostgresIdent(tbl.name),
|
||||
strings.Join(cols, ", "),
|
||||
strings.Join(placeholders, ", "),
|
||||
)
|
||||
return q, args, nil
|
||||
}
|
||||
|
||||
func joinQuotedIdents(idents []string) string {
|
||||
quoted := make([]string, 0, len(idents))
|
||||
for _, s := range idents {
|
||||
quoted = append(quoted, quotePostgresIdent(s))
|
||||
}
|
||||
return strings.Join(quoted, ", ")
|
||||
}
|
||||
|
||||
func quotePostgresIdent(s string) string {
|
||||
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
|
||||
}
|
||||
|
||||
239
sinks/postgres_schema.go
Normal file
239
sinks/postgres_schema.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package sinks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
// PostgresMapFunc maps one event into zero or more table writes.
|
||||
//
|
||||
// Returning zero writes means "this event is not mapped for this sink" and is
|
||||
// treated as a non-error no-op.
|
||||
type PostgresMapFunc func(ctx context.Context, e event.Event) ([]PostgresWrite, error)
|
||||
|
||||
// PostgresSchema describes the downstream-provided relational model and mapper
|
||||
// for one configured postgres sink.
|
||||
type PostgresSchema struct {
|
||||
Tables []PostgresTable
|
||||
MapEvent PostgresMapFunc
|
||||
}
|
||||
|
||||
type PostgresWrite struct {
|
||||
Table string
|
||||
Values map[string]any
|
||||
}
|
||||
|
||||
type PostgresTable struct {
|
||||
Name string
|
||||
Columns []PostgresColumn
|
||||
PrimaryKey []string
|
||||
PruneColumn string
|
||||
Indexes []PostgresIndex
|
||||
}
|
||||
|
||||
type PostgresColumn struct {
|
||||
Name string
|
||||
Type string
|
||||
Nullable bool
|
||||
}
|
||||
|
||||
type PostgresIndex struct {
|
||||
Name string
|
||||
Columns []string
|
||||
Unique bool
|
||||
}
|
||||
|
||||
// PostgresPruner is an optional interface exposed by PostgresSink so downstream
|
||||
// applications can call retention helpers via type assertion.
|
||||
type PostgresPruner interface {
|
||||
PruneKeepLatest(ctx context.Context, table string, keep int) (int64, error)
|
||||
PruneOlderThan(ctx context.Context, table string, cutoff time.Time) (int64, error)
|
||||
PruneAllKeepLatest(ctx context.Context, keep int) (map[string]int64, error)
|
||||
PruneAllOlderThan(ctx context.Context, cutoff time.Time) (map[string]int64, error)
|
||||
}
|
||||
|
||||
type postgresSchemaCompiled struct {
|
||||
tableOrder []string
|
||||
tables map[string]postgresTableCompiled
|
||||
mapEvent PostgresMapFunc
|
||||
}
|
||||
|
||||
type postgresTableCompiled struct {
|
||||
name string
|
||||
columns map[string]PostgresColumn
|
||||
columnOrder []string
|
||||
primaryKey []string
|
||||
pruneColumn string
|
||||
indexes []PostgresIndex
|
||||
}
|
||||
|
||||
func compilePostgresSchema(schema PostgresSchema) (postgresSchemaCompiled, error) {
|
||||
if schema.MapEvent == nil {
|
||||
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: map function is required")
|
||||
}
|
||||
if len(schema.Tables) == 0 {
|
||||
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: at least one table is required")
|
||||
}
|
||||
|
||||
compiled := postgresSchemaCompiled{
|
||||
tableOrder: make([]string, 0, len(schema.Tables)),
|
||||
tables: make(map[string]postgresTableCompiled, len(schema.Tables)),
|
||||
mapEvent: schema.MapEvent,
|
||||
}
|
||||
|
||||
seenTables := map[string]bool{}
|
||||
seenIndexes := map[string]bool{}
|
||||
|
||||
for i, tbl := range schema.Tables {
|
||||
tableName := strings.TrimSpace(tbl.Name)
|
||||
if tableName == "" {
|
||||
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: tables[%d].name is required", i)
|
||||
}
|
||||
if seenTables[tableName] {
|
||||
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: duplicate table name %q", tableName)
|
||||
}
|
||||
seenTables[tableName] = true
|
||||
|
||||
if len(tbl.Columns) == 0 {
|
||||
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q must define at least one column", tableName)
|
||||
}
|
||||
|
||||
colOrder := make([]string, 0, len(tbl.Columns))
|
||||
colMap := make(map[string]PostgresColumn, len(tbl.Columns))
|
||||
for j, col := range tbl.Columns {
|
||||
colName := strings.TrimSpace(col.Name)
|
||||
if colName == "" {
|
||||
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q columns[%d].name is required", tableName, j)
|
||||
}
|
||||
if _, exists := colMap[colName]; exists {
|
||||
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q duplicate column %q", tableName, colName)
|
||||
}
|
||||
if strings.TrimSpace(col.Type) == "" {
|
||||
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q column %q type is required", tableName, colName)
|
||||
}
|
||||
colOrder = append(colOrder, colName)
|
||||
colMap[colName] = PostgresColumn{
|
||||
Name: colName,
|
||||
Type: strings.TrimSpace(col.Type),
|
||||
Nullable: col.Nullable,
|
||||
}
|
||||
}
|
||||
|
||||
pk, err := validatePostgresColumnSet(tableName, "primary key", tbl.PrimaryKey, colMap)
|
||||
if err != nil {
|
||||
return postgresSchemaCompiled{}, err
|
||||
}
|
||||
|
||||
pruneCol := strings.TrimSpace(tbl.PruneColumn)
|
||||
if pruneCol == "" {
|
||||
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q prune column is required", tableName)
|
||||
}
|
||||
if _, ok := colMap[pruneCol]; !ok {
|
||||
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q prune column %q not found in columns", tableName, pruneCol)
|
||||
}
|
||||
|
||||
indexes := make([]PostgresIndex, 0, len(tbl.Indexes))
|
||||
for j, idx := range tbl.Indexes {
|
||||
idxName := strings.TrimSpace(idx.Name)
|
||||
if idxName == "" {
|
||||
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q indexes[%d].name is required", tableName, j)
|
||||
}
|
||||
if len(idx.Columns) == 0 {
|
||||
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q index %q must include at least one column", tableName, idxName)
|
||||
}
|
||||
if seenIndexes[idxName] {
|
||||
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: duplicate index name %q", idxName)
|
||||
}
|
||||
seenIndexes[idxName] = true
|
||||
|
||||
idxCols, err := validatePostgresColumnSet(tableName, fmt.Sprintf("index %q columns", idxName), idx.Columns, colMap)
|
||||
if err != nil {
|
||||
return postgresSchemaCompiled{}, err
|
||||
}
|
||||
|
||||
indexes = append(indexes, PostgresIndex{
|
||||
Name: idxName,
|
||||
Columns: idxCols,
|
||||
Unique: idx.Unique,
|
||||
})
|
||||
}
|
||||
|
||||
compiled.tableOrder = append(compiled.tableOrder, tableName)
|
||||
compiled.tables[tableName] = postgresTableCompiled{
|
||||
name: tableName,
|
||||
columns: colMap,
|
||||
columnOrder: colOrder,
|
||||
primaryKey: pk,
|
||||
pruneColumn: pruneCol,
|
||||
indexes: indexes,
|
||||
}
|
||||
}
|
||||
|
||||
return compiled, nil
|
||||
}
|
||||
|
||||
func validatePostgresColumnSet(tableName, label string, cols []string, colMap map[string]PostgresColumn) ([]string, error) {
|
||||
if len(cols) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
out := make([]string, 0, len(cols))
|
||||
seen := map[string]bool{}
|
||||
for _, c := range cols {
|
||||
name := strings.TrimSpace(c)
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("postgres schema: table %q %s contains empty column name", tableName, label)
|
||||
}
|
||||
if seen[name] {
|
||||
return nil, fmt.Errorf("postgres schema: table %q %s contains duplicate column %q", tableName, label, name)
|
||||
}
|
||||
if _, ok := colMap[name]; !ok {
|
||||
return nil, fmt.Errorf("postgres schema: table %q %s references unknown column %q", tableName, label, name)
|
||||
}
|
||||
seen[name] = true
|
||||
out = append(out, name)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s postgresSchemaCompiled) validateWrite(w PostgresWrite) (postgresTableCompiled, error) {
|
||||
tableName := strings.TrimSpace(w.Table)
|
||||
if tableName == "" {
|
||||
return postgresTableCompiled{}, fmt.Errorf("write table is required")
|
||||
}
|
||||
t, ok := s.tables[tableName]
|
||||
if !ok {
|
||||
return postgresTableCompiled{}, fmt.Errorf("table %q is not defined in postgres schema", tableName)
|
||||
}
|
||||
|
||||
if len(w.Values) == 0 {
|
||||
return postgresTableCompiled{}, fmt.Errorf("write for table %q must include values", tableName)
|
||||
}
|
||||
|
||||
for k := range w.Values {
|
||||
if _, ok := t.columns[k]; !ok {
|
||||
return postgresTableCompiled{}, fmt.Errorf("write for table %q includes unknown column %q", tableName, k)
|
||||
}
|
||||
}
|
||||
|
||||
if len(w.Values) != len(t.columnOrder) {
|
||||
return postgresTableCompiled{}, fmt.Errorf("write for table %q must include all declared columns", tableName)
|
||||
}
|
||||
|
||||
for _, col := range t.columnOrder {
|
||||
v, ok := w.Values[col]
|
||||
if !ok {
|
||||
return postgresTableCompiled{}, fmt.Errorf("write for table %q is missing column %q", tableName, col)
|
||||
}
|
||||
if v == nil {
|
||||
if c := t.columns[col]; !c.Nullable {
|
||||
return postgresTableCompiled{}, fmt.Errorf("write for table %q has nil value for non-null column %q", tableName, col)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
748
sinks/postgres_test.go
Normal file
748
sinks/postgres_test.go
Normal file
@@ -0,0 +1,748 @@
|
||||
package sinks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
|
||||
)
|
||||
|
||||
type fakeResult struct {
|
||||
rows int64
|
||||
}
|
||||
|
||||
func (r fakeResult) LastInsertId() (int64, error) { return 0, errors.New("unsupported") }
|
||||
func (r fakeResult) RowsAffected() (int64, error) { return r.rows, nil }
|
||||
|
||||
type execCall struct {
|
||||
query string
|
||||
args []any
|
||||
}
|
||||
|
||||
type fakeTx struct {
|
||||
execCalls []execCall
|
||||
execErr error
|
||||
execErrOnCall int
|
||||
execRows int64
|
||||
commitErr error
|
||||
rollbackErr error
|
||||
commitCalls int
|
||||
rollbackCalls int
|
||||
}
|
||||
|
||||
func (t *fakeTx) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) {
|
||||
t.execCalls = append(t.execCalls, execCall{query: query, args: append([]any(nil), args...)})
|
||||
if t.execErr != nil && t.execErrOnCall == len(t.execCalls) {
|
||||
return nil, t.execErr
|
||||
}
|
||||
return fakeResult{rows: t.execRows}, nil
|
||||
}
|
||||
|
||||
func (t *fakeTx) Commit() error {
|
||||
t.commitCalls++
|
||||
return t.commitErr
|
||||
}
|
||||
|
||||
func (t *fakeTx) Rollback() error {
|
||||
t.rollbackCalls++
|
||||
return t.rollbackErr
|
||||
}
|
||||
|
||||
type fakeDB struct {
|
||||
pingErr error
|
||||
beginErr error
|
||||
execErr error
|
||||
execErrOnCall int
|
||||
execRows int64
|
||||
pingCalls int
|
||||
beginCalls int
|
||||
execCalls []execCall
|
||||
closeCalls int
|
||||
tx *fakeTx
|
||||
}
|
||||
|
||||
func (d *fakeDB) PingContext(_ context.Context) error {
|
||||
d.pingCalls++
|
||||
return d.pingErr
|
||||
}
|
||||
|
||||
func (d *fakeDB) BeginTx(_ context.Context, _ *sql.TxOptions) (postgresTx, error) {
|
||||
d.beginCalls++
|
||||
if d.beginErr != nil {
|
||||
return nil, d.beginErr
|
||||
}
|
||||
if d.tx == nil {
|
||||
d.tx = &fakeTx{}
|
||||
}
|
||||
return d.tx, nil
|
||||
}
|
||||
|
||||
func (d *fakeDB) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) {
|
||||
d.execCalls = append(d.execCalls, execCall{query: query, args: append([]any(nil), args...)})
|
||||
if d.execErr != nil && d.execErrOnCall == len(d.execCalls) {
|
||||
return nil, d.execErr
|
||||
}
|
||||
return fakeResult{rows: d.execRows}, nil
|
||||
}
|
||||
|
||||
func (d *fakeDB) Close() error {
|
||||
d.closeCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func withPostgresTestState(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
oldOpen := openPostgresDB
|
||||
t.Cleanup(func() {
|
||||
openPostgresDB = oldOpen
|
||||
})
|
||||
}
|
||||
|
||||
func validTestEvent() event.Event {
|
||||
now := time.Now().UTC()
|
||||
return event.Event{
|
||||
ID: "evt-1",
|
||||
Kind: event.Kind("observation"),
|
||||
Source: "source-1",
|
||||
EmittedAt: now,
|
||||
Payload: map[string]any{
|
||||
"x": 1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func schemaOneTable(mapFn PostgresMapFunc) PostgresSchema {
|
||||
return PostgresSchema{
|
||||
Tables: []PostgresTable{
|
||||
{
|
||||
Name: "events",
|
||||
Columns: []PostgresColumn{
|
||||
{Name: "event_id", Type: "TEXT", Nullable: false},
|
||||
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
|
||||
{Name: "payload_json", Type: "JSONB", Nullable: false},
|
||||
},
|
||||
PrimaryKey: []string{"event_id"},
|
||||
PruneColumn: "emitted_at",
|
||||
Indexes: []PostgresIndex{
|
||||
{Name: "idx_events_emitted_at", Columns: []string{"emitted_at"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
MapEvent: mapFn,
|
||||
}
|
||||
}
|
||||
|
||||
func schemaTwoTables(mapFn PostgresMapFunc) PostgresSchema {
|
||||
return PostgresSchema{
|
||||
Tables: []PostgresTable{
|
||||
{
|
||||
Name: "events",
|
||||
Columns: []PostgresColumn{
|
||||
{Name: "event_id", Type: "TEXT", Nullable: false},
|
||||
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
|
||||
},
|
||||
PrimaryKey: []string{"event_id"},
|
||||
PruneColumn: "emitted_at",
|
||||
},
|
||||
{
|
||||
Name: "event_payloads",
|
||||
Columns: []PostgresColumn{
|
||||
{Name: "event_id", Type: "TEXT", Nullable: false},
|
||||
{Name: "payload_json", Type: "JSONB", Nullable: false},
|
||||
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
|
||||
},
|
||||
PrimaryKey: []string{"event_id"},
|
||||
PruneColumn: "emitted_at",
|
||||
},
|
||||
},
|
||||
MapEvent: mapFn,
|
||||
}
|
||||
}
|
||||
|
||||
func mustCompileSchema(t *testing.T, s PostgresSchema) postgresSchemaCompiled {
|
||||
t.Helper()
|
||||
compiled, err := compilePostgresSchema(s)
|
||||
if err != nil {
|
||||
t.Fatalf("compile schema: %v", err)
|
||||
}
|
||||
return compiled
|
||||
}
|
||||
|
||||
func TestCompilePostgresSchemaRejectsInvalidSchema(t *testing.T) {
|
||||
_, err := compilePostgresSchema(PostgresSchema{
|
||||
Tables: []PostgresTable{
|
||||
{
|
||||
Name: "events",
|
||||
Columns: []PostgresColumn{
|
||||
{Name: "id", Type: "TEXT", Nullable: false},
|
||||
},
|
||||
PruneColumn: "missing_col",
|
||||
},
|
||||
},
|
||||
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected invalid schema error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "prune column") {
|
||||
t.Fatalf("unexpected schema validation error: %v", err)
|
||||
}
|
||||
|
||||
_, err = compilePostgresSchema(PostgresSchema{
|
||||
Tables: []PostgresTable{
|
||||
{
|
||||
Name: "events",
|
||||
Columns: []PostgresColumn{
|
||||
{Name: "id", Type: "TEXT", Nullable: false},
|
||||
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
|
||||
},
|
||||
PruneColumn: "emitted_at",
|
||||
Indexes: []PostgresIndex{
|
||||
{Name: "idx_events_empty", Columns: nil},
|
||||
},
|
||||
},
|
||||
},
|
||||
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected invalid index schema error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "at least one column") {
|
||||
t.Fatalf("unexpected index validation error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresFactoryBuildsMultipleSinksWithSameSchema(t *testing.T) {
|
||||
withPostgresTestState(t)
|
||||
|
||||
dbs := []*fakeDB{{}, {}}
|
||||
var gotCfgs []pgconn.ConnConfig
|
||||
openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
|
||||
gotCfgs = append(gotCfgs, cfg)
|
||||
db := dbs[len(gotCfgs)-1]
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
factory := PostgresFactory(schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
|
||||
return nil, nil
|
||||
}))
|
||||
|
||||
for _, name := range []string{"pg_a", "pg_b"} {
|
||||
sink, err := factory(config.SinkConfig{
|
||||
Name: name,
|
||||
Driver: "postgres",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "user",
|
||||
"password": "pass",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("factory(%q) error = %v", name, err)
|
||||
}
|
||||
if sink == nil {
|
||||
t.Fatalf("factory(%q) returned nil sink", name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(gotCfgs) != 2 {
|
||||
t.Fatalf("len(gotCfgs) = %d, want 2", len(gotCfgs))
|
||||
}
|
||||
if gotCfgs[0].Username != "user" || gotCfgs[0].Password != "pass" {
|
||||
t.Fatalf("first ConnConfig = %+v", gotCfgs[0])
|
||||
}
|
||||
for i, db := range dbs {
|
||||
if db.pingCalls != 1 {
|
||||
t.Fatalf("db[%d] pingCalls = %d, want 1", i, db.pingCalls)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresSinkFromConfigMissingParams(t *testing.T) {
|
||||
withPostgresTestState(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
params map[string]any
|
||||
want string
|
||||
}{
|
||||
{name: "missing uri", params: map[string]any{"username": "u", "password": "p"}, want: "params.uri"},
|
||||
{name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p"}, want: "params.username"},
|
||||
{name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u"}, want: "params.password"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
|
||||
Name: "pg",
|
||||
Driver: "postgres",
|
||||
Params: tc.params,
|
||||
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.want) {
|
||||
t.Fatalf("expected %q in error, got: %v", tc.want, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresSinkFromConfigRejectsInvalidSchema(t *testing.T) {
|
||||
withPostgresTestState(t)
|
||||
|
||||
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
|
||||
Name: "pg",
|
||||
Driver: "postgres",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "user",
|
||||
"password": "pass",
|
||||
},
|
||||
}, PostgresSchema{
|
||||
Tables: []PostgresTable{
|
||||
{
|
||||
Name: "events",
|
||||
Columns: []PostgresColumn{
|
||||
{Name: "id", Type: "TEXT", Nullable: false},
|
||||
},
|
||||
PruneColumn: "missing_col",
|
||||
},
|
||||
},
|
||||
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected invalid schema error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "compile schema") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresSinkFromConfigEagerInit(t *testing.T) {
|
||||
withPostgresTestState(t)
|
||||
|
||||
db := &fakeDB{}
|
||||
var gotCfg pgconn.ConnConfig
|
||||
openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
|
||||
gotCfg = cfg
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
s, err := NewPostgresSinkFromConfig(config.SinkConfig{
|
||||
Name: "pg",
|
||||
Driver: "postgres",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://db.example.local:5432/feedkit?sslmode=disable",
|
||||
"username": "app_user",
|
||||
"password": "app_pass",
|
||||
},
|
||||
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
|
||||
if err != nil {
|
||||
t.Fatalf("new postgres sink: %v", err)
|
||||
}
|
||||
if s == nil {
|
||||
t.Fatalf("expected sink")
|
||||
}
|
||||
|
||||
if db.pingCalls != 1 {
|
||||
t.Fatalf("expected one ping, got %d", db.pingCalls)
|
||||
}
|
||||
if len(db.execCalls) != 2 {
|
||||
t.Fatalf("expected 2 init exec calls (table + index), got %d", len(db.execCalls))
|
||||
}
|
||||
if !strings.Contains(db.execCalls[0].query, `CREATE TABLE IF NOT EXISTS "events"`) {
|
||||
t.Fatalf("unexpected create table query: %s", db.execCalls[0].query)
|
||||
}
|
||||
if !strings.Contains(db.execCalls[1].query, `CREATE INDEX IF NOT EXISTS "idx_events_emitted_at"`) {
|
||||
t.Fatalf("unexpected create index query: %s", db.execCalls[1].query)
|
||||
}
|
||||
|
||||
if gotCfg.URI != "postgres://db.example.local:5432/feedkit?sslmode=disable" {
|
||||
t.Fatalf("URI = %q", gotCfg.URI)
|
||||
}
|
||||
if gotCfg.Username != "app_user" {
|
||||
t.Fatalf("Username = %q, want app_user", gotCfg.Username)
|
||||
}
|
||||
if gotCfg.Password != "app_pass" {
|
||||
t.Fatalf("Password = %q, want app_pass", gotCfg.Password)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresSinkFromConfigInitFailureClosesDB(t *testing.T) {
|
||||
withPostgresTestState(t)
|
||||
|
||||
db := &fakeDB{execErrOnCall: 1, execErr: errors.New("ddl failed")}
|
||||
openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
|
||||
Name: "pg",
|
||||
Driver: "postgres",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "user",
|
||||
"password": "pass",
|
||||
},
|
||||
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
|
||||
if err == nil {
|
||||
t.Fatalf("expected init error")
|
||||
}
|
||||
if db.closeCalls != 1 {
|
||||
t.Fatalf("expected db close on init failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresSinkFromConfigPruneParamAccepted(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want time.Duration
|
||||
}{
|
||||
{name: "go duration", in: "72h", want: 72 * time.Hour},
|
||||
{name: "days suffix", in: "3d", want: 72 * time.Hour},
|
||||
{name: "weeks suffix", in: "2w", want: 14 * 24 * time.Hour},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
withPostgresTestState(t)
|
||||
|
||||
openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) {
|
||||
return &fakeDB{}, nil
|
||||
}
|
||||
|
||||
s, err := NewPostgresSinkFromConfig(config.SinkConfig{
|
||||
Name: "pg",
|
||||
Driver: "postgres",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "user",
|
||||
"password": "pass",
|
||||
"prune": tc.in,
|
||||
},
|
||||
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
|
||||
if err != nil {
|
||||
t.Fatalf("new postgres sink: %v", err)
|
||||
}
|
||||
|
||||
pg, ok := s.(*PostgresSink)
|
||||
if !ok {
|
||||
t.Fatalf("expected *PostgresSink, got %T", s)
|
||||
}
|
||||
if pg.pruneWindow != tc.want {
|
||||
t.Fatalf("prune window = %s, want %s", pg.pruneWindow, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresSinkFromConfigPruneParamRejected(t *testing.T) {
|
||||
withPostgresTestState(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in any
|
||||
}{
|
||||
{name: "empty", in: ""},
|
||||
{name: "zero", in: "0"},
|
||||
{name: "negative", in: "-1h"},
|
||||
{name: "malformed", in: "abc"},
|
||||
{name: "fractional day", in: "1.5d"},
|
||||
{name: "wrong type", in: 5},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
|
||||
Name: "pg",
|
||||
Driver: "postgres",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "user",
|
||||
"password": "pass",
|
||||
"prune": tc.in,
|
||||
},
|
||||
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "params.prune") {
|
||||
t.Fatalf("expected params.prune error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresSinkConsumeInvalidEvent(t *testing.T) {
|
||||
db := &fakeDB{}
|
||||
called := 0
|
||||
sink := &PostgresSink{
|
||||
name: "pg",
|
||||
db: db,
|
||||
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
|
||||
called++
|
||||
return nil, nil
|
||||
})),
|
||||
}
|
||||
|
||||
err := sink.Consume(context.Background(), event.Event{})
|
||||
if err == nil {
|
||||
t.Fatalf("expected invalid event error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid event") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if called != 0 {
|
||||
t.Fatalf("expected mapper not called for invalid events")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresSinkConsumeUnmappedEventIsNoOp(t *testing.T) {
|
||||
db := &fakeDB{}
|
||||
sink := &PostgresSink{
|
||||
name: "pg",
|
||||
db: db,
|
||||
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
|
||||
return nil, nil
|
||||
})),
|
||||
}
|
||||
|
||||
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
|
||||
t.Fatalf("consume: %v", err)
|
||||
}
|
||||
if db.beginCalls != 0 {
|
||||
t.Fatalf("expected no transaction for unmapped events")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresSinkConsumeOneEventWritesMultipleTablesAtomically(t *testing.T) {
|
||||
tx := &fakeTx{}
|
||||
db := &fakeDB{tx: tx}
|
||||
sink := &PostgresSink{
|
||||
name: "pg",
|
||||
db: db,
|
||||
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
|
||||
return []PostgresWrite{
|
||||
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
|
||||
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
|
||||
}, nil
|
||||
})),
|
||||
}
|
||||
|
||||
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
|
||||
t.Fatalf("consume: %v", err)
|
||||
}
|
||||
if db.beginCalls != 1 {
|
||||
t.Fatalf("expected one transaction begin, got %d", db.beginCalls)
|
||||
}
|
||||
if len(tx.execCalls) != 2 {
|
||||
t.Fatalf("expected 2 insert statements, got %d", len(tx.execCalls))
|
||||
}
|
||||
if tx.commitCalls != 1 {
|
||||
t.Fatalf("expected one commit, got %d", tx.commitCalls)
|
||||
}
|
||||
if tx.rollbackCalls != 0 {
|
||||
t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresSinkConsumeInsertFailureRollsBack(t *testing.T) {
|
||||
tx := &fakeTx{execErrOnCall: 2, execErr: errors.New("duplicate key")}
|
||||
db := &fakeDB{tx: tx}
|
||||
sink := &PostgresSink{
|
||||
name: "pg",
|
||||
db: db,
|
||||
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
|
||||
return []PostgresWrite{
|
||||
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
|
||||
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
|
||||
}, nil
|
||||
})),
|
||||
}
|
||||
|
||||
err := sink.Consume(context.Background(), validTestEvent())
|
||||
if err == nil {
|
||||
t.Fatalf("expected insert error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "insert into") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if tx.commitCalls != 0 {
|
||||
t.Fatalf("expected no commit")
|
||||
}
|
||||
if tx.rollbackCalls != 1 {
|
||||
t.Fatalf("expected rollback, got %d", tx.rollbackCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresSinkConsumeAutoPruneRunsInSameTransaction(t *testing.T) {
|
||||
tx := &fakeTx{}
|
||||
db := &fakeDB{tx: tx}
|
||||
sink := &PostgresSink{
|
||||
name: "pg",
|
||||
db: db,
|
||||
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
|
||||
return []PostgresWrite{
|
||||
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
|
||||
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
|
||||
}, nil
|
||||
})),
|
||||
pruneWindow: 24 * time.Hour,
|
||||
}
|
||||
|
||||
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
|
||||
t.Fatalf("consume: %v", err)
|
||||
}
|
||||
if len(tx.execCalls) != 4 {
|
||||
t.Fatalf("expected 4 tx statements (2 inserts + 2 prunes), got %d", len(tx.execCalls))
|
||||
}
|
||||
if !strings.Contains(tx.execCalls[2].query, `DELETE FROM "events"`) {
|
||||
t.Fatalf("expected prune delete for events, got %s", tx.execCalls[2].query)
|
||||
}
|
||||
if !strings.Contains(tx.execCalls[3].query, `DELETE FROM "event_payloads"`) {
|
||||
t.Fatalf("expected prune delete for event_payloads, got %s", tx.execCalls[3].query)
|
||||
}
|
||||
if tx.commitCalls != 1 {
|
||||
t.Fatalf("expected one commit, got %d", tx.commitCalls)
|
||||
}
|
||||
if tx.rollbackCalls != 0 {
|
||||
t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresSinkConsumeAutoPruneFailureRollsBack(t *testing.T) {
|
||||
tx := &fakeTx{execErrOnCall: 3, execErr: errors.New("prune failed")}
|
||||
db := &fakeDB{tx: tx}
|
||||
sink := &PostgresSink{
|
||||
name: "pg",
|
||||
db: db,
|
||||
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
|
||||
return []PostgresWrite{
|
||||
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
|
||||
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
|
||||
}, nil
|
||||
})),
|
||||
pruneWindow: 24 * time.Hour,
|
||||
}
|
||||
|
||||
err := sink.Consume(context.Background(), validTestEvent())
|
||||
if err == nil {
|
||||
t.Fatalf("expected prune error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "prune older than") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if tx.commitCalls != 0 {
|
||||
t.Fatalf("expected no commit")
|
||||
}
|
||||
if tx.rollbackCalls != 1 {
|
||||
t.Fatalf("expected rollback, got %d", tx.rollbackCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresSinkPrunePerTable(t *testing.T) {
|
||||
db := &fakeDB{execRows: 7}
|
||||
sink := &PostgresSink{
|
||||
name: "pg",
|
||||
db: db,
|
||||
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
|
||||
return nil, nil
|
||||
})),
|
||||
}
|
||||
|
||||
rows, err := sink.PruneKeepLatest(context.Background(), "events", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("prune keep latest: %v", err)
|
||||
}
|
||||
if rows != 7 {
|
||||
t.Fatalf("unexpected rows affected: %d", rows)
|
||||
}
|
||||
if len(db.execCalls) != 1 {
|
||||
t.Fatalf("expected one prune query")
|
||||
}
|
||||
if !strings.Contains(db.execCalls[0].query, `ORDER BY "emitted_at" DESC`) {
|
||||
t.Fatalf("unexpected keep-latest query: %s", db.execCalls[0].query)
|
||||
}
|
||||
if len(db.execCalls[0].args) != 1 || db.execCalls[0].args[0] != 10 {
|
||||
t.Fatalf("unexpected keep-latest args: %#v", db.execCalls[0].args)
|
||||
}
|
||||
|
||||
cutoff := time.Now().UTC().Add(-24 * time.Hour)
|
||||
rows, err = sink.PruneOlderThan(context.Background(), "events", cutoff)
|
||||
if err != nil {
|
||||
t.Fatalf("prune older than: %v", err)
|
||||
}
|
||||
if rows != 7 {
|
||||
t.Fatalf("unexpected rows affected: %d", rows)
|
||||
}
|
||||
if len(db.execCalls) != 2 {
|
||||
t.Fatalf("expected two prune queries")
|
||||
}
|
||||
if !strings.Contains(db.execCalls[1].query, `WHERE "emitted_at" < $1`) {
|
||||
t.Fatalf("unexpected older-than query: %s", db.execCalls[1].query)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresSinkPruneAllTables(t *testing.T) {
|
||||
db := &fakeDB{execRows: 3}
|
||||
sink := &PostgresSink{
|
||||
name: "pg",
|
||||
db: db,
|
||||
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
|
||||
return nil, nil
|
||||
})),
|
||||
}
|
||||
|
||||
keepCounts, err := sink.PruneAllKeepLatest(context.Background(), 5)
|
||||
if err != nil {
|
||||
t.Fatalf("prune all keep latest: %v", err)
|
||||
}
|
||||
if len(keepCounts) != 2 || keepCounts["events"] != 3 || keepCounts["event_payloads"] != 3 {
|
||||
t.Fatalf("unexpected keep counts: %#v", keepCounts)
|
||||
}
|
||||
|
||||
db.execCalls = nil
|
||||
olderCounts, err := sink.PruneAllOlderThan(context.Background(), time.Now().UTC())
|
||||
if err != nil {
|
||||
t.Fatalf("prune all older than: %v", err)
|
||||
}
|
||||
if len(olderCounts) != 2 || olderCounts["events"] != 3 || olderCounts["event_payloads"] != 3 {
|
||||
t.Fatalf("unexpected older-than counts: %#v", olderCounts)
|
||||
}
|
||||
if len(db.execCalls) != 2 {
|
||||
t.Fatalf("expected one prune call per table")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresSinkPruneErrors(t *testing.T) {
|
||||
db := &fakeDB{}
|
||||
sink := &PostgresSink{
|
||||
name: "pg",
|
||||
db: db,
|
||||
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
|
||||
return nil, nil
|
||||
})),
|
||||
}
|
||||
|
||||
if _, err := sink.PruneKeepLatest(context.Background(), "events", -1); err == nil {
|
||||
t.Fatalf("expected negative keep error")
|
||||
}
|
||||
if _, err := sink.PruneOlderThan(context.Background(), "missing", time.Now().UTC()); err == nil {
|
||||
t.Fatalf("expected unknown table error")
|
||||
}
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
package sinks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
type RabbitMQSink struct {
|
||||
name string
|
||||
url string
|
||||
exchange string
|
||||
}
|
||||
|
||||
func NewRabbitMQSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
|
||||
url, err := requireStringParam(cfg, "url")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ex, err := requireStringParam(cfg, "exchange")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &RabbitMQSink{name: cfg.Name, url: url, exchange: ex}, nil
|
||||
}
|
||||
|
||||
func (r *RabbitMQSink) Name() string { return r.name }
|
||||
|
||||
func (r *RabbitMQSink) Consume(ctx context.Context, e event.Event) error {
|
||||
_ = ctx
|
||||
|
||||
// Boundary validation: if something upstream violated invariants,
|
||||
// surface it loudly rather than printing partial nonsense.
|
||||
if err := e.Validate(); err != nil {
|
||||
return fmt.Errorf("rabbitmq sink: invalid event: %w", err)
|
||||
}
|
||||
|
||||
// TODO implement RabbitMQ publishing
|
||||
return nil
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package sinks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
)
|
||||
@@ -21,13 +22,40 @@ func NewRegistry() *Registry {
|
||||
}
|
||||
|
||||
func (r *Registry) Register(driver string, f Factory) {
|
||||
if r == nil {
|
||||
panic("sinks.Registry.Register: registry cannot be nil")
|
||||
}
|
||||
driver = strings.TrimSpace(driver)
|
||||
if driver == "" {
|
||||
panic("sinks.Registry.Register: driver cannot be empty")
|
||||
}
|
||||
if f == nil {
|
||||
panic(fmt.Sprintf("sinks.Registry.Register: factory cannot be nil (driver=%q)", driver))
|
||||
}
|
||||
if r.byDriver == nil {
|
||||
r.byDriver = map[string]Factory{}
|
||||
}
|
||||
if _, exists := r.byDriver[driver]; exists {
|
||||
panic(fmt.Sprintf("sinks.Registry.Register: driver %q already registered", driver))
|
||||
}
|
||||
r.byDriver[driver] = f
|
||||
}
|
||||
|
||||
func (r *Registry) Build(cfg config.SinkConfig) (Sink, error) {
|
||||
f, ok := r.byDriver[cfg.Driver]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown sink driver: %q", cfg.Driver)
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("sinks registry is nil")
|
||||
}
|
||||
return f(cfg)
|
||||
driver := strings.TrimSpace(cfg.Driver)
|
||||
f, ok := r.byDriver[driver]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown sink driver: %q", driver)
|
||||
}
|
||||
sink, err := f(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build sink %q: %w", driver, err)
|
||||
}
|
||||
if sink == nil {
|
||||
return nil, fmt.Errorf("build sink %q: factory returned nil sink", driver)
|
||||
}
|
||||
return sink, nil
|
||||
}
|
||||
|
||||
126
sinks/registry_test.go
Normal file
126
sinks/registry_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package sinks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
type testSink struct{ name string }
|
||||
|
||||
func (s testSink) Name() string { return s.name }
|
||||
|
||||
func (s testSink) Consume(context.Context, event.Event) error { return nil }
|
||||
|
||||
func TestRegistryRegisterPanicsOnNilRegistry(t *testing.T) {
|
||||
var r *Registry
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatalf("Register() expected panic on nil registry")
|
||||
}
|
||||
}()
|
||||
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "stdout"}, nil })
|
||||
}
|
||||
|
||||
func TestRegistryRegisterPanicsOnEmptyDriver(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatalf("Register() expected panic on empty driver")
|
||||
}
|
||||
}()
|
||||
r.Register(" ", func(config.SinkConfig) (Sink, error) { return testSink{name: "x"}, nil })
|
||||
}
|
||||
|
||||
func TestRegistryRegisterPanicsOnNilFactory(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatalf("Register() expected panic on nil factory")
|
||||
}
|
||||
}()
|
||||
r.Register("stdout", nil)
|
||||
}
|
||||
|
||||
func TestRegistryRegisterPanicsOnDuplicateDriver(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "a"}, nil })
|
||||
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatalf("Register() expected panic on duplicate driver")
|
||||
}
|
||||
}()
|
||||
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "b"}, nil })
|
||||
}
|
||||
|
||||
func TestRegistryBuildNilRegistryFails(t *testing.T) {
|
||||
var r *Registry
|
||||
_, err := r.Build(config.SinkConfig{Driver: "stdout"})
|
||||
if err == nil {
|
||||
t.Fatalf("Build() expected error for nil registry")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "registry is nil") {
|
||||
t.Fatalf("Build() error = %q, want registry is nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryBuildTrimsDriver(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "stdout"}, nil })
|
||||
|
||||
sink, err := r.Build(config.SinkConfig{Name: "sink1", Driver: " stdout "})
|
||||
if err != nil {
|
||||
t.Fatalf("Build() error = %v", err)
|
||||
}
|
||||
if sink.Name() != "stdout" {
|
||||
t.Fatalf("Build() sink name = %q, want stdout", sink.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryBuildWrapsFactoryError(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("broken", func(config.SinkConfig) (Sink, error) { return nil, errors.New("boom") })
|
||||
|
||||
_, err := r.Build(config.SinkConfig{Driver: "broken"})
|
||||
if err == nil {
|
||||
t.Fatalf("Build() expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `build sink "broken": boom`) {
|
||||
t.Fatalf("Build() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryBuildRejectsNilSink(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register("nil_sink", func(config.SinkConfig) (Sink, error) { return nil, nil })
|
||||
|
||||
_, err := r.Build(config.SinkConfig{Driver: "nil_sink"})
|
||||
if err == nil {
|
||||
t.Fatalf("Build() expected nil sink error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "factory returned nil sink") {
|
||||
t.Fatalf("Build() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterBuiltinsExposesExpectedDrivers(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
RegisterBuiltins(r)
|
||||
|
||||
if len(r.byDriver) != 2 {
|
||||
t.Fatalf("len(byDriver) = %d, want 2", len(r.byDriver))
|
||||
}
|
||||
for _, driver := range []string{"stdout", "nats"} {
|
||||
if _, ok := r.byDriver[driver]; !ok {
|
||||
t.Fatalf("builtins missing driver %q", driver)
|
||||
}
|
||||
}
|
||||
if _, ok := r.byDriver["postgres"]; ok {
|
||||
t.Fatalf("builtins unexpectedly registered postgres driver")
|
||||
}
|
||||
}
|
||||
52
sources/doc.go
Normal file
52
sources/doc.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Package sources defines feedkit's input-source abstractions and source
|
||||
// registry.
|
||||
//
|
||||
// External API surface:
|
||||
// - Input: common source identity surface
|
||||
// - PollSource: polling source interface
|
||||
// - StreamSource: streaming source interface
|
||||
// - StreamRetryable / StreamFatal / IsStreamRetryable / IsStreamFatal:
|
||||
// stream exit classification helpers
|
||||
// - Registry / NewRegistry: source driver registry and builders
|
||||
// - HTTPSource / NewHTTPSource: reusable HTTP polling helper
|
||||
// - PostgresQuerySource / NewPostgresQuerySource: reusable Postgres polling
|
||||
// helper
|
||||
//
|
||||
// Source drivers are domain-specific and registered into Registry by driver name.
|
||||
// Registry can then build configured sources from config.SourceConfig.
|
||||
//
|
||||
// A single source may emit 0..N events per poll or stream iteration, and those
|
||||
// events may span multiple event kinds.
|
||||
//
|
||||
// Optional helpers from helpers.go:
|
||||
// - DefaultEventID: default event ID policy for source implementations
|
||||
// - SingleEvent: construct and validate a one-element event slice
|
||||
// - ValidateExpectedKinds: compare configured expected kinds against source
|
||||
// advertised kinds when metadata is available
|
||||
//
|
||||
// HTTP-backed polling sources can share NewHTTPSource for generic HTTP config
|
||||
// parsing and conditional GET behavior. The helper understands:
|
||||
// - params.url
|
||||
// - params.user_agent
|
||||
// - params.conditional (optional, default true)
|
||||
// - params.http_timeout (optional, default transport.DefaultHTTPTimeout)
|
||||
// - params.http_response_body_limit_bytes (optional, default
|
||||
// transport.DefaultHTTPResponseBodyLimitBytes)
|
||||
//
|
||||
// When validators are available, NewHTTPSource prefers ETag/If-None-Match and
|
||||
// falls back to Last-Modified/If-Modified-Since. A 304 Not Modified response is
|
||||
// treated as a successful unchanged poll.
|
||||
//
|
||||
// Postgres-backed polling sources can share NewPostgresQuerySource for generic
|
||||
// DB config parsing and query execution. The helper understands:
|
||||
// - params.uri
|
||||
// - params.username
|
||||
// - params.password
|
||||
// - params.query
|
||||
// - params.query_timeout (optional, default 30s)
|
||||
//
|
||||
// feedkit does not register a built-in postgres poll driver. Downstream daemons
|
||||
// should register domain-specific driver names that call
|
||||
// NewPostgresQuerySource, then keep SQL semantics, row scanning, ordering,
|
||||
// watermark policy, and event construction in their own source types.
|
||||
package sources
|
||||
140
sources/helpers.go
Normal file
140
sources/helpers.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package sources
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
// DefaultEventID applies feedkit's default Event.ID policy:
|
||||
//
|
||||
// - If upstream provides an ID, use it (trimmed).
|
||||
// - Otherwise, ID is "<Source>:<EffectiveAt>" when available.
|
||||
// - If EffectiveAt is unavailable, fall back to "<Source>:<EmittedAt>".
|
||||
//
|
||||
// Timestamps are encoded as RFC3339Nano in UTC.
|
||||
func DefaultEventID(upstreamID, sourceName string, effectiveAt *time.Time, emittedAt time.Time) string {
|
||||
if id := strings.TrimSpace(upstreamID); id != "" {
|
||||
return id
|
||||
}
|
||||
|
||||
src := strings.TrimSpace(sourceName)
|
||||
if src == "" {
|
||||
src = "UNKNOWN_SOURCE"
|
||||
}
|
||||
|
||||
if effectiveAt != nil && !effectiveAt.IsZero() {
|
||||
return fmt.Sprintf("%s:%s", src, effectiveAt.UTC().Format(time.RFC3339Nano))
|
||||
}
|
||||
|
||||
t := emittedAt.UTC()
|
||||
if t.IsZero() {
|
||||
t = time.Now().UTC()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s:%s", src, t.Format(time.RFC3339Nano))
|
||||
}
|
||||
|
||||
// SingleEvent constructs, validates, and returns a slice containing exactly one event.
|
||||
func SingleEvent(
|
||||
kind event.Kind,
|
||||
sourceName string,
|
||||
schema string,
|
||||
id string,
|
||||
emittedAt time.Time,
|
||||
effectiveAt *time.Time,
|
||||
payload any,
|
||||
) ([]event.Event, error) {
|
||||
if emittedAt.IsZero() {
|
||||
emittedAt = time.Now().UTC()
|
||||
} else {
|
||||
emittedAt = emittedAt.UTC()
|
||||
}
|
||||
|
||||
e := event.Event{
|
||||
ID: id,
|
||||
Kind: kind,
|
||||
Source: sourceName,
|
||||
EmittedAt: emittedAt,
|
||||
EffectiveAt: effectiveAt,
|
||||
Schema: schema,
|
||||
Payload: payload,
|
||||
}
|
||||
|
||||
if err := e.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []event.Event{e}, nil
|
||||
}
|
||||
|
||||
// ValidateExpectedKinds checks that configured source expected kinds are a subset
|
||||
// of the kinds advertised by the built source, when the source exposes kind
|
||||
// metadata. If the source does not advertise kinds, the check is skipped.
|
||||
func ValidateExpectedKinds(cfg config.SourceConfig, in Input) error {
|
||||
expectedKinds, err := parseExpectedKinds(cfg.ExpectedKinds())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(expectedKinds) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
advertisedKinds := advertisedSourceKinds(in)
|
||||
if len(advertisedKinds) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for kind := range expectedKinds {
|
||||
if !advertisedKinds[kind] {
|
||||
return fmt.Errorf(
|
||||
"configured expected kind %q not advertised by source (configured=%v advertised=%v)",
|
||||
kind,
|
||||
sortedKinds(expectedKinds),
|
||||
sortedKinds(advertisedKinds),
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseExpectedKinds(raw []string) (map[event.Kind]bool, error) {
|
||||
kinds := map[event.Kind]bool{}
|
||||
for i, k := range raw {
|
||||
kind, err := event.ParseKind(k)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid expected kind at index %d (%q): %w", i, k, err)
|
||||
}
|
||||
kinds[kind] = true
|
||||
}
|
||||
return kinds, nil
|
||||
}
|
||||
|
||||
func advertisedSourceKinds(in Input) map[event.Kind]bool {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
kinds := map[event.Kind]bool{}
|
||||
if ks, ok := in.(KindsSource); ok {
|
||||
for _, kind := range ks.Kinds() {
|
||||
kinds[kind] = true
|
||||
}
|
||||
return kinds
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sortedKinds(kindSet map[event.Kind]bool) []string {
|
||||
out := make([]string, 0, len(kindSet))
|
||||
for kind := range kindSet {
|
||||
out = append(out, string(kind))
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
112
sources/helpers_test.go
Normal file
112
sources/helpers_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package sources
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
type testInput struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (s testInput) Name() string { return s.name }
|
||||
|
||||
type testKindsSource struct {
|
||||
testInput
|
||||
kinds []event.Kind
|
||||
}
|
||||
|
||||
func (s testKindsSource) Kinds() []event.Kind { return s.kinds }
|
||||
|
||||
func TestValidateExpectedKindsSubsetAllowed(t *testing.T) {
|
||||
cfg := config.SourceConfig{Kinds: []string{"observation"}}
|
||||
in := testKindsSource{
|
||||
testInput: testInput{name: "test"},
|
||||
kinds: []event.Kind{"observation", "forecast"},
|
||||
}
|
||||
|
||||
if err := ValidateExpectedKinds(cfg, in); err != nil {
|
||||
t.Fatalf("ValidateExpectedKinds() unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateExpectedKindsMismatchFails(t *testing.T) {
|
||||
cfg := config.SourceConfig{Kinds: []string{"alert"}}
|
||||
in := testKindsSource{
|
||||
testInput: testInput{name: "test"},
|
||||
kinds: []event.Kind{"observation", "forecast"},
|
||||
}
|
||||
|
||||
err := ValidateExpectedKinds(cfg, in)
|
||||
if err == nil {
|
||||
t.Fatalf("ValidateExpectedKinds() expected mismatch error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "configured expected kind") {
|
||||
t.Fatalf("ValidateExpectedKinds() error %q does not include expected message", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateExpectedKindsNoMetadataSkipsCheck(t *testing.T) {
|
||||
cfg := config.SourceConfig{Kinds: []string{"alert"}}
|
||||
in := testInput{name: "test"}
|
||||
|
||||
if err := ValidateExpectedKinds(cfg, in); err != nil {
|
||||
t.Fatalf("ValidateExpectedKinds() unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultEventIDUsesUpstreamID(t *testing.T) {
|
||||
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123, time.UTC)
|
||||
got := DefaultEventID(" upstream-id ", "source", nil, emittedAt)
|
||||
if got != "upstream-id" {
|
||||
t.Fatalf("DefaultEventID() = %q, want upstream-id", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultEventIDPrefersEffectiveAt(t *testing.T) {
|
||||
effectiveAt := time.Date(2026, 3, 28, 16, 4, 5, 987654321, time.FixedZone("x", -6*3600))
|
||||
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123, time.UTC)
|
||||
|
||||
got := DefaultEventID("", "source", &effectiveAt, emittedAt)
|
||||
want := "source:" + effectiveAt.UTC().Format(time.RFC3339Nano)
|
||||
if got != want {
|
||||
t.Fatalf("DefaultEventID() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultEventIDFallsBackToEmittedAt(t *testing.T) {
|
||||
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123456789, time.FixedZone("y", 3*3600))
|
||||
got := DefaultEventID("", "source", nil, emittedAt)
|
||||
want := "source:" + emittedAt.UTC().Format(time.RFC3339Nano)
|
||||
if got != want {
|
||||
t.Fatalf("DefaultEventID() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingleEventBuildsValidatedSlice(t *testing.T) {
|
||||
effectiveAt := time.Date(2026, 3, 28, 16, 0, 0, 0, time.UTC)
|
||||
emittedAt := time.Date(2026, 3, 28, 15, 0, 0, 0, time.FixedZone("z", -5*3600))
|
||||
|
||||
got, err := SingleEvent(
|
||||
event.Kind("observation"),
|
||||
"source-a",
|
||||
"raw.example.v1",
|
||||
"evt-1",
|
||||
emittedAt,
|
||||
&effectiveAt,
|
||||
map[string]any{"ok": true},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SingleEvent() unexpected error: %v", err)
|
||||
}
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("SingleEvent() len = %d, want 1", len(got))
|
||||
}
|
||||
if got[0].EmittedAt != emittedAt.UTC() {
|
||||
t.Fatalf("SingleEvent() emittedAt = %s, want %s", got[0].EmittedAt, emittedAt.UTC())
|
||||
}
|
||||
}
|
||||
152
sources/http.go
Normal file
152
sources/http.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package sources
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/transport"
|
||||
)
|
||||
|
||||
// HTTPSource is a reusable helper for polling HTTP-backed sources.
|
||||
//
|
||||
// It centralizes generic source config parsing (`params.url`,
|
||||
// `params.user_agent`, and optional `params.conditional`), default HTTP client
|
||||
// setup, and conditional GET validator handling. Concrete daemon sources remain
|
||||
// responsible for decoding the response body and constructing events.
|
||||
type HTTPSource struct {
|
||||
Driver string
|
||||
Name string
|
||||
URL string
|
||||
UserAgent string
|
||||
Accept string
|
||||
Conditional bool
|
||||
ResponseBodyLimitBytes int64
|
||||
Client *http.Client
|
||||
|
||||
mu sync.Mutex
|
||||
validators transport.HTTPValidators
|
||||
}
|
||||
|
||||
// NewHTTPSource builds a generic HTTP polling helper from SourceConfig.
|
||||
//
|
||||
// Required params:
|
||||
// - params.url
|
||||
// - params.user_agent
|
||||
//
|
||||
// Optional params:
|
||||
// - params.conditional (default true): enable conditional GET using cached
|
||||
// ETag / Last-Modified validators
|
||||
func NewHTTPSource(driver string, cfg config.SourceConfig, accept string) (*HTTPSource, error) {
|
||||
name := strings.TrimSpace(cfg.Name)
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("%s: name is required", driver)
|
||||
}
|
||||
if cfg.Params == nil {
|
||||
return nil, fmt.Errorf("%s %q: params are required (need params.url and params.user_agent)", driver, cfg.Name)
|
||||
}
|
||||
|
||||
url, ok := cfg.ParamString("url", "URL")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s %q: params.url is required", driver, cfg.Name)
|
||||
}
|
||||
|
||||
userAgent, ok := cfg.ParamString("user_agent", "userAgent")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s %q: params.user_agent is required", driver, cfg.Name)
|
||||
}
|
||||
|
||||
conditional := true
|
||||
if _, exists := cfg.Params["conditional"]; exists {
|
||||
var ok bool
|
||||
conditional, ok = cfg.ParamBool("conditional")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source %q: params.conditional must be a boolean", cfg.Name)
|
||||
}
|
||||
}
|
||||
|
||||
timeout := transport.DefaultHTTPTimeout
|
||||
if _, exists := cfg.Params["http_timeout"]; exists {
|
||||
var ok bool
|
||||
timeout, ok = cfg.ParamDuration("http_timeout")
|
||||
if !ok || timeout <= 0 {
|
||||
return nil, fmt.Errorf("source %q: params.http_timeout must be a positive duration", cfg.Name)
|
||||
}
|
||||
}
|
||||
|
||||
bodyLimit := transport.DefaultHTTPResponseBodyLimitBytes
|
||||
if _, exists := cfg.Params["http_response_body_limit_bytes"]; exists {
|
||||
rawLimit, ok := cfg.ParamInt("http_response_body_limit_bytes")
|
||||
if !ok || rawLimit <= 0 {
|
||||
return nil, fmt.Errorf("source %q: params.http_response_body_limit_bytes must be a positive integer", cfg.Name)
|
||||
}
|
||||
bodyLimit = int64(rawLimit)
|
||||
}
|
||||
|
||||
return &HTTPSource{
|
||||
Driver: driver,
|
||||
Name: name,
|
||||
URL: url,
|
||||
UserAgent: userAgent,
|
||||
Accept: accept,
|
||||
Conditional: conditional,
|
||||
ResponseBodyLimitBytes: bodyLimit,
|
||||
Client: transport.NewHTTPClient(timeout),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FetchBytesIfChanged fetches the configured URL and reports whether the
|
||||
// upstream content changed. An unchanged 304 response returns changed=false
|
||||
// with no body and no error.
|
||||
func (s *HTTPSource) FetchBytesIfChanged(ctx context.Context) ([]byte, bool, error) {
|
||||
client := s.Client
|
||||
if client == nil {
|
||||
client = transport.NewHTTPClient(transport.DefaultHTTPTimeout)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
validators := s.validators
|
||||
s.mu.Unlock()
|
||||
|
||||
bodyLimit := s.ResponseBodyLimitBytes
|
||||
if bodyLimit <= 0 {
|
||||
bodyLimit = transport.DefaultHTTPResponseBodyLimitBytes
|
||||
}
|
||||
|
||||
body, changed, next, err := transport.FetchBodyIfChangedWithLimit(
|
||||
ctx,
|
||||
client,
|
||||
s.URL,
|
||||
s.UserAgent,
|
||||
s.Accept,
|
||||
s.Conditional,
|
||||
validators,
|
||||
bodyLimit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("%s %q: %w", s.Driver, s.Name, err)
|
||||
}
|
||||
|
||||
if s.Conditional {
|
||||
s.mu.Lock()
|
||||
s.validators = next
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
return body, changed, nil
|
||||
}
|
||||
|
||||
// FetchJSONIfChanged fetches the configured URL and returns the raw response
|
||||
// body as json.RawMessage when content changed. An unchanged 304 response
|
||||
// returns changed=false with a nil body and no error.
|
||||
func (s *HTTPSource) FetchJSONIfChanged(ctx context.Context) (json.RawMessage, bool, error) {
|
||||
body, changed, err := s.FetchBytesIfChanged(ctx)
|
||||
if err != nil || !changed {
|
||||
return nil, changed, err
|
||||
}
|
||||
return json.RawMessage(body), true, nil
|
||||
}
|
||||
260
sources/http_test.go
Normal file
260
sources/http_test.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package sources
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/transport"
|
||||
)
|
||||
|
||||
func TestNewHTTPSourceConditionalDefaultsTrue(t *testing.T) {
|
||||
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||
Name: "test-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"url": "https://example.invalid",
|
||||
"user_agent": "test-agent",
|
||||
},
|
||||
}, "application/json")
|
||||
if err != nil {
|
||||
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||
}
|
||||
if !src.Conditional {
|
||||
t.Fatalf("Conditional = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPSourceRejectsInvalidConditional(t *testing.T) {
|
||||
_, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||
Name: "test-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"url": "https://example.invalid",
|
||||
"user_agent": "test-agent",
|
||||
"conditional": "sometimes",
|
||||
},
|
||||
}, "application/json")
|
||||
if err == nil {
|
||||
t.Fatalf("NewHTTPSource() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "params.conditional must be a boolean") {
|
||||
t.Fatalf("NewHTTPSource() error = %q, want params.conditional must be a boolean", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPSourceConditionalCanBeExplicitlyFalse(t *testing.T) {
|
||||
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||
Name: "test-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"url": "https://example.invalid",
|
||||
"user_agent": "test-agent",
|
||||
"conditional": false,
|
||||
},
|
||||
}, "application/json")
|
||||
if err != nil {
|
||||
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||
}
|
||||
if src.Conditional {
|
||||
t.Fatalf("Conditional = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPSourceHTTPTimeoutOverride(t *testing.T) {
|
||||
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||
Name: "test-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"url": "https://example.invalid",
|
||||
"user_agent": "test-agent",
|
||||
"http_timeout": "250ms",
|
||||
},
|
||||
}, "application/json")
|
||||
if err != nil {
|
||||
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||
}
|
||||
if src.Client == nil {
|
||||
t.Fatalf("Client = nil")
|
||||
}
|
||||
if src.Client.Timeout != 250*time.Millisecond {
|
||||
t.Fatalf("Client.Timeout = %s, want 250ms", src.Client.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPSourceBodyLimitOverride(t *testing.T) {
|
||||
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||
Name: "test-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"url": "https://example.invalid",
|
||||
"user_agent": "test-agent",
|
||||
"http_response_body_limit_bytes": 12345,
|
||||
},
|
||||
}, "application/json")
|
||||
if err != nil {
|
||||
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||
}
|
||||
if src.ResponseBodyLimitBytes != 12345 {
|
||||
t.Fatalf("ResponseBodyLimitBytes = %d, want 12345", src.ResponseBodyLimitBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPSourceRejectsInvalidHTTPTimeout(t *testing.T) {
|
||||
_, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||
Name: "test-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"url": "https://example.invalid",
|
||||
"user_agent": "test-agent",
|
||||
"http_timeout": "soon",
|
||||
},
|
||||
}, "application/json")
|
||||
if err == nil {
|
||||
t.Fatalf("NewHTTPSource() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "params.http_timeout must be a positive duration") {
|
||||
t.Fatalf("NewHTTPSource() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPSourceRejectsInvalidBodyLimit(t *testing.T) {
|
||||
_, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||
Name: "test-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"url": "https://example.invalid",
|
||||
"user_agent": "test-agent",
|
||||
"http_response_body_limit_bytes": "abc",
|
||||
},
|
||||
}, "application/json")
|
||||
if err == nil {
|
||||
t.Fatalf("NewHTTPSource() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "params.http_response_body_limit_bytes must be a positive integer") {
|
||||
t.Fatalf("NewHTTPSource() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPSourceFetchJSONIfChanged(t *testing.T) {
|
||||
var call int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
call++
|
||||
switch call {
|
||||
case 1:
|
||||
w.Header().Set("ETag", `"v1"`)
|
||||
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||
case 2:
|
||||
if got := r.Header.Get("If-None-Match"); got != `"v1"` {
|
||||
t.Fatalf("second request If-None-Match = %q", got)
|
||||
}
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
default:
|
||||
t.Fatalf("unexpected call count %d", call)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||
Name: "test-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"url": srv.URL,
|
||||
"user_agent": "test-agent",
|
||||
},
|
||||
}, "application/json")
|
||||
if err != nil {
|
||||
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||
}
|
||||
|
||||
raw, changed, err := src.FetchJSONIfChanged(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("first FetchJSONIfChanged() error = %v", err)
|
||||
}
|
||||
if !changed {
|
||||
t.Fatalf("first FetchJSONIfChanged() changed = false, want true")
|
||||
}
|
||||
if got := string(raw); got != `{"ok":true}` {
|
||||
t.Fatalf("first FetchJSONIfChanged() body = %q", got)
|
||||
}
|
||||
|
||||
raw, changed, err = src.FetchJSONIfChanged(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("second FetchJSONIfChanged() error = %v", err)
|
||||
}
|
||||
if changed {
|
||||
t.Fatalf("second FetchJSONIfChanged() changed = true, want false")
|
||||
}
|
||||
if raw != nil {
|
||||
t.Fatalf("second FetchJSONIfChanged() body = %q, want nil", string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPSourceFetchJSONIfChangedHonorsBodyLimitOverride(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||
Name: "test-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"url": srv.URL,
|
||||
"user_agent": "test-agent",
|
||||
"http_response_body_limit_bytes": 4,
|
||||
},
|
||||
}, "application/json")
|
||||
if err != nil {
|
||||
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||
}
|
||||
|
||||
_, _, err = src.FetchJSONIfChanged(context.Background())
|
||||
if err == nil {
|
||||
t.Fatalf("FetchJSONIfChanged() error = nil, want limit error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "response body too large") {
|
||||
t.Fatalf("FetchJSONIfChanged() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPSourceUsesDefaultBodyLimitWhenUnset(t *testing.T) {
|
||||
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||
Name: "test-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"url": "https://example.invalid",
|
||||
"user_agent": "test-agent",
|
||||
},
|
||||
}, "application/json")
|
||||
if err != nil {
|
||||
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||
}
|
||||
if src.ResponseBodyLimitBytes != transport.DefaultHTTPResponseBodyLimitBytes {
|
||||
t.Fatalf("ResponseBodyLimitBytes = %d, want %d", src.ResponseBodyLimitBytes, transport.DefaultHTTPResponseBodyLimitBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPSourceUsesDefaultTimeoutWhenUnset(t *testing.T) {
|
||||
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||
Name: "test-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"url": "https://example.invalid",
|
||||
"user_agent": "test-agent",
|
||||
},
|
||||
}, "application/json")
|
||||
if err != nil {
|
||||
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||
}
|
||||
if src.Client == nil {
|
||||
t.Fatalf("Client = nil")
|
||||
}
|
||||
if src.Client.Timeout != transport.DefaultHTTPTimeout {
|
||||
t.Fatalf("Client.Timeout = %s, want %s", src.Client.Timeout, transport.DefaultHTTPTimeout)
|
||||
}
|
||||
}
|
||||
117
sources/postgres.go
Normal file
117
sources/postgres.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package sources
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
|
||||
)
|
||||
|
||||
const defaultPostgresQueryTimeout = 30 * time.Second
|
||||
|
||||
type postgresQueryDB interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
var openPostgresQueryDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) {
|
||||
return pgconn.Open(ctx, cfg)
|
||||
}
|
||||
|
||||
// PostgresQuerySource is a reusable helper for polling Postgres-backed sources.
|
||||
//
|
||||
// It centralizes generic source config parsing and query execution. Concrete
|
||||
// daemon sources remain responsible for SQL semantics, row scanning, cursoring,
|
||||
// and event construction.
|
||||
type PostgresQuerySource struct {
|
||||
Driver string
|
||||
Name string
|
||||
SQL string
|
||||
QueryTimeout time.Duration
|
||||
|
||||
db postgresQueryDB
|
||||
}
|
||||
|
||||
// NewPostgresQuerySource builds a generic Postgres polling helper from
|
||||
// SourceConfig.
|
||||
//
|
||||
// Required params:
|
||||
// - params.uri
|
||||
// - params.username
|
||||
// - params.password
|
||||
// - params.query
|
||||
//
|
||||
// Optional params:
|
||||
// - params.query_timeout (default 30s)
|
||||
func NewPostgresQuerySource(driver string, cfg config.SourceConfig) (*PostgresQuerySource, error) {
|
||||
name := strings.TrimSpace(cfg.Name)
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("%s: name is required", driver)
|
||||
}
|
||||
if cfg.Params == nil {
|
||||
return nil, fmt.Errorf("%s %q: params are required (need params.uri, params.username, params.password, and params.query)", driver, cfg.Name)
|
||||
}
|
||||
|
||||
uri, ok := cfg.ParamString("uri")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s %q: params.uri is required", driver, cfg.Name)
|
||||
}
|
||||
username, ok := cfg.ParamString("username")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s %q: params.username is required", driver, cfg.Name)
|
||||
}
|
||||
password, ok := cfg.ParamString("password")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s %q: params.password is required", driver, cfg.Name)
|
||||
}
|
||||
query, ok := cfg.ParamString("query")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s %q: params.query is required", driver, cfg.Name)
|
||||
}
|
||||
|
||||
queryTimeout := defaultPostgresQueryTimeout
|
||||
if _, exists := cfg.Params["query_timeout"]; exists {
|
||||
var ok bool
|
||||
queryTimeout, ok = cfg.ParamDuration("query_timeout")
|
||||
if !ok || queryTimeout <= 0 {
|
||||
return nil, fmt.Errorf("source %q: params.query_timeout must be a positive duration", cfg.Name)
|
||||
}
|
||||
}
|
||||
|
||||
db, err := openPostgresQueryDB(context.Background(), pgconn.ConnConfig{
|
||||
URI: uri,
|
||||
Username: username,
|
||||
Password: password,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s %q: open db: %w", driver, cfg.Name, err)
|
||||
}
|
||||
|
||||
return &PostgresQuerySource{
|
||||
Driver: driver,
|
||||
Name: name,
|
||||
SQL: query,
|
||||
QueryTimeout: queryTimeout,
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *PostgresQuerySource) Query(ctx context.Context, args ...any) (*sql.Rows, error) {
|
||||
queryCtx := ctx
|
||||
if s.QueryTimeout > 0 {
|
||||
if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > s.QueryTimeout {
|
||||
// We intentionally do not cancel this derived context here because the
|
||||
// returned rows may still be reading from the database.
|
||||
queryCtx, _ = context.WithTimeout(ctx, s.QueryTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(queryCtx, s.SQL, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s %q: query: %w", s.Driver, s.Name, err)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
352
sources/postgres_test.go
Normal file
352
sources/postgres_test.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package sources
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
|
||||
)
|
||||
|
||||
type fakePostgresQueryDB struct {
|
||||
queryErr error
|
||||
lastCtx context.Context
|
||||
lastQuery string
|
||||
lastArgs []any
|
||||
returnRows *sql.Rows
|
||||
}
|
||||
|
||||
func (db *fakePostgresQueryDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||
db.lastCtx = ctx
|
||||
db.lastQuery = query
|
||||
db.lastArgs = append([]any(nil), args...)
|
||||
if db.queryErr != nil {
|
||||
return nil, db.queryErr
|
||||
}
|
||||
return db.returnRows, nil
|
||||
}
|
||||
|
||||
func withPostgresQuerySourceTestState(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
oldOpen := openPostgresQueryDB
|
||||
t.Cleanup(func() {
|
||||
openPostgresQueryDB = oldOpen
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewPostgresQuerySourceMissingParams(t *testing.T) {
|
||||
withPostgresQuerySourceTestState(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
params map[string]any
|
||||
want string
|
||||
}{
|
||||
{name: "missing uri", params: map[string]any{"username": "u", "password": "p", "query": "SELECT 1"}, want: "params.uri"},
|
||||
{name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p", "query": "SELECT 1"}, want: "params.username"},
|
||||
{name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "query": "SELECT 1"}, want: "params.password"},
|
||||
{name: "missing query", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "password": "p"}, want: "params.query"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: tc.params,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.want) {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %q, want substring %q", err, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresQuerySourceRejectsInvalidQueryTimeout(t *testing.T) {
|
||||
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "u",
|
||||
"password": "p",
|
||||
"query": "SELECT 1",
|
||||
"query_timeout": "soon",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "params.query_timeout must be a positive duration") {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresQuerySourceSuccessfulConstruction(t *testing.T) {
|
||||
withPostgresQuerySourceTestState(t)
|
||||
|
||||
db := &fakePostgresQueryDB{}
|
||||
var gotCfg pgconn.ConnConfig
|
||||
openPostgresQueryDB = func(_ context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) {
|
||||
gotCfg = cfg
|
||||
return db, nil
|
||||
}
|
||||
|
||||
src, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://db.example.local/feedkit",
|
||||
"username": "app_user",
|
||||
"password": "app_pass",
|
||||
"query": "SELECT * FROM observations",
|
||||
"query_timeout": "45s",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %v", err)
|
||||
}
|
||||
if src.Name != "pg-source" {
|
||||
t.Fatalf("Name = %q, want pg-source", src.Name)
|
||||
}
|
||||
if src.QueryTimeout != 45*time.Second {
|
||||
t.Fatalf("QueryTimeout = %s, want 45s", src.QueryTimeout)
|
||||
}
|
||||
if src.SQL != "SELECT * FROM observations" {
|
||||
t.Fatalf("SQL = %q", src.SQL)
|
||||
}
|
||||
if gotCfg.Username != "app_user" || gotCfg.Password != "app_pass" {
|
||||
t.Fatalf("ConnConfig = %+v", gotCfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresQuerySourceOpenFailure(t *testing.T) {
|
||||
withPostgresQuerySourceTestState(t)
|
||||
|
||||
openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) {
|
||||
return nil, errors.New("db unavailable")
|
||||
}
|
||||
|
||||
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "u",
|
||||
"password": "p",
|
||||
"query": "SELECT 1",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `test_driver "pg-source": open db: db unavailable`) {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresQuerySourceQueryAppliesTimeoutAndWrapsError(t *testing.T) {
|
||||
db := &fakePostgresQueryDB{queryErr: errors.New("query failed")}
|
||||
src := &PostgresQuerySource{
|
||||
Driver: "test_driver",
|
||||
Name: "pg-source",
|
||||
SQL: "SELECT 1",
|
||||
QueryTimeout: 30 * time.Second,
|
||||
db: db,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := src.Query(ctx, "arg1")
|
||||
if err == nil {
|
||||
t.Fatalf("Query() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `test_driver "pg-source": query: query failed`) {
|
||||
t.Fatalf("Query() error = %q", err)
|
||||
}
|
||||
if db.lastCtx == nil {
|
||||
t.Fatalf("lastCtx = nil")
|
||||
}
|
||||
if _, ok := db.lastCtx.Deadline(); !ok {
|
||||
t.Fatalf("expected derived deadline on query context")
|
||||
}
|
||||
if db.lastQuery != "SELECT 1" {
|
||||
t.Fatalf("lastQuery = %q", db.lastQuery)
|
||||
}
|
||||
if len(db.lastArgs) != 1 || db.lastArgs[0] != "arg1" {
|
||||
t.Fatalf("lastArgs = %#v", db.lastArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresQuerySourceQueryUsesEarlierCallerDeadline(t *testing.T) {
|
||||
db := &fakePostgresQueryDB{queryErr: errors.New("query failed")}
|
||||
src := &PostgresQuerySource{
|
||||
Driver: "test_driver",
|
||||
Name: "pg-source",
|
||||
SQL: "SELECT 1",
|
||||
QueryTimeout: 30 * time.Second,
|
||||
db: db,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, _ = src.Query(ctx)
|
||||
if db.lastCtx != ctx {
|
||||
t.Fatalf("expected source to reuse earlier caller deadline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresQuerySourceSupportsDownstreamPollingPattern(t *testing.T) {
|
||||
withPostgresQuerySourceTestState(t)
|
||||
|
||||
db, cleanup := openRowsTestDB(t, "feedkit_sources_pg_rows", []string{"event_id"}, [][]driver.Value{{"evt-1"}})
|
||||
defer cleanup()
|
||||
|
||||
openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
type fakeDownstreamSource struct {
|
||||
pg *PostgresQuerySource
|
||||
}
|
||||
poll := func(s fakeDownstreamSource, ctx context.Context) ([]event.Event, error) {
|
||||
rows, err := s.pg.Query(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []event.Event
|
||||
for rows.Next() {
|
||||
var eventID string
|
||||
if err := rows.Scan(&eventID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, event.Event{
|
||||
ID: eventID,
|
||||
Kind: event.Kind("observation"),
|
||||
Source: s.pg.Name,
|
||||
Schema: "raw.test.v1",
|
||||
EmittedAt: time.Now().UTC(),
|
||||
Payload: map[string]any{"event_id": eventID},
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
pg, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "u",
|
||||
"password": "p",
|
||||
"query": "SELECT event_id FROM events",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %v", err)
|
||||
}
|
||||
|
||||
events, err := poll(fakeDownstreamSource{pg: pg}, context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("poll() error = %v", err)
|
||||
}
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("len(events) = %d, want 1", len(events))
|
||||
}
|
||||
if events[0].ID != "evt-1" {
|
||||
t.Fatalf("events[0].ID = %q, want evt-1", events[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
rowsDriverMu sync.Mutex
|
||||
rowsDriverSeen = map[string]bool{}
|
||||
)
|
||||
|
||||
func openRowsTestDB(t *testing.T, driverName string, columns []string, rows [][]driver.Value) (*sql.DB, func()) {
|
||||
t.Helper()
|
||||
|
||||
rowsDriverMu.Lock()
|
||||
if !rowsDriverSeen[driverName] {
|
||||
sql.Register(driverName, &rowsTestDriver{columns: append([]string(nil), columns...), rows: cloneDriverRows(rows)})
|
||||
rowsDriverSeen[driverName] = true
|
||||
}
|
||||
rowsDriverMu.Unlock()
|
||||
|
||||
db, err := sql.Open(driverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open() error = %v", err)
|
||||
}
|
||||
|
||||
return db, func() {
|
||||
_ = db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func cloneDriverRows(in [][]driver.Value) [][]driver.Value {
|
||||
out := make([][]driver.Value, 0, len(in))
|
||||
for _, row := range in {
|
||||
copied := append([]driver.Value(nil), row...)
|
||||
out = append(out, copied)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type rowsTestDriver struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
}
|
||||
|
||||
func (d *rowsTestDriver) Open(string) (driver.Conn, error) {
|
||||
return &rowsTestConn{columns: append([]string(nil), d.columns...), rows: cloneDriverRows(d.rows)}, nil
|
||||
}
|
||||
|
||||
type rowsTestConn struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
}
|
||||
|
||||
func (c *rowsTestConn) Prepare(string) (driver.Stmt, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (c *rowsTestConn) Close() error { return nil }
|
||||
func (c *rowsTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") }
|
||||
|
||||
func (c *rowsTestConn) QueryContext(_ context.Context, _ string, _ []driver.NamedValue) (driver.Rows, error) {
|
||||
return &rowsTestRows{columns: append([]string(nil), c.columns...), rows: cloneDriverRows(c.rows)}, nil
|
||||
}
|
||||
|
||||
type rowsTestRows struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
idx int
|
||||
}
|
||||
|
||||
func (r *rowsTestRows) Columns() []string { return append([]string(nil), r.columns...) }
|
||||
func (r *rowsTestRows) Close() error { return nil }
|
||||
|
||||
func (r *rowsTestRows) Next(dest []driver.Value) error {
|
||||
if r.idx >= len(r.rows) {
|
||||
return io.EOF
|
||||
}
|
||||
copy(dest, r.rows[r.idx])
|
||||
r.idx++
|
||||
return nil
|
||||
}
|
||||
@@ -7,43 +7,115 @@ import (
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
)
|
||||
|
||||
// Factory constructs a configured Source instance from config.
|
||||
// PollFactory constructs a configured PollSource instance from config.
|
||||
//
|
||||
// This is how concrete daemons (weatherfeeder/newsfeeder/...) register their
|
||||
// domain-specific source drivers (Open-Meteo, NWS, RSS, etc.) while feedkit
|
||||
// remains domain-agnostic.
|
||||
type Factory func(cfg config.SourceConfig) (Source, error)
|
||||
type PollFactory func(cfg config.SourceConfig) (PollSource, error)
|
||||
type StreamFactory func(cfg config.SourceConfig) (StreamSource, error)
|
||||
|
||||
type Registry struct {
|
||||
byDriver map[string]Factory
|
||||
byPollDriver map[string]PollFactory
|
||||
byStreamDriver map[string]StreamFactory
|
||||
}
|
||||
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{byDriver: map[string]Factory{}}
|
||||
return &Registry{
|
||||
byPollDriver: map[string]PollFactory{},
|
||||
byStreamDriver: map[string]StreamFactory{},
|
||||
}
|
||||
}
|
||||
|
||||
// Register associates a driver name (e.g. "openmeteo_observation") with a factory.
|
||||
//
|
||||
// The driver string is the "lookup key" used by config.sources[].driver.
|
||||
func (r *Registry) Register(driver string, f Factory) {
|
||||
// RegisterPoll associates a driver name with a polling-source factory.
|
||||
func (r *Registry) RegisterPoll(driver string, f PollFactory) {
|
||||
driver = strings.TrimSpace(driver)
|
||||
if driver == "" {
|
||||
// Panic is appropriate here: registering an empty driver is always a programmer error,
|
||||
// and it will lead to extremely confusing runtime behavior if allowed.
|
||||
panic("sources.Registry.Register: driver cannot be empty")
|
||||
panic("sources.Registry.RegisterPoll: driver cannot be empty")
|
||||
}
|
||||
if f == nil {
|
||||
panic(fmt.Sprintf("sources.Registry.Register: factory cannot be nil (driver=%q)", driver))
|
||||
panic(fmt.Sprintf("sources.Registry.RegisterPoll: factory cannot be nil (driver=%q)", driver))
|
||||
}
|
||||
|
||||
r.byDriver[driver] = f
|
||||
if _, exists := r.byStreamDriver[driver]; exists {
|
||||
panic(fmt.Sprintf("sources.Registry.RegisterPoll: driver %q already registered as a stream source", driver))
|
||||
}
|
||||
if _, exists := r.byPollDriver[driver]; exists {
|
||||
panic(fmt.Sprintf("sources.Registry.RegisterPoll: driver %q already registered as a polling source", driver))
|
||||
}
|
||||
r.byPollDriver[driver] = f
|
||||
}
|
||||
|
||||
// Build constructs a Source from a SourceConfig by looking up cfg.Driver.
|
||||
func (r *Registry) Build(cfg config.SourceConfig) (Source, error) {
|
||||
f, ok := r.byDriver[cfg.Driver]
|
||||
// RegisterStream is the StreamSource equivalent of Register.
|
||||
func (r *Registry) RegisterStream(driver string, f StreamFactory) {
|
||||
driver = strings.TrimSpace(driver)
|
||||
if driver == "" {
|
||||
panic("sources.Registry.RegisterStream: driver cannot be empty")
|
||||
}
|
||||
if f == nil {
|
||||
panic(fmt.Sprintf("sources.Registry.RegisterStream: factory cannot be nil (driver=%q)", driver))
|
||||
}
|
||||
if _, exists := r.byPollDriver[driver]; exists {
|
||||
panic(fmt.Sprintf("sources.Registry.RegisterStream: driver %q already registered as a polling source", driver))
|
||||
}
|
||||
if _, exists := r.byStreamDriver[driver]; exists {
|
||||
panic(fmt.Sprintf("sources.Registry.RegisterStream: driver %q already registered as a stream source", driver))
|
||||
}
|
||||
r.byStreamDriver[driver] = f
|
||||
}
|
||||
|
||||
// BuildPoll constructs a polling source from a SourceConfig by looking up cfg.Driver.
|
||||
func (r *Registry) BuildPoll(cfg config.SourceConfig) (PollSource, error) {
|
||||
driver := strings.TrimSpace(cfg.Driver)
|
||||
if cfg.Mode.Normalize() == config.SourceModeStream {
|
||||
return nil, fmt.Errorf("source %q mode=stream cannot be built as polling source", cfg.Name)
|
||||
}
|
||||
|
||||
f, ok := r.byPollDriver[driver]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown source driver: %q", cfg.Driver)
|
||||
if _, streamExists := r.byStreamDriver[driver]; streamExists {
|
||||
return nil, fmt.Errorf("source driver %q is stream-only; cannot build as polling source", driver)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown source driver: %q", driver)
|
||||
}
|
||||
return f(cfg)
|
||||
}
|
||||
|
||||
// BuildInput can return either a polling Source or a StreamSource.
|
||||
func (r *Registry) BuildInput(cfg config.SourceConfig) (Input, error) {
|
||||
driver := strings.TrimSpace(cfg.Driver)
|
||||
mode := cfg.Mode.Normalize()
|
||||
if mode != config.SourceModeAuto && mode != config.SourceModePoll && mode != config.SourceModeStream {
|
||||
return nil, fmt.Errorf("source %q has invalid mode %q (expected \"poll\" or \"stream\")", cfg.Name, cfg.Mode)
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case config.SourceModePoll:
|
||||
f, ok := r.byPollDriver[driver]
|
||||
if !ok {
|
||||
if _, streamExists := r.byStreamDriver[driver]; streamExists {
|
||||
return nil, fmt.Errorf("source %q mode=poll conflicts with stream-only driver %q", cfg.Name, driver)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown source driver: %q", driver)
|
||||
}
|
||||
return f(cfg)
|
||||
case config.SourceModeStream:
|
||||
f, ok := r.byStreamDriver[driver]
|
||||
if !ok {
|
||||
if _, pollExists := r.byPollDriver[driver]; pollExists {
|
||||
return nil, fmt.Errorf("source %q mode=stream conflicts with polling driver %q", cfg.Name, driver)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown source driver: %q", driver)
|
||||
}
|
||||
return f(cfg)
|
||||
}
|
||||
|
||||
if f, ok := r.byStreamDriver[driver]; ok {
|
||||
return f(cfg)
|
||||
}
|
||||
if f, ok := r.byPollDriver[driver]; ok {
|
||||
return f(cfg)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown source driver: %q", driver)
|
||||
}
|
||||
|
||||
84
sources/registry_test.go
Normal file
84
sources/registry_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package sources
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
type testPollSource struct{ name string }
|
||||
|
||||
func (s testPollSource) Name() string { return s.name }
|
||||
func (s testPollSource) Poll(context.Context) ([]event.Event, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type testStreamSource struct{ name string }
|
||||
|
||||
func (s testStreamSource) Name() string { return s.name }
|
||||
func (s testStreamSource) Run(context.Context, chan<- event.Event) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRegistryBuildInputModeConflicts(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.RegisterPoll("poll_driver", func(cfg config.SourceConfig) (PollSource, error) {
|
||||
return testPollSource{name: cfg.Name}, nil
|
||||
})
|
||||
r.RegisterStream("stream_driver", func(cfg config.SourceConfig) (StreamSource, error) {
|
||||
return testStreamSource{name: cfg.Name}, nil
|
||||
})
|
||||
|
||||
_, err := r.BuildInput(config.SourceConfig{
|
||||
Name: "s1",
|
||||
Driver: "stream_driver",
|
||||
Mode: config.SourceModePoll,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected mode conflict error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "mode=poll") {
|
||||
t.Fatalf("expected poll conflict error, got: %v", err)
|
||||
}
|
||||
|
||||
_, err = r.BuildInput(config.SourceConfig{
|
||||
Name: "s2",
|
||||
Driver: "poll_driver",
|
||||
Mode: config.SourceModeStream,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected mode conflict error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "mode=stream") {
|
||||
t.Fatalf("expected stream conflict error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryBuildInputAutoByDriverType(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.RegisterPoll("poll_driver", func(cfg config.SourceConfig) (PollSource, error) {
|
||||
return testPollSource{name: cfg.Name}, nil
|
||||
})
|
||||
r.RegisterStream("stream_driver", func(cfg config.SourceConfig) (StreamSource, error) {
|
||||
return testStreamSource{name: cfg.Name}, nil
|
||||
})
|
||||
|
||||
src, err := r.BuildInput(config.SourceConfig{Name: "p", Driver: "poll_driver"})
|
||||
if err != nil {
|
||||
t.Fatalf("BuildInput poll auto failed: %v", err)
|
||||
}
|
||||
if _, ok := src.(PollSource); !ok {
|
||||
t.Fatalf("expected PollSource, got %T", src)
|
||||
}
|
||||
|
||||
src, err = r.BuildInput(config.SourceConfig{Name: "s", Driver: "stream_driver"})
|
||||
if err != nil {
|
||||
t.Fatalf("BuildInput stream auto failed: %v", err)
|
||||
}
|
||||
if _, ok := src.(StreamSource); !ok {
|
||||
t.Fatalf("expected StreamSource, got %T", src)
|
||||
}
|
||||
}
|
||||
@@ -6,25 +6,44 @@ import (
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
)
|
||||
|
||||
// Source is a configured polling job that emits 0..N events per poll.
|
||||
// Input is the common surface shared by all source types.
|
||||
//
|
||||
// Source implementations live in domain modules (weatherfeeder/newsfeeder/...)
|
||||
// A source may be polling (PollSource) or event-driven (StreamSource).
|
||||
// Both source types emit domain-agnostic event.Event values.
|
||||
type Input interface {
|
||||
Name() string
|
||||
}
|
||||
|
||||
// PollSource is a configured polling source that emits 0..N events per poll.
|
||||
//
|
||||
// PollSource implementations live in domain modules (weatherfeeder/newsfeeder/...)
|
||||
// and are registered into a feedkit sources.Registry.
|
||||
//
|
||||
// feedkit infrastructure treats Source as opaque; it just calls Poll()
|
||||
// feedkit infrastructure treats PollSource as opaque; it just calls Poll()
|
||||
// on the configured cadence and publishes the resulting events.
|
||||
type Source interface {
|
||||
type PollSource interface {
|
||||
// Name is the configured source name (used for logs and included in emitted events).
|
||||
Name() string
|
||||
|
||||
// Kind is the "primary kind" emitted by this source.
|
||||
//
|
||||
// This is mainly useful as a *safety check* (e.g. config says kind=forecast but
|
||||
// driver emits observation). Some future sources may emit multiple kinds; if/when
|
||||
// that happens, we can evolve this interface (e.g., make Kind optional, or remove it).
|
||||
Kind() event.Kind
|
||||
|
||||
// Poll fetches from upstream and returns 0..N events.
|
||||
// Poll fetches/processes one input batch and returns 0..N events.
|
||||
// A single poll can emit multiple event kinds.
|
||||
// Implementations should honor ctx.Done() for network calls and other I/O.
|
||||
Poll(ctx context.Context) ([]event.Event, error)
|
||||
}
|
||||
|
||||
// StreamSource is an event-driven source (NATS/RabbitMQ/MQTT/etc).
|
||||
//
|
||||
// Run should block, producing events into `out` until ctx is cancelled or a fatal error occurs.
|
||||
// It MUST NOT close out (the scheduler/daemon owns the bus).
|
||||
//
|
||||
// Stream sources can classify exits by wrapping errors with StreamRetryable or
|
||||
// StreamFatal. Plain non-nil errors are treated as retryable by the scheduler.
|
||||
type StreamSource interface {
|
||||
Input
|
||||
Run(ctx context.Context, out chan<- event.Event) error
|
||||
}
|
||||
|
||||
// KindsSource is an optional interface for sources that advertise multiple kinds.
|
||||
type KindsSource interface {
|
||||
Kinds() []event.Kind
|
||||
}
|
||||
|
||||
63
sources/stream_errors.go
Normal file
63
sources/stream_errors.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package sources
|
||||
|
||||
import "errors"
|
||||
|
||||
type streamRetryableError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *streamRetryableError) Error() string {
|
||||
if e.err == nil {
|
||||
return "retryable stream error"
|
||||
}
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *streamRetryableError) Unwrap() error { return e.err }
|
||||
|
||||
type streamFatalError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *streamFatalError) Error() string {
|
||||
if e.err == nil {
|
||||
return "fatal stream error"
|
||||
}
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *streamFatalError) Unwrap() error { return e.err }
|
||||
|
||||
// StreamRetryable marks a stream-source exit as retryable.
|
||||
func StreamRetryable(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return &streamRetryableError{err: err}
|
||||
}
|
||||
|
||||
// StreamFatal marks a stream-source exit as fatal.
|
||||
func StreamFatal(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return &streamFatalError{err: err}
|
||||
}
|
||||
|
||||
// IsStreamRetryable reports whether err contains a retryable stream marker.
|
||||
func IsStreamRetryable(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var target *streamRetryableError
|
||||
return errors.As(err, &target)
|
||||
}
|
||||
|
||||
// IsStreamFatal reports whether err contains a fatal stream marker.
|
||||
func IsStreamFatal(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var target *streamFatalError
|
||||
return errors.As(err, &target)
|
||||
}
|
||||
52
sources/stream_errors_test.go
Normal file
52
sources/stream_errors_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package sources
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStreamRetryableWrapsThroughErrorChains(t *testing.T) {
|
||||
base := errors.New("retry me")
|
||||
err := fmt.Errorf("outer: %w", StreamRetryable(base))
|
||||
|
||||
if !IsStreamRetryable(err) {
|
||||
t.Fatalf("IsStreamRetryable() = false, want true")
|
||||
}
|
||||
if IsStreamFatal(err) {
|
||||
t.Fatalf("IsStreamFatal() = true, want false")
|
||||
}
|
||||
if !errors.Is(err, base) {
|
||||
t.Fatalf("errors.Is(err, base) = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamFatalWrapsThroughErrorChains(t *testing.T) {
|
||||
base := errors.New("fatal")
|
||||
err := fmt.Errorf("outer: %w", StreamFatal(base))
|
||||
|
||||
if !IsStreamFatal(err) {
|
||||
t.Fatalf("IsStreamFatal() = false, want true")
|
||||
}
|
||||
if IsStreamRetryable(err) {
|
||||
t.Fatalf("IsStreamRetryable() = true, want false")
|
||||
}
|
||||
if !errors.Is(err, base) {
|
||||
t.Fatalf("errors.Is(err, base) = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamErrorHelpersNil(t *testing.T) {
|
||||
if StreamRetryable(nil) != nil {
|
||||
t.Fatalf("StreamRetryable(nil) != nil")
|
||||
}
|
||||
if StreamFatal(nil) != nil {
|
||||
t.Fatalf("StreamFatal(nil) != nil")
|
||||
}
|
||||
if IsStreamRetryable(nil) {
|
||||
t.Fatalf("IsStreamRetryable(nil) = true")
|
||||
}
|
||||
if IsStreamFatal(nil) {
|
||||
t.Fatalf("IsStreamFatal(nil) = true")
|
||||
}
|
||||
}
|
||||
190
transport/http.go
Normal file
190
transport/http.go
Normal file
@@ -0,0 +1,190 @@
|
||||
// FILE: ./transport/http.go
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DefaultHTTPResponseBodyLimitBytes is a hard safety limit on HTTP response bodies.
|
||||
// API responses should be small, so this protects us from accidental
|
||||
// or malicious large responses.
|
||||
const DefaultHTTPResponseBodyLimitBytes int64 = 2 << 21 // 4 MiB
|
||||
|
||||
// DefaultHTTPTimeout is the standard timeout used by HTTP sources.
|
||||
// Individual drivers may override this if they have a specific need.
|
||||
const DefaultHTTPTimeout = 10 * time.Second
|
||||
|
||||
// NewHTTPClient returns a simple http.Client configured with a timeout.
|
||||
// If timeout <= 0, DefaultHTTPTimeout is used.
|
||||
func NewHTTPClient(timeout time.Duration) *http.Client {
|
||||
if timeout <= 0 {
|
||||
timeout = DefaultHTTPTimeout
|
||||
}
|
||||
return &http.Client{Timeout: timeout}
|
||||
}
|
||||
|
||||
func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept string) ([]byte, error) {
|
||||
return FetchBodyWithLimit(ctx, client, url, userAgent, accept, DefaultHTTPResponseBodyLimitBytes)
|
||||
}
|
||||
|
||||
func FetchBodyWithLimit(ctx context.Context, client *http.Client, url, userAgent, accept string, bodyLimitBytes int64) ([]byte, error) {
|
||||
res, err := doRequest(ctx, client, http.MethodGet, url, userAgent, accept, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("HTTP %s", res.Status)
|
||||
}
|
||||
|
||||
return readValidatedBody(res.Body, bodyLimitBytes)
|
||||
}
|
||||
|
||||
// HTTPValidators are cache validators learned from prior successful GET responses.
|
||||
//
|
||||
// ETag is preferred when present. LastModified is used as a fallback validator
|
||||
// when ETag is unavailable.
|
||||
type HTTPValidators struct {
|
||||
ETag string
|
||||
LastModified string
|
||||
}
|
||||
|
||||
// FetchBodyIfChanged performs an HTTP GET and opportunistically uses conditional
|
||||
// request headers based on the provided validators.
|
||||
//
|
||||
// Behavior:
|
||||
// - if conditional is false, this behaves like a normal GET and leaves validators unchanged
|
||||
// - if validators.ETag is set, sends If-None-Match
|
||||
// - else if validators.LastModified is set, sends If-Modified-Since
|
||||
// - 304 Not Modified is treated as success with changed=false and no body
|
||||
// - 200 responses are treated as changed=true and still enforce the normal body checks
|
||||
//
|
||||
// Returned validators reflect any updates learned from the response headers.
|
||||
func FetchBodyIfChanged(
|
||||
ctx context.Context,
|
||||
client *http.Client,
|
||||
url, userAgent, accept string,
|
||||
conditional bool,
|
||||
validators HTTPValidators,
|
||||
) ([]byte, bool, HTTPValidators, error) {
|
||||
return FetchBodyIfChangedWithLimit(ctx, client, url, userAgent, accept, conditional, validators, DefaultHTTPResponseBodyLimitBytes)
|
||||
}
|
||||
|
||||
func FetchBodyIfChangedWithLimit(
|
||||
ctx context.Context,
|
||||
client *http.Client,
|
||||
url, userAgent, accept string,
|
||||
conditional bool,
|
||||
validators HTTPValidators,
|
||||
bodyLimitBytes int64,
|
||||
) ([]byte, bool, HTTPValidators, error) {
|
||||
headerName, headerValue := conditionalHeader(conditional, validators)
|
||||
|
||||
res, err := doRequest(ctx, client, http.MethodGet, url, userAgent, accept, headerName, headerValue)
|
||||
if err != nil {
|
||||
return nil, false, validators, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
switch res.StatusCode {
|
||||
case http.StatusNotModified:
|
||||
if conditional {
|
||||
validators = refreshValidators(validators, res.Header)
|
||||
}
|
||||
return nil, false, validators, nil
|
||||
default:
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
return nil, false, validators, fmt.Errorf("HTTP %s", res.Status)
|
||||
}
|
||||
}
|
||||
|
||||
b, err := readValidatedBody(res.Body, bodyLimitBytes)
|
||||
if err != nil {
|
||||
return nil, false, validators, err
|
||||
}
|
||||
|
||||
if conditional {
|
||||
validators = replaceValidators(res.Header)
|
||||
}
|
||||
|
||||
return b, true, validators, nil
|
||||
}
|
||||
|
||||
func doRequest(ctx context.Context, client *http.Client, method, url, userAgent, accept, headerName, headerValue string) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if userAgent != "" {
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
}
|
||||
if accept != "" {
|
||||
req.Header.Set("Accept", accept)
|
||||
}
|
||||
if headerName != "" && headerValue != "" {
|
||||
req.Header.Set(headerName, headerValue)
|
||||
}
|
||||
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func conditionalHeader(enabled bool, validators HTTPValidators) (string, string) {
|
||||
if !enabled {
|
||||
return "", ""
|
||||
}
|
||||
if etag := strings.TrimSpace(validators.ETag); etag != "" {
|
||||
return "If-None-Match", etag
|
||||
}
|
||||
if lastModified := strings.TrimSpace(validators.LastModified); lastModified != "" {
|
||||
return "If-Modified-Since", lastModified
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func replaceValidators(header http.Header) HTTPValidators {
|
||||
return HTTPValidators{
|
||||
ETag: strings.TrimSpace(header.Get("ETag")),
|
||||
LastModified: strings.TrimSpace(header.Get("Last-Modified")),
|
||||
}
|
||||
}
|
||||
|
||||
func refreshValidators(current HTTPValidators, header http.Header) HTTPValidators {
|
||||
if etag := strings.TrimSpace(header.Get("ETag")); etag != "" {
|
||||
current.ETag = etag
|
||||
}
|
||||
if lastModified := strings.TrimSpace(header.Get("Last-Modified")); lastModified != "" {
|
||||
current.LastModified = lastModified
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
func readValidatedBody(r io.Reader, bodyLimitBytes int64) ([]byte, error) {
|
||||
if bodyLimitBytes <= 0 {
|
||||
bodyLimitBytes = DefaultHTTPResponseBodyLimitBytes
|
||||
}
|
||||
|
||||
// Read at most bodyLimitBytes + 1 so we can detect overflow.
|
||||
limited := io.LimitReader(r, bodyLimitBytes+1)
|
||||
|
||||
b, err := io.ReadAll(limited)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(b) == 0 {
|
||||
return nil, fmt.Errorf("empty response body")
|
||||
}
|
||||
|
||||
if int64(len(b)) > bodyLimitBytes {
|
||||
return nil, fmt.Errorf("response body too large (>%d bytes)", bodyLimitBytes)
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
232
transport/http_test.go
Normal file
232
transport/http_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFetchBodyIfChangedPrefersETagAndTreats304AsUnchanged(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
var call int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
call++
|
||||
switch call {
|
||||
case 1:
|
||||
if got := r.Header.Get("If-None-Match"); got != "" {
|
||||
t.Fatalf("first request If-None-Match = %q, want empty", got)
|
||||
}
|
||||
if got := r.Header.Get("If-Modified-Since"); got != "" {
|
||||
t.Fatalf("first request If-Modified-Since = %q, want empty", got)
|
||||
}
|
||||
w.Header().Set("ETag", `"v1"`)
|
||||
w.Header().Set("Last-Modified", "Mon, 02 Jan 2006 15:04:05 GMT")
|
||||
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||
case 2:
|
||||
if got := r.Header.Get("If-None-Match"); got != `"v1"` {
|
||||
t.Fatalf("second request If-None-Match = %q, want %q", got, `"v1"`)
|
||||
}
|
||||
if got := r.Header.Get("If-Modified-Since"); got != "" {
|
||||
t.Fatalf("second request If-Modified-Since = %q, want empty when ETag is cached", got)
|
||||
}
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
default:
|
||||
t.Fatalf("unexpected call count %d", call)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
validators := HTTPValidators{}
|
||||
body, changed, next, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "application/json", true, validators)
|
||||
if err != nil {
|
||||
t.Fatalf("first FetchBodyIfChanged() error = %v", err)
|
||||
}
|
||||
if !changed {
|
||||
t.Fatalf("first FetchBodyIfChanged() changed = false, want true")
|
||||
}
|
||||
if got := string(body); got != `{"ok":true}` {
|
||||
t.Fatalf("first FetchBodyIfChanged() body = %q", got)
|
||||
}
|
||||
if got := next.ETag; got != `"v1"` {
|
||||
t.Fatalf("cached ETag = %q, want %q", got, `"v1"`)
|
||||
}
|
||||
if got := next.LastModified; got != "Mon, 02 Jan 2006 15:04:05 GMT" {
|
||||
t.Fatalf("cached Last-Modified = %q", got)
|
||||
}
|
||||
|
||||
body, changed, next, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "application/json", true, next)
|
||||
if err != nil {
|
||||
t.Fatalf("second FetchBodyIfChanged() error = %v", err)
|
||||
}
|
||||
if changed {
|
||||
t.Fatalf("second FetchBodyIfChanged() changed = true, want false")
|
||||
}
|
||||
if body != nil {
|
||||
t.Fatalf("second FetchBodyIfChanged() body = %q, want nil", string(body))
|
||||
}
|
||||
if got := next.ETag; got != `"v1"` {
|
||||
t.Fatalf("cached ETag after 304 = %q, want %q", got, `"v1"`)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchBodyIfChangedFallsBackToIfModifiedSince(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
var call int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
call++
|
||||
switch call {
|
||||
case 1:
|
||||
w.Header().Set("Last-Modified", "Tue, 03 Jan 2006 15:04:05 GMT")
|
||||
_, _ = w.Write([]byte(`first`))
|
||||
case 2:
|
||||
if got := r.Header.Get("If-None-Match"); got != "" {
|
||||
t.Fatalf("second request If-None-Match = %q, want empty", got)
|
||||
}
|
||||
if got := r.Header.Get("If-Modified-Since"); got != "Tue, 03 Jan 2006 15:04:05 GMT" {
|
||||
t.Fatalf("second request If-Modified-Since = %q", got)
|
||||
}
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
default:
|
||||
t.Fatalf("unexpected call count %d", call)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
_, changed, validators, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, HTTPValidators{})
|
||||
if err != nil {
|
||||
t.Fatalf("first FetchBodyIfChanged() error = %v", err)
|
||||
}
|
||||
if !changed {
|
||||
t.Fatalf("first FetchBodyIfChanged() changed = false, want true")
|
||||
}
|
||||
if got := validators.LastModified; got != "Tue, 03 Jan 2006 15:04:05 GMT" {
|
||||
t.Fatalf("cached Last-Modified = %q", got)
|
||||
}
|
||||
|
||||
_, changed, _, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, validators)
|
||||
if err != nil {
|
||||
t.Fatalf("second FetchBodyIfChanged() error = %v", err)
|
||||
}
|
||||
if changed {
|
||||
t.Fatalf("second FetchBodyIfChanged() changed = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchBodyIfChangedClearsValidatorsOn200WithoutValidators(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
var call int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
call++
|
||||
switch call {
|
||||
case 1:
|
||||
w.Header().Set("ETag", `"v1"`)
|
||||
_, _ = w.Write([]byte(`first`))
|
||||
case 2:
|
||||
if got := r.Header.Get("If-None-Match"); got != `"v1"` {
|
||||
t.Fatalf("second request If-None-Match = %q", got)
|
||||
}
|
||||
_, _ = w.Write([]byte(`second`))
|
||||
case 3:
|
||||
if got := r.Header.Get("If-None-Match"); got != "" {
|
||||
t.Fatalf("third request If-None-Match = %q, want empty", got)
|
||||
}
|
||||
if got := r.Header.Get("If-Modified-Since"); got != "" {
|
||||
t.Fatalf("third request If-Modified-Since = %q, want empty", got)
|
||||
}
|
||||
_, _ = w.Write([]byte(`third`))
|
||||
default:
|
||||
t.Fatalf("unexpected call count %d", call)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
_, _, validators, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, HTTPValidators{})
|
||||
if err != nil {
|
||||
t.Fatalf("first FetchBodyIfChanged() error = %v", err)
|
||||
}
|
||||
_, _, validators, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, validators)
|
||||
if err != nil {
|
||||
t.Fatalf("second FetchBodyIfChanged() error = %v", err)
|
||||
}
|
||||
if validators.ETag != "" || validators.LastModified != "" {
|
||||
t.Fatalf("validators after 200 without validators = %+v, want cleared", validators)
|
||||
}
|
||||
_, _, _, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, validators)
|
||||
if err != nil {
|
||||
t.Fatalf("third FetchBodyIfChanged() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchBodyIfChangedConditionalDisabledSkipsConditionalHeaders(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
var calls int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls++
|
||||
if got := r.Header.Get("If-None-Match"); got != "" {
|
||||
t.Fatalf("request If-None-Match = %q, want empty", got)
|
||||
}
|
||||
if got := r.Header.Get("If-Modified-Since"); got != "" {
|
||||
t.Fatalf("request If-Modified-Since = %q, want empty", got)
|
||||
}
|
||||
_, _ = w.Write([]byte(`body`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
validators := HTTPValidators{ETag: `"v1"`, LastModified: "Wed, 04 Jan 2006 15:04:05 GMT"}
|
||||
_, changed, next, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", false, validators)
|
||||
if err != nil {
|
||||
t.Fatalf("FetchBodyIfChanged() error = %v", err)
|
||||
}
|
||||
if !changed {
|
||||
t.Fatalf("FetchBodyIfChanged() changed = false, want true")
|
||||
}
|
||||
if next != validators {
|
||||
t.Fatalf("validators changed when conditional disabled: got %+v want %+v", next, validators)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("calls = %d, want 1", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchBodyIfChangedAllowsEmpty304ButRejectsEmpty200(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
notModified := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
}))
|
||||
defer notModified.Close()
|
||||
|
||||
_, changed, _, err := FetchBodyIfChanged(
|
||||
context.Background(),
|
||||
notModified.Client(),
|
||||
notModified.URL,
|
||||
"test-agent",
|
||||
"",
|
||||
true,
|
||||
HTTPValidators{ETag: `"v1"`},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("304 FetchBodyIfChanged() error = %v", err)
|
||||
}
|
||||
if changed {
|
||||
t.Fatalf("304 FetchBodyIfChanged() changed = true, want false")
|
||||
}
|
||||
|
||||
emptyBody := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer emptyBody.Close()
|
||||
|
||||
_, _, _, err = FetchBodyIfChanged(context.Background(), emptyBody.Client(), emptyBody.URL, "test-agent", "", true, HTTPValidators{})
|
||||
if err == nil {
|
||||
t.Fatalf("empty 200 FetchBodyIfChanged() error = nil, want error")
|
||||
}
|
||||
if err.Error() != "empty response body" {
|
||||
t.Fatalf("empty 200 FetchBodyIfChanged() error = %q", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user