package postgres import ( "context" "database/sql" "database/sql/driver" "errors" "io" "net/url" "strings" "sync" "testing" ) func withPostgresPackageTestState(t *testing.T) { t.Helper() oldSQLOpen := sqlOpen oldPingDB := pingDB t.Cleanup(func() { sqlOpen = oldSQLOpen pingDB = oldPingDB }) } func TestBuildDSNInjectsCredentials(t *testing.T) { dsn, err := BuildDSN(ConnConfig{ URI: " postgres://db.example.local:5432/feedkit?sslmode=disable ", Username: "app_user", Password: "app_pass", }) if err != nil { t.Fatalf("BuildDSN() error = %v", err) } u, err := url.Parse(dsn) if err != nil { t.Fatalf("url.Parse() error = %v", err) } if u.User == nil || u.User.Username() != "app_user" { t.Fatalf("username = %q, want app_user", u.User.Username()) } pass, ok := u.User.Password() if !ok || pass != "app_pass" { t.Fatalf("password = %q, want app_pass", pass) } } func TestBuildDSNRejectsInvalidURI(t *testing.T) { _, err := BuildDSN(ConnConfig{URI: "http://[::1", Username: "u", Password: "p"}) if err == nil { t.Fatalf("BuildDSN() error = nil, want error") } if !strings.Contains(err.Error(), "invalid uri") { t.Fatalf("BuildDSN() error = %q", err) } } func TestBuildDSNRejectsMissingScheme(t *testing.T) { _, err := BuildDSN(ConnConfig{URI: "//db.example.local/feedkit", Username: "u", Password: "p"}) if err == nil { t.Fatalf("BuildDSN() error = nil, want error") } if !strings.Contains(err.Error(), "missing scheme") { t.Fatalf("BuildDSN() error = %q", err) } } func TestBuildDSNRejectsMissingHost(t *testing.T) { _, err := BuildDSN(ConnConfig{URI: "postgres:///feedkit", Username: "u", Password: "p"}) if err == nil { t.Fatalf("BuildDSN() error = nil, want error") } if !strings.Contains(err.Error(), "missing host") { t.Fatalf("BuildDSN() error = %q", err) } } func TestOpenPropagatesOpenFailure(t *testing.T) { withPostgresPackageTestState(t) sqlOpen = func(_, _ string) (*sql.DB, error) { return nil, errors.New("open failed") } _, err := Open(context.Background(), ConnConfig{ URI: "postgres://db.example.local/feedkit", Username: "u", Password: "p", }) if err == nil { t.Fatalf("Open() error = nil, want error") } if !strings.Contains(err.Error(), "open failed") { t.Fatalf("Open() error = %q", err) } } func TestOpenPropagatesPingFailure(t *testing.T) { withPostgresPackageTestState(t) const driverName = "feedkit_internal_postgres_ping_fail" registerPingTestDriver(driverName, errors.New("ping failed")) sqlOpen = func(_, _ string) (*sql.DB, error) { return sql.Open(driverName, "") } _, err := Open(context.Background(), ConnConfig{ URI: "postgres://db.example.local/feedkit", Username: "u", Password: "p", }) if err == nil { t.Fatalf("Open() error = nil, want error") } if !strings.Contains(err.Error(), "ping failed") { t.Fatalf("Open() error = %q", err) } } var ( pingDriverMu sync.Mutex pingDriverSeen = map[string]bool{} ) func registerPingTestDriver(name string, pingErr error) { pingDriverMu.Lock() defer pingDriverMu.Unlock() if pingDriverSeen[name] { return } sql.Register(name, &pingTestDriver{pingErr: pingErr}) pingDriverSeen[name] = true } type pingTestDriver struct { pingErr error } func (d *pingTestDriver) Open(string) (driver.Conn, error) { return &pingTestConn{pingErr: d.pingErr}, nil } type pingTestConn struct { pingErr error } func (c *pingTestConn) Prepare(string) (driver.Stmt, error) { return nil, errors.New("not implemented") } func (c *pingTestConn) Close() error { return nil } func (c *pingTestConn) Begin() (driver.Tx, error) { return nil, errors.New("not implemented") } func (c *pingTestConn) Ping(context.Context) error { return c.pingErr } func (c *pingTestConn) QueryContext(context.Context, string, []driver.NamedValue) (driver.Rows, error) { return &pingTestRows{}, nil } type pingTestRows struct{} func (r *pingTestRows) Columns() []string { return []string{"ok"} } func (r *pingTestRows) Close() error { return nil } func (r *pingTestRows) Next([]driver.Value) error { return io.EOF }