package sources import ( "context" "database/sql" "database/sql/driver" "errors" "io" "strings" "sync" "testing" "time" "gitea.maximumdirect.net/ejr/feedkit/config" "gitea.maximumdirect.net/ejr/feedkit/event" pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres" ) type fakePostgresQueryDB struct { queryErr error lastCtx context.Context lastQuery string lastArgs []any returnRows *sql.Rows } func (db *fakePostgresQueryDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { db.lastCtx = ctx db.lastQuery = query db.lastArgs = append([]any(nil), args...) if db.queryErr != nil { return nil, db.queryErr } return db.returnRows, nil } func withPostgresQuerySourceTestState(t *testing.T) { t.Helper() oldOpen := openPostgresQueryDB t.Cleanup(func() { openPostgresQueryDB = oldOpen }) } func TestNewPostgresQuerySourceMissingParams(t *testing.T) { withPostgresQuerySourceTestState(t) tests := []struct { name string params map[string]any want string }{ {name: "missing uri", params: map[string]any{"username": "u", "password": "p", "query": "SELECT 1"}, want: "params.uri"}, {name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p", "query": "SELECT 1"}, want: "params.username"}, {name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "query": "SELECT 1"}, want: "params.password"}, {name: "missing query", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "password": "p"}, want: "params.query"}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { _, err := NewPostgresQuerySource("test_driver", config.SourceConfig{ Name: "pg-source", Driver: "test_driver", Params: tc.params, }) if err == nil { t.Fatalf("NewPostgresQuerySource() error = nil, want error") } if !strings.Contains(err.Error(), tc.want) { t.Fatalf("NewPostgresQuerySource() error = %q, want substring %q", err, tc.want) } }) } } func TestNewPostgresQuerySourceRejectsInvalidQueryTimeout(t *testing.T) { _, err := NewPostgresQuerySource("test_driver", config.SourceConfig{ Name: "pg-source", Driver: "test_driver", Params: map[string]any{ "uri": "postgres://localhost/db", "username": "u", "password": "p", "query": "SELECT 1", "query_timeout": "soon", }, }) if err == nil { t.Fatalf("NewPostgresQuerySource() error = nil, want error") } if !strings.Contains(err.Error(), "params.query_timeout must be a positive duration") { t.Fatalf("NewPostgresQuerySource() error = %q", err) } } func TestNewPostgresQuerySourceSuccessfulConstruction(t *testing.T) { withPostgresQuerySourceTestState(t) db := &fakePostgresQueryDB{} var gotCfg pgconn.ConnConfig openPostgresQueryDB = func(_ context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) { gotCfg = cfg return db, nil } src, err := NewPostgresQuerySource("test_driver", config.SourceConfig{ Name: "pg-source", Driver: "test_driver", Params: map[string]any{ "uri": "postgres://db.example.local/feedkit", "username": "app_user", "password": "app_pass", "query": "SELECT * FROM observations", "query_timeout": "45s", }, }) if err != nil { t.Fatalf("NewPostgresQuerySource() error = %v", err) } if src.Name != "pg-source" { t.Fatalf("Name = %q, want pg-source", src.Name) } if src.QueryTimeout != 45*time.Second { t.Fatalf("QueryTimeout = %s, want 45s", src.QueryTimeout) } if src.SQL != "SELECT * FROM observations" { t.Fatalf("SQL = %q", src.SQL) } if gotCfg.Username != "app_user" || gotCfg.Password != "app_pass" { t.Fatalf("ConnConfig = %+v", gotCfg) } } func TestNewPostgresQuerySourceOpenFailure(t *testing.T) { withPostgresQuerySourceTestState(t) openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) { return nil, errors.New("db unavailable") } _, err := NewPostgresQuerySource("test_driver", config.SourceConfig{ Name: "pg-source", Driver: "test_driver", Params: map[string]any{ "uri": "postgres://localhost/db", "username": "u", "password": "p", "query": "SELECT 1", }, }) if err == nil { t.Fatalf("NewPostgresQuerySource() error = nil, want error") } if !strings.Contains(err.Error(), `test_driver "pg-source": open db: db unavailable`) { t.Fatalf("NewPostgresQuerySource() error = %q", err) } } func TestPostgresQuerySourceQueryAppliesTimeoutAndWrapsError(t *testing.T) { db := &fakePostgresQueryDB{queryErr: errors.New("query failed")} src := &PostgresQuerySource{ Driver: "test_driver", Name: "pg-source", SQL: "SELECT 1", QueryTimeout: 30 * time.Second, db: db, } ctx := context.Background() _, err := src.Query(ctx, "arg1") if err == nil { t.Fatalf("Query() error = nil, want error") } if !strings.Contains(err.Error(), `test_driver "pg-source": query: query failed`) { t.Fatalf("Query() error = %q", err) } if db.lastCtx == nil { t.Fatalf("lastCtx = nil") } if _, ok := db.lastCtx.Deadline(); !ok { t.Fatalf("expected derived deadline on query context") } if db.lastQuery != "SELECT 1" { t.Fatalf("lastQuery = %q", db.lastQuery) } if len(db.lastArgs) != 1 || db.lastArgs[0] != "arg1" { t.Fatalf("lastArgs = %#v", db.lastArgs) } } func TestPostgresQuerySourceQueryUsesEarlierCallerDeadline(t *testing.T) { db := &fakePostgresQueryDB{queryErr: errors.New("query failed")} src := &PostgresQuerySource{ Driver: "test_driver", Name: "pg-source", SQL: "SELECT 1", QueryTimeout: 30 * time.Second, db: db, } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() _, _ = src.Query(ctx) if db.lastCtx != ctx { t.Fatalf("expected source to reuse earlier caller deadline") } } func TestPostgresQuerySourceSupportsDownstreamPollingPattern(t *testing.T) { withPostgresQuerySourceTestState(t) db, cleanup := openRowsTestDB(t, "feedkit_sources_pg_rows", []string{"event_id"}, [][]driver.Value{{"evt-1"}}) defer cleanup() openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) { return db, nil } type fakeDownstreamSource struct { pg *PostgresQuerySource } poll := func(s fakeDownstreamSource, ctx context.Context) ([]event.Event, error) { rows, err := s.pg.Query(ctx) if err != nil { return nil, err } defer rows.Close() var out []event.Event for rows.Next() { var eventID string if err := rows.Scan(&eventID); err != nil { return nil, err } out = append(out, event.Event{ ID: eventID, Kind: event.Kind("observation"), Source: s.pg.Name, Schema: "raw.test.v1", EmittedAt: time.Now().UTC(), Payload: map[string]any{"event_id": eventID}, }) } if err := rows.Err(); err != nil { return nil, err } return out, nil } pg, err := NewPostgresQuerySource("test_driver", config.SourceConfig{ Name: "pg-source", Driver: "test_driver", Params: map[string]any{ "uri": "postgres://localhost/db", "username": "u", "password": "p", "query": "SELECT event_id FROM events", }, }) if err != nil { t.Fatalf("NewPostgresQuerySource() error = %v", err) } events, err := poll(fakeDownstreamSource{pg: pg}, context.Background()) if err != nil { t.Fatalf("poll() error = %v", err) } if len(events) != 1 { t.Fatalf("len(events) = %d, want 1", len(events)) } if events[0].ID != "evt-1" { t.Fatalf("events[0].ID = %q, want evt-1", events[0].ID) } } var ( rowsDriverMu sync.Mutex rowsDriverSeen = map[string]bool{} ) func openRowsTestDB(t *testing.T, driverName string, columns []string, rows [][]driver.Value) (*sql.DB, func()) { t.Helper() rowsDriverMu.Lock() if !rowsDriverSeen[driverName] { sql.Register(driverName, &rowsTestDriver{columns: append([]string(nil), columns...), rows: cloneDriverRows(rows)}) rowsDriverSeen[driverName] = true } rowsDriverMu.Unlock() db, err := sql.Open(driverName, "") if err != nil { t.Fatalf("sql.Open() error = %v", err) } return db, func() { _ = db.Close() } } func cloneDriverRows(in [][]driver.Value) [][]driver.Value { out := make([][]driver.Value, 0, len(in)) for _, row := range in { copied := append([]driver.Value(nil), row...) out = append(out, copied) } return out } type rowsTestDriver struct { columns []string rows [][]driver.Value } func (d *rowsTestDriver) Open(string) (driver.Conn, error) { return &rowsTestConn{columns: append([]string(nil), d.columns...), rows: cloneDriverRows(d.rows)}, nil } type rowsTestConn struct { columns []string rows [][]driver.Value } func (c *rowsTestConn) Prepare(string) (driver.Stmt, error) { return nil, errors.New("not implemented") } func (c *rowsTestConn) Close() error { return nil } func (c *rowsTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") } func (c *rowsTestConn) QueryContext(_ context.Context, _ string, _ []driver.NamedValue) (driver.Rows, error) { return &rowsTestRows{columns: append([]string(nil), c.columns...), rows: cloneDriverRows(c.rows)}, nil } type rowsTestRows struct { columns []string rows [][]driver.Value idx int } func (r *rowsTestRows) Columns() []string { return append([]string(nil), r.columns...) } func (r *rowsTestRows) Close() error { return nil } func (r *rowsTestRows) Next(dest []driver.Value) error { if r.idx >= len(r.rows) { return io.EOF } copy(dest, r.rows[r.idx]) r.idx++ return nil }