Added support for Postgres polling sources
This commit is contained in:
67
internal/postgres/postgres.go
Normal file
67
internal/postgres/postgres.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
const initTimeout = 5 * time.Second
|
||||
|
||||
var (
|
||||
sqlOpen = sql.Open
|
||||
pingDB = func(ctx context.Context, db *sql.DB) error { return db.PingContext(ctx) }
|
||||
)
|
||||
|
||||
// ConnConfig describes the minimal connection settings shared by feedkit's
|
||||
// Postgres readers and writers.
|
||||
type ConnConfig struct {
|
||||
URI string
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
// BuildDSN validates a Postgres URI and injects credentials into it.
|
||||
func BuildDSN(cfg ConnConfig) (string, error) {
|
||||
u, err := url.Parse(strings.TrimSpace(cfg.URI))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid uri: %w", err)
|
||||
}
|
||||
if u.Scheme == "" {
|
||||
return "", fmt.Errorf("invalid uri: missing scheme")
|
||||
}
|
||||
if u.Host == "" {
|
||||
return "", fmt.Errorf("invalid uri: missing host")
|
||||
}
|
||||
u.User = url.UserPassword(cfg.Username, cfg.Password)
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// Open builds a DSN, opens a database handle, and verifies connectivity with a
|
||||
// bounded ping before returning the handle.
|
||||
func Open(ctx context.Context, cfg ConnConfig) (*sql.DB, error) {
|
||||
dsn, err := BuildDSN(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err := sqlOpen("postgres", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, initTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := pingDB(pingCtx, db); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
165
internal/postgres/postgres_test.go
Normal file
165
internal/postgres/postgres_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"io"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func withPostgresPackageTestState(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
oldSQLOpen := sqlOpen
|
||||
oldPingDB := pingDB
|
||||
t.Cleanup(func() {
|
||||
sqlOpen = oldSQLOpen
|
||||
pingDB = oldPingDB
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildDSNInjectsCredentials(t *testing.T) {
|
||||
dsn, err := BuildDSN(ConnConfig{
|
||||
URI: " postgres://db.example.local:5432/feedkit?sslmode=disable ",
|
||||
Username: "app_user",
|
||||
Password: "app_pass",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("BuildDSN() error = %v", err)
|
||||
}
|
||||
|
||||
u, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error = %v", err)
|
||||
}
|
||||
if u.User == nil || u.User.Username() != "app_user" {
|
||||
t.Fatalf("username = %q, want app_user", u.User.Username())
|
||||
}
|
||||
pass, ok := u.User.Password()
|
||||
if !ok || pass != "app_pass" {
|
||||
t.Fatalf("password = %q, want app_pass", pass)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDSNRejectsInvalidURI(t *testing.T) {
|
||||
_, err := BuildDSN(ConnConfig{URI: "http://[::1", Username: "u", Password: "p"})
|
||||
if err == nil {
|
||||
t.Fatalf("BuildDSN() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid uri") {
|
||||
t.Fatalf("BuildDSN() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDSNRejectsMissingScheme(t *testing.T) {
|
||||
_, err := BuildDSN(ConnConfig{URI: "//db.example.local/feedkit", Username: "u", Password: "p"})
|
||||
if err == nil {
|
||||
t.Fatalf("BuildDSN() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "missing scheme") {
|
||||
t.Fatalf("BuildDSN() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDSNRejectsMissingHost(t *testing.T) {
|
||||
_, err := BuildDSN(ConnConfig{URI: "postgres:///feedkit", Username: "u", Password: "p"})
|
||||
if err == nil {
|
||||
t.Fatalf("BuildDSN() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "missing host") {
|
||||
t.Fatalf("BuildDSN() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPropagatesOpenFailure(t *testing.T) {
|
||||
withPostgresPackageTestState(t)
|
||||
|
||||
sqlOpen = func(_, _ string) (*sql.DB, error) {
|
||||
return nil, errors.New("open failed")
|
||||
}
|
||||
|
||||
_, err := Open(context.Background(), ConnConfig{
|
||||
URI: "postgres://db.example.local/feedkit",
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("Open() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "open failed") {
|
||||
t.Fatalf("Open() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenPropagatesPingFailure(t *testing.T) {
|
||||
withPostgresPackageTestState(t)
|
||||
|
||||
const driverName = "feedkit_internal_postgres_ping_fail"
|
||||
registerPingTestDriver(driverName, errors.New("ping failed"))
|
||||
|
||||
sqlOpen = func(_, _ string) (*sql.DB, error) {
|
||||
return sql.Open(driverName, "")
|
||||
}
|
||||
|
||||
_, err := Open(context.Background(), ConnConfig{
|
||||
URI: "postgres://db.example.local/feedkit",
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("Open() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "ping failed") {
|
||||
t.Fatalf("Open() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
pingDriverMu sync.Mutex
|
||||
pingDriverSeen = map[string]bool{}
|
||||
)
|
||||
|
||||
func registerPingTestDriver(name string, pingErr error) {
|
||||
pingDriverMu.Lock()
|
||||
defer pingDriverMu.Unlock()
|
||||
|
||||
if pingDriverSeen[name] {
|
||||
return
|
||||
}
|
||||
sql.Register(name, &pingTestDriver{pingErr: pingErr})
|
||||
pingDriverSeen[name] = true
|
||||
}
|
||||
|
||||
type pingTestDriver struct {
|
||||
pingErr error
|
||||
}
|
||||
|
||||
func (d *pingTestDriver) Open(string) (driver.Conn, error) {
|
||||
return &pingTestConn{pingErr: d.pingErr}, nil
|
||||
}
|
||||
|
||||
type pingTestConn struct {
|
||||
pingErr error
|
||||
}
|
||||
|
||||
func (c *pingTestConn) Prepare(string) (driver.Stmt, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (c *pingTestConn) Close() error { return nil }
|
||||
func (c *pingTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") }
|
||||
func (c *pingTestConn) Ping(context.Context) error { return c.pingErr }
|
||||
|
||||
func (c *pingTestConn) QueryContext(context.Context, string, []driver.NamedValue) (driver.Rows, error) {
|
||||
return &pingTestRows{}, nil
|
||||
}
|
||||
|
||||
type pingTestRows struct{}
|
||||
|
||||
func (r *pingTestRows) Columns() []string { return []string{"ok"} }
|
||||
func (r *pingTestRows) Close() error { return nil }
|
||||
func (r *pingTestRows) Next([]driver.Value) error { return io.EOF }
|
||||
@@ -4,18 +4,15 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
_ "github.com/lib/pq"
|
||||
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
|
||||
)
|
||||
|
||||
const postgresInitTimeout = 5 * time.Second
|
||||
|
||||
type postgresTx interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
Commit() error
|
||||
@@ -73,8 +70,8 @@ func (w *sqlTxWrapper) Rollback() error {
|
||||
return w.tx.Rollback()
|
||||
}
|
||||
|
||||
var openPostgresDB = func(dsn string) (postgresDB, error) {
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
var openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
|
||||
db, err := pgconn.Open(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -111,12 +108,11 @@ func NewPostgresSinkFromConfig(cfg config.SinkConfig, schemaDef PostgresSchema)
|
||||
return nil, fmt.Errorf("postgres sink %q: compile schema: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
dsn, err := buildPostgresDSN(uri, username, password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres sink %q: build dsn: %w", cfg.Name, err)
|
||||
}
|
||||
|
||||
db, err := openPostgresDB(dsn)
|
||||
db, err := openPostgresDB(context.Background(), pgconn.ConnConfig{
|
||||
URI: uri,
|
||||
Username: username,
|
||||
Password: password,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres sink %q: open db: %w", cfg.Name, err)
|
||||
}
|
||||
@@ -264,13 +260,9 @@ func (p *PostgresSink) PruneAllOlderThan(ctx context.Context, cutoff time.Time)
|
||||
}
|
||||
|
||||
func (p *PostgresSink) initialize() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), postgresInitTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := p.db.PingContext(ctx); err != nil {
|
||||
return fmt.Errorf("postgres sink %q: ping db: %w", p.name, err)
|
||||
}
|
||||
|
||||
for _, tableName := range p.schema.tableOrder {
|
||||
tbl := p.schema.tables[tableName]
|
||||
|
||||
@@ -302,21 +294,6 @@ func (p *PostgresSink) lookupTable(table string) (postgresTableCompiled, error)
|
||||
return tbl, nil
|
||||
}
|
||||
|
||||
func buildPostgresDSN(uri, username, password string) (string, error) {
|
||||
u, err := url.Parse(strings.TrimSpace(uri))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid uri: %w", err)
|
||||
}
|
||||
if u.Scheme == "" {
|
||||
return "", fmt.Errorf("invalid uri: missing scheme")
|
||||
}
|
||||
if u.Host == "" {
|
||||
return "", fmt.Errorf("invalid uri: missing host")
|
||||
}
|
||||
u.User = url.UserPassword(username, password)
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func parsePostgresPruneWindow(cfg config.SinkConfig) (time.Duration, error) {
|
||||
raw, ok := cfg.Params["prune"]
|
||||
if !ok || raw == nil {
|
||||
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
|
||||
)
|
||||
|
||||
type fakeResult struct {
|
||||
@@ -223,10 +223,13 @@ func TestPostgresFactoryBuildsMultipleSinksWithSameSchema(t *testing.T) {
|
||||
withPostgresTestState(t)
|
||||
|
||||
dbs := []*fakeDB{{}, {}}
|
||||
var gotDSNs []string
|
||||
openPostgresDB = func(dsn string) (postgresDB, error) {
|
||||
gotDSNs = append(gotDSNs, dsn)
|
||||
db := dbs[len(gotDSNs)-1]
|
||||
var gotCfgs []pgconn.ConnConfig
|
||||
openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
|
||||
gotCfgs = append(gotCfgs, cfg)
|
||||
db := dbs[len(gotCfgs)-1]
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
@@ -252,8 +255,11 @@ func TestPostgresFactoryBuildsMultipleSinksWithSameSchema(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(gotDSNs) != 2 {
|
||||
t.Fatalf("len(gotDSNs) = %d, want 2", len(gotDSNs))
|
||||
if len(gotCfgs) != 2 {
|
||||
t.Fatalf("len(gotCfgs) = %d, want 2", len(gotCfgs))
|
||||
}
|
||||
if gotCfgs[0].Username != "user" || gotCfgs[0].Password != "pass" {
|
||||
t.Fatalf("first ConnConfig = %+v", gotCfgs[0])
|
||||
}
|
||||
for i, db := range dbs {
|
||||
if db.pingCalls != 1 {
|
||||
@@ -327,9 +333,12 @@ func TestNewPostgresSinkFromConfigEagerInit(t *testing.T) {
|
||||
withPostgresTestState(t)
|
||||
|
||||
db := &fakeDB{}
|
||||
var gotDSN string
|
||||
openPostgresDB = func(dsn string) (postgresDB, error) {
|
||||
gotDSN = dsn
|
||||
var gotCfg pgconn.ConnConfig
|
||||
openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) {
|
||||
gotCfg = cfg
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
@@ -362,16 +371,14 @@ func TestNewPostgresSinkFromConfigEagerInit(t *testing.T) {
|
||||
t.Fatalf("unexpected create index query: %s", db.execCalls[1].query)
|
||||
}
|
||||
|
||||
u, err := url.Parse(gotDSN)
|
||||
if err != nil {
|
||||
t.Fatalf("parse dsn: %v", err)
|
||||
if gotCfg.URI != "postgres://db.example.local:5432/feedkit?sslmode=disable" {
|
||||
t.Fatalf("URI = %q", gotCfg.URI)
|
||||
}
|
||||
if u.User == nil || u.User.Username() != "app_user" {
|
||||
t.Fatalf("dsn missing username: %q", gotDSN)
|
||||
if gotCfg.Username != "app_user" {
|
||||
t.Fatalf("Username = %q, want app_user", gotCfg.Username)
|
||||
}
|
||||
pass, ok := u.User.Password()
|
||||
if !ok || pass != "app_pass" {
|
||||
t.Fatalf("dsn missing password: %q", gotDSN)
|
||||
if gotCfg.Password != "app_pass" {
|
||||
t.Fatalf("Password = %q, want app_pass", gotCfg.Password)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -379,7 +386,7 @@ func TestNewPostgresSinkFromConfigInitFailureClosesDB(t *testing.T) {
|
||||
withPostgresTestState(t)
|
||||
|
||||
db := &fakeDB{execErrOnCall: 1, execErr: errors.New("ddl failed")}
|
||||
openPostgresDB = func(_ string) (postgresDB, error) {
|
||||
openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
@@ -415,7 +422,7 @@ func TestNewPostgresSinkFromConfigPruneParamAccepted(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
withPostgresTestState(t)
|
||||
|
||||
openPostgresDB = func(_ string) (postgresDB, error) {
|
||||
openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) {
|
||||
return &fakeDB{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
// stream exit classification helpers
|
||||
// - Registry / NewRegistry: source driver registry and builders
|
||||
// - HTTPSource / NewHTTPSource: reusable HTTP polling helper
|
||||
// - PostgresQuerySource / NewPostgresQuerySource: reusable Postgres polling
|
||||
// helper
|
||||
//
|
||||
// Source drivers are domain-specific and registered into Registry by driver name.
|
||||
// Registry can then build configured sources from config.SourceConfig.
|
||||
@@ -34,4 +36,17 @@
|
||||
// When validators are available, NewHTTPSource prefers ETag/If-None-Match and
|
||||
// falls back to Last-Modified/If-Modified-Since. A 304 Not Modified response is
|
||||
// treated as a successful unchanged poll.
|
||||
//
|
||||
// Postgres-backed polling sources can share NewPostgresQuerySource for generic
|
||||
// DB config parsing and query execution. The helper understands:
|
||||
// - params.uri
|
||||
// - params.username
|
||||
// - params.password
|
||||
// - params.query
|
||||
// - params.query_timeout (optional, default 30s)
|
||||
//
|
||||
// feedkit does not register a built-in postgres poll driver. Downstream daemons
|
||||
// should register domain-specific driver names that call
|
||||
// NewPostgresQuerySource, then keep SQL semantics, row scanning, ordering,
|
||||
// watermark policy, and event construction in their own source types.
|
||||
package sources
|
||||
|
||||
117
sources/postgres.go
Normal file
117
sources/postgres.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package sources
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
|
||||
)
|
||||
|
||||
const defaultPostgresQueryTimeout = 30 * time.Second
|
||||
|
||||
type postgresQueryDB interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
var openPostgresQueryDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) {
|
||||
return pgconn.Open(ctx, cfg)
|
||||
}
|
||||
|
||||
// PostgresQuerySource is a reusable helper for polling Postgres-backed sources.
|
||||
//
|
||||
// It centralizes generic source config parsing and query execution. Concrete
|
||||
// daemon sources remain responsible for SQL semantics, row scanning, cursoring,
|
||||
// and event construction.
|
||||
type PostgresQuerySource struct {
|
||||
Driver string
|
||||
Name string
|
||||
SQL string
|
||||
QueryTimeout time.Duration
|
||||
|
||||
db postgresQueryDB
|
||||
}
|
||||
|
||||
// NewPostgresQuerySource builds a generic Postgres polling helper from
|
||||
// SourceConfig.
|
||||
//
|
||||
// Required params:
|
||||
// - params.uri
|
||||
// - params.username
|
||||
// - params.password
|
||||
// - params.query
|
||||
//
|
||||
// Optional params:
|
||||
// - params.query_timeout (default 30s)
|
||||
func NewPostgresQuerySource(driver string, cfg config.SourceConfig) (*PostgresQuerySource, error) {
|
||||
name := strings.TrimSpace(cfg.Name)
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("%s: name is required", driver)
|
||||
}
|
||||
if cfg.Params == nil {
|
||||
return nil, fmt.Errorf("%s %q: params are required (need params.uri, params.username, params.password, and params.query)", driver, cfg.Name)
|
||||
}
|
||||
|
||||
uri, ok := cfg.ParamString("uri")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s %q: params.uri is required", driver, cfg.Name)
|
||||
}
|
||||
username, ok := cfg.ParamString("username")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s %q: params.username is required", driver, cfg.Name)
|
||||
}
|
||||
password, ok := cfg.ParamString("password")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s %q: params.password is required", driver, cfg.Name)
|
||||
}
|
||||
query, ok := cfg.ParamString("query")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s %q: params.query is required", driver, cfg.Name)
|
||||
}
|
||||
|
||||
queryTimeout := defaultPostgresQueryTimeout
|
||||
if _, exists := cfg.Params["query_timeout"]; exists {
|
||||
var ok bool
|
||||
queryTimeout, ok = cfg.ParamDuration("query_timeout")
|
||||
if !ok || queryTimeout <= 0 {
|
||||
return nil, fmt.Errorf("source %q: params.query_timeout must be a positive duration", cfg.Name)
|
||||
}
|
||||
}
|
||||
|
||||
db, err := openPostgresQueryDB(context.Background(), pgconn.ConnConfig{
|
||||
URI: uri,
|
||||
Username: username,
|
||||
Password: password,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s %q: open db: %w", driver, cfg.Name, err)
|
||||
}
|
||||
|
||||
return &PostgresQuerySource{
|
||||
Driver: driver,
|
||||
Name: name,
|
||||
SQL: query,
|
||||
QueryTimeout: queryTimeout,
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *PostgresQuerySource) Query(ctx context.Context, args ...any) (*sql.Rows, error) {
|
||||
queryCtx := ctx
|
||||
if s.QueryTimeout > 0 {
|
||||
if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > s.QueryTimeout {
|
||||
// We intentionally do not cancel this derived context here because the
|
||||
// returned rows may still be reading from the database.
|
||||
queryCtx, _ = context.WithTimeout(ctx, s.QueryTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(queryCtx, s.SQL, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s %q: query: %w", s.Driver, s.Name, err)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
352
sources/postgres_test.go
Normal file
352
sources/postgres_test.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package sources
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
|
||||
)
|
||||
|
||||
type fakePostgresQueryDB struct {
|
||||
queryErr error
|
||||
lastCtx context.Context
|
||||
lastQuery string
|
||||
lastArgs []any
|
||||
returnRows *sql.Rows
|
||||
}
|
||||
|
||||
func (db *fakePostgresQueryDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||
db.lastCtx = ctx
|
||||
db.lastQuery = query
|
||||
db.lastArgs = append([]any(nil), args...)
|
||||
if db.queryErr != nil {
|
||||
return nil, db.queryErr
|
||||
}
|
||||
return db.returnRows, nil
|
||||
}
|
||||
|
||||
func withPostgresQuerySourceTestState(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
oldOpen := openPostgresQueryDB
|
||||
t.Cleanup(func() {
|
||||
openPostgresQueryDB = oldOpen
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewPostgresQuerySourceMissingParams(t *testing.T) {
|
||||
withPostgresQuerySourceTestState(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
params map[string]any
|
||||
want string
|
||||
}{
|
||||
{name: "missing uri", params: map[string]any{"username": "u", "password": "p", "query": "SELECT 1"}, want: "params.uri"},
|
||||
{name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p", "query": "SELECT 1"}, want: "params.username"},
|
||||
{name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "query": "SELECT 1"}, want: "params.password"},
|
||||
{name: "missing query", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "password": "p"}, want: "params.query"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: tc.params,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.want) {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %q, want substring %q", err, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresQuerySourceRejectsInvalidQueryTimeout(t *testing.T) {
|
||||
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "u",
|
||||
"password": "p",
|
||||
"query": "SELECT 1",
|
||||
"query_timeout": "soon",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "params.query_timeout must be a positive duration") {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresQuerySourceSuccessfulConstruction(t *testing.T) {
|
||||
withPostgresQuerySourceTestState(t)
|
||||
|
||||
db := &fakePostgresQueryDB{}
|
||||
var gotCfg pgconn.ConnConfig
|
||||
openPostgresQueryDB = func(_ context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) {
|
||||
gotCfg = cfg
|
||||
return db, nil
|
||||
}
|
||||
|
||||
src, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://db.example.local/feedkit",
|
||||
"username": "app_user",
|
||||
"password": "app_pass",
|
||||
"query": "SELECT * FROM observations",
|
||||
"query_timeout": "45s",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %v", err)
|
||||
}
|
||||
if src.Name != "pg-source" {
|
||||
t.Fatalf("Name = %q, want pg-source", src.Name)
|
||||
}
|
||||
if src.QueryTimeout != 45*time.Second {
|
||||
t.Fatalf("QueryTimeout = %s, want 45s", src.QueryTimeout)
|
||||
}
|
||||
if src.SQL != "SELECT * FROM observations" {
|
||||
t.Fatalf("SQL = %q", src.SQL)
|
||||
}
|
||||
if gotCfg.Username != "app_user" || gotCfg.Password != "app_pass" {
|
||||
t.Fatalf("ConnConfig = %+v", gotCfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresQuerySourceOpenFailure(t *testing.T) {
|
||||
withPostgresQuerySourceTestState(t)
|
||||
|
||||
openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) {
|
||||
return nil, errors.New("db unavailable")
|
||||
}
|
||||
|
||||
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "u",
|
||||
"password": "p",
|
||||
"query": "SELECT 1",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `test_driver "pg-source": open db: db unavailable`) {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresQuerySourceQueryAppliesTimeoutAndWrapsError(t *testing.T) {
|
||||
db := &fakePostgresQueryDB{queryErr: errors.New("query failed")}
|
||||
src := &PostgresQuerySource{
|
||||
Driver: "test_driver",
|
||||
Name: "pg-source",
|
||||
SQL: "SELECT 1",
|
||||
QueryTimeout: 30 * time.Second,
|
||||
db: db,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := src.Query(ctx, "arg1")
|
||||
if err == nil {
|
||||
t.Fatalf("Query() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `test_driver "pg-source": query: query failed`) {
|
||||
t.Fatalf("Query() error = %q", err)
|
||||
}
|
||||
if db.lastCtx == nil {
|
||||
t.Fatalf("lastCtx = nil")
|
||||
}
|
||||
if _, ok := db.lastCtx.Deadline(); !ok {
|
||||
t.Fatalf("expected derived deadline on query context")
|
||||
}
|
||||
if db.lastQuery != "SELECT 1" {
|
||||
t.Fatalf("lastQuery = %q", db.lastQuery)
|
||||
}
|
||||
if len(db.lastArgs) != 1 || db.lastArgs[0] != "arg1" {
|
||||
t.Fatalf("lastArgs = %#v", db.lastArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresQuerySourceQueryUsesEarlierCallerDeadline(t *testing.T) {
|
||||
db := &fakePostgresQueryDB{queryErr: errors.New("query failed")}
|
||||
src := &PostgresQuerySource{
|
||||
Driver: "test_driver",
|
||||
Name: "pg-source",
|
||||
SQL: "SELECT 1",
|
||||
QueryTimeout: 30 * time.Second,
|
||||
db: db,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, _ = src.Query(ctx)
|
||||
if db.lastCtx != ctx {
|
||||
t.Fatalf("expected source to reuse earlier caller deadline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresQuerySourceSupportsDownstreamPollingPattern(t *testing.T) {
|
||||
withPostgresQuerySourceTestState(t)
|
||||
|
||||
db, cleanup := openRowsTestDB(t, "feedkit_sources_pg_rows", []string{"event_id"}, [][]driver.Value{{"evt-1"}})
|
||||
defer cleanup()
|
||||
|
||||
openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
type fakeDownstreamSource struct {
|
||||
pg *PostgresQuerySource
|
||||
}
|
||||
poll := func(s fakeDownstreamSource, ctx context.Context) ([]event.Event, error) {
|
||||
rows, err := s.pg.Query(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []event.Event
|
||||
for rows.Next() {
|
||||
var eventID string
|
||||
if err := rows.Scan(&eventID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, event.Event{
|
||||
ID: eventID,
|
||||
Kind: event.Kind("observation"),
|
||||
Source: s.pg.Name,
|
||||
Schema: "raw.test.v1",
|
||||
EmittedAt: time.Now().UTC(),
|
||||
Payload: map[string]any{"event_id": eventID},
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
pg, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "u",
|
||||
"password": "p",
|
||||
"query": "SELECT event_id FROM events",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %v", err)
|
||||
}
|
||||
|
||||
events, err := poll(fakeDownstreamSource{pg: pg}, context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("poll() error = %v", err)
|
||||
}
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("len(events) = %d, want 1", len(events))
|
||||
}
|
||||
if events[0].ID != "evt-1" {
|
||||
t.Fatalf("events[0].ID = %q, want evt-1", events[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
rowsDriverMu sync.Mutex
|
||||
rowsDriverSeen = map[string]bool{}
|
||||
)
|
||||
|
||||
func openRowsTestDB(t *testing.T, driverName string, columns []string, rows [][]driver.Value) (*sql.DB, func()) {
|
||||
t.Helper()
|
||||
|
||||
rowsDriverMu.Lock()
|
||||
if !rowsDriverSeen[driverName] {
|
||||
sql.Register(driverName, &rowsTestDriver{columns: append([]string(nil), columns...), rows: cloneDriverRows(rows)})
|
||||
rowsDriverSeen[driverName] = true
|
||||
}
|
||||
rowsDriverMu.Unlock()
|
||||
|
||||
db, err := sql.Open(driverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open() error = %v", err)
|
||||
}
|
||||
|
||||
return db, func() {
|
||||
_ = db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func cloneDriverRows(in [][]driver.Value) [][]driver.Value {
|
||||
out := make([][]driver.Value, 0, len(in))
|
||||
for _, row := range in {
|
||||
copied := append([]driver.Value(nil), row...)
|
||||
out = append(out, copied)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type rowsTestDriver struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
}
|
||||
|
||||
func (d *rowsTestDriver) Open(string) (driver.Conn, error) {
|
||||
return &rowsTestConn{columns: append([]string(nil), d.columns...), rows: cloneDriverRows(d.rows)}, nil
|
||||
}
|
||||
|
||||
type rowsTestConn struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
}
|
||||
|
||||
func (c *rowsTestConn) Prepare(string) (driver.Stmt, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (c *rowsTestConn) Close() error { return nil }
|
||||
func (c *rowsTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") }
|
||||
|
||||
func (c *rowsTestConn) QueryContext(_ context.Context, _ string, _ []driver.NamedValue) (driver.Rows, error) {
|
||||
return &rowsTestRows{columns: append([]string(nil), c.columns...), rows: cloneDriverRows(c.rows)}, nil
|
||||
}
|
||||
|
||||
type rowsTestRows struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
idx int
|
||||
}
|
||||
|
||||
func (r *rowsTestRows) Columns() []string { return append([]string(nil), r.columns...) }
|
||||
func (r *rowsTestRows) Close() error { return nil }
|
||||
|
||||
func (r *rowsTestRows) Next(dest []driver.Value) error {
|
||||
if r.idx >= len(r.rows) {
|
||||
return io.EOF
|
||||
}
|
||||
copy(dest, r.rows[r.idx])
|
||||
r.idx++
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user