Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| eb9a7cb349 | |||
| 3281368922 | |||
| 3ef93faf69 | |||
| 4910440756 | |||
| 3b92c2284d | |||
| 215afe1acf | |||
| 4572c53580 |
168
README.md
168
README.md
@@ -1,118 +1,92 @@
|
|||||||
# feedkit
|
# feedkit
|
||||||
|
|
||||||
`feedkit` provides domain-agnostic plumbing for feed-processing daemons.
|
`feedkit` is a small Go toolkit for building feed-processing daemons.
|
||||||
|
|
||||||
A daemon built on feedkit typically:
|
It gives you the reusable plumbing around collection, processing, routing, and
|
||||||
- ingests upstream input (polling APIs or consuming streams)
|
emission, while leaving domain concepts, schemas, and application wiring in
|
||||||
|
your daemon. The intended shape is a family of sibling applications such as
|
||||||
|
`weatherfeeder`, `newsfeeder`, or `earthquakefeeder` that all share the same
|
||||||
|
infrastructure patterns without sharing domain logic.
|
||||||
|
|
||||||
|
## What It Does
|
||||||
|
|
||||||
|
A daemon built on `feedkit` typically:
|
||||||
|
- ingests upstream input by polling HTTP APIs or consuming streams
|
||||||
- emits domain-agnostic `event.Event` values
|
- emits domain-agnostic `event.Event` values
|
||||||
- applies optional processing (normalization, dedupe, policy)
|
- optionally processes those events with stages like dedupe or normalization
|
||||||
- routes events to sinks (stdout, NATS, files, databases, etc.)
|
- routes events to one or more sinks such as stdout, NATS, or Postgres
|
||||||
|
|
||||||
|
Conceptually, the pipeline is:
|
||||||
|
|
||||||
|
`Collect -> Process -> Route -> Emit`
|
||||||
|
|
||||||
## Philosophy
|
## Philosophy
|
||||||
|
|
||||||
feedkit is not a framework. It provides small composable packages and leaves
|
`feedkit` is intentionally not a framework.
|
||||||
lifecycle, domain schemas, and domain-specific validation in your daemon.
|
|
||||||
|
|
||||||
## Conceptual pipeline
|
It does not try to own:
|
||||||
|
- your domain payload schemas
|
||||||
|
- your domain event kinds
|
||||||
|
- your daemon lifecycle or `main.go`
|
||||||
|
- your observability stack or deployment model
|
||||||
|
|
||||||
Collect -> Process (optional stages, including normalize) -> Route -> Emit
|
Instead, it provides small composable packages that are easy to wire together in
|
||||||
|
different daemons.
|
||||||
|
|
||||||
| Stage | Package(s) |
|
## When To Use It
|
||||||
|---|---|
|
|
||||||
| Collect | `sources`, `scheduler` |
|
|
||||||
| Process | `pipeline`, `processors`, `normalize` (optional stage) |
|
|
||||||
| Route | `dispatch` |
|
|
||||||
| Emit | `sinks` |
|
|
||||||
| Configure | `config` |
|
|
||||||
|
|
||||||
## Core packages
|
`feedkit` is a good fit when you want:
|
||||||
|
- multiple small ingestion daemons with shared infrastructure patterns
|
||||||
|
- clear separation between raw upstream payloads and normalized canonical models
|
||||||
|
- reusable routing and sink behavior across domains
|
||||||
|
- strong config and event-envelope conventions without centralizing domain rules
|
||||||
|
|
||||||
### `config`
|
It is a poor fit if you want a monolithic framework that dictates application
|
||||||
|
structure end-to-end.
|
||||||
|
|
||||||
Loads YAML config with strict decoding and domain-agnostic validation.
|
## Built-In Capabilities
|
||||||
|
|
||||||
`SourceConfig` supports both source modes:
|
`feedkit` currently includes:
|
||||||
- `mode: poll` requires `every`
|
- strict YAML config loading and validation
|
||||||
- `mode: stream` forbids `every`
|
- polling and streaming source abstractions
|
||||||
- omitted `mode` means auto (inferred from the registered driver type)
|
- scheduler orchestration for configured sources
|
||||||
|
- optional pipeline processors
|
||||||
|
- built-in dedupe and normalization processors
|
||||||
|
- route compilation and sink fanout
|
||||||
|
- built-in sinks for `stdout`, `nats`, and `postgres`
|
||||||
|
|
||||||
It also supports optional expected source kinds:
|
The Postgres sink is intentionally split between feedkit-owned infrastructure
|
||||||
- `kinds: ["observation", "alert"]` (preferred)
|
and daemon-owned schema mapping. `feedkit` manages connection setup, DDL,
|
||||||
- `kind: "observation"` (legacy fallback)
|
writes, and pruning; downstream applications define the schema and event mapper.
|
||||||
|
|
||||||
### `event`
|
## Typical Wiring
|
||||||
|
|
||||||
Defines the domain-agnostic event envelope (`event.Event`) used across the system.
|
At a high level, a daemon built on `feedkit` does this:
|
||||||
|
|
||||||
### `sources`
|
|
||||||
|
|
||||||
Defines source interfaces and driver registry:
|
|
||||||
|
|
||||||
```go
|
|
||||||
type Input interface {
|
|
||||||
Name() string
|
|
||||||
}
|
|
||||||
|
|
||||||
type PollSource interface {
|
|
||||||
Input
|
|
||||||
Poll(ctx context.Context) ([]event.Event, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type StreamSource interface {
|
|
||||||
Input
|
|
||||||
Run(ctx context.Context, out chan<- event.Event) error
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- a poll can emit `0..N` events
|
|
||||||
- stream sources emit events continuously
|
|
||||||
- a single source may emit multiple event kinds
|
|
||||||
- driver implementations live in downstream daemons and are registered via `sources.Registry`
|
|
||||||
|
|
||||||
### `scheduler`
|
|
||||||
|
|
||||||
Runs one goroutine per source job:
|
|
||||||
- poll sources: cadence driven (`every` + jitter)
|
|
||||||
- stream sources: continuous run loop
|
|
||||||
|
|
||||||
### `pipeline`
|
|
||||||
|
|
||||||
Optional processing chain between collection and dispatch.
|
|
||||||
Processors can transform, drop, or reject events.
|
|
||||||
|
|
||||||
### `processors`
|
|
||||||
|
|
||||||
Defines the generic processor interface and a named-driver registry used by
|
|
||||||
daemons to build ordered processor chains.
|
|
||||||
|
|
||||||
### `normalize`
|
|
||||||
|
|
||||||
Concrete normalization processor implementation. Typical use: sources emit raw
|
|
||||||
payload events, then a normalize stage maps them to canonical schemas.
|
|
||||||
|
|
||||||
### `dispatch`
|
|
||||||
|
|
||||||
Compiles routes and fans out events to sinks with per-sink queue/worker isolation.
|
|
||||||
|
|
||||||
### `sinks`
|
|
||||||
|
|
||||||
Defines sink interface and sink registry. Built-ins include `stdout` and `nats`, with
|
|
||||||
additional sink implementations at varying maturity.
|
|
||||||
|
|
||||||
## Typical wiring
|
|
||||||
|
|
||||||
1. Load config.
|
1. Load config.
|
||||||
2. Register/build sources from `cfg.Sources`.
|
2. Register domain-specific source drivers.
|
||||||
3. Register/build sinks from `cfg.Sinks`.
|
3. Register built-in and/or custom sinks.
|
||||||
4. Compile routes.
|
4. Build sources, sinks, and optional processor chain from config.
|
||||||
5. Start scheduler (`sources -> bus`).
|
5. Compile routes.
|
||||||
6. Start dispatcher (`bus -> pipeline -> sinks`).
|
6. Start the scheduler and dispatcher.
|
||||||
|
|
||||||
## Non-goals
|
The package docs are the better source of truth for code-level details. In
|
||||||
|
particular, each subpackage `doc.go` describes its external API surface and any
|
||||||
|
optional helper APIs in `helpers.go`.
|
||||||
|
|
||||||
feedkit intentionally does not:
|
## Package Layout
|
||||||
- define domain payload schemas
|
|
||||||
- enforce domain-specific event kinds
|
The major packages are:
|
||||||
- own application lifecycle
|
- `config`: config loading and validation
|
||||||
- prescribe observability stack choices
|
- `event`: the domain-agnostic event envelope
|
||||||
|
- `sources`: source interfaces and reusable source helpers
|
||||||
|
- `scheduler`: source execution and cadence management
|
||||||
|
- `processors`: processor interfaces and registry
|
||||||
|
- `processors/dedupe`: built-in in-memory dedupe processor
|
||||||
|
- `processors/normalize`: built-in normalization processor and helpers
|
||||||
|
- `pipeline`: optional processor chain
|
||||||
|
- `dispatch`: route compilation and fanout
|
||||||
|
- `sinks`: sink interfaces, built-ins, and explicit Postgres factory helpers
|
||||||
|
|
||||||
|
The root package docs in `doc.go` provide a concise package-by-package map for
|
||||||
|
Go documentation consumers.
|
||||||
|
|||||||
@@ -69,19 +69,13 @@ type SourceConfig struct {
|
|||||||
// If set, it describes the expected emitted event kinds for this source.
|
// If set, it describes the expected emitted event kinds for this source.
|
||||||
Kinds []string `yaml:"kinds"`
|
Kinds []string `yaml:"kinds"`
|
||||||
|
|
||||||
// Kind is the legacy singular form. Prefer "kinds".
|
|
||||||
// If both kind and kinds are set, validation fails.
|
|
||||||
Kind string `yaml:"kind"`
|
|
||||||
|
|
||||||
// Params are driver-specific settings (URL, headers, station IDs, API keys, etc.).
|
// Params are driver-specific settings (URL, headers, station IDs, API keys, etc.).
|
||||||
// The driver implementation is responsible for reading/validating these.
|
// The driver implementation is responsible for reading/validating these.
|
||||||
Params map[string]any `yaml:"params"`
|
Params map[string]any `yaml:"params"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpectedKinds returns normalized expected kinds from config.
|
// ExpectedKinds returns normalized expected kinds from config.
|
||||||
// "kinds" takes precedence; "kind" is used as a legacy fallback.
|
|
||||||
func (cfg SourceConfig) ExpectedKinds() []string {
|
func (cfg SourceConfig) ExpectedKinds() []string {
|
||||||
if len(cfg.Kinds) > 0 {
|
|
||||||
out := make([]string, 0, len(cfg.Kinds))
|
out := make([]string, 0, len(cfg.Kinds))
|
||||||
for _, k := range cfg.Kinds {
|
for _, k := range cfg.Kinds {
|
||||||
k = strings.TrimSpace(k)
|
k = strings.TrimSpace(k)
|
||||||
@@ -90,18 +84,16 @@ func (cfg SourceConfig) ExpectedKinds() []string {
|
|||||||
}
|
}
|
||||||
out = append(out, k)
|
out = append(out, k)
|
||||||
}
|
}
|
||||||
return out
|
if len(out) == 0 {
|
||||||
}
|
|
||||||
if k := strings.TrimSpace(cfg.Kind); k != "" {
|
|
||||||
return []string{k}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// SinkConfig describes one output sink adapter.
|
// SinkConfig describes one output sink adapter.
|
||||||
type SinkConfig struct {
|
type SinkConfig struct {
|
||||||
Name string `yaml:"name"`
|
Name string `yaml:"name"`
|
||||||
Driver string `yaml:"driver"` // "stdout", "file", "postgres", "rabbitmq", ...
|
Driver string `yaml:"driver"` // "stdout", "nats", "postgres", ...
|
||||||
Params map[string]any `yaml:"params"` // sink-specific settings
|
Params map[string]any `yaml:"params"` // sink-specific settings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,20 +12,12 @@ func TestSourceConfigExpectedKinds(t *testing.T) {
|
|||||||
want []string
|
want []string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "plural kinds preferred",
|
name: "plural kinds normalized",
|
||||||
cfg: SourceConfig{
|
cfg: SourceConfig{
|
||||||
Kinds: []string{" observation ", "forecast"},
|
Kinds: []string{" observation ", "forecast"},
|
||||||
Kind: "alert",
|
|
||||||
},
|
},
|
||||||
want: []string{"observation", "forecast"},
|
want: []string{"observation", "forecast"},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "legacy singular fallback",
|
|
||||||
cfg: SourceConfig{
|
|
||||||
Kind: " alert ",
|
|
||||||
},
|
|
||||||
want: []string{"alert"},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "empty kinds",
|
name: "empty kinds",
|
||||||
cfg: SourceConfig{},
|
cfg: SourceConfig{},
|
||||||
|
|||||||
@@ -105,13 +105,7 @@ func (c *Config) Validate() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Kind/Kinds (optional)
|
// Kinds (optional)
|
||||||
if s.Kind != "" && len(s.Kinds) > 0 {
|
|
||||||
m.Add(fieldErr(path+".kind", `cannot be set when "kinds" is provided (use only "kinds")`))
|
|
||||||
}
|
|
||||||
if s.Kind != "" && strings.TrimSpace(s.Kind) == "" {
|
|
||||||
m.Add(fieldErr(path+".kind", "cannot be blank (omit it entirely, or provide a non-empty string)"))
|
|
||||||
}
|
|
||||||
for j, k := range s.Kinds {
|
for j, k := range s.Kinds {
|
||||||
kpath := fmt.Sprintf("%s.kinds[%d]", path, j)
|
kpath := fmt.Sprintf("%s.kinds[%d]", path, j)
|
||||||
if strings.TrimSpace(k) == "" {
|
if strings.TrimSpace(k) == "" {
|
||||||
@@ -141,7 +135,7 @@ func (c *Config) Validate() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if strings.TrimSpace(s.Driver) == "" {
|
if strings.TrimSpace(s.Driver) == "" {
|
||||||
m.Add(fieldErr(path+".driver", "is required (stdout|file|postgres|rabbitmq|...)"))
|
m.Add(fieldErr(path+".driver", "is required (stdout|nats|postgres|...)"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Params can be nil; that's fine.
|
// Params can be nil; that's fine.
|
||||||
|
|||||||
@@ -114,31 +114,6 @@ func TestValidate_SourceModeRejectsUnknownValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidate_SourceKindAndKindsConflict(t *testing.T) {
|
|
||||||
cfg := &Config{
|
|
||||||
Sources: []SourceConfig{
|
|
||||||
{
|
|
||||||
Name: "src1",
|
|
||||||
Driver: "driver1",
|
|
||||||
Every: Duration{Duration: time.Minute},
|
|
||||||
Kind: "observation",
|
|
||||||
Kinds: []string{"forecast"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Sinks: []SinkConfig{
|
|
||||||
{Name: "sink1", Driver: "stdout"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := cfg.Validate()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected error, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), `sources[0].kind`) {
|
|
||||||
t.Fatalf("expected error to mention sources[0].kind, got: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_SourceKindsRejectBlankEntries(t *testing.T) {
|
func TestValidate_SourceKindsRejectBlankEntries(t *testing.T) {
|
||||||
cfg := &Config{
|
cfg := &Config{
|
||||||
Sources: []SourceConfig{
|
Sources: []SourceConfig{
|
||||||
|
|||||||
122
doc.go
122
doc.go
@@ -1,110 +1,56 @@
|
|||||||
// Package feedkit provides domain-agnostic plumbing for feed-processing daemons.
|
// Package feedkit provides a high-level map of the feedkit package set.
|
||||||
//
|
//
|
||||||
// A feed daemon ingests upstream input, turns it into event.Event values, applies
|
// Most real applications do not import the root package directly. Instead, they
|
||||||
// optional processing, and emits to sinks.
|
// compose the subpackages that handle configuration, collection, processing,
|
||||||
|
// routing, and sinks.
|
||||||
//
|
//
|
||||||
// Conceptual flow:
|
// The usual flow through feedkit is:
|
||||||
//
|
//
|
||||||
// Collect -> Process (optional stages, including normalize) -> Route -> Emit
|
// Collect -> Process -> Route -> Emit
|
||||||
//
|
//
|
||||||
// In feedkit this maps to:
|
// That flow maps to packages like this:
|
||||||
//
|
|
||||||
// Collect: sources + scheduler
|
|
||||||
// Process: pipeline + processors + normalize (optional stage)
|
|
||||||
// Route: dispatch
|
|
||||||
// Emit: sinks
|
|
||||||
// Config: config
|
|
||||||
//
|
|
||||||
// feedkit intentionally does not define domain payload schemas or domain-specific
|
|
||||||
// validation rules. Those belong in each concrete daemon.
|
|
||||||
//
|
|
||||||
// Public packages
|
|
||||||
//
|
//
|
||||||
// - config
|
// - config
|
||||||
// YAML config loading/validation (strict decode + domain-agnostic checks).
|
// Loads and validates daemon config. This package owns domain-agnostic
|
||||||
//
|
// config shape and consistency checks.
|
||||||
// SourceConfig supports both polling and streaming sources:
|
|
||||||
//
|
|
||||||
// - mode: "poll" | "stream" | omitted (auto by driver type)
|
|
||||||
//
|
|
||||||
// - every: poll interval (required for mode="poll")
|
|
||||||
//
|
|
||||||
// - kinds: optional expected emitted kinds
|
|
||||||
//
|
|
||||||
// - kind: legacy singular fallback
|
|
||||||
//
|
|
||||||
// - params: driver-specific settings
|
|
||||||
//
|
//
|
||||||
// - event
|
// - event
|
||||||
// Domain-agnostic event envelope (ID, Kind, Source, EmittedAt, Schema, Payload).
|
// Defines the event.Event envelope shared across sources, processors,
|
||||||
|
// dispatch, and sinks.
|
||||||
//
|
//
|
||||||
// - sources
|
// - sources
|
||||||
// Source abstractions and source-driver registry.
|
// Defines polling and streaming source interfaces, the source registry, and
|
||||||
//
|
// reusable source helpers.
|
||||||
// There are two source interfaces:
|
|
||||||
//
|
|
||||||
// - PollSource: Poll(ctx) ([]event.Event, error)
|
|
||||||
//
|
|
||||||
// - StreamSource: Run(ctx, out) error
|
|
||||||
//
|
|
||||||
// Both share Input{Name()}. A source may emit 0..N events per poll/run step,
|
|
||||||
// and may emit multiple event kinds.
|
|
||||||
//
|
//
|
||||||
// - scheduler
|
// - scheduler
|
||||||
// Runs one goroutine per job:
|
// Runs configured sources on a cadence or as long-lived stream workers.
|
||||||
//
|
|
||||||
// - PollSource jobs run on Every (+ jitter)
|
|
||||||
//
|
|
||||||
// - StreamSource jobs run continuously
|
|
||||||
//
|
|
||||||
// - pipeline
|
|
||||||
// Processor chain between scheduler and dispatch.
|
|
||||||
// Processors can transform, drop, or reject events.
|
|
||||||
//
|
//
|
||||||
// - processors
|
// - processors
|
||||||
// Generic processor interface and named factory registry for wiring chains.
|
// Defines the generic processor interface and registry used to build
|
||||||
|
// ordered processor chains.
|
||||||
//
|
//
|
||||||
// - normalize
|
// - processors/dedupe
|
||||||
// Concrete pipeline processor for raw->canonical mapping.
|
// Built-in in-memory dedupe processor keyed by Event.ID.
|
||||||
// If no normalizer matches, the event passes through unchanged by default.
|
//
|
||||||
|
// - processors/normalize
|
||||||
|
// Built-in normalization processor plus helper APIs for raw-to-canonical
|
||||||
|
// event mapping.
|
||||||
|
//
|
||||||
|
// - pipeline
|
||||||
|
// Applies an ordered processor chain between collection and dispatch.
|
||||||
//
|
//
|
||||||
// - dispatch
|
// - dispatch
|
||||||
// Routes events to sinks and isolates slow sinks via per-sink queues/workers.
|
// Compiles routes and fans events out to sinks with per-sink isolation.
|
||||||
//
|
//
|
||||||
// - sinks
|
// - sinks
|
||||||
// Sink abstractions + sink registry.
|
// Defines sink interfaces, the sink registry, schema-free built-in sinks,
|
||||||
|
// and explicit Postgres factory helpers.
|
||||||
//
|
//
|
||||||
// Typical wiring (daemon main.go)
|
// feedkit is intentionally domain-agnostic. Domain schemas, domain event kinds,
|
||||||
|
// upstream-specific parsing, and daemon lifecycle remain the responsibility of
|
||||||
|
// each concrete application.
|
||||||
//
|
//
|
||||||
// 1. Load config.
|
// For repository-level overview and usage narrative, see README.md. For
|
||||||
// 2. Register source drivers and build sources from config.Sources.
|
// code-level details, each subpackage doc.go is the source of truth for that
|
||||||
// 3. Register sink drivers and build sinks from config.Sinks.
|
// package's public API surface and optional helpers.
|
||||||
// 4. Compile routes.
|
|
||||||
// 5. Start scheduler (sources -> bus) and dispatcher (bus -> pipeline -> sinks).
|
|
||||||
//
|
|
||||||
// Sketch:
|
|
||||||
//
|
|
||||||
// cfg, _ := config.Load("config.yml")
|
|
||||||
// srcReg := sources.NewRegistry()
|
|
||||||
// // domain registers poll/stream drivers...
|
|
||||||
//
|
|
||||||
// var jobs []scheduler.Job
|
|
||||||
// for _, sc := range cfg.Sources {
|
|
||||||
// src, _ := srcReg.BuildInput(sc)
|
|
||||||
// jobs = append(jobs, scheduler.Job{
|
|
||||||
// Source: src,
|
|
||||||
// Every: sc.Every.Duration,
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// bus := make(chan event.Event, 256)
|
|
||||||
// s := &scheduler.Scheduler{Jobs: jobs, Out: bus, Logf: logf}
|
|
||||||
// // start dispatcher similarly...
|
|
||||||
//
|
|
||||||
// # Context and cancellation
|
|
||||||
//
|
|
||||||
// All blocking work should honor context cancellation:
|
|
||||||
// - source polling/streaming I/O
|
|
||||||
// - sink consumption
|
|
||||||
// - any expensive processor work
|
|
||||||
package feedkit
|
package feedkit
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -3,6 +3,7 @@ module gitea.maximumdirect.net/ejr/feedkit
|
|||||||
go 1.22
|
go 1.22
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/lib/pq v1.10.9
|
||||||
github.com/nats-io/nats.go v1.34.0
|
github.com/nats-io/nats.go v1.34.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -1,5 +1,7 @@
|
|||||||
github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4=
|
github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4=
|
||||||
github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
|
github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
|
||||||
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
github.com/nats-io/nats.go v1.34.0 h1:fnxnPCNiwIG5w08rlMcEKTUw4AV/nKyGCOJE8TdhSPk=
|
github.com/nats-io/nats.go v1.34.0 h1:fnxnPCNiwIG5w08rlMcEKTUw4AV/nKyGCOJE8TdhSPk=
|
||||||
github.com/nats-io/nats.go v1.34.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8=
|
github.com/nats-io/nats.go v1.34.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8=
|
||||||
github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI=
|
github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI=
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
// Package normalize provides a concrete normalization processor for feedkit pipelines.
|
|
||||||
//
|
|
||||||
// Motivation:
|
|
||||||
// Many daemons have sources that:
|
|
||||||
// 1. fetch raw upstream data (often JSON), and
|
|
||||||
// 2. transform it into a domain's normalized payload format.
|
|
||||||
//
|
|
||||||
// Doing both steps inside Source.Poll works, but tends to make sources large and
|
|
||||||
// encourages duplication (unit conversions, common mapping helpers, etc.).
|
|
||||||
//
|
|
||||||
// This package lets a source emit a "raw" event (e.g., Schema="raw.openweather.current.v1",
|
|
||||||
// Payload=json.RawMessage), and then a normalize.Processor can convert it into a
|
|
||||||
// normalized event (e.g., Schema="weather.observation.v1", Payload=WeatherObservation{}).
|
|
||||||
//
|
|
||||||
// Key property: normalization is optional.
|
|
||||||
// If no Normalizer matches an event, Processor passes it through unchanged by default.
|
|
||||||
package normalize
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package pipeline
|
|
||||||
|
|
||||||
// Placeholder for dedupe processor:
|
|
||||||
// - key by Event.ID or computed key
|
|
||||||
// - in-memory store first; later optional Postgres-backed
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package pipeline
|
|
||||||
|
|
||||||
// Placeholder for rate limit processor:
|
|
||||||
// - per source/kind sink routing limits
|
|
||||||
// - cooldown windows
|
|
||||||
28
processors/dedupe/doc.go
Normal file
28
processors/dedupe/doc.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
// Package dedupe provides a default in-memory LRU deduplication processor.
|
||||||
|
//
|
||||||
|
// The processor keys strictly by event.Event.ID:
|
||||||
|
// - first-seen IDs pass through
|
||||||
|
// - repeated IDs are dropped
|
||||||
|
//
|
||||||
|
// The in-memory seen-ID set is bounded by a required maxEntries capacity.
|
||||||
|
// When capacity is exceeded, the least recently used ID is evicted.
|
||||||
|
//
|
||||||
|
// Typical registry wiring:
|
||||||
|
//
|
||||||
|
// ```go
|
||||||
|
// reg := processors.NewRegistry()
|
||||||
|
// reg.Register("dedupe", dedupe.Factory(10_000))
|
||||||
|
//
|
||||||
|
// reg.Register("normalize", func() (processors.Processor, error) {
|
||||||
|
// return normalize.NewProcessor(myNormalizers, false), nil
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// chain, err := reg.BuildChain([]string{"dedupe", "normalize"})
|
||||||
|
//
|
||||||
|
// if err != nil {
|
||||||
|
// // handle wiring error
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// p := &pipeline.Pipeline{Processors: chain}
|
||||||
|
// ```
|
||||||
|
package dedupe
|
||||||
89
processors/dedupe/processor.go
Normal file
89
processors/dedupe/processor.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package dedupe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/list"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/processors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Processor drops duplicate events by Event.ID using an in-memory LRU.
|
||||||
|
type Processor struct {
|
||||||
|
maxEntries int
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
order *list.List // most-recent at front, least-recent at back
|
||||||
|
byID map[string]*list.Element // id -> list element (element.Value is string id)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ processors.Processor = (*Processor)(nil)
|
||||||
|
|
||||||
|
// NewProcessor constructs a dedupe processor with a required max entry count.
|
||||||
|
func NewProcessor(maxEntries int) (*Processor, error) {
|
||||||
|
if maxEntries <= 0 {
|
||||||
|
return nil, fmt.Errorf("dedupe: maxEntries must be > 0, got %d", maxEntries)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Processor{
|
||||||
|
maxEntries: maxEntries,
|
||||||
|
order: list.New(),
|
||||||
|
byID: make(map[string]*list.Element, maxEntries),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Factory returns a processors.Factory that constructs Processor instances.
|
||||||
|
func Factory(maxEntries int) processors.Factory {
|
||||||
|
return func() (processors.Processor, error) {
|
||||||
|
return NewProcessor(maxEntries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process implements processors.Processor.
|
||||||
|
func (p *Processor) Process(_ context.Context, in event.Event) (*event.Event, error) {
|
||||||
|
if p == nil {
|
||||||
|
return nil, fmt.Errorf("dedupe: processor is nil")
|
||||||
|
}
|
||||||
|
if p.maxEntries <= 0 {
|
||||||
|
return nil, fmt.Errorf("dedupe: processor maxEntries must be > 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
id := strings.TrimSpace(in.ID)
|
||||||
|
if id == "" {
|
||||||
|
return nil, fmt.Errorf("dedupe: event ID is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
|
||||||
|
if p.order == nil || p.byID == nil {
|
||||||
|
p.mu.Unlock()
|
||||||
|
return nil, fmt.Errorf("dedupe: processor is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
if elem, exists := p.byID[id]; exists {
|
||||||
|
p.order.MoveToFront(elem)
|
||||||
|
p.mu.Unlock()
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
elem := p.order.PushFront(id)
|
||||||
|
p.byID[id] = elem
|
||||||
|
|
||||||
|
if p.order.Len() > p.maxEntries {
|
||||||
|
oldest := p.order.Back()
|
||||||
|
if oldest != nil {
|
||||||
|
p.order.Remove(oldest)
|
||||||
|
if oldestID, ok := oldest.Value.(string); ok {
|
||||||
|
delete(p.byID, oldestID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
out := in
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
163
processors/dedupe/processor_test.go
Normal file
163
processors/dedupe/processor_test.go
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
package dedupe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/processors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewProcessorValidation(t *testing.T) {
|
||||||
|
t.Run("rejects non-positive maxEntries", func(t *testing.T) {
|
||||||
|
for _, maxEntries := range []int{0, -1} {
|
||||||
|
p, err := NewProcessor(maxEntries)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for maxEntries=%d, got nil", maxEntries)
|
||||||
|
}
|
||||||
|
if p != nil {
|
||||||
|
t.Fatalf("expected nil processor for maxEntries=%d", maxEntries)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "maxEntries") {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("accepts positive maxEntries", func(t *testing.T) {
|
||||||
|
p, err := NewProcessor(1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProcessor error: %v", err)
|
||||||
|
}
|
||||||
|
if p == nil {
|
||||||
|
t.Fatalf("expected processor, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessorFirstSeenAndDuplicate(t *testing.T) {
|
||||||
|
p, err := NewProcessor(8)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProcessor error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
first := testEvent("evt-1")
|
||||||
|
|
||||||
|
out, err := p.Process(ctx, first)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Process first error: %v", err)
|
||||||
|
}
|
||||||
|
if out == nil {
|
||||||
|
t.Fatalf("expected first event to pass through")
|
||||||
|
}
|
||||||
|
if out.ID != first.ID {
|
||||||
|
t.Fatalf("expected unchanged ID %q, got %q", first.ID, out.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err = p.Process(ctx, first)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Process duplicate error: %v", err)
|
||||||
|
}
|
||||||
|
if out != nil {
|
||||||
|
t.Fatalf("expected duplicate to be dropped, got %#v", out)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err = p.Process(ctx, testEvent("evt-2"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Process second unique error: %v", err)
|
||||||
|
}
|
||||||
|
if out == nil {
|
||||||
|
t.Fatalf("expected second unique event to pass through")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessorLRUEvictionAndPromotion(t *testing.T) {
|
||||||
|
p, err := NewProcessor(2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProcessor error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
mustPass(t, p, ctx, "a")
|
||||||
|
mustPass(t, p, ctx, "b")
|
||||||
|
mustDrop(t, p, ctx, "a") // promote "a" so "b" becomes least-recently-used
|
||||||
|
mustPass(t, p, ctx, "c") // evicts "b"
|
||||||
|
mustDrop(t, p, ctx, "a") // "a" should still be tracked after promotion
|
||||||
|
mustPass(t, p, ctx, "b") // "b" was evicted, so now it passes again
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessorRejectsBlankID(t *testing.T) {
|
||||||
|
p, err := NewProcessor(4)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProcessor error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
in := testEvent(" ")
|
||||||
|
out, err := p.Process(context.Background(), in)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for blank ID")
|
||||||
|
}
|
||||||
|
if out != nil {
|
||||||
|
t.Fatalf("expected nil output on error, got %#v", out)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "event ID is required") {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFactoryWithRegistry(t *testing.T) {
|
||||||
|
r := processors.NewRegistry()
|
||||||
|
r.Register("dedupe", Factory(3))
|
||||||
|
|
||||||
|
p, err := r.Build("dedupe")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Build error: %v", err)
|
||||||
|
}
|
||||||
|
if p == nil {
|
||||||
|
t.Fatalf("expected processor, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := p.Process(context.Background(), testEvent("evt-factory-1"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Process error: %v", err)
|
||||||
|
}
|
||||||
|
if out == nil {
|
||||||
|
t.Fatalf("expected first event to pass through")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustPass(t *testing.T, p *Processor, ctx context.Context, id string) {
|
||||||
|
t.Helper()
|
||||||
|
out, err := p.Process(ctx, testEvent(id))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected pass for id=%q, got error: %v", id, err)
|
||||||
|
}
|
||||||
|
if out == nil {
|
||||||
|
t.Fatalf("expected pass for id=%q, got drop", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustDrop(t *testing.T, p *Processor, ctx context.Context, id string) {
|
||||||
|
t.Helper()
|
||||||
|
out, err := p.Process(ctx, testEvent(id))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected drop for id=%q, got error: %v", id, err)
|
||||||
|
}
|
||||||
|
if out != nil {
|
||||||
|
t.Fatalf("expected drop for id=%q, got output", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testEvent(id string) event.Event {
|
||||||
|
return event.Event{
|
||||||
|
ID: id,
|
||||||
|
Kind: event.Kind("observation"),
|
||||||
|
Source: "source-1",
|
||||||
|
EmittedAt: time.Now().UTC(),
|
||||||
|
Payload: map[string]any{"ok": true},
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,11 +9,13 @@
|
|||||||
// Example:
|
// Example:
|
||||||
//
|
//
|
||||||
// reg := processors.NewRegistry()
|
// reg := processors.NewRegistry()
|
||||||
|
// reg.Register("dedupe", dedupe.Factory(10_000))
|
||||||
// reg.Register("normalize", func() (processors.Processor, error) {
|
// reg.Register("normalize", func() (processors.Processor, error) {
|
||||||
|
// // import "gitea.maximumdirect.net/ejr/feedkit/processors/normalize"
|
||||||
// return normalize.NewProcessor(myNormalizers, false), nil
|
// return normalize.NewProcessor(myNormalizers, false), nil
|
||||||
// })
|
// })
|
||||||
//
|
//
|
||||||
// chain, err := reg.BuildChain([]string{"normalize"})
|
// chain, err := reg.BuildChain([]string{"dedupe", "normalize"})
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// // handle wiring error
|
// // handle wiring error
|
||||||
// }
|
// }
|
||||||
|
|||||||
19
processors/normalize/doc.go
Normal file
19
processors/normalize/doc.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
// Package normalize provides the feedkit normalization processor and related
|
||||||
|
// helper APIs for raw-to-canonical event mapping.
|
||||||
|
//
|
||||||
|
// External API surface:
|
||||||
|
// - Processor: concrete processors.Processor implementation
|
||||||
|
// - Normalizer / Func: normalization interface and ergonomic function adapter
|
||||||
|
//
|
||||||
|
// Optional helpers from helpers.go:
|
||||||
|
// - PayloadJSONBytes: extract supported JSON-shaped payloads into bytes
|
||||||
|
// - DecodeJSONPayload: decode an event payload into a typed struct
|
||||||
|
// - FinalizeEvent: copy the input event envelope onto a normalized output
|
||||||
|
//
|
||||||
|
// Typical usage:
|
||||||
|
// sources emit raw events (often with json.RawMessage payloads), then a
|
||||||
|
// normalize.Processor converts matching raw schemas into canonical payloads.
|
||||||
|
//
|
||||||
|
// Key property: normalization is optional.
|
||||||
|
// If no Normalizer matches an event, Processor passes it through unchanged by default.
|
||||||
|
package normalize
|
||||||
84
processors/normalize/helpers.go
Normal file
84
processors/normalize/helpers.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package normalize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PayloadJSONBytes extracts a JSON payload into bytes suitable for json.Unmarshal.
|
||||||
|
//
|
||||||
|
// Supported payload shapes:
|
||||||
|
// - json.RawMessage
|
||||||
|
// - []byte
|
||||||
|
// - string
|
||||||
|
// - map[string]any
|
||||||
|
func PayloadJSONBytes(e event.Event) ([]byte, error) {
|
||||||
|
if e.Payload == nil {
|
||||||
|
return nil, fmt.Errorf("payload is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := e.Payload.(type) {
|
||||||
|
case json.RawMessage:
|
||||||
|
if len(v) == 0 {
|
||||||
|
return nil, fmt.Errorf("payload is empty json.RawMessage")
|
||||||
|
}
|
||||||
|
return []byte(v), nil
|
||||||
|
case []byte:
|
||||||
|
if len(v) == 0 {
|
||||||
|
return nil, fmt.Errorf("payload is empty []byte")
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
case string:
|
||||||
|
if v == "" {
|
||||||
|
return nil, fmt.Errorf("payload is empty string")
|
||||||
|
}
|
||||||
|
return []byte(v), nil
|
||||||
|
case map[string]any:
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal map payload: %w", err)
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported payload type %T", e.Payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeJSONPayload extracts the event payload as bytes and unmarshals it into T.
|
||||||
|
func DecodeJSONPayload[T any](in event.Event) (T, error) {
|
||||||
|
var zero T
|
||||||
|
|
||||||
|
b, err := PayloadJSONBytes(in)
|
||||||
|
if err != nil {
|
||||||
|
return zero, fmt.Errorf("extract payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed T
|
||||||
|
if err := json.Unmarshal(b, &parsed); err != nil {
|
||||||
|
return zero, fmt.Errorf("decode raw payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinalizeEvent builds the output event envelope by copying the input and applying
|
||||||
|
// the new schema/payload, plus optional EffectiveAt.
|
||||||
|
func FinalizeEvent(in event.Event, outSchema string, outPayload any, effectiveAt time.Time) (*event.Event, error) {
|
||||||
|
out := in
|
||||||
|
out.Schema = outSchema
|
||||||
|
out.Payload = outPayload
|
||||||
|
|
||||||
|
if !effectiveAt.IsZero() {
|
||||||
|
t := effectiveAt.UTC()
|
||||||
|
out.EffectiveAt = &t
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := out.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
118
processors/normalize/helpers_test.go
Normal file
118
processors/normalize/helpers_test.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package normalize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPayloadJSONBytesSupportedShapes(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payload any
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "rawmessage", payload: json.RawMessage(`{"a":1}`), want: `{"a":1}`},
|
||||||
|
{name: "bytes", payload: []byte(`{"a":2}`), want: `{"a":2}`},
|
||||||
|
{name: "string", payload: `{"a":3}`, want: `{"a":3}`},
|
||||||
|
{name: "map", payload: map[string]any{"a": 4}, want: `{"a":4}`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got, err := PayloadJSONBytes(event.Event{Payload: tc.payload})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PayloadJSONBytes() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if string(got) != tc.want {
|
||||||
|
t.Fatalf("PayloadJSONBytes() = %s, want %s", string(got), tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPayloadJSONBytesRejectsInvalidPayloads(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payload any
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "nil", payload: nil, want: "payload is nil"},
|
||||||
|
{name: "empty rawmessage", payload: json.RawMessage{}, want: "payload is empty json.RawMessage"},
|
||||||
|
{name: "empty bytes", payload: []byte{}, want: "payload is empty []byte"},
|
||||||
|
{name: "empty string", payload: "", want: "payload is empty string"},
|
||||||
|
{name: "unsupported", payload: 123, want: "unsupported payload type"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, err := PayloadJSONBytes(event.Event{Payload: tc.payload})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("PayloadJSONBytes() expected error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), tc.want) {
|
||||||
|
t.Fatalf("PayloadJSONBytes() error = %q, want substring %q", err, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeJSONPayload(t *testing.T) {
|
||||||
|
type payload struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := DecodeJSONPayload[payload](event.Event{
|
||||||
|
Payload: json.RawMessage(`{"name":"alice"}`),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DecodeJSONPayload() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got.Name != "alice" {
|
||||||
|
t.Fatalf("DecodeJSONPayload() = %#v, want name alice", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinalizeEventPreservesEnvelopeAndEffectiveAtBehavior(t *testing.T) {
|
||||||
|
existingEffectiveAt := time.Date(2026, 3, 28, 11, 0, 0, 0, time.UTC)
|
||||||
|
in := event.Event{
|
||||||
|
ID: "evt-1",
|
||||||
|
Kind: event.Kind("observation"),
|
||||||
|
Source: "source-a",
|
||||||
|
EmittedAt: time.Date(2026, 3, 28, 12, 0, 0, 0, time.UTC),
|
||||||
|
EffectiveAt: &existingEffectiveAt,
|
||||||
|
Schema: "raw.example.v1",
|
||||||
|
Payload: map[string]any{"old": true},
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := FinalizeEvent(in, "example.v1", map[string]any{"value": 1.234567}, time.Time{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FinalizeEvent() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if out.ID != in.ID || out.Kind != in.Kind || out.Source != in.Source || out.EmittedAt != in.EmittedAt {
|
||||||
|
t.Fatalf("FinalizeEvent() changed preserved envelope fields: %#v", out)
|
||||||
|
}
|
||||||
|
if out.EffectiveAt == nil || !out.EffectiveAt.Equal(existingEffectiveAt) {
|
||||||
|
t.Fatalf("FinalizeEvent() effectiveAt = %#v, want preserved existing value", out.EffectiveAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextEffectiveAt := time.Date(2026, 3, 28, 13, 0, 0, 0, time.FixedZone("x", -4*3600))
|
||||||
|
out, err = FinalizeEvent(in, "example.v1", map[string]any{"value": 1.234567}, nextEffectiveAt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FinalizeEvent() unexpected overwrite error: %v", err)
|
||||||
|
}
|
||||||
|
if out.EffectiveAt == nil || !out.EffectiveAt.Equal(nextEffectiveAt.UTC()) {
|
||||||
|
t.Fatalf("FinalizeEvent() effectiveAt = %#v, want %s", out.EffectiveAt, nextEffectiveAt.UTC())
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadMap, ok := out.Payload.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("FinalizeEvent() payload type = %T, want map[string]any", out.Payload)
|
||||||
|
}
|
||||||
|
if payloadMap["value"] != 1.234567 {
|
||||||
|
t.Fatalf("FinalizeEvent() payload value = %#v, want unrounded 1.234567", payloadMap["value"])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
package scheduler
|
|
||||||
|
|
||||||
// Placeholder for per-source worker logic:
|
|
||||||
// - ticker loop
|
|
||||||
// - jitter
|
|
||||||
// - backoff on errors
|
|
||||||
// - emits events into scheduler.Out
|
|
||||||
@@ -1,11 +1,6 @@
|
|||||||
package sinks
|
package sinks
|
||||||
|
|
||||||
import (
|
import "gitea.maximumdirect.net/ejr/feedkit/config"
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RegisterBuiltins registers sink drivers included in this binary.
|
// RegisterBuiltins registers sink drivers included in this binary.
|
||||||
//
|
//
|
||||||
@@ -17,39 +12,8 @@ func RegisterBuiltins(r *Registry) {
|
|||||||
return NewStdoutSink(cfg.Name), nil
|
return NewStdoutSink(cfg.Name), nil
|
||||||
})
|
})
|
||||||
|
|
||||||
// File sink: writes/archives events somewhere on disk.
|
|
||||||
r.Register("file", func(cfg config.SinkConfig) (Sink, error) {
|
|
||||||
return NewFileSinkFromConfig(cfg)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Postgres sink: persists events durably.
|
|
||||||
r.Register("postgres", func(cfg config.SinkConfig) (Sink, error) {
|
|
||||||
return NewPostgresSinkFromConfig(cfg)
|
|
||||||
})
|
|
||||||
|
|
||||||
// NATS sink: publishes events to a broker for downstream consumers.
|
// NATS sink: publishes events to a broker for downstream consumers.
|
||||||
r.Register("nats", func(cfg config.SinkConfig) (Sink, error) {
|
r.Register("nats", func(cfg config.SinkConfig) (Sink, error) {
|
||||||
return NewNATSSinkFromConfig(cfg)
|
return NewNATSSinkFromConfig(cfg)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---- helpers for validating sink params ----
|
|
||||||
//
|
|
||||||
// These helpers live in sinks (not config) on purpose:
|
|
||||||
// - config is domain-agnostic and should not embed driver-specific validation helpers.
|
|
||||||
// - sinks are adapters; validating their own params here keeps the logic near the driver.
|
|
||||||
|
|
||||||
func requireStringParam(cfg config.SinkConfig, key string) (string, error) {
|
|
||||||
v, ok := cfg.Params[key]
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("sink %q: params.%s is required", cfg.Name, key)
|
|
||||||
}
|
|
||||||
s, ok := v.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("sink %q: params.%s must be a string", cfg.Name, key)
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(s) == "" {
|
|
||||||
return "", fmt.Errorf("sink %q: params.%s cannot be empty", cfg.Name, key)
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|||||||
101
sinks/doc.go
Normal file
101
sinks/doc.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
// Package sinks defines the feedkit sink interface, sink driver registry, and
|
||||||
|
// built-in infrastructure sinks.
|
||||||
|
//
|
||||||
|
// External API surface:
|
||||||
|
// - Sink: adapter interface that consumes event.Event values
|
||||||
|
// - Registry / NewRegistry: named sink factory registry
|
||||||
|
// - RegisterBuiltins: registers the schema-free built-in sink drivers
|
||||||
|
//
|
||||||
|
// Built-in sink implementations:
|
||||||
|
// - stdout
|
||||||
|
// - nats
|
||||||
|
// - postgres
|
||||||
|
//
|
||||||
|
// Optional helpers from helpers.go:
|
||||||
|
// - PostgresFactory: returns a sink factory for the built-in Postgres sink
|
||||||
|
// using a provided downstream schema
|
||||||
|
//
|
||||||
|
// # NATS built-in overview
|
||||||
|
//
|
||||||
|
// The NATS sink publishes each event as JSON to a configured subject.
|
||||||
|
//
|
||||||
|
// Required params:
|
||||||
|
// - url: NATS server URL (for example, nats://localhost:4222)
|
||||||
|
// - subject: NATS subject to publish to
|
||||||
|
//
|
||||||
|
// Example config:
|
||||||
|
//
|
||||||
|
// sinks:
|
||||||
|
// - name: nats_main
|
||||||
|
// driver: nats
|
||||||
|
// params:
|
||||||
|
// url: nats://localhost:4222
|
||||||
|
// subject: feedkit.events
|
||||||
|
//
|
||||||
|
// # Postgres built-in overview
|
||||||
|
//
|
||||||
|
// The postgres sink is intentionally split between downstream daemon ownership
|
||||||
|
// and feedkit ownership:
|
||||||
|
// - downstream code registers table schema + event mapping functions
|
||||||
|
// - feedkit manages DB connection, create-if-missing DDL, transactional
|
||||||
|
// inserts, optional automatic retention pruning, and manual prune helpers
|
||||||
|
//
|
||||||
|
// Example config:
|
||||||
|
//
|
||||||
|
// sinks:
|
||||||
|
// - name: pg_main
|
||||||
|
// driver: postgres
|
||||||
|
// params:
|
||||||
|
// uri: postgres://localhost:5432/feedkit?sslmode=disable
|
||||||
|
// username: feedkit_user
|
||||||
|
// password: feedkit_pass
|
||||||
|
// prune: 3d # optional: prune rows older than now-3d on each write tx
|
||||||
|
//
|
||||||
|
// params.prune supports:
|
||||||
|
// - Go duration strings (72h, 90m, 30s, ...)
|
||||||
|
// - day/week suffixes (3d, 2w)
|
||||||
|
//
|
||||||
|
// If params.prune is omitted, automatic pruning is disabled.
|
||||||
|
//
|
||||||
|
// Example downstream wiring:
|
||||||
|
//
|
||||||
|
// sinkReg := sinks.NewRegistry()
|
||||||
|
// sinks.RegisterBuiltins(sinkReg)
|
||||||
|
// sinkReg.Register("postgres", sinks.PostgresFactory(sinks.PostgresSchema{
|
||||||
|
// Tables: []sinks.PostgresTable{
|
||||||
|
// {
|
||||||
|
// Name: "events",
|
||||||
|
// Columns: []sinks.PostgresColumn{
|
||||||
|
// {Name: "event_id", Type: "TEXT", Nullable: false},
|
||||||
|
// {Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
|
||||||
|
// {Name: "payload_json", Type: "JSONB", Nullable: false},
|
||||||
|
// },
|
||||||
|
// PrimaryKey: []string{"event_id"},
|
||||||
|
// PruneColumn: "emitted_at", // required for retention pruning
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// MapEvent: func(ctx context.Context, e event.Event) ([]sinks.PostgresWrite, error) {
|
||||||
|
// b, err := json.Marshal(e.Payload)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// return []sinks.PostgresWrite{
|
||||||
|
// {
|
||||||
|
// Table: "events",
|
||||||
|
// Values: map[string]any{
|
||||||
|
// "event_id": e.ID,
|
||||||
|
// "emitted_at": e.EmittedAt,
|
||||||
|
// "payload_json": string(b),
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// }, nil
|
||||||
|
// },
|
||||||
|
// }))
|
||||||
|
//
|
||||||
|
// Manual pruning via type assertion (administrative helpers):
|
||||||
|
//
|
||||||
|
// if p, ok := sink.(sinks.PostgresPruner); ok {
|
||||||
|
// _, _ = p.PruneKeepLatest(ctx, "events", 10000)
|
||||||
|
// _, _ = p.PruneOlderThan(ctx, "events", time.Now().UTC().AddDate(0, -1, 0))
|
||||||
|
// }
|
||||||
|
package sinks
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
package sinks
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
|
||||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
|
||||||
)
|
|
||||||
|
|
||||||
type FileSink struct {
|
|
||||||
name string
|
|
||||||
path string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewFileSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
|
|
||||||
path, err := requireStringParam(cfg, "path")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &FileSink{name: cfg.Name, path: path}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileSink) Name() string { return s.name }
|
|
||||||
|
|
||||||
func (s *FileSink) Consume(ctx context.Context, e event.Event) error {
|
|
||||||
_ = ctx
|
|
||||||
_ = e
|
|
||||||
return fmt.Errorf("file sink: TODO implement (path=%s)", s.path)
|
|
||||||
}
|
|
||||||
35
sinks/helpers.go
Normal file
35
sinks/helpers.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package sinks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// requireStringParam returns a non-empty string sink param.
|
||||||
|
//
|
||||||
|
// This helper is intentionally local to sinks rather than config so
|
||||||
|
// driver-specific validation stays close to the adapters that use it.
|
||||||
|
func requireStringParam(cfg config.SinkConfig, key string) (string, error) {
|
||||||
|
v, ok := cfg.Params[key]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("sink %q: params.%s is required", cfg.Name, key)
|
||||||
|
}
|
||||||
|
s, ok := v.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("sink %q: params.%s must be a string", cfg.Name, key)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(s) == "" {
|
||||||
|
return "", fmt.Errorf("sink %q: params.%s cannot be empty", cfg.Name, key)
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostgresFactory returns a sink factory that builds the built-in Postgres sink
|
||||||
|
// using the provided downstream schema definition.
|
||||||
|
func PostgresFactory(schema PostgresSchema) Factory {
|
||||||
|
return func(cfg config.SinkConfig) (Sink, error) {
|
||||||
|
return NewPostgresSinkFromConfig(cfg, schema)
|
||||||
|
}
|
||||||
|
}
|
||||||
46
sinks/helpers_test.go
Normal file
46
sinks/helpers_test.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package sinks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPostgresFactoryReturnsWorkingFactory(t *testing.T) {
|
||||||
|
factory := PostgresFactory(testPostgresSchema())
|
||||||
|
if factory == nil {
|
||||||
|
t.Fatalf("PostgresFactory() returned nil")
|
||||||
|
}
|
||||||
|
if _, err := factory(config.SinkConfig{}); err == nil {
|
||||||
|
t.Fatalf("factory(config) expected parameter validation error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testPostgresSchema() PostgresSchema {
|
||||||
|
return PostgresSchema{
|
||||||
|
Tables: []PostgresTable{
|
||||||
|
{
|
||||||
|
Name: "events",
|
||||||
|
Columns: []PostgresColumn{
|
||||||
|
{Name: "event_id", Type: "TEXT", Nullable: false},
|
||||||
|
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
|
||||||
|
},
|
||||||
|
PrimaryKey: []string{"event_id"},
|
||||||
|
PruneColumn: "emitted_at",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MapEvent: func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
|
||||||
|
return []PostgresWrite{
|
||||||
|
{
|
||||||
|
Table: "events",
|
||||||
|
Values: map[string]any{
|
||||||
|
"event_id": e.ID,
|
||||||
|
"emitted_at": e.EmittedAt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
type NATSSink struct {
|
type NATSSink struct {
|
||||||
name string
|
name string
|
||||||
url string
|
url string
|
||||||
exchange string
|
subject string
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
conn *nats.Conn
|
conn *nats.Conn
|
||||||
@@ -26,11 +26,11 @@ func NewNATSSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ex, err := requireStringParam(cfg, "exchange")
|
subject, err := requireStringParam(cfg, "subject")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &NATSSink{name: cfg.Name, url: url, exchange: ex}, nil
|
return &NATSSink{name: cfg.Name, url: url, subject: subject}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *NATSSink) Name() string { return r.name }
|
func (r *NATSSink) Name() string { return r.name }
|
||||||
@@ -59,7 +59,7 @@ func (r *NATSSink) Consume(ctx context.Context, e event.Event) error {
|
|||||||
if err := ctx.Err(); err != nil {
|
if err := ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := conn.Publish(r.exchange, b); err != nil {
|
if err := conn.Publish(r.subject, b); err != nil {
|
||||||
return fmt.Errorf("NATS sink: publish: %w", err)
|
return fmt.Errorf("NATS sink: publish: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
47
sinks/nats_test.go
Normal file
47
sinks/nats_test.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package sinks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewNATSSinkFromConfigRequiresSubject(t *testing.T) {
|
||||||
|
sink, err := NewNATSSinkFromConfig(config.SinkConfig{
|
||||||
|
Name: "nats-main",
|
||||||
|
Driver: "nats",
|
||||||
|
Params: map[string]any{
|
||||||
|
"url": "nats://localhost:4222",
|
||||||
|
"subject": "feedkit.events",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewNATSSinkFromConfig() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
natsSink, ok := sink.(*NATSSink)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("sink type = %T, want *NATSSink", sink)
|
||||||
|
}
|
||||||
|
if natsSink.subject != "feedkit.events" {
|
||||||
|
t.Fatalf("subject = %q, want feedkit.events", natsSink.subject)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewNATSSinkFromConfigRejectsLegacyExchange(t *testing.T) {
|
||||||
|
_, err := NewNATSSinkFromConfig(config.SinkConfig{
|
||||||
|
Name: "nats-main",
|
||||||
|
Driver: "nats",
|
||||||
|
Params: map[string]any{
|
||||||
|
"url": "nats://localhost:4222",
|
||||||
|
"exchange": "feedkit.events",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("NewNATSSinkFromConfig() expected error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "params.subject is required") {
|
||||||
|
t.Fatalf("error = %q, want params.subject is required", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,36 +2,460 @@ package sinks
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PostgresSink struct {
|
const postgresInitTimeout = 5 * time.Second
|
||||||
name string
|
|
||||||
dsn string
|
type postgresTx interface {
|
||||||
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||||
|
Commit() error
|
||||||
|
Rollback() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
|
type postgresExecer interface {
|
||||||
dsn, err := requireStringParam(cfg, "dsn")
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type postgresDB interface {
|
||||||
|
PingContext(ctx context.Context) error
|
||||||
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error)
|
||||||
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type sqlDBWrapper struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *sqlDBWrapper) PingContext(ctx context.Context) error {
|
||||||
|
return w.db.PingContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *sqlDBWrapper) BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error) {
|
||||||
|
tx, err := w.db.BeginTx(ctx, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &PostgresSink{name: cfg.Name, dsn: dsn}, nil
|
return &sqlTxWrapper{tx: tx}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *sqlDBWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||||
|
return w.db.ExecContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *sqlDBWrapper) Close() error {
|
||||||
|
return w.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type sqlTxWrapper struct {
|
||||||
|
tx *sql.Tx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *sqlTxWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||||
|
return w.tx.ExecContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *sqlTxWrapper) Commit() error {
|
||||||
|
return w.tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *sqlTxWrapper) Rollback() error {
|
||||||
|
return w.tx.Rollback()
|
||||||
|
}
|
||||||
|
|
||||||
|
var openPostgresDB = func(dsn string) (postgresDB, error) {
|
||||||
|
db, err := sql.Open("postgres", dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &sqlDBWrapper{db: db}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type PostgresSink struct {
|
||||||
|
name string
|
||||||
|
db postgresDB
|
||||||
|
schema postgresSchemaCompiled
|
||||||
|
pruneWindow time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresSinkFromConfig(cfg config.SinkConfig, schemaDef PostgresSchema) (Sink, error) {
|
||||||
|
uri, err := requireStringParam(cfg, "uri")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
username, err := requireStringParam(cfg, "username")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
password, err := requireStringParam(cfg, "password")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pruneWindow, err := parsePostgresPruneWindow(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
schema, err := compilePostgresSchema(schemaDef)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("postgres sink %q: compile schema: %w", cfg.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dsn, err := buildPostgresDSN(uri, username, password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("postgres sink %q: build dsn: %w", cfg.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := openPostgresDB(dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("postgres sink %q: open db: %w", cfg.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &PostgresSink{name: cfg.Name, db: db, schema: schema, pruneWindow: pruneWindow}
|
||||||
|
if err := s.initialize(); err != nil {
|
||||||
|
_ = db.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgresSink) Name() string { return p.name }
|
func (p *PostgresSink) Name() string { return p.name }
|
||||||
|
|
||||||
func (p *PostgresSink) Consume(ctx context.Context, e event.Event) error {
|
func (p *PostgresSink) Consume(ctx context.Context, e event.Event) error {
|
||||||
_ = ctx
|
|
||||||
|
|
||||||
// Boundary validation: if something upstream violated invariants,
|
// Boundary validation: if something upstream violated invariants,
|
||||||
// surface it loudly rather than printing partial nonsense.
|
// surface it loudly rather than writing corrupt rows.
|
||||||
if err := e.Validate(); err != nil {
|
if err := e.Validate(); err != nil {
|
||||||
return fmt.Errorf("postgres sink: invalid event: %w", err)
|
return fmt.Errorf("postgres sink: invalid event: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO implement Postgres transaction
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
writes, err := p.schema.mapEvent(ctx, e)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("postgres sink: map event: %w", err)
|
||||||
|
}
|
||||||
|
if len(writes) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tx, err := p.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("postgres sink: begin tx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
committed := false
|
||||||
|
defer func() {
|
||||||
|
if !committed {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, w := range writes {
|
||||||
|
tbl, err := p.schema.validateWrite(w)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("postgres sink: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
query, args, err := buildInsertSQL(tbl, w)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("postgres sink: build insert for table %q: %w", tbl.name, err)
|
||||||
|
}
|
||||||
|
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
|
||||||
|
return fmt.Errorf("postgres sink: insert into %q: %w", tbl.name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if p.pruneWindow > 0 {
|
||||||
|
cutoff := time.Now().UTC().Add(-p.pruneWindow)
|
||||||
|
for _, tableName := range p.schema.tableOrder {
|
||||||
|
tbl := p.schema.tables[tableName]
|
||||||
|
if _, err := execPruneOlderThan(ctx, tx, tbl, cutoff); err != nil {
|
||||||
|
return fmt.Errorf("postgres sink: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("postgres sink: commit tx: %w", err)
|
||||||
|
}
|
||||||
|
committed = true
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PostgresSink) PruneKeepLatest(ctx context.Context, table string, keep int) (int64, error) {
|
||||||
|
if keep < 0 {
|
||||||
|
return 0, fmt.Errorf("postgres sink: keep must be >= 0")
|
||||||
|
}
|
||||||
|
tbl, err := p.lookupTable(table)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
query := fmt.Sprintf(
|
||||||
|
`DELETE FROM %s WHERE ctid IN (
|
||||||
|
SELECT ctid FROM %s
|
||||||
|
ORDER BY %s DESC
|
||||||
|
OFFSET $1
|
||||||
|
)`,
|
||||||
|
quotePostgresIdent(tbl.name),
|
||||||
|
quotePostgresIdent(tbl.name),
|
||||||
|
quotePostgresIdent(tbl.pruneColumn),
|
||||||
|
)
|
||||||
|
|
||||||
|
res, err := p.db.ExecContext(ctx, query, keep)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("postgres sink: prune keep latest table %q: %w", tbl.name, err)
|
||||||
|
}
|
||||||
|
rows, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("postgres sink: prune keep latest table %q rows affected: %w", tbl.name, err)
|
||||||
|
}
|
||||||
|
return rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PostgresSink) PruneOlderThan(ctx context.Context, table string, cutoff time.Time) (int64, error) {
|
||||||
|
tbl, err := p.lookupTable(table)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := execPruneOlderThan(ctx, p.db, tbl, cutoff)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("postgres sink: %w", err)
|
||||||
|
}
|
||||||
|
return rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PostgresSink) PruneAllKeepLatest(ctx context.Context, keep int) (map[string]int64, error) {
|
||||||
|
counts := make(map[string]int64, len(p.schema.tableOrder))
|
||||||
|
for _, table := range p.schema.tableOrder {
|
||||||
|
n, err := p.PruneKeepLatest(ctx, table, keep)
|
||||||
|
if err != nil {
|
||||||
|
return counts, err
|
||||||
|
}
|
||||||
|
counts[table] = n
|
||||||
|
}
|
||||||
|
return counts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PostgresSink) PruneAllOlderThan(ctx context.Context, cutoff time.Time) (map[string]int64, error) {
|
||||||
|
counts := make(map[string]int64, len(p.schema.tableOrder))
|
||||||
|
for _, table := range p.schema.tableOrder {
|
||||||
|
n, err := p.PruneOlderThan(ctx, table, cutoff)
|
||||||
|
if err != nil {
|
||||||
|
return counts, err
|
||||||
|
}
|
||||||
|
counts[table] = n
|
||||||
|
}
|
||||||
|
return counts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PostgresSink) initialize() error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), postgresInitTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := p.db.PingContext(ctx); err != nil {
|
||||||
|
return fmt.Errorf("postgres sink %q: ping db: %w", p.name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tableName := range p.schema.tableOrder {
|
||||||
|
tbl := p.schema.tables[tableName]
|
||||||
|
|
||||||
|
createTableSQL := buildCreateTableSQL(tbl)
|
||||||
|
if _, err := p.db.ExecContext(ctx, createTableSQL); err != nil {
|
||||||
|
return fmt.Errorf("postgres sink %q: ensure table %q: %w", p.name, tbl.name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, idx := range tbl.indexes {
|
||||||
|
createIndexSQL := buildCreateIndexSQL(tbl.name, idx)
|
||||||
|
if _, err := p.db.ExecContext(ctx, createIndexSQL); err != nil {
|
||||||
|
return fmt.Errorf("postgres sink %q: ensure index %q on %q: %w", p.name, idx.Name, tbl.name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PostgresSink) lookupTable(table string) (postgresTableCompiled, error) {
|
||||||
|
table = strings.TrimSpace(table)
|
||||||
|
if table == "" {
|
||||||
|
return postgresTableCompiled{}, fmt.Errorf("postgres sink: table cannot be empty")
|
||||||
|
}
|
||||||
|
tbl, ok := p.schema.tables[table]
|
||||||
|
if !ok {
|
||||||
|
return postgresTableCompiled{}, fmt.Errorf("postgres sink: unknown table %q", table)
|
||||||
|
}
|
||||||
|
return tbl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildPostgresDSN(uri, username, password string) (string, error) {
|
||||||
|
u, err := url.Parse(strings.TrimSpace(uri))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid uri: %w", err)
|
||||||
|
}
|
||||||
|
if u.Scheme == "" {
|
||||||
|
return "", fmt.Errorf("invalid uri: missing scheme")
|
||||||
|
}
|
||||||
|
if u.Host == "" {
|
||||||
|
return "", fmt.Errorf("invalid uri: missing host")
|
||||||
|
}
|
||||||
|
u.User = url.UserPassword(username, password)
|
||||||
|
return u.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePostgresPruneWindow(cfg config.SinkConfig) (time.Duration, error) {
|
||||||
|
raw, ok := cfg.Params["prune"]
|
||||||
|
if !ok || raw == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s, ok := raw.(string)
|
||||||
|
if !ok {
|
||||||
|
return 0, fmt.Errorf("sink %q: params.prune must be a string duration (e.g. 72h, 3d, 2w)", cfg.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := parsePostgresPruneDuration(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("sink %q: params.prune %q is invalid: %w", cfg.Name, s, err)
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePostgresPruneDuration(raw string) (time.Duration, error) {
|
||||||
|
s := strings.TrimSpace(raw)
|
||||||
|
if s == "" {
|
||||||
|
return 0, fmt.Errorf("must not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
lower := strings.ToLower(s)
|
||||||
|
if strings.HasSuffix(lower, "d") || strings.HasSuffix(lower, "w") {
|
||||||
|
unit := lower[len(lower)-1]
|
||||||
|
n, err := strconv.Atoi(strings.TrimSpace(lower[:len(lower)-1]))
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("must use a positive integer before %q", string(unit))
|
||||||
|
}
|
||||||
|
if n <= 0 {
|
||||||
|
return 0, fmt.Errorf("must be > 0")
|
||||||
|
}
|
||||||
|
if unit == 'd' {
|
||||||
|
return time.Duration(n) * 24 * time.Hour, nil
|
||||||
|
}
|
||||||
|
return time.Duration(n) * 7 * 24 * time.Hour, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := time.ParseDuration(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("must be a Go duration or use d/w suffixes")
|
||||||
|
}
|
||||||
|
if d <= 0 {
|
||||||
|
return 0, fmt.Errorf("must be > 0")
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildPruneOlderThanSQL(tbl postgresTableCompiled) string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
`DELETE FROM %s WHERE %s < $1`,
|
||||||
|
quotePostgresIdent(tbl.name),
|
||||||
|
quotePostgresIdent(tbl.pruneColumn),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func execPruneOlderThan(ctx context.Context, execer postgresExecer, tbl postgresTableCompiled, cutoff time.Time) (int64, error) {
|
||||||
|
query := buildPruneOlderThanSQL(tbl)
|
||||||
|
res, err := execer.ExecContext(ctx, query, cutoff)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("prune older than table %q: %w", tbl.name, err)
|
||||||
|
}
|
||||||
|
rows, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("prune older than table %q rows affected: %w", tbl.name, err)
|
||||||
|
}
|
||||||
|
return rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCreateTableSQL(tbl postgresTableCompiled) string {
|
||||||
|
defs := make([]string, 0, len(tbl.columnOrder)+1)
|
||||||
|
for _, colName := range tbl.columnOrder {
|
||||||
|
col := tbl.columns[colName]
|
||||||
|
def := fmt.Sprintf("%s %s", quotePostgresIdent(col.Name), col.Type)
|
||||||
|
if !col.Nullable {
|
||||||
|
def += " NOT NULL"
|
||||||
|
}
|
||||||
|
defs = append(defs, def)
|
||||||
|
}
|
||||||
|
if len(tbl.primaryKey) > 0 {
|
||||||
|
defs = append(defs, fmt.Sprintf("PRIMARY KEY (%s)", joinQuotedIdents(tbl.primaryKey)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"CREATE TABLE IF NOT EXISTS %s (%s)",
|
||||||
|
quotePostgresIdent(tbl.name),
|
||||||
|
strings.Join(defs, ", "),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCreateIndexSQL(tableName string, idx PostgresIndex) string {
|
||||||
|
unique := ""
|
||||||
|
if idx.Unique {
|
||||||
|
unique = "UNIQUE "
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
|
||||||
|
unique,
|
||||||
|
quotePostgresIdent(idx.Name),
|
||||||
|
quotePostgresIdent(tableName),
|
||||||
|
joinQuotedIdents(idx.Columns),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildInsertSQL(tbl postgresTableCompiled, w PostgresWrite) (string, []any, error) {
|
||||||
|
cols := make([]string, 0, len(tbl.columnOrder))
|
||||||
|
args := make([]any, 0, len(tbl.columnOrder))
|
||||||
|
placeholders := make([]string, 0, len(tbl.columnOrder))
|
||||||
|
|
||||||
|
for i, colName := range tbl.columnOrder {
|
||||||
|
v, ok := w.Values[colName]
|
||||||
|
if !ok {
|
||||||
|
return "", nil, fmt.Errorf("missing value for column %q", colName)
|
||||||
|
}
|
||||||
|
cols = append(cols, quotePostgresIdent(colName))
|
||||||
|
args = append(args, v)
|
||||||
|
placeholders = append(placeholders, "$"+strconv.Itoa(i+1))
|
||||||
|
}
|
||||||
|
|
||||||
|
q := fmt.Sprintf(
|
||||||
|
"INSERT INTO %s (%s) VALUES (%s)",
|
||||||
|
quotePostgresIdent(tbl.name),
|
||||||
|
strings.Join(cols, ", "),
|
||||||
|
strings.Join(placeholders, ", "),
|
||||||
|
)
|
||||||
|
return q, args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinQuotedIdents(idents []string) string {
|
||||||
|
quoted := make([]string, 0, len(idents))
|
||||||
|
for _, s := range idents {
|
||||||
|
quoted = append(quoted, quotePostgresIdent(s))
|
||||||
|
}
|
||||||
|
return strings.Join(quoted, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func quotePostgresIdent(s string) string {
|
||||||
|
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
|
||||||
|
}
|
||||||
|
|||||||
239
sinks/postgres_schema.go
Normal file
239
sinks/postgres_schema.go
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
package sinks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostgresMapFunc maps one event into zero or more table writes.
|
||||||
|
//
|
||||||
|
// Returning zero writes means "this event is not mapped for this sink" and is
|
||||||
|
// treated as a non-error no-op.
|
||||||
|
type PostgresMapFunc func(ctx context.Context, e event.Event) ([]PostgresWrite, error)
|
||||||
|
|
||||||
|
// PostgresSchema describes the downstream-provided relational model and mapper
|
||||||
|
// for one configured postgres sink.
|
||||||
|
type PostgresSchema struct {
|
||||||
|
Tables []PostgresTable
|
||||||
|
MapEvent PostgresMapFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
type PostgresWrite struct {
|
||||||
|
Table string
|
||||||
|
Values map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
type PostgresTable struct {
|
||||||
|
Name string
|
||||||
|
Columns []PostgresColumn
|
||||||
|
PrimaryKey []string
|
||||||
|
PruneColumn string
|
||||||
|
Indexes []PostgresIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
type PostgresColumn struct {
|
||||||
|
Name string
|
||||||
|
Type string
|
||||||
|
Nullable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type PostgresIndex struct {
|
||||||
|
Name string
|
||||||
|
Columns []string
|
||||||
|
Unique bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostgresPruner is an optional interface exposed by PostgresSink so downstream
|
||||||
|
// applications can call retention helpers via type assertion.
|
||||||
|
type PostgresPruner interface {
|
||||||
|
PruneKeepLatest(ctx context.Context, table string, keep int) (int64, error)
|
||||||
|
PruneOlderThan(ctx context.Context, table string, cutoff time.Time) (int64, error)
|
||||||
|
PruneAllKeepLatest(ctx context.Context, keep int) (map[string]int64, error)
|
||||||
|
PruneAllOlderThan(ctx context.Context, cutoff time.Time) (map[string]int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type postgresSchemaCompiled struct {
|
||||||
|
tableOrder []string
|
||||||
|
tables map[string]postgresTableCompiled
|
||||||
|
mapEvent PostgresMapFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
type postgresTableCompiled struct {
|
||||||
|
name string
|
||||||
|
columns map[string]PostgresColumn
|
||||||
|
columnOrder []string
|
||||||
|
primaryKey []string
|
||||||
|
pruneColumn string
|
||||||
|
indexes []PostgresIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
func compilePostgresSchema(schema PostgresSchema) (postgresSchemaCompiled, error) {
|
||||||
|
if schema.MapEvent == nil {
|
||||||
|
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: map function is required")
|
||||||
|
}
|
||||||
|
if len(schema.Tables) == 0 {
|
||||||
|
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: at least one table is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
compiled := postgresSchemaCompiled{
|
||||||
|
tableOrder: make([]string, 0, len(schema.Tables)),
|
||||||
|
tables: make(map[string]postgresTableCompiled, len(schema.Tables)),
|
||||||
|
mapEvent: schema.MapEvent,
|
||||||
|
}
|
||||||
|
|
||||||
|
seenTables := map[string]bool{}
|
||||||
|
seenIndexes := map[string]bool{}
|
||||||
|
|
||||||
|
for i, tbl := range schema.Tables {
|
||||||
|
tableName := strings.TrimSpace(tbl.Name)
|
||||||
|
if tableName == "" {
|
||||||
|
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: tables[%d].name is required", i)
|
||||||
|
}
|
||||||
|
if seenTables[tableName] {
|
||||||
|
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: duplicate table name %q", tableName)
|
||||||
|
}
|
||||||
|
seenTables[tableName] = true
|
||||||
|
|
||||||
|
if len(tbl.Columns) == 0 {
|
||||||
|
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q must define at least one column", tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
colOrder := make([]string, 0, len(tbl.Columns))
|
||||||
|
colMap := make(map[string]PostgresColumn, len(tbl.Columns))
|
||||||
|
for j, col := range tbl.Columns {
|
||||||
|
colName := strings.TrimSpace(col.Name)
|
||||||
|
if colName == "" {
|
||||||
|
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q columns[%d].name is required", tableName, j)
|
||||||
|
}
|
||||||
|
if _, exists := colMap[colName]; exists {
|
||||||
|
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q duplicate column %q", tableName, colName)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(col.Type) == "" {
|
||||||
|
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q column %q type is required", tableName, colName)
|
||||||
|
}
|
||||||
|
colOrder = append(colOrder, colName)
|
||||||
|
colMap[colName] = PostgresColumn{
|
||||||
|
Name: colName,
|
||||||
|
Type: strings.TrimSpace(col.Type),
|
||||||
|
Nullable: col.Nullable,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pk, err := validatePostgresColumnSet(tableName, "primary key", tbl.PrimaryKey, colMap)
|
||||||
|
if err != nil {
|
||||||
|
return postgresSchemaCompiled{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pruneCol := strings.TrimSpace(tbl.PruneColumn)
|
||||||
|
if pruneCol == "" {
|
||||||
|
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q prune column is required", tableName)
|
||||||
|
}
|
||||||
|
if _, ok := colMap[pruneCol]; !ok {
|
||||||
|
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q prune column %q not found in columns", tableName, pruneCol)
|
||||||
|
}
|
||||||
|
|
||||||
|
indexes := make([]PostgresIndex, 0, len(tbl.Indexes))
|
||||||
|
for j, idx := range tbl.Indexes {
|
||||||
|
idxName := strings.TrimSpace(idx.Name)
|
||||||
|
if idxName == "" {
|
||||||
|
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q indexes[%d].name is required", tableName, j)
|
||||||
|
}
|
||||||
|
if len(idx.Columns) == 0 {
|
||||||
|
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q index %q must include at least one column", tableName, idxName)
|
||||||
|
}
|
||||||
|
if seenIndexes[idxName] {
|
||||||
|
return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: duplicate index name %q", idxName)
|
||||||
|
}
|
||||||
|
seenIndexes[idxName] = true
|
||||||
|
|
||||||
|
idxCols, err := validatePostgresColumnSet(tableName, fmt.Sprintf("index %q columns", idxName), idx.Columns, colMap)
|
||||||
|
if err != nil {
|
||||||
|
return postgresSchemaCompiled{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
indexes = append(indexes, PostgresIndex{
|
||||||
|
Name: idxName,
|
||||||
|
Columns: idxCols,
|
||||||
|
Unique: idx.Unique,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
compiled.tableOrder = append(compiled.tableOrder, tableName)
|
||||||
|
compiled.tables[tableName] = postgresTableCompiled{
|
||||||
|
name: tableName,
|
||||||
|
columns: colMap,
|
||||||
|
columnOrder: colOrder,
|
||||||
|
primaryKey: pk,
|
||||||
|
pruneColumn: pruneCol,
|
||||||
|
indexes: indexes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return compiled, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validatePostgresColumnSet(tableName, label string, cols []string, colMap map[string]PostgresColumn) ([]string, error) {
|
||||||
|
if len(cols) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(cols))
|
||||||
|
seen := map[string]bool{}
|
||||||
|
for _, c := range cols {
|
||||||
|
name := strings.TrimSpace(c)
|
||||||
|
if name == "" {
|
||||||
|
return nil, fmt.Errorf("postgres schema: table %q %s contains empty column name", tableName, label)
|
||||||
|
}
|
||||||
|
if seen[name] {
|
||||||
|
return nil, fmt.Errorf("postgres schema: table %q %s contains duplicate column %q", tableName, label, name)
|
||||||
|
}
|
||||||
|
if _, ok := colMap[name]; !ok {
|
||||||
|
return nil, fmt.Errorf("postgres schema: table %q %s references unknown column %q", tableName, label, name)
|
||||||
|
}
|
||||||
|
seen[name] = true
|
||||||
|
out = append(out, name)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s postgresSchemaCompiled) validateWrite(w PostgresWrite) (postgresTableCompiled, error) {
|
||||||
|
tableName := strings.TrimSpace(w.Table)
|
||||||
|
if tableName == "" {
|
||||||
|
return postgresTableCompiled{}, fmt.Errorf("write table is required")
|
||||||
|
}
|
||||||
|
t, ok := s.tables[tableName]
|
||||||
|
if !ok {
|
||||||
|
return postgresTableCompiled{}, fmt.Errorf("table %q is not defined in postgres schema", tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(w.Values) == 0 {
|
||||||
|
return postgresTableCompiled{}, fmt.Errorf("write for table %q must include values", tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k := range w.Values {
|
||||||
|
if _, ok := t.columns[k]; !ok {
|
||||||
|
return postgresTableCompiled{}, fmt.Errorf("write for table %q includes unknown column %q", tableName, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(w.Values) != len(t.columnOrder) {
|
||||||
|
return postgresTableCompiled{}, fmt.Errorf("write for table %q must include all declared columns", tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, col := range t.columnOrder {
|
||||||
|
v, ok := w.Values[col]
|
||||||
|
if !ok {
|
||||||
|
return postgresTableCompiled{}, fmt.Errorf("write for table %q is missing column %q", tableName, col)
|
||||||
|
}
|
||||||
|
if v == nil {
|
||||||
|
if c := t.columns[col]; !c.Nullable {
|
||||||
|
return postgresTableCompiled{}, fmt.Errorf("write for table %q has nil value for non-null column %q", tableName, col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
741
sinks/postgres_test.go
Normal file
741
sinks/postgres_test.go
Normal file
@@ -0,0 +1,741 @@
|
|||||||
|
package sinks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeResult struct {
|
||||||
|
rows int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r fakeResult) LastInsertId() (int64, error) { return 0, errors.New("unsupported") }
|
||||||
|
func (r fakeResult) RowsAffected() (int64, error) { return r.rows, nil }
|
||||||
|
|
||||||
|
type execCall struct {
|
||||||
|
query string
|
||||||
|
args []any
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeTx struct {
|
||||||
|
execCalls []execCall
|
||||||
|
execErr error
|
||||||
|
execErrOnCall int
|
||||||
|
execRows int64
|
||||||
|
commitErr error
|
||||||
|
rollbackErr error
|
||||||
|
commitCalls int
|
||||||
|
rollbackCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *fakeTx) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) {
|
||||||
|
t.execCalls = append(t.execCalls, execCall{query: query, args: append([]any(nil), args...)})
|
||||||
|
if t.execErr != nil && t.execErrOnCall == len(t.execCalls) {
|
||||||
|
return nil, t.execErr
|
||||||
|
}
|
||||||
|
return fakeResult{rows: t.execRows}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *fakeTx) Commit() error {
|
||||||
|
t.commitCalls++
|
||||||
|
return t.commitErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *fakeTx) Rollback() error {
|
||||||
|
t.rollbackCalls++
|
||||||
|
return t.rollbackErr
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeDB struct {
|
||||||
|
pingErr error
|
||||||
|
beginErr error
|
||||||
|
execErr error
|
||||||
|
execErrOnCall int
|
||||||
|
execRows int64
|
||||||
|
pingCalls int
|
||||||
|
beginCalls int
|
||||||
|
execCalls []execCall
|
||||||
|
closeCalls int
|
||||||
|
tx *fakeTx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *fakeDB) PingContext(_ context.Context) error {
|
||||||
|
d.pingCalls++
|
||||||
|
return d.pingErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *fakeDB) BeginTx(_ context.Context, _ *sql.TxOptions) (postgresTx, error) {
|
||||||
|
d.beginCalls++
|
||||||
|
if d.beginErr != nil {
|
||||||
|
return nil, d.beginErr
|
||||||
|
}
|
||||||
|
if d.tx == nil {
|
||||||
|
d.tx = &fakeTx{}
|
||||||
|
}
|
||||||
|
return d.tx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *fakeDB) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) {
|
||||||
|
d.execCalls = append(d.execCalls, execCall{query: query, args: append([]any(nil), args...)})
|
||||||
|
if d.execErr != nil && d.execErrOnCall == len(d.execCalls) {
|
||||||
|
return nil, d.execErr
|
||||||
|
}
|
||||||
|
return fakeResult{rows: d.execRows}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *fakeDB) Close() error {
|
||||||
|
d.closeCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func withPostgresTestState(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
oldOpen := openPostgresDB
|
||||||
|
t.Cleanup(func() {
|
||||||
|
openPostgresDB = oldOpen
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func validTestEvent() event.Event {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
return event.Event{
|
||||||
|
ID: "evt-1",
|
||||||
|
Kind: event.Kind("observation"),
|
||||||
|
Source: "source-1",
|
||||||
|
EmittedAt: now,
|
||||||
|
Payload: map[string]any{
|
||||||
|
"x": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func schemaOneTable(mapFn PostgresMapFunc) PostgresSchema {
|
||||||
|
return PostgresSchema{
|
||||||
|
Tables: []PostgresTable{
|
||||||
|
{
|
||||||
|
Name: "events",
|
||||||
|
Columns: []PostgresColumn{
|
||||||
|
{Name: "event_id", Type: "TEXT", Nullable: false},
|
||||||
|
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
|
||||||
|
{Name: "payload_json", Type: "JSONB", Nullable: false},
|
||||||
|
},
|
||||||
|
PrimaryKey: []string{"event_id"},
|
||||||
|
PruneColumn: "emitted_at",
|
||||||
|
Indexes: []PostgresIndex{
|
||||||
|
{Name: "idx_events_emitted_at", Columns: []string{"emitted_at"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MapEvent: mapFn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func schemaTwoTables(mapFn PostgresMapFunc) PostgresSchema {
|
||||||
|
return PostgresSchema{
|
||||||
|
Tables: []PostgresTable{
|
||||||
|
{
|
||||||
|
Name: "events",
|
||||||
|
Columns: []PostgresColumn{
|
||||||
|
{Name: "event_id", Type: "TEXT", Nullable: false},
|
||||||
|
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
|
||||||
|
},
|
||||||
|
PrimaryKey: []string{"event_id"},
|
||||||
|
PruneColumn: "emitted_at",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "event_payloads",
|
||||||
|
Columns: []PostgresColumn{
|
||||||
|
{Name: "event_id", Type: "TEXT", Nullable: false},
|
||||||
|
{Name: "payload_json", Type: "JSONB", Nullable: false},
|
||||||
|
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
|
||||||
|
},
|
||||||
|
PrimaryKey: []string{"event_id"},
|
||||||
|
PruneColumn: "emitted_at",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MapEvent: mapFn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCompileSchema(t *testing.T, s PostgresSchema) postgresSchemaCompiled {
|
||||||
|
t.Helper()
|
||||||
|
compiled, err := compilePostgresSchema(s)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("compile schema: %v", err)
|
||||||
|
}
|
||||||
|
return compiled
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompilePostgresSchemaRejectsInvalidSchema(t *testing.T) {
|
||||||
|
_, err := compilePostgresSchema(PostgresSchema{
|
||||||
|
Tables: []PostgresTable{
|
||||||
|
{
|
||||||
|
Name: "events",
|
||||||
|
Columns: []PostgresColumn{
|
||||||
|
{Name: "id", Type: "TEXT", Nullable: false},
|
||||||
|
},
|
||||||
|
PruneColumn: "missing_col",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected invalid schema error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "prune column") {
|
||||||
|
t.Fatalf("unexpected schema validation error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = compilePostgresSchema(PostgresSchema{
|
||||||
|
Tables: []PostgresTable{
|
||||||
|
{
|
||||||
|
Name: "events",
|
||||||
|
Columns: []PostgresColumn{
|
||||||
|
{Name: "id", Type: "TEXT", Nullable: false},
|
||||||
|
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
|
||||||
|
},
|
||||||
|
PruneColumn: "emitted_at",
|
||||||
|
Indexes: []PostgresIndex{
|
||||||
|
{Name: "idx_events_empty", Columns: nil},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected invalid index schema error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "at least one column") {
|
||||||
|
t.Fatalf("unexpected index validation error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresFactoryBuildsMultipleSinksWithSameSchema(t *testing.T) {
|
||||||
|
withPostgresTestState(t)
|
||||||
|
|
||||||
|
dbs := []*fakeDB{{}, {}}
|
||||||
|
var gotDSNs []string
|
||||||
|
openPostgresDB = func(dsn string) (postgresDB, error) {
|
||||||
|
gotDSNs = append(gotDSNs, dsn)
|
||||||
|
db := dbs[len(gotDSNs)-1]
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
factory := PostgresFactory(schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
|
||||||
|
return nil, nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
for _, name := range []string{"pg_a", "pg_b"} {
|
||||||
|
sink, err := factory(config.SinkConfig{
|
||||||
|
Name: name,
|
||||||
|
Driver: "postgres",
|
||||||
|
Params: map[string]any{
|
||||||
|
"uri": "postgres://localhost/db",
|
||||||
|
"username": "user",
|
||||||
|
"password": "pass",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("factory(%q) error = %v", name, err)
|
||||||
|
}
|
||||||
|
if sink == nil {
|
||||||
|
t.Fatalf("factory(%q) returned nil sink", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(gotDSNs) != 2 {
|
||||||
|
t.Fatalf("len(gotDSNs) = %d, want 2", len(gotDSNs))
|
||||||
|
}
|
||||||
|
for i, db := range dbs {
|
||||||
|
if db.pingCalls != 1 {
|
||||||
|
t.Fatalf("db[%d] pingCalls = %d, want 1", i, db.pingCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPostgresSinkFromConfigMissingParams(t *testing.T) {
|
||||||
|
withPostgresTestState(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
params map[string]any
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "missing uri", params: map[string]any{"username": "u", "password": "p"}, want: "params.uri"},
|
||||||
|
{name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p"}, want: "params.username"},
|
||||||
|
{name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u"}, want: "params.password"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
|
||||||
|
Name: "pg",
|
||||||
|
Driver: "postgres",
|
||||||
|
Params: tc.params,
|
||||||
|
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), tc.want) {
|
||||||
|
t.Fatalf("expected %q in error, got: %v", tc.want, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPostgresSinkFromConfigRejectsInvalidSchema(t *testing.T) {
|
||||||
|
withPostgresTestState(t)
|
||||||
|
|
||||||
|
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
|
||||||
|
Name: "pg",
|
||||||
|
Driver: "postgres",
|
||||||
|
Params: map[string]any{
|
||||||
|
"uri": "postgres://localhost/db",
|
||||||
|
"username": "user",
|
||||||
|
"password": "pass",
|
||||||
|
},
|
||||||
|
}, PostgresSchema{
|
||||||
|
Tables: []PostgresTable{
|
||||||
|
{
|
||||||
|
Name: "events",
|
||||||
|
Columns: []PostgresColumn{
|
||||||
|
{Name: "id", Type: "TEXT", Nullable: false},
|
||||||
|
},
|
||||||
|
PruneColumn: "missing_col",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected invalid schema error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "compile schema") {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPostgresSinkFromConfigEagerInit(t *testing.T) {
|
||||||
|
withPostgresTestState(t)
|
||||||
|
|
||||||
|
db := &fakeDB{}
|
||||||
|
var gotDSN string
|
||||||
|
openPostgresDB = func(dsn string) (postgresDB, error) {
|
||||||
|
gotDSN = dsn
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := NewPostgresSinkFromConfig(config.SinkConfig{
|
||||||
|
Name: "pg",
|
||||||
|
Driver: "postgres",
|
||||||
|
Params: map[string]any{
|
||||||
|
"uri": "postgres://db.example.local:5432/feedkit?sslmode=disable",
|
||||||
|
"username": "app_user",
|
||||||
|
"password": "app_pass",
|
||||||
|
},
|
||||||
|
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new postgres sink: %v", err)
|
||||||
|
}
|
||||||
|
if s == nil {
|
||||||
|
t.Fatalf("expected sink")
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.pingCalls != 1 {
|
||||||
|
t.Fatalf("expected one ping, got %d", db.pingCalls)
|
||||||
|
}
|
||||||
|
if len(db.execCalls) != 2 {
|
||||||
|
t.Fatalf("expected 2 init exec calls (table + index), got %d", len(db.execCalls))
|
||||||
|
}
|
||||||
|
if !strings.Contains(db.execCalls[0].query, `CREATE TABLE IF NOT EXISTS "events"`) {
|
||||||
|
t.Fatalf("unexpected create table query: %s", db.execCalls[0].query)
|
||||||
|
}
|
||||||
|
if !strings.Contains(db.execCalls[1].query, `CREATE INDEX IF NOT EXISTS "idx_events_emitted_at"`) {
|
||||||
|
t.Fatalf("unexpected create index query: %s", db.execCalls[1].query)
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(gotDSN)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse dsn: %v", err)
|
||||||
|
}
|
||||||
|
if u.User == nil || u.User.Username() != "app_user" {
|
||||||
|
t.Fatalf("dsn missing username: %q", gotDSN)
|
||||||
|
}
|
||||||
|
pass, ok := u.User.Password()
|
||||||
|
if !ok || pass != "app_pass" {
|
||||||
|
t.Fatalf("dsn missing password: %q", gotDSN)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPostgresSinkFromConfigInitFailureClosesDB(t *testing.T) {
|
||||||
|
withPostgresTestState(t)
|
||||||
|
|
||||||
|
db := &fakeDB{execErrOnCall: 1, execErr: errors.New("ddl failed")}
|
||||||
|
openPostgresDB = func(_ string) (postgresDB, error) {
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
|
||||||
|
Name: "pg",
|
||||||
|
Driver: "postgres",
|
||||||
|
Params: map[string]any{
|
||||||
|
"uri": "postgres://localhost/db",
|
||||||
|
"username": "user",
|
||||||
|
"password": "pass",
|
||||||
|
},
|
||||||
|
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected init error")
|
||||||
|
}
|
||||||
|
if db.closeCalls != 1 {
|
||||||
|
t.Fatalf("expected db close on init failure")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPostgresSinkFromConfigPruneParamAccepted(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{name: "go duration", in: "72h", want: 72 * time.Hour},
|
||||||
|
{name: "days suffix", in: "3d", want: 72 * time.Hour},
|
||||||
|
{name: "weeks suffix", in: "2w", want: 14 * 24 * time.Hour},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
withPostgresTestState(t)
|
||||||
|
|
||||||
|
openPostgresDB = func(_ string) (postgresDB, error) {
|
||||||
|
return &fakeDB{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := NewPostgresSinkFromConfig(config.SinkConfig{
|
||||||
|
Name: "pg",
|
||||||
|
Driver: "postgres",
|
||||||
|
Params: map[string]any{
|
||||||
|
"uri": "postgres://localhost/db",
|
||||||
|
"username": "user",
|
||||||
|
"password": "pass",
|
||||||
|
"prune": tc.in,
|
||||||
|
},
|
||||||
|
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new postgres sink: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pg, ok := s.(*PostgresSink)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected *PostgresSink, got %T", s)
|
||||||
|
}
|
||||||
|
if pg.pruneWindow != tc.want {
|
||||||
|
t.Fatalf("prune window = %s, want %s", pg.pruneWindow, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPostgresSinkFromConfigPruneParamRejected(t *testing.T) {
|
||||||
|
withPostgresTestState(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in any
|
||||||
|
}{
|
||||||
|
{name: "empty", in: ""},
|
||||||
|
{name: "zero", in: "0"},
|
||||||
|
{name: "negative", in: "-1h"},
|
||||||
|
{name: "malformed", in: "abc"},
|
||||||
|
{name: "fractional day", in: "1.5d"},
|
||||||
|
{name: "wrong type", in: 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
|
||||||
|
Name: "pg",
|
||||||
|
Driver: "postgres",
|
||||||
|
Params: map[string]any{
|
||||||
|
"uri": "postgres://localhost/db",
|
||||||
|
"username": "user",
|
||||||
|
"password": "pass",
|
||||||
|
"prune": tc.in,
|
||||||
|
},
|
||||||
|
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "params.prune") {
|
||||||
|
t.Fatalf("expected params.prune error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresSinkConsumeInvalidEvent(t *testing.T) {
|
||||||
|
db := &fakeDB{}
|
||||||
|
called := 0
|
||||||
|
sink := &PostgresSink{
|
||||||
|
name: "pg",
|
||||||
|
db: db,
|
||||||
|
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
|
||||||
|
called++
|
||||||
|
return nil, nil
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := sink.Consume(context.Background(), event.Event{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected invalid event error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "invalid event") {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if called != 0 {
|
||||||
|
t.Fatalf("expected mapper not called for invalid events")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresSinkConsumeUnmappedEventIsNoOp(t *testing.T) {
|
||||||
|
db := &fakeDB{}
|
||||||
|
sink := &PostgresSink{
|
||||||
|
name: "pg",
|
||||||
|
db: db,
|
||||||
|
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
|
||||||
|
return nil, nil
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
|
||||||
|
t.Fatalf("consume: %v", err)
|
||||||
|
}
|
||||||
|
if db.beginCalls != 0 {
|
||||||
|
t.Fatalf("expected no transaction for unmapped events")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresSinkConsumeOneEventWritesMultipleTablesAtomically(t *testing.T) {
|
||||||
|
tx := &fakeTx{}
|
||||||
|
db := &fakeDB{tx: tx}
|
||||||
|
sink := &PostgresSink{
|
||||||
|
name: "pg",
|
||||||
|
db: db,
|
||||||
|
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
|
||||||
|
return []PostgresWrite{
|
||||||
|
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
|
||||||
|
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
|
||||||
|
}, nil
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
|
||||||
|
t.Fatalf("consume: %v", err)
|
||||||
|
}
|
||||||
|
if db.beginCalls != 1 {
|
||||||
|
t.Fatalf("expected one transaction begin, got %d", db.beginCalls)
|
||||||
|
}
|
||||||
|
if len(tx.execCalls) != 2 {
|
||||||
|
t.Fatalf("expected 2 insert statements, got %d", len(tx.execCalls))
|
||||||
|
}
|
||||||
|
if tx.commitCalls != 1 {
|
||||||
|
t.Fatalf("expected one commit, got %d", tx.commitCalls)
|
||||||
|
}
|
||||||
|
if tx.rollbackCalls != 0 {
|
||||||
|
t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresSinkConsumeInsertFailureRollsBack(t *testing.T) {
|
||||||
|
tx := &fakeTx{execErrOnCall: 2, execErr: errors.New("duplicate key")}
|
||||||
|
db := &fakeDB{tx: tx}
|
||||||
|
sink := &PostgresSink{
|
||||||
|
name: "pg",
|
||||||
|
db: db,
|
||||||
|
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
|
||||||
|
return []PostgresWrite{
|
||||||
|
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
|
||||||
|
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
|
||||||
|
}, nil
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := sink.Consume(context.Background(), validTestEvent())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected insert error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "insert into") {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if tx.commitCalls != 0 {
|
||||||
|
t.Fatalf("expected no commit")
|
||||||
|
}
|
||||||
|
if tx.rollbackCalls != 1 {
|
||||||
|
t.Fatalf("expected rollback, got %d", tx.rollbackCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresSinkConsumeAutoPruneRunsInSameTransaction(t *testing.T) {
|
||||||
|
tx := &fakeTx{}
|
||||||
|
db := &fakeDB{tx: tx}
|
||||||
|
sink := &PostgresSink{
|
||||||
|
name: "pg",
|
||||||
|
db: db,
|
||||||
|
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
|
||||||
|
return []PostgresWrite{
|
||||||
|
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
|
||||||
|
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
|
||||||
|
}, nil
|
||||||
|
})),
|
||||||
|
pruneWindow: 24 * time.Hour,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
|
||||||
|
t.Fatalf("consume: %v", err)
|
||||||
|
}
|
||||||
|
if len(tx.execCalls) != 4 {
|
||||||
|
t.Fatalf("expected 4 tx statements (2 inserts + 2 prunes), got %d", len(tx.execCalls))
|
||||||
|
}
|
||||||
|
if !strings.Contains(tx.execCalls[2].query, `DELETE FROM "events"`) {
|
||||||
|
t.Fatalf("expected prune delete for events, got %s", tx.execCalls[2].query)
|
||||||
|
}
|
||||||
|
if !strings.Contains(tx.execCalls[3].query, `DELETE FROM "event_payloads"`) {
|
||||||
|
t.Fatalf("expected prune delete for event_payloads, got %s", tx.execCalls[3].query)
|
||||||
|
}
|
||||||
|
if tx.commitCalls != 1 {
|
||||||
|
t.Fatalf("expected one commit, got %d", tx.commitCalls)
|
||||||
|
}
|
||||||
|
if tx.rollbackCalls != 0 {
|
||||||
|
t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresSinkConsumeAutoPruneFailureRollsBack(t *testing.T) {
|
||||||
|
tx := &fakeTx{execErrOnCall: 3, execErr: errors.New("prune failed")}
|
||||||
|
db := &fakeDB{tx: tx}
|
||||||
|
sink := &PostgresSink{
|
||||||
|
name: "pg",
|
||||||
|
db: db,
|
||||||
|
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
|
||||||
|
return []PostgresWrite{
|
||||||
|
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
|
||||||
|
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
|
||||||
|
}, nil
|
||||||
|
})),
|
||||||
|
pruneWindow: 24 * time.Hour,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := sink.Consume(context.Background(), validTestEvent())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected prune error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "prune older than") {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if tx.commitCalls != 0 {
|
||||||
|
t.Fatalf("expected no commit")
|
||||||
|
}
|
||||||
|
if tx.rollbackCalls != 1 {
|
||||||
|
t.Fatalf("expected rollback, got %d", tx.rollbackCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresSinkPrunePerTable(t *testing.T) {
|
||||||
|
db := &fakeDB{execRows: 7}
|
||||||
|
sink := &PostgresSink{
|
||||||
|
name: "pg",
|
||||||
|
db: db,
|
||||||
|
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
|
||||||
|
return nil, nil
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := sink.PruneKeepLatest(context.Background(), "events", 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("prune keep latest: %v", err)
|
||||||
|
}
|
||||||
|
if rows != 7 {
|
||||||
|
t.Fatalf("unexpected rows affected: %d", rows)
|
||||||
|
}
|
||||||
|
if len(db.execCalls) != 1 {
|
||||||
|
t.Fatalf("expected one prune query")
|
||||||
|
}
|
||||||
|
if !strings.Contains(db.execCalls[0].query, `ORDER BY "emitted_at" DESC`) {
|
||||||
|
t.Fatalf("unexpected keep-latest query: %s", db.execCalls[0].query)
|
||||||
|
}
|
||||||
|
if len(db.execCalls[0].args) != 1 || db.execCalls[0].args[0] != 10 {
|
||||||
|
t.Fatalf("unexpected keep-latest args: %#v", db.execCalls[0].args)
|
||||||
|
}
|
||||||
|
|
||||||
|
cutoff := time.Now().UTC().Add(-24 * time.Hour)
|
||||||
|
rows, err = sink.PruneOlderThan(context.Background(), "events", cutoff)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("prune older than: %v", err)
|
||||||
|
}
|
||||||
|
if rows != 7 {
|
||||||
|
t.Fatalf("unexpected rows affected: %d", rows)
|
||||||
|
}
|
||||||
|
if len(db.execCalls) != 2 {
|
||||||
|
t.Fatalf("expected two prune queries")
|
||||||
|
}
|
||||||
|
if !strings.Contains(db.execCalls[1].query, `WHERE "emitted_at" < $1`) {
|
||||||
|
t.Fatalf("unexpected older-than query: %s", db.execCalls[1].query)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresSinkPruneAllTables(t *testing.T) {
|
||||||
|
db := &fakeDB{execRows: 3}
|
||||||
|
sink := &PostgresSink{
|
||||||
|
name: "pg",
|
||||||
|
db: db,
|
||||||
|
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
|
||||||
|
return nil, nil
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
|
||||||
|
keepCounts, err := sink.PruneAllKeepLatest(context.Background(), 5)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("prune all keep latest: %v", err)
|
||||||
|
}
|
||||||
|
if len(keepCounts) != 2 || keepCounts["events"] != 3 || keepCounts["event_payloads"] != 3 {
|
||||||
|
t.Fatalf("unexpected keep counts: %#v", keepCounts)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.execCalls = nil
|
||||||
|
olderCounts, err := sink.PruneAllOlderThan(context.Background(), time.Now().UTC())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("prune all older than: %v", err)
|
||||||
|
}
|
||||||
|
if len(olderCounts) != 2 || olderCounts["events"] != 3 || olderCounts["event_payloads"] != 3 {
|
||||||
|
t.Fatalf("unexpected older-than counts: %#v", olderCounts)
|
||||||
|
}
|
||||||
|
if len(db.execCalls) != 2 {
|
||||||
|
t.Fatalf("expected one prune call per table")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresSinkPruneErrors(t *testing.T) {
|
||||||
|
db := &fakeDB{}
|
||||||
|
sink := &PostgresSink{
|
||||||
|
name: "pg",
|
||||||
|
db: db,
|
||||||
|
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
|
||||||
|
return nil, nil
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := sink.PruneKeepLatest(context.Background(), "events", -1); err == nil {
|
||||||
|
t.Fatalf("expected negative keep error")
|
||||||
|
}
|
||||||
|
if _, err := sink.PruneOlderThan(context.Background(), "missing", time.Now().UTC()); err == nil {
|
||||||
|
t.Fatalf("expected unknown table error")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package sinks
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||||
)
|
)
|
||||||
@@ -21,13 +22,40 @@ func NewRegistry() *Registry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Registry) Register(driver string, f Factory) {
|
func (r *Registry) Register(driver string, f Factory) {
|
||||||
|
if r == nil {
|
||||||
|
panic("sinks.Registry.Register: registry cannot be nil")
|
||||||
|
}
|
||||||
|
driver = strings.TrimSpace(driver)
|
||||||
|
if driver == "" {
|
||||||
|
panic("sinks.Registry.Register: driver cannot be empty")
|
||||||
|
}
|
||||||
|
if f == nil {
|
||||||
|
panic(fmt.Sprintf("sinks.Registry.Register: factory cannot be nil (driver=%q)", driver))
|
||||||
|
}
|
||||||
|
if r.byDriver == nil {
|
||||||
|
r.byDriver = map[string]Factory{}
|
||||||
|
}
|
||||||
|
if _, exists := r.byDriver[driver]; exists {
|
||||||
|
panic(fmt.Sprintf("sinks.Registry.Register: driver %q already registered", driver))
|
||||||
|
}
|
||||||
r.byDriver[driver] = f
|
r.byDriver[driver] = f
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Registry) Build(cfg config.SinkConfig) (Sink, error) {
|
func (r *Registry) Build(cfg config.SinkConfig) (Sink, error) {
|
||||||
f, ok := r.byDriver[cfg.Driver]
|
if r == nil {
|
||||||
|
return nil, fmt.Errorf("sinks registry is nil")
|
||||||
|
}
|
||||||
|
driver := strings.TrimSpace(cfg.Driver)
|
||||||
|
f, ok := r.byDriver[driver]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unknown sink driver: %q", cfg.Driver)
|
return nil, fmt.Errorf("unknown sink driver: %q", driver)
|
||||||
}
|
}
|
||||||
return f(cfg)
|
sink, err := f(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build sink %q: %w", driver, err)
|
||||||
|
}
|
||||||
|
if sink == nil {
|
||||||
|
return nil, fmt.Errorf("build sink %q: factory returned nil sink", driver)
|
||||||
|
}
|
||||||
|
return sink, nil
|
||||||
}
|
}
|
||||||
|
|||||||
126
sinks/registry_test.go
Normal file
126
sinks/registry_test.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
package sinks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testSink struct{ name string }
|
||||||
|
|
||||||
|
func (s testSink) Name() string { return s.name }
|
||||||
|
|
||||||
|
func (s testSink) Consume(context.Context, event.Event) error { return nil }
|
||||||
|
|
||||||
|
func TestRegistryRegisterPanicsOnNilRegistry(t *testing.T) {
|
||||||
|
var r *Registry
|
||||||
|
defer func() {
|
||||||
|
if recover() == nil {
|
||||||
|
t.Fatalf("Register() expected panic on nil registry")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "stdout"}, nil })
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryRegisterPanicsOnEmptyDriver(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
defer func() {
|
||||||
|
if recover() == nil {
|
||||||
|
t.Fatalf("Register() expected panic on empty driver")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
r.Register(" ", func(config.SinkConfig) (Sink, error) { return testSink{name: "x"}, nil })
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryRegisterPanicsOnNilFactory(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
defer func() {
|
||||||
|
if recover() == nil {
|
||||||
|
t.Fatalf("Register() expected panic on nil factory")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
r.Register("stdout", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryRegisterPanicsOnDuplicateDriver(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "a"}, nil })
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if recover() == nil {
|
||||||
|
t.Fatalf("Register() expected panic on duplicate driver")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "b"}, nil })
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryBuildNilRegistryFails(t *testing.T) {
|
||||||
|
var r *Registry
|
||||||
|
_, err := r.Build(config.SinkConfig{Driver: "stdout"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Build() expected error for nil registry")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "registry is nil") {
|
||||||
|
t.Fatalf("Build() error = %q, want registry is nil", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryBuildTrimsDriver(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register("stdout", func(config.SinkConfig) (Sink, error) { return testSink{name: "stdout"}, nil })
|
||||||
|
|
||||||
|
sink, err := r.Build(config.SinkConfig{Name: "sink1", Driver: " stdout "})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Build() error = %v", err)
|
||||||
|
}
|
||||||
|
if sink.Name() != "stdout" {
|
||||||
|
t.Fatalf("Build() sink name = %q, want stdout", sink.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryBuildWrapsFactoryError(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register("broken", func(config.SinkConfig) (Sink, error) { return nil, errors.New("boom") })
|
||||||
|
|
||||||
|
_, err := r.Build(config.SinkConfig{Driver: "broken"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Build() expected error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), `build sink "broken": boom`) {
|
||||||
|
t.Fatalf("Build() error = %q", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryBuildRejectsNilSink(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register("nil_sink", func(config.SinkConfig) (Sink, error) { return nil, nil })
|
||||||
|
|
||||||
|
_, err := r.Build(config.SinkConfig{Driver: "nil_sink"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Build() expected nil sink error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "factory returned nil sink") {
|
||||||
|
t.Fatalf("Build() error = %q", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterBuiltinsExposesExpectedDrivers(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
RegisterBuiltins(r)
|
||||||
|
|
||||||
|
if len(r.byDriver) != 2 {
|
||||||
|
t.Fatalf("len(byDriver) = %d, want 2", len(r.byDriver))
|
||||||
|
}
|
||||||
|
for _, driver := range []string{"stdout", "nats"} {
|
||||||
|
if _, ok := r.byDriver[driver]; !ok {
|
||||||
|
t.Fatalf("builtins missing driver %q", driver)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ok := r.byDriver["postgres"]; ok {
|
||||||
|
t.Fatalf("builtins unexpectedly registered postgres driver")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,14 +1,35 @@
|
|||||||
// Package sources defines feedkit's input-source abstraction.
|
// Package sources defines feedkit's input-source abstractions and source
|
||||||
|
// registry.
|
||||||
//
|
//
|
||||||
// A source ingests upstream input and emits one or more event.Event values.
|
// External API surface:
|
||||||
//
|
// - Input: common source identity surface
|
||||||
// feedkit supports two source modes:
|
// - PollSource: polling source interface
|
||||||
// - PollSource: scheduler invokes Poll on a cadence.
|
// - StreamSource: streaming source interface
|
||||||
// - StreamSource: source runs continuously and pushes events as input arrives.
|
// - Registry / NewRegistry: source driver registry and builders
|
||||||
|
// - HTTPSource / NewHTTPSource: reusable HTTP polling helper
|
||||||
//
|
//
|
||||||
// Source drivers are domain-specific and registered into Registry by driver name.
|
// Source drivers are domain-specific and registered into Registry by driver name.
|
||||||
// Registry can then build configured sources from config.SourceConfig.
|
// Registry can then build configured sources from config.SourceConfig.
|
||||||
//
|
//
|
||||||
// A single source may emit 0..N events per poll or stream iteration, and those
|
// A single source may emit 0..N events per poll or stream iteration, and those
|
||||||
// events may span multiple event kinds.
|
// events may span multiple event kinds.
|
||||||
|
//
|
||||||
|
// Optional helpers from helpers.go:
|
||||||
|
// - DefaultEventID: default event ID policy for source implementations
|
||||||
|
// - SingleEvent: construct and validate a one-element event slice
|
||||||
|
// - ValidateExpectedKinds: compare configured expected kinds against source
|
||||||
|
// advertised kinds when metadata is available
|
||||||
|
//
|
||||||
|
// HTTP-backed polling sources can share NewHTTPSource for generic HTTP config
|
||||||
|
// parsing and conditional GET behavior. The helper understands:
|
||||||
|
// - params.url
|
||||||
|
// - params.user_agent
|
||||||
|
// - params.conditional (optional, default true)
|
||||||
|
// - params.http_timeout (optional, default transport.DefaultHTTPTimeout)
|
||||||
|
// - params.http_response_body_limit_bytes (optional, default
|
||||||
|
// transport.DefaultHTTPResponseBodyLimitBytes)
|
||||||
|
//
|
||||||
|
// When validators are available, NewHTTPSource prefers ETag/If-None-Match and
|
||||||
|
// falls back to Last-Modified/If-Modified-Since. A 304 Not Modified response is
|
||||||
|
// treated as a successful unchanged poll.
|
||||||
package sources
|
package sources
|
||||||
|
|||||||
140
sources/helpers.go
Normal file
140
sources/helpers.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package sources
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultEventID applies feedkit's default Event.ID policy:
|
||||||
|
//
|
||||||
|
// - If upstream provides an ID, use it (trimmed).
|
||||||
|
// - Otherwise, ID is "<Source>:<EffectiveAt>" when available.
|
||||||
|
// - If EffectiveAt is unavailable, fall back to "<Source>:<EmittedAt>".
|
||||||
|
//
|
||||||
|
// Timestamps are encoded as RFC3339Nano in UTC.
|
||||||
|
func DefaultEventID(upstreamID, sourceName string, effectiveAt *time.Time, emittedAt time.Time) string {
|
||||||
|
if id := strings.TrimSpace(upstreamID); id != "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
src := strings.TrimSpace(sourceName)
|
||||||
|
if src == "" {
|
||||||
|
src = "UNKNOWN_SOURCE"
|
||||||
|
}
|
||||||
|
|
||||||
|
if effectiveAt != nil && !effectiveAt.IsZero() {
|
||||||
|
return fmt.Sprintf("%s:%s", src, effectiveAt.UTC().Format(time.RFC3339Nano))
|
||||||
|
}
|
||||||
|
|
||||||
|
t := emittedAt.UTC()
|
||||||
|
if t.IsZero() {
|
||||||
|
t = time.Now().UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s:%s", src, t.Format(time.RFC3339Nano))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SingleEvent constructs, validates, and returns a slice containing exactly one event.
|
||||||
|
func SingleEvent(
|
||||||
|
kind event.Kind,
|
||||||
|
sourceName string,
|
||||||
|
schema string,
|
||||||
|
id string,
|
||||||
|
emittedAt time.Time,
|
||||||
|
effectiveAt *time.Time,
|
||||||
|
payload any,
|
||||||
|
) ([]event.Event, error) {
|
||||||
|
if emittedAt.IsZero() {
|
||||||
|
emittedAt = time.Now().UTC()
|
||||||
|
} else {
|
||||||
|
emittedAt = emittedAt.UTC()
|
||||||
|
}
|
||||||
|
|
||||||
|
e := event.Event{
|
||||||
|
ID: id,
|
||||||
|
Kind: kind,
|
||||||
|
Source: sourceName,
|
||||||
|
EmittedAt: emittedAt,
|
||||||
|
EffectiveAt: effectiveAt,
|
||||||
|
Schema: schema,
|
||||||
|
Payload: payload,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return []event.Event{e}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateExpectedKinds checks that configured source expected kinds are a subset
|
||||||
|
// of the kinds advertised by the built source, when the source exposes kind
|
||||||
|
// metadata. If the source does not advertise kinds, the check is skipped.
|
||||||
|
func ValidateExpectedKinds(cfg config.SourceConfig, in Input) error {
|
||||||
|
expectedKinds, err := parseExpectedKinds(cfg.ExpectedKinds())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(expectedKinds) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
advertisedKinds := advertisedSourceKinds(in)
|
||||||
|
if len(advertisedKinds) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for kind := range expectedKinds {
|
||||||
|
if !advertisedKinds[kind] {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"configured expected kind %q not advertised by source (configured=%v advertised=%v)",
|
||||||
|
kind,
|
||||||
|
sortedKinds(expectedKinds),
|
||||||
|
sortedKinds(advertisedKinds),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseExpectedKinds(raw []string) (map[event.Kind]bool, error) {
|
||||||
|
kinds := map[event.Kind]bool{}
|
||||||
|
for i, k := range raw {
|
||||||
|
kind, err := event.ParseKind(k)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid expected kind at index %d (%q): %w", i, k, err)
|
||||||
|
}
|
||||||
|
kinds[kind] = true
|
||||||
|
}
|
||||||
|
return kinds, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func advertisedSourceKinds(in Input) map[event.Kind]bool {
|
||||||
|
if in == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
kinds := map[event.Kind]bool{}
|
||||||
|
if ks, ok := in.(KindsSource); ok {
|
||||||
|
for _, kind := range ks.Kinds() {
|
||||||
|
kinds[kind] = true
|
||||||
|
}
|
||||||
|
return kinds
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortedKinds(kindSet map[event.Kind]bool) []string {
|
||||||
|
out := make([]string, 0, len(kindSet))
|
||||||
|
for kind := range kindSet {
|
||||||
|
out = append(out, string(kind))
|
||||||
|
}
|
||||||
|
sort.Strings(out)
|
||||||
|
return out
|
||||||
|
}
|
||||||
112
sources/helpers_test.go
Normal file
112
sources/helpers_test.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package sources
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testInput struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s testInput) Name() string { return s.name }
|
||||||
|
|
||||||
|
type testKindsSource struct {
|
||||||
|
testInput
|
||||||
|
kinds []event.Kind
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s testKindsSource) Kinds() []event.Kind { return s.kinds }
|
||||||
|
|
||||||
|
func TestValidateExpectedKindsSubsetAllowed(t *testing.T) {
|
||||||
|
cfg := config.SourceConfig{Kinds: []string{"observation"}}
|
||||||
|
in := testKindsSource{
|
||||||
|
testInput: testInput{name: "test"},
|
||||||
|
kinds: []event.Kind{"observation", "forecast"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ValidateExpectedKinds(cfg, in); err != nil {
|
||||||
|
t.Fatalf("ValidateExpectedKinds() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateExpectedKindsMismatchFails(t *testing.T) {
|
||||||
|
cfg := config.SourceConfig{Kinds: []string{"alert"}}
|
||||||
|
in := testKindsSource{
|
||||||
|
testInput: testInput{name: "test"},
|
||||||
|
kinds: []event.Kind{"observation", "forecast"},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ValidateExpectedKinds(cfg, in)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("ValidateExpectedKinds() expected mismatch error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "configured expected kind") {
|
||||||
|
t.Fatalf("ValidateExpectedKinds() error %q does not include expected message", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateExpectedKindsNoMetadataSkipsCheck(t *testing.T) {
|
||||||
|
cfg := config.SourceConfig{Kinds: []string{"alert"}}
|
||||||
|
in := testInput{name: "test"}
|
||||||
|
|
||||||
|
if err := ValidateExpectedKinds(cfg, in); err != nil {
|
||||||
|
t.Fatalf("ValidateExpectedKinds() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultEventIDUsesUpstreamID(t *testing.T) {
|
||||||
|
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123, time.UTC)
|
||||||
|
got := DefaultEventID(" upstream-id ", "source", nil, emittedAt)
|
||||||
|
if got != "upstream-id" {
|
||||||
|
t.Fatalf("DefaultEventID() = %q, want upstream-id", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultEventIDPrefersEffectiveAt(t *testing.T) {
|
||||||
|
effectiveAt := time.Date(2026, 3, 28, 16, 4, 5, 987654321, time.FixedZone("x", -6*3600))
|
||||||
|
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123, time.UTC)
|
||||||
|
|
||||||
|
got := DefaultEventID("", "source", &effectiveAt, emittedAt)
|
||||||
|
want := "source:" + effectiveAt.UTC().Format(time.RFC3339Nano)
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("DefaultEventID() = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultEventIDFallsBackToEmittedAt(t *testing.T) {
|
||||||
|
emittedAt := time.Date(2026, 3, 28, 15, 4, 5, 123456789, time.FixedZone("y", 3*3600))
|
||||||
|
got := DefaultEventID("", "source", nil, emittedAt)
|
||||||
|
want := "source:" + emittedAt.UTC().Format(time.RFC3339Nano)
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("DefaultEventID() = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSingleEventBuildsValidatedSlice(t *testing.T) {
|
||||||
|
effectiveAt := time.Date(2026, 3, 28, 16, 0, 0, 0, time.UTC)
|
||||||
|
emittedAt := time.Date(2026, 3, 28, 15, 0, 0, 0, time.FixedZone("z", -5*3600))
|
||||||
|
|
||||||
|
got, err := SingleEvent(
|
||||||
|
event.Kind("observation"),
|
||||||
|
"source-a",
|
||||||
|
"raw.example.v1",
|
||||||
|
"evt-1",
|
||||||
|
emittedAt,
|
||||||
|
&effectiveAt,
|
||||||
|
map[string]any{"ok": true},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SingleEvent() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatalf("SingleEvent() len = %d, want 1", len(got))
|
||||||
|
}
|
||||||
|
if got[0].EmittedAt != emittedAt.UTC() {
|
||||||
|
t.Fatalf("SingleEvent() emittedAt = %s, want %s", got[0].EmittedAt, emittedAt.UTC())
|
||||||
|
}
|
||||||
|
}
|
||||||
152
sources/http.go
Normal file
152
sources/http.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
package sources
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/transport"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPSource is a reusable helper for polling HTTP-backed sources.
|
||||||
|
//
|
||||||
|
// It centralizes generic source config parsing (`params.url`,
|
||||||
|
// `params.user_agent`, and optional `params.conditional`), default HTTP client
|
||||||
|
// setup, and conditional GET validator handling. Concrete daemon sources remain
|
||||||
|
// responsible for decoding the response body and constructing events.
|
||||||
|
type HTTPSource struct {
|
||||||
|
Driver string
|
||||||
|
Name string
|
||||||
|
URL string
|
||||||
|
UserAgent string
|
||||||
|
Accept string
|
||||||
|
Conditional bool
|
||||||
|
ResponseBodyLimitBytes int64
|
||||||
|
Client *http.Client
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
validators transport.HTTPValidators
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHTTPSource builds a generic HTTP polling helper from SourceConfig.
|
||||||
|
//
|
||||||
|
// Required params:
|
||||||
|
// - params.url
|
||||||
|
// - params.user_agent
|
||||||
|
//
|
||||||
|
// Optional params:
|
||||||
|
// - params.conditional (default true): enable conditional GET using cached
|
||||||
|
// ETag / Last-Modified validators
|
||||||
|
func NewHTTPSource(driver string, cfg config.SourceConfig, accept string) (*HTTPSource, error) {
|
||||||
|
name := strings.TrimSpace(cfg.Name)
|
||||||
|
if name == "" {
|
||||||
|
return nil, fmt.Errorf("%s: name is required", driver)
|
||||||
|
}
|
||||||
|
if cfg.Params == nil {
|
||||||
|
return nil, fmt.Errorf("%s %q: params are required (need params.url and params.user_agent)", driver, cfg.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
url, ok := cfg.ParamString("url", "URL")
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("%s %q: params.url is required", driver, cfg.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
userAgent, ok := cfg.ParamString("user_agent", "userAgent")
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("%s %q: params.user_agent is required", driver, cfg.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
conditional := true
|
||||||
|
if _, exists := cfg.Params["conditional"]; exists {
|
||||||
|
var ok bool
|
||||||
|
conditional, ok = cfg.ParamBool("conditional")
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("source %q: params.conditional must be a boolean", cfg.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := transport.DefaultHTTPTimeout
|
||||||
|
if _, exists := cfg.Params["http_timeout"]; exists {
|
||||||
|
var ok bool
|
||||||
|
timeout, ok = cfg.ParamDuration("http_timeout")
|
||||||
|
if !ok || timeout <= 0 {
|
||||||
|
return nil, fmt.Errorf("source %q: params.http_timeout must be a positive duration", cfg.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyLimit := transport.DefaultHTTPResponseBodyLimitBytes
|
||||||
|
if _, exists := cfg.Params["http_response_body_limit_bytes"]; exists {
|
||||||
|
rawLimit, ok := cfg.ParamInt("http_response_body_limit_bytes")
|
||||||
|
if !ok || rawLimit <= 0 {
|
||||||
|
return nil, fmt.Errorf("source %q: params.http_response_body_limit_bytes must be a positive integer", cfg.Name)
|
||||||
|
}
|
||||||
|
bodyLimit = int64(rawLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &HTTPSource{
|
||||||
|
Driver: driver,
|
||||||
|
Name: name,
|
||||||
|
URL: url,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
Accept: accept,
|
||||||
|
Conditional: conditional,
|
||||||
|
ResponseBodyLimitBytes: bodyLimit,
|
||||||
|
Client: transport.NewHTTPClient(timeout),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchBytesIfChanged fetches the configured URL and reports whether the
|
||||||
|
// upstream content changed. An unchanged 304 response returns changed=false
|
||||||
|
// with no body and no error.
|
||||||
|
func (s *HTTPSource) FetchBytesIfChanged(ctx context.Context) ([]byte, bool, error) {
|
||||||
|
client := s.Client
|
||||||
|
if client == nil {
|
||||||
|
client = transport.NewHTTPClient(transport.DefaultHTTPTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
validators := s.validators
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
bodyLimit := s.ResponseBodyLimitBytes
|
||||||
|
if bodyLimit <= 0 {
|
||||||
|
bodyLimit = transport.DefaultHTTPResponseBodyLimitBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
body, changed, next, err := transport.FetchBodyIfChangedWithLimit(
|
||||||
|
ctx,
|
||||||
|
client,
|
||||||
|
s.URL,
|
||||||
|
s.UserAgent,
|
||||||
|
s.Accept,
|
||||||
|
s.Conditional,
|
||||||
|
validators,
|
||||||
|
bodyLimit,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, fmt.Errorf("%s %q: %w", s.Driver, s.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Conditional {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.validators = next
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return body, changed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchJSONIfChanged fetches the configured URL and returns the raw response
|
||||||
|
// body as json.RawMessage when content changed. An unchanged 304 response
|
||||||
|
// returns changed=false with a nil body and no error.
|
||||||
|
func (s *HTTPSource) FetchJSONIfChanged(ctx context.Context) (json.RawMessage, bool, error) {
|
||||||
|
body, changed, err := s.FetchBytesIfChanged(ctx)
|
||||||
|
if err != nil || !changed {
|
||||||
|
return nil, changed, err
|
||||||
|
}
|
||||||
|
return json.RawMessage(body), true, nil
|
||||||
|
}
|
||||||
260
sources/http_test.go
Normal file
260
sources/http_test.go
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
package sources
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||||
|
"gitea.maximumdirect.net/ejr/feedkit/transport"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewHTTPSourceConditionalDefaultsTrue(t *testing.T) {
|
||||||
|
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||||
|
Name: "test-source",
|
||||||
|
Driver: "test_driver",
|
||||||
|
Params: map[string]any{
|
||||||
|
"url": "https://example.invalid",
|
||||||
|
"user_agent": "test-agent",
|
||||||
|
},
|
||||||
|
}, "application/json")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||||
|
}
|
||||||
|
if !src.Conditional {
|
||||||
|
t.Fatalf("Conditional = false, want true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPSourceRejectsInvalidConditional(t *testing.T) {
|
||||||
|
_, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||||
|
Name: "test-source",
|
||||||
|
Driver: "test_driver",
|
||||||
|
Params: map[string]any{
|
||||||
|
"url": "https://example.invalid",
|
||||||
|
"user_agent": "test-agent",
|
||||||
|
"conditional": "sometimes",
|
||||||
|
},
|
||||||
|
}, "application/json")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("NewHTTPSource() error = nil, want error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "params.conditional must be a boolean") {
|
||||||
|
t.Fatalf("NewHTTPSource() error = %q, want params.conditional must be a boolean", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPSourceConditionalCanBeExplicitlyFalse(t *testing.T) {
|
||||||
|
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||||
|
Name: "test-source",
|
||||||
|
Driver: "test_driver",
|
||||||
|
Params: map[string]any{
|
||||||
|
"url": "https://example.invalid",
|
||||||
|
"user_agent": "test-agent",
|
||||||
|
"conditional": false,
|
||||||
|
},
|
||||||
|
}, "application/json")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||||
|
}
|
||||||
|
if src.Conditional {
|
||||||
|
t.Fatalf("Conditional = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPSourceHTTPTimeoutOverride(t *testing.T) {
|
||||||
|
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||||
|
Name: "test-source",
|
||||||
|
Driver: "test_driver",
|
||||||
|
Params: map[string]any{
|
||||||
|
"url": "https://example.invalid",
|
||||||
|
"user_agent": "test-agent",
|
||||||
|
"http_timeout": "250ms",
|
||||||
|
},
|
||||||
|
}, "application/json")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||||
|
}
|
||||||
|
if src.Client == nil {
|
||||||
|
t.Fatalf("Client = nil")
|
||||||
|
}
|
||||||
|
if src.Client.Timeout != 250*time.Millisecond {
|
||||||
|
t.Fatalf("Client.Timeout = %s, want 250ms", src.Client.Timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPSourceBodyLimitOverride(t *testing.T) {
|
||||||
|
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||||
|
Name: "test-source",
|
||||||
|
Driver: "test_driver",
|
||||||
|
Params: map[string]any{
|
||||||
|
"url": "https://example.invalid",
|
||||||
|
"user_agent": "test-agent",
|
||||||
|
"http_response_body_limit_bytes": 12345,
|
||||||
|
},
|
||||||
|
}, "application/json")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||||
|
}
|
||||||
|
if src.ResponseBodyLimitBytes != 12345 {
|
||||||
|
t.Fatalf("ResponseBodyLimitBytes = %d, want 12345", src.ResponseBodyLimitBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPSourceRejectsInvalidHTTPTimeout(t *testing.T) {
|
||||||
|
_, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||||
|
Name: "test-source",
|
||||||
|
Driver: "test_driver",
|
||||||
|
Params: map[string]any{
|
||||||
|
"url": "https://example.invalid",
|
||||||
|
"user_agent": "test-agent",
|
||||||
|
"http_timeout": "soon",
|
||||||
|
},
|
||||||
|
}, "application/json")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("NewHTTPSource() error = nil, want error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "params.http_timeout must be a positive duration") {
|
||||||
|
t.Fatalf("NewHTTPSource() error = %q", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPSourceRejectsInvalidBodyLimit(t *testing.T) {
|
||||||
|
_, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||||
|
Name: "test-source",
|
||||||
|
Driver: "test_driver",
|
||||||
|
Params: map[string]any{
|
||||||
|
"url": "https://example.invalid",
|
||||||
|
"user_agent": "test-agent",
|
||||||
|
"http_response_body_limit_bytes": "abc",
|
||||||
|
},
|
||||||
|
}, "application/json")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("NewHTTPSource() error = nil, want error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "params.http_response_body_limit_bytes must be a positive integer") {
|
||||||
|
t.Fatalf("NewHTTPSource() error = %q", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPSourceFetchJSONIfChanged(t *testing.T) {
|
||||||
|
var call int
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
call++
|
||||||
|
switch call {
|
||||||
|
case 1:
|
||||||
|
w.Header().Set("ETag", `"v1"`)
|
||||||
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||||
|
case 2:
|
||||||
|
if got := r.Header.Get("If-None-Match"); got != `"v1"` {
|
||||||
|
t.Fatalf("second request If-None-Match = %q", got)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotModified)
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected call count %d", call)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||||
|
Name: "test-source",
|
||||||
|
Driver: "test_driver",
|
||||||
|
Params: map[string]any{
|
||||||
|
"url": srv.URL,
|
||||||
|
"user_agent": "test-agent",
|
||||||
|
},
|
||||||
|
}, "application/json")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, changed, err := src.FetchJSONIfChanged(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first FetchJSONIfChanged() error = %v", err)
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
t.Fatalf("first FetchJSONIfChanged() changed = false, want true")
|
||||||
|
}
|
||||||
|
if got := string(raw); got != `{"ok":true}` {
|
||||||
|
t.Fatalf("first FetchJSONIfChanged() body = %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
raw, changed, err = src.FetchJSONIfChanged(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second FetchJSONIfChanged() error = %v", err)
|
||||||
|
}
|
||||||
|
if changed {
|
||||||
|
t.Fatalf("second FetchJSONIfChanged() changed = true, want false")
|
||||||
|
}
|
||||||
|
if raw != nil {
|
||||||
|
t.Fatalf("second FetchJSONIfChanged() body = %q, want nil", string(raw))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPSourceFetchJSONIfChangedHonorsBodyLimitOverride(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||||
|
Name: "test-source",
|
||||||
|
Driver: "test_driver",
|
||||||
|
Params: map[string]any{
|
||||||
|
"url": srv.URL,
|
||||||
|
"user_agent": "test-agent",
|
||||||
|
"http_response_body_limit_bytes": 4,
|
||||||
|
},
|
||||||
|
}, "application/json")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err = src.FetchJSONIfChanged(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("FetchJSONIfChanged() error = nil, want limit error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "response body too large") {
|
||||||
|
t.Fatalf("FetchJSONIfChanged() error = %q", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPSourceUsesDefaultBodyLimitWhenUnset(t *testing.T) {
|
||||||
|
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||||
|
Name: "test-source",
|
||||||
|
Driver: "test_driver",
|
||||||
|
Params: map[string]any{
|
||||||
|
"url": "https://example.invalid",
|
||||||
|
"user_agent": "test-agent",
|
||||||
|
},
|
||||||
|
}, "application/json")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||||
|
}
|
||||||
|
if src.ResponseBodyLimitBytes != transport.DefaultHTTPResponseBodyLimitBytes {
|
||||||
|
t.Fatalf("ResponseBodyLimitBytes = %d, want %d", src.ResponseBodyLimitBytes, transport.DefaultHTTPResponseBodyLimitBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPSourceUsesDefaultTimeoutWhenUnset(t *testing.T) {
|
||||||
|
src, err := NewHTTPSource("test_driver", config.SourceConfig{
|
||||||
|
Name: "test-source",
|
||||||
|
Driver: "test_driver",
|
||||||
|
Params: map[string]any{
|
||||||
|
"url": "https://example.invalid",
|
||||||
|
"user_agent": "test-agent",
|
||||||
|
},
|
||||||
|
}, "application/json")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewHTTPSource() error = %v", err)
|
||||||
|
}
|
||||||
|
if src.Client == nil {
|
||||||
|
t.Fatalf("Client = nil")
|
||||||
|
}
|
||||||
|
if src.Client.Timeout != transport.DefaultHTTPTimeout {
|
||||||
|
t.Fatalf("Client.Timeout = %s, want %s", src.Client.Timeout, transport.DefaultHTTPTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,9 +15,6 @@ import (
|
|||||||
type PollFactory func(cfg config.SourceConfig) (PollSource, error)
|
type PollFactory func(cfg config.SourceConfig) (PollSource, error)
|
||||||
type StreamFactory func(cfg config.SourceConfig) (StreamSource, error)
|
type StreamFactory func(cfg config.SourceConfig) (StreamSource, error)
|
||||||
|
|
||||||
// Factory is the legacy alias for poll source factories.
|
|
||||||
type Factory = PollFactory
|
|
||||||
|
|
||||||
type Registry struct {
|
type Registry struct {
|
||||||
byPollDriver map[string]PollFactory
|
byPollDriver map[string]PollFactory
|
||||||
byStreamDriver map[string]StreamFactory
|
byStreamDriver map[string]StreamFactory
|
||||||
@@ -30,13 +27,6 @@ func NewRegistry() *Registry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register associates a driver name (e.g. "openmeteo_observation") with a factory.
|
|
||||||
//
|
|
||||||
// The driver string is the "lookup key" used by config.sources[].driver.
|
|
||||||
func (r *Registry) Register(driver string, f PollFactory) {
|
|
||||||
r.RegisterPoll(driver, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterPoll associates a driver name with a polling-source factory.
|
// RegisterPoll associates a driver name with a polling-source factory.
|
||||||
func (r *Registry) RegisterPoll(driver string, f PollFactory) {
|
func (r *Registry) RegisterPoll(driver string, f PollFactory) {
|
||||||
driver = strings.TrimSpace(driver)
|
driver = strings.TrimSpace(driver)
|
||||||
@@ -75,11 +65,6 @@ func (r *Registry) RegisterStream(driver string, f StreamFactory) {
|
|||||||
r.byStreamDriver[driver] = f
|
r.byStreamDriver[driver] = f
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build constructs a polling source from a SourceConfig by looking up cfg.Driver.
|
|
||||||
func (r *Registry) Build(cfg config.SourceConfig) (PollSource, error) {
|
|
||||||
return r.BuildPoll(cfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildPoll constructs a polling source from a SourceConfig by looking up cfg.Driver.
|
// BuildPoll constructs a polling source from a SourceConfig by looking up cfg.Driver.
|
||||||
func (r *Registry) BuildPoll(cfg config.SourceConfig) (PollSource, error) {
|
func (r *Registry) BuildPoll(cfg config.SourceConfig) (PollSource, error) {
|
||||||
driver := strings.TrimSpace(cfg.Driver)
|
driver := strings.TrimSpace(cfg.Driver)
|
||||||
|
|||||||
@@ -31,9 +31,6 @@ type PollSource interface {
|
|||||||
Poll(ctx context.Context) ([]event.Event, error)
|
Poll(ctx context.Context) ([]event.Event, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Source is a compatibility alias for the legacy polling-source name.
|
|
||||||
type Source = PollSource
|
|
||||||
|
|
||||||
// StreamSource is an event-driven source (NATS/RabbitMQ/MQTT/etc).
|
// StreamSource is an event-driven source (NATS/RabbitMQ/MQTT/etc).
|
||||||
//
|
//
|
||||||
// Run should block, producing events into `out` until ctx is cancelled or a fatal error occurs.
|
// Run should block, producing events into `out` until ctx is cancelled or a fatal error occurs.
|
||||||
@@ -43,12 +40,6 @@ type StreamSource interface {
|
|||||||
Run(ctx context.Context, out chan<- event.Event) error
|
Run(ctx context.Context, out chan<- event.Event) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// KindSource is an optional interface for sources that advertise one "primary" kind.
|
|
||||||
// This is legacy-friendly but no longer required.
|
|
||||||
type KindSource interface {
|
|
||||||
Kind() event.Kind
|
|
||||||
}
|
|
||||||
|
|
||||||
// KindsSource is an optional interface for sources that advertise multiple kinds.
|
// KindsSource is an optional interface for sources that advertise multiple kinds.
|
||||||
type KindsSource interface {
|
type KindsSource interface {
|
||||||
Kinds() []event.Kind
|
Kinds() []event.Kind
|
||||||
|
|||||||
@@ -6,13 +6,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// maxResponseBodyBytes is a hard safety limit on HTTP response bodies.
|
// DefaultHTTPResponseBodyLimitBytes is a hard safety limit on HTTP response bodies.
|
||||||
// API responses should be small, so this protects us from accidental
|
// API responses should be small, so this protects us from accidental
|
||||||
// or malicious large responses.
|
// or malicious large responses.
|
||||||
const maxResponseBodyBytes = 2 << 21 // 4 MiB
|
const DefaultHTTPResponseBodyLimitBytes int64 = 2 << 21 // 4 MiB
|
||||||
|
|
||||||
// DefaultHTTPTimeout is the standard timeout used by HTTP sources.
|
// DefaultHTTPTimeout is the standard timeout used by HTTP sources.
|
||||||
// Individual drivers may override this if they have a specific need.
|
// Individual drivers may override this if they have a specific need.
|
||||||
@@ -28,7 +29,95 @@ func NewHTTPClient(timeout time.Duration) *http.Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept string) ([]byte, error) {
|
func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept string) ([]byte, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
return FetchBodyWithLimit(ctx, client, url, userAgent, accept, DefaultHTTPResponseBodyLimitBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchBodyWithLimit(ctx context.Context, client *http.Client, url, userAgent, accept string, bodyLimitBytes int64) ([]byte, error) {
|
||||||
|
res, err := doRequest(ctx, client, http.MethodGet, url, userAgent, accept, "", "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("HTTP %s", res.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
return readValidatedBody(res.Body, bodyLimitBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPValidators are cache validators learned from prior successful GET responses.
|
||||||
|
//
|
||||||
|
// ETag is preferred when present. LastModified is used as a fallback validator
|
||||||
|
// when ETag is unavailable.
|
||||||
|
type HTTPValidators struct {
|
||||||
|
ETag string
|
||||||
|
LastModified string
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchBodyIfChanged performs an HTTP GET and opportunistically uses conditional
|
||||||
|
// request headers based on the provided validators.
|
||||||
|
//
|
||||||
|
// Behavior:
|
||||||
|
// - if conditional is false, this behaves like a normal GET and leaves validators unchanged
|
||||||
|
// - if validators.ETag is set, sends If-None-Match
|
||||||
|
// - else if validators.LastModified is set, sends If-Modified-Since
|
||||||
|
// - 304 Not Modified is treated as success with changed=false and no body
|
||||||
|
// - 200 responses are treated as changed=true and still enforce the normal body checks
|
||||||
|
//
|
||||||
|
// Returned validators reflect any updates learned from the response headers.
|
||||||
|
func FetchBodyIfChanged(
|
||||||
|
ctx context.Context,
|
||||||
|
client *http.Client,
|
||||||
|
url, userAgent, accept string,
|
||||||
|
conditional bool,
|
||||||
|
validators HTTPValidators,
|
||||||
|
) ([]byte, bool, HTTPValidators, error) {
|
||||||
|
return FetchBodyIfChangedWithLimit(ctx, client, url, userAgent, accept, conditional, validators, DefaultHTTPResponseBodyLimitBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchBodyIfChangedWithLimit(
|
||||||
|
ctx context.Context,
|
||||||
|
client *http.Client,
|
||||||
|
url, userAgent, accept string,
|
||||||
|
conditional bool,
|
||||||
|
validators HTTPValidators,
|
||||||
|
bodyLimitBytes int64,
|
||||||
|
) ([]byte, bool, HTTPValidators, error) {
|
||||||
|
headerName, headerValue := conditionalHeader(conditional, validators)
|
||||||
|
|
||||||
|
res, err := doRequest(ctx, client, http.MethodGet, url, userAgent, accept, headerName, headerValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, validators, err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
switch res.StatusCode {
|
||||||
|
case http.StatusNotModified:
|
||||||
|
if conditional {
|
||||||
|
validators = refreshValidators(validators, res.Header)
|
||||||
|
}
|
||||||
|
return nil, false, validators, nil
|
||||||
|
default:
|
||||||
|
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||||
|
return nil, false, validators, fmt.Errorf("HTTP %s", res.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := readValidatedBody(res.Body, bodyLimitBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, validators, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if conditional {
|
||||||
|
validators = replaceValidators(res.Header)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b, true, validators, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func doRequest(ctx context.Context, client *http.Client, method, url, userAgent, accept, headerName, headerValue string) (*http.Response, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -39,19 +128,50 @@ func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept
|
|||||||
if accept != "" {
|
if accept != "" {
|
||||||
req.Header.Set("Accept", accept)
|
req.Header.Set("Accept", accept)
|
||||||
}
|
}
|
||||||
|
if headerName != "" && headerValue != "" {
|
||||||
res, err := client.Do(req)
|
req.Header.Set(headerName, headerValue)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
|
|
||||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
|
||||||
return nil, fmt.Errorf("HTTP %s", res.Status)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read at most maxResponseBodyBytes + 1 so we can detect overflow.
|
return client.Do(req)
|
||||||
limited := io.LimitReader(res.Body, maxResponseBodyBytes+1)
|
}
|
||||||
|
|
||||||
|
func conditionalHeader(enabled bool, validators HTTPValidators) (string, string) {
|
||||||
|
if !enabled {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
if etag := strings.TrimSpace(validators.ETag); etag != "" {
|
||||||
|
return "If-None-Match", etag
|
||||||
|
}
|
||||||
|
if lastModified := strings.TrimSpace(validators.LastModified); lastModified != "" {
|
||||||
|
return "If-Modified-Since", lastModified
|
||||||
|
}
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func replaceValidators(header http.Header) HTTPValidators {
|
||||||
|
return HTTPValidators{
|
||||||
|
ETag: strings.TrimSpace(header.Get("ETag")),
|
||||||
|
LastModified: strings.TrimSpace(header.Get("Last-Modified")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func refreshValidators(current HTTPValidators, header http.Header) HTTPValidators {
|
||||||
|
if etag := strings.TrimSpace(header.Get("ETag")); etag != "" {
|
||||||
|
current.ETag = etag
|
||||||
|
}
|
||||||
|
if lastModified := strings.TrimSpace(header.Get("Last-Modified")); lastModified != "" {
|
||||||
|
current.LastModified = lastModified
|
||||||
|
}
|
||||||
|
return current
|
||||||
|
}
|
||||||
|
|
||||||
|
func readValidatedBody(r io.Reader, bodyLimitBytes int64) ([]byte, error) {
|
||||||
|
if bodyLimitBytes <= 0 {
|
||||||
|
bodyLimitBytes = DefaultHTTPResponseBodyLimitBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read at most bodyLimitBytes + 1 so we can detect overflow.
|
||||||
|
limited := io.LimitReader(r, bodyLimitBytes+1)
|
||||||
|
|
||||||
b, err := io.ReadAll(limited)
|
b, err := io.ReadAll(limited)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -62,8 +182,8 @@ func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept
|
|||||||
return nil, fmt.Errorf("empty response body")
|
return nil, fmt.Errorf("empty response body")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(b) > maxResponseBodyBytes {
|
if int64(len(b)) > bodyLimitBytes {
|
||||||
return nil, fmt.Errorf("response body too large (>%d bytes)", maxResponseBodyBytes)
|
return nil, fmt.Errorf("response body too large (>%d bytes)", bodyLimitBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
return b, nil
|
return b, nil
|
||||||
|
|||||||
232
transport/http_test.go
Normal file
232
transport/http_test.go
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFetchBodyIfChangedPrefersETagAndTreats304AsUnchanged(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var call int
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
call++
|
||||||
|
switch call {
|
||||||
|
case 1:
|
||||||
|
if got := r.Header.Get("If-None-Match"); got != "" {
|
||||||
|
t.Fatalf("first request If-None-Match = %q, want empty", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("If-Modified-Since"); got != "" {
|
||||||
|
t.Fatalf("first request If-Modified-Since = %q, want empty", got)
|
||||||
|
}
|
||||||
|
w.Header().Set("ETag", `"v1"`)
|
||||||
|
w.Header().Set("Last-Modified", "Mon, 02 Jan 2006 15:04:05 GMT")
|
||||||
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||||
|
case 2:
|
||||||
|
if got := r.Header.Get("If-None-Match"); got != `"v1"` {
|
||||||
|
t.Fatalf("second request If-None-Match = %q, want %q", got, `"v1"`)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("If-Modified-Since"); got != "" {
|
||||||
|
t.Fatalf("second request If-Modified-Since = %q, want empty when ETag is cached", got)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotModified)
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected call count %d", call)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
validators := HTTPValidators{}
|
||||||
|
body, changed, next, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "application/json", true, validators)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first FetchBodyIfChanged() error = %v", err)
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
t.Fatalf("first FetchBodyIfChanged() changed = false, want true")
|
||||||
|
}
|
||||||
|
if got := string(body); got != `{"ok":true}` {
|
||||||
|
t.Fatalf("first FetchBodyIfChanged() body = %q", got)
|
||||||
|
}
|
||||||
|
if got := next.ETag; got != `"v1"` {
|
||||||
|
t.Fatalf("cached ETag = %q, want %q", got, `"v1"`)
|
||||||
|
}
|
||||||
|
if got := next.LastModified; got != "Mon, 02 Jan 2006 15:04:05 GMT" {
|
||||||
|
t.Fatalf("cached Last-Modified = %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, changed, next, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "application/json", true, next)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second FetchBodyIfChanged() error = %v", err)
|
||||||
|
}
|
||||||
|
if changed {
|
||||||
|
t.Fatalf("second FetchBodyIfChanged() changed = true, want false")
|
||||||
|
}
|
||||||
|
if body != nil {
|
||||||
|
t.Fatalf("second FetchBodyIfChanged() body = %q, want nil", string(body))
|
||||||
|
}
|
||||||
|
if got := next.ETag; got != `"v1"` {
|
||||||
|
t.Fatalf("cached ETag after 304 = %q, want %q", got, `"v1"`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchBodyIfChangedFallsBackToIfModifiedSince(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var call int
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
call++
|
||||||
|
switch call {
|
||||||
|
case 1:
|
||||||
|
w.Header().Set("Last-Modified", "Tue, 03 Jan 2006 15:04:05 GMT")
|
||||||
|
_, _ = w.Write([]byte(`first`))
|
||||||
|
case 2:
|
||||||
|
if got := r.Header.Get("If-None-Match"); got != "" {
|
||||||
|
t.Fatalf("second request If-None-Match = %q, want empty", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("If-Modified-Since"); got != "Tue, 03 Jan 2006 15:04:05 GMT" {
|
||||||
|
t.Fatalf("second request If-Modified-Since = %q", got)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNotModified)
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected call count %d", call)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
_, changed, validators, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, HTTPValidators{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first FetchBodyIfChanged() error = %v", err)
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
t.Fatalf("first FetchBodyIfChanged() changed = false, want true")
|
||||||
|
}
|
||||||
|
if got := validators.LastModified; got != "Tue, 03 Jan 2006 15:04:05 GMT" {
|
||||||
|
t.Fatalf("cached Last-Modified = %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, changed, _, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, validators)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second FetchBodyIfChanged() error = %v", err)
|
||||||
|
}
|
||||||
|
if changed {
|
||||||
|
t.Fatalf("second FetchBodyIfChanged() changed = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchBodyIfChangedClearsValidatorsOn200WithoutValidators(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var call int
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
call++
|
||||||
|
switch call {
|
||||||
|
case 1:
|
||||||
|
w.Header().Set("ETag", `"v1"`)
|
||||||
|
_, _ = w.Write([]byte(`first`))
|
||||||
|
case 2:
|
||||||
|
if got := r.Header.Get("If-None-Match"); got != `"v1"` {
|
||||||
|
t.Fatalf("second request If-None-Match = %q", got)
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte(`second`))
|
||||||
|
case 3:
|
||||||
|
if got := r.Header.Get("If-None-Match"); got != "" {
|
||||||
|
t.Fatalf("third request If-None-Match = %q, want empty", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("If-Modified-Since"); got != "" {
|
||||||
|
t.Fatalf("third request If-Modified-Since = %q, want empty", got)
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte(`third`))
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected call count %d", call)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
_, _, validators, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, HTTPValidators{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first FetchBodyIfChanged() error = %v", err)
|
||||||
|
}
|
||||||
|
_, _, validators, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, validators)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second FetchBodyIfChanged() error = %v", err)
|
||||||
|
}
|
||||||
|
if validators.ETag != "" || validators.LastModified != "" {
|
||||||
|
t.Fatalf("validators after 200 without validators = %+v, want cleared", validators)
|
||||||
|
}
|
||||||
|
_, _, _, err = FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", true, validators)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("third FetchBodyIfChanged() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchBodyIfChangedConditionalDisabledSkipsConditionalHeaders(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var calls int
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
calls++
|
||||||
|
if got := r.Header.Get("If-None-Match"); got != "" {
|
||||||
|
t.Fatalf("request If-None-Match = %q, want empty", got)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("If-Modified-Since"); got != "" {
|
||||||
|
t.Fatalf("request If-Modified-Since = %q, want empty", got)
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte(`body`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
validators := HTTPValidators{ETag: `"v1"`, LastModified: "Wed, 04 Jan 2006 15:04:05 GMT"}
|
||||||
|
_, changed, next, err := FetchBodyIfChanged(context.Background(), srv.Client(), srv.URL, "test-agent", "", false, validators)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FetchBodyIfChanged() error = %v", err)
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
t.Fatalf("FetchBodyIfChanged() changed = false, want true")
|
||||||
|
}
|
||||||
|
if next != validators {
|
||||||
|
t.Fatalf("validators changed when conditional disabled: got %+v want %+v", next, validators)
|
||||||
|
}
|
||||||
|
if calls != 1 {
|
||||||
|
t.Fatalf("calls = %d, want 1", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchBodyIfChangedAllowsEmpty304ButRejectsEmpty200(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
notModified := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotModified)
|
||||||
|
}))
|
||||||
|
defer notModified.Close()
|
||||||
|
|
||||||
|
_, changed, _, err := FetchBodyIfChanged(
|
||||||
|
context.Background(),
|
||||||
|
notModified.Client(),
|
||||||
|
notModified.URL,
|
||||||
|
"test-agent",
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
HTTPValidators{ETag: `"v1"`},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("304 FetchBodyIfChanged() error = %v", err)
|
||||||
|
}
|
||||||
|
if changed {
|
||||||
|
t.Fatalf("304 FetchBodyIfChanged() changed = true, want false")
|
||||||
|
}
|
||||||
|
|
||||||
|
emptyBody := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer emptyBody.Close()
|
||||||
|
|
||||||
|
_, _, _, err = FetchBodyIfChanged(context.Background(), emptyBody.Client(), emptyBody.URL, "test-agent", "", true, HTTPValidators{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("empty 200 FetchBodyIfChanged() error = nil, want error")
|
||||||
|
}
|
||||||
|
if err.Error() != "empty response body" {
|
||||||
|
t.Fatalf("empty 200 FetchBodyIfChanged() error = %q", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user