diff --git a/README.md b/README.md index 6c51b9f..420809b 100644 --- a/README.md +++ b/README.md @@ -97,8 +97,13 @@ Compiles routes and fans out events to sinks with per-sink queue/worker isolatio ### `sinks` -Defines sink interface and sink registry. Built-ins include `stdout` and `nats`, with -additional sink implementations at varying maturity. +Defines sink interface and sink registry. Built-ins include: +- `stdout` +- `nats` +- `postgres` (downstream registers table schema + event mapper; feedkit handles create-if-missing DDL, transactional inserts, and optional prune APIs) + +Detailed Postgres configuration and wiring examples live in package docs: +`sinks/doc.go`. ## Typical wiring diff --git a/doc.go b/doc.go index 5cb7151..316fb9f 100644 --- a/doc.go +++ b/doc.go @@ -73,6 +73,9 @@ // // - sinks // Sink abstractions + sink registry. +// Built-ins include stdout, NATS, and Postgres. For Postgres, downstream +// code registers table schemas/mappers while feedkit manages DDL, writes, +// and optional prune helpers. // // Typical wiring (daemon main.go) // diff --git a/go.mod b/go.mod index 766877a..19f0751 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module gitea.maximumdirect.net/ejr/feedkit go 1.22 require ( + github.com/lib/pq v1.10.9 github.com/nats-io/nats.go v1.34.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 91387b2..191aaaa 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4= github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/nats-io/nats.go v1.34.0 h1:fnxnPCNiwIG5w08rlMcEKTUw4AV/nKyGCOJE8TdhSPk= github.com/nats-io/nats.go v1.34.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8= github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI= diff --git a/sinks/doc.go b/sinks/doc.go new file mode 100644 index 0000000..db97dee --- /dev/null +++ b/sinks/doc.go @@ -0,0 +1,83 @@ +// Package sinks provides sink abstractions, a sink driver registry, and several +// built-in sink drivers. +// +// Built-in drivers: +// - stdout +// - nats +// - postgres +// +// # NATS built-in overview +// +// The NATS sink publishes each event as JSON to a configured subject. +// +// Required params: +// - url: NATS server URL (for example, nats://localhost:4222) +// - exchange: NATS subject to publish to +// +// Example config: +// +// sinks: +// - name: nats_main +// driver: nats +// params: +// url: nats://localhost:4222 +// exchange: feedkit.events +// +// # Postgres built-in overview +// +// The postgres sink is intentionally split between downstream daemon ownership +// and feedkit ownership: +// - downstream code registers table schema + event mapping functions +// - feedkit manages DB connection, create-if-missing DDL, transactional +// inserts, and prune helpers +// +// Example config: +// +// sinks: +// - name: pg_main +// driver: postgres +// params: +// uri: postgres://localhost:5432/feedkit?sslmode=disable +// username: feedkit_user +// password: feedkit_pass +// +// Example downstream wiring: +// +// sinks.MustRegisterPostgresSchema("pg_main", sinks.PostgresSchema{ +// Tables: []sinks.PostgresTable{ +// { +// Name: "events", +// Columns: []sinks.PostgresColumn{ +// {Name: "event_id", Type: "TEXT", Nullable: false}, +// {Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false}, +// {Name: "payload_json", Type: "JSONB", Nullable: false}, +// }, +// PrimaryKey: []string{"event_id"}, +// PruneColumn: "emitted_at", +// }, +// }, +// MapEvent: func(ctx context.Context, e event.Event) ([]sinks.PostgresWrite, error) { +// b, err := json.Marshal(e.Payload) +// if err != nil { +// return nil, err +// } +// return []sinks.PostgresWrite{ +// { +// Table: "events", +// Values: map[string]any{ +// "event_id": e.ID, +// "emitted_at": e.EmittedAt, +// "payload_json": string(b), +// }, +// }, +// }, nil +// }, +// }) +// +// Pruning via type assertion: +// +// if p, ok := sink.(sinks.PostgresPruner); ok { +// _, _ = p.PruneKeepLatest(ctx, "events", 10000) +// _, _ = p.PruneOlderThan(ctx, "events", time.Now().UTC().AddDate(0, -1, 0)) +// } +package sinks diff --git a/sinks/postgres.go b/sinks/postgres.go index 9b9151b..72a26a9 100644 --- a/sinks/postgres.go +++ b/sinks/postgres.go @@ -2,36 +2,381 @@ package sinks import ( "context" + "database/sql" "fmt" + "net/url" + "strconv" + "strings" + "time" "gitea.maximumdirect.net/ejr/feedkit/config" "gitea.maximumdirect.net/ejr/feedkit/event" + _ "github.com/lib/pq" ) -type PostgresSink struct { - name string - dsn string +const postgresInitTimeout = 5 * time.Second + +type postgresTx interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + Commit() error + Rollback() error } -func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) { - dsn, err := requireStringParam(cfg, "dsn") +type postgresDB interface { + PingContext(ctx context.Context) error + BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error) + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + Close() error +} + +type sqlDBWrapper struct { + db *sql.DB +} + +func (w *sqlDBWrapper) PingContext(ctx context.Context) error { + return w.db.PingContext(ctx) +} + +func (w *sqlDBWrapper) BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error) { + tx, err := w.db.BeginTx(ctx, opts) if err != nil { return nil, err } - return &PostgresSink{name: cfg.Name, dsn: dsn}, nil + return &sqlTxWrapper{tx: tx}, nil +} + +func (w *sqlDBWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + return w.db.ExecContext(ctx, query, args...) +} + +func (w *sqlDBWrapper) Close() error { + return w.db.Close() +} + +type sqlTxWrapper struct { + tx *sql.Tx +} + +func (w *sqlTxWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + return w.tx.ExecContext(ctx, query, args...) +} + +func (w *sqlTxWrapper) Commit() error { + return w.tx.Commit() +} + +func (w *sqlTxWrapper) Rollback() error { + return w.tx.Rollback() +} + +var openPostgresDB = func(dsn string) (postgresDB, error) { + db, err := sql.Open("postgres", dsn) + if err != nil { + return nil, err + } + return &sqlDBWrapper{db: db}, nil +} + +type PostgresSink struct { + name string + db postgresDB + schema postgresSchemaCompiled +} + +func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) { + uri, err := requireStringParam(cfg, "uri") + if err != nil { + return nil, err + } + username, err := requireStringParam(cfg, "username") + if err != nil { + return nil, err + } + password, err := requireStringParam(cfg, "password") + if err != nil { + return nil, err + } + + schema, ok := lookupPostgresSchema(cfg.Name) + if !ok { + return nil, fmt.Errorf("postgres sink %q: no schema registered (call sinks.RegisterPostgresSchema before building sinks)", cfg.Name) + } + + dsn, err := buildPostgresDSN(uri, username, password) + if err != nil { + return nil, fmt.Errorf("postgres sink %q: build dsn: %w", cfg.Name, err) + } + + db, err := openPostgresDB(dsn) + if err != nil { + return nil, fmt.Errorf("postgres sink %q: open db: %w", cfg.Name, err) + } + + s := &PostgresSink{name: cfg.Name, db: db, schema: schema} + if err := s.initialize(); err != nil { + _ = db.Close() + return nil, err + } + + return s, nil } func (p *PostgresSink) Name() string { return p.name } func (p *PostgresSink) Consume(ctx context.Context, e event.Event) error { - _ = ctx - // Boundary validation: if something upstream violated invariants, - // surface it loudly rather than printing partial nonsense. + // surface it loudly rather than writing corrupt rows. if err := e.Validate(); err != nil { return fmt.Errorf("postgres sink: invalid event: %w", err) } - // TODO implement Postgres transaction + if err := ctx.Err(); err != nil { + return err + } + + writes, err := p.schema.mapEvent(ctx, e) + if err != nil { + return fmt.Errorf("postgres sink: map event: %w", err) + } + if len(writes) == 0 { + return nil + } + + tx, err := p.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("postgres sink: begin tx: %w", err) + } + + committed := false + defer func() { + if !committed { + _ = tx.Rollback() + } + }() + + for _, w := range writes { + tbl, err := p.schema.validateWrite(w) + if err != nil { + return fmt.Errorf("postgres sink: %w", err) + } + + query, args, err := buildInsertSQL(tbl, w) + if err != nil { + return fmt.Errorf("postgres sink: build insert for table %q: %w", tbl.name, err) + } + if _, err := tx.ExecContext(ctx, query, args...); err != nil { + return fmt.Errorf("postgres sink: insert into %q: %w", tbl.name, err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("postgres sink: commit tx: %w", err) + } + committed = true + return nil } + +func (p *PostgresSink) PruneKeepLatest(ctx context.Context, table string, keep int) (int64, error) { + if keep < 0 { + return 0, fmt.Errorf("postgres sink: keep must be >= 0") + } + tbl, err := p.lookupTable(table) + if err != nil { + return 0, err + } + + query := fmt.Sprintf( + `DELETE FROM %s WHERE ctid IN ( + SELECT ctid FROM %s + ORDER BY %s DESC + OFFSET $1 +)`, + quotePostgresIdent(tbl.name), + quotePostgresIdent(tbl.name), + quotePostgresIdent(tbl.pruneColumn), + ) + + res, err := p.db.ExecContext(ctx, query, keep) + if err != nil { + return 0, fmt.Errorf("postgres sink: prune keep latest table %q: %w", tbl.name, err) + } + rows, err := res.RowsAffected() + if err != nil { + return 0, fmt.Errorf("postgres sink: prune keep latest table %q rows affected: %w", tbl.name, err) + } + return rows, nil +} + +func (p *PostgresSink) PruneOlderThan(ctx context.Context, table string, cutoff time.Time) (int64, error) { + tbl, err := p.lookupTable(table) + if err != nil { + return 0, err + } + + query := fmt.Sprintf( + `DELETE FROM %s WHERE %s < $1`, + quotePostgresIdent(tbl.name), + quotePostgresIdent(tbl.pruneColumn), + ) + + res, err := p.db.ExecContext(ctx, query, cutoff) + if err != nil { + return 0, fmt.Errorf("postgres sink: prune older than table %q: %w", tbl.name, err) + } + rows, err := res.RowsAffected() + if err != nil { + return 0, fmt.Errorf("postgres sink: prune older than table %q rows affected: %w", tbl.name, err) + } + return rows, nil +} + +func (p *PostgresSink) PruneAllKeepLatest(ctx context.Context, keep int) (map[string]int64, error) { + counts := make(map[string]int64, len(p.schema.tableOrder)) + for _, table := range p.schema.tableOrder { + n, err := p.PruneKeepLatest(ctx, table, keep) + if err != nil { + return counts, err + } + counts[table] = n + } + return counts, nil +} + +func (p *PostgresSink) PruneAllOlderThan(ctx context.Context, cutoff time.Time) (map[string]int64, error) { + counts := make(map[string]int64, len(p.schema.tableOrder)) + for _, table := range p.schema.tableOrder { + n, err := p.PruneOlderThan(ctx, table, cutoff) + if err != nil { + return counts, err + } + counts[table] = n + } + return counts, nil +} + +func (p *PostgresSink) initialize() error { + ctx, cancel := context.WithTimeout(context.Background(), postgresInitTimeout) + defer cancel() + + if err := p.db.PingContext(ctx); err != nil { + return fmt.Errorf("postgres sink %q: ping db: %w", p.name, err) + } + + for _, tableName := range p.schema.tableOrder { + tbl := p.schema.tables[tableName] + + createTableSQL := buildCreateTableSQL(tbl) + if _, err := p.db.ExecContext(ctx, createTableSQL); err != nil { + return fmt.Errorf("postgres sink %q: ensure table %q: %w", p.name, tbl.name, err) + } + + for _, idx := range tbl.indexes { + createIndexSQL := buildCreateIndexSQL(tbl.name, idx) + if _, err := p.db.ExecContext(ctx, createIndexSQL); err != nil { + return fmt.Errorf("postgres sink %q: ensure index %q on %q: %w", p.name, idx.Name, tbl.name, err) + } + } + } + + return nil +} + +func (p *PostgresSink) lookupTable(table string) (postgresTableCompiled, error) { + table = strings.TrimSpace(table) + if table == "" { + return postgresTableCompiled{}, fmt.Errorf("postgres sink: table cannot be empty") + } + tbl, ok := p.schema.tables[table] + if !ok { + return postgresTableCompiled{}, fmt.Errorf("postgres sink: unknown table %q", table) + } + return tbl, nil +} + +func buildPostgresDSN(uri, username, password string) (string, error) { + u, err := url.Parse(strings.TrimSpace(uri)) + if err != nil { + return "", fmt.Errorf("invalid uri: %w", err) + } + if u.Scheme == "" { + return "", fmt.Errorf("invalid uri: missing scheme") + } + if u.Host == "" { + return "", fmt.Errorf("invalid uri: missing host") + } + u.User = url.UserPassword(username, password) + return u.String(), nil +} + +func buildCreateTableSQL(tbl postgresTableCompiled) string { + defs := make([]string, 0, len(tbl.columnOrder)+1) + for _, colName := range tbl.columnOrder { + col := tbl.columns[colName] + def := fmt.Sprintf("%s %s", quotePostgresIdent(col.Name), col.Type) + if !col.Nullable { + def += " NOT NULL" + } + defs = append(defs, def) + } + if len(tbl.primaryKey) > 0 { + defs = append(defs, fmt.Sprintf("PRIMARY KEY (%s)", joinQuotedIdents(tbl.primaryKey))) + } + + return fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s (%s)", + quotePostgresIdent(tbl.name), + strings.Join(defs, ", "), + ) +} + +func buildCreateIndexSQL(tableName string, idx PostgresIndex) string { + unique := "" + if idx.Unique { + unique = "UNIQUE " + } + + return fmt.Sprintf( + "CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)", + unique, + quotePostgresIdent(idx.Name), + quotePostgresIdent(tableName), + joinQuotedIdents(idx.Columns), + ) +} + +func buildInsertSQL(tbl postgresTableCompiled, w PostgresWrite) (string, []any, error) { + cols := make([]string, 0, len(tbl.columnOrder)) + args := make([]any, 0, len(tbl.columnOrder)) + placeholders := make([]string, 0, len(tbl.columnOrder)) + + for i, colName := range tbl.columnOrder { + v, ok := w.Values[colName] + if !ok { + return "", nil, fmt.Errorf("missing value for column %q", colName) + } + cols = append(cols, quotePostgresIdent(colName)) + args = append(args, v) + placeholders = append(placeholders, "$"+strconv.Itoa(i+1)) + } + + q := fmt.Sprintf( + "INSERT INTO %s (%s) VALUES (%s)", + quotePostgresIdent(tbl.name), + strings.Join(cols, ", "), + strings.Join(placeholders, ", "), + ) + return q, args, nil +} + +func joinQuotedIdents(idents []string) string { + quoted := make([]string, 0, len(idents)) + for _, s := range idents { + quoted = append(quoted, quotePostgresIdent(s)) + } + return strings.Join(quoted, ", ") +} + +func quotePostgresIdent(s string) string { + return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` +} diff --git a/sinks/postgres_schema.go b/sinks/postgres_schema.go new file mode 100644 index 0000000..88bda3a --- /dev/null +++ b/sinks/postgres_schema.go @@ -0,0 +1,285 @@ +package sinks + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "gitea.maximumdirect.net/ejr/feedkit/event" +) + +// PostgresMapFunc maps one event into zero or more table writes. +// +// Returning zero writes means "this event is not mapped for this sink" and is +// treated as a non-error no-op. +type PostgresMapFunc func(ctx context.Context, e event.Event) ([]PostgresWrite, error) + +// PostgresSchema describes the downstream-provided relational model and mapper +// for one configured postgres sink. +type PostgresSchema struct { + Tables []PostgresTable + MapEvent PostgresMapFunc +} + +type PostgresWrite struct { + Table string + Values map[string]any +} + +type PostgresTable struct { + Name string + Columns []PostgresColumn + PrimaryKey []string + PruneColumn string + Indexes []PostgresIndex +} + +type PostgresColumn struct { + Name string + Type string + Nullable bool +} + +type PostgresIndex struct { + Name string + Columns []string + Unique bool +} + +// PostgresPruner is an optional interface exposed by PostgresSink so downstream +// applications can call retention helpers via type assertion. +type PostgresPruner interface { + PruneKeepLatest(ctx context.Context, table string, keep int) (int64, error) + PruneOlderThan(ctx context.Context, table string, cutoff time.Time) (int64, error) + PruneAllKeepLatest(ctx context.Context, keep int) (map[string]int64, error) + PruneAllOlderThan(ctx context.Context, cutoff time.Time) (map[string]int64, error) +} + +type postgresSchemaCompiled struct { + tableOrder []string + tables map[string]postgresTableCompiled + mapEvent PostgresMapFunc +} + +type postgresTableCompiled struct { + name string + columns map[string]PostgresColumn + columnOrder []string + primaryKey []string + pruneColumn string + indexes []PostgresIndex +} + +var ( + postgresSchemaRegistryMu sync.RWMutex + postgresSchemaRegistry = map[string]postgresSchemaCompiled{} +) + +// RegisterPostgresSchema registers one downstream schema by sink name. +// +// This should be called by downstream daemon wiring code before sink +// construction. Duplicate sink-name registrations are rejected. +func RegisterPostgresSchema(sinkName string, schema PostgresSchema) error { + sinkName = strings.TrimSpace(sinkName) + if sinkName == "" { + return fmt.Errorf("postgres schema: sink name cannot be empty") + } + + compiled, err := compilePostgresSchema(schema) + if err != nil { + return err + } + + postgresSchemaRegistryMu.Lock() + defer postgresSchemaRegistryMu.Unlock() + + if _, exists := postgresSchemaRegistry[sinkName]; exists { + return fmt.Errorf("postgres schema: sink %q already registered", sinkName) + } + + postgresSchemaRegistry[sinkName] = compiled + return nil +} + +func MustRegisterPostgresSchema(sinkName string, schema PostgresSchema) { + if err := RegisterPostgresSchema(sinkName, schema); err != nil { + panic(err) + } +} + +func lookupPostgresSchema(sinkName string) (postgresSchemaCompiled, bool) { + postgresSchemaRegistryMu.RLock() + defer postgresSchemaRegistryMu.RUnlock() + + s, ok := postgresSchemaRegistry[sinkName] + return s, ok +} + +func compilePostgresSchema(schema PostgresSchema) (postgresSchemaCompiled, error) { + if schema.MapEvent == nil { + return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: map function is required") + } + if len(schema.Tables) == 0 { + return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: at least one table is required") + } + + compiled := postgresSchemaCompiled{ + tableOrder: make([]string, 0, len(schema.Tables)), + tables: make(map[string]postgresTableCompiled, len(schema.Tables)), + mapEvent: schema.MapEvent, + } + + seenTables := map[string]bool{} + seenIndexes := map[string]bool{} + + for i, tbl := range schema.Tables { + tableName := strings.TrimSpace(tbl.Name) + if tableName == "" { + return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: tables[%d].name is required", i) + } + if seenTables[tableName] { + return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: duplicate table name %q", tableName) + } + seenTables[tableName] = true + + if len(tbl.Columns) == 0 { + return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q must define at least one column", tableName) + } + + colOrder := make([]string, 0, len(tbl.Columns)) + colMap := make(map[string]PostgresColumn, len(tbl.Columns)) + for j, col := range tbl.Columns { + colName := strings.TrimSpace(col.Name) + if colName == "" { + return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q columns[%d].name is required", tableName, j) + } + if _, exists := colMap[colName]; exists { + return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q duplicate column %q", tableName, colName) + } + if strings.TrimSpace(col.Type) == "" { + return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q column %q type is required", tableName, colName) + } + colOrder = append(colOrder, colName) + colMap[colName] = PostgresColumn{ + Name: colName, + Type: strings.TrimSpace(col.Type), + Nullable: col.Nullable, + } + } + + pk, err := validatePostgresColumnSet(tableName, "primary key", tbl.PrimaryKey, colMap) + if err != nil { + return postgresSchemaCompiled{}, err + } + + pruneCol := strings.TrimSpace(tbl.PruneColumn) + if pruneCol == "" { + return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q prune column is required", tableName) + } + if _, ok := colMap[pruneCol]; !ok { + return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q prune column %q not found in columns", tableName, pruneCol) + } + + indexes := make([]PostgresIndex, 0, len(tbl.Indexes)) + for j, idx := range tbl.Indexes { + idxName := strings.TrimSpace(idx.Name) + if idxName == "" { + return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q indexes[%d].name is required", tableName, j) + } + if len(idx.Columns) == 0 { + return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: table %q index %q must include at least one column", tableName, idxName) + } + if seenIndexes[idxName] { + return postgresSchemaCompiled{}, fmt.Errorf("postgres schema: duplicate index name %q", idxName) + } + seenIndexes[idxName] = true + + idxCols, err := validatePostgresColumnSet(tableName, fmt.Sprintf("index %q columns", idxName), idx.Columns, colMap) + if err != nil { + return postgresSchemaCompiled{}, err + } + + indexes = append(indexes, PostgresIndex{ + Name: idxName, + Columns: idxCols, + Unique: idx.Unique, + }) + } + + compiled.tableOrder = append(compiled.tableOrder, tableName) + compiled.tables[tableName] = postgresTableCompiled{ + name: tableName, + columns: colMap, + columnOrder: colOrder, + primaryKey: pk, + pruneColumn: pruneCol, + indexes: indexes, + } + } + + return compiled, nil +} + +func validatePostgresColumnSet(tableName, label string, cols []string, colMap map[string]PostgresColumn) ([]string, error) { + if len(cols) == 0 { + return nil, nil + } + out := make([]string, 0, len(cols)) + seen := map[string]bool{} + for _, c := range cols { + name := strings.TrimSpace(c) + if name == "" { + return nil, fmt.Errorf("postgres schema: table %q %s contains empty column name", tableName, label) + } + if seen[name] { + return nil, fmt.Errorf("postgres schema: table %q %s contains duplicate column %q", tableName, label, name) + } + if _, ok := colMap[name]; !ok { + return nil, fmt.Errorf("postgres schema: table %q %s references unknown column %q", tableName, label, name) + } + seen[name] = true + out = append(out, name) + } + return out, nil +} + +func (s postgresSchemaCompiled) validateWrite(w PostgresWrite) (postgresTableCompiled, error) { + tableName := strings.TrimSpace(w.Table) + if tableName == "" { + return postgresTableCompiled{}, fmt.Errorf("write table is required") + } + t, ok := s.tables[tableName] + if !ok { + return postgresTableCompiled{}, fmt.Errorf("table %q is not defined in postgres schema", tableName) + } + + if len(w.Values) == 0 { + return postgresTableCompiled{}, fmt.Errorf("write for table %q must include values", tableName) + } + + for k := range w.Values { + if _, ok := t.columns[k]; !ok { + return postgresTableCompiled{}, fmt.Errorf("write for table %q includes unknown column %q", tableName, k) + } + } + + if len(w.Values) != len(t.columnOrder) { + return postgresTableCompiled{}, fmt.Errorf("write for table %q must include all declared columns", tableName) + } + + for _, col := range t.columnOrder { + v, ok := w.Values[col] + if !ok { + return postgresTableCompiled{}, fmt.Errorf("write for table %q is missing column %q", tableName, col) + } + if v == nil { + if c := t.columns[col]; !c.Nullable { + return postgresTableCompiled{}, fmt.Errorf("write for table %q has nil value for non-null column %q", tableName, col) + } + } + } + + return t, nil +} diff --git a/sinks/postgres_test.go b/sinks/postgres_test.go new file mode 100644 index 0000000..0778155 --- /dev/null +++ b/sinks/postgres_test.go @@ -0,0 +1,590 @@ +package sinks + +import ( + "context" + "database/sql" + "errors" + "net/url" + "strings" + "testing" + "time" + + "gitea.maximumdirect.net/ejr/feedkit/config" + "gitea.maximumdirect.net/ejr/feedkit/event" +) + +type fakeResult struct { + rows int64 +} + +func (r fakeResult) LastInsertId() (int64, error) { return 0, errors.New("unsupported") } +func (r fakeResult) RowsAffected() (int64, error) { return r.rows, nil } + +type execCall struct { + query string + args []any +} + +type fakeTx struct { + execCalls []execCall + execErr error + execErrOnCall int + execRows int64 + commitErr error + rollbackErr error + commitCalls int + rollbackCalls int +} + +func (t *fakeTx) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) { + t.execCalls = append(t.execCalls, execCall{query: query, args: append([]any(nil), args...)}) + if t.execErr != nil && t.execErrOnCall == len(t.execCalls) { + return nil, t.execErr + } + return fakeResult{rows: t.execRows}, nil +} + +func (t *fakeTx) Commit() error { + t.commitCalls++ + return t.commitErr +} + +func (t *fakeTx) Rollback() error { + t.rollbackCalls++ + return t.rollbackErr +} + +type fakeDB struct { + pingErr error + beginErr error + execErr error + execErrOnCall int + execRows int64 + pingCalls int + beginCalls int + execCalls []execCall + closeCalls int + tx *fakeTx +} + +func (d *fakeDB) PingContext(_ context.Context) error { + d.pingCalls++ + return d.pingErr +} + +func (d *fakeDB) BeginTx(_ context.Context, _ *sql.TxOptions) (postgresTx, error) { + d.beginCalls++ + if d.beginErr != nil { + return nil, d.beginErr + } + if d.tx == nil { + d.tx = &fakeTx{} + } + return d.tx, nil +} + +func (d *fakeDB) ExecContext(_ context.Context, query string, args ...any) (sql.Result, error) { + d.execCalls = append(d.execCalls, execCall{query: query, args: append([]any(nil), args...)}) + if d.execErr != nil && d.execErrOnCall == len(d.execCalls) { + return nil, d.execErr + } + return fakeResult{rows: d.execRows}, nil +} + +func (d *fakeDB) Close() error { + d.closeCalls++ + return nil +} + +func resetPostgresSchemaRegistryForTest() { + postgresSchemaRegistryMu.Lock() + defer postgresSchemaRegistryMu.Unlock() + postgresSchemaRegistry = map[string]postgresSchemaCompiled{} +} + +func withPostgresTestState(t *testing.T) { + t.Helper() + + resetPostgresSchemaRegistryForTest() + oldOpen := openPostgresDB + t.Cleanup(func() { + openPostgresDB = oldOpen + resetPostgresSchemaRegistryForTest() + }) +} + +func validTestEvent() event.Event { + now := time.Now().UTC() + return event.Event{ + ID: "evt-1", + Kind: event.Kind("observation"), + Source: "source-1", + EmittedAt: now, + Payload: map[string]any{ + "x": 1, + }, + } +} + +func schemaOneTable(mapFn PostgresMapFunc) PostgresSchema { + return PostgresSchema{ + Tables: []PostgresTable{ + { + Name: "events", + Columns: []PostgresColumn{ + {Name: "event_id", Type: "TEXT", Nullable: false}, + {Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false}, + {Name: "payload_json", Type: "JSONB", Nullable: false}, + }, + PrimaryKey: []string{"event_id"}, + PruneColumn: "emitted_at", + Indexes: []PostgresIndex{ + {Name: "idx_events_emitted_at", Columns: []string{"emitted_at"}}, + }, + }, + }, + MapEvent: mapFn, + } +} + +func schemaTwoTables(mapFn PostgresMapFunc) PostgresSchema { + return PostgresSchema{ + Tables: []PostgresTable{ + { + Name: "events", + Columns: []PostgresColumn{ + {Name: "event_id", Type: "TEXT", Nullable: false}, + {Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false}, + }, + PrimaryKey: []string{"event_id"}, + PruneColumn: "emitted_at", + }, + { + Name: "event_payloads", + Columns: []PostgresColumn{ + {Name: "event_id", Type: "TEXT", Nullable: false}, + {Name: "payload_json", Type: "JSONB", Nullable: false}, + {Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false}, + }, + PrimaryKey: []string{"event_id"}, + PruneColumn: "emitted_at", + }, + }, + MapEvent: mapFn, + } +} + +func mustCompileSchema(t *testing.T, s PostgresSchema) postgresSchemaCompiled { + t.Helper() + compiled, err := compilePostgresSchema(s) + if err != nil { + t.Fatalf("compile schema: %v", err) + } + return compiled +} + +func TestRegisterPostgresSchema(t *testing.T) { + withPostgresTestState(t) + + err := RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { + return nil, nil + })) + if err != nil { + t.Fatalf("register schema: %v", err) + } + + if _, ok := lookupPostgresSchema("pg"); !ok { + t.Fatalf("expected schema registration") + } + + err = RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { + return nil, nil + })) + if err == nil { + t.Fatalf("expected duplicate registration error") + } + if !strings.Contains(err.Error(), "already registered") { + t.Fatalf("unexpected duplicate error: %v", err) + } +} + +func TestRegisterPostgresSchema_RejectsInvalidSchema(t *testing.T) { + withPostgresTestState(t) + + err := RegisterPostgresSchema("pg", PostgresSchema{ + Tables: []PostgresTable{ + { + Name: "events", + Columns: []PostgresColumn{ + {Name: "id", Type: "TEXT", Nullable: false}, + }, + PruneColumn: "missing_col", + }, + }, + MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }, + }) + if err == nil { + t.Fatalf("expected invalid schema error") + } + if !strings.Contains(err.Error(), "prune column") { + t.Fatalf("unexpected schema validation error: %v", err) + } + + err = RegisterPostgresSchema("pg2", PostgresSchema{ + Tables: []PostgresTable{ + { + Name: "events", + Columns: []PostgresColumn{ + {Name: "id", Type: "TEXT", Nullable: false}, + {Name: "emitted_at", Type: "TIMESTAMPTZ", Nullable: false}, + }, + PruneColumn: "emitted_at", + Indexes: []PostgresIndex{ + {Name: "idx_events_empty", Columns: nil}, + }, + }, + }, + MapEvent: func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { return nil, nil }, + }) + if err == nil { + t.Fatalf("expected invalid index schema error") + } + if !strings.Contains(err.Error(), "at least one column") { + t.Fatalf("unexpected index validation error: %v", err) + } +} + +func TestNewPostgresSinkFromConfig_MissingParams(t *testing.T) { + withPostgresTestState(t) + + tests := []struct { + name string + params map[string]any + want string + }{ + {name: "missing uri", params: map[string]any{"username": "u", "password": "p"}, want: "params.uri"}, + {name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p"}, want: "params.username"}, + {name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u"}, want: "params.password"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewPostgresSinkFromConfig(config.SinkConfig{ + Name: "pg", + Driver: "postgres", + Params: tc.params, + }) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), tc.want) { + t.Fatalf("expected %q in error, got: %v", tc.want, err) + } + }) + } +} + +func TestNewPostgresSinkFromConfig_MissingSchemaRegistration(t *testing.T) { + withPostgresTestState(t) + + _, err := NewPostgresSinkFromConfig(config.SinkConfig{ + Name: "pg", + Driver: "postgres", + Params: map[string]any{ + "uri": "postgres://localhost/db", + "username": "user", + "password": "pass", + }, + }) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "no schema registered") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestNewPostgresSinkFromConfig_EagerInit(t *testing.T) { + withPostgresTestState(t) + + err := RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { + return nil, nil + })) + if err != nil { + t.Fatalf("register schema: %v", err) + } + + db := &fakeDB{} + var gotDSN string + openPostgresDB = func(dsn string) (postgresDB, error) { + gotDSN = dsn + return db, nil + } + + s, err := NewPostgresSinkFromConfig(config.SinkConfig{ + Name: "pg", + Driver: "postgres", + Params: map[string]any{ + "uri": "postgres://db.example.local:5432/feedkit?sslmode=disable", + "username": "app_user", + "password": "app_pass", + }, + }) + if err != nil { + t.Fatalf("new postgres sink: %v", err) + } + if s == nil { + t.Fatalf("expected sink") + } + + if db.pingCalls != 1 { + t.Fatalf("expected one ping, got %d", db.pingCalls) + } + if len(db.execCalls) != 2 { + t.Fatalf("expected 2 init exec calls (table + index), got %d", len(db.execCalls)) + } + if !strings.Contains(db.execCalls[0].query, `CREATE TABLE IF NOT EXISTS "events"`) { + t.Fatalf("unexpected create table query: %s", db.execCalls[0].query) + } + if !strings.Contains(db.execCalls[1].query, `CREATE INDEX IF NOT EXISTS "idx_events_emitted_at"`) { + t.Fatalf("unexpected create index query: %s", db.execCalls[1].query) + } + + u, err := url.Parse(gotDSN) + if err != nil { + t.Fatalf("parse dsn: %v", err) + } + if u.User == nil || u.User.Username() != "app_user" { + t.Fatalf("dsn missing username: %q", gotDSN) + } + pass, ok := u.User.Password() + if !ok || pass != "app_pass" { + t.Fatalf("dsn missing password: %q", gotDSN) + } +} + +func TestNewPostgresSinkFromConfig_InitFailureClosesDB(t *testing.T) { + withPostgresTestState(t) + + err := RegisterPostgresSchema("pg", schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { + return nil, nil + })) + if err != nil { + t.Fatalf("register schema: %v", err) + } + + db := &fakeDB{execErrOnCall: 1, execErr: errors.New("ddl failed")} + openPostgresDB = func(_ string) (postgresDB, error) { + return db, nil + } + + _, err = NewPostgresSinkFromConfig(config.SinkConfig{ + Name: "pg", + Driver: "postgres", + Params: map[string]any{ + "uri": "postgres://localhost/db", + "username": "user", + "password": "pass", + }, + }) + if err == nil { + t.Fatalf("expected init error") + } + if db.closeCalls != 1 { + t.Fatalf("expected db close on init failure") + } +} + +func TestPostgresSinkConsume_InvalidEvent(t *testing.T) { + db := &fakeDB{} + called := 0 + sink := &PostgresSink{ + name: "pg", + db: db, + schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { + called++ + return nil, nil + })), + } + + err := sink.Consume(context.Background(), event.Event{}) + if err == nil { + t.Fatalf("expected invalid event error") + } + if !strings.Contains(err.Error(), "invalid event") { + t.Fatalf("unexpected error: %v", err) + } + if called != 0 { + t.Fatalf("expected mapper not called for invalid events") + } +} + +func TestPostgresSinkConsume_UnmappedEventIsNoOp(t *testing.T) { + db := &fakeDB{} + sink := &PostgresSink{ + name: "pg", + db: db, + schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { + return nil, nil + })), + } + + if err := sink.Consume(context.Background(), validTestEvent()); err != nil { + t.Fatalf("consume: %v", err) + } + if db.beginCalls != 0 { + t.Fatalf("expected no transaction for unmapped events") + } +} + +func TestPostgresSinkConsume_OneEventWritesMultipleTablesAtomically(t *testing.T) { + tx := &fakeTx{} + db := &fakeDB{tx: tx} + sink := &PostgresSink{ + name: "pg", + db: db, + schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { + return []PostgresWrite{ + {Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}}, + {Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}}, + }, nil + })), + } + + if err := sink.Consume(context.Background(), validTestEvent()); err != nil { + t.Fatalf("consume: %v", err) + } + if db.beginCalls != 1 { + t.Fatalf("expected one transaction begin, got %d", db.beginCalls) + } + if len(tx.execCalls) != 2 { + t.Fatalf("expected 2 insert statements, got %d", len(tx.execCalls)) + } + if tx.commitCalls != 1 { + t.Fatalf("expected one commit, got %d", tx.commitCalls) + } + if tx.rollbackCalls != 0 { + t.Fatalf("expected zero rollbacks, got %d", tx.rollbackCalls) + } +} + +func TestPostgresSinkConsume_InsertFailureRollsBack(t *testing.T) { + tx := &fakeTx{execErrOnCall: 2, execErr: errors.New("duplicate key")} + db := &fakeDB{tx: tx} + sink := &PostgresSink{ + name: "pg", + db: db, + schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, e event.Event) ([]PostgresWrite, error) { + return []PostgresWrite{ + {Table: "events", Values: map[string]any{"event_id": e.ID, "emitted_at": e.EmittedAt}}, + {Table: "event_payloads", Values: map[string]any{"event_id": e.ID, "payload_json": `{}`, "emitted_at": e.EmittedAt}}, + }, nil + })), + } + + err := sink.Consume(context.Background(), validTestEvent()) + if err == nil { + t.Fatalf("expected insert error") + } + if !strings.Contains(err.Error(), "insert into") { + t.Fatalf("unexpected error: %v", err) + } + if tx.commitCalls != 0 { + t.Fatalf("expected no commit") + } + if tx.rollbackCalls != 1 { + t.Fatalf("expected rollback, got %d", tx.rollbackCalls) + } +} + +func TestPostgresSinkPrune_PerTable(t *testing.T) { + db := &fakeDB{execRows: 7} + sink := &PostgresSink{ + name: "pg", + db: db, + schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { + return nil, nil + })), + } + + rows, err := sink.PruneKeepLatest(context.Background(), "events", 10) + if err != nil { + t.Fatalf("prune keep latest: %v", err) + } + if rows != 7 { + t.Fatalf("unexpected rows affected: %d", rows) + } + if len(db.execCalls) != 1 { + t.Fatalf("expected one prune query") + } + if !strings.Contains(db.execCalls[0].query, `ORDER BY "emitted_at" DESC`) { + t.Fatalf("unexpected keep-latest query: %s", db.execCalls[0].query) + } + if len(db.execCalls[0].args) != 1 || db.execCalls[0].args[0] != 10 { + t.Fatalf("unexpected keep-latest args: %#v", db.execCalls[0].args) + } + + cutoff := time.Now().UTC().Add(-24 * time.Hour) + rows, err = sink.PruneOlderThan(context.Background(), "events", cutoff) + if err != nil { + t.Fatalf("prune older than: %v", err) + } + if rows != 7 { + t.Fatalf("unexpected rows affected: %d", rows) + } + if len(db.execCalls) != 2 { + t.Fatalf("expected two prune queries") + } + if !strings.Contains(db.execCalls[1].query, `WHERE "emitted_at" < $1`) { + t.Fatalf("unexpected older-than query: %s", db.execCalls[1].query) + } +} + +func TestPostgresSinkPrune_AllTables(t *testing.T) { + db := &fakeDB{execRows: 3} + sink := &PostgresSink{ + name: "pg", + db: db, + schema: mustCompileSchema(t, schemaTwoTables(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { + return nil, nil + })), + } + + keepCounts, err := sink.PruneAllKeepLatest(context.Background(), 5) + if err != nil { + t.Fatalf("prune all keep latest: %v", err) + } + if len(keepCounts) != 2 || keepCounts["events"] != 3 || keepCounts["event_payloads"] != 3 { + t.Fatalf("unexpected keep counts: %#v", keepCounts) + } + + db.execCalls = nil + olderCounts, err := sink.PruneAllOlderThan(context.Background(), time.Now().UTC()) + if err != nil { + t.Fatalf("prune all older than: %v", err) + } + if len(olderCounts) != 2 || olderCounts["events"] != 3 || olderCounts["event_payloads"] != 3 { + t.Fatalf("unexpected older-than counts: %#v", olderCounts) + } + if len(db.execCalls) != 2 { + t.Fatalf("expected one prune call per table") + } +} + +func TestPostgresSinkPrune_Errors(t *testing.T) { + db := &fakeDB{} + sink := &PostgresSink{ + name: "pg", + db: db, + schema: mustCompileSchema(t, schemaOneTable(func(_ context.Context, _ event.Event) ([]PostgresWrite, error) { + return nil, nil + })), + } + + if _, err := sink.PruneKeepLatest(context.Background(), "events", -1); err == nil { + t.Fatalf("expected negative keep error") + } + if _, err := sink.PruneOlderThan(context.Background(), "missing", time.Now().UTC()); err == nil { + t.Fatalf("expected unknown table error") + } +}