462 lines
12 KiB
Go
462 lines
12 KiB
Go
package sinks
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"gitea.maximumdirect.net/ejr/feedkit/config"
|
|
"gitea.maximumdirect.net/ejr/feedkit/event"
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
const postgresInitTimeout = 5 * time.Second
|
|
|
|
type postgresTx interface {
|
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
Commit() error
|
|
Rollback() error
|
|
}
|
|
|
|
type postgresExecer interface {
|
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
}
|
|
|
|
type postgresDB interface {
|
|
PingContext(ctx context.Context) error
|
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error)
|
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
Close() error
|
|
}
|
|
|
|
type sqlDBWrapper struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func (w *sqlDBWrapper) PingContext(ctx context.Context) error {
|
|
return w.db.PingContext(ctx)
|
|
}
|
|
|
|
func (w *sqlDBWrapper) BeginTx(ctx context.Context, opts *sql.TxOptions) (postgresTx, error) {
|
|
tx, err := w.db.BeginTx(ctx, opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &sqlTxWrapper{tx: tx}, nil
|
|
}
|
|
|
|
func (w *sqlDBWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
|
return w.db.ExecContext(ctx, query, args...)
|
|
}
|
|
|
|
func (w *sqlDBWrapper) Close() error {
|
|
return w.db.Close()
|
|
}
|
|
|
|
type sqlTxWrapper struct {
|
|
tx *sql.Tx
|
|
}
|
|
|
|
func (w *sqlTxWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
|
return w.tx.ExecContext(ctx, query, args...)
|
|
}
|
|
|
|
func (w *sqlTxWrapper) Commit() error {
|
|
return w.tx.Commit()
|
|
}
|
|
|
|
func (w *sqlTxWrapper) Rollback() error {
|
|
return w.tx.Rollback()
|
|
}
|
|
|
|
var openPostgresDB = func(dsn string) (postgresDB, error) {
|
|
db, err := sql.Open("postgres", dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &sqlDBWrapper{db: db}, nil
|
|
}
|
|
|
|
type PostgresSink struct {
|
|
name string
|
|
db postgresDB
|
|
schema postgresSchemaCompiled
|
|
pruneWindow time.Duration
|
|
}
|
|
|
|
func NewPostgresSinkFromConfig(cfg config.SinkConfig) (Sink, error) {
|
|
uri, err := requireStringParam(cfg, "uri")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
username, err := requireStringParam(cfg, "username")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
password, err := requireStringParam(cfg, "password")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pruneWindow, err := parsePostgresPruneWindow(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
schema, ok := lookupPostgresSchema(cfg.Name)
|
|
if !ok {
|
|
return nil, fmt.Errorf("postgres sink %q: no schema registered (call sinks.RegisterPostgresSchema before building sinks)", cfg.Name)
|
|
}
|
|
|
|
dsn, err := buildPostgresDSN(uri, username, password)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres sink %q: build dsn: %w", cfg.Name, err)
|
|
}
|
|
|
|
db, err := openPostgresDB(dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres sink %q: open db: %w", cfg.Name, err)
|
|
}
|
|
|
|
s := &PostgresSink{name: cfg.Name, db: db, schema: schema, pruneWindow: pruneWindow}
|
|
if err := s.initialize(); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (p *PostgresSink) Name() string { return p.name }
|
|
|
|
func (p *PostgresSink) Consume(ctx context.Context, e event.Event) error {
|
|
// Boundary validation: if something upstream violated invariants,
|
|
// surface it loudly rather than writing corrupt rows.
|
|
if err := e.Validate(); err != nil {
|
|
return fmt.Errorf("postgres sink: invalid event: %w", err)
|
|
}
|
|
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
writes, err := p.schema.mapEvent(ctx, e)
|
|
if err != nil {
|
|
return fmt.Errorf("postgres sink: map event: %w", err)
|
|
}
|
|
if len(writes) == 0 {
|
|
return nil
|
|
}
|
|
|
|
tx, err := p.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("postgres sink: begin tx: %w", err)
|
|
}
|
|
|
|
committed := false
|
|
defer func() {
|
|
if !committed {
|
|
_ = tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
for _, w := range writes {
|
|
tbl, err := p.schema.validateWrite(w)
|
|
if err != nil {
|
|
return fmt.Errorf("postgres sink: %w", err)
|
|
}
|
|
|
|
query, args, err := buildInsertSQL(tbl, w)
|
|
if err != nil {
|
|
return fmt.Errorf("postgres sink: build insert for table %q: %w", tbl.name, err)
|
|
}
|
|
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
|
|
return fmt.Errorf("postgres sink: insert into %q: %w", tbl.name, err)
|
|
}
|
|
}
|
|
if p.pruneWindow > 0 {
|
|
cutoff := time.Now().UTC().Add(-p.pruneWindow)
|
|
for _, tableName := range p.schema.tableOrder {
|
|
tbl := p.schema.tables[tableName]
|
|
if _, err := execPruneOlderThan(ctx, tx, tbl, cutoff); err != nil {
|
|
return fmt.Errorf("postgres sink: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("postgres sink: commit tx: %w", err)
|
|
}
|
|
committed = true
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *PostgresSink) PruneKeepLatest(ctx context.Context, table string, keep int) (int64, error) {
|
|
if keep < 0 {
|
|
return 0, fmt.Errorf("postgres sink: keep must be >= 0")
|
|
}
|
|
tbl, err := p.lookupTable(table)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
query := fmt.Sprintf(
|
|
`DELETE FROM %s WHERE ctid IN (
|
|
SELECT ctid FROM %s
|
|
ORDER BY %s DESC
|
|
OFFSET $1
|
|
)`,
|
|
quotePostgresIdent(tbl.name),
|
|
quotePostgresIdent(tbl.name),
|
|
quotePostgresIdent(tbl.pruneColumn),
|
|
)
|
|
|
|
res, err := p.db.ExecContext(ctx, query, keep)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("postgres sink: prune keep latest table %q: %w", tbl.name, err)
|
|
}
|
|
rows, err := res.RowsAffected()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("postgres sink: prune keep latest table %q rows affected: %w", tbl.name, err)
|
|
}
|
|
return rows, nil
|
|
}
|
|
|
|
func (p *PostgresSink) PruneOlderThan(ctx context.Context, table string, cutoff time.Time) (int64, error) {
|
|
tbl, err := p.lookupTable(table)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
rows, err := execPruneOlderThan(ctx, p.db, tbl, cutoff)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("postgres sink: %w", err)
|
|
}
|
|
return rows, nil
|
|
}
|
|
|
|
func (p *PostgresSink) PruneAllKeepLatest(ctx context.Context, keep int) (map[string]int64, error) {
|
|
counts := make(map[string]int64, len(p.schema.tableOrder))
|
|
for _, table := range p.schema.tableOrder {
|
|
n, err := p.PruneKeepLatest(ctx, table, keep)
|
|
if err != nil {
|
|
return counts, err
|
|
}
|
|
counts[table] = n
|
|
}
|
|
return counts, nil
|
|
}
|
|
|
|
func (p *PostgresSink) PruneAllOlderThan(ctx context.Context, cutoff time.Time) (map[string]int64, error) {
|
|
counts := make(map[string]int64, len(p.schema.tableOrder))
|
|
for _, table := range p.schema.tableOrder {
|
|
n, err := p.PruneOlderThan(ctx, table, cutoff)
|
|
if err != nil {
|
|
return counts, err
|
|
}
|
|
counts[table] = n
|
|
}
|
|
return counts, nil
|
|
}
|
|
|
|
func (p *PostgresSink) initialize() error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), postgresInitTimeout)
|
|
defer cancel()
|
|
|
|
if err := p.db.PingContext(ctx); err != nil {
|
|
return fmt.Errorf("postgres sink %q: ping db: %w", p.name, err)
|
|
}
|
|
|
|
for _, tableName := range p.schema.tableOrder {
|
|
tbl := p.schema.tables[tableName]
|
|
|
|
createTableSQL := buildCreateTableSQL(tbl)
|
|
if _, err := p.db.ExecContext(ctx, createTableSQL); err != nil {
|
|
return fmt.Errorf("postgres sink %q: ensure table %q: %w", p.name, tbl.name, err)
|
|
}
|
|
|
|
for _, idx := range tbl.indexes {
|
|
createIndexSQL := buildCreateIndexSQL(tbl.name, idx)
|
|
if _, err := p.db.ExecContext(ctx, createIndexSQL); err != nil {
|
|
return fmt.Errorf("postgres sink %q: ensure index %q on %q: %w", p.name, idx.Name, tbl.name, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *PostgresSink) lookupTable(table string) (postgresTableCompiled, error) {
|
|
table = strings.TrimSpace(table)
|
|
if table == "" {
|
|
return postgresTableCompiled{}, fmt.Errorf("postgres sink: table cannot be empty")
|
|
}
|
|
tbl, ok := p.schema.tables[table]
|
|
if !ok {
|
|
return postgresTableCompiled{}, fmt.Errorf("postgres sink: unknown table %q", table)
|
|
}
|
|
return tbl, nil
|
|
}
|
|
|
|
func buildPostgresDSN(uri, username, password string) (string, error) {
|
|
u, err := url.Parse(strings.TrimSpace(uri))
|
|
if err != nil {
|
|
return "", fmt.Errorf("invalid uri: %w", err)
|
|
}
|
|
if u.Scheme == "" {
|
|
return "", fmt.Errorf("invalid uri: missing scheme")
|
|
}
|
|
if u.Host == "" {
|
|
return "", fmt.Errorf("invalid uri: missing host")
|
|
}
|
|
u.User = url.UserPassword(username, password)
|
|
return u.String(), nil
|
|
}
|
|
|
|
func parsePostgresPruneWindow(cfg config.SinkConfig) (time.Duration, error) {
|
|
raw, ok := cfg.Params["prune"]
|
|
if !ok || raw == nil {
|
|
return 0, nil
|
|
}
|
|
|
|
s, ok := raw.(string)
|
|
if !ok {
|
|
return 0, fmt.Errorf("sink %q: params.prune must be a string duration (e.g. 72h, 3d, 2w)", cfg.Name)
|
|
}
|
|
|
|
d, err := parsePostgresPruneDuration(s)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("sink %q: params.prune %q is invalid: %w", cfg.Name, s, err)
|
|
}
|
|
return d, nil
|
|
}
|
|
|
|
func parsePostgresPruneDuration(raw string) (time.Duration, error) {
|
|
s := strings.TrimSpace(raw)
|
|
if s == "" {
|
|
return 0, fmt.Errorf("must not be empty")
|
|
}
|
|
|
|
lower := strings.ToLower(s)
|
|
if strings.HasSuffix(lower, "d") || strings.HasSuffix(lower, "w") {
|
|
unit := lower[len(lower)-1]
|
|
n, err := strconv.Atoi(strings.TrimSpace(lower[:len(lower)-1]))
|
|
if err != nil {
|
|
return 0, fmt.Errorf("must use a positive integer before %q", string(unit))
|
|
}
|
|
if n <= 0 {
|
|
return 0, fmt.Errorf("must be > 0")
|
|
}
|
|
if unit == 'd' {
|
|
return time.Duration(n) * 24 * time.Hour, nil
|
|
}
|
|
return time.Duration(n) * 7 * 24 * time.Hour, nil
|
|
}
|
|
|
|
d, err := time.ParseDuration(s)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("must be a Go duration or use d/w suffixes")
|
|
}
|
|
if d <= 0 {
|
|
return 0, fmt.Errorf("must be > 0")
|
|
}
|
|
return d, nil
|
|
}
|
|
|
|
func buildPruneOlderThanSQL(tbl postgresTableCompiled) string {
|
|
return fmt.Sprintf(
|
|
`DELETE FROM %s WHERE %s < $1`,
|
|
quotePostgresIdent(tbl.name),
|
|
quotePostgresIdent(tbl.pruneColumn),
|
|
)
|
|
}
|
|
|
|
func execPruneOlderThan(ctx context.Context, execer postgresExecer, tbl postgresTableCompiled, cutoff time.Time) (int64, error) {
|
|
query := buildPruneOlderThanSQL(tbl)
|
|
res, err := execer.ExecContext(ctx, query, cutoff)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("prune older than table %q: %w", tbl.name, err)
|
|
}
|
|
rows, err := res.RowsAffected()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("prune older than table %q rows affected: %w", tbl.name, err)
|
|
}
|
|
return rows, nil
|
|
}
|
|
|
|
func buildCreateTableSQL(tbl postgresTableCompiled) string {
|
|
defs := make([]string, 0, len(tbl.columnOrder)+1)
|
|
for _, colName := range tbl.columnOrder {
|
|
col := tbl.columns[colName]
|
|
def := fmt.Sprintf("%s %s", quotePostgresIdent(col.Name), col.Type)
|
|
if !col.Nullable {
|
|
def += " NOT NULL"
|
|
}
|
|
defs = append(defs, def)
|
|
}
|
|
if len(tbl.primaryKey) > 0 {
|
|
defs = append(defs, fmt.Sprintf("PRIMARY KEY (%s)", joinQuotedIdents(tbl.primaryKey)))
|
|
}
|
|
|
|
return fmt.Sprintf(
|
|
"CREATE TABLE IF NOT EXISTS %s (%s)",
|
|
quotePostgresIdent(tbl.name),
|
|
strings.Join(defs, ", "),
|
|
)
|
|
}
|
|
|
|
func buildCreateIndexSQL(tableName string, idx PostgresIndex) string {
|
|
unique := ""
|
|
if idx.Unique {
|
|
unique = "UNIQUE "
|
|
}
|
|
|
|
return fmt.Sprintf(
|
|
"CREATE %sINDEX IF NOT EXISTS %s ON %s (%s)",
|
|
unique,
|
|
quotePostgresIdent(idx.Name),
|
|
quotePostgresIdent(tableName),
|
|
joinQuotedIdents(idx.Columns),
|
|
)
|
|
}
|
|
|
|
func buildInsertSQL(tbl postgresTableCompiled, w PostgresWrite) (string, []any, error) {
|
|
cols := make([]string, 0, len(tbl.columnOrder))
|
|
args := make([]any, 0, len(tbl.columnOrder))
|
|
placeholders := make([]string, 0, len(tbl.columnOrder))
|
|
|
|
for i, colName := range tbl.columnOrder {
|
|
v, ok := w.Values[colName]
|
|
if !ok {
|
|
return "", nil, fmt.Errorf("missing value for column %q", colName)
|
|
}
|
|
cols = append(cols, quotePostgresIdent(colName))
|
|
args = append(args, v)
|
|
placeholders = append(placeholders, "$"+strconv.Itoa(i+1))
|
|
}
|
|
|
|
q := fmt.Sprintf(
|
|
"INSERT INTO %s (%s) VALUES (%s)",
|
|
quotePostgresIdent(tbl.name),
|
|
strings.Join(cols, ", "),
|
|
strings.Join(placeholders, ", "),
|
|
)
|
|
return q, args, nil
|
|
}
|
|
|
|
func joinQuotedIdents(idents []string) string {
|
|
quoted := make([]string, 0, len(idents))
|
|
for _, s := range idents {
|
|
quoted = append(quoted, quotePostgresIdent(s))
|
|
}
|
|
return strings.Join(quoted, ", ")
|
|
}
|
|
|
|
func quotePostgresIdent(s string) string {
|
|
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
|
|
}
|