Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions cmd/rds-iam-psql/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# rds-iam-psql

A simple CLI tool that bridges AWS RDS IAM authentication into an interactive `psql` session. It generates a short-lived IAM auth token and launches `psql` with the token as the password, so you never have to manage database passwords.

## Why?

RDS IAM authentication lets you connect to PostgreSQL using your AWS credentials instead of a static database password. However, the auth tokens are temporary (15 minutes) and cumbersome to generate manually. This tool handles token generation automatically and drops you into a familiar `psql` shell.

## Installation

```bash
go install github.com/corbaltcode/go-libraries/cmd/rds-iam-psql@latest
```

Or build from source:

```bash
cd ./cmd/rds-iam-psql
go build
```

## Prerequisites

- **psql** installed and available in your PATH
- **AWS credentials** configured (via environment variables, `~/.aws/credentials`, IAM role, etc.)
- **RDS IAM authentication enabled** on your database instance
- A database user configured for IAM authentication (created with `CREATE USER myuser WITH LOGIN; GRANT rds_iam TO myuser;`)

## Usage

```bash
rds-iam-psql -host <rds-endpoint> -user <db-user> -db <database-name> [options]
```

### Required Flags

| Flag | Description |
|------|-------------|
| `-host` | RDS endpoint hostname (without port), e.g. `mydb.abc123.us-east-1.rds.amazonaws.com` |
| `-user` | Database username configured for IAM auth |
| `-db` | Database name to connect to |

### Optional Flags

| Flag | Default | Description |
|------|---------|-------------|
| `-port` | `5432` | PostgreSQL port |
| `-region` | auto | AWS region. If omitted, inferred from AWS config or the hostname |
| `-profile` | | AWS shared config profile to use (e.g. `dev`, `prod`) |
| `-psql` | `psql` | Path to the `psql` binary |
| `-sslmode` | `require` | SSL mode (`require`, `verify-full`, etc.) |
| `-search-path` | | PostgreSQL `search_path` to set on connection (e.g. `myschema,public`) |

## Examples

Basic connection:

```bash
rds-iam-psql -host mydb.abc123.us-east-1.rds.amazonaws.com -user app_user -db myapp
```

With a specific AWS profile and schema:

```bash
rds-iam-psql \
-host mydb.abc123.us-east-1.rds.amazonaws.com \
-user app_user \
-db myapp \
-profile production \
-search-path "app_schema,public"
```

Using a non-standard port and explicit region:

```bash
rds-iam-psql \
-host mydb.abc123.us-east-1.rds.amazonaws.com \
-port 5433 \
-user admin \
-db postgres \
-region us-east-1
```

## How It Works

1. Loads your AWS credentials from the standard credential chain
2. Generates a temporary RDS IAM auth token using `auth.BuildAuthToken`
3. Launches `psql` with:
- `PGPASSWORD` set to the auth token
- `PGSSLMODE` set according to `-sslmode`
- `PGOPTIONS` set if `-search-path` is provided
4. Attaches stdin/stdout/stderr for interactive use

## Setting Up IAM Auth on RDS

1. Enable IAM authentication on your RDS instance
2. Create a database user and grant IAM privileges:
```sql
CREATE USER myuser WITH LOGIN;
GRANT rds_iam TO myuser;
```
3. Attach an IAM policy allowing `rds-db:connect` to your AWS user/role:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "rds-db:connect",
"Resource": "arn:aws:rds-db:<region>:<account-id>:dbuser:<dbi-resource-id>/<db-user>"
}
]
}
```
172 changes: 172 additions & 0 deletions cmd/rds-iam-psql/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package main

import (
"context"
"flag"
"fmt"
"log"
"os"
"os/exec"
"os/signal"
"strings"
"syscall"

"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
"github.com/aws/aws-sdk-go-v2/service/sts"
)

func main() {
var (
host = flag.String("host", "", "RDS PostgreSQL endpoint hostname (no port, e.g. mydb.abc123.us-east-1.rds.amazonaws.com)")
port = flag.Int("port", 5432, "RDS PostgreSQL port (default 5432)")
user = flag.String("user", "", "Database user name")
dbName = flag.String("db", "", "Database name")
region = flag.String("region", "", "AWS region for the RDS instance (e.g. us-east-1). If empty, uses AWS config or tries to infer from host.")
profile = flag.String("profile", "", "Optional AWS shared config profile (e.g. dev)")
psqlPath = flag.String("psql", "psql", "Path to psql binary")
sslMode = flag.String("sslmode", "require", "PGSSLMODE for psql (e.g. require, verify-full)")
searchPath = flag.String("search-path", "", "Optional PostgreSQL search_path to set (e.g. 'myschema,public')")
)
flag.Parse()

if *host == "" || *user == "" || *dbName == "" {
log.Fatalf("host, user, and db are required\n\nUsage example:\n %s -host mydb.abc123.us-east-1.rds.amazonaws.com -port 5432 -user myuser -db mydb -search-path \"login,public\" -region us-east-1\n", os.Args[0])
}

ctx := context.Background()

// Load AWS config (standard RDS/IAM auth expects your AWS creds, *not* the DB password).
var cfg aws.Config
var err error
if *profile != "" {
cfg, err = awsconfig.LoadDefaultConfig(ctx, awsconfig.WithSharedConfigProfile(*profile))
} else {
cfg, err = awsconfig.LoadDefaultConfig(ctx)
}
if err != nil {
log.Fatalf("failed to load AWS config: %v", err)
}

// Fail fast + print identity (account/arn/role-ish).
if err := printCallerIdentity(ctx, cfg); err != nil {
log.Fatalf("AWS credentials check failed: %v", err)
}

awsRegion := *region
if awsRegion == "" {
awsRegion = cfg.Region
}

if awsRegion == "" {
log.Fatalf("AWS region is not set; pass -region or set AWS_REGION / configure your AWS profile")
}

endpointWithPort := fmt.Sprintf("%s:%d", *host, *port)

// Generate the IAM auth token.
authToken, err := auth.BuildAuthToken(ctx, endpointWithPort, awsRegion, *user, cfg.Credentials)
if err != nil {
log.Fatalf("failed to build RDS IAM auth token: %v", err)
}

// Prepare psql command. We pass the token through PGPASSWORD and SSL mode via PGSSLMODE.
cmd := exec.Command(
*psqlPath,
"--host", *host,
"--port", fmt.Sprintf("%d", *port),
"--username", *user,
"--dbname", *dbName,
)

// Attach stdio so it behaves like an interactive shell.
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

// Inherit existing env and add PG vars.
env := os.Environ()
env = append(env,
"PGPASSWORD="+authToken,
"PGSSLMODE="+*sslMode,
)

// If a search path is provided, wire it through PGOPTIONS.
if sp := strings.TrimSpace(*searchPath); sp != "" {
add := "-c search_path=" + sp

found := false
for i, e := range env {
if strings.HasPrefix(e, "PGOPTIONS=") {
current := strings.TrimPrefix(e, "PGOPTIONS=")
if strings.TrimSpace(current) == "" {
env[i] = "PGOPTIONS=" + add
} else {
env[i] = "PGOPTIONS=" + current + " " + add
}
found = true
break
}
}
if !found {
env = append(env, "PGOPTIONS="+add)
}
}

cmd.Env = env

// --- Ctrl-C handling ---
// The key idea: keep psql in the same foreground process group so it can read
// from the terminal. We intercept SIGINT only to prevent THIS wrapper from
// exiting; psql will still receive SIGINT normally and cancel the current
// query / line as expected.
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
defer signal.Stop(sigCh)

if err := cmd.Start(); err != nil {
log.Fatalf("failed to start psql: %v", err)
}

waitCh := make(chan error, 1)
go func() { waitCh <- cmd.Wait() }()

for {
select {
case sig := <-sigCh:
switch sig {
case os.Interrupt:
// Swallow SIGINT so this wrapper doesn't exit.
// psql still gets SIGINT (same terminal foreground process group).
continue
case syscall.SIGTERM:
// If we're being terminated, pass it through to psql and exit accordingly.
if cmd.Process != nil {
_ = cmd.Process.Signal(syscall.SIGTERM)
}
}
case err := <-waitCh:
// psql exited; now we exit with the same code.
if err == nil {
return
}
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
log.Fatalf("psql failed: %v", err)
}
}
}

func printCallerIdentity(ctx context.Context, cfg aws.Config) error {
stsClient := sts.NewFromConfig(cfg)

out, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return fmt.Errorf("STS GetCallerIdentity failed (creds invalid/expired or STS not allowed): %w", err)
}

fmt.Printf("Caller ARN: %s\n", aws.ToString(out.Arn))
return nil
}