From 70188217a5e2fd032682b3ba8c3c17c25c41306f Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 30 Oct 2025 05:19:16 +0000 Subject: [PATCH 1/2] Add token federation examples MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two comprehensive examples demonstrating token provider usage: 1. token_federation: Simple external token provider with federation 2. browser_oauth_federation: Full browser OAuth flow with automatic token exchange Both examples show real-world integration patterns for custom authentication. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- examples/browser_oauth_federation/main.go | 454 ++++++++++++++++++++++ examples/token_federation/main.go | 344 ++++++++++++++++ 2 files changed, 798 insertions(+) create mode 100644 examples/browser_oauth_federation/main.go create mode 100644 examples/token_federation/main.go diff --git a/examples/browser_oauth_federation/main.go b/examples/browser_oauth_federation/main.go new file mode 100644 index 00000000..760ba48c --- /dev/null +++ b/examples/browser_oauth_federation/main.go @@ -0,0 +1,454 @@ +package main + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "os" + "strings" + "time" + + dbsql "github.com/databricks/databricks-sql-go" + "github.com/databricks/databricks-sql-go/auth/oauth/u2m" + "github.com/databricks/databricks-sql-go/auth/tokenprovider" + "github.com/joho/godotenv" +) + +func main() { + err := godotenv.Load() + if err != nil { + log.Printf("Warning: .env file not found: %v", err) + } + + fmt.Println("Browser OAuth with Token Federation Test") + fmt.Println("=========================================") + fmt.Println() + + // Get test mode from environment + testMode := os.Getenv("TEST_MODE") + if testMode == "" { + fmt.Println("TEST_MODE not set. Available modes:") + fmt.Println(" passthrough - Account-wide WIF Auth_Flow=0 (Token passthrough)") + fmt.Println(" u2m_federation - Account-wide WIF Auth_Flow=2 (U2M with federation)") + fmt.Println(" u2m_native - Native U2M without federation (baseline)") + fmt.Println(" external_token - Manual token passthrough (for testing exchange)") + os.Exit(1) + } + + switch testMode { + case "passthrough": + testTokenPassthrough() + case "u2m_federation": + testU2MWithFederation() + case "u2m_native": + testU2MNative() + case "external_token": + testExternalTokenWithFederation() + default: + log.Fatalf("Unknown test mode: %s", testMode) + } +} + +// testU2MNative tests native U2M OAuth without token federation (baseline) +func testU2MNative() { + fmt.Println("Test: Native U2M OAuth (Baseline - No Federation)") + fmt.Println("--------------------------------------------------") + fmt.Println("This uses Databricks' built-in OAuth without token exchange") + fmt.Println() + + host := os.Getenv("DATABRICKS_HOST") + httpPath := os.Getenv("DATABRICKS_HTTPPATH") + + if host == "" || httpPath == "" { + log.Fatal("DATABRICKS_HOST and DATABRICKS_HTTPPATH must be set") + } + + fmt.Printf("Host: %s\n", host) + fmt.Printf("HTTP Path: %s\n", httpPath) + fmt.Println() + + // Create U2M authenticator + authenticator, err := u2m.NewAuthenticator(host, 2*time.Minute) + if err != nil { + log.Fatal(err) + } + + // Create connector with native OAuth + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithHTTPPath(httpPath), + dbsql.WithAuthenticator(authenticator), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + // Test connection + fmt.Println("Testing connection with browser OAuth...") + + // First try ping to see if we can establish connection + fmt.Println("Attempting to ping database...") + if err := db.Ping(); err != nil { + log.Fatalf("Ping failed: %v", err) + } + fmt.Println("✓ Ping successful") + + if err := testConnection(db); err != nil { + log.Fatalf("Connection test failed: %v", err) + } + + fmt.Println() + fmt.Println("✓ Native U2M OAuth test PASSED") +} + +// testU2MWithFederation tests U2M OAuth with token federation (Account-wide WIF) +func testU2MWithFederation() { + fmt.Println("Test: U2M OAuth with Token Federation (Account-wide WIF)") + fmt.Println("---------------------------------------------------------") + fmt.Println("This tests Auth_Flow=2: Browser OAuth token → Federation exchange → Databricks token") + fmt.Println() + + host := os.Getenv("DATABRICKS_HOST") + httpPath := os.Getenv("DATABRICKS_HTTPPATH") + externalIdpHost := os.Getenv("EXTERNAL_IDP_HOST") + + if host == "" || httpPath == "" { + log.Fatal("DATABRICKS_HOST and DATABRICKS_HTTPPATH must be set") + } + + if externalIdpHost == "" { + fmt.Println("WARNING: EXTERNAL_IDP_HOST not set, using Databricks host") + externalIdpHost = host + } + + fmt.Printf("Databricks Host: %s\n", host) + fmt.Printf("HTTP Path: %s\n", httpPath) + fmt.Printf("External IdP Host: %s\n", externalIdpHost) + fmt.Println() + + // Step 1: Get token from external IdP using browser OAuth + fmt.Println("Step 1: Getting token from external IdP via browser OAuth...") + baseAuthenticator, err := u2m.NewAuthenticator(externalIdpHost, 2*time.Minute) + if err != nil { + log.Fatalf("Failed to create U2M authenticator: %v", err) + } + + // Wrap U2M authenticator as a token provider + u2mProvider := &U2MTokenProvider{authenticator: baseAuthenticator} + + // Step 2: Wrap with federation provider for automatic token exchange + fmt.Println("Step 2: Setting up federation provider for automatic token exchange...") + federationProvider := tokenprovider.NewFederationProvider(u2mProvider, host) + cachedProvider := tokenprovider.NewCachedTokenProvider(federationProvider) + + // Create connector with federated authentication + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithHTTPPath(httpPath), + dbsql.WithTokenProvider(cachedProvider), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + // Test connection + fmt.Println() + fmt.Println("Step 3: Testing connection (will trigger browser OAuth and token exchange)...") + if err := testConnection(db); err != nil { + log.Fatalf("Connection test failed: %v", err) + } + + fmt.Println() + fmt.Println("✓ U2M with Token Federation test PASSED") + fmt.Println() + fmt.Println("Token flow: Browser OAuth → External IdP Token → Token Exchange → Databricks Token") +} + +// testTokenPassthrough tests manual token passthrough with federation +func testTokenPassthrough() { + fmt.Println("Test: Token Passthrough with Federation (Auth_Flow=0)") + fmt.Println("------------------------------------------------------") + fmt.Println("This tests passing an external token that gets exchanged automatically") + fmt.Println() + + host := os.Getenv("DATABRICKS_HOST") + httpPath := os.Getenv("DATABRICKS_HTTPPATH") + externalToken := os.Getenv("EXTERNAL_TOKEN") + + if host == "" || httpPath == "" { + log.Fatal("DATABRICKS_HOST and DATABRICKS_HTTPPATH must be set") + } + + if externalToken == "" { + log.Fatal("EXTERNAL_TOKEN must be set (get from external IdP)") + } + + fmt.Printf("Host: %s\n", host) + fmt.Printf("Token issuer: %s\n", getTokenIssuer(externalToken)) + fmt.Println() + + // Create static token provider + baseProvider := tokenprovider.NewStaticTokenProvider(externalToken) + + // Wrap with federation provider + federationProvider := tokenprovider.NewFederationProvider(baseProvider, host) + cachedProvider := tokenprovider.NewCachedTokenProvider(federationProvider) + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithHTTPPath(httpPath), + dbsql.WithTokenProvider(cachedProvider), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + fmt.Println("Testing connection with token passthrough...") + if err := testConnection(db); err != nil { + log.Fatalf("Connection test failed: %v", err) + } + + fmt.Println() + fmt.Println("✓ Token Passthrough with Federation test PASSED") +} + +// testExternalTokenWithFederation tests manual token exchange process +func testExternalTokenWithFederation() { + fmt.Println("Test: Manual Token Exchange (for debugging)") + fmt.Println("-------------------------------------------") + fmt.Println() + + host := os.Getenv("DATABRICKS_HOST") + httpPath := os.Getenv("DATABRICKS_HTTPPATH") + externalToken := os.Getenv("EXTERNAL_TOKEN") + + if host == "" || httpPath == "" || externalToken == "" { + log.Fatal("DATABRICKS_HOST, DATABRICKS_HTTPPATH, and EXTERNAL_TOKEN must be set") + } + + fmt.Printf("Host: %s\n", host) + fmt.Printf("Token issuer: %s\n", getTokenIssuer(externalToken)) + fmt.Println() + + // Manual token exchange + fmt.Println("Step 1: Manually exchanging token...") + exchangedToken, err := manualTokenExchange(host, externalToken) + if err != nil { + log.Fatalf("Token exchange failed: %v", err) + } + + fmt.Printf("✓ Token exchange successful\n") + fmt.Printf(" Exchanged token length: %d chars\n", len(exchangedToken)) + fmt.Println() + + // Test connection with exchanged token + fmt.Println("Step 2: Testing connection with exchanged token...") + provider := tokenprovider.NewStaticTokenProvider(exchangedToken) + cachedProvider := tokenprovider.NewCachedTokenProvider(provider) + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithHTTPPath(httpPath), + dbsql.WithTokenProvider(cachedProvider), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + if err := testConnection(db); err != nil { + log.Fatalf("Connection test failed: %v", err) + } + + fmt.Println() + fmt.Println("✓ Manual Token Exchange test PASSED") +} + +// Helper: Test database connection with queries +func testConnection(db *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + // Test 1: Simple query + var result int + err := db.QueryRowContext(ctx, "SELECT 1").Scan(&result) + if err != nil { + return fmt.Errorf("simple query failed: %w", err) + } + fmt.Printf("✓ SELECT 1 returned: %d\n", result) + + // Test 2: Range query + rows, err := db.QueryContext(ctx, "SELECT * FROM RANGE(5)") + if err != nil { + return fmt.Errorf("range query failed: %w", err) + } + defer rows.Close() + + count := 0 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return fmt.Errorf("scan failed: %w", err) + } + count++ + } + fmt.Printf("✓ SELECT FROM RANGE(5) returned %d rows\n", count) + + // Test 3: Current user query + var username string + err = db.QueryRowContext(ctx, "SELECT CURRENT_USER()").Scan(&username) + if err != nil { + return fmt.Errorf("current user query failed: %w", err) + } + fmt.Printf("✓ Connected as user: %s\n", username) + + return nil +} + +// Helper: Manual token exchange (for debugging/testing) +func manualTokenExchange(databricksHost, subjectToken string) (string, error) { + exchangeURL := databricksHost + if !strings.HasPrefix(exchangeURL, "http://") && !strings.HasPrefix(exchangeURL, "https://") { + exchangeURL = "https://" + exchangeURL + } + if !strings.HasSuffix(exchangeURL, "/") { + exchangeURL += "/" + } + exchangeURL += "oidc/v1/token" + + data := url.Values{} + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") + data.Set("scope", "sql") + data.Set("subject_token_type", "urn:ietf:params:oauth:token-type:jwt") + data.Set("subject_token", subjectToken) + + req, err := http.NewRequest("POST", exchangeURL, strings.NewReader(data.Encode())) + if err != nil { + return "", err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "*/*") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("exchange failed with status %d: %s", resp.StatusCode, string(body)) + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + } + + if err := json.Unmarshal(body, &tokenResp); err != nil { + return "", err + } + + return tokenResp.AccessToken, nil +} + +// Helper: Get token issuer from JWT (for logging) +func getTokenIssuer(tokenString string) string { + parts := strings.Split(tokenString, ".") + if len(parts) < 2 { + return "not a JWT" + } + + // Decode payload (second part) + payload, err := decodeBase64(parts[1]) + if err != nil { + return "invalid JWT" + } + + var claims map[string]interface{} + if err := json.Unmarshal(payload, &claims); err != nil { + return "invalid JWT" + } + + if iss, ok := claims["iss"].(string); ok { + return iss + } + + return "unknown" +} + +func decodeBase64(s string) ([]byte, error) { + // Add padding if needed + switch len(s) % 4 { + case 2: + s += "==" + case 3: + s += "=" + } + return io.ReadAll(strings.NewReader(s)) +} + +// U2MTokenProvider wraps U2M authenticator as a TokenProvider +type U2MTokenProvider struct { + authenticator interface { + Authenticate(*http.Request) error + } +} + +func (p *U2MTokenProvider) GetToken(ctx context.Context) (*tokenprovider.Token, error) { + // Create a dummy request to trigger authentication + req, err := http.NewRequestWithContext(ctx, "GET", "http://dummy", nil) + if err != nil { + return nil, err + } + + // Authenticate will add Authorization header + if err := p.authenticator.Authenticate(req); err != nil { + return nil, err + } + + // Extract token from Authorization header + authHeader := req.Header.Get("Authorization") + if authHeader == "" { + return nil, fmt.Errorf("no authorization header set") + } + + // Parse "Bearer " + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid authorization header format") + } + + return &tokenprovider.Token{ + AccessToken: parts[1], + TokenType: parts[0], + }, nil +} + +func (p *U2MTokenProvider) Name() string { + return "u2m-browser-oauth" +} diff --git a/examples/token_federation/main.go b/examples/token_federation/main.go new file mode 100644 index 00000000..e2deeaff --- /dev/null +++ b/examples/token_federation/main.go @@ -0,0 +1,344 @@ +package main + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "strconv" + "time" + + dbsql "github.com/databricks/databricks-sql-go" + "github.com/databricks/databricks-sql-go/auth/tokenprovider" +) + +func main() { + // Get configuration from environment + host := os.Getenv("DATABRICKS_HOST") + httpPath := os.Getenv("DATABRICKS_HTTPPATH") + port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT")) + if err != nil { + port = 443 + } + + fmt.Println("Token Federation Examples") + fmt.Println("=========================") + + // Choose which example to run based on environment variable + example := os.Getenv("TOKEN_EXAMPLE") + if example == "" { + example = "static" + } + + switch example { + case "static": + runStaticTokenExample(host, httpPath, port) + case "external": + runExternalTokenExample(host, httpPath, port) + case "cached": + runCachedTokenExample(host, httpPath, port) + case "custom": + runCustomProviderExample(host, httpPath, port) + case "oauth": + runOAuthServiceExample(host, httpPath, port) + default: + log.Fatalf("Unknown example: %s", example) + } +} + +// Example 1: Static token (simplest case) +func runStaticTokenExample(host, httpPath string, port int) { + fmt.Println("\nExample 1: Static Token Provider") + fmt.Println("---------------------------------") + + token := os.Getenv("DATABRICKS_ACCESS_TOKEN") + if token == "" { + log.Fatal("DATABRICKS_ACCESS_TOKEN not set") + } + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithPort(port), + dbsql.WithHTTPPath(httpPath), + dbsql.WithStaticToken(token), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + // Test the connection + var result int + err = db.QueryRow("SELECT 1").Scan(&result) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("✓ Connected successfully using static token\n") + fmt.Printf("✓ Test query result: %d\n", result) +} + +// Example 2: External token provider (token passthrough) +func runExternalTokenExample(host, httpPath string, port int) { + fmt.Println("\nExample 2: External Token Provider (Passthrough)") + fmt.Println("------------------------------------------------") + + // Simulate getting token from external source + tokenFunc := func() (string, error) { + // In real scenario, this could: + // - Read from a file + // - Call another service + // - Retrieve from a secret manager + // - Get from environment variable + token := os.Getenv("DATABRICKS_ACCESS_TOKEN") + if token == "" { + return "", fmt.Errorf("no token available") + } + fmt.Println(" → Fetching token from external source...") + return token, nil + } + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithPort(port), + dbsql.WithHTTPPath(httpPath), + dbsql.WithExternalToken(tokenFunc), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + // Test the connection + var result int + err = db.QueryRow("SELECT 2").Scan(&result) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("✓ Connected successfully using external token provider\n") + fmt.Printf("✓ Test query result: %d\n", result) +} + +// Example 3: Cached token provider +func runCachedTokenExample(host, httpPath string, port int) { + fmt.Println("\nExample 3: Cached Token Provider") + fmt.Println("--------------------------------") + + callCount := 0 + // Create a token provider that tracks how many times it's called + baseProvider := tokenprovider.NewExternalTokenProvider(func() (string, error) { + callCount++ + fmt.Printf(" → Token provider called (count: %d)\n", callCount) + token := os.Getenv("DATABRICKS_ACCESS_TOKEN") + if token == "" { + return "", fmt.Errorf("no token available") + } + return token, nil + }) + + // Wrap with caching + cachedProvider := tokenprovider.NewCachedTokenProvider(baseProvider) + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithPort(port), + dbsql.WithHTTPPath(httpPath), + dbsql.WithTokenProvider(cachedProvider), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + // Run multiple queries - token should only be fetched once due to caching + for i := 1; i <= 3; i++ { + var result int + err = db.QueryRow(fmt.Sprintf("SELECT %d", i)).Scan(&result) + if err != nil { + log.Fatal(err) + } + fmt.Printf("✓ Query %d result: %d\n", i, result) + } + + fmt.Printf("✓ Token was fetched %d time(s) (should be 1 due to caching)\n", callCount) +} + +// Example 4: Custom token provider with expiry +func runCustomProviderExample(host, httpPath string, port int) { + fmt.Println("\nExample 4: Custom Token Provider with Expiry") + fmt.Println("--------------------------------------------") + + // Custom provider that simulates token with expiry + provider := &CustomExpiringTokenProvider{ + baseToken: os.Getenv("DATABRICKS_ACCESS_TOKEN"), + expiry: 1 * time.Hour, + } + + // Wrap with caching to handle refresh + cachedProvider := tokenprovider.NewCachedTokenProvider(provider) + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithPort(port), + dbsql.WithHTTPPath(httpPath), + dbsql.WithTokenProvider(cachedProvider), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + var result int + err = db.QueryRow("SELECT 42").Scan(&result) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("✓ Connected with custom provider\n") + fmt.Printf("✓ Token expires at: %s\n", provider.lastToken.ExpiresAt.Format(time.RFC3339)) + fmt.Printf("✓ Test query result: %d\n", result) +} + +// Example 5: OAuth service token provider +func runOAuthServiceExample(host, httpPath string, port int) { + fmt.Println("\nExample 5: OAuth Service Token Provider") + fmt.Println("---------------------------------------") + + oauthEndpoint := os.Getenv("OAUTH_TOKEN_ENDPOINT") + clientID := os.Getenv("OAUTH_CLIENT_ID") + clientSecret := os.Getenv("OAUTH_CLIENT_SECRET") + + if oauthEndpoint == "" || clientID == "" || clientSecret == "" { + fmt.Println("⚠ Skipping OAuth example (OAUTH_TOKEN_ENDPOINT, OAUTH_CLIENT_ID, or OAUTH_CLIENT_SECRET not set)") + return + } + + provider := &OAuthServiceTokenProvider{ + endpoint: oauthEndpoint, + clientID: clientID, + clientSecret: clientSecret, + } + + // Wrap with caching for efficiency + cachedProvider := tokenprovider.NewCachedTokenProvider(provider) + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithPort(port), + dbsql.WithHTTPPath(httpPath), + dbsql.WithTokenProvider(cachedProvider), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + var result string + err = db.QueryRow("SELECT 'OAuth Success'").Scan(&result) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("✓ Connected with OAuth service token\n") + fmt.Printf("✓ Test query result: %s\n", result) +} + +// CustomExpiringTokenProvider simulates a provider with token expiry +type CustomExpiringTokenProvider struct { + baseToken string + expiry time.Duration + lastToken *tokenprovider.Token +} + +func (p *CustomExpiringTokenProvider) GetToken(ctx context.Context) (*tokenprovider.Token, error) { + if p.baseToken == "" { + return nil, fmt.Errorf("no base token configured") + } + + fmt.Println(" → Generating new token with expiry...") + p.lastToken = &tokenprovider.Token{ + AccessToken: p.baseToken, + TokenType: "Bearer", + ExpiresAt: time.Now().Add(p.expiry), + } + + return p.lastToken, nil +} + +func (p *CustomExpiringTokenProvider) Name() string { + return "custom-expiring" +} + +// OAuthServiceTokenProvider gets tokens from an OAuth service +type OAuthServiceTokenProvider struct { + endpoint string + clientID string + clientSecret string +} + +func (p *OAuthServiceTokenProvider) GetToken(ctx context.Context) (*tokenprovider.Token, error) { + fmt.Printf(" → Fetching token from OAuth service: %s\n", p.endpoint) + + // Create OAuth request + req, err := http.NewRequestWithContext(ctx, "POST", p.endpoint, nil) + if err != nil { + return nil, err + } + + req.SetBasicAuth(p.clientID, p.clientSecret) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Make request + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("OAuth service returned %d: %s", resp.StatusCode, body) + } + + // Parse response + var tokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + } + + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return nil, err + } + + token := &tokenprovider.Token{ + AccessToken: tokenResp.AccessToken, + TokenType: tokenResp.TokenType, + } + + if tokenResp.ExpiresIn > 0 { + token.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + return token, nil +} + +func (p *OAuthServiceTokenProvider) Name() string { + return "oauth-service" +} From d3f0fb919dc112a8e36be18ed45bf650ec9a686e Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 28 Jan 2026 10:39:17 +0000 Subject: [PATCH 2/2] simplify --- examples/browser_oauth_federation/main.go | 435 ++-------------------- examples/token_federation/main.go | 334 +++++------------ 2 files changed, 115 insertions(+), 654 deletions(-) diff --git a/examples/browser_oauth_federation/main.go b/examples/browser_oauth_federation/main.go index 760ba48c..b19302ad 100644 --- a/examples/browser_oauth_federation/main.go +++ b/examples/browser_oauth_federation/main.go @@ -1,454 +1,83 @@ +// Example: Browser OAuth (U2M) Authentication +// +// This example demonstrates User-to-Machine (U2M) OAuth authentication, +// which opens a browser for the user to log in interactively. +// +// Environment variables: +// - DATABRICKS_HOST: Databricks workspace hostname +// - DATABRICKS_HTTPPATH: SQL warehouse HTTP path package main import ( "context" "database/sql" - "encoding/json" "fmt" - "io" "log" - "net/http" - "net/url" "os" - "strings" "time" dbsql "github.com/databricks/databricks-sql-go" "github.com/databricks/databricks-sql-go/auth/oauth/u2m" - "github.com/databricks/databricks-sql-go/auth/tokenprovider" "github.com/joho/godotenv" ) func main() { - err := godotenv.Load() - if err != nil { - log.Printf("Warning: .env file not found: %v", err) - } - - fmt.Println("Browser OAuth with Token Federation Test") - fmt.Println("=========================================") - fmt.Println() - - // Get test mode from environment - testMode := os.Getenv("TEST_MODE") - if testMode == "" { - fmt.Println("TEST_MODE not set. Available modes:") - fmt.Println(" passthrough - Account-wide WIF Auth_Flow=0 (Token passthrough)") - fmt.Println(" u2m_federation - Account-wide WIF Auth_Flow=2 (U2M with federation)") - fmt.Println(" u2m_native - Native U2M without federation (baseline)") - fmt.Println(" external_token - Manual token passthrough (for testing exchange)") - os.Exit(1) - } - - switch testMode { - case "passthrough": - testTokenPassthrough() - case "u2m_federation": - testU2MWithFederation() - case "u2m_native": - testU2MNative() - case "external_token": - testExternalTokenWithFederation() - default: - log.Fatalf("Unknown test mode: %s", testMode) + // Load .env file if present + if err := godotenv.Load(); err != nil { + log.Printf("Note: .env file not found") } -} - -// testU2MNative tests native U2M OAuth without token federation (baseline) -func testU2MNative() { - fmt.Println("Test: Native U2M OAuth (Baseline - No Federation)") - fmt.Println("--------------------------------------------------") - fmt.Println("This uses Databricks' built-in OAuth without token exchange") - fmt.Println() host := os.Getenv("DATABRICKS_HOST") httpPath := os.Getenv("DATABRICKS_HTTPPATH") if host == "" || httpPath == "" { - log.Fatal("DATABRICKS_HOST and DATABRICKS_HTTPPATH must be set") + log.Fatal("Required: DATABRICKS_HOST and DATABRICKS_HTTPPATH") } + fmt.Println("Browser OAuth (U2M) Example") + fmt.Println("===========================") fmt.Printf("Host: %s\n", host) - fmt.Printf("HTTP Path: %s\n", httpPath) - fmt.Println() + fmt.Printf("Path: %s\n\n", httpPath) - // Create U2M authenticator + // Create U2M authenticator - this will open a browser for login authenticator, err := u2m.NewAuthenticator(host, 2*time.Minute) if err != nil { - log.Fatal(err) + log.Fatalf("Failed to create authenticator: %v", err) } - // Create connector with native OAuth + // Create database connector connector, err := dbsql.NewConnector( dbsql.WithServerHostname(host), dbsql.WithHTTPPath(httpPath), dbsql.WithAuthenticator(authenticator), ) if err != nil { - log.Fatal(err) + log.Fatalf("Failed to create connector: %v", err) } db := sql.OpenDB(connector) defer db.Close() - // Test connection - fmt.Println("Testing connection with browser OAuth...") - - // First try ping to see if we can establish connection - fmt.Println("Attempting to ping database...") + // Test connection - this triggers browser OAuth flow + fmt.Println("Connecting (browser will open for login)...") if err := db.Ping(); err != nil { - log.Fatalf("Ping failed: %v", err) - } - fmt.Println("✓ Ping successful") - - if err := testConnection(db); err != nil { - log.Fatalf("Connection test failed: %v", err) - } - - fmt.Println() - fmt.Println("✓ Native U2M OAuth test PASSED") -} - -// testU2MWithFederation tests U2M OAuth with token federation (Account-wide WIF) -func testU2MWithFederation() { - fmt.Println("Test: U2M OAuth with Token Federation (Account-wide WIF)") - fmt.Println("---------------------------------------------------------") - fmt.Println("This tests Auth_Flow=2: Browser OAuth token → Federation exchange → Databricks token") - fmt.Println() - - host := os.Getenv("DATABRICKS_HOST") - httpPath := os.Getenv("DATABRICKS_HTTPPATH") - externalIdpHost := os.Getenv("EXTERNAL_IDP_HOST") - - if host == "" || httpPath == "" { - log.Fatal("DATABRICKS_HOST and DATABRICKS_HTTPPATH must be set") - } - - if externalIdpHost == "" { - fmt.Println("WARNING: EXTERNAL_IDP_HOST not set, using Databricks host") - externalIdpHost = host - } - - fmt.Printf("Databricks Host: %s\n", host) - fmt.Printf("HTTP Path: %s\n", httpPath) - fmt.Printf("External IdP Host: %s\n", externalIdpHost) - fmt.Println() - - // Step 1: Get token from external IdP using browser OAuth - fmt.Println("Step 1: Getting token from external IdP via browser OAuth...") - baseAuthenticator, err := u2m.NewAuthenticator(externalIdpHost, 2*time.Minute) - if err != nil { - log.Fatalf("Failed to create U2M authenticator: %v", err) - } - - // Wrap U2M authenticator as a token provider - u2mProvider := &U2MTokenProvider{authenticator: baseAuthenticator} - - // Step 2: Wrap with federation provider for automatic token exchange - fmt.Println("Step 2: Setting up federation provider for automatic token exchange...") - federationProvider := tokenprovider.NewFederationProvider(u2mProvider, host) - cachedProvider := tokenprovider.NewCachedTokenProvider(federationProvider) - - // Create connector with federated authentication - connector, err := dbsql.NewConnector( - dbsql.WithServerHostname(host), - dbsql.WithHTTPPath(httpPath), - dbsql.WithTokenProvider(cachedProvider), - ) - if err != nil { - log.Fatal(err) - } - - db := sql.OpenDB(connector) - defer db.Close() - - // Test connection - fmt.Println() - fmt.Println("Step 3: Testing connection (will trigger browser OAuth and token exchange)...") - if err := testConnection(db); err != nil { - log.Fatalf("Connection test failed: %v", err) - } - - fmt.Println() - fmt.Println("✓ U2M with Token Federation test PASSED") - fmt.Println() - fmt.Println("Token flow: Browser OAuth → External IdP Token → Token Exchange → Databricks Token") -} - -// testTokenPassthrough tests manual token passthrough with federation -func testTokenPassthrough() { - fmt.Println("Test: Token Passthrough with Federation (Auth_Flow=0)") - fmt.Println("------------------------------------------------------") - fmt.Println("This tests passing an external token that gets exchanged automatically") - fmt.Println() - - host := os.Getenv("DATABRICKS_HOST") - httpPath := os.Getenv("DATABRICKS_HTTPPATH") - externalToken := os.Getenv("EXTERNAL_TOKEN") - - if host == "" || httpPath == "" { - log.Fatal("DATABRICKS_HOST and DATABRICKS_HTTPPATH must be set") - } - - if externalToken == "" { - log.Fatal("EXTERNAL_TOKEN must be set (get from external IdP)") - } - - fmt.Printf("Host: %s\n", host) - fmt.Printf("Token issuer: %s\n", getTokenIssuer(externalToken)) - fmt.Println() - - // Create static token provider - baseProvider := tokenprovider.NewStaticTokenProvider(externalToken) - - // Wrap with federation provider - federationProvider := tokenprovider.NewFederationProvider(baseProvider, host) - cachedProvider := tokenprovider.NewCachedTokenProvider(federationProvider) - - connector, err := dbsql.NewConnector( - dbsql.WithServerHostname(host), - dbsql.WithHTTPPath(httpPath), - dbsql.WithTokenProvider(cachedProvider), - ) - if err != nil { - log.Fatal(err) - } - - db := sql.OpenDB(connector) - defer db.Close() - - fmt.Println("Testing connection with token passthrough...") - if err := testConnection(db); err != nil { - log.Fatalf("Connection test failed: %v", err) - } - - fmt.Println() - fmt.Println("✓ Token Passthrough with Federation test PASSED") -} - -// testExternalTokenWithFederation tests manual token exchange process -func testExternalTokenWithFederation() { - fmt.Println("Test: Manual Token Exchange (for debugging)") - fmt.Println("-------------------------------------------") - fmt.Println() - - host := os.Getenv("DATABRICKS_HOST") - httpPath := os.Getenv("DATABRICKS_HTTPPATH") - externalToken := os.Getenv("EXTERNAL_TOKEN") - - if host == "" || httpPath == "" || externalToken == "" { - log.Fatal("DATABRICKS_HOST, DATABRICKS_HTTPPATH, and EXTERNAL_TOKEN must be set") - } - - fmt.Printf("Host: %s\n", host) - fmt.Printf("Token issuer: %s\n", getTokenIssuer(externalToken)) - fmt.Println() - - // Manual token exchange - fmt.Println("Step 1: Manually exchanging token...") - exchangedToken, err := manualTokenExchange(host, externalToken) - if err != nil { - log.Fatalf("Token exchange failed: %v", err) - } - - fmt.Printf("✓ Token exchange successful\n") - fmt.Printf(" Exchanged token length: %d chars\n", len(exchangedToken)) - fmt.Println() - - // Test connection with exchanged token - fmt.Println("Step 2: Testing connection with exchanged token...") - provider := tokenprovider.NewStaticTokenProvider(exchangedToken) - cachedProvider := tokenprovider.NewCachedTokenProvider(provider) - - connector, err := dbsql.NewConnector( - dbsql.WithServerHostname(host), - dbsql.WithHTTPPath(httpPath), - dbsql.WithTokenProvider(cachedProvider), - ) - if err != nil { - log.Fatal(err) - } - - db := sql.OpenDB(connector) - defer db.Close() - - if err := testConnection(db); err != nil { - log.Fatalf("Connection test failed: %v", err) + log.Fatalf("Connection failed: %v", err) } + fmt.Println("✓ Connected successfully") - fmt.Println() - fmt.Println("✓ Manual Token Exchange test PASSED") -} - -// Helper: Test database connection with queries -func testConnection(db *sql.DB) error { - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + // Run test queries + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - // Test 1: Simple query var result int - err := db.QueryRowContext(ctx, "SELECT 1").Scan(&result) - if err != nil { - return fmt.Errorf("simple query failed: %w", err) - } - fmt.Printf("✓ SELECT 1 returned: %d\n", result) - - // Test 2: Range query - rows, err := db.QueryContext(ctx, "SELECT * FROM RANGE(5)") - if err != nil { - return fmt.Errorf("range query failed: %w", err) - } - defer rows.Close() - - count := 0 - for rows.Next() { - var id int64 - if err := rows.Scan(&id); err != nil { - return fmt.Errorf("scan failed: %w", err) - } - count++ - } - fmt.Printf("✓ SELECT FROM RANGE(5) returned %d rows\n", count) - - // Test 3: Current user query - var username string - err = db.QueryRowContext(ctx, "SELECT CURRENT_USER()").Scan(&username) - if err != nil { - return fmt.Errorf("current user query failed: %w", err) - } - fmt.Printf("✓ Connected as user: %s\n", username) - - return nil -} - -// Helper: Manual token exchange (for debugging/testing) -func manualTokenExchange(databricksHost, subjectToken string) (string, error) { - exchangeURL := databricksHost - if !strings.HasPrefix(exchangeURL, "http://") && !strings.HasPrefix(exchangeURL, "https://") { - exchangeURL = "https://" + exchangeURL - } - if !strings.HasSuffix(exchangeURL, "/") { - exchangeURL += "/" - } - exchangeURL += "oidc/v1/token" - - data := url.Values{} - data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") - data.Set("scope", "sql") - data.Set("subject_token_type", "urn:ietf:params:oauth:token-type:jwt") - data.Set("subject_token", subjectToken) - - req, err := http.NewRequest("POST", exchangeURL, strings.NewReader(data.Encode())) - if err != nil { - return "", err - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "*/*") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("exchange failed with status %d: %s", resp.StatusCode, string(body)) - } - - var tokenResp struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` - Scope string `json:"scope"` - } - - if err := json.Unmarshal(body, &tokenResp); err != nil { - return "", err + if err := db.QueryRowContext(ctx, "SELECT 1").Scan(&result); err != nil { + log.Fatalf("Query failed: %v", err) } + fmt.Printf("✓ SELECT 1 = %d\n", result) - return tokenResp.AccessToken, nil -} - -// Helper: Get token issuer from JWT (for logging) -func getTokenIssuer(tokenString string) string { - parts := strings.Split(tokenString, ".") - if len(parts) < 2 { - return "not a JWT" - } - - // Decode payload (second part) - payload, err := decodeBase64(parts[1]) - if err != nil { - return "invalid JWT" + var user string + if err := db.QueryRowContext(ctx, "SELECT CURRENT_USER()").Scan(&user); err != nil { + log.Fatalf("Query failed: %v", err) } - - var claims map[string]interface{} - if err := json.Unmarshal(payload, &claims); err != nil { - return "invalid JWT" - } - - if iss, ok := claims["iss"].(string); ok { - return iss - } - - return "unknown" -} - -func decodeBase64(s string) ([]byte, error) { - // Add padding if needed - switch len(s) % 4 { - case 2: - s += "==" - case 3: - s += "=" - } - return io.ReadAll(strings.NewReader(s)) -} - -// U2MTokenProvider wraps U2M authenticator as a TokenProvider -type U2MTokenProvider struct { - authenticator interface { - Authenticate(*http.Request) error - } -} - -func (p *U2MTokenProvider) GetToken(ctx context.Context) (*tokenprovider.Token, error) { - // Create a dummy request to trigger authentication - req, err := http.NewRequestWithContext(ctx, "GET", "http://dummy", nil) - if err != nil { - return nil, err - } - - // Authenticate will add Authorization header - if err := p.authenticator.Authenticate(req); err != nil { - return nil, err - } - - // Extract token from Authorization header - authHeader := req.Header.Get("Authorization") - if authHeader == "" { - return nil, fmt.Errorf("no authorization header set") - } - - // Parse "Bearer " - parts := strings.SplitN(authHeader, " ", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid authorization header format") - } - - return &tokenprovider.Token{ - AccessToken: parts[1], - TokenType: parts[0], - }, nil -} - -func (p *U2MTokenProvider) Name() string { - return "u2m-browser-oauth" + fmt.Printf("✓ Logged in as: %s\n", user) } diff --git a/examples/token_federation/main.go b/examples/token_federation/main.go index e2deeaff..85f675e0 100644 --- a/examples/token_federation/main.go +++ b/examples/token_federation/main.go @@ -1,34 +1,36 @@ +// Example: Token Provider Authentication +// +// This example demonstrates different ways to provide authentication tokens: +// 1. Static token - hardcoded/env token +// 2. External token - dynamic token from a function +// 3. Custom provider - full TokenProvider implementation +// +// Environment variables: +// - DATABRICKS_HOST: Databricks workspace hostname +// - DATABRICKS_HTTPPATH: SQL warehouse HTTP path +// - DATABRICKS_ACCESS_TOKEN: Access token for authentication +// - TOKEN_EXAMPLE: Which example to run (static, external, custom) package main import ( "context" "database/sql" - "encoding/json" + "database/sql/driver" "fmt" - "io" "log" - "net/http" "os" - "strconv" "time" dbsql "github.com/databricks/databricks-sql-go" "github.com/databricks/databricks-sql-go/auth/tokenprovider" + "github.com/joho/godotenv" ) func main() { - // Get configuration from environment - host := os.Getenv("DATABRICKS_HOST") - httpPath := os.Getenv("DATABRICKS_HTTPPATH") - port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT")) - if err != nil { - port = 443 + if err := godotenv.Load(); err != nil { + log.Printf("Note: .env file not found") } - fmt.Println("Token Federation Examples") - fmt.Println("=========================") - - // Choose which example to run based on environment variable example := os.Getenv("TOKEN_EXAMPLE") if example == "" { example = "static" @@ -36,33 +38,31 @@ func main() { switch example { case "static": - runStaticTokenExample(host, httpPath, port) + runStaticExample() case "external": - runExternalTokenExample(host, httpPath, port) - case "cached": - runCachedTokenExample(host, httpPath, port) + runExternalExample() case "custom": - runCustomProviderExample(host, httpPath, port) - case "oauth": - runOAuthServiceExample(host, httpPath, port) + runCustomProviderExample() default: - log.Fatalf("Unknown example: %s", example) + log.Fatalf("Unknown example: %s (use: static, external, custom)", example) } } -// Example 1: Static token (simplest case) -func runStaticTokenExample(host, httpPath string, port int) { - fmt.Println("\nExample 1: Static Token Provider") - fmt.Println("---------------------------------") +// runStaticExample uses a static token from environment variable +func runStaticExample() { + fmt.Println("Static Token Example") + fmt.Println("====================") + host := os.Getenv("DATABRICKS_HOST") + httpPath := os.Getenv("DATABRICKS_HTTPPATH") token := os.Getenv("DATABRICKS_ACCESS_TOKEN") - if token == "" { - log.Fatal("DATABRICKS_ACCESS_TOKEN not set") + + if host == "" || httpPath == "" || token == "" { + log.Fatal("Required: DATABRICKS_HOST, DATABRICKS_HTTPPATH, DATABRICKS_ACCESS_TOKEN") } connector, err := dbsql.NewConnector( dbsql.WithServerHostname(host), - dbsql.WithPort(port), dbsql.WithHTTPPath(httpPath), dbsql.WithStaticToken(token), ) @@ -70,43 +70,34 @@ func runStaticTokenExample(host, httpPath string, port int) { log.Fatal(err) } - db := sql.OpenDB(connector) - defer db.Close() + testConnection(connector) +} - // Test the connection - var result int - err = db.QueryRow("SELECT 1").Scan(&result) - if err != nil { - log.Fatal(err) - } +// runExternalExample uses a function that returns tokens on-demand +func runExternalExample() { + fmt.Println("External Token Example") + fmt.Println("======================") - fmt.Printf("✓ Connected successfully using static token\n") - fmt.Printf("✓ Test query result: %d\n", result) -} + host := os.Getenv("DATABRICKS_HOST") + httpPath := os.Getenv("DATABRICKS_HTTPPATH") -// Example 2: External token provider (token passthrough) -func runExternalTokenExample(host, httpPath string, port int) { - fmt.Println("\nExample 2: External Token Provider (Passthrough)") - fmt.Println("------------------------------------------------") + if host == "" || httpPath == "" { + log.Fatal("Required: DATABRICKS_HOST, DATABRICKS_HTTPPATH") + } - // Simulate getting token from external source + // Token function - called each time a token is needed + // In practice, this could read from a file, call an API, etc. tokenFunc := func() (string, error) { - // In real scenario, this could: - // - Read from a file - // - Call another service - // - Retrieve from a secret manager - // - Get from environment variable token := os.Getenv("DATABRICKS_ACCESS_TOKEN") if token == "" { - return "", fmt.Errorf("no token available") + return "", fmt.Errorf("DATABRICKS_ACCESS_TOKEN not set") } - fmt.Println(" → Fetching token from external source...") + fmt.Println(" → Token fetched from external source") return token, nil } connector, err := dbsql.NewConnector( dbsql.WithServerHostname(host), - dbsql.WithPort(port), dbsql.WithHTTPPath(httpPath), dbsql.WithExternalToken(tokenFunc), ) @@ -114,231 +105,72 @@ func runExternalTokenExample(host, httpPath string, port int) { log.Fatal(err) } - db := sql.OpenDB(connector) - defer db.Close() - - // Test the connection - var result int - err = db.QueryRow("SELECT 2").Scan(&result) - if err != nil { - log.Fatal(err) - } - - fmt.Printf("✓ Connected successfully using external token provider\n") - fmt.Printf("✓ Test query result: %d\n", result) + testConnection(connector) } -// Example 3: Cached token provider -func runCachedTokenExample(host, httpPath string, port int) { - fmt.Println("\nExample 3: Cached Token Provider") - fmt.Println("--------------------------------") +// runCustomProviderExample uses a custom TokenProvider implementation +func runCustomProviderExample() { + fmt.Println("Custom Provider Example") + fmt.Println("=======================") - callCount := 0 - // Create a token provider that tracks how many times it's called - baseProvider := tokenprovider.NewExternalTokenProvider(func() (string, error) { - callCount++ - fmt.Printf(" → Token provider called (count: %d)\n", callCount) - token := os.Getenv("DATABRICKS_ACCESS_TOKEN") - if token == "" { - return "", fmt.Errorf("no token available") - } - return token, nil - }) - - // Wrap with caching - cachedProvider := tokenprovider.NewCachedTokenProvider(baseProvider) - - connector, err := dbsql.NewConnector( - dbsql.WithServerHostname(host), - dbsql.WithPort(port), - dbsql.WithHTTPPath(httpPath), - dbsql.WithTokenProvider(cachedProvider), - ) - if err != nil { - log.Fatal(err) - } - - db := sql.OpenDB(connector) - defer db.Close() + host := os.Getenv("DATABRICKS_HOST") + httpPath := os.Getenv("DATABRICKS_HTTPPATH") + token := os.Getenv("DATABRICKS_ACCESS_TOKEN") - // Run multiple queries - token should only be fetched once due to caching - for i := 1; i <= 3; i++ { - var result int - err = db.QueryRow(fmt.Sprintf("SELECT %d", i)).Scan(&result) - if err != nil { - log.Fatal(err) - } - fmt.Printf("✓ Query %d result: %d\n", i, result) + if host == "" || httpPath == "" || token == "" { + log.Fatal("Required: DATABRICKS_HOST, DATABRICKS_HTTPPATH, DATABRICKS_ACCESS_TOKEN") } - fmt.Printf("✓ Token was fetched %d time(s) (should be 1 due to caching)\n", callCount) -} - -// Example 4: Custom token provider with expiry -func runCustomProviderExample(host, httpPath string, port int) { - fmt.Println("\nExample 4: Custom Token Provider with Expiry") - fmt.Println("--------------------------------------------") - - // Custom provider that simulates token with expiry - provider := &CustomExpiringTokenProvider{ - baseToken: os.Getenv("DATABRICKS_ACCESS_TOKEN"), - expiry: 1 * time.Hour, + // Custom provider with expiry tracking + provider := &ExpiringTokenProvider{ + token: token, + lifetime: 1 * time.Hour, } - // Wrap with caching to handle refresh - cachedProvider := tokenprovider.NewCachedTokenProvider(provider) - connector, err := dbsql.NewConnector( dbsql.WithServerHostname(host), - dbsql.WithPort(port), dbsql.WithHTTPPath(httpPath), - dbsql.WithTokenProvider(cachedProvider), + dbsql.WithTokenProvider(provider), ) if err != nil { log.Fatal(err) } - db := sql.OpenDB(connector) - defer db.Close() - - var result int - err = db.QueryRow("SELECT 42").Scan(&result) - if err != nil { - log.Fatal(err) - } - - fmt.Printf("✓ Connected with custom provider\n") - fmt.Printf("✓ Token expires at: %s\n", provider.lastToken.ExpiresAt.Format(time.RFC3339)) - fmt.Printf("✓ Test query result: %d\n", result) + testConnection(connector) + fmt.Printf(" Token expires: %s\n", provider.expiresAt.Format(time.RFC3339)) } -// Example 5: OAuth service token provider -func runOAuthServiceExample(host, httpPath string, port int) { - fmt.Println("\nExample 5: OAuth Service Token Provider") - fmt.Println("---------------------------------------") - - oauthEndpoint := os.Getenv("OAUTH_TOKEN_ENDPOINT") - clientID := os.Getenv("OAUTH_CLIENT_ID") - clientSecret := os.Getenv("OAUTH_CLIENT_SECRET") - - if oauthEndpoint == "" || clientID == "" || clientSecret == "" { - fmt.Println("⚠ Skipping OAuth example (OAUTH_TOKEN_ENDPOINT, OAUTH_CLIENT_ID, or OAUTH_CLIENT_SECRET not set)") - return - } - - provider := &OAuthServiceTokenProvider{ - endpoint: oauthEndpoint, - clientID: clientID, - clientSecret: clientSecret, - } - - // Wrap with caching for efficiency - cachedProvider := tokenprovider.NewCachedTokenProvider(provider) - - connector, err := dbsql.NewConnector( - dbsql.WithServerHostname(host), - dbsql.WithPort(port), - dbsql.WithHTTPPath(httpPath), - dbsql.WithTokenProvider(cachedProvider), - ) - if err != nil { - log.Fatal(err) - } - +// testConnection verifies the connection works +func testConnection(connector driver.Connector) { db := sql.OpenDB(connector) defer db.Close() - var result string - err = db.QueryRow("SELECT 'OAuth Success'").Scan(&result) - if err != nil { - log.Fatal(err) - } - - fmt.Printf("✓ Connected with OAuth service token\n") - fmt.Printf("✓ Test query result: %s\n", result) -} - -// CustomExpiringTokenProvider simulates a provider with token expiry -type CustomExpiringTokenProvider struct { - baseToken string - expiry time.Duration - lastToken *tokenprovider.Token -} + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() -func (p *CustomExpiringTokenProvider) GetToken(ctx context.Context) (*tokenprovider.Token, error) { - if p.baseToken == "" { - return nil, fmt.Errorf("no base token configured") - } - - fmt.Println(" → Generating new token with expiry...") - p.lastToken = &tokenprovider.Token{ - AccessToken: p.baseToken, - TokenType: "Bearer", - ExpiresAt: time.Now().Add(p.expiry), + var result int + if err := db.QueryRowContext(ctx, "SELECT 1").Scan(&result); err != nil { + log.Fatalf("Query failed: %v", err) } - - return p.lastToken, nil -} - -func (p *CustomExpiringTokenProvider) Name() string { - return "custom-expiring" + fmt.Printf("✓ Connected, SELECT 1 = %d\n", result) } -// OAuthServiceTokenProvider gets tokens from an OAuth service -type OAuthServiceTokenProvider struct { - endpoint string - clientID string - clientSecret string +// ExpiringTokenProvider is a custom TokenProvider with expiry support +type ExpiringTokenProvider struct { + token string + lifetime time.Duration + expiresAt time.Time } -func (p *OAuthServiceTokenProvider) GetToken(ctx context.Context) (*tokenprovider.Token, error) { - fmt.Printf(" → Fetching token from OAuth service: %s\n", p.endpoint) - - // Create OAuth request - req, err := http.NewRequestWithContext(ctx, "POST", p.endpoint, nil) - if err != nil { - return nil, err - } - - req.SetBasicAuth(p.clientID, p.clientSecret) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - // Make request - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("OAuth service returned %d: %s", resp.StatusCode, body) - } - - // Parse response - var tokenResp struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` - } - - if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - return nil, err - } - - token := &tokenprovider.Token{ - AccessToken: tokenResp.AccessToken, - TokenType: tokenResp.TokenType, - } - - if tokenResp.ExpiresIn > 0 { - token.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) - } - - return token, nil +func (p *ExpiringTokenProvider) GetToken(ctx context.Context) (*tokenprovider.Token, error) { + p.expiresAt = time.Now().Add(p.lifetime) + return &tokenprovider.Token{ + AccessToken: p.token, + TokenType: "Bearer", + ExpiresAt: p.expiresAt, + }, nil } -func (p *OAuthServiceTokenProvider) Name() string { - return "oauth-service" +func (p *ExpiringTokenProvider) Name() string { + return "expiring-token" }