diff --git a/README.md b/README.md index 3bfa696..53e441c 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ The major packages are: - `processors/normalize`: built-in normalization processor and helpers - `pipeline`: optional processor chain - `dispatch`: route compilation and fanout -- `sinks`: sink interfaces, built-ins, and Postgres registration helpers +- `sinks`: sink interfaces, built-ins, and explicit Postgres factory helpers The root package docs in `doc.go` provide a concise package-by-package map for Go documentation consumers. diff --git a/config/config.go b/config/config.go index f5f8468..ec6438c 100644 --- a/config/config.go +++ b/config/config.go @@ -69,33 +69,25 @@ type SourceConfig struct { // If set, it describes the expected emitted event kinds for this source. Kinds []string `yaml:"kinds"` - // Kind is the legacy singular form. Prefer "kinds". - // If both kind and kinds are set, validation fails. - Kind string `yaml:"kind"` - // Params are driver-specific settings (URL, headers, station IDs, API keys, etc.). // The driver implementation is responsible for reading/validating these. Params map[string]any `yaml:"params"` } // ExpectedKinds returns normalized expected kinds from config. -// "kinds" takes precedence; "kind" is used as a legacy fallback. func (cfg SourceConfig) ExpectedKinds() []string { - if len(cfg.Kinds) > 0 { - out := make([]string, 0, len(cfg.Kinds)) - for _, k := range cfg.Kinds { - k = strings.TrimSpace(k) - if k == "" { - continue - } - out = append(out, k) + out := make([]string, 0, len(cfg.Kinds)) + for _, k := range cfg.Kinds { + k = strings.TrimSpace(k) + if k == "" { + continue } - return out + out = append(out, k) } - if k := strings.TrimSpace(cfg.Kind); k != "" { - return []string{k} + if len(out) == 0 { + return nil } - return nil + return out } // SinkConfig describes one output sink adapter. diff --git a/config/config_test.go b/config/config_test.go index 6dd26f7..e127690 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -12,20 +12,12 @@ func TestSourceConfigExpectedKinds(t *testing.T) { want []string }{ { - name: "plural kinds preferred", + name: "plural kinds normalized", cfg: SourceConfig{ Kinds: []string{" observation ", "forecast"}, - Kind: "alert", }, want: []string{"observation", "forecast"}, }, - { - name: "legacy singular fallback", - cfg: SourceConfig{ - Kind: " alert ", - }, - want: []string{"alert"}, - }, { name: "empty kinds", cfg: SourceConfig{}, diff --git a/config/load.go b/config/load.go index 1c15f86..b8952a7 100644 --- a/config/load.go +++ b/config/load.go @@ -105,13 +105,7 @@ func (c *Config) Validate() error { } } - // Kind/Kinds (optional) - if s.Kind != "" && len(s.Kinds) > 0 { - m.Add(fieldErr(path+".kind", `cannot be set when "kinds" is provided (use only "kinds")`)) - } - if s.Kind != "" && strings.TrimSpace(s.Kind) == "" { - m.Add(fieldErr(path+".kind", "cannot be blank (omit it entirely, or provide a non-empty string)")) - } + // Kinds (optional) for j, k := range s.Kinds { kpath := fmt.Sprintf("%s.kinds[%d]", path, j) if strings.TrimSpace(k) == "" { diff --git a/config/validate_test.go b/config/validate_test.go index d4b09fd..def4a0c 100644 --- a/config/validate_test.go +++ b/config/validate_test.go @@ -114,31 +114,6 @@ func TestValidate_SourceModeRejectsUnknownValue(t *testing.T) { } } -func TestValidate_SourceKindAndKindsConflict(t *testing.T) { - cfg := &Config{ - Sources: []SourceConfig{ - { - Name: "src1", - Driver: "driver1", - Every: Duration{Duration: time.Minute}, - Kind: "observation", - Kinds: []string{"forecast"}, - }, - }, - Sinks: []SinkConfig{ - {Name: "sink1", Driver: "stdout"}, - }, - } - - err := cfg.Validate() - if err == nil { - t.Fatalf("expected error, got nil") - } - if !strings.Contains(err.Error(), `sources[0].kind`) { - t.Fatalf("expected error to mention sources[0].kind, got: %v", err) - } -} - func TestValidate_SourceKindsRejectBlankEntries(t *testing.T) { cfg := &Config{ Sources: []SourceConfig{ diff --git a/doc.go b/doc.go index 7d077e1..7f98343 100644 --- a/doc.go +++ b/doc.go @@ -43,8 +43,8 @@ // Compiles routes and fans events out to sinks with per-sink isolation. // // - sinks -// Defines sink interfaces, the sink registry, built-in sinks, and Postgres -// schema registration helpers. +// Defines sink interfaces, the sink registry, schema-free built-in sinks, +// and explicit Postgres factory helpers. // // feedkit is intentionally domain-agnostic. Domain schemas, domain event kinds, // upstream-specific parsing, and daemon lifecycle remain the responsibility of diff --git a/sinks/builtins.go b/sinks/builtins.go index 5049c96..5d2bdbc 100644 --- a/sinks/builtins.go +++ b/sinks/builtins.go @@ -12,11 +12,6 @@ func RegisterBuiltins(r *Registry) { return NewStdoutSink(cfg.Name), nil }) - // Postgres sink: persists events durably. - r.Register("postgres", func(cfg config.SinkConfig) (Sink, error) { - return NewPostgresSinkFromConfig(cfg) - }) - // NATS sink: publishes events to a broker for downstream consumers. r.Register("nats", func(cfg config.SinkConfig) (Sink, error) { return NewNATSSinkFromConfig(cfg) diff --git a/sinks/doc.go b/sinks/doc.go index 7962d5c..4b0634c 100644 --- a/sinks/doc.go +++ b/sinks/doc.go @@ -4,16 +4,16 @@ // External API surface: // - Sink: adapter interface that consumes event.Event values // - Registry / NewRegistry: named sink factory registry -// - RegisterBuiltins: registers the built-in sink drivers in this binary +// - RegisterBuiltins: registers the schema-free built-in sink drivers // -// Built-in drivers: +// Built-in sink implementations: // - stdout // - nats // - postgres // // Optional helpers from helpers.go: -// - RegisterPostgresSchemaForConfiguredSinks: registers one Postgres schema -// for each configured sink using driver=postgres +// - PostgresFactory: returns a sink factory for the built-in Postgres sink +// using a provided downstream schema // // # NATS built-in overview // @@ -59,7 +59,9 @@ // // Example downstream wiring: // -// sinks.MustRegisterPostgresSchema("pg_main", sinks.PostgresSchema{ +// sinkReg := sinks.NewRegistry() +// sinks.RegisterBuiltins(sinkReg) +// sinkReg.Register("postgres", sinks.PostgresFactory(sinks.PostgresSchema{ // Tables: []sinks.PostgresTable{ // { // Name: "events", @@ -88,7 +90,7 @@ // }, // }, nil // }, -// }) +// })) // // Manual pruning via type assertion (administrative helpers): // diff --git a/sinks/helpers.go b/sinks/helpers.go index 0f1a933..e18338a 100644 --- a/sinks/helpers.go +++ b/sinks/helpers.go @@ -26,21 +26,10 @@ func requireStringParam(cfg config.SinkConfig, key string) (string, error) { return s, nil } -// RegisterPostgresSchemaForConfiguredSinks registers one Postgres schema for each -// configured sink using driver=postgres. -func RegisterPostgresSchemaForConfiguredSinks(cfg *config.Config, schema PostgresSchema) error { - if cfg == nil { - return fmt.Errorf("register postgres schemas: config is nil") +// PostgresFactory returns a sink factory that builds the built-in Postgres sink +// using the provided downstream schema definition. +func PostgresFactory(schema PostgresSchema) Factory { + return func(cfg config.SinkConfig) (Sink, error) { + return NewPostgresSinkFromConfig(cfg, schema) } - - for i, sk := range cfg.Sinks { - if !strings.EqualFold(strings.TrimSpace(sk.Driver), "postgres") { - continue - } - if err := RegisterPostgresSchema(sk.Name, schema); err != nil { - return fmt.Errorf("register postgres schema for sinks[%d] name=%q: %w", i, sk.Name, err) - } - } - - return nil } diff --git a/sinks/helpers_test.go b/sinks/helpers_test.go index 8e318bd..6f38076 100644 --- a/sinks/helpers_test.go +++ b/sinks/helpers_test.go @@ -2,55 +2,19 @@ package sinks import ( "context" - "fmt" - "strings" "testing" - "time" "gitea.maximumdirect.net/ejr/feedkit/config" "gitea.maximumdirect.net/ejr/feedkit/event" ) -func TestRegisterPostgresSchemaForConfiguredSinksNilConfig(t *testing.T) { - err := RegisterPostgresSchemaForConfiguredSinks(nil, testPostgresSchema()) - if err == nil { - t.Fatalf("RegisterPostgresSchemaForConfiguredSinks(nil) expected error") +func TestPostgresFactoryReturnsWorkingFactory(t *testing.T) { + factory := PostgresFactory(testPostgresSchema()) + if factory == nil { + t.Fatalf("PostgresFactory() returned nil") } - if !strings.Contains(err.Error(), "config is nil") { - t.Fatalf("error = %q, want config is nil", err) - } -} - -func TestRegisterPostgresSchemaForConfiguredSinksNonPostgresNoOp(t *testing.T) { - cfg := &config.Config{ - Sinks: []config.SinkConfig{ - {Name: uniqueSinkName("stdout"), Driver: "stdout"}, - {Name: uniqueSinkName("nats"), Driver: "nats"}, - }, - } - - if err := RegisterPostgresSchemaForConfiguredSinks(cfg, testPostgresSchema()); err != nil { - t.Fatalf("RegisterPostgresSchemaForConfiguredSinks(non-postgres) error = %v", err) - } -} - -func TestRegisterPostgresSchemaForConfiguredSinksDuplicateRegistrationFails(t *testing.T) { - cfg := &config.Config{ - Sinks: []config.SinkConfig{ - {Name: uniqueSinkName("pg"), Driver: "postgres"}, - }, - } - - if err := RegisterPostgresSchemaForConfiguredSinks(cfg, testPostgresSchema()); err != nil { - t.Fatalf("first RegisterPostgresSchemaForConfiguredSinks() error = %v", err) - } - - err := RegisterPostgresSchemaForConfiguredSinks(cfg, testPostgresSchema()) - if err == nil { - t.Fatalf("second RegisterPostgresSchemaForConfiguredSinks() expected duplicate error") - } - if !strings.Contains(err.Error(), "already registered") { - t.Fatalf("error = %q, want already registered", err) + if _, err := factory(config.SinkConfig{}); err == nil { + t.Fatalf("factory(config) expected parameter validation error") } } @@ -80,7 +44,3 @@ func testPostgresSchema() PostgresSchema { }, } } - -func uniqueSinkName(prefix string) string { - return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano()) -} diff --git a/sinks/postgres.go b/sinks/postgres.go index ed2afb1..5a0ea4a 100644 --- a/sinks/postgres.go +++ b/sinks/postgres.go @@ -88,7 +88,7 @@ type PostgresSink struct { pruneWindow time.Duration } -func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) { +func NewPostgresSinkFromConfig(cfg config.SinkConfig, schemaDef PostgresSchema) (Sink, error) { uri, err := requireStringParam(cfg, "uri") if err != nil { return nil, err @@ -106,9 +106,9 @@ func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) { return nil, err } - schema, ok := lookupPostgresSchema(cfg.Name) - if !ok { - return nil, fmt.Errorf("postgres sink %q: no schema registered (call sinks.RegisterPostgresSchema before building sinks)", cfg.Name) + schema, err := compilePostgresSchema(schemaDef) + if err != nil { + return nil, fmt.Errorf("postgres sink %q: compile schema: %w", cfg.Name, err) } dsn, err := buildPostgresDSN(uri, username, password) diff --git a/sinks/postgres_schema.go b/sinks/postgres_schema.go index 88bda3a..b8757b0 100644 --- a/sinks/postgres_schema.go +++ b/sinks/postgres_schema.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "strings" - "sync" "time" "gitea.maximumdirect.net/ejr/feedkit/event" @@ -72,51 +71,6 @@ type postgresTableCompiled struct { indexes []PostgresIndex } -var ( - postgresSchemaRegistryMu sync.RWMutex - postgresSchemaRegistry = map[string]postgresSchemaCompiled{} -) - -// RegisterPostgresSchema registers one downstream schema by sink name. -// -// This should be called by downstream daemon wiring code before sink -// construction. Duplicate sink-name registrations are rejected. -func RegisterPostgresSchema(sinkName string, schema PostgresSchema) error { - sinkName = strings.TrimSpace(sinkName) - if sinkName == "" { - return fmt.Errorf("postgres schema: sink name cannot be empty") - } - - compiled, err := compilePostgresSchema(schema) - if err != nil { - return err - } - - postgresSchemaRegistryMu.Lock() - defer postgresSchemaRegistryMu.Unlock() - - if _, exists := postgresSchemaRegistry[sinkName]; exists { - return fmt.Errorf("postgres schema: sink %q already registered", sinkName) - } - - postgresSchemaRegistry[sinkName] = compiled - return nil -} - -func MustRegisterPostgresSchema(sinkName string, schema PostgresSchema) { - if err := RegisterPostgresSchema(sinkName, schema); err != nil { - panic(err) - } -} - -func lookupPostgresSchema(sinkName string) (postgresSchemaCompiled, bool) { - postgresSchemaRegistryMu.RLock() - defer postgresSchemaRegistryMu.RUnlock() - - s, ok := postgresSchemaRegistry[sinkName] - return s, ok -} - func compilePostgresSchema(schema PostgresSchema) (postgresSchemaCompiled, error) { if schema.MapEvent == nil { return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: map function is required") diff --git a/sinks/postgres_test.go b/sinks/postgres_test.go index 4353dbe..2ff98d0 100644 --- a/sinks/postgres_test.go +++ b/sinks/postgres_test.go @@ -96,20 +96,12 @@ func (d *fakeDB) Close() error { return nil } -func resetPostgresSchemaRegistryForTest() { - postgresSchemaRegistryMu.Lock() - defer postgresSchemaRegistryMu.Unlock() - postgresSchemaRegistry = map[string]postgresSchemaCompiled{} -} - func withPostgresTestState(t *testing.T) { t.Helper() - resetPostgresSchemaRegistryForTest() oldOpen := openPostgresDB t.Cleanup(func() { openPostgresDB = oldOpen - resetPostgresSchemaRegistryForTest() }) } @@ -183,35 +175,8 @@ func mustCompileSchema(t *testing.T, s PostgresSchema) postgresSchemaCompiled { return compiled } -func TestRegisterPostgresSchema(t *testing.T) { - withPostgresTestState(t) - - err := RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { - return nil, nil - })) - if err != nil { - t.Fatalf("register schema: %v", err) - } - - if _, ok := lookupPostgresSchema("pg"); !ok { - t.Fatalf("expected schema registration") - } - - err = RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { - return nil, nil - })) - if err == nil { - t.Fatalf("expected duplicate registration error") - } - if !strings.Contains(err.Error(), "already registered") { - t.Fatalf("unexpected duplicate error: %v", err) - } -} - -func TestRegisterPostgresSchema_RejectsInvalidSchema(t *testing.T) { - withPostgresTestState(t) - - err := RegisterPostgresSchema("pg", PostgresSchema{ +func TestCompilePostgresSchemaRejectsInvalidSchema(t *testing.T) { + _, err := compilePostgresSchema(PostgresSchema{ Tables: []PostgresTable{ { Name: "events", @@ -230,7 +195,7 @@ func TestRegisterPostgresSchema_RejectsInvalidSchema(t *testing.T) { t.Fatalf("unexpected schema validation error: %v", err) } - err = RegisterPostgresSchema("pg2", PostgresSchema{ + _, err = compilePostgresSchema(PostgresSchema{ Tables: []PostgresTable{ { Name: "events", @@ -254,7 +219,50 @@ func TestRegisterPostgresSchema_RejectsInvalidSchema(t *testing.T) { } } -func TestNewPostgresSinkFromConfig_MissingParams(t *testing.T) { +func TestPostgresFactoryBuildsMultipleSinksWithSameSchema(t *testing.T) { + withPostgresTestState(t) + + dbs := []*fakeDB{{}, {}} + var gotDSNs []string + openPostgresDB = func(dsn string) (postgresDB, error) { + gotDSNs = append(gotDSNs, dsn) + db := dbs[len(gotDSNs)-1] + return db, nil + } + + factory := PostgresFactory(schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { + return nil, nil + })) + + for _, name := range []string{"pg_a", "pg_b"} { + sink, err := factory(config.SinkConfig{ + Name: name, + Driver: "postgres", + Params: map[string]any{ + "uri": "postgres://localhost/db", + "username": "user", + "password": "pass", + }, + }) + if err != nil { + t.Fatalf("factory(%q) error = %v", name, err) + } + if sink == nil { + t.Fatalf("factory(%q) returned nil sink", name) + } + } + + if len(gotDSNs) != 2 { + t.Fatalf("len(gotDSNs) = %d, want 2", len(gotDSNs)) + } + for i, db := range dbs { + if db.pingCalls != 1 { + t.Fatalf("db[%d] pingCalls = %d, want 1", i, db.pingCalls) + } + } +} + +func TestNewPostgresSinkFromConfigMissingParams(t *testing.T) { withPostgresTestState(t) tests := []struct { @@ -273,7 +281,7 @@ func TestNewPostgresSinkFromConfig_MissingParams(t *testing.T) { Name: "pg", Driver: "postgres", Params: tc.params, - }) + }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })) if err == nil { t.Fatalf("expected error") } @@ -284,7 +292,7 @@ func TestNewPostgresSinkFromConfig_MissingParams(t *testing.T) { } } -func TestNewPostgresSinkFromConfig_MissingSchemaRegistration(t *testing.T) { +func TestNewPostgresSinkFromConfigRejectsInvalidSchema(t *testing.T) { withPostgresTestState(t) _, err := NewPostgresSinkFromConfig(config.SinkConfig{ @@ -295,25 +303,29 @@ func TestNewPostgresSinkFromConfig_MissingSchemaRegistration(t *testing.T) { "username": "user", "password": "pass", }, + }, PostgresSchema{ + Tables: []PostgresTable{ + { + Name: "events", + Columns: []PostgresColumn{ + {Name: "id", Type: "TEXT", Nullable: false}, + }, + PruneColumn: "missing_col", + }, + }, + MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }, }) if err == nil { - t.Fatalf("expected error") + t.Fatalf("expected invalid schema error") } - if !strings.Contains(err.Error(), "no schema registered") { + if !strings.Contains(err.Error(), "compile schema") { t.Fatalf("unexpected error: %v", err) } } -func TestNewPostgresSinkFromConfig_EagerInit(t *testing.T) { +func TestNewPostgresSinkFromConfigEagerInit(t *testing.T) { withPostgresTestState(t) - err := RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { - return nil, nil - })) - if err != nil { - t.Fatalf("register schema: %v", err) - } - db := &fakeDB{} var gotDSN string openPostgresDB = func(dsn string) (postgresDB, error) { @@ -329,7 +341,7 @@ func TestNewPostgresSinkFromConfig_EagerInit(t *testing.T) { "username": "app_user", "password": "app_pass", }, - }) + }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })) if err != nil { t.Fatalf("new postgres sink: %v", err) } @@ -363,22 +375,15 @@ func TestNewPostgresSinkFromConfig_EagerInit(t *testing.T) { } } -func TestNewPostgresSinkFromConfig_InitFailureClosesDB(t *testing.T) { +func TestNewPostgresSinkFromConfigInitFailureClosesDB(t *testing.T) { withPostgresTestState(t) - err := RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { - return nil, nil - })) - if err != nil { - t.Fatalf("register schema: %v", err) - } - db := &fakeDB{execErrOnCall: 1, execErr: errors.New("ddl failed")} openPostgresDB = func(_ string) (postgresDB, error) { return db, nil } - _, err = NewPostgresSinkFromConfig(config.SinkConfig{ + _, err := NewPostgresSinkFromConfig(config.SinkConfig{ Name: "pg", Driver: "postgres", Params: map[string]any{ @@ -386,7 +391,7 @@ func TestNewPostgresSinkFromConfig_InitFailureClosesDB(t *testing.T) { "username": "user", "password": "pass", }, - }) + }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })) if err == nil { t.Fatalf("expected init error") } @@ -395,7 +400,7 @@ func TestNewPostgresSinkFromConfig_InitFailureClosesDB(t *testing.T) { } } -func TestNewPostgresSinkFromConfig_PruneParamAccepted(t *testing.T) { +func TestNewPostgresSinkFromConfigPruneParamAccepted(t *testing.T) { tests := []struct { name string in string @@ -410,13 +415,6 @@ func TestNewPostgresSinkFromConfig_PruneParamAccepted(t *testing.T) { t.Run(tc.name, func(t *testing.T) { withPostgresTestState(t) - err := RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { - return nil, nil - })) - if err != nil { - t.Fatalf("register schema: %v", err) - } - openPostgresDB = func(_ string) (postgresDB, error) { return &fakeDB{}, nil } @@ -430,7 +428,7 @@ func TestNewPostgresSinkFromConfig_PruneParamAccepted(t *testing.T) { "password": "pass", "prune": tc.in, }, - }) + }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })) if err != nil { t.Fatalf("new postgres sink: %v", err) } @@ -446,7 +444,7 @@ func TestNewPostgresSinkFromConfig_PruneParamAccepted(t *testing.T) { } } -func TestNewPostgresSinkFromConfig_PruneParamRejected(t *testing.T) { +func TestNewPostgresSinkFromConfigPruneParamRejected(t *testing.T) { withPostgresTestState(t) tests := []struct { @@ -472,7 +470,7 @@ func TestNewPostgresSinkFromConfig_PruneParamRejected(t *testing.T) { "password": "pass", "prune": tc.in, }, - }) + }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })) if err == nil { t.Fatalf("expected error") } @@ -483,7 +481,7 @@ func TestNewPostgresSinkFromConfig_PruneParamRejected(t *testing.T) { } } -func TestPostgresSinkConsume_InvalidEvent(t *testing.T) { +func TestPostgresSinkConsumeInvalidEvent(t *testing.T) { db := &fakeDB{} called := 0 sink := &PostgresSink{ @@ -507,7 +505,7 @@ func TestPostgresSinkConsume_InvalidEvent(t *testing.T) { } } -func TestPostgresSinkConsume_UnmappedEventIsNoOp(t *testing.T) { +func TestPostgresSinkConsumeUnmappedEventIsNoOp(t *testing.T) { db := &fakeDB{} sink := &PostgresSink{ name: "pg", @@ -525,7 +523,7 @@ func TestPostgresSinkConsume_UnmappedEventIsNoOp(t *testing.T) { } } -func TestPostgresSinkConsume_OneEventWritesMultipleTablesAtomically(t *testing.T) { +func TestPostgresSinkConsumeOneEventWritesMultipleTablesAtomically(t *testing.T) { tx := &fakeTx{} db := &fakeDB{tx: tx} sink := &PostgresSink{ @@ -556,7 +554,7 @@ func TestPostgresSinkConsume_OneEventWritesMultipleTablesAtomically(t *testing.T } } -func TestPostgresSinkConsume_InsertFailureRollsBack(t *testing.T) { +func TestPostgresSinkConsumeInsertFailureRollsBack(t *testing.T) { tx := &fakeTx{execErrOnCall: 2, execErr: errors.New("duplicate key")} db := &fakeDB{tx: tx} sink := &PostgresSink{ @@ -585,13 +583,13 @@ func TestPostgresSinkConsume_InsertFailureRollsBack(t *testing.T) { } } -func TestPostgresSinkConsume_AutoPruneRunsInSameTransaction(t *testing.T) { +func TestPostgresSinkConsumeAutoPruneRunsInSameTransaction(t *testing.T) { tx := &fakeTx{} db := &fakeDB{tx: tx} sink := &PostgresSink{ - name: "pg", - db: db, - schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { + name: "pg", + db: db, + schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { return []PostgresWrite{ {Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}}, {Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}}, @@ -620,13 +618,13 @@ func TestPostgresSinkConsume_AutoPruneRunsInSameTransaction(t *testing.T) { } } -func TestPostgresSinkConsume_AutoPruneFailureRollsBack(t *testing.T) { +func TestPostgresSinkConsumeAutoPruneFailureRollsBack(t *testing.T) { tx := &fakeTx{execErrOnCall: 3, execErr: errors.New("prune failed")} db := &fakeDB{tx: tx} sink := &PostgresSink{ - name: "pg", - db: db, - schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { + name: "pg", + db: db, + schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { return []PostgresWrite{ {Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}}, {Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}}, @@ -650,7 +648,7 @@ func TestPostgresSinkConsume_AutoPruneFailureRollsBack(t *testing.T) { } } -func TestPostgresSinkPrune_PerTable(t *testing.T) { +func TestPostgresSinkPrunePerTable(t *testing.T) { db := &fakeDB{execRows: 7} sink := &PostgresSink{ name: "pg", @@ -693,7 +691,7 @@ func TestPostgresSinkPrune_PerTable(t *testing.T) { } } -func TestPostgresSinkPrune_AllTables(t *testing.T) { +func TestPostgresSinkPruneAllTables(t *testing.T) { db := &fakeDB{execRows: 3} sink := &PostgresSink{ name: "pg", @@ -724,7 +722,7 @@ func TestPostgresSinkPrune_AllTables(t *testing.T) { } } -func TestPostgresSinkPrune_Errors(t *testing.T) { +func TestPostgresSinkPruneErrors(t *testing.T) { db := &fakeDB{} sink := &PostgresSink{ name: "pg", diff --git a/sinks/registry_test.go b/sinks/registry_test.go index ba18d00..3ae23ec 100644 --- a/sinks/registry_test.go +++ b/sinks/registry_test.go @@ -112,15 +112,15 @@ func TestRegisterBuiltinsExposesExpectedDrivers(t *testing.T) { r := NewRegistry() RegisterBuiltins(r) - if len(r.byDriver) != 3 { - t.Fatalf("len(byDriver) = %d, want 3", len(r.byDriver)) + if len(r.byDriver) != 2 { + t.Fatalf("len(byDriver) = %d, want 2", len(r.byDriver)) } - for _, driver := range []string{"stdout", "nats", "postgres"} { + for _, driver := range []string{"stdout", "nats"} { if _, ok := r.byDriver[driver]; !ok { t.Fatalf("builtins missing driver %q", driver) } } - if _, ok := r.byDriver["file"]; ok { - t.Fatalf("builtins unexpectedly registered file driver") + if _, ok := r.byDriver["postgres"]; ok { + t.Fatalf("builtins unexpectedly registered postgres driver") } } diff --git a/sources/doc.go b/sources/doc.go index 007c2b2..befd5a6 100644 --- a/sources/doc.go +++ b/sources/doc.go @@ -25,6 +25,9 @@ // - params.url // - params.user_agent // - params.conditional (optional, default true) +// - params.http_timeout (optional, default transport.DefaultHTTPTimeout) +// - params.http_response_body_limit_bytes (optional, default +// transport.DefaultHTTPResponseBodyLimitBytes) // // When validators are available, NewHTTPSource prefers ETag/If-None-Match and // falls back to Last-Modified/If-Modified-Since. A 304 Not Modified response is diff --git a/sources/helpers.go b/sources/helpers.go index 2984ab9..f41fb7a 100644 --- a/sources/helpers.go +++ b/sources/helpers.go @@ -127,11 +127,6 @@ func advertisedSourceKinds(in Input) map[event.Kind]bool { return kinds } - if ks, ok := in.(KindSource); ok { - kinds[ks.Kind()] = true - return kinds - } - return nil } diff --git a/sources/helpers_test.go b/sources/helpers_test.go index 57d4f8a..ad99114 100644 --- a/sources/helpers_test.go +++ b/sources/helpers_test.go @@ -15,13 +15,6 @@ type testInput struct { func (s testInput) Name() string { return s.name } -type testKindSource struct { - testInput - kind event.Kind -} - -func (s testKindSource) Kind() event.Kind { return s.kind } - type testKindsSource struct { testInput kinds []event.Kind @@ -29,18 +22,6 @@ type testKindsSource struct { func (s testKindsSource) Kinds() []event.Kind { return s.kinds } -func TestValidateExpectedKindsLegacyKindFallback(t *testing.T) { - cfg := config.SourceConfig{Kind: "observation"} - in := testKindSource{ - testInput: testInput{name: "test"}, - kind: event.Kind("observation"), - } - - if err := ValidateExpectedKinds(cfg, in); err != nil { - t.Fatalf("ValidateExpectedKinds() unexpected error: %v", err) - } -} - func TestValidateExpectedKindsSubsetAllowed(t *testing.T) { cfg := config.SourceConfig{Kinds: []string{"observation"}} in := testKindsSource{ diff --git a/sources/http.go b/sources/http.go index c124e9d..2a89183 100644 --- a/sources/http.go +++ b/sources/http.go @@ -19,13 +19,14 @@ import ( // setup, and conditional GET validator handling. Concrete daemon sources remain // responsible for decoding the response body and constructing events. type HTTPSource struct { - Driver string - Name string - URL string - UserAgent string - Accept string - Conditional bool - Client *http.Client + Driver string + Name string + URL string + UserAgent string + Accept string + Conditional bool + ResponseBodyLimitBytes int64 + Client *http.Client mu sync.Mutex validators transport.HTTPValidators @@ -68,14 +69,33 @@ func NewHTTPSource(driver string, cfg config.SourceConfig, accept string) (*HTTP } } + timeout := transport.DefaultHTTPTimeout + if _, exists := cfg.Params["http_timeout"]; exists { + var ok bool + timeout, ok = cfg.ParamDuration("http_timeout") + if !ok || timeout <= 0 { + return nil, fmt.Errorf("source %q: params.http_timeout must be a positive duration", cfg.Name) + } + } + + bodyLimit := transport.DefaultHTTPResponseBodyLimitBytes + if _, exists := cfg.Params["http_response_body_limit_bytes"]; exists { + rawLimit, ok := cfg.ParamInt("http_response_body_limit_bytes") + if !ok || rawLimit <= 0 { + return nil, fmt.Errorf("source %q: params.http_response_body_limit_bytes must be a positive integer", cfg.Name) + } + bodyLimit = int64(rawLimit) + } + return &HTTPSource{ - Driver: driver, - Name: name, - URL: url, - UserAgent: userAgent, - Accept: accept, - Conditional: conditional, - Client: transport.NewHTTPClient(transport.DefaultHTTPTimeout), + Driver: driver, + Name: name, + URL: url, + UserAgent: userAgent, + Accept: accept, + Conditional: conditional, + ResponseBodyLimitBytes: bodyLimit, + Client: transport.NewHTTPClient(timeout), }, nil } @@ -92,7 +112,12 @@ func (s *HTTPSource) FetchBytesIfChanged(ctx context.Context) ([]byte, bool, err validators := s.validators s.mu.Unlock() - body, changed, next, err := transport.FetchBodyIfChanged( + bodyLimit := s.ResponseBodyLimitBytes + if bodyLimit <= 0 { + bodyLimit = transport.DefaultHTTPResponseBodyLimitBytes + } + + body, changed, next, err := transport.FetchBodyIfChangedWithLimit( ctx, client, s.URL, @@ -100,6 +125,7 @@ func (s *HTTPSource) FetchBytesIfChanged(ctx context.Context) ([]byte, bool, err s.Accept, s.Conditional, validators, + bodyLimit, ) if err != nil { return nil, false, fmt.Errorf("%s %q: %w", s.Driver, s.Name, err) diff --git a/sources/http_test.go b/sources/http_test.go index 2b8341b..f0990a9 100644 --- a/sources/http_test.go +++ b/sources/http_test.go @@ -6,8 +6,10 @@ import ( "net/http/httptest" "strings" "testing" + "time" "gitea.maximumdirect.net/ejr/feedkit/config" + "gitea.maximumdirect.net/ejr/feedkit/transport" ) func TestNewHTTPSourceConditionalDefaultsTrue(t *testing.T) { @@ -63,6 +65,81 @@ func TestNewHTTPSourceConditionalCanBeExplicitlyFalse(t *testing.T) { } } +func TestNewHTTPSourceHTTPTimeoutOverride(t *testing.T) { + src, err := NewHTTPSource("test_driver", config.SourceConfig{ + Name: "test-source", + Driver: "test_driver", + Params: map[string]any{ + "url": "https://example.invalid", + "user_agent": "test-agent", + "http_timeout": "250ms", + }, + }, "application/json") + if err != nil { + t.Fatalf("NewHTTPSource() error = %v", err) + } + if src.Client == nil { + t.Fatalf("Client = nil") + } + if src.Client.Timeout != 250*time.Millisecond { + t.Fatalf("Client.Timeout = %s, want 250ms", src.Client.Timeout) + } +} + +func TestNewHTTPSourceBodyLimitOverride(t *testing.T) { + src, err := NewHTTPSource("test_driver", config.SourceConfig{ + Name: "test-source", + Driver: "test_driver", + Params: map[string]any{ + "url": "https://example.invalid", + "user_agent": "test-agent", + "http_response_body_limit_bytes": 12345, + }, + }, "application/json") + if err != nil { + t.Fatalf("NewHTTPSource() error = %v", err) + } + if src.ResponseBodyLimitBytes != 12345 { + t.Fatalf("ResponseBodyLimitBytes = %d, want 12345", src.ResponseBodyLimitBytes) + } +} + +func TestNewHTTPSourceRejectsInvalidHTTPTimeout(t *testing.T) { + _, err := NewHTTPSource("test_driver", config.SourceConfig{ + Name: "test-source", + Driver: "test_driver", + Params: map[string]any{ + "url": "https://example.invalid", + "user_agent": "test-agent", + "http_timeout": "soon", + }, + }, "application/json") + if err == nil { + t.Fatalf("NewHTTPSource() error = nil, want error") + } + if !strings.Contains(err.Error(), "params.http_timeout must be a positive duration") { + t.Fatalf("NewHTTPSource() error = %q", err) + } +} + +func TestNewHTTPSourceRejectsInvalidBodyLimit(t *testing.T) { + _, err := NewHTTPSource("test_driver", config.SourceConfig{ + Name: "test-source", + Driver: "test_driver", + Params: map[string]any{ + "url": "https://example.invalid", + "user_agent": "test-agent", + "http_response_body_limit_bytes": "abc", + }, + }, "application/json") + if err == nil { + t.Fatalf("NewHTTPSource() error = nil, want error") + } + if !strings.Contains(err.Error(), "params.http_response_body_limit_bytes must be a positive integer") { + t.Fatalf("NewHTTPSource() error = %q", err) + } +} + func TestHTTPSourceFetchJSONIfChanged(t *testing.T) { var call int srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -116,3 +193,68 @@ func TestHTTPSourceFetchJSONIfChanged(t *testing.T) { t.Fatalf("second FetchJSONIfChanged() body = %q, want nil", string(raw)) } } + +func TestHTTPSourceFetchJSONIfChangedHonorsBodyLimitOverride(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + src, err := NewHTTPSource("test_driver", config.SourceConfig{ + Name: "test-source", + Driver: "test_driver", + Params: map[string]any{ + "url": srv.URL, + "user_agent": "test-agent", + "http_response_body_limit_bytes": 4, + }, + }, "application/json") + if err != nil { + t.Fatalf("NewHTTPSource() error = %v", err) + } + + _, _, err = src.FetchJSONIfChanged(context.Background()) + if err == nil { + t.Fatalf("FetchJSONIfChanged() error = nil, want limit error") + } + if !strings.Contains(err.Error(), "response body too large") { + t.Fatalf("FetchJSONIfChanged() error = %q", err) + } +} + +func TestNewHTTPSourceUsesDefaultBodyLimitWhenUnset(t *testing.T) { + src, err := NewHTTPSource("test_driver", config.SourceConfig{ + Name: "test-source", + Driver: "test_driver", + Params: map[string]any{ + "url": "https://example.invalid", + "user_agent": "test-agent", + }, + }, "application/json") + if err != nil { + t.Fatalf("NewHTTPSource() error = %v", err) + } + if src.ResponseBodyLimitBytes != transport.DefaultHTTPResponseBodyLimitBytes { + t.Fatalf("ResponseBodyLimitBytes = %d, want %d", src.ResponseBodyLimitBytes, transport.DefaultHTTPResponseBodyLimitBytes) + } +} + +func TestNewHTTPSourceUsesDefaultTimeoutWhenUnset(t *testing.T) { + src, err := NewHTTPSource("test_driver", config.SourceConfig{ + Name: "test-source", + Driver: "test_driver", + Params: map[string]any{ + "url": "https://example.invalid", + "user_agent": "test-agent", + }, + }, "application/json") + if err != nil { + t.Fatalf("NewHTTPSource() error = %v", err) + } + if src.Client == nil { + t.Fatalf("Client = nil") + } + if src.Client.Timeout != transport.DefaultHTTPTimeout { + t.Fatalf("Client.Timeout = %s, want %s", src.Client.Timeout, transport.DefaultHTTPTimeout) + } +} diff --git a/sources/registry.go b/sources/registry.go index c141933..0ab3bf4 100644 --- a/sources/registry.go +++ b/sources/registry.go @@ -15,9 +15,6 @@ import ( type PollFactory func(cfg config.SourceConfig) (PollSource, error) type StreamFactory func(cfg config.SourceConfig) (StreamSource, error) -// Factory is the legacy alias for poll source factories. -type Factory = PollFactory - type Registry struct { byPollDriver map[string]PollFactory byStreamDriver map[string]StreamFactory @@ -30,13 +27,6 @@ func NewRegistry() *Registry { } } -// Register associates a driver name (e.g. "openmeteo_observation") with a factory. -// -// The driver string is the "lookup key" used by config.sources[].driver. -func (r *Registry) Register(driver string, f PollFactory) { - r.RegisterPoll(driver, f) -} - // RegisterPoll associates a driver name with a polling-source factory. func (r *Registry) RegisterPoll(driver string, f PollFactory) { driver = strings.TrimSpace(driver) @@ -75,11 +65,6 @@ func (r *Registry) RegisterStream(driver string, f StreamFactory) { r.byStreamDriver[driver] = f } -// Build constructs a polling source from a SourceConfig by looking up cfg.Driver. -func (r *Registry) Build(cfg config.SourceConfig) (PollSource, error) { - return r.BuildPoll(cfg) -} - // BuildPoll constructs a polling source from a SourceConfig by looking up cfg.Driver. func (r *Registry) BuildPoll(cfg config.SourceConfig) (PollSource, error) { driver := strings.TrimSpace(cfg.Driver) diff --git a/sources/source.go b/sources/source.go index 908df44..efb954b 100644 --- a/sources/source.go +++ b/sources/source.go @@ -31,9 +31,6 @@ type PollSource interface { Poll(ctx context.Context) ([]event.Event, error) } -// Source is a compatibility alias for the legacy polling-source name. -type Source = PollSource - // StreamSource is an event-driven source (NATS/RabbitMQ/MQTT/etc). // // Run should block, producing events into `out` until ctx is cancelled or a fatal error occurs. @@ -43,12 +40,6 @@ type StreamSource interface { Run(ctx context.Context, out chan<- event.Event) error } -// KindSource is an optional interface for sources that advertise one "primary" kind. -// This is legacy-friendly but no longer required. -type KindSource interface { - Kind() event.Kind -} - // KindsSource is an optional interface for sources that advertise multiple kinds. type KindsSource interface { Kinds() []event.Kind diff --git a/transport/http.go b/transport/http.go index 79d068c..599d5a6 100644 --- a/transport/http.go +++ b/transport/http.go @@ -10,10 +10,10 @@ import ( "time" ) -// maxResponseBodyBytes is a hard safety limit on HTTP response bodies. +// DefaultHTTPResponseBodyLimitBytes is a hard safety limit on HTTP response bodies. // API responses should be small, so this protects us from accidental // or malicious large responses. -const maxResponseBodyBytes = 2 << 21 // 4 MiB +const DefaultHTTPResponseBodyLimitBytes int64 = 2 << 21 // 4 MiB // DefaultHTTPTimeout is the standard timeout used by HTTP sources. // Individual drivers may override this if they have a specific need. @@ -29,6 +29,10 @@ func NewHTTPClient(timeout time.Duration) *http.Client { } func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept string) ([]byte, error) { + return FetchBodyWithLimit(ctx, client, url, userAgent, accept, DefaultHTTPResponseBodyLimitBytes) +} + +func FetchBodyWithLimit(ctx context.Context, client *http.Client, url, userAgent, accept string, bodyLimitBytes int64) ([]byte, error) { res, err := doRequest(ctx, client, http.MethodGet, url, userAgent, accept, "", "") if err != nil { return nil, err @@ -39,7 +43,7 @@ func FetchBody(ctx context.Context, client *http.Client, url, userAgent, accept return nil, fmt.Errorf("HTTP %s", res.Status) } - return readValidatedBody(res.Body) + return readValidatedBody(res.Body, bodyLimitBytes) } // HTTPValidators are cache validators learned from prior successful GET responses. @@ -68,6 +72,17 @@ func FetchBodyIfChanged( url, userAgent, accept string, conditional bool, validators HTTPValidators, +) ([]byte, bool, HTTPValidators, error) { + return FetchBodyIfChangedWithLimit(ctx, client, url, userAgent, accept, conditional, validators, DefaultHTTPResponseBodyLimitBytes) +} + +func FetchBodyIfChangedWithLimit( + ctx context.Context, + client *http.Client, + url, userAgent, accept string, + conditional bool, + validators HTTPValidators, + bodyLimitBytes int64, ) ([]byte, bool, HTTPValidators, error) { headerName, headerValue := conditionalHeader(conditional, validators) @@ -89,7 +104,7 @@ func FetchBodyIfChanged( } } - b, err := readValidatedBody(res.Body) + b, err := readValidatedBody(res.Body, bodyLimitBytes) if err != nil { return nil, false, validators, err } @@ -150,9 +165,13 @@ func refreshValidators(current HTTPValidators, header http.Header) HTTPValidator return current } -func readValidatedBody(r io.Reader) ([]byte, error) { - // Read at most maxResponseBodyBytes + 1 so we can detect overflow. - limited := io.LimitReader(r, maxResponseBodyBytes+1) +func readValidatedBody(r io.Reader, bodyLimitBytes int64) ([]byte, error) { + if bodyLimitBytes <= 0 { + bodyLimitBytes = DefaultHTTPResponseBodyLimitBytes + } + + // Read at most bodyLimitBytes + 1 so we can detect overflow. + limited := io.LimitReader(r, bodyLimitBytes+1) b, err := io.ReadAll(limited) if err != nil { @@ -163,8 +182,8 @@ func readValidatedBody(r io.Reader) ([]byte, error) { return nil, fmt.Errorf("empty response body") } - if len(b) > maxResponseBodyBytes { - return nil, fmt.Errorf("response body too large (>%d bytes)", maxResponseBodyBytes) + if int64(len(b)) > bodyLimitBytes { + return nil, fmt.Errorf("response body too large (>%d bytes)", bodyLimitBytes) } return b, nil