Files
feedkit/sinks/postgres_test.go
Eric Rakestraw eb9a7cb349 Refactor feedkit boundaries ahead of v1
Remove global Postgres schema registration in favor of explicit schema-aware sink factory wiring, and update weatherfeeder to register the Postgres sink explicitly. Add optional per-source HTTP timeout and response body limit overrides while keeping feedkit defaults. Remove remaining legacy source/config compatibility surfaces, including singular kind support and old source registry/type aliases, and migrate weatherfeeder sources to plural `Kinds()` metadata. Clean up related docs, tests, and sample config to match the new Postgres, HTTP, and NATS configuration model.
2026-03-28 13:52:48 -05:00

742 lines
20 KiB
Go

package sinks
import (
"context"
"database/sql"
"errors"
"net/url"
"strings"
"testing"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
)
type fakeResult struct {
rows int64
}
func (r fakeResult) LastInsertId() (int64, error) { return 0, errors.New("unsupported") }
func (r fakeResult) RowsAffected() (int64, error) { return r.rows, nil }
type execCall struct {
query string
args []any
}
type fakeTx struct {
execCalls []execCall
execErr error
execErrOnCall int
execRows int64
commitErr error
rollbackErr error
commitCalls int
rollbackCalls int
}
func (t *fakeTx) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) {
t.execCalls = append(t.execCalls, execCall{query: query, args: append([]any(nil), args...)})
if t.execErr != nil && t.execErrOnCall == len(t.execCalls) {
return nil, t.execErr
}
return fakeResult{rows: t.execRows}, nil
}
func (t *fakeTx) Commit() error {
t.commitCalls++
return t.commitErr
}
func (t *fakeTx) Rollback() error {
t.rollbackCalls++
return t.rollbackErr
}
type fakeDB struct {
pingErr error
beginErr error
execErr error
execErrOnCall int
execRows int64
pingCalls int
beginCalls int
execCalls []execCall
closeCalls int
tx *fakeTx
}
func (d *fakeDB) PingContext(_ context.Context) error {
d.pingCalls++
return d.pingErr
}
func (d *fakeDB) BeginTx(_ context.Context, _ *sql.TxOptions) (postgresTx, error) {
d.beginCalls++
if d.beginErr != nil {
return nil, d.beginErr
}
if d.tx == nil {
d.tx = &fakeTx{}
}
return d.tx, nil
}
func (d *fakeDB) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) {
d.execCalls = append(d.execCalls, execCall{query: query, args: append([]any(nil), args...)})
if d.execErr != nil && d.execErrOnCall == len(d.execCalls) {
return nil, d.execErr
}
return fakeResult{rows: d.execRows}, nil
}
func (d *fakeDB) Close() error {
d.closeCalls++
return nil
}
func withPostgresTestState(t *testing.T) {
t.Helper()
oldOpen := openPostgresDB
t.Cleanup(func() {
openPostgresDB = oldOpen
})
}
func validTestEvent() event.Event {
now := time.Now().UTC()
return event.Event{
ID: "evt-1",
Kind: event.Kind("observation"),
Source: "source-1",
EmittedAt: now,
Payload: map[string]any{
"x": 1,
},
}
}
func schemaOneTable(mapFn PostgresMapFunc) PostgresSchema {
return PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "event_id", Type: "TEXT", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
{Name: "payload_json", Type: "JSONB", Nullable: false},
},
PrimaryKey: []string{"event_id"},
PruneColumn: "emitted_at",
Indexes: []PostgresIndex{
{Name: "idx_events_emitted_at", Columns: []string{"emitted_at"}},
},
},
},
MapEvent: mapFn,
}
}
func schemaTwoTables(mapFn PostgresMapFunc) PostgresSchema {
return PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "event_id", Type: "TEXT", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
},
PrimaryKey: []string{"event_id"},
PruneColumn: "emitted_at",
},
{
Name: "event_payloads",
Columns: []PostgresColumn{
{Name: "event_id", Type: "TEXT", Nullable: false},
{Name: "payload_json", Type: "JSONB", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
},
PrimaryKey: []string{"event_id"},
PruneColumn: "emitted_at",
},
},
MapEvent: mapFn,
}
}
func mustCompileSchema(t *testing.T, s PostgresSchema) postgresSchemaCompiled {
t.Helper()
compiled, err := compilePostgresSchema(s)
if err != nil {
t.Fatalf("compile schema: %v", err)
}
return compiled
}
func TestCompilePostgresSchemaRejectsInvalidSchema(t *testing.T) {
_, err := compilePostgresSchema(PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "id", Type: "TEXT", Nullable: false},
},
PruneColumn: "missing_col",
},
},
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
})
if err == nil {
t.Fatalf("expected invalid schema error")
}
if !strings.Contains(err.Error(), "prune column") {
t.Fatalf("unexpected schema validation error: %v", err)
}
_, err = compilePostgresSchema(PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "id", Type: "TEXT", Nullable: false},
{Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false},
},
PruneColumn: "emitted_at",
Indexes: []PostgresIndex{
{Name: "idx_events_empty", Columns: nil},
},
},
},
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
})
if err == nil {
t.Fatalf("expected invalid index schema error")
}
if !strings.Contains(err.Error(), "at least one column") {
t.Fatalf("unexpected index validation error: %v", err)
}
}
func TestPostgresFactoryBuildsMultipleSinksWithSameSchema(t *testing.T) {
withPostgresTestState(t)
dbs := []*fakeDB{{}, {}}
var gotDSNs []string
openPostgresDB = func(dsn string) (postgresDB, error) {
gotDSNs = append(gotDSNs, dsn)
db := dbs[len(gotDSNs)-1]
return db, nil
}
factory := PostgresFactory(schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
}))
for _, name := range []string{"pg_a", "pg_b"} {
sink, err := factory(config.SinkConfig{
Name: name,
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
},
})
if err != nil {
t.Fatalf("factory(%q) error = %v", name, err)
}
if sink == nil {
t.Fatalf("factory(%q) returned nil sink", name)
}
}
if len(gotDSNs) != 2 {
t.Fatalf("len(gotDSNs) = %d, want 2", len(gotDSNs))
}
for i, db := range dbs {
if db.pingCalls != 1 {
t.Fatalf("db[%d] pingCalls = %d, want 1", i, db.pingCalls)
}
}
}
func TestNewPostgresSinkFromConfigMissingParams(t *testing.T) {
withPostgresTestState(t)
tests := []struct {
name string
params map[string]any
want string
}{
{name: "missing uri", params: map[string]any{"username": "u", "password": "p"}, want: "params.uri"},
{name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p"}, want: "params.username"},
{name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u"}, want: "params.password"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: tc.params,
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err == nil {
t.Fatalf("expected error")
}
if !strings.Contains(err.Error(), tc.want) {
t.Fatalf("expected %q in error, got: %v", tc.want, err)
}
})
}
}
func TestNewPostgresSinkFromConfigRejectsInvalidSchema(t *testing.T) {
withPostgresTestState(t)
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
},
}, PostgresSchema{
Tables: []PostgresTable{
{
Name: "events",
Columns: []PostgresColumn{
{Name: "id", Type: "TEXT", Nullable: false},
},
PruneColumn: "missing_col",
},
},
MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil },
})
if err == nil {
t.Fatalf("expected invalid schema error")
}
if !strings.Contains(err.Error(), "compile schema") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNewPostgresSinkFromConfigEagerInit(t *testing.T) {
withPostgresTestState(t)
db := &fakeDB{}
var gotDSN string
openPostgresDB = func(dsn string) (postgresDB, error) {
gotDSN = dsn
return db, nil
}
s, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://db.example.local:5432/feedkit?sslmode=disable",
"username": "app_user",
"password": "app_pass",
},
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err != nil {
t.Fatalf("new postgres sink: %v", err)
}
if s == nil {
t.Fatalf("expected sink")
}
if db.pingCalls != 1 {
t.Fatalf("expected one ping, got %d", db.pingCalls)
}
if len(db.execCalls) != 2 {
t.Fatalf("expected 2 init exec calls (table + index), got %d", len(db.execCalls))
}
if !strings.Contains(db.execCalls[0].query, `CREATE TABLE IF NOT EXISTS "events"`) {
t.Fatalf("unexpected create table query: %s", db.execCalls[0].query)
}
if !strings.Contains(db.execCalls[1].query, `CREATE INDEX IF NOT EXISTS "idx_events_emitted_at"`) {
t.Fatalf("unexpected create index query: %s", db.execCalls[1].query)
}
u, err := url.Parse(gotDSN)
if err != nil {
t.Fatalf("parse dsn: %v", err)
}
if u.User == nil || u.User.Username() != "app_user" {
t.Fatalf("dsn missing username: %q", gotDSN)
}
pass, ok := u.User.Password()
if !ok || pass != "app_pass" {
t.Fatalf("dsn missing password: %q", gotDSN)
}
}
func TestNewPostgresSinkFromConfigInitFailureClosesDB(t *testing.T) {
withPostgresTestState(t)
db := &fakeDB{execErrOnCall: 1, execErr: errors.New("ddl failed")}
openPostgresDB = func(_ string) (postgresDB, error) {
return db, nil
}
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
},
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err == nil {
t.Fatalf("expected init error")
}
if db.closeCalls != 1 {
t.Fatalf("expected db close on init failure")
}
}
func TestNewPostgresSinkFromConfigPruneParamAccepted(t *testing.T) {
tests := []struct {
name string
in string
want time.Duration
}{
{name: "go duration", in: "72h", want: 72 * time.Hour},
{name: "days suffix", in: "3d", want: 72 * time.Hour},
{name: "weeks suffix", in: "2w", want: 14 * 24 * time.Hour},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
withPostgresTestState(t)
openPostgresDB = func(_ string) (postgresDB, error) {
return &fakeDB{}, nil
}
s, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
"prune": tc.in,
},
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err != nil {
t.Fatalf("new postgres sink: %v", err)
}
pg, ok := s.(*PostgresSink)
if !ok {
t.Fatalf("expected *PostgresSink, got %T", s)
}
if pg.pruneWindow != tc.want {
t.Fatalf("prune window = %s, want %s", pg.pruneWindow, tc.want)
}
})
}
}
func TestNewPostgresSinkFromConfigPruneParamRejected(t *testing.T) {
withPostgresTestState(t)
tests := []struct {
name string
in any
}{
{name: "empty", in: ""},
{name: "zero", in: "0"},
{name: "negative", in: "-1h"},
{name: "malformed", in: "abc"},
{name: "fractional day", in: "1.5d"},
{name: "wrong type", in: 5},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPostgresSinkFromConfig(config.SinkConfig{
Name: "pg",
Driver: "postgres",
Params: map[string]any{
"uri": "postgres://localhost/db",
"username": "user",
"password": "pass",
"prune": tc.in,
},
}, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }))
if err == nil {
t.Fatalf("expected error")
}
if !strings.Contains(err.Error(), "params.prune") {
t.Fatalf("expected params.prune error, got %v", err)
}
})
}
}
func TestPostgresSinkConsumeInvalidEvent(t *testing.T) {
db := &fakeDB{}
called := 0
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
called++
return nil, nil
})),
}
err := sink.Consume(context.Background(), event.Event{})
if err == nil {
t.Fatalf("expected invalid event error")
}
if !strings.Contains(err.Error(), "invalid event") {
t.Fatalf("unexpected error: %v", err)
}
if called != 0 {
t.Fatalf("expected mapper not called for invalid events")
}
}
func TestPostgresSinkConsumeUnmappedEventIsNoOp(t *testing.T) {
db := &fakeDB{}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
})),
}
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
t.Fatalf("consume: %v", err)
}
if db.beginCalls != 0 {
t.Fatalf("expected no transaction for unmapped events")
}
}
func TestPostgresSinkConsumeOneEventWritesMultipleTablesAtomically(t *testing.T) {
tx := &fakeTx{}
db := &fakeDB{tx: tx}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
}, nil
})),
}
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
t.Fatalf("consume: %v", err)
}
if db.beginCalls != 1 {
t.Fatalf("expected one transaction begin, got %d", db.beginCalls)
}
if len(tx.execCalls) != 2 {
t.Fatalf("expected 2 insert statements, got %d", len(tx.execCalls))
}
if tx.commitCalls != 1 {
t.Fatalf("expected one commit, got %d", tx.commitCalls)
}
if tx.rollbackCalls != 0 {
t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls)
}
}
func TestPostgresSinkConsumeInsertFailureRollsBack(t *testing.T) {
tx := &fakeTx{execErrOnCall: 2, execErr: errors.New("duplicate key")}
db := &fakeDB{tx: tx}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
}, nil
})),
}
err := sink.Consume(context.Background(), validTestEvent())
if err == nil {
t.Fatalf("expected insert error")
}
if !strings.Contains(err.Error(), "insert into") {
t.Fatalf("unexpected error: %v", err)
}
if tx.commitCalls != 0 {
t.Fatalf("expected no commit")
}
if tx.rollbackCalls != 1 {
t.Fatalf("expected rollback, got %d", tx.rollbackCalls)
}
}
func TestPostgresSinkConsumeAutoPruneRunsInSameTransaction(t *testing.T) {
tx := &fakeTx{}
db := &fakeDB{tx: tx}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
}, nil
})),
pruneWindow: 24 * time.Hour,
}
if err := sink.Consume(context.Background(), validTestEvent()); err != nil {
t.Fatalf("consume: %v", err)
}
if len(tx.execCalls) != 4 {
t.Fatalf("expected 4 tx statements (2 inserts + 2 prunes), got %d", len(tx.execCalls))
}
if !strings.Contains(tx.execCalls[2].query, `DELETE FROM "events"`) {
t.Fatalf("expected prune delete for events, got %s", tx.execCalls[2].query)
}
if !strings.Contains(tx.execCalls[3].query, `DELETE FROM "event_payloads"`) {
t.Fatalf("expected prune delete for event_payloads, got %s", tx.execCalls[3].query)
}
if tx.commitCalls != 1 {
t.Fatalf("expected one commit, got %d", tx.commitCalls)
}
if tx.rollbackCalls != 0 {
t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls)
}
}
func TestPostgresSinkConsumeAutoPruneFailureRollsBack(t *testing.T) {
tx := &fakeTx{execErrOnCall: 3, execErr: errors.New("prune failed")}
db := &fakeDB{tx: tx}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) {
return []PostgresWrite{
{Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}},
{Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}},
}, nil
})),
pruneWindow: 24 * time.Hour,
}
err := sink.Consume(context.Background(), validTestEvent())
if err == nil {
t.Fatalf("expected prune error")
}
if !strings.Contains(err.Error(), "prune older than") {
t.Fatalf("unexpected error: %v", err)
}
if tx.commitCalls != 0 {
t.Fatalf("expected no commit")
}
if tx.rollbackCalls != 1 {
t.Fatalf("expected rollback, got %d", tx.rollbackCalls)
}
}
func TestPostgresSinkPrunePerTable(t *testing.T) {
db := &fakeDB{execRows: 7}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
})),
}
rows, err := sink.PruneKeepLatest(context.Background(), "events", 10)
if err != nil {
t.Fatalf("prune keep latest: %v", err)
}
if rows != 7 {
t.Fatalf("unexpected rows affected: %d", rows)
}
if len(db.execCalls) != 1 {
t.Fatalf("expected one prune query")
}
if !strings.Contains(db.execCalls[0].query, `ORDER BY "emitted_at" DESC`) {
t.Fatalf("unexpected keep-latest query: %s", db.execCalls[0].query)
}
if len(db.execCalls[0].args) != 1 || db.execCalls[0].args[0] != 10 {
t.Fatalf("unexpected keep-latest args: %#v", db.execCalls[0].args)
}
cutoff := time.Now().UTC().Add(-24 * time.Hour)
rows, err = sink.PruneOlderThan(context.Background(), "events", cutoff)
if err != nil {
t.Fatalf("prune older than: %v", err)
}
if rows != 7 {
t.Fatalf("unexpected rows affected: %d", rows)
}
if len(db.execCalls) != 2 {
t.Fatalf("expected two prune queries")
}
if !strings.Contains(db.execCalls[1].query, `WHERE "emitted_at" < $1`) {
t.Fatalf("unexpected older-than query: %s", db.execCalls[1].query)
}
}
func TestPostgresSinkPruneAllTables(t *testing.T) {
db := &fakeDB{execRows: 3}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
})),
}
keepCounts, err := sink.PruneAllKeepLatest(context.Background(), 5)
if err != nil {
t.Fatalf("prune all keep latest: %v", err)
}
if len(keepCounts) != 2 || keepCounts["events"] != 3 || keepCounts["event_payloads"] != 3 {
t.Fatalf("unexpected keep counts: %#v", keepCounts)
}
db.execCalls = nil
olderCounts, err := sink.PruneAllOlderThan(context.Background(), time.Now().UTC())
if err != nil {
t.Fatalf("prune all older than: %v", err)
}
if len(olderCounts) != 2 || olderCounts["events"] != 3 || olderCounts["event_payloads"] != 3 {
t.Fatalf("unexpected older-than counts: %#v", olderCounts)
}
if len(db.execCalls) != 2 {
t.Fatalf("expected one prune call per table")
}
}
func TestPostgresSinkPruneErrors(t *testing.T) {
db := &fakeDB{}
sink := &PostgresSink{
name: "pg",
db: db,
schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) {
return nil, nil
})),
}
if _, err := sink.PruneKeepLatest(context.Background(), "events", -1); err == nil {
t.Fatalf("expected negative keep error")
}
if _, err := sink.PruneOlderThan(context.Background(), "missing", time.Now().UTC()); err == nil {
t.Fatalf("expected unknown table error")
}
}