diff --git a/cmd/rds-iam-psql/README.md b/cmd/rds-iam-psql/README.md new file mode 100644 index 0000000..b6dae57 --- /dev/null +++ b/cmd/rds-iam-psql/README.md @@ -0,0 +1,114 @@ +# rds-iam-psql + +A simple CLI tool that bridges AWS RDS IAM authentication into an interactive `psql` session. It generates a short-lived IAM auth token and launches `psql` with the token as the password, so you never have to manage database passwords. + +## Why? + +RDS IAM authentication lets you connect to PostgreSQL using your AWS credentials instead of a static database password. However, the auth tokens are temporary (15 minutes) and cumbersome to generate manually. This tool handles token generation automatically and drops you into a familiar `psql` shell. + +## Installation + +```bash +go install github.com/corbaltcode/go-libraries/cmd/rds-iam-psql@latest +``` + +Or build from source: + +```bash +cd ./cmd/rds-iam-psql +go build +``` + +## Prerequisites + +- **psql** installed and available in your PATH +- **AWS credentials** configured (via environment variables, `~/.aws/credentials`, IAM role, etc.) +- **RDS IAM authentication enabled** on your database instance +- A database user configured for IAM authentication (created with `CREATE USER myuser WITH LOGIN; GRANT rds_iam TO myuser;`) + +## Usage + +```bash +rds-iam-psql -host -user -db [options] +``` + +### Required Flags + +| Flag | Description | +|------|-------------| +| `-host` | RDS endpoint hostname (without port), e.g. `mydb.abc123.us-east-1.rds.amazonaws.com` | +| `-user` | Database username configured for IAM auth | +| `-db` | Database name to connect to | + +### Optional Flags + +| Flag | Default | Description | +|------|---------|-------------| +| `-port` | `5432` | PostgreSQL port | +| `-region` | auto | AWS region. If omitted, inferred from AWS config or the hostname | +| `-profile` | | AWS shared config profile to use (e.g. `dev`, `prod`) | +| `-psql` | `psql` | Path to the `psql` binary | +| `-sslmode` | `require` | SSL mode (`require`, `verify-full`, etc.) | +| `-search-path` | | PostgreSQL `search_path` to set on connection (e.g. `myschema,public`) | + +## Examples + +Basic connection: + +```bash +rds-iam-psql -host mydb.abc123.us-east-1.rds.amazonaws.com -user app_user -db myapp +``` + +With a specific AWS profile and schema: + +```bash +rds-iam-psql \ + -host mydb.abc123.us-east-1.rds.amazonaws.com \ + -user app_user \ + -db myapp \ + -profile production \ + -search-path "app_schema,public" +``` + +Using a non-standard port and explicit region: + +```bash +rds-iam-psql \ + -host mydb.abc123.us-east-1.rds.amazonaws.com \ + -port 5433 \ + -user admin \ + -db postgres \ + -region us-east-1 +``` + +## How It Works + +1. Loads your AWS credentials from the standard credential chain +2. Generates a temporary RDS IAM auth token using `auth.BuildAuthToken` +3. Launches `psql` with: + - `PGPASSWORD` set to the auth token + - `PGSSLMODE` set according to `-sslmode` + - `PGOPTIONS` set if `-search-path` is provided +4. Attaches stdin/stdout/stderr for interactive use + +## Setting Up IAM Auth on RDS + +1. Enable IAM authentication on your RDS instance +2. Create a database user and grant IAM privileges: + ```sql + CREATE USER myuser WITH LOGIN; + GRANT rds_iam TO myuser; + ``` +3. Attach an IAM policy allowing `rds-db:connect` to your AWS user/role: + ```json + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": "rds-db:connect", + "Resource": "arn:aws:rds-db:::dbuser:/" + } + ] + } + ``` diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go new file mode 100644 index 0000000..8406c0d --- /dev/null +++ b/cmd/rds-iam-psql/main.go @@ -0,0 +1,172 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + "os" + "os/exec" + "os/signal" + "strings" + "syscall" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +func main() { + var ( + host = flag.String("host", "", "RDS PostgreSQL endpoint hostname (no port, e.g. mydb.abc123.us-east-1.rds.amazonaws.com)") + port = flag.Int("port", 5432, "RDS PostgreSQL port (default 5432)") + user = flag.String("user", "", "Database user name") + dbName = flag.String("db", "", "Database name") + region = flag.String("region", "", "AWS region for the RDS instance (e.g. us-east-1). If empty, uses AWS config or tries to infer from host.") + profile = flag.String("profile", "", "Optional AWS shared config profile (e.g. dev)") + psqlPath = flag.String("psql", "psql", "Path to psql binary") + sslMode = flag.String("sslmode", "require", "PGSSLMODE for psql (e.g. require, verify-full)") + searchPath = flag.String("search-path", "", "Optional PostgreSQL search_path to set (e.g. 'myschema,public')") + ) + flag.Parse() + + if *host == "" || *user == "" || *dbName == "" { + log.Fatalf("host, user, and db are required\n\nUsage example:\n %s -host mydb.abc123.us-east-1.rds.amazonaws.com -port 5432 -user myuser -db mydb -search-path \"login,public\" -region us-east-1\n", os.Args[0]) + } + + ctx := context.Background() + + // Load AWS config (standard RDS/IAM auth expects your AWS creds, *not* the DB password). + var cfg aws.Config + var err error + if *profile != "" { + cfg, err = awsconfig.LoadDefaultConfig(ctx, awsconfig.WithSharedConfigProfile(*profile)) + } else { + cfg, err = awsconfig.LoadDefaultConfig(ctx) + } + if err != nil { + log.Fatalf("failed to load AWS config: %v", err) + } + + // Fail fast + print identity (account/arn/role-ish). + if err := printCallerIdentity(ctx, cfg); err != nil { + log.Fatalf("AWS credentials check failed: %v", err) + } + + awsRegion := *region + if awsRegion == "" { + awsRegion = cfg.Region + } + + if awsRegion == "" { + log.Fatalf("AWS region is not set; pass -region or set AWS_REGION / configure your AWS profile") + } + + endpointWithPort := fmt.Sprintf("%s:%d", *host, *port) + + // Generate the IAM auth token. + authToken, err := auth.BuildAuthToken(ctx, endpointWithPort, awsRegion, *user, cfg.Credentials) + if err != nil { + log.Fatalf("failed to build RDS IAM auth token: %v", err) + } + + // Prepare psql command. We pass the token through PGPASSWORD and SSL mode via PGSSLMODE. + cmd := exec.Command( + *psqlPath, + "--host", *host, + "--port", fmt.Sprintf("%d", *port), + "--username", *user, + "--dbname", *dbName, + ) + + // Attach stdio so it behaves like an interactive shell. + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // Inherit existing env and add PG vars. + env := os.Environ() + env = append(env, + "PGPASSWORD="+authToken, + "PGSSLMODE="+*sslMode, + ) + + // If a search path is provided, wire it through PGOPTIONS. + if sp := strings.TrimSpace(*searchPath); sp != "" { + add := "-c search_path=" + sp + + found := false + for i, e := range env { + if strings.HasPrefix(e, "PGOPTIONS=") { + current := strings.TrimPrefix(e, "PGOPTIONS=") + if strings.TrimSpace(current) == "" { + env[i] = "PGOPTIONS=" + add + } else { + env[i] = "PGOPTIONS=" + current + " " + add + } + found = true + break + } + } + if !found { + env = append(env, "PGOPTIONS="+add) + } + } + + cmd.Env = env + + // --- Ctrl-C handling --- + // The key idea: keep psql in the same foreground process group so it can read + // from the terminal. We intercept SIGINT only to prevent THIS wrapper from + // exiting; psql will still receive SIGINT normally and cancel the current + // query / line as expected. + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigCh) + + if err := cmd.Start(); err != nil { + log.Fatalf("failed to start psql: %v", err) + } + + waitCh := make(chan error, 1) + go func() { waitCh <- cmd.Wait() }() + + for { + select { + case sig := <-sigCh: + switch sig { + case os.Interrupt: + // Swallow SIGINT so this wrapper doesn't exit. + // psql still gets SIGINT (same terminal foreground process group). + continue + case syscall.SIGTERM: + // If we're being terminated, pass it through to psql and exit accordingly. + if cmd.Process != nil { + _ = cmd.Process.Signal(syscall.SIGTERM) + } + } + case err := <-waitCh: + // psql exited; now we exit with the same code. + if err == nil { + return + } + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) + } + log.Fatalf("psql failed: %v", err) + } + } +} + +func printCallerIdentity(ctx context.Context, cfg aws.Config) error { + stsClient := sts.NewFromConfig(cfg) + + out, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return fmt.Errorf("STS GetCallerIdentity failed (creds invalid/expired or STS not allowed): %w", err) + } + + fmt.Printf("Caller ARN: %s\n", aws.ToString(out.Arn)) + return nil +}