diff --git a/internal/postgres/postgres.go b/internal/postgres/postgres.go new file mode 100644 index 0000000..5aced41 --- /dev/null +++ b/internal/postgres/postgres.go @@ -0,0 +1,67 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + "net/url" + "strings" + "time" + + _ "github.com/lib/pq" +) + +const initTimeout = 5 * time.Second + +var ( + sqlOpen = sql.Open + pingDB = func(ctx context.Context, db *sql.DB) error { return db.PingContext(ctx) } +) + +// ConnConfig describes the minimal connection settings shared by feedkit's +// Postgres readers and writers. +type ConnConfig struct { + URI string + Username string + Password string +} + +// BuildDSN validates a Postgres URI and injects credentials into it. +func BuildDSN(cfg ConnConfig) (string, error) { + u, err := url.Parse(strings.TrimSpace(cfg.URI)) + if err != nil { + return "", fmt.Errorf("invalid uri: %w", err) + } + if u.Scheme == "" { + return "", fmt.Errorf("invalid uri: missing scheme") + } + if u.Host == "" { + return "", fmt.Errorf("invalid uri: missing host") + } + u.User = url.UserPassword(cfg.Username, cfg.Password) + return u.String(), nil +} + +// Open builds a DSN, opens a database handle, and verifies connectivity with a +// bounded ping before returning the handle. +func Open(ctx context.Context, cfg ConnConfig) (*sql.DB, error) { + dsn, err := BuildDSN(cfg) + if err != nil { + return nil, err + } + + db, err := sqlOpen("postgres", dsn) + if err != nil { + return nil, err + } + + pingCtx, cancel := context.WithTimeout(ctx, initTimeout) + defer cancel() + + if err := pingDB(pingCtx, db); err != nil { + _ = db.Close() + return nil, err + } + + return db, nil +} diff --git a/internal/postgres/postgres_test.go b/internal/postgres/postgres_test.go new file mode 100644 index 0000000..d83cedf --- /dev/null +++ b/internal/postgres/postgres_test.go @@ -0,0 +1,165 @@ +package postgres + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "io" + "net/url" + "strings" + "sync" + "testing" +) + +func withPostgresPackageTestState(t *testing.T) { + t.Helper() + + oldSQLOpen := sqlOpen + oldPingDB := pingDB + t.Cleanup(func() { + sqlOpen = oldSQLOpen + pingDB = oldPingDB + }) +} + +func TestBuildDSNInjectsCredentials(t *testing.T) { + dsn, err := BuildDSN(ConnConfig{ + URI: " postgres://db.example.local:5432/feedkit?sslmode=disable ", + Username: "app_user", + Password: "app_pass", + }) + if err != nil { + t.Fatalf("BuildDSN() error = %v", err) + } + + u, err := url.Parse(dsn) + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + if u.User == nil || u.User.Username() != "app_user" { + t.Fatalf("username = %q, want app_user", u.User.Username()) + } + pass, ok := u.User.Password() + if !ok || pass != "app_pass" { + t.Fatalf("password = %q, want app_pass", pass) + } +} + +func TestBuildDSNRejectsInvalidURI(t *testing.T) { + _, err := BuildDSN(ConnConfig{URI: "http://[::1", Username: "u", Password: "p"}) + if err == nil { + t.Fatalf("BuildDSN() error = nil, want error") + } + if !strings.Contains(err.Error(), "invalid uri") { + t.Fatalf("BuildDSN() error = %q", err) + } +} + +func TestBuildDSNRejectsMissingScheme(t *testing.T) { + _, err := BuildDSN(ConnConfig{URI: "//db.example.local/feedkit", Username: "u", Password: "p"}) + if err == nil { + t.Fatalf("BuildDSN() error = nil, want error") + } + if !strings.Contains(err.Error(), "missing scheme") { + t.Fatalf("BuildDSN() error = %q", err) + } +} + +func TestBuildDSNRejectsMissingHost(t *testing.T) { + _, err := BuildDSN(ConnConfig{URI: "postgres:///feedkit", Username: "u", Password: "p"}) + if err == nil { + t.Fatalf("BuildDSN() error = nil, want error") + } + if !strings.Contains(err.Error(), "missing host") { + t.Fatalf("BuildDSN() error = %q", err) + } +} + +func TestOpenPropagatesOpenFailure(t *testing.T) { + withPostgresPackageTestState(t) + + sqlOpen = func(_, _ string) (*sql.DB, error) { + return nil, errors.New("open failed") + } + + _, err := Open(context.Background(), ConnConfig{ + URI: "postgres://db.example.local/feedkit", + Username: "u", + Password: "p", + }) + if err == nil { + t.Fatalf("Open() error = nil, want error") + } + if !strings.Contains(err.Error(), "open failed") { + t.Fatalf("Open() error = %q", err) + } +} + +func TestOpenPropagatesPingFailure(t *testing.T) { + withPostgresPackageTestState(t) + + const driverName = "feedkit_internal_postgres_ping_fail" + registerPingTestDriver(driverName, errors.New("ping failed")) + + sqlOpen = func(_, _ string) (*sql.DB, error) { + return sql.Open(driverName, "") + } + + _, err := Open(context.Background(), ConnConfig{ + URI: "postgres://db.example.local/feedkit", + Username: "u", + Password: "p", + }) + if err == nil { + t.Fatalf("Open() error = nil, want error") + } + if !strings.Contains(err.Error(), "ping failed") { + t.Fatalf("Open() error = %q", err) + } +} + +var ( + pingDriverMu sync.Mutex + pingDriverSeen = map[string]bool{} +) + +func registerPingTestDriver(name string, pingErr error) { + pingDriverMu.Lock() + defer pingDriverMu.Unlock() + + if pingDriverSeen[name] { + return + } + sql.Register(name, &pingTestDriver{pingErr: pingErr}) + pingDriverSeen[name] = true +} + +type pingTestDriver struct { + pingErr error +} + +func (d *pingTestDriver) Open(string) (driver.Conn, error) { + return &pingTestConn{pingErr: d.pingErr}, nil +} + +type pingTestConn struct { + pingErr error +} + +func (c *pingTestConn) Prepare(string) (driver.Stmt, error) { + return nil, errors.New("not implemented") +} +func (c *pingTestConn) Close() error { return nil } +func (c *pingTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") } +func (c *pingTestConn) Ping(context.Context) error { return c.pingErr } + +func (c *pingTestConn) QueryContext(context.Context, string, []driver.NamedValue) (driver.Rows, error) { + return &pingTestRows{}, nil +} + +type pingTestRows struct{} + +func (r *pingTestRows) Columns() []string { return []string{"ok"} } +func (r *pingTestRows) Close() error { return nil } +func (r *pingTestRows) Next([]driver.Value) error { return io.EOF } diff --git a/sinks/postgres.go b/sinks/postgres.go index 5a0ea4a..c9f4461 100644 --- a/sinks/postgres.go +++ b/sinks/postgres.go @@ -4,18 +4,15 @@ import ( "context" "database/sql" "fmt" - "net/url" "strconv" "strings" "time" "gitea.maximumdirect.net/ejr/feedkit/config" "gitea.maximumdirect.net/ejr/feedkit/event" - _ "github.com/lib/pq" + pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres" ) -const postgresInitTimeout = 5 * time.Second - type postgresTx interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) Commit() error @@ -73,8 +70,8 @@ func (w *sqlTxWrapper) Rollback() error { return w.tx.Rollback() } -var openPostgresDB = func(dsn string) (postgresDB, error) { - db, err := sql.Open("postgres", dsn) +var openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) { + db, err := pgconn.Open(ctx, cfg) if err != nil { return nil, err } @@ -111,12 +108,11 @@ func NewPostgresSinkFromConfig(cfg config.SinkConfig, schemaDef PostgresSchema) return nil, fmt.Errorf("postgres sink %q: compile schema: %w", cfg.Name, err) } - dsn, err := buildPostgresDSN(uri, username, password) - if err != nil { - return nil, fmt.Errorf("postgres sink %q: build dsn: %w", cfg.Name, err) - } - - db, err := openPostgresDB(dsn) + db, err := openPostgresDB(context.Background(), pgconn.ConnConfig{ + URI: uri, + Username: username, + Password: password, + }) if err != nil { return nil, fmt.Errorf("postgres sink %q: open db: %w", cfg.Name, err) } @@ -264,13 +260,9 @@ func (p *PostgresSink) PruneAllOlderThan(ctx context.Context, cutoff time.Time) } func (p *PostgresSink) initialize() error { - ctx, cancel := context.WithTimeout(context.Background(), postgresInitTimeout) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if err := p.db.PingContext(ctx); err != nil { - return fmt.Errorf("postgres sink %q: ping db: %w", p.name, err) - } - for _, tableName := range p.schema.tableOrder { tbl := p.schema.tables[tableName] @@ -302,21 +294,6 @@ func (p *PostgresSink) lookupTable(table string) (postgresTableCompiled, error) return tbl, nil } -func buildPostgresDSN(uri, username, password string) (string, error) { - u, err := url.Parse(strings.TrimSpace(uri)) - if err != nil { - return "", fmt.Errorf("invalid uri: %w", err) - } - if u.Scheme == "" { - return "", fmt.Errorf("invalid uri: missing scheme") - } - if u.Host == "" { - return "", fmt.Errorf("invalid uri: missing host") - } - u.User = url.UserPassword(username, password) - return u.String(), nil -} - func parsePostgresPruneWindow(cfg config.SinkConfig) (time.Duration, error) { raw, ok := cfg.Params["prune"] if !ok || raw == nil { diff --git a/sinks/postgres_test.go b/sinks/postgres_test.go index 2ff98d0..df9d4d8 100644 --- a/sinks/postgres_test.go +++ b/sinks/postgres_test.go @@ -4,13 +4,13 @@ import ( "context" "database/sql" "errors" - "net/url" "strings" "testing" "time" "gitea.maximumdirect.net/ejr/feedkit/config" "gitea.maximumdirect.net/ejr/feedkit/event" + pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres" ) type fakeResult struct { @@ -223,10 +223,13 @@ func TestPostgresFactoryBuildsMultipleSinksWithSameSchema(t *testing.T) { withPostgresTestState(t) dbs := []*fakeDB{{}, {}} - var gotDSNs []string - openPostgresDB = func(dsn string) (postgresDB, error) { - gotDSNs = append(gotDSNs, dsn) - db := dbs[len(gotDSNs)-1] + var gotCfgs []pgconn.ConnConfig + openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) { + gotCfgs = append(gotCfgs, cfg) + db := dbs[len(gotCfgs)-1] + if err := db.PingContext(ctx); err != nil { + return nil, err + } return db, nil } @@ -252,8 +255,11 @@ func TestPostgresFactoryBuildsMultipleSinksWithSameSchema(t *testing.T) { } } - if len(gotDSNs) != 2 { - t.Fatalf("len(gotDSNs) = %d, want 2", len(gotDSNs)) + if len(gotCfgs) != 2 { + t.Fatalf("len(gotCfgs) = %d, want 2", len(gotCfgs)) + } + if gotCfgs[0].Username != "user" || gotCfgs[0].Password != "pass" { + t.Fatalf("first ConnConfig = %+v", gotCfgs[0]) } for i, db := range dbs { if db.pingCalls != 1 { @@ -327,9 +333,12 @@ func TestNewPostgresSinkFromConfigEagerInit(t *testing.T) { withPostgresTestState(t) db := &fakeDB{} - var gotDSN string - openPostgresDB = func(dsn string) (postgresDB, error) { - gotDSN = dsn + var gotCfg pgconn.ConnConfig + openPostgresDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresDB, error) { + gotCfg = cfg + if err := db.PingContext(ctx); err != nil { + return nil, err + } return db, nil } @@ -362,16 +371,14 @@ func TestNewPostgresSinkFromConfigEagerInit(t *testing.T) { t.Fatalf("unexpected create index query: %s", db.execCalls[1].query) } - u, err := url.Parse(gotDSN) - if err != nil { - t.Fatalf("parse dsn: %v", err) + if gotCfg.URI != "postgres://db.example.local:5432/feedkit?sslmode=disable" { + t.Fatalf("URI = %q", gotCfg.URI) } - if u.User == nil || u.User.Username() != "app_user" { - t.Fatalf("dsn missing username: %q", gotDSN) + if gotCfg.Username != "app_user" { + t.Fatalf("Username = %q, want app_user", gotCfg.Username) } - pass, ok := u.User.Password() - if !ok || pass != "app_pass" { - t.Fatalf("dsn missing password: %q", gotDSN) + if gotCfg.Password != "app_pass" { + t.Fatalf("Password = %q, want app_pass", gotCfg.Password) } } @@ -379,7 +386,7 @@ func TestNewPostgresSinkFromConfigInitFailureClosesDB(t *testing.T) { withPostgresTestState(t) db := &fakeDB{execErrOnCall: 1, execErr: errors.New("ddl failed")} - openPostgresDB = func(_ string) (postgresDB, error) { + openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) { return db, nil } @@ -415,7 +422,7 @@ func TestNewPostgresSinkFromConfigPruneParamAccepted(t *testing.T) { t.Run(tc.name, func(t *testing.T) { withPostgresTestState(t) - openPostgresDB = func(_ string) (postgresDB, error) { + openPostgresDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresDB, error) { return &fakeDB{}, nil } diff --git a/sources/doc.go b/sources/doc.go index 70d4eb4..a922bcf 100644 --- a/sources/doc.go +++ b/sources/doc.go @@ -9,6 +9,8 @@ // stream exit classification helpers // - Registry / NewRegistry: source driver registry and builders // - HTTPSource / NewHTTPSource: reusable HTTP polling helper +// - PostgresQuerySource / NewPostgresQuerySource: reusable Postgres polling +// helper // // Source drivers are domain-specific and registered into Registry by driver name. // Registry can then build configured sources from config.SourceConfig. @@ -34,4 +36,17 @@ // When validators are available, NewHTTPSource prefers ETag/If-None-Match and // falls back to Last-Modified/If-Modified-Since. A 304 Not Modified response is // treated as a successful unchanged poll. +// +// Postgres-backed polling sources can share NewPostgresQuerySource for generic +// DB config parsing and query execution. The helper understands: +// - params.uri +// - params.username +// - params.password +// - params.query +// - params.query_timeout (optional, default 30s) +// +// feedkit does not register a built-in postgres poll driver. Downstream daemons +// should register domain-specific driver names that call +// NewPostgresQuerySource, then keep SQL semantics, row scanning, ordering, +// watermark policy, and event construction in their own source types. package sources diff --git a/sources/postgres.go b/sources/postgres.go new file mode 100644 index 0000000..34f02d1 --- /dev/null +++ b/sources/postgres.go @@ -0,0 +1,117 @@ +package sources + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + "gitea.maximumdirect.net/ejr/feedkit/config" + pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres" +) + +const defaultPostgresQueryTimeout = 30 * time.Second + +type postgresQueryDB interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} + +var openPostgresQueryDB = func(ctx context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) { + return pgconn.Open(ctx, cfg) +} + +// PostgresQuerySource is a reusable helper for polling Postgres-backed sources. +// +// It centralizes generic source config parsing and query execution. Concrete +// daemon sources remain responsible for SQL semantics, row scanning, cursoring, +// and event construction. +type PostgresQuerySource struct { + Driver string + Name string + SQL string + QueryTimeout time.Duration + + db postgresQueryDB +} + +// NewPostgresQuerySource builds a generic Postgres polling helper from +// SourceConfig. +// +// Required params: +// - params.uri +// - params.username +// - params.password +// - params.query +// +// Optional params: +// - params.query_timeout (default 30s) +func NewPostgresQuerySource(driver string, cfg config.SourceConfig) (*PostgresQuerySource, error) { + name := strings.TrimSpace(cfg.Name) + if name == "" { + return nil, fmt.Errorf("%s: name is required", driver) + } + if cfg.Params == nil { + return nil, fmt.Errorf("%s %q: params are required (need params.uri, params.username, params.password, and params.query)", driver, cfg.Name) + } + + uri, ok := cfg.ParamString("uri") + if !ok { + return nil, fmt.Errorf("%s %q: params.uri is required", driver, cfg.Name) + } + username, ok := cfg.ParamString("username") + if !ok { + return nil, fmt.Errorf("%s %q: params.username is required", driver, cfg.Name) + } + password, ok := cfg.ParamString("password") + if !ok { + return nil, fmt.Errorf("%s %q: params.password is required", driver, cfg.Name) + } + query, ok := cfg.ParamString("query") + if !ok { + return nil, fmt.Errorf("%s %q: params.query is required", driver, cfg.Name) + } + + queryTimeout := defaultPostgresQueryTimeout + if _, exists := cfg.Params["query_timeout"]; exists { + var ok bool + queryTimeout, ok = cfg.ParamDuration("query_timeout") + if !ok || queryTimeout <= 0 { + return nil, fmt.Errorf("source %q: params.query_timeout must be a positive duration", cfg.Name) + } + } + + db, err := openPostgresQueryDB(context.Background(), pgconn.ConnConfig{ + URI: uri, + Username: username, + Password: password, + }) + if err != nil { + return nil, fmt.Errorf("%s %q: open db: %w", driver, cfg.Name, err) + } + + return &PostgresQuerySource{ + Driver: driver, + Name: name, + SQL: query, + QueryTimeout: queryTimeout, + db: db, + }, nil +} + +func (s *PostgresQuerySource) Query(ctx context.Context, args ...any) (*sql.Rows, error) { + queryCtx := ctx + if s.QueryTimeout > 0 { + if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > s.QueryTimeout { + // We intentionally do not cancel this derived context here because the + // returned rows may still be reading from the database. + queryCtx, _ = context.WithTimeout(ctx, s.QueryTimeout) + } + } + + rows, err := s.db.QueryContext(queryCtx, s.SQL, args...) + if err != nil { + return nil, fmt.Errorf("%s %q: query: %w", s.Driver, s.Name, err) + } + return rows, nil +} diff --git a/sources/postgres_test.go b/sources/postgres_test.go new file mode 100644 index 0000000..40982aa --- /dev/null +++ b/sources/postgres_test.go @@ -0,0 +1,352 @@ +package sources + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "io" + "strings" + "sync" + "testing" + "time" + + "gitea.maximumdirect.net/ejr/feedkit/config" + "gitea.maximumdirect.net/ejr/feedkit/event" + pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres" +) + +type fakePostgresQueryDB struct { + queryErr error + lastCtx context.Context + lastQuery string + lastArgs []any + returnRows *sql.Rows +} + +func (db *fakePostgresQueryDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + db.lastCtx = ctx + db.lastQuery = query + db.lastArgs = append([]any(nil), args...) + if db.queryErr != nil { + return nil, db.queryErr + } + return db.returnRows, nil +} + +func withPostgresQuerySourceTestState(t *testing.T) { + t.Helper() + + oldOpen := openPostgresQueryDB + t.Cleanup(func() { + openPostgresQueryDB = oldOpen + }) +} + +func TestNewPostgresQuerySourceMissingParams(t *testing.T) { + withPostgresQuerySourceTestState(t) + + tests := []struct { + name string + params map[string]any + want string + }{ + {name: "missing uri", params: map[string]any{"username": "u", "password": "p", "query": "SELECT 1"}, want: "params.uri"}, + {name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p", "query": "SELECT 1"}, want: "params.username"}, + {name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "query": "SELECT 1"}, want: "params.password"}, + {name: "missing query", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "password": "p"}, want: "params.query"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewPostgresQuerySource("test_driver", config.SourceConfig{ + Name: "pg-source", + Driver: "test_driver", + Params: tc.params, + }) + if err == nil { + t.Fatalf("NewPostgresQuerySource() error = nil, want error") + } + if !strings.Contains(err.Error(), tc.want) { + t.Fatalf("NewPostgresQuerySource() error = %q, want substring %q", err, tc.want) + } + }) + } +} + +func TestNewPostgresQuerySourceRejectsInvalidQueryTimeout(t *testing.T) { + _, err := NewPostgresQuerySource("test_driver", config.SourceConfig{ + Name: "pg-source", + Driver: "test_driver", + Params: map[string]any{ + "uri": "postgres://localhost/db", + "username": "u", + "password": "p", + "query": "SELECT 1", + "query_timeout": "soon", + }, + }) + if err == nil { + t.Fatalf("NewPostgresQuerySource() error = nil, want error") + } + if !strings.Contains(err.Error(), "params.query_timeout must be a positive duration") { + t.Fatalf("NewPostgresQuerySource() error = %q", err) + } +} + +func TestNewPostgresQuerySourceSuccessfulConstruction(t *testing.T) { + withPostgresQuerySourceTestState(t) + + db := &fakePostgresQueryDB{} + var gotCfg pgconn.ConnConfig + openPostgresQueryDB = func(_ context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) { + gotCfg = cfg + return db, nil + } + + src, err := NewPostgresQuerySource("test_driver", config.SourceConfig{ + Name: "pg-source", + Driver: "test_driver", + Params: map[string]any{ + "uri": "postgres://db.example.local/feedkit", + "username": "app_user", + "password": "app_pass", + "query": "SELECT * FROM observations", + "query_timeout": "45s", + }, + }) + if err != nil { + t.Fatalf("NewPostgresQuerySource() error = %v", err) + } + if src.Name != "pg-source" { + t.Fatalf("Name = %q, want pg-source", src.Name) + } + if src.QueryTimeout != 45*time.Second { + t.Fatalf("QueryTimeout = %s, want 45s", src.QueryTimeout) + } + if src.SQL != "SELECT * FROM observations" { + t.Fatalf("SQL = %q", src.SQL) + } + if gotCfg.Username != "app_user" || gotCfg.Password != "app_pass" { + t.Fatalf("ConnConfig = %+v", gotCfg) + } +} + +func TestNewPostgresQuerySourceOpenFailure(t *testing.T) { + withPostgresQuerySourceTestState(t) + + openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) { + return nil, errors.New("db unavailable") + } + + _, err := NewPostgresQuerySource("test_driver", config.SourceConfig{ + Name: "pg-source", + Driver: "test_driver", + Params: map[string]any{ + "uri": "postgres://localhost/db", + "username": "u", + "password": "p", + "query": "SELECT 1", + }, + }) + if err == nil { + t.Fatalf("NewPostgresQuerySource() error = nil, want error") + } + if !strings.Contains(err.Error(), `test_driver "pg-source": open db: db unavailable`) { + t.Fatalf("NewPostgresQuerySource() error = %q", err) + } +} + +func TestPostgresQuerySourceQueryAppliesTimeoutAndWrapsError(t *testing.T) { + db := &fakePostgresQueryDB{queryErr: errors.New("query failed")} + src := &PostgresQuerySource{ + Driver: "test_driver", + Name: "pg-source", + SQL: "SELECT 1", + QueryTimeout: 30 * time.Second, + db: db, + } + + ctx := context.Background() + _, err := src.Query(ctx, "arg1") + if err == nil { + t.Fatalf("Query() error = nil, want error") + } + if !strings.Contains(err.Error(), `test_driver "pg-source": query: query failed`) { + t.Fatalf("Query() error = %q", err) + } + if db.lastCtx == nil { + t.Fatalf("lastCtx = nil") + } + if _, ok := db.lastCtx.Deadline(); !ok { + t.Fatalf("expected derived deadline on query context") + } + if db.lastQuery != "SELECT 1" { + t.Fatalf("lastQuery = %q", db.lastQuery) + } + if len(db.lastArgs) != 1 || db.lastArgs[0] != "arg1" { + t.Fatalf("lastArgs = %#v", db.lastArgs) + } +} + +func TestPostgresQuerySourceQueryUsesEarlierCallerDeadline(t *testing.T) { + db := &fakePostgresQueryDB{queryErr: errors.New("query failed")} + src := &PostgresQuerySource{ + Driver: "test_driver", + Name: "pg-source", + SQL: "SELECT 1", + QueryTimeout: 30 * time.Second, + db: db, + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + _, _ = src.Query(ctx) + if db.lastCtx != ctx { + t.Fatalf("expected source to reuse earlier caller deadline") + } +} + +func TestPostgresQuerySourceSupportsDownstreamPollingPattern(t *testing.T) { + withPostgresQuerySourceTestState(t) + + db, cleanup := openRowsTestDB(t, "feedkit_sources_pg_rows", []string{"event_id"}, [][]driver.Value{{"evt-1"}}) + defer cleanup() + + openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) { + return db, nil + } + + type fakeDownstreamSource struct { + pg *PostgresQuerySource + } + poll := func(s fakeDownstreamSource, ctx context.Context) ([]event.Event, error) { + rows, err := s.pg.Query(ctx) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []event.Event + for rows.Next() { + var eventID string + if err := rows.Scan(&eventID); err != nil { + return nil, err + } + out = append(out, event.Event{ + ID: eventID, + Kind: event.Kind("observation"), + Source: s.pg.Name, + Schema: "raw.test.v1", + EmittedAt: time.Now().UTC(), + Payload: map[string]any{"event_id": eventID}, + }) + } + if err := rows.Err(); err != nil { + return nil, err + } + return out, nil + } + + pg, err := NewPostgresQuerySource("test_driver", config.SourceConfig{ + Name: "pg-source", + Driver: "test_driver", + Params: map[string]any{ + "uri": "postgres://localhost/db", + "username": "u", + "password": "p", + "query": "SELECT event_id FROM events", + }, + }) + if err != nil { + t.Fatalf("NewPostgresQuerySource() error = %v", err) + } + + events, err := poll(fakeDownstreamSource{pg: pg}, context.Background()) + if err != nil { + t.Fatalf("poll() error = %v", err) + } + if len(events) != 1 { + t.Fatalf("len(events) = %d, want 1", len(events)) + } + if events[0].ID != "evt-1" { + t.Fatalf("events[0].ID = %q, want evt-1", events[0].ID) + } +} + +var ( + rowsDriverMu sync.Mutex + rowsDriverSeen = map[string]bool{} +) + +func openRowsTestDB(t *testing.T, driverName string, columns []string, rows [][]driver.Value) (*sql.DB, func()) { + t.Helper() + + rowsDriverMu.Lock() + if !rowsDriverSeen[driverName] { + sql.Register(driverName, &rowsTestDriver{columns: append([]string(nil), columns...), rows: cloneDriverRows(rows)}) + rowsDriverSeen[driverName] = true + } + rowsDriverMu.Unlock() + + db, err := sql.Open(driverName, "") + if err != nil { + t.Fatalf("sql.Open() error = %v", err) + } + + return db, func() { + _ = db.Close() + } +} + +func cloneDriverRows(in [][]driver.Value) [][]driver.Value { + out := make([][]driver.Value, 0, len(in)) + for _, row := range in { + copied := append([]driver.Value(nil), row...) + out = append(out, copied) + } + return out +} + +type rowsTestDriver struct { + columns []string + rows [][]driver.Value +} + +func (d *rowsTestDriver) Open(string) (driver.Conn, error) { + return &rowsTestConn{columns: append([]string(nil), d.columns...), rows: cloneDriverRows(d.rows)}, nil +} + +type rowsTestConn struct { + columns []string + rows [][]driver.Value +} + +func (c *rowsTestConn) Prepare(string) (driver.Stmt, error) { + return nil, errors.New("not implemented") +} +func (c *rowsTestConn) Close() error { return nil } +func (c *rowsTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") } + +func (c *rowsTestConn) QueryContext(_ context.Context, _ string, _ []driver.NamedValue) (driver.Rows, error) { + return &rowsTestRows{columns: append([]string(nil), c.columns...), rows: cloneDriverRows(c.rows)}, nil +} + +type rowsTestRows struct { + columns []string + rows [][]driver.Value + idx int +} + +func (r *rowsTestRows) Columns() []string { return append([]string(nil), r.columns...) } +func (r *rowsTestRows) Close() error { return nil } + +func (r *rowsTestRows) Next(dest []driver.Value) error { + if r.idx >= len(r.rows) { + return io.EOF + } + copy(dest, r.rows[r.idx]) + r.idx++ + return nil +}