Added support for Postgres polling sources
This commit is contained in:
352
sources/postgres_test.go
Normal file
352
sources/postgres_test.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package sources
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.maximumdirect.net/ejr/feedkit/config"
|
||||
"gitea.maximumdirect.net/ejr/feedkit/event"
|
||||
pgconn "gitea.maximumdirect.net/ejr/feedkit/internal/postgres"
|
||||
)
|
||||
|
||||
type fakePostgresQueryDB struct {
|
||||
queryErr error
|
||||
lastCtx context.Context
|
||||
lastQuery string
|
||||
lastArgs []any
|
||||
returnRows *sql.Rows
|
||||
}
|
||||
|
||||
func (db *fakePostgresQueryDB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||
db.lastCtx = ctx
|
||||
db.lastQuery = query
|
||||
db.lastArgs = append([]any(nil), args...)
|
||||
if db.queryErr != nil {
|
||||
return nil, db.queryErr
|
||||
}
|
||||
return db.returnRows, nil
|
||||
}
|
||||
|
||||
func withPostgresQuerySourceTestState(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
oldOpen := openPostgresQueryDB
|
||||
t.Cleanup(func() {
|
||||
openPostgresQueryDB = oldOpen
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewPostgresQuerySourceMissingParams(t *testing.T) {
|
||||
withPostgresQuerySourceTestState(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
params map[string]any
|
||||
want string
|
||||
}{
|
||||
{name: "missing uri", params: map[string]any{"username": "u", "password": "p", "query": "SELECT 1"}, want: "params.uri"},
|
||||
{name: "missing username", params: map[string]any{"uri": "postgres://localhost/db", "password": "p", "query": "SELECT 1"}, want: "params.username"},
|
||||
{name: "missing password", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "query": "SELECT 1"}, want: "params.password"},
|
||||
{name: "missing query", params: map[string]any{"uri": "postgres://localhost/db", "username": "u", "password": "p"}, want: "params.query"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: tc.params,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.want) {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %q, want substring %q", err, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresQuerySourceRejectsInvalidQueryTimeout(t *testing.T) {
|
||||
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "u",
|
||||
"password": "p",
|
||||
"query": "SELECT 1",
|
||||
"query_timeout": "soon",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "params.query_timeout must be a positive duration") {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresQuerySourceSuccessfulConstruction(t *testing.T) {
|
||||
withPostgresQuerySourceTestState(t)
|
||||
|
||||
db := &fakePostgresQueryDB{}
|
||||
var gotCfg pgconn.ConnConfig
|
||||
openPostgresQueryDB = func(_ context.Context, cfg pgconn.ConnConfig) (postgresQueryDB, error) {
|
||||
gotCfg = cfg
|
||||
return db, nil
|
||||
}
|
||||
|
||||
src, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://db.example.local/feedkit",
|
||||
"username": "app_user",
|
||||
"password": "app_pass",
|
||||
"query": "SELECT * FROM observations",
|
||||
"query_timeout": "45s",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %v", err)
|
||||
}
|
||||
if src.Name != "pg-source" {
|
||||
t.Fatalf("Name = %q, want pg-source", src.Name)
|
||||
}
|
||||
if src.QueryTimeout != 45*time.Second {
|
||||
t.Fatalf("QueryTimeout = %s, want 45s", src.QueryTimeout)
|
||||
}
|
||||
if src.SQL != "SELECT * FROM observations" {
|
||||
t.Fatalf("SQL = %q", src.SQL)
|
||||
}
|
||||
if gotCfg.Username != "app_user" || gotCfg.Password != "app_pass" {
|
||||
t.Fatalf("ConnConfig = %+v", gotCfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPostgresQuerySourceOpenFailure(t *testing.T) {
|
||||
withPostgresQuerySourceTestState(t)
|
||||
|
||||
openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) {
|
||||
return nil, errors.New("db unavailable")
|
||||
}
|
||||
|
||||
_, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "u",
|
||||
"password": "p",
|
||||
"query": "SELECT 1",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `test_driver "pg-source": open db: db unavailable`) {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %q", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresQuerySourceQueryAppliesTimeoutAndWrapsError(t *testing.T) {
|
||||
db := &fakePostgresQueryDB{queryErr: errors.New("query failed")}
|
||||
src := &PostgresQuerySource{
|
||||
Driver: "test_driver",
|
||||
Name: "pg-source",
|
||||
SQL: "SELECT 1",
|
||||
QueryTimeout: 30 * time.Second,
|
||||
db: db,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := src.Query(ctx, "arg1")
|
||||
if err == nil {
|
||||
t.Fatalf("Query() error = nil, want error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), `test_driver "pg-source": query: query failed`) {
|
||||
t.Fatalf("Query() error = %q", err)
|
||||
}
|
||||
if db.lastCtx == nil {
|
||||
t.Fatalf("lastCtx = nil")
|
||||
}
|
||||
if _, ok := db.lastCtx.Deadline(); !ok {
|
||||
t.Fatalf("expected derived deadline on query context")
|
||||
}
|
||||
if db.lastQuery != "SELECT 1" {
|
||||
t.Fatalf("lastQuery = %q", db.lastQuery)
|
||||
}
|
||||
if len(db.lastArgs) != 1 || db.lastArgs[0] != "arg1" {
|
||||
t.Fatalf("lastArgs = %#v", db.lastArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresQuerySourceQueryUsesEarlierCallerDeadline(t *testing.T) {
|
||||
db := &fakePostgresQueryDB{queryErr: errors.New("query failed")}
|
||||
src := &PostgresQuerySource{
|
||||
Driver: "test_driver",
|
||||
Name: "pg-source",
|
||||
SQL: "SELECT 1",
|
||||
QueryTimeout: 30 * time.Second,
|
||||
db: db,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, _ = src.Query(ctx)
|
||||
if db.lastCtx != ctx {
|
||||
t.Fatalf("expected source to reuse earlier caller deadline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresQuerySourceSupportsDownstreamPollingPattern(t *testing.T) {
|
||||
withPostgresQuerySourceTestState(t)
|
||||
|
||||
db, cleanup := openRowsTestDB(t, "feedkit_sources_pg_rows", []string{"event_id"}, [][]driver.Value{{"evt-1"}})
|
||||
defer cleanup()
|
||||
|
||||
openPostgresQueryDB = func(_ context.Context, _ pgconn.ConnConfig) (postgresQueryDB, error) {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
type fakeDownstreamSource struct {
|
||||
pg *PostgresQuerySource
|
||||
}
|
||||
poll := func(s fakeDownstreamSource, ctx context.Context) ([]event.Event, error) {
|
||||
rows, err := s.pg.Query(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []event.Event
|
||||
for rows.Next() {
|
||||
var eventID string
|
||||
if err := rows.Scan(&eventID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, event.Event{
|
||||
ID: eventID,
|
||||
Kind: event.Kind("observation"),
|
||||
Source: s.pg.Name,
|
||||
Schema: "raw.test.v1",
|
||||
EmittedAt: time.Now().UTC(),
|
||||
Payload: map[string]any{"event_id": eventID},
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
pg, err := NewPostgresQuerySource("test_driver", config.SourceConfig{
|
||||
Name: "pg-source",
|
||||
Driver: "test_driver",
|
||||
Params: map[string]any{
|
||||
"uri": "postgres://localhost/db",
|
||||
"username": "u",
|
||||
"password": "p",
|
||||
"query": "SELECT event_id FROM events",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewPostgresQuerySource() error = %v", err)
|
||||
}
|
||||
|
||||
events, err := poll(fakeDownstreamSource{pg: pg}, context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("poll() error = %v", err)
|
||||
}
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("len(events) = %d, want 1", len(events))
|
||||
}
|
||||
if events[0].ID != "evt-1" {
|
||||
t.Fatalf("events[0].ID = %q, want evt-1", events[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
rowsDriverMu sync.Mutex
|
||||
rowsDriverSeen = map[string]bool{}
|
||||
)
|
||||
|
||||
func openRowsTestDB(t *testing.T, driverName string, columns []string, rows [][]driver.Value) (*sql.DB, func()) {
|
||||
t.Helper()
|
||||
|
||||
rowsDriverMu.Lock()
|
||||
if !rowsDriverSeen[driverName] {
|
||||
sql.Register(driverName, &rowsTestDriver{columns: append([]string(nil), columns...), rows: cloneDriverRows(rows)})
|
||||
rowsDriverSeen[driverName] = true
|
||||
}
|
||||
rowsDriverMu.Unlock()
|
||||
|
||||
db, err := sql.Open(driverName, "")
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open() error = %v", err)
|
||||
}
|
||||
|
||||
return db, func() {
|
||||
_ = db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func cloneDriverRows(in [][]driver.Value) [][]driver.Value {
|
||||
out := make([][]driver.Value, 0, len(in))
|
||||
for _, row := range in {
|
||||
copied := append([]driver.Value(nil), row...)
|
||||
out = append(out, copied)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type rowsTestDriver struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
}
|
||||
|
||||
func (d *rowsTestDriver) Open(string) (driver.Conn, error) {
|
||||
return &rowsTestConn{columns: append([]string(nil), d.columns...), rows: cloneDriverRows(d.rows)}, nil
|
||||
}
|
||||
|
||||
type rowsTestConn struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
}
|
||||
|
||||
func (c *rowsTestConn) Prepare(string) (driver.Stmt, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (c *rowsTestConn) Close() error { return nil }
|
||||
func (c *rowsTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") }
|
||||
|
||||
func (c *rowsTestConn) QueryContext(_ context.Context, _ string, _ []driver.NamedValue) (driver.Rows, error) {
|
||||
return &rowsTestRows{columns: append([]string(nil), c.columns...), rows: cloneDriverRows(c.rows)}, nil
|
||||
}
|
||||
|
||||
type rowsTestRows struct {
|
||||
columns []string
|
||||
rows [][]driver.Value
|
||||
idx int
|
||||
}
|
||||
|
||||
func (r *rowsTestRows) Columns() []string { return append([]string(nil), r.columns...) }
|
||||
func (r *rowsTestRows) Close() error { return nil }
|
||||
|
||||
func (r *rowsTestRows) Next(dest []driver.Value) error {
|
||||
if r.idx >= len(r.rows) {
|
||||
return io.EOF
|
||||
}
|
||||
copy(dest, r.rows[r.idx])
|
||||
r.idx++
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user