package sinks import ( "context" "database/sql" "fmt" "net/url" "strconv" "strings" "time" "gitea.maximumdirect.net/ejr/feedkit/config" "gitea.maximumdirect.net/ejr/feedkit/event" _ "github.com/lib/pq" ) const postgresInitTimeout = 5 * time.Second type postgresTx interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) Commit() error Rollback() error } type postgresExecer interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) } type postgresDB interface { PingContext(ctx context.Context) error BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) Close() error } type sqlDBWrapper struct { db *sql.DB } func (w *sqlDBWrapper) PingContext(ctx context.Context) error { return w.db.PingContext(ctx) } func (w *sqlDBWrapper) BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error) { tx, err := w.db.BeginTx(ctx, opts) if err != nil { return nil, err } return &sqlTxWrapper{tx: tx}, nil } func (w *sqlDBWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { return w.db.ExecContext(ctx, query, args...) } func (w *sqlDBWrapper) Close() error { return w.db.Close() } type sqlTxWrapper struct { tx *sql.Tx } func (w *sqlTxWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { return w.tx.ExecContext(ctx, query, args...) } func (w *sqlTxWrapper) Commit() error { return w.tx.Commit() } func (w *sqlTxWrapper) Rollback() error { return w.tx.Rollback() } var openPostgresDB = func(dsn string) (postgresDB, error) { db, err := sql.Open("postgres", dsn) if err != nil { return nil, err } return &sqlDBWrapper{db: db}, nil } type PostgresSink struct { name string db postgresDB schema postgresSchemaCompiled pruneWindow time.Duration } func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) { uri, err := requireStringParam(cfg, "uri") if err != nil { return nil, err } username, err := requireStringParam(cfg, "username") if err != nil { return nil, err } password, err := requireStringParam(cfg, "password") if err != nil { return nil, err } pruneWindow, err := parsePostgresPruneWindow(cfg) if err != nil { return nil, err } schema, ok := lookupPostgresSchema(cfg.Name) if !ok { return nil, fmt.Errorf("postgres sink %q: no schema registered (call sinks.RegisterPostgresSchema before building sinks)", cfg.Name) } dsn, err := buildPostgresDSN(uri, username, password) if err != nil { return nil, fmt.Errorf("postgres sink %q: build dsn: %w", cfg.Name, err) } db, err := openPostgresDB(dsn) if err != nil { return nil, fmt.Errorf("postgres sink %q: open db: %w", cfg.Name, err) } s := &PostgresSink{name: cfg.Name, db: db, schema: schema, pruneWindow: pruneWindow} if err := s.initialize(); err != nil { _ = db.Close() return nil, err } return s, nil } func (p *PostgresSink) Name() string { return p.name } func (p *PostgresSink) Consume(ctx context.Context, e event.Event) error { // Boundary validation: if something upstream violated invariants, // surface it loudly rather than writing corrupt rows. if err := e.Validate(); err != nil { return fmt.Errorf("postgres sink: invalid event: %w", err) } if err := ctx.Err(); err != nil { return err } writes, err := p.schema.mapEvent(ctx, e) if err != nil { return fmt.Errorf("postgres sink: map event: %w", err) } if len(writes) == 0 { return nil } tx, err := p.db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("postgres sink: begin tx: %w", err) } committed := false defer func() { if !committed { _ = tx.Rollback() } }() for _, w := range writes { tbl, err := p.schema.validateWrite(w) if err != nil { return fmt.Errorf("postgres sink: %w", err) } query, args, err := buildInsertSQL(tbl, w) if err != nil { return fmt.Errorf("postgres sink: build insert for table %q: %w", tbl.name, err) } if _, err := tx.ExecContext(ctx, query, args...); err != nil { return fmt.Errorf("postgres sink: insert into %q: %w", tbl.name, err) } } if p.pruneWindow > 0 { cutoff := time.Now().UTC().Add(-p.pruneWindow) for _, tableName := range p.schema.tableOrder { tbl := p.schema.tables[tableName] if _, err := execPruneOlderThan(ctx, tx, tbl, cutoff); err != nil { return fmt.Errorf("postgres sink: %w", err) } } } if err := tx.Commit(); err != nil { return fmt.Errorf("postgres sink: commit tx: %w", err) } committed = true return nil } func (p *PostgresSink) PruneKeepLatest(ctx context.Context, table string, keep int) (int64, error) { if keep < 0 { return 0, fmt.Errorf("postgres sink: keep must be >= 0") } tbl, err := p.lookupTable(table) if err != nil { return 0, err } query := fmt.Sprintf( `DELETE FROM %s WHERE ctid IN ( SELECT ctid FROM %s ORDER BY %s DESC OFFSET $1 )`, quotePostgresIdent(tbl.name), quotePostgresIdent(tbl.name), quotePostgresIdent(tbl.pruneColumn), ) res, err := p.db.ExecContext(ctx, query, keep) if err != nil { return 0, fmt.Errorf("postgres sink: prune keep latest table %q: %w", tbl.name, err) } rows, err := res.RowsAffected() if err != nil { return 0, fmt.Errorf("postgres sink: prune keep latest table %q rows affected: %w", tbl.name, err) } return rows, nil } func (p *PostgresSink) PruneOlderThan(ctx context.Context, table string, cutoff time.Time) (int64, error) { tbl, err := p.lookupTable(table) if err != nil { return 0, err } rows, err := execPruneOlderThan(ctx, p.db, tbl, cutoff) if err != nil { return 0, fmt.Errorf("postgres sink: %w", err) } return rows, nil } func (p *PostgresSink) PruneAllKeepLatest(ctx context.Context, keep int) (map[string]int64, error) { counts := make(map[string]int64, len(p.schema.tableOrder)) for _, table := range p.schema.tableOrder { n, err := p.PruneKeepLatest(ctx, table, keep) if err != nil { return counts, err } counts[table] = n } return counts, nil } func (p *PostgresSink) PruneAllOlderThan(ctx context.Context, cutoff time.Time) (map[string]int64, error) { counts := make(map[string]int64, len(p.schema.tableOrder)) for _, table := range p.schema.tableOrder { n, err := p.PruneOlderThan(ctx, table, cutoff) if err != nil { return counts, err } counts[table] = n } return counts, nil } func (p *PostgresSink) initialize() error { ctx, cancel := context.WithTimeout(context.Background(), postgresInitTimeout) defer cancel() if err := p.db.PingContext(ctx); err != nil { return fmt.Errorf("postgres sink %q: ping db: %w", p.name, err) } for _, tableName := range p.schema.tableOrder { tbl := p.schema.tables[tableName] createTableSQL := buildCreateTableSQL(tbl) if _, err := p.db.ExecContext(ctx, createTableSQL); err != nil { return fmt.Errorf("postgres sink %q: ensure table %q: %w", p.name, tbl.name, err) } for _, idx := range tbl.indexes { createIndexSQL := buildCreateIndexSQL(tbl.name, idx) if _, err := p.db.ExecContext(ctx, createIndexSQL); err != nil { return fmt.Errorf("postgres sink %q: ensure index %q on %q: %w", p.name, idx.Name, tbl.name, err) } } } return nil } func (p *PostgresSink) lookupTable(table string) (postgresTableCompiled, error) { table = strings.TrimSpace(table) if table == "" { return postgresTableCompiled{}, fmt.Errorf("postgres sink: table cannot be empty") } tbl, ok := p.schema.tables[table] if !ok { return postgresTableCompiled{}, fmt.Errorf("postgres sink: unknown table %q", table) } return tbl, nil } func buildPostgresDSN(uri, username, password string) (string, error) { u, err := url.Parse(strings.TrimSpace(uri)) if err != nil { return "", fmt.Errorf("invalid uri: %w", err) } if u.Scheme == "" { return "", fmt.Errorf("invalid uri: missing scheme") } if u.Host == "" { return "", fmt.Errorf("invalid uri: missing host") } u.User = url.UserPassword(username, password) return u.String(), nil } func parsePostgresPruneWindow(cfg config.SinkConfig) (time.Duration, error) { raw, ok := cfg.Params["prune"] if !ok || raw == nil { return 0, nil } s, ok := raw.(string) if !ok { return 0, fmt.Errorf("sink %q: params.prune must be a string duration (e.g. 72h, 3d, 2w)", cfg.Name) } d, err := parsePostgresPruneDuration(s) if err != nil { return 0, fmt.Errorf("sink %q: params.prune %q is invalid: %w", cfg.Name, s, err) } return d, nil } func parsePostgresPruneDuration(raw string) (time.Duration, error) { s := strings.TrimSpace(raw) if s == "" { return 0, fmt.Errorf("must not be empty") } lower := strings.ToLower(s) if strings.HasSuffix(lower, "d") || strings.HasSuffix(lower, "w") { unit := lower[len(lower)-1] n, err := strconv.Atoi(strings.TrimSpace(lower[:len(lower)-1])) if err != nil { return 0, fmt.Errorf("must use a positive integer before %q", string(unit)) } if n <= 0 { return 0, fmt.Errorf("must be > 0") } if unit == 'd' { return time.Duration(n) * 24 * time.Hour, nil } return time.Duration(n) * 7 * 24 * time.Hour, nil } d, err := time.ParseDuration(s) if err != nil { return 0, fmt.Errorf("must be a Go duration or use d/w suffixes") } if d <= 0 { return 0, fmt.Errorf("must be > 0") } return d, nil } func buildPruneOlderThanSQL(tbl postgresTableCompiled) string { return fmt.Sprintf( `DELETE FROM %s WHERE %s < $1`, quotePostgresIdent(tbl.name), quotePostgresIdent(tbl.pruneColumn), ) } func execPruneOlderThan(ctx context.Context, execer postgresExecer, tbl postgresTableCompiled, cutoff time.Time) (int64, error) { query := buildPruneOlderThanSQL(tbl) res, err := execer.ExecContext(ctx, query, cutoff) if err != nil { return 0, fmt.Errorf("prune older than table %q: %w", tbl.name, err) } rows, err := res.RowsAffected() if err != nil { return 0, fmt.Errorf("prune older than table %q rows affected: %w", tbl.name, err) } return rows, nil } func buildCreateTableSQL(tbl postgresTableCompiled) string { defs := make([]string, 0, len(tbl.columnOrder)+1) for _, colName := range tbl.columnOrder { col := tbl.columns[colName] def := fmt.Sprintf("%s %s", quotePostgresIdent(col.Name), col.Type) if !col.Nullable { def += " NOT NULL" } defs = append(defs, def) } if len(tbl.primaryKey) > 0 { defs = append(defs, fmt.Sprintf("PRIMARY KEY (%s)", joinQuotedIdents(tbl.primaryKey))) } return fmt.Sprintf( "CREATE TABLE IF NOT EXISTS %s (%s)", quotePostgresIdent(tbl.name), strings.Join(defs, ", "), ) } func buildCreateIndexSQL(tableName string, idx PostgresIndex) string { unique := "" if idx.Unique { unique = "UNIQUE " } return fmt.Sprintf( "CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)", unique, quotePostgresIdent(idx.Name), quotePostgresIdent(tableName), joinQuotedIdents(idx.Columns), ) } func buildInsertSQL(tbl postgresTableCompiled, w PostgresWrite) (string, []any, error) { cols := make([]string, 0, len(tbl.columnOrder)) args := make([]any, 0, len(tbl.columnOrder)) placeholders := make([]string, 0, len(tbl.columnOrder)) for i, colName := range tbl.columnOrder { v, ok := w.Values[colName] if !ok { return "", nil, fmt.Errorf("missing value for column %q", colName) } cols = append(cols, quotePostgresIdent(colName)) args = append(args, v) placeholders = append(placeholders, "$"+strconv.Itoa(i+1)) } q := fmt.Sprintf( "INSERT INTO %s (%s) VALUES (%s)", quotePostgresIdent(tbl.name), strings.Join(cols, ", "), strings.Join(placeholders, ", "), ) return q, args, nil } func joinQuotedIdents(idents []string) string { quoted := make([]string, 0, len(idents)) for _, s := range idents { quoted = append(quoted, quotePostgresIdent(s)) } return strings.Join(quoted, ", ") } func quotePostgresIdent(s string) string { return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` }