Files
feedkit/sinks/postgres.go

439 lines
11 KiB
Go

package sinks
import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"time"
"gitea.maximumdirect.net/ejr/feedkit/config"
"gitea.maximumdirect.net/ejr/feedkit/event"
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
)
type postgresTx interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Commit() error
Rollback() error
}
type postgresExecer interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
type postgresDB interface {
PingContext(ctx context.Context) error
BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Close() error
}
type sqlDBWrapper struct {
db *sql.DB
}
func (w *sqlDBWrapper) PingContext(ctx context.Context) error {
return w.db.PingContext(ctx)
}
func (w *sqlDBWrapper) BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error) {
tx, err := w.db.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &sqlTxWrapper{tx: tx}, nil
}
func (w *sqlDBWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
return w.db.ExecContext(ctx, query, args...)
}
func (w *sqlDBWrapper) Close() error {
return w.db.Close()
}
type sqlTxWrapper struct {
tx *sql.Tx
}
func (w *sqlTxWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
return w.tx.ExecContext(ctx, query, args...)
}
func (w *sqlTxWrapper) Commit() error {
return w.tx.Commit()
}
func (w *sqlTxWrapper) Rollback() error {
return w.tx.Rollback()
}
var openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
db, err := pgconn.Open(ctx, cfg)
if err != nil {
return nil, err
}
return &sqlDBWrapper{db: db}, nil
}
type PostgresSink struct {
name string
db postgresDB
schema postgresSchemaCompiled
pruneWindow time.Duration
}
func NewPostgresSinkFromConfig(cfg config.SinkConfig, schemaDef PostgresSchema) (Sink, error) {
uri, err := requireStringParam(cfg, "uri")
if err != nil {
return nil, err
}
username, err := requireStringParam(cfg, "username")
if err != nil {
return nil, err
}
password, err := requireStringParam(cfg, "password")
if err != nil {
return nil, err
}
pruneWindow, err := parsePostgresPruneWindow(cfg)
if err != nil {
return nil, err
}
schema, err := compilePostgresSchema(schemaDef)
if err != nil {
return nil, fmt.Errorf("postgres sink %q: compile schema: %w", cfg.Name, err)
}
db, err := openPostgresDB(context.Background(), pgconn.ConnConfig{
URI: uri,
Username: username,
Password: password,
})
if err != nil {
return nil, fmt.Errorf("postgres sink %q: open db: %w", cfg.Name, err)
}
s := &PostgresSink{name: cfg.Name, db: db, schema: schema, pruneWindow: pruneWindow}
if err := s.initialize(); err != nil {
_ = db.Close()
return nil, err
}
return s, nil
}
func (p *PostgresSink) Name() string { return p.name }
func (p *PostgresSink) Consume(ctx context.Context, e event.Event) error {
// Boundary validation: if something upstream violated invariants,
// surface it loudly rather than writing corrupt rows.
if err := e.Validate(); err != nil {
return fmt.Errorf("postgres sink: invalid event: %w", err)
}
if err := ctx.Err(); err != nil {
return err
}
writes, err := p.schema.mapEvent(ctx, e)
if err != nil {
return fmt.Errorf("postgres sink: map event: %w", err)
}
if len(writes) == 0 {
return nil
}
tx, err := p.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("postgres sink: begin tx: %w", err)
}
committed := false
defer func() {
if !committed {
_ = tx.Rollback()
}
}()
for _, w := range writes {
tbl, err := p.schema.validateWrite(w)
if err != nil {
return fmt.Errorf("postgres sink: %w", err)
}
query, args, err := buildInsertSQL(tbl, w)
if err != nil {
return fmt.Errorf("postgres sink: build insert for table %q: %w", tbl.name, err)
}
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return fmt.Errorf("postgres sink: insert into %q: %w", tbl.name, err)
}
}
if p.pruneWindow > 0 {
cutoff := time.Now().UTC().Add(-p.pruneWindow)
for _, tableName := range p.schema.tableOrder {
tbl := p.schema.tables[tableName]
if _, err := execPruneOlderThan(ctx, tx, tbl, cutoff); err != nil {
return fmt.Errorf("postgres sink: %w", err)
}
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("postgres sink: commit tx: %w", err)
}
committed = true
return nil
}
func (p *PostgresSink) PruneKeepLatest(ctx context.Context, table string, keep int) (int64, error) {
if keep < 0 {
return 0, fmt.Errorf("postgres sink: keep must be >= 0")
}
tbl, err := p.lookupTable(table)
if err != nil {
return 0, err
}
query := fmt.Sprintf(
`DELETE FROM %s WHERE ctid IN (
SELECT ctid FROM %s
ORDER BY %s DESC
OFFSET $1
)`,
quotePostgresIdent(tbl.name),
quotePostgresIdent(tbl.name),
quotePostgresIdent(tbl.pruneColumn),
)
res, err := p.db.ExecContext(ctx, query, keep)
if err != nil {
return 0, fmt.Errorf("postgres sink: prune keep latest table %q: %w", tbl.name, err)
}
rows, err := res.RowsAffected()
if err != nil {
return 0, fmt.Errorf("postgres sink: prune keep latest table %q rows affected: %w", tbl.name, err)
}
return rows, nil
}
func (p *PostgresSink) PruneOlderThan(ctx context.Context, table string, cutoff time.Time) (int64, error) {
tbl, err := p.lookupTable(table)
if err != nil {
return 0, err
}
rows, err := execPruneOlderThan(ctx, p.db, tbl, cutoff)
if err != nil {
return 0, fmt.Errorf("postgres sink: %w", err)
}
return rows, nil
}
func (p *PostgresSink) PruneAllKeepLatest(ctx context.Context, keep int) (map[string]int64, error) {
counts := make(map[string]int64, len(p.schema.tableOrder))
for _, table := range p.schema.tableOrder {
n, err := p.PruneKeepLatest(ctx, table, keep)
if err != nil {
return counts, err
}
counts[table] = n
}
return counts, nil
}
func (p *PostgresSink) PruneAllOlderThan(ctx context.Context, cutoff time.Time) (map[string]int64, error) {
counts := make(map[string]int64, len(p.schema.tableOrder))
for _, table := range p.schema.tableOrder {
n, err := p.PruneOlderThan(ctx, table, cutoff)
if err != nil {
return counts, err
}
counts[table] = n
}
return counts, nil
}
func (p *PostgresSink) initialize() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
for _, tableName := range p.schema.tableOrder {
tbl := p.schema.tables[tableName]
createTableSQL := buildCreateTableSQL(tbl)
if _, err := p.db.ExecContext(ctx, createTableSQL); err != nil {
return fmt.Errorf("postgres sink %q: ensure table %q: %w", p.name, tbl.name, err)
}
for _, idx := range tbl.indexes {
createIndexSQL := buildCreateIndexSQL(tbl.name, idx)
if _, err := p.db.ExecContext(ctx, createIndexSQL); err != nil {
return fmt.Errorf("postgres sink %q: ensure index %q on %q: %w", p.name, idx.Name, tbl.name, err)
}
}
}
return nil
}
func (p *PostgresSink) lookupTable(table string) (postgresTableCompiled, error) {
table = strings.TrimSpace(table)
if table == "" {
return postgresTableCompiled{}, fmt.Errorf("postgres sink: table cannot be empty")
}
tbl, ok := p.schema.tables[table]
if !ok {
return postgresTableCompiled{}, fmt.Errorf("postgres sink: unknown table %q", table)
}
return tbl, nil
}
func parsePostgresPruneWindow(cfg config.SinkConfig) (time.Duration, error) {
raw, ok := cfg.Params["prune"]
if !ok || raw == nil {
return 0, nil
}
s, ok := raw.(string)
if !ok {
return 0, fmt.Errorf("sink %q: params.prune must be a string duration (e.g. 72h, 3d, 2w)", cfg.Name)
}
d, err := parsePostgresPruneDuration(s)
if err != nil {
return 0, fmt.Errorf("sink %q: params.prune %q is invalid: %w", cfg.Name, s, err)
}
return d, nil
}
func parsePostgresPruneDuration(raw string) (time.Duration, error) {
s := strings.TrimSpace(raw)
if s == "" {
return 0, fmt.Errorf("must not be empty")
}
lower := strings.ToLower(s)
if strings.HasSuffix(lower, "d") || strings.HasSuffix(lower, "w") {
unit := lower[len(lower)-1]
n, err := strconv.Atoi(strings.TrimSpace(lower[:len(lower)-1]))
if err != nil {
return 0, fmt.Errorf("must use a positive integer before %q", string(unit))
}
if n <= 0 {
return 0, fmt.Errorf("must be > 0")
}
if unit == 'd' {
return time.Duration(n) * 24 * time.Hour, nil
}
return time.Duration(n) * 7 * 24 * time.Hour, nil
}
d, err := time.ParseDuration(s)
if err != nil {
return 0, fmt.Errorf("must be a Go duration or use d/w suffixes")
}
if d <= 0 {
return 0, fmt.Errorf("must be > 0")
}
return d, nil
}
func buildPruneOlderThanSQL(tbl postgresTableCompiled) string {
return fmt.Sprintf(
`DELETE FROM %s WHERE %s < $1`,
quotePostgresIdent(tbl.name),
quotePostgresIdent(tbl.pruneColumn),
)
}
func execPruneOlderThan(ctx context.Context, execer postgresExecer, tbl postgresTableCompiled, cutoff time.Time) (int64, error) {
query := buildPruneOlderThanSQL(tbl)
res, err := execer.ExecContext(ctx, query, cutoff)
if err != nil {
return 0, fmt.Errorf("prune older than table %q: %w", tbl.name, err)
}
rows, err := res.RowsAffected()
if err != nil {
return 0, fmt.Errorf("prune older than table %q rows affected: %w", tbl.name, err)
}
return rows, nil
}
func buildCreateTableSQL(tbl postgresTableCompiled) string {
defs := make([]string, 0, len(tbl.columnOrder)+1)
for _, colName := range tbl.columnOrder {
col := tbl.columns[colName]
def := fmt.Sprintf("%s %s", quotePostgresIdent(col.Name), col.Type)
if !col.Nullable {
def += " NOT NULL"
}
defs = append(defs, def)
}
if len(tbl.primaryKey) > 0 {
defs = append(defs, fmt.Sprintf("PRIMARY KEY (%s)", joinQuotedIdents(tbl.primaryKey)))
}
return fmt.Sprintf(
"CREATE TABLE IF NOT EXISTS %s (%s)",
quotePostgresIdent(tbl.name),
strings.Join(defs, ", "),
)
}
func buildCreateIndexSQL(tableName string, idx PostgresIndex) string {
unique := ""
if idx.Unique {
unique = "UNIQUE "
}
return fmt.Sprintf(
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
unique,
quotePostgresIdent(idx.Name),
quotePostgresIdent(tableName),
joinQuotedIdents(idx.Columns),
)
}
func buildInsertSQL(tbl postgresTableCompiled, w PostgresWrite) (string, []any, error) {
cols := make([]string, 0, len(tbl.columnOrder))
args := make([]any, 0, len(tbl.columnOrder))
placeholders := make([]string, 0, len(tbl.columnOrder))
for i, colName := range tbl.columnOrder {
v, ok := w.Values[colName]
if !ok {
return "", nil, fmt.Errorf("missing value for column %q", colName)
}
cols = append(cols, quotePostgresIdent(colName))
args = append(args, v)
placeholders = append(placeholders, "$"+strconv.Itoa(i+1))
}
q := fmt.Sprintf(
"INSERT INTO %s (%s) VALUES (%s)",
quotePostgresIdent(tbl.name),
strings.Join(cols, ", "),
strings.Join(placeholders, ", "),
)
return q, args, nil
}
func joinQuotedIdents(idents []string) string {
quoted := make([]string, 0, len(idents))
for _, s := range idents {
quoted = append(quoted, quotePostgresIdent(s))
}
return strings.Join(quoted, ", ")
}
func quotePostgresIdent(s string) string {
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
}