package sinks import ( "context" "database/sql" "errors" "strings" "testing" "time" "gitea.maximumdirect.net/ejr/feedkit/config" "gitea.maximumdirect.net/ejr/feedkit/event" pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres" ) type fakeResult struct { rows int64 } func (r fakeResult) LastInsertId() (int64, error) { return 0, errors.New("unsupported") } func (r fakeResult) RowsAffected() (int64, error) { return r.rows, nil } type execCall struct { query string args []any } type fakeTx struct { execCalls []execCall execErr error execErrOnCall int execRows int64 commitErr error rollbackErr error commitCalls int rollbackCalls int } func (t *fakeTx) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) { t.execCalls = append(t.execCalls, execCall{query: query, args: append([]any(nil), args...)}) if t.execErr != nil && t.execErrOnCall == len(t.execCalls) { return nil, t.execErr } return fakeResult{rows: t.execRows}, nil } func (t *fakeTx) Commit() error { t.commitCalls++ return t.commitErr } func (t *fakeTx) Rollback() error { t.rollbackCalls++ return t.rollbackErr } type fakeDB struct { pingErr error beginErr error execErr error execErrOnCall int execRows int64 pingCalls int beginCalls int execCalls []execCall closeCalls int tx *fakeTx } func (d *fakeDB) PingContext(_ context.Context) error { d.pingCalls++ return d.pingErr } func (d *fakeDB) BeginTx(_ context.Context, _ *sql.TxOptions) (postgresTx, error) { d.beginCalls++ if d.beginErr != nil { return nil, d.beginErr } if d.tx == nil { d.tx = &fakeTx{} } return d.tx, nil } func (d *fakeDB) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) { d.execCalls = append(d.execCalls, execCall{query: query, args: append([]any(nil), args...)}) if d.execErr != nil && d.execErrOnCall == len(d.execCalls) { return nil, d.execErr } return fakeResult{rows: d.execRows}, nil } func (d *fakeDB) Close() error { d.closeCalls++ return nil } func withPostgresTestState(t *testing.T) { t.Helper() oldOpen := openPostgresDB t.Cleanup(func() { openPostgresDB = oldOpen }) } func validTestEvent() event.Event { now := time.Now().UTC() return event.Event{ ID: "evt-1", Kind: event.Kind("observation"), Source: "source-1", EmittedAt: now, Payload: map[string]any{ "x": 1, }, } } func schemaOneTable(mapFn PostgresMapFunc) PostgresSchema { return PostgresSchema{ Tables: []PostgresTable{ { Name: "events", Columns: []PostgresColumn{ {Name: "event_id", Type: "TEXT", Nullable: false}, {Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false}, {Name: "payload_json", Type: "JSONB", Nullable: false}, }, PrimaryKey: []string{"event_id"}, PruneColumn: "emitted_at", Indexes: []PostgresIndex{ {Name: "idx_events_emitted_at", Columns: []string{"emitted_at"}}, }, }, }, MapEvent: mapFn, } } func schemaTwoTables(mapFn PostgresMapFunc) PostgresSchema { return PostgresSchema{ Tables: []PostgresTable{ { Name: "events", Columns: []PostgresColumn{ {Name: "event_id", Type: "TEXT", Nullable: false}, {Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false}, }, PrimaryKey: []string{"event_id"}, PruneColumn: "emitted_at", }, { Name: "event_payloads", Columns: []PostgresColumn{ {Name: "event_id", Type: "TEXT", Nullable: false}, {Name: "payload_json", Type: "JSONB", Nullable: false}, {Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false}, }, PrimaryKey: []string{"event_id"}, PruneColumn: "emitted_at", }, }, MapEvent: mapFn, } } func mustCompileSchema(t *testing.T, s PostgresSchema) postgresSchemaCompiled { t.Helper() compiled, err := compilePostgresSchema(s) if err != nil { t.Fatalf("compile schema: %v", err) } return compiled } func TestCompilePostgresSchemaRejectsInvalidSchema(t *testing.T) { _, err := compilePostgresSchema(PostgresSchema{ Tables: []PostgresTable{ { Name: "events", Columns: []PostgresColumn{ {Name: "id", Type: "TEXT", Nullable: false}, }, PruneColumn: "missing_col", }, }, MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }, }) if err == nil { t.Fatalf("expected invalid schema error") } if !strings.Contains(err.Error(), "prune column") { t.Fatalf("unexpected schema validation error: %v", err) } _, err = compilePostgresSchema(PostgresSchema{ Tables: []PostgresTable{ { Name: "events", Columns: []PostgresColumn{ {Name: "id", Type: "TEXT", Nullable: false}, {Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false}, }, PruneColumn: "emitted_at", Indexes: []PostgresIndex{ {Name: "idx_events_empty", Columns: nil}, }, }, }, MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }, }) if err == nil { t.Fatalf("expected invalid index schema error") } if !strings.Contains(err.Error(), "at least one column") { t.Fatalf("unexpected index validation error: %v", err) } } func TestPostgresFactoryBuildsMultipleSinksWithSameSchema(t *testing.T) { withPostgresTestState(t) dbs := []*fakeDB{{}, {}} var gotCfgs []pgconn.ConnConfig openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) { gotCfgs = append(gotCfgs, cfg) db := dbs[len(gotCfgs)-1] if err := db.PingContext(ctx); err != nil { return nil, err } return db, nil } factory := PostgresFactory(schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })) for _, name := range []string{"pg_a", "pg_b"} { sink, err := factory(config.SinkConfig{ Name: name, Driver: "postgres", Params: map[string]any{ "uri": "postgres://localhost/db", "username": "user", "password": "pass", }, }) if err != nil { t.Fatalf("factory(%q) error = %v", name, err) } if sink == nil { t.Fatalf("factory(%q) returned nil sink", name) } } if len(gotCfgs) != 2 { t.Fatalf("len(gotCfgs) = %d, want 2", len(gotCfgs)) } if gotCfgs[0].Username != "user" || gotCfgs[0].Password != "pass" { t.Fatalf("first ConnConfig = %+v", gotCfgs[0]) } for i, db := range dbs { if db.pingCalls != 1 { t.Fatalf("db[%d] pingCalls = %d, want 1", i, db.pingCalls) } } } func TestNewPostgresSinkFromConfigMissingParams(t *testing.T) { withPostgresTestState(t) tests := []struct { name string params map[string]any want string }{ {name: "missing uri", params: map[string]any{"username": "u", "password": "p"}, want: "params.uri"}, {name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p"}, want: "params.username"}, {name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u"}, want: "params.password"}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { _, err := NewPostgresSinkFromConfig(config.SinkConfig{ Name: "pg", Driver: "postgres", Params: tc.params, }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })) if err == nil { t.Fatalf("expected error") } if !strings.Contains(err.Error(), tc.want) { t.Fatalf("expected %q in error, got: %v", tc.want, err) } }) } } func TestNewPostgresSinkFromConfigRejectsInvalidSchema(t *testing.T) { withPostgresTestState(t) _, err := NewPostgresSinkFromConfig(config.SinkConfig{ Name: "pg", Driver: "postgres", Params: map[string]any{ "uri": "postgres://localhost/db", "username": "user", "password": "pass", }, }, PostgresSchema{ Tables: []PostgresTable{ { Name: "events", Columns: []PostgresColumn{ {Name: "id", Type: "TEXT", Nullable: false}, }, PruneColumn: "missing_col", }, }, MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }, }) if err == nil { t.Fatalf("expected invalid schema error") } if !strings.Contains(err.Error(), "compile schema") { t.Fatalf("unexpected error: %v", err) } } func TestNewPostgresSinkFromConfigEagerInit(t *testing.T) { withPostgresTestState(t) db := &fakeDB{} var gotCfg pgconn.ConnConfig openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) { gotCfg = cfg if err := db.PingContext(ctx); err != nil { return nil, err } return db, nil } s, err := NewPostgresSinkFromConfig(config.SinkConfig{ Name: "pg", Driver: "postgres", Params: map[string]any{ "uri": "postgres://db.example.local:5432/feedkit?sslmode=disable", "username": "app_user", "password": "app_pass", }, }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })) if err != nil { t.Fatalf("new postgres sink: %v", err) } if s == nil { t.Fatalf("expected sink") } if db.pingCalls != 1 { t.Fatalf("expected one ping, got %d", db.pingCalls) } if len(db.execCalls) != 2 { t.Fatalf("expected 2 init exec calls (table + index), got %d", len(db.execCalls)) } if !strings.Contains(db.execCalls[0].query, `CREATE TABLE IF NOT EXISTS "events"`) { t.Fatalf("unexpected create table query: %s", db.execCalls[0].query) } if !strings.Contains(db.execCalls[1].query, `CREATE INDEX IF NOT EXISTS "idx_events_emitted_at"`) { t.Fatalf("unexpected create index query: %s", db.execCalls[1].query) } if gotCfg.URI != "postgres://db.example.local:5432/feedkit?sslmode=disable" { t.Fatalf("URI = %q", gotCfg.URI) } if gotCfg.Username != "app_user" { t.Fatalf("Username = %q, want app_user", gotCfg.Username) } if gotCfg.Password != "app_pass" { t.Fatalf("Password = %q, want app_pass", gotCfg.Password) } } func TestNewPostgresSinkFromConfigInitFailureClosesDB(t *testing.T) { withPostgresTestState(t) db := &fakeDB{execErrOnCall: 1, execErr: errors.New("ddl failed")} openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) { return db, nil } _, err := NewPostgresSinkFromConfig(config.SinkConfig{ Name: "pg", Driver: "postgres", Params: map[string]any{ "uri": "postgres://localhost/db", "username": "user", "password": "pass", }, }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })) if err == nil { t.Fatalf("expected init error") } if db.closeCalls != 1 { t.Fatalf("expected db close on init failure") } } func TestNewPostgresSinkFromConfigPruneParamAccepted(t *testing.T) { tests := []struct { name string in string want time.Duration }{ {name: "go duration", in: "72h", want: 72 * time.Hour}, {name: "days suffix", in: "3d", want: 72 * time.Hour}, {name: "weeks suffix", in: "2w", want: 14 * 24 * time.Hour}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { withPostgresTestState(t) openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) { return &fakeDB{}, nil } s, err := NewPostgresSinkFromConfig(config.SinkConfig{ Name: "pg", Driver: "postgres", Params: map[string]any{ "uri": "postgres://localhost/db", "username": "user", "password": "pass", "prune": tc.in, }, }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })) if err != nil { t.Fatalf("new postgres sink: %v", err) } pg, ok := s.(*PostgresSink) if !ok { t.Fatalf("expected *PostgresSink, got %T", s) } if pg.pruneWindow != tc.want { t.Fatalf("prune window = %s, want %s", pg.pruneWindow, tc.want) } }) } } func TestNewPostgresSinkFromConfigPruneParamRejected(t *testing.T) { withPostgresTestState(t) tests := []struct { name string in any }{ {name: "empty", in: ""}, {name: "zero", in: "0"}, {name: "negative", in: "-1h"}, {name: "malformed", in: "abc"}, {name: "fractional day", in: "1.5d"}, {name: "wrong type", in: 5}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { _, err := NewPostgresSinkFromConfig(config.SinkConfig{ Name: "pg", Driver: "postgres", Params: map[string]any{ "uri": "postgres://localhost/db", "username": "user", "password": "pass", "prune": tc.in, }, }, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })) if err == nil { t.Fatalf("expected error") } if !strings.Contains(err.Error(), "params.prune") { t.Fatalf("expected params.prune error, got %v", err) } }) } } func TestPostgresSinkConsumeInvalidEvent(t *testing.T) { db := &fakeDB{} called := 0 sink := &PostgresSink{ name: "pg", db: db, schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { called++ return nil, nil })), } err := sink.Consume(context.Background(), event.Event{}) if err == nil { t.Fatalf("expected invalid event error") } if !strings.Contains(err.Error(), "invalid event") { t.Fatalf("unexpected error: %v", err) } if called != 0 { t.Fatalf("expected mapper not called for invalid events") } } func TestPostgresSinkConsumeUnmappedEventIsNoOp(t *testing.T) { db := &fakeDB{} sink := &PostgresSink{ name: "pg", db: db, schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })), } if err := sink.Consume(context.Background(), validTestEvent()); err != nil { t.Fatalf("consume: %v", err) } if db.beginCalls != 0 { t.Fatalf("expected no transaction for unmapped events") } } func TestPostgresSinkConsumeOneEventWritesMultipleTablesAtomically(t *testing.T) { tx := &fakeTx{} db := &fakeDB{tx: tx} sink := &PostgresSink{ name: "pg", db: db, schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { return []PostgresWrite{ {Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}}, {Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}}, }, nil })), } if err := sink.Consume(context.Background(), validTestEvent()); err != nil { t.Fatalf("consume: %v", err) } if db.beginCalls != 1 { t.Fatalf("expected one transaction begin, got %d", db.beginCalls) } if len(tx.execCalls) != 2 { t.Fatalf("expected 2 insert statements, got %d", len(tx.execCalls)) } if tx.commitCalls != 1 { t.Fatalf("expected one commit, got %d", tx.commitCalls) } if tx.rollbackCalls != 0 { t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls) } } func TestPostgresSinkConsumeInsertFailureRollsBack(t *testing.T) { tx := &fakeTx{execErrOnCall: 2, execErr: errors.New("duplicate key")} db := &fakeDB{tx: tx} sink := &PostgresSink{ name: "pg", db: db, schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { return []PostgresWrite{ {Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}}, {Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}}, }, nil })), } err := sink.Consume(context.Background(), validTestEvent()) if err == nil { t.Fatalf("expected insert error") } if !strings.Contains(err.Error(), "insert into") { t.Fatalf("unexpected error: %v", err) } if tx.commitCalls != 0 { t.Fatalf("expected no commit") } if tx.rollbackCalls != 1 { t.Fatalf("expected rollback, got %d", tx.rollbackCalls) } } func TestPostgresSinkConsumeAutoPruneRunsInSameTransaction(t *testing.T) { tx := &fakeTx{} db := &fakeDB{tx: tx} sink := &PostgresSink{ name: "pg", db: db, schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { return []PostgresWrite{ {Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}}, {Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}}, }, nil })), pruneWindow: 24 * time.Hour, } if err := sink.Consume(context.Background(), validTestEvent()); err != nil { t.Fatalf("consume: %v", err) } if len(tx.execCalls) != 4 { t.Fatalf("expected 4 tx statements (2 inserts + 2 prunes), got %d", len(tx.execCalls)) } if !strings.Contains(tx.execCalls[2].query, `DELETE FROM "events"`) { t.Fatalf("expected prune delete for events, got %s", tx.execCalls[2].query) } if !strings.Contains(tx.execCalls[3].query, `DELETE FROM "event_payloads"`) { t.Fatalf("expected prune delete for event_payloads, got %s", tx.execCalls[3].query) } if tx.commitCalls != 1 { t.Fatalf("expected one commit, got %d", tx.commitCalls) } if tx.rollbackCalls != 0 { t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls) } } func TestPostgresSinkConsumeAutoPruneFailureRollsBack(t *testing.T) { tx := &fakeTx{execErrOnCall: 3, execErr: errors.New("prune failed")} db := &fakeDB{tx: tx} sink := &PostgresSink{ name: "pg", db: db, schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { return []PostgresWrite{ {Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}}, {Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}}, }, nil })), pruneWindow: 24 * time.Hour, } err := sink.Consume(context.Background(), validTestEvent()) if err == nil { t.Fatalf("expected prune error") } if !strings.Contains(err.Error(), "prune older than") { t.Fatalf("unexpected error: %v", err) } if tx.commitCalls != 0 { t.Fatalf("expected no commit") } if tx.rollbackCalls != 1 { t.Fatalf("expected rollback, got %d", tx.rollbackCalls) } } func TestPostgresSinkPrunePerTable(t *testing.T) { db := &fakeDB{execRows: 7} sink := &PostgresSink{ name: "pg", db: db, schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })), } rows, err := sink.PruneKeepLatest(context.Background(), "events", 10) if err != nil { t.Fatalf("prune keep latest: %v", err) } if rows != 7 { t.Fatalf("unexpected rows affected: %d", rows) } if len(db.execCalls) != 1 { t.Fatalf("expected one prune query") } if !strings.Contains(db.execCalls[0].query, `ORDER BY "emitted_at" DESC`) { t.Fatalf("unexpected keep-latest query: %s", db.execCalls[0].query) } if len(db.execCalls[0].args) != 1 || db.execCalls[0].args[0] != 10 { t.Fatalf("unexpected keep-latest args: %#v", db.execCalls[0].args) } cutoff := time.Now().UTC().Add(-24 * time.Hour) rows, err = sink.PruneOlderThan(context.Background(), "events", cutoff) if err != nil { t.Fatalf("prune older than: %v", err) } if rows != 7 { t.Fatalf("unexpected rows affected: %d", rows) } if len(db.execCalls) != 2 { t.Fatalf("expected two prune queries") } if !strings.Contains(db.execCalls[1].query, `WHERE "emitted_at" < $1`) { t.Fatalf("unexpected older-than query: %s", db.execCalls[1].query) } } func TestPostgresSinkPruneAllTables(t *testing.T) { db := &fakeDB{execRows: 3} sink := &PostgresSink{ name: "pg", db: db, schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })), } keepCounts, err := sink.PruneAllKeepLatest(context.Background(), 5) if err != nil { t.Fatalf("prune all keep latest: %v", err) } if len(keepCounts) != 2 || keepCounts["events"] != 3 || keepCounts["event_payloads"] != 3 { t.Fatalf("unexpected keep counts: %#v", keepCounts) } db.execCalls = nil olderCounts, err := sink.PruneAllOlderThan(context.Background(), time.Now().UTC()) if err != nil { t.Fatalf("prune all older than: %v", err) } if len(olderCounts) != 2 || olderCounts["events"] != 3 || olderCounts["event_payloads"] != 3 { t.Fatalf("unexpected older-than counts: %#v", olderCounts) } if len(db.execCalls) != 2 { t.Fatalf("expected one prune call per table") } } func TestPostgresSinkPruneErrors(t *testing.T) { db := &fakeDB{} sink := &PostgresSink{ name: "pg", db: db, schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil })), } if _, err := sink.PruneKeepLatest(context.Background(), "events", -1); err == nil { t.Fatalf("expected negative keep error") } if _, err := sink.PruneOlderThan(context.Background(), "missing", time.Now().UTC()); err == nil { t.Fatalf("expected unknown table error") } }