diff --git a/pgutils/connector.go b/pgutils/connector.go index 21dce91..9bdf605 100644 --- a/pgutils/connector.go +++ b/pgutils/connector.go @@ -6,6 +6,8 @@ import ( "fmt" "log" "net/url" + "slices" + "strings" "time" "database/sql" @@ -20,6 +22,8 @@ import ( "github.com/lib/pq" ) +const defaultPostgresPort = "5432" + type baseConnectionStringProvider interface { getBaseConnectionString(ctx context.Context) (string, error) } @@ -212,3 +216,76 @@ func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB { return db } +// NewPostgresqlConnectorFromDSN constructs a PostgresqlConnector from either a normal +// Postgres DSN/connection string or the custom postgres+rds-iam DSN used for RDS IAM auth. +// +// IAM example 1: postgres+rds-iam://@[:]/ +// +// Optional query params (for cross-account IAM): +// - assume_role_arn: role ARN to assume. +// - assume_role_session_name: only used when assume_role_arn is set; defaults to "pgutils-rds-iam" if omitted. +// +// IAM example 2: postgres+rds-iam://@[:]/?assume_role_arn=...&assume_role_session_name=... +func NewPostgresqlConnectorFromDSN(ctx context.Context, dsn string) (*PostgresqlConnector, error) { + if dsn == "" { + return nil, errors.New("DSN cannot be empty") + } + + u, err := url.Parse(dsn) + if err != nil { + return nil, fmt.Errorf("failed to parse DSN: %w", err) + } + + if u.Scheme != "postgres+rds-iam" { // Not our custom scheme: hand off to existing DSN handling. + return NewPostgresqlConnectorFromConnectionString(dsn), nil + } + + user := "" + if u.User != nil { + user = u.User.Username() + if _, hasPw := u.User.Password(); hasPw { + return nil, fmt.Errorf("postgres+rds-iam DSN must not include a password") + } + } + if user == "" { + return nil, fmt.Errorf("postgres+rds-iam DSN missing username") + } + + host := u.Hostname() + if host == "" { + return nil, fmt.Errorf("postgres+rds-iam DSN missing host") + } + + port := u.Port() + if port == "" { + port = defaultPostgresPort + } + + // Match libpq/psql defaulting: if dbname isn't specified, dbname defaults to username. + dbName := strings.TrimPrefix(u.Path, "/") + if dbName == "" { + dbName = user + } + + q := u.Query() + supportedParams := []string{"assume_role_arn", "assume_role_session_name"} + for k := range q { + if !slices.Contains(supportedParams, k) { + return nil, fmt.Errorf("postgres+rds-iam DSN has unsupported query parameter: %s", k) + } + } + + cfg := &IAMAuthConfig{ + RDSEndpoint: host + ":" + port, + User: user, + Database: dbName, + } + + assumeRoleARN := q.Get("assume_role_arn") + if assumeRoleARN != "" { + cfg.AssumeRoleARN = assumeRoleARN + cfg.AssumeRoleSessionName = q.Get("assume_role_session_name") + } + + return NewPostgresqlConnectorWithIAMAuth(ctx, cfg) +}