From 3b92c2284dfb9916a6c14e24f4e14310b39bfb88 Mon Sep 17 00:00:00 2001 From: Eric Rakestraw Date: Sat, 28 Mar 2026 08:04:15 -0500 Subject: [PATCH] Added automatic pruning for configured Postgres sinks --- doc.go | 3 +- sinks/doc.go | 13 +++- sinks/postgres.go | 111 +++++++++++++++++++++++++----- sinks/postgres_test.go | 153 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 260 insertions(+), 20 deletions(-) diff --git a/doc.go b/doc.go index c2aefac..ced3e43 100644 --- a/doc.go +++ b/doc.go @@ -78,7 +78,8 @@ // Sink abstractions + sink registry. // Built-ins include stdout, NATS, and Postgres. For Postgres, downstream // code registers table schemas/mappers while feedkit manages DDL, writes, -// and optional prune helpers. +// optional automatic retention pruning (via sink params.prune), and +// manual prune helpers. Postgres table schemas must declare PruneColumn. // // Typical wiring (daemon main.go) // diff --git a/sinks/doc.go b/sinks/doc.go index db97dee..ecd4926 100644 --- a/sinks/doc.go +++ b/sinks/doc.go @@ -29,7 +29,7 @@ // and feedkit ownership: // - downstream code registers table schema + event mapping functions // - feedkit manages DB connection, create-if-missing DDL, transactional -// inserts, and prune helpers +// inserts, optional automatic retention pruning, and manual prune helpers // // Example config: // @@ -40,6 +40,13 @@ // uri: postgres://localhost:5432/feedkit?sslmode=disable // username: feedkit_user // password: feedkit_pass +// prune: 3d # optional: prune rows older than now-3d on each write tx +// +// params.prune supports: +// - Go duration strings (72h, 90m, 30s, ...) +// - day/week suffixes (3d, 2w) +// +// If params.prune is omitted, automatic pruning is disabled. // // Example downstream wiring: // @@ -53,7 +60,7 @@ // {Name: "payload_json", Type: "JSONB", Nullable: false}, // }, // PrimaryKey: []string{"event_id"}, -// PruneColumn: "emitted_at", +// PruneColumn: "emitted_at", // required for retention pruning // }, // }, // MapEvent: func(ctx context.Context, e event.Event) ([]sinks.PostgresWrite, error) { @@ -74,7 +81,7 @@ // }, // }) // -// Pruning via type assertion: +// Manual pruning via type assertion (administrative helpers): // // if p, ok := sink.(sinks.PostgresPruner); ok { // _, _ = p.PruneKeepLatest(ctx, "events", 10000) diff --git a/sinks/postgres.go b/sinks/postgres.go index 72a26a9..ed2afb1 100644 --- a/sinks/postgres.go +++ b/sinks/postgres.go @@ -22,6 +22,10 @@ type postgresTx interface { Rollback() error } +type postgresExecer interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +} + type postgresDB interface { PingContext(ctx context.Context) error BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error) @@ -78,9 +82,10 @@ var openPostgresDB = func(dsn string) (postgresDB, error) { } type PostgresSink struct { - name string - db postgresDB - schema postgresSchemaCompiled + name string + db postgresDB + schema postgresSchemaCompiled + pruneWindow time.Duration } func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) { @@ -96,6 +101,10 @@ func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) { if err != nil { return nil, err } + pruneWindow, err := parsePostgresPruneWindow(cfg) + if err != nil { + return nil, err + } schema, ok := lookupPostgresSchema(cfg.Name) if !ok { @@ -112,7 +121,7 @@ func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) { return nil, fmt.Errorf("postgres sink %q: open db: %w", cfg.Name, err) } - s := &PostgresSink{name: cfg.Name, db: db, schema: schema} + s := &PostgresSink{name: cfg.Name, db: db, schema: schema, pruneWindow: pruneWindow} if err := s.initialize(); err != nil { _ = db.Close() return nil, err @@ -168,6 +177,15 @@ func (p *PostgresSink) Consume(ctx context.Context, e event.Event) error { return fmt.Errorf("postgres sink: insert into %q: %w", tbl.name, err) } } + if p.pruneWindow > 0 { + cutoff := time.Now().UTC().Add(-p.pruneWindow) + for _, tableName := range p.schema.tableOrder { + tbl := p.schema.tables[tableName] + if _, err := execPruneOlderThan(ctx, tx, tbl, cutoff); err != nil { + return fmt.Errorf("postgres sink: %w", err) + } + } + } if err := tx.Commit(); err != nil { return fmt.Errorf("postgres sink: commit tx: %w", err) @@ -214,19 +232,9 @@ func (p *PostgresSink) PruneOlderThan(ctx context.Context, table string, cutoff return 0, err } - query := fmt.Sprintf( - `DELETE FROM %s WHERE %s < $1`, - quotePostgresIdent(tbl.name), - quotePostgresIdent(tbl.pruneColumn), - ) - - res, err := p.db.ExecContext(ctx, query, cutoff) + rows, err := execPruneOlderThan(ctx, p.db, tbl, cutoff) if err != nil { - return 0, fmt.Errorf("postgres sink: prune older than table %q: %w", tbl.name, err) - } - rows, err := res.RowsAffected() - if err != nil { - return 0, fmt.Errorf("postgres sink: prune older than table %q rows affected: %w", tbl.name, err) + return 0, fmt.Errorf("postgres sink: %w", err) } return rows, nil } @@ -309,6 +317,77 @@ func buildPostgresDSN(uri, username, password string) (string, error) { return u.String(), nil } +func parsePostgresPruneWindow(cfg config.SinkConfig) (time.Duration, error) { + raw, ok := cfg.Params["prune"] + if !ok || raw == nil { + return 0, nil + } + + s, ok := raw.(string) + if !ok { + return 0, fmt.Errorf("sink %q: params.prune must be a string duration (e.g. 72h, 3d, 2w)", cfg.Name) + } + + d, err := parsePostgresPruneDuration(s) + if err != nil { + return 0, fmt.Errorf("sink %q: params.prune %q is invalid: %w", cfg.Name, s, err) + } + return d, nil +} + +func parsePostgresPruneDuration(raw string) (time.Duration, error) { + s := strings.TrimSpace(raw) + if s == "" { + return 0, fmt.Errorf("must not be empty") + } + + lower := strings.ToLower(s) + if strings.HasSuffix(lower, "d") || strings.HasSuffix(lower, "w") { + unit := lower[len(lower)-1] + n, err := strconv.Atoi(strings.TrimSpace(lower[:len(lower)-1])) + if err != nil { + return 0, fmt.Errorf("must use a positive integer before %q", string(unit)) + } + if n <= 0 { + return 0, fmt.Errorf("must be > 0") + } + if unit == 'd' { + return time.Duration(n) * 24 * time.Hour, nil + } + return time.Duration(n) * 7 * 24 * time.Hour, nil + } + + d, err := time.ParseDuration(s) + if err != nil { + return 0, fmt.Errorf("must be a Go duration or use d/w suffixes") + } + if d <= 0 { + return 0, fmt.Errorf("must be > 0") + } + return d, nil +} + +func buildPruneOlderThanSQL(tbl postgresTableCompiled) string { + return fmt.Sprintf( + `DELETE FROM %s WHERE %s < $1`, + quotePostgresIdent(tbl.name), + quotePostgresIdent(tbl.pruneColumn), + ) +} + +func execPruneOlderThan(ctx context.Context, execer postgresExecer, tbl postgresTableCompiled, cutoff time.Time) (int64, error) { + query := buildPruneOlderThanSQL(tbl) + res, err := execer.ExecContext(ctx, query, cutoff) + if err != nil { + return 0, fmt.Errorf("prune older than table %q: %w", tbl.name, err) + } + rows, err := res.RowsAffected() + if err != nil { + return 0, fmt.Errorf("prune older than table %q rows affected: %w", tbl.name, err) + } + return rows, nil +} + func buildCreateTableSQL(tbl postgresTableCompiled) string { defs := make([]string, 0, len(tbl.columnOrder)+1) for _, colName := range tbl.columnOrder { diff --git a/sinks/postgres_test.go b/sinks/postgres_test.go index 0778155..4353dbe 100644 --- a/sinks/postgres_test.go +++ b/sinks/postgres_test.go @@ -395,6 +395,94 @@ func TestNewPostgresSinkFromConfig_InitFailureClosesDB(t *testing.T) { } } +func TestNewPostgresSinkFromConfig_PruneParamAccepted(t *testing.T) { + tests := []struct { + name string + in string + want time.Duration + }{ + {name: "go duration", in: "72h", want: 72 * time.Hour}, + {name: "days suffix", in: "3d", want: 72 * time.Hour}, + {name: "weeks suffix", in: "2w", want: 14 * 24 * time.Hour}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + withPostgresTestState(t) + + err := RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { + return nil, nil + })) + if err != nil { + t.Fatalf("register schema: %v", err) + } + + openPostgresDB = func(_ string) (postgresDB, error) { + return &fakeDB{}, nil + } + + s, err := NewPostgresSinkFromConfig(config.SinkConfig{ + Name: "pg", + Driver: "postgres", + Params: map[string]any{ + "uri": "postgres://localhost/db", + "username": "user", + "password": "pass", + "prune": tc.in, + }, + }) + if err != nil { + t.Fatalf("new postgres sink: %v", err) + } + + pg, ok := s.(*PostgresSink) + if !ok { + t.Fatalf("expected *PostgresSink, got %T", s) + } + if pg.pruneWindow != tc.want { + t.Fatalf("prune window = %s, want %s", pg.pruneWindow, tc.want) + } + }) + } +} + +func TestNewPostgresSinkFromConfig_PruneParamRejected(t *testing.T) { + withPostgresTestState(t) + + tests := []struct { + name string + in any + }{ + {name: "empty", in: ""}, + {name: "zero", in: "0"}, + {name: "negative", in: "-1h"}, + {name: "malformed", in: "abc"}, + {name: "fractional day", in: "1.5d"}, + {name: "wrong type", in: 5}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewPostgresSinkFromConfig(config.SinkConfig{ + Name: "pg", + Driver: "postgres", + Params: map[string]any{ + "uri": "postgres://localhost/db", + "username": "user", + "password": "pass", + "prune": tc.in, + }, + }) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "params.prune") { + t.Fatalf("expected params.prune error, got %v", err) + } + }) + } +} + func TestPostgresSinkConsume_InvalidEvent(t *testing.T) { db := &fakeDB{} called := 0 @@ -497,6 +585,71 @@ func TestPostgresSinkConsume_InsertFailureRollsBack(t *testing.T) { } } +func TestPostgresSinkConsume_AutoPruneRunsInSameTransaction(t *testing.T) { + tx := &fakeTx{} + db := &fakeDB{tx: tx} + sink := &PostgresSink{ + name: "pg", + db: db, + schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { + return []PostgresWrite{ + {Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}}, + {Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}}, + }, nil + })), + pruneWindow: 24 * time.Hour, + } + + if err := sink.Consume(context.Background(), validTestEvent()); err != nil { + t.Fatalf("consume: %v", err) + } + if len(tx.execCalls) != 4 { + t.Fatalf("expected 4 tx statements (2 inserts + 2 prunes), got %d", len(tx.execCalls)) + } + if !strings.Contains(tx.execCalls[2].query, `DELETE FROM "events"`) { + t.Fatalf("expected prune delete for events, got %s", tx.execCalls[2].query) + } + if !strings.Contains(tx.execCalls[3].query, `DELETE FROM "event_payloads"`) { + t.Fatalf("expected prune delete for event_payloads, got %s", tx.execCalls[3].query) + } + if tx.commitCalls != 1 { + t.Fatalf("expected one commit, got %d", tx.commitCalls) + } + if tx.rollbackCalls != 0 { + t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls) + } +} + +func TestPostgresSinkConsume_AutoPruneFailureRollsBack(t *testing.T) { + tx := &fakeTx{execErrOnCall: 3, execErr: errors.New("prune failed")} + db := &fakeDB{tx: tx} + sink := &PostgresSink{ + name: "pg", + db: db, + schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { + return []PostgresWrite{ + {Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}}, + {Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}}, + }, nil + })), + pruneWindow: 24 * time.Hour, + } + + err := sink.Consume(context.Background(), validTestEvent()) + if err == nil { + t.Fatalf("expected prune error") + } + if !strings.Contains(err.Error(), "prune older than") { + t.Fatalf("unexpected error: %v", err) + } + if tx.commitCalls != 0 { + t.Fatalf("expected no commit") + } + if tx.rollbackCalls != 1 { + t.Fatalf("expected rollback, got %d", tx.rollbackCalls) + } +} + func TestPostgresSinkPrune_PerTable(t *testing.T) { db := &fakeDB{execRows: 7} sink := &PostgresSink{