From 82f4584d2ab15d25751854a1b773a5236cfd0d2d Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Mon, 12 Jan 2026 14:18:43 -0500 Subject: [PATCH 01/15] Extract OidcAuthenticator trait from Frontegg Authenticator We extract - authenticate - validate_access_token Two methods used to authenticate HTTP and pgwire sessions out of the Frontegg authenticator. The goal is we can reuse these methods for a generic OIDC authenticator, used for self managed SSO --- Cargo.lock | 16 +++++ Cargo.toml | 2 + src/authenticator-types/Cargo.toml | 23 ++++++ src/authenticator-types/src/lib.rs | 60 ++++++++++++++++ src/balancerd/Cargo.toml | 1 + src/balancerd/src/lib.rs | 1 + src/environmentd/Cargo.toml | 1 + src/environmentd/src/http.rs | 1 + src/frontegg-auth/Cargo.toml | 2 + src/frontegg-auth/src/auth.rs | 109 +++++++++++++++++------------ src/pgwire/Cargo.toml | 1 + src/pgwire/src/protocol.rs | 3 +- 12 files changed, 175 insertions(+), 45 deletions(-) create mode 100644 src/authenticator-types/Cargo.toml create mode 100644 src/authenticator-types/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 8442f8133837e..352e7dcc3d1bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5598,6 +5598,17 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "mz-authenticator-types" +version = "0.0.0" +dependencies = [ + "async-trait", + "mz-repr", + "tokio", + "uuid", + "workspace-hack", +] + [[package]] name = "mz-avro" version = "0.7.0" @@ -5686,6 +5697,7 @@ dependencies = [ "launchdarkly-server-sdk", "mz-alloc", "mz-alloc-default", + "mz-authenticator-types", "mz-build-info", "mz-dyncfg", "mz-dyncfg-file", @@ -6311,6 +6323,7 @@ dependencies = [ "mz-alloc-default", "mz-auth", "mz-authenticator", + "mz-authenticator-types", "mz-aws-secrets-controller", "mz-build-info", "mz-catalog", @@ -6559,6 +6572,7 @@ name = "mz-frontegg-auth" version = "0.0.0" dependencies = [ "anyhow", + "async-trait", "axum", "base64 0.22.1", "clap", @@ -6566,6 +6580,7 @@ dependencies = [ "futures", "jsonwebtoken", "lru 0.16.3", + "mz-authenticator-types", "mz-ore", "mz-repr", "prometheus", @@ -7338,6 +7353,7 @@ dependencies = [ "mz-adapter-types", "mz-auth", "mz-authenticator", + "mz-authenticator-types", "mz-frontegg-auth", "mz-ore", "mz-pgcopy", diff --git a/Cargo.toml b/Cargo.toml index 9b9ad955c7c75..ac084c6193e36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "src/audit-log", "src/auth", "src/authenticator", + "src/authenticator-types", "src/avro", "src/aws-secrets-controller", "src/aws-util", @@ -133,6 +134,7 @@ default-members = [ "src/audit-log", "src/auth", "src/authenticator", + "src/authenticator-types", "src/avro", "src/aws-secrets-controller", "src/aws-util", diff --git a/src/authenticator-types/Cargo.toml b/src/authenticator-types/Cargo.toml new file mode 100644 index 0000000000000..71430fdfd7617 --- /dev/null +++ b/src/authenticator-types/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "mz-authenticator-types" +description = "Shared types for Materialize authentication." +version = "0.0.0" +edition.workspace = true +rust-version.workspace = true +publish = false + +[lints] +workspace = true + +[dependencies] +async-trait = "0.1.89" +mz-repr = { path = "../repr" } +tokio = { version = "1.48.0", features = ["macros"] } +uuid = { version = "1.19.0", features = ["serde"] } +workspace-hack = { version = "0.0.0", path = "../workspace-hack", optional = true } + +[package.metadata.cargo-udeps.ignore] +normal = ["workspace-hack"] + +[features] +default = ["workspace-hack"] diff --git a/src/authenticator-types/src/lib.rs b/src/authenticator-types/src/lib.rs new file mode 100644 index 0000000000000..74b88d70157bb --- /dev/null +++ b/src/authenticator-types/src/lib.rs @@ -0,0 +1,60 @@ +// Copyright Materialize, Inc. and contributors. All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +//! Shared types for Materialize authentication. + +use std::fmt::Debug; + +use async_trait::async_trait; + +/// A handle to an authentication session. +/// +/// An authentication session represents a duration of time during which a +/// user's authentication is known to be valid. +/// +/// [`OidcAuthSessionHandle::external_metadata_rx`] can be used to receive events if +/// the session's metadata is updated. +/// +/// [`OidcAuthSessionHandle::expired`] can be used to learn if the session has +/// failed to refresh the validity of the API key. +#[async_trait] +pub trait OidcAuthSessionHandle: Debug + Send { + /// Returns the name of the user that created the session. + fn user(&self) -> &str; + /// Completes when the authentication session has expired. + async fn expired(&mut self); +} + +#[async_trait] +pub trait OidcAuthenticator { + /// The error type for the authenticator. + type Error; + /// The authenticator's session handle type. + type SessionHandle: OidcAuthSessionHandle; + /// Claims that have been validated by [`OidcAuthenticator::validate_access_token`]. + type ValidatedClaims; + /// Establishes a new authentication session. + /// If successful, returns a [`OidcAuthenticator::SessionHandle`] to the authentication session. + /// Otherwise, returns [`OidcAuthenticator::Error`]. + async fn authenticate( + &self, + expected_user: &str, + password: &str, + ) -> Result; + + /// Validates an access token, returning the validated claims. + /// + /// If `expected_user` is provided, the token's user name is additionally + /// validated to match `expected_user`. + fn validate_access_token( + &self, + token: &str, + expected_user: Option<&str>, + ) -> Result; +} diff --git a/src/balancerd/Cargo.toml b/src/balancerd/Cargo.toml index e58daf43c3cf7..245a77bc63f61 100644 --- a/src/balancerd/Cargo.toml +++ b/src/balancerd/Cargo.toml @@ -27,6 +27,7 @@ jsonwebtoken = "9.3.1" launchdarkly-server-sdk = { version = "2.6.2", default-features = false } mz-alloc = { path = "../alloc" } mz-alloc-default = { path = "../alloc-default", optional = true } +mz-authenticator-types = { path = "../authenticator-types" } mz-build-info = { path = "../build-info" } mz-dyncfg-launchdarkly = { path = "../dyncfg-launchdarkly" } mz-dyncfg-file= { path = "../dyncfg-file" } diff --git a/src/balancerd/src/lib.rs b/src/balancerd/src/lib.rs index 67c898939731b..65b9f4a1526da 100644 --- a/src/balancerd/src/lib.rs +++ b/src/balancerd/src/lib.rs @@ -39,6 +39,7 @@ use futures::stream::BoxStream; use hyper::StatusCode; use hyper_util::rt::TokioIo; use launchdarkly_server_sdk as ld; +use mz_authenticator_types::OidcAuthenticator; use mz_build_info::{BuildInfo, build_info}; use mz_dyncfg::ConfigSet; use mz_frontegg_auth::Authenticator as FronteggAuthentication; diff --git a/src/environmentd/Cargo.toml b/src/environmentd/Cargo.toml index 254f13646cf06..60c89f4e3f243 100644 --- a/src/environmentd/Cargo.toml +++ b/src/environmentd/Cargo.toml @@ -43,6 +43,7 @@ mz-alloc = { path = "../alloc" } mz-alloc-default = { path = "../alloc-default", optional = true } mz-auth = { path = "../auth" } mz-authenticator = { path = "../authenticator" } +mz-authenticator-types = { path = "../authenticator-types" } mz-aws-secrets-controller = { path = "../aws-secrets-controller" } mz-build-info = { path = "../build-info" } mz-adapter = { path = "../adapter" } diff --git a/src/environmentd/src/http.rs b/src/environmentd/src/http.rs index e06c7fb87cb25..452a710647208 100644 --- a/src/environmentd/src/http.rs +++ b/src/environmentd/src/http.rs @@ -45,6 +45,7 @@ use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSes use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache}; use mz_auth::password::Password; use mz_authenticator::Authenticator; +use mz_authenticator_types::{OidcAuthSessionHandle, OidcAuthenticator}; use mz_controller::ReplicaHttpLocator; use mz_frontegg_auth::Error as FronteggError; use mz_http_util::DynamicFilterTarget; diff --git a/src/frontegg-auth/Cargo.toml b/src/frontegg-auth/Cargo.toml index 644256f687bee..f8a0d89e30b9b 100644 --- a/src/frontegg-auth/Cargo.toml +++ b/src/frontegg-auth/Cargo.toml @@ -11,12 +11,14 @@ workspace = true [dependencies] anyhow = "1.0.100" +async-trait = "0.1.89" base64 = "0.22.1" clap = { version = "4.5.23", features = ["wrap_help", "env", "derive"] } derivative = "2.2.0" futures = "0.3.31" jsonwebtoken = "9.3.1" lru = "0.16.3" +mz-authenticator-types = { path = "../authenticator-types" } mz-ore = { path = "../ore", features = ["network", "metrics"] } mz-repr = { path = "../repr" } prometheus = { version = "0.14.0", default-features = false } diff --git a/src/frontegg-auth/src/auth.rs b/src/frontegg-auth/src/auth.rs index a80b1a01aa31c..f4c68f1baf86f 100644 --- a/src/frontegg-auth/src/auth.rs +++ b/src/frontegg-auth/src/auth.rs @@ -15,11 +15,13 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use anyhow::Context as _; +use async_trait::async_trait; use derivative::Derivative; use futures::FutureExt; use futures::future::Shared; use jsonwebtoken::{Algorithm, DecodingKey, Validation}; use lru::LruCache; +use mz_authenticator_types::{OidcAuthSessionHandle, OidcAuthenticator}; use mz_ore::instrument; use mz_ore::metrics::MetricsRegistry; use mz_ore::now::NowFn; @@ -166,27 +168,27 @@ impl Authenticator { Ok(Some(Self::new(config, client, registry))) } - /// Establishes a new authentication session. - /// - /// If successful, returns a handle to the authentication session. - /// Otherwise, returns the authentication error. - pub async fn authenticate( - &self, - expected_user: &str, - password: &str, - ) -> Result { - let password: AppPassword = password.parse()?; - match self.authenticate_inner(expected_user, password).await { - Ok(handle) => { - tracing::debug!("authentication successful"); - Ok(handle) - } - Err(e) => { - tracing::debug!(error = ?e, "authentication failed"); - Err(e) - } - } - } + // /// Establishes a new authentication session. + // /// + // /// If successful, returns a handle to the authentication session. + // /// Otherwise, returns the authentication error. + // pub async fn authenticate( + // &self, + // expected_user: &str, + // password: &str, + // ) -> Result { + // let password: AppPassword = password.parse()?; + // match self.authenticate_inner(expected_user, password).await { + // Ok(handle) => { + // tracing::debug!("authentication successful"); + // Ok(handle) + // } + // Err(e) => { + // tracing::debug!(error = ?e, "authentication failed"); + // Err(e) + // } + // } + // } #[instrument(level = "debug", fields(client_id = %password.client_id))] async fn authenticate_inner( @@ -298,6 +300,34 @@ impl Authenticator { }; request.await } +} + +#[async_trait] +impl OidcAuthenticator for Authenticator { + type Error = Error; + type SessionHandle = AuthSessionHandle; + type ValidatedClaims = ValidatedClaims; + /// Establishes a new authentication session. + /// + /// If successful, returns a handle to the authentication session. + /// Otherwise, returns the authentication error. + async fn authenticate( + &self, + expected_user: &str, + password: &str, + ) -> Result { + let password: AppPassword = password.parse()?; + match self.authenticate_inner(expected_user, password).await { + Ok(handle) => { + tracing::debug!("authentication successful"); + Ok(handle) + } + Err(e) => { + tracing::debug!(error = ?e, "authentication failed"); + Err(e) + } + } + } /// Validates an access token, returning the validated claims. /// @@ -309,7 +339,7 @@ impl Authenticator { /// /// If `expected_user` is provided, the token's user name is additionally /// validated to match `expected_user`. - pub fn validate_access_token( + fn validate_access_token( &self, token: &str, expected_user: Option<&str>, @@ -317,23 +347,13 @@ impl Authenticator { self.inner.validate_access_token(token, expected_user) } } - /// A handle to an authentication session. /// -/// An authentication session represents a duration of time during which a -/// user's authentication is known to be valid. -/// /// An authentication session begins with a successful API key exchange with /// Frontegg. While there is at least one outstanding handle to the session, the /// session's metadata and validity are refreshed with Frontegg at a regular /// interval. The session ends when all outstanding handles are dropped and the /// refresh interval is reached. -/// -/// [`AuthSessionHandle::external_metadata_rx`] can be used to receive events if -/// the session's metadata is updated. -/// -/// [`AuthSessionHandle::expired`] can be used to learn if the session has -/// failed to refresh the validity of the API key. #[derive(Debug, Clone)] pub struct AuthSessionHandle { ident: Arc, @@ -344,29 +364,30 @@ pub struct AuthSessionHandle { app_password: AppPassword, } -impl AuthSessionHandle { - /// Returns the name of the user that created the session. - pub fn user(&self) -> &str { +#[async_trait] +impl OidcAuthSessionHandle for AuthSessionHandle { + fn user(&self) -> &str { &self.ident.user } + async fn expired(&mut self) { + // We piggyback on the external metadata channel to determine session + // expiration. The external metadata channel is closed when the session + // expires. + let _ = self.external_metadata_rx.wait_for(|_| false).await; + } +} + +impl AuthSessionHandle { /// Returns the ID of the tenant that created the session. pub fn tenant_id(&self) -> Uuid { self.ident.tenant_id } - /// Mints a receiver for updates to the session user's external metadata. + /// Returns a receiver for updates to the session user's external metadata. pub fn external_metadata_rx(&self) -> watch::Receiver { self.external_metadata_rx.clone() } - - /// Completes when the authentication session has expired. - pub async fn expired(&mut self) { - // We piggyback on the external metadata channel to determine session - // expiration. The external metadata channel is closed when the session - // expires. - let _ = self.external_metadata_rx.wait_for(|_| false).await; - } } impl Drop for AuthSessionHandle { diff --git a/src/pgwire/Cargo.toml b/src/pgwire/Cargo.toml index e66ffa2306664..7d0b5d503f914 100644 --- a/src/pgwire/Cargo.toml +++ b/src/pgwire/Cargo.toml @@ -23,6 +23,7 @@ mz-adapter = { path = "../adapter" } mz-adapter-types = { path = "../adapter-types" } mz-auth = { path = "../auth" } mz-authenticator = { path = "../authenticator" } +mz-authenticator-types = { path = "../authenticator-types" } mz-frontegg-auth = { path = "../frontegg-auth" } mz-ore = { path = "../ore", features = ["tracing"] } mz-pgcopy = { path = "../pgcopy" } diff --git a/src/pgwire/src/protocol.rs b/src/pgwire/src/protocol.rs index c6a3d68301a99..c5a45507b0822 100644 --- a/src/pgwire/src/protocol.rs +++ b/src/pgwire/src/protocol.rs @@ -31,6 +31,7 @@ use mz_adapter::{ }; use mz_auth::password::Password; use mz_authenticator::Authenticator; +use mz_authenticator_types::{OidcAuthSessionHandle, OidcAuthenticator}; use mz_ore::cast::CastFrom; use mz_ore::netio::AsyncReady; use mz_ore::now::{EpochMillis, SYSTEM_TIME}; @@ -229,7 +230,7 @@ where internal_user_metadata: None, helm_chart_version, }); - let expired = async move { auth_session.expired().await }; + let expired = async move { auth_session.expired().await }.boxed(); (session, expired.left_future()) } Err(err) => { From 424f43872ebafff938d5759332214022429753ab Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Wed, 14 Jan 2026 12:11:28 -0500 Subject: [PATCH 02/15] Extract password retrieval into own method --- src/pgwire/src/protocol.rs | 58 +++++++------------------------------- 1 file changed, 10 insertions(+), 48 deletions(-) diff --git a/src/pgwire/src/protocol.rs b/src/pgwire/src/protocol.rs index c5a45507b0822..09032c41d1c4f 100644 --- a/src/pgwire/src/protocol.rs +++ b/src/pgwire/src/protocol.rs @@ -184,30 +184,11 @@ where let (mut session, expired) = match authenticator { Authenticator::Frontegg(frontegg) => { - conn.send(BackendMessage::AuthenticationCleartextPassword) - .await?; - conn.flush().await?; - let password = match conn.recv().await? { - Some(FrontendMessage::RawAuthentication(data)) => { - match decode_password(Cursor::new(&data)).ok() { - Some(FrontendMessage::Password { password }) => password, - _ => { - return conn - .send(ErrorResponse::fatal( - SqlState::INVALID_AUTHORIZATION_SPECIFICATION, - "expected Password message", - )) - .await; - } - } - } - _ => { - return conn - .send(ErrorResponse::fatal( - SqlState::INVALID_AUTHORIZATION_SPECIFICATION, - "expected Password message", - )) - .await; + let password = match request_cleartext_password(conn).await { + Ok(password) => password, + Err(PasswordRequestError::IoError(e)) => return Err(e), + Err(PasswordRequestError::InvalidPasswordError(e)) => { + return conn.send(e).await; } }; @@ -245,30 +226,11 @@ where } } Authenticator::Password(adapter_client) => { - conn.send(BackendMessage::AuthenticationCleartextPassword) - .await?; - conn.flush().await?; - let password = match conn.recv().await? { - Some(FrontendMessage::RawAuthentication(data)) => { - match decode_password(Cursor::new(&data)).ok() { - Some(FrontendMessage::Password { password }) => Password(password), - _ => { - return conn - .send(ErrorResponse::fatal( - SqlState::INVALID_AUTHORIZATION_SPECIFICATION, - "expected Password message", - )) - .await; - } - } - } - _ => { - return conn - .send(ErrorResponse::fatal( - SqlState::INVALID_AUTHORIZATION_SPECIFICATION, - "expected Password message", - )) - .await; + let password = match request_cleartext_password(conn).await { + Ok(password) => Password(password), + Err(PasswordRequestError::IoError(e)) => return Err(e), + Err(PasswordRequestError::InvalidPasswordError(e)) => { + return conn.send(e).await; } }; let auth_response = match adapter_client.authenticate(&user, &password).await { From fdb2c9ae1b5f931481b67b70a2a8a208df21a757 Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Thu, 15 Jan 2026 13:24:56 -0500 Subject: [PATCH 03/15] Implement OIDC prototype - Create an OIDC authenticator kind and a minimal set of config variables using CLI args - Implement JWK fetch on validate and also cache by the JWK key id. --- Cargo.lock | 7 + oidc_auth_setup.md | 83 +++++ src/authenticator-types/src/lib.rs | 2 +- src/authenticator/Cargo.toml | 7 + src/authenticator/src/lib.rs | 5 + src/authenticator/src/oidc.rs | 300 ++++++++++++++++++ src/environmentd/src/environmentd/main.rs | 11 + src/environmentd/src/http.rs | 27 +- src/environmentd/src/lib.rs | 19 +- src/environmentd/src/test_util.rs | 2 + src/frontegg-auth/src/auth.rs | 24 +- .../ci/listener_configs/oidc.json | 38 +++ src/pgwire/src/protocol.rs | 96 ++++++ src/server-core/src/listeners.rs | 2 + 14 files changed, 596 insertions(+), 27 deletions(-) create mode 100644 oidc_auth_setup.md create mode 100644 src/authenticator/src/oidc.rs create mode 100644 src/materialized/ci/listener_configs/oidc.json diff --git a/Cargo.lock b/Cargo.lock index 352e7dcc3d1bf..b8bd29375a8a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5592,9 +5592,16 @@ dependencies = [ name = "mz-authenticator" version = "0.1.0" dependencies = [ + "async-trait", + "jsonwebtoken", "mz-adapter", + "mz-authenticator-types", "mz-frontegg-auth", + "reqwest", "serde", + "tokio", + "tracing", + "url", "workspace-hack", ] diff --git a/oidc_auth_setup.md b/oidc_auth_setup.md new file mode 100644 index 0000000000000..9a9dcb3a0bebe --- /dev/null +++ b/oidc_auth_setup.md @@ -0,0 +1,83 @@ +## Setting + +PGOAUTHDEBUG=UNSAFE psql 'host=192.168.215.3 user=employees dbname=promo oauth_issuer=http://host.docker.internal:4444 oauth_client_id=1186624a-7bed-44f8-867e-d3938a29c924 oauth_client_secret=88OlbvOkiDaPSypu94qK_WHDjG' + + +code_client=$(docker compose -f quickstart.yml exec hydra \ + hydra create client \ + --endpoint http://127.0.0.1:4445 \ + --grant-type authorization_code,refresh_token \ + --response-type code,id_token \ + --format json \ + --scope openid --scope offline --scope profile --scope email\ + --access-token-strategy jwt \ + --redirect-uri http://127.0.0.1:5555/callback) + +code_client_id=$(echo $code_client | jq -r '.client_id') +code_client_secret=$(echo $code_client | jq -r '.client_secret') + +docker compose -f quickstart.yml exec hydra \ + hydra perform authorization-code \ + --client-id $code_client_id \ + --client-secret $code_client_secret \ + --endpoint http://127.0.0.1:4444/ \ + --port 5555 \ + --scope openid --scope offline --scope profile --scope email + +## Deleting a client +hydra delete oauth2-client --endpoint http://localhost:4445 b1a93de1-e4dd-4da9-8e81-083ec4e89f6e=$ + +client id: 060a4f3d-1cac-46e4-b5a5-6b9c66cd9431 +secret: wAghHCKR_E26yuLRpSkaoz2epq + + + + + +device + +device_client=$(docker compose -f quickstart.yml exec hydra \ + hydra create client \ + --endpoint http://127.0.0.1:4445 \ + --format json \ + --name "my device app" \ + --grant-type urn:ietf:params:oauth:grant-type:device_code,refresh_token \ + --token-endpoint-auth-method none \ + --access-token-strategy jwt \ + --scope openid,offline_access,profile) + +device_client_id=$(echo $device_client | jq -r '.client_id') +device_client_secret=$(echo $device_client | jq -r '.client_secret') + +echo $device_client_id +echo $device_client_secret + + +docker compose -f quickstart.yml exec hydra \ + hydra perform device-code \ + --client-id $device_client_id \ + --client-secret $device_client_secret \ + --endpoint http://127.0.0.1:4444/ \ + --scope openid,offline_access + + +Visit http://host.docker.internal:4444/oauth2/device/verify and enter the code: mpGRAMPk + + +http://localhost:4444/.well-known/jwks.json + + bin/environmentd -- \ + --oidc-issuer="http://127.0.0.1:4444" \ + --listeners-config-path='src/materialized/ci/listener_configs/oidc.json' + + +Access token: + +eyJhbGciOiJSUzI1NiIsImtpZCI6Ijk3ZTJmOTJhLWM2YjQtNDQ0ZC1hNjZhLWY3Y2YwOTIwNzdhMyIsInR5cCI6IkpXVCJ9.eyJhdWQiOltdLCJjbGllbnRfaWQiOiJlMGZkYzJkZC05YTU1LTRlZjEtYWU2Zi05YjQyYTJhMzA3NWYiLCJleHAiOjE3NjgyNTgzODksImV4dCI6e30sImlhdCI6MTc2ODI1NDc4OCwiaXNzIjoiaHR0cDovLzEyNy4wLjAuMTo0NDQ0IiwianRpIjoiNTk1N2RiN2YtZGVjZi00OWE0LTliNTMtOTBiM2M1ZDhlMWI5IiwibmJmIjoxNzY4MjU0Nzg4LCJzY3AiOlsib3BlbmlkIiwib2ZmbGluZSIsInByb2ZpbGUiLCJlbWFpbCJdLCJzdWIiOiJmb29AYmFyLmNvbSJ9.w_Vype6NRh_gAIowtja24alhINfXOGRavwq9Nd3gu1tW5l1zxDza6X5iPhMmSTnnlDm1dbekAVQ8tldZs5XDfycDFsFfuMa-IvsoQ3GUyglGMv-hdVa8hDBLGLonVkn5fZDAiotnRzDKZo1qGJ7nDGkV1_oO7DE_BqlXC6OebqQKXdyzZI4xXrreQvQCF0JiW4Kz7F3FZrJeIMyBgMgwgt1spi6YFuER-08l0ZPotrQ20KGhTHy0k-zpyjPUZA8vm8AAiyePvgIHh4pAm_0k4gG_fcX6rw5Hv3UsNtDH42b2QQhGgqY_gvBTNCxCW_wHmHtrgFYiIH7N3NwQE36ZJLSAVL9xuVdaV9km1ZSHAnJ5TdXrtB1wEEsjwYFIrv0AwUv-mlUk0QS7E_8Wv_-BqwgbE4TjdcTIe2-S85N3i7w_LkJT5D1tIwSKlotXCfRV_nTvrWAwar9bLBdynBXYAhwpzASCub_L4qqwCvrWnOPYIHqb9EQFsIEqYaKv_Iz5BLUaMC4fgymSDpsb_kujlQNmR1R_EfMIZA2noFQ8HZ3JJfWckYgLLpJL5RhHDZoQIQOpZL2RvXE1Ud8roT1f2sRGqotNJ93PcBuISzVzJ5ov3ZM7VU3QjhQ4q4Z5q6BQIAFW_j8WoLbWZ9KlCbhVTzXu0usjeI4BHf0_HluIkNE + + Run ``` + PGPASSWORD="eyJhbGciOiJSUzI1NiIsImtpZCI6Ijk3ZTJmOTJhLWM2YjQtNDQ0ZC1hNjZhLWY3Y2YwOTIwNzdhMyIsInR5cCI6IkpXVCJ9.eyJhdWQiOltdLCJjbGllbnRfaWQiOiIyY2MzZTE0Ny1iZTMxLTQzNGYtOWNiYy02MzY3Y2Y0NDE3Y2MiLCJleHAiOjE3Njg5Mzk5MTYsImV4dCI6e30sImlhdCI6MTc2ODkzNjMxNiwiaXNzIjoiaHR0cDovLzEyNy4wLjAuMTo0NDQ0IiwianRpIjoiMjlkMjY4YjgtOTc1ZC00NzhkLTkzOTktNDNjMDk2MTVmZGJlIiwibmJmIjoxNzY4OTM2MzE2LCJzY3AiOlsib3BlbmlkIiwib2ZmbGluZSIsInByb2ZpbGUiLCJlbWFpbCJdLCJzdWIiOiJmb29AYmFyLmNvbSJ9.LlxflqxSf8l0EsgR0F1mVW2JVBYAsVBeElZcqUK6CD-_wBfvgQzCQlIVZNRNYnVQYZwrDzJvrv4-niysYxJ-PRgMAp826nuZIjgQz5qMDovmrqO6UFXSW6pA1rR4N1tQVAYmCAoS-O3PgntJHUE5vpX19D28YbyyS60Yo1u5KzXkQdqZPrkStUhcHXJP-4CfX43k9ginbn23XtKqp0NXzMbJRXk1wY4P6Si-9pqeqLiD7CtNCRXbFowFBePsr9cYoQJRV6ausBRTFE7mXxybU7NFKvuerGNUFI5u0LKzuRsvhw5iuJHPi2PLxrxWjqh3idKgCRbFGJR-Vk63El7Z4O-piwREShyNDfAU1_KREIPt5-zxCp-qe0JhCYEPVikICRT2NF0c29ZzvaxsGEOb8PZBXgPRncfDAsz-fXTOr2MXuGsxIZBgcRx6oDR2mnGZKIXrLqRiBDwik66M2LDE7x5FQZqiha0y2_h_PwNhnDdWqcRAb-l48FqtZCXDi5V4zyfYCw24sWhXVyqLi5WGozavVSxCvZmUP3Qd1OvE9j2n3JOjlecY-7G3ccV1Te_uYNcALyo2DRaiA1mO7XHhmh-9W1_DOZWljmYF9j6qJhMft38N-6fB4Wp8U7vKwdHVfBk5dHb8q95qaLefJkaXk79vA28Wmu6_LHn7EFC0U_M" psql -h localhost -p 6875 -U foo@bar.com materialize + ``` diff --git a/src/authenticator-types/src/lib.rs b/src/authenticator-types/src/lib.rs index 74b88d70157bb..30793a325be28 100644 --- a/src/authenticator-types/src/lib.rs +++ b/src/authenticator-types/src/lib.rs @@ -52,7 +52,7 @@ pub trait OidcAuthenticator { /// /// If `expected_user` is provided, the token's user name is additionally /// validated to match `expected_user`. - fn validate_access_token( + async fn validate_access_token( &self, token: &str, expected_user: Option<&str>, diff --git a/src/authenticator/Cargo.toml b/src/authenticator/Cargo.toml index 3c34cd66028cb..b7976b026196c 100644 --- a/src/authenticator/Cargo.toml +++ b/src/authenticator/Cargo.toml @@ -8,9 +8,16 @@ rust-version.workspace = true publish = false [dependencies] +async-trait = "0.1" +jsonwebtoken = "9.3.1" mz-adapter = { path = "../adapter", default-features = false } +mz-authenticator-types = { path = "../authenticator-types" } mz-frontegg-auth = { path = "../frontegg-auth", default-features = false } +reqwest = "0.12.24" serde = { version = "1.0.219", features = ["derive"] } +tokio = { version = "1.48.0", default-features = false, features = ["sync"] } +tracing = "0.1.43" +url = "2.5.7" workspace-hack = { version = "0.0.0", path = "../workspace-hack", optional = true } [lints] diff --git a/src/authenticator/src/lib.rs b/src/authenticator/src/lib.rs index 6e1527d13b742..e45ce4b20122b 100644 --- a/src/authenticator/src/lib.rs +++ b/src/authenticator/src/lib.rs @@ -7,13 +7,18 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +pub mod oidc; + use mz_adapter::Client as AdapterClient; use mz_frontegg_auth::Authenticator as FronteggAuthenticator; +pub use oidc::{GenericOidcAuthenticator, OidcConfig, OidcError}; + #[derive(Debug, Clone)] pub enum Authenticator { Frontegg(FronteggAuthenticator), Password(AdapterClient), Sasl(AdapterClient), + Oidc(GenericOidcAuthenticator), None, } diff --git a/src/authenticator/src/oidc.rs b/src/authenticator/src/oidc.rs new file mode 100644 index 0000000000000..0451e2070c2d5 --- /dev/null +++ b/src/authenticator/src/oidc.rs @@ -0,0 +1,300 @@ +// Copyright Materialize, Inc. and contributors. All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +//! OIDC Authentication for pgwire connections. +//! +//! This module provides JWT-based authentication using OpenID Connect (OIDC). +//! JWTs are validated locally using JWKS fetched from the configured provider. + +use std::collections::BTreeMap; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use async_trait::async_trait; +use jsonwebtoken::{DecodingKey, Validation, decode, decode_header, jwk::JwkSet}; +use mz_authenticator_types::{OidcAuthSessionHandle, OidcAuthenticator}; +use reqwest::Client as HttpClient; +use serde::{Deserialize, Serialize}; + +use tracing::warn; + +/// Command line arguments for OIDC authentication. +#[derive(Debug, Clone)] +pub struct OidcConfig { + /// OIDC issuer URL (e.g., "https://accounts.google.com"). + /// This is validated against the `iss` claim in the JWT. + pub oidc_issuer: String, +} + +/// Errors that can occur during OIDC authentication. +#[derive(Debug)] +pub enum OidcError { + /// Failed to parse OIDC configuration URL. + InvalidConfigUrl(url::ParseError), + /// Failed to fetch JWKS from provider. + JwksFetchFailed(String), + /// The key ID is missing in the token header. + MissingKid, + /// No matching key found in JWKS. + NoMatchingKey, + /// JWT validation error from jsonwebtoken. + Jwt(jsonwebtoken::errors::Error), + /// User does not match expected value. + WrongUser, +} + +impl std::fmt::Display for OidcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OidcError::InvalidConfigUrl(e) => { + write!(f, "failed to parse OIDC configuration URL: {}", e) + } + OidcError::JwksFetchFailed(e) => write!(f, "failed to fetch JWKS: {}", e), + OidcError::MissingKid => write!(f, "missing key ID in token header"), + OidcError::NoMatchingKey => write!(f, "no matching key in JWKS"), + OidcError::Jwt(e) => write!(f, "JWT error: {}", e), + OidcError::WrongUser => write!(f, "user does not match expected value"), + } + } +} + +impl std::error::Error for OidcError {} + +/// Claims extracted from a validated JWT. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OidcClaims { + /// Subject (user identifier). + pub sub: String, + /// Issuer. + pub iss: String, + /// Expiration time (Unix timestamp). + pub exp: i64, + /// Issued at time (Unix timestamp). + #[serde(default)] + pub iat: Option, + /// Email claim (commonly used for username). + #[serde(default)] + pub email: Option, +} + +impl OidcClaims { + /// Extract the username to use for the session. + /// + /// Priority: email > sub + // TODO (SangJunBak): Add a configuration variable to use a different username field. + pub fn username(&self) -> &str { + self.email.as_deref().unwrap_or(&self.sub) + } +} + +#[derive(Clone)] +struct OidcDecodingKey(DecodingKey); + +impl std::fmt::Debug for OidcDecodingKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("JWKS").field("key", &"").finish() + } +} + +/// Session handle for generic OIDC authentication. +#[derive(Debug)] +pub struct GenericOidcSessionHandle { + user: String, +} + +#[async_trait] +impl OidcAuthSessionHandle for GenericOidcSessionHandle { + fn user(&self) -> &str { + &self.user + } + + async fn expired(&mut self) { + // This session never expires - wait forever + // TODO (SangJunBak): Implement expiration. + std::future::pending().await + } +} + +/// OIDC Authenticator that validates JWTs using JWKS. +/// +/// This implementation pre-fetches JWKS at construction time for synchronous +/// token validation. +#[derive(Clone, Debug)] +pub struct GenericOidcAuthenticator { + inner: Arc, +} + +#[derive(Debug)] +pub struct GenericOidcAuthenticatorInner { + issuer: String, + jwks_uri: String, + decoding_keys: Mutex>, + http_client: HttpClient, +} + +impl GenericOidcAuthenticator { + /// Create a new [`GenericOidcAuthenticator`] from [`OidcConfig`]. + pub fn new(config: OidcConfig) -> Result { + let issuer_url = + url::Url::parse(&config.oidc_issuer).map_err(OidcError::InvalidConfigUrl)?; + + // TODO (SangJunBak): Add a configuration variable for the JWKS set and + // a boolean jwksFetchFromIssuer. + let jwks_uri = issuer_url + .join(".well-known/jwks.json") + .map_err(OidcError::InvalidConfigUrl)? + .to_string(); + + Ok(Self { + inner: Arc::new(GenericOidcAuthenticatorInner { + issuer: config.oidc_issuer, + jwks_uri, + decoding_keys: Mutex::new(BTreeMap::new()), + http_client: HttpClient::new(), + }), + }) + } +} + +impl GenericOidcAuthenticatorInner { + /// Fetch JWKS from the provider and parse into a map of key IDs to decoding keys. + async fn fetch_jwks(self: &Self) -> Result, OidcError> { + let response = self + .http_client + .get(&self.jwks_uri) + .timeout(Duration::from_secs(10)) + .send() + .await + .map_err(|e| OidcError::JwksFetchFailed(e.to_string()))?; + + if !response.status().is_success() { + return Err(OidcError::JwksFetchFailed(format!( + "HTTP {}", + response.status() + ))); + } + + let jwks: JwkSet = response + .json() + .await + .map_err(|e| OidcError::JwksFetchFailed(e.to_string()))?; + + let mut keys = BTreeMap::new(); + for jwk in jwks.keys { + match DecodingKey::from_jwk(&jwk) { + Ok(key) => { + if let Some(kid) = jwk.common.key_id { + keys.insert(kid, OidcDecodingKey(key)); + } + } + Err(e) => { + warn!("Failed to parse JWK: {}", e); + } + } + } + + if keys.is_empty() { + return Err(OidcError::JwksFetchFailed( + "no valid keys found in JWKS".to_string(), + )); + } + + Ok(keys) + } + + /// Find a decoding key matching the given key ID. + /// If the key is not found, fetch the JWKS and cache the keys. + async fn find_key(&self, kid: &String) -> Result { + { + let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned"); + + if let Some(key) = decoding_keys.get_mut(kid) { + return Ok(key.clone()); + } + } + + let new_decoding_keys = self.fetch_jwks().await?; + + let decoding_key = new_decoding_keys.get(kid).cloned(); + + { + let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned"); + decoding_keys.extend(new_decoding_keys); + } + + if let Some(key) = decoding_key { + return Ok(key); + } + + Err(OidcError::NoMatchingKey) + } + + async fn validate_access_token( + &self, + token: &str, + expected_user: Option<&str>, + ) -> Result { + // Decode header to get key ID (kid) and algorithm + let header = decode_header(token).map_err(OidcError::Jwt)?; + + let kid = header.kid.ok_or(OidcError::MissingKid)?; + // Find matching key from cached keys + let decoding_key = self.find_key(&kid).await?; + + // Set up validation + // TODO (SangJunBak): Make JWT expiration configurable. + let mut validation = Validation::new(header.alg); + validation.set_issuer(&[&self.issuer]); + // TODO (SangJunBak): Validate audience based on configuration. + validation.validate_aud = false; + + // Decode and validate the token + let token_data = + decode::(token, &(decoding_key.0), &validation).map_err(OidcError::Jwt)?; + + // Optionally validate expected user + if let Some(expected) = expected_user { + if token_data.claims.username() != expected { + return Err(OidcError::WrongUser); + } + } + + Ok(token_data.claims) + } +} + +#[async_trait] +impl OidcAuthenticator for GenericOidcAuthenticator { + type Error = OidcError; + type SessionHandle = GenericOidcSessionHandle; + type ValidatedClaims = OidcClaims; + + async fn authenticate( + &self, + expected_user: &str, + password: &str, + ) -> Result { + // The password is the JWT token + let claims = self + .validate_access_token(password, Some(expected_user)) + .await?; + + Ok(GenericOidcSessionHandle { + user: claims.username().to_string(), + }) + } + + async fn validate_access_token( + &self, + token: &str, + expected_user: Option<&str>, + ) -> Result { + self.inner.validate_access_token(token, expected_user).await + } +} diff --git a/src/environmentd/src/environmentd/main.rs b/src/environmentd/src/environmentd/main.rs index 06d086e8e6567..879768bbbb30b 100644 --- a/src/environmentd/src/environmentd/main.rs +++ b/src/environmentd/src/environmentd/main.rs @@ -36,6 +36,7 @@ use mz_adapter_types::bootstrap_builtin_cluster_config::{ SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR, }; use mz_auth::password::Password; +use mz_authenticator::{GenericOidcAuthenticator, OidcConfig}; use mz_aws_secrets_controller::AwsSecretsController; use mz_build_info::BuildInfo; use mz_catalog::config::ClusterReplicaSizeMap; @@ -171,6 +172,10 @@ pub struct Args { /// Frontegg arguments. #[clap(flatten)] frontegg: FronteggCliArgs, + // === OIDC options. === + /// OIDC issuer URL (e.g., "https://accounts.google.com"). + #[clap(long, env = "MZ_OIDC_ISSUER")] + oidc_issuer: Option, // === Orchestrator options. === /// The service orchestrator implementation to use. #[structopt(long, value_enum, env = "ORCHESTRATOR")] @@ -743,6 +748,11 @@ fn run(mut args: Args) -> Result<(), anyhow::Error> { // Configure connections. let tls = args.tls.into_config()?; let frontegg = FronteggAuthenticator::from_args(args.frontegg, &metrics_registry)?; + let oidc = if let Some(oidc_issuer) = args.oidc_issuer { + Some(GenericOidcAuthenticator::new(OidcConfig { oidc_issuer })?) + } else { + None + }; let listeners_config: ListenersConfig = { let f = File::open(args.listeners_config_path)?; serde_json::from_reader(f)? @@ -1083,6 +1093,7 @@ fn run(mut args: Args) -> Result<(), anyhow::Error> { tls_reload_certs: mz_server_core::default_cert_reload_ticker(), external_login_password_mz_system: args.external_login_password_mz_system, frontegg, + oidc, cors_allowed_origin, egress_addresses: args.announce_egress_address, http_host_name: args.http_host_name, diff --git a/src/environmentd/src/http.rs b/src/environmentd/src/http.rs index 452a710647208..85d9995aa911e 100644 --- a/src/environmentd/src/http.rs +++ b/src/environmentd/src/http.rs @@ -1005,7 +1005,7 @@ async fn auth( (name, external_metadata_rx, None) } Some(Credentials::Token { token }) => { - let claims = frontegg.validate_access_token(&token, None)?; + let claims = frontegg.validate_access_token(&token, None).await?; let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata { user_id: claims.user_id, admin: claims.is_admin, @@ -1043,6 +1043,31 @@ async fn auth( include_www_authenticate_header, }); } + Authenticator::Oidc(oidc) => match creds { + Some(Credentials::Token { token }) => { + // Validate JWT token + let claims = oidc + .validate_access_token(&token, None) + .await + .map_err(|_| AuthError::InvalidCredentials)?; + let name = claims.username().to_string(); + (name, None, None) + } + Some(Credentials::Password { password, .. }) => { + // Allow JWT to be passed as password + let claims = oidc + .validate_access_token(&password.0, None) + .await + .map_err(|_| AuthError::InvalidCredentials)?; + let name = claims.username().to_string(); + (name, None, None) + } + None => { + return Err(AuthError::MissingHttpAuthentication { + include_www_authenticate_header, + }); + } + }, Authenticator::None => { // If no authentication, use whatever is in the HTTP auth // header (without checking the password), or fall back to the diff --git a/src/environmentd/src/lib.rs b/src/environmentd/src/lib.rs index 045980d57a658..7c48ed3574602 100644 --- a/src/environmentd/src/lib.rs +++ b/src/environmentd/src/lib.rs @@ -36,7 +36,7 @@ use mz_adapter_types::dyncfgs::{ WITH_0DT_DEPLOYMENT_MAX_WAIT, }; use mz_auth::password::Password; -use mz_authenticator::Authenticator; +use mz_authenticator::{Authenticator, GenericOidcAuthenticator}; use mz_build_info::{BuildInfo, build_info}; use mz_catalog::config::ClusterReplicaSizeMap; use mz_catalog::durable::BootstrapArgs; @@ -106,6 +106,8 @@ pub struct Config { pub external_login_password_mz_system: Option, /// Frontegg JWT authentication configuration. pub frontegg: Option, + /// OIDC JWT authentication configuration. + pub oidc: Option, /// Origins for which cross-origin resource sharing (CORS) for HTTP requests /// is permitted. pub cors_allowed_origin: AllowOrigin, @@ -261,6 +263,7 @@ impl Listener { active_connection_counter: ConnectionCounter, tls_reloading_context: Option, frontegg: Option, + oidc: Option, adapter_client: AdapterClient, metrics: MetricsConfig, helm_chart_version: Option, @@ -280,6 +283,9 @@ impl Listener { ), AuthenticatorKind::Password => Authenticator::Password(adapter_client.clone()), AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client.clone()), + AuthenticatorKind::Oidc => Authenticator::Oidc( + oidc.expect("OIDC config is required with AuthenticatorKind::Oidc"), + ), AuthenticatorKind::None => Authenticator::None, }; @@ -380,16 +386,23 @@ impl Listeners { let authenticator_frontegg_rx = authenticator_frontegg_rx.shared(); let (authenticator_password_tx, authenticator_password_rx) = oneshot::channel(); let authenticator_password_rx = authenticator_password_rx.shared(); + let (authenticator_oidc_tx, authenticator_oidc_rx) = oneshot::channel(); + let authenticator_oidc_rx = authenticator_oidc_rx.shared(); let (authenticator_none_tx, authenticator_none_rx) = oneshot::channel(); let authenticator_none_rx = authenticator_none_rx.shared(); - // We can only send the Frontegg and None variants immediately. + // We can only send the Frontegg, OIDC, and None variants immediately. // The Password variant requires an adapter client. if let Some(frontegg) = &config.frontegg { authenticator_frontegg_tx .send(Arc::new(Authenticator::Frontegg(frontegg.clone()))) .expect("rx known to be live"); } + if let Some(oidc) = &config.oidc { + authenticator_oidc_tx + .send(Arc::new(Authenticator::Oidc(oidc.clone()))) + .expect("rx known to be live"); + } authenticator_none_tx .send(Arc::new(Authenticator::None)) .expect("rx known to be live"); @@ -406,6 +419,7 @@ impl Listeners { AuthenticatorKind::Frontegg => authenticator_frontegg_rx.clone(), AuthenticatorKind::Password => authenticator_password_rx.clone(), AuthenticatorKind::Sasl => authenticator_password_rx.clone(), + AuthenticatorKind::Oidc => authenticator_oidc_rx.clone(), AuthenticatorKind::None => authenticator_none_rx.clone(), }; let source: &'static str = Box::leak(name.clone().into_boxed_str()); @@ -827,6 +841,7 @@ impl Listeners { active_connection_counter.clone(), tls_reloading_context.clone(), config.frontegg.clone(), + config.oidc.clone(), adapter_client.clone(), metrics.clone(), config.helm_chart_version.clone(), diff --git a/src/environmentd/src/test_util.rs b/src/environmentd/src/test_util.rs index 1d1f28e4d4fa8..2b5c9dd9f8b77 100644 --- a/src/environmentd/src/test_util.rs +++ b/src/environmentd/src/test_util.rs @@ -750,6 +750,8 @@ impl Listeners { connection_context, replica_http_locator: Default::default(), }, + // TODO (SangJunBak): Add a mock OIDC authenticator + oidc: None, secrets_controller, cloud_resource_controller: None, tls: config.tls, diff --git a/src/frontegg-auth/src/auth.rs b/src/frontegg-auth/src/auth.rs index f4c68f1baf86f..f5774506bebe5 100644 --- a/src/frontegg-auth/src/auth.rs +++ b/src/frontegg-auth/src/auth.rs @@ -168,28 +168,6 @@ impl Authenticator { Ok(Some(Self::new(config, client, registry))) } - // /// Establishes a new authentication session. - // /// - // /// If successful, returns a handle to the authentication session. - // /// Otherwise, returns the authentication error. - // pub async fn authenticate( - // &self, - // expected_user: &str, - // password: &str, - // ) -> Result { - // let password: AppPassword = password.parse()?; - // match self.authenticate_inner(expected_user, password).await { - // Ok(handle) => { - // tracing::debug!("authentication successful"); - // Ok(handle) - // } - // Err(e) => { - // tracing::debug!(error = ?e, "authentication failed"); - // Err(e) - // } - // } - // } - #[instrument(level = "debug", fields(client_id = %password.client_id))] async fn authenticate_inner( &self, @@ -339,7 +317,7 @@ impl OidcAuthenticator for Authenticator { /// /// If `expected_user` is provided, the token's user name is additionally /// validated to match `expected_user`. - fn validate_access_token( + async fn validate_access_token( &self, token: &str, expected_user: Option<&str>, diff --git a/src/materialized/ci/listener_configs/oidc.json b/src/materialized/ci/listener_configs/oidc.json new file mode 100644 index 0000000000000..28cc5ca0cc353 --- /dev/null +++ b/src/materialized/ci/listener_configs/oidc.json @@ -0,0 +1,38 @@ +{ + "sql": { + "external": { + "addr": "0.0.0.0:6875", + "authenticator_kind": "Oidc", + "allowed_roles": "NormalAndInternal", + "enable_tls": false + } + }, + "http": { + "external": { + "addr": "0.0.0.0:6876", + "authenticator_kind": "Oidc", + "allowed_roles": "NormalAndInternal", + "enable_tls": false, + "routes": { + "base": true, + "webhook": true, + "internal": true, + "metrics": false, + "profiling": true + } + }, + "metrics": { + "addr": "0.0.0.0:6878", + "authenticator_kind": "None", + "allowed_roles": "NormalAndInternal", + "enable_tls": false, + "routes": { + "base": false, + "webhook": false, + "internal": false, + "metrics": true, + "profiling": false + } + } + } +} diff --git a/src/pgwire/src/protocol.rs b/src/pgwire/src/protocol.rs index 09032c41d1c4f..09e9803e72042 100644 --- a/src/pgwire/src/protocol.rs +++ b/src/pgwire/src/protocol.rs @@ -225,6 +225,59 @@ where } } } + Authenticator::Oidc(oidc) => { + tracing::info!("OIDC authentication"); + // OIDC authentication: JWT sent as password in cleartext flow + let jwt = match request_cleartext_password(conn).await { + Ok(password) => password, + Err(PasswordRequestError::IoError(e)) => return Err(e), + Err(PasswordRequestError::InvalidPasswordError(e)) => { + return conn.send(e).await; + } + }; + + tracing::info!("JWT: {}", jwt); + + // Two steps: + // 1. Validate the JWT + // 2. Check the catalog to see if the user is a superuser + // 3. If the role doesn't exist, just create one. + + let auth_response = oidc.authenticate(&user, &jwt).await; + + match auth_response { + Ok(mut auth_session) => { + // Create a session based on the auth session. + // + // In particular, it's important that the username come from the + // auth session, as Frontegg may return an email address with + // different casing than the user supplied via the pgwire + // username field. We want to use the Frontegg casing as + // canonical. + let session = adapter_client.new_session(SessionConfig { + conn_id: conn.conn_id().clone(), + uuid: conn_uuid, + user: auth_session.user().into(), + client_ip: conn.peer_addr().clone(), + external_metadata_rx: None, + // TODO (oidc_auth): Add superuser status to internal_user_metadata from catalog. + internal_user_metadata: None, + helm_chart_version, + }); + let expired = async move { auth_session.expired().await }.boxed(); + (session, expired.left_future()) + } + Err(err) => { + warn!(?err, "pgwire connection failed authentication"); + return conn + .send(ErrorResponse::fatal( + SqlState::INVALID_PASSWORD, + "invalid password", + )) + .await; + } + } + } Authenticator::Password(adapter_client) => { let password = match request_cleartext_password(conn).await { Ok(password) => Password(password), @@ -443,6 +496,7 @@ where let auth_session = pending().right_future(); (session, auth_session) } + Authenticator::None => { let session = adapter_client.new_session(SessionConfig { conn_id: conn.conn_id().clone(), @@ -645,6 +699,48 @@ fn split_options(value: &str) -> Vec { strs } +enum PasswordRequestError { + InvalidPasswordError(ErrorResponse), + IoError(io::Error), +} + +impl From for PasswordRequestError { + fn from(e: io::Error) -> Self { + PasswordRequestError::IoError(e) + } +} + +/// Requests a cleartext password from a connection and returns it if it is valid. +/// Sends an error response in the connection and returns None if the password +/// is not valid. +async fn request_cleartext_password( + conn: &mut FramedConn, +) -> Result +where + A: AsyncRead + AsyncWrite + Unpin, +{ + conn.send(BackendMessage::AuthenticationCleartextPassword) + .await?; + conn.flush().await?; + + if let Some(message) = conn.recv().await? { + if let FrontendMessage::RawAuthentication(data) = message { + if let Some(FrontendMessage::Password { password }) = + decode_password(Cursor::new(&data)).ok() + { + return Ok(password); + } + } + } + + Err(PasswordRequestError::InvalidPasswordError( + ErrorResponse::fatal( + SqlState::INVALID_AUTHORIZATION_SPECIFICATION, + "expected Password message", + ), + )) +} + #[derive(Debug)] enum State { Ready, diff --git a/src/server-core/src/listeners.rs b/src/server-core/src/listeners.rs index e078cfdf9a1f6..4298a014097c4 100644 --- a/src/server-core/src/listeners.rs +++ b/src/server-core/src/listeners.rs @@ -22,6 +22,8 @@ pub enum AuthenticatorKind { Password, /// Authenticate users using SASL. Sasl, + /// Authenticate users using OIDC (JWT tokens). + Oidc, /// Do not authenticate users. Trust they are who they say they are without verification. #[default] None, From 8c1a6f96f10f80f1164560a0acc6a07173d3fa00 Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Tue, 20 Jan 2026 11:40:59 -0500 Subject: [PATCH 04/15] Refactor internal_user_metadata out of session clients I noticed that we were doing this weird round trip of getting internal user metadata from the catalog during authentication, then passing it back when initializing the session. By just doing this on startup, we: - Remove extraneous code - Open up ease of creating a unified interface for OIDC clients --- src/adapter/src/client.rs | 22 +++++++++---- src/adapter/src/command.rs | 18 ++++------- src/adapter/src/config/backend.rs | 1 - src/adapter/src/coord/command_handler.rs | 35 +++++++++------------ src/adapter/src/session.rs | 11 ++++--- src/environmentd/src/http.rs | 40 +++++++----------------- src/environmentd/src/http/sql.rs | 1 - src/pgwire/src/protocol.rs | 21 ++----------- src/sql/src/session/user.rs | 2 ++ src/sql/src/session/vars.rs | 7 ++++- 10 files changed, 65 insertions(+), 93 deletions(-) diff --git a/src/adapter/src/client.rs b/src/adapter/src/client.rs index 175ccaa57741d..ae90292300ab4 100644 --- a/src/adapter/src/client.rs +++ b/src/adapter/src/client.rs @@ -33,6 +33,7 @@ use mz_ore::result::ResultExt; use mz_ore::task::AbortOnDropHandle; use mz_ore::thread::JoinOnDropHandle; use mz_ore::tracing::OpenTelemetryContext; +use mz_repr::user::InternalUserMetadata; use mz_repr::{CatalogItemId, ColumnIndex, Row, SqlScalarType}; use mz_sql::ast::{Raw, Statement}; use mz_sql::catalog::{EnvironmentId, SessionCatalog}; @@ -51,8 +52,8 @@ use uuid::Uuid; use crate::catalog::Catalog; use crate::command::{ - AuthResponse, CatalogDump, CatalogSnapshot, Command, ExecuteResponse, Response, - SASLChallengeResponse, SASLVerifyProofResponse, + CatalogDump, CatalogSnapshot, Command, ExecuteResponse, Response, SASLChallengeResponse, + SASLVerifyProofResponse, }; use crate::coord::{Coordinator, ExecuteContextGuard}; use crate::error::AdapterError; @@ -160,15 +161,15 @@ impl Client { &self, user: &String, password: &Password, - ) -> Result { + ) -> Result<(), AdapterError> { let (tx, rx) = oneshot::channel(); self.send(Command::AuthenticatePassword { role_name: user.to_string(), password: Some(password.clone()), tx, }); - let response = rx.await.expect("sender dropped")?; - Ok(response) + rx.await.expect("sender dropped")?; + Ok(()) } pub async fn generate_sasl_challenge( @@ -265,6 +266,7 @@ impl Client { optimizer_metrics, persist_client, statement_logging_frontend, + superuser_attribute, } = response; let peek_client = PeekClient::new( @@ -287,6 +289,15 @@ impl Client { }; let session = client.session(); + + // Apply the superuser attribute to the session's user if + // it exists. + if let Some(superuser_attribute) = superuser_attribute { + session.apply_internal_user_metadata(InternalUserMetadata { + superuser: superuser_attribute, + }); + } + session.initialize_role_metadata(role_id); let vars_mut = session.vars_mut(); for (name, val) in session_defaults { @@ -444,7 +455,6 @@ Issue a SQL query to get started. Need help? user: SUPPORT_USER.name.clone(), client_ip: None, external_metadata_rx: None, - internal_user_metadata: None, helm_chart_version: None, }); let mut session_client = self.startup(session).await?; diff --git a/src/adapter/src/command.rs b/src/adapter/src/command.rs index e53bbe3295b3e..0815681e4b270 100644 --- a/src/adapter/src/command.rs +++ b/src/adapter/src/command.rs @@ -86,7 +86,7 @@ pub enum Command { }, AuthenticatePassword { - tx: oneshot::Sender>, + tx: oneshot::Sender>, role_name: String, password: Option, }, @@ -378,6 +378,11 @@ pub struct Response { pub struct StartupResponse { /// RoleId for the user. pub role_id: RoleId, + /// The role's superuser attribute in the Catalog. + /// This attribute is None for Cloud. Cloud is able + /// to derive the role's superuser status from + /// [Session.external_metadata_rx](crate::session::Session::external_metadata_rx). + pub superuser_attribute: Option, /// A future that completes when all necessary Builtin Table writes have completed. #[derivative(Debug = "ignore")] pub write_notify: BuiltinTableAppendNotify, @@ -396,16 +401,6 @@ pub struct StartupResponse { pub statement_logging_frontend: StatementLoggingFrontend, } -/// The response to [`Client::authenticate`](crate::Client::authenticate). -#[derive(Derivative)] -#[derivative(Debug)] -pub struct AuthResponse { - /// RoleId for the user. - pub role_id: RoleId, - /// If the user is a superuser. - pub superuser: bool, -} - #[derive(Derivative)] #[derivative(Debug)] pub struct SASLChallengeResponse { @@ -419,7 +414,6 @@ pub struct SASLChallengeResponse { #[derivative(Debug)] pub struct SASLVerifyProofResponse { pub verifier: String, - pub auth_resp: AuthResponse, } // Facile implementation for `StartupResponse`, which does not use the `allowed` diff --git a/src/adapter/src/config/backend.rs b/src/adapter/src/config/backend.rs index e51f9540b92bc..fa6a52329bbbc 100644 --- a/src/adapter/src/config/backend.rs +++ b/src/adapter/src/config/backend.rs @@ -34,7 +34,6 @@ impl SystemParameterBackend { user: SYSTEM_USER.name.clone(), client_ip: None, external_metadata_rx: None, - internal_user_metadata: None, helm_chart_version: None, }); let session_client = client.startup(session).await?; diff --git a/src/adapter/src/coord/command_handler.rs b/src/adapter/src/coord/command_handler.rs index aa3a63b7a821b..3d806f74bd670 100644 --- a/src/adapter/src/coord/command_handler.rs +++ b/src/adapter/src/coord/command_handler.rs @@ -63,8 +63,8 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; use uuid::Uuid; use crate::command::{ - AuthResponse, CatalogSnapshot, Command, ExecuteResponse, SASLChallengeResponse, - SASLVerifyProofResponse, StartupResponse, + CatalogSnapshot, Command, ExecuteResponse, SASLChallengeResponse, SASLVerifyProofResponse, + StartupResponse, }; use crate::coord::appends::PendingWriteTxn; use crate::coord::peek::PendingPeek; @@ -504,14 +504,8 @@ impl Coordinator { Ok(verifier) => { // Success only if role exists, allows login, and a real password hash was used. if login && real_hash.is_some() { - let role = role.expect("login implies role exists"); - let _ = tx.send(Ok(SASLVerifyProofResponse { - verifier, - auth_resp: AuthResponse { - role_id: role.id, - superuser: role.attributes.superuser.unwrap_or(false), - }, - })); + role.expect("login implies role exists"); + let _ = tx.send(Ok(SASLVerifyProofResponse { verifier })); } else { let _ = tx.send(Err(make_auth_err(role_present, login))); } @@ -604,7 +598,7 @@ impl Coordinator { #[mz_ore::instrument(level = "debug")] async fn handle_authenticate_password( &self, - tx: oneshot::Sender>, + tx: oneshot::Sender>, role_name: String, password: Option, ) { @@ -627,13 +621,11 @@ impl Coordinator { if let Some(auth) = self.catalog().try_get_role_auth_by_id(&role.id) { if let Some(hash) = &auth.password_hash { let hash = hash.clone(); - let role_id = role.id; - let superuser = role.attributes.superuser.unwrap_or(false); task::spawn_blocking( || "auth-check-hash", move || { let _ = match mz_auth::hash::scram256_verify(&password, &hash) { - Ok(_) => tx.send(Ok(AuthResponse { role_id, superuser })), + Ok(_) => tx.send(Ok(())), Err(_) => tx.send(Err(AdapterError::AuthenticationError( AuthenticationError::InvalidCredentials, ))), @@ -669,7 +661,7 @@ impl Coordinator { ) { // Early return if successful, otherwise cleanup any possible state. match self.handle_startup_inner(&user, &conn_id, &client_ip).await { - Ok((role_id, session_defaults)) => { + Ok((role_id, superuser_attribute, session_defaults)) => { let session_type = metrics::session_type_label_value(&user); self.metrics .active_sessions @@ -744,6 +736,7 @@ impl Coordinator { optimizer_metrics: self.optimizer_metrics.clone(), persist_client: self.persist_client.clone(), statement_logging_frontend, + superuser_attribute, }); if tx.send(resp).is_err() { // Failed to send to adapter, but everything is setup so we can terminate @@ -770,7 +763,7 @@ impl Coordinator { user: &User, _conn_id: &ConnectionId, client_ip: &Option, - ) -> Result<(RoleId, BTreeMap), AdapterError> { + ) -> Result<(RoleId, Option, BTreeMap), AdapterError> { if self.catalog().try_get_role_by_name(&user.name).is_none() { // If the user has made it to this point, that means they have been fully authenticated. // This includes preventing any user, except a pre-defined set of system users, from @@ -783,11 +776,13 @@ impl Coordinator { }; self.sequence_create_role_for_startup(plan).await?; } - let role_id = self + let role = self .catalog() .try_get_role_by_name(&user.name) - .expect("created above") - .id; + .expect("created above"); + let role_id = role.id; + + let superuser_attribute = role.attributes.superuser; if role_id.is_user() && !ALLOW_USER_SESSIONS.get(self.catalog().system_config().dyncfgs()) { return Err(AdapterError::UserSessionsDisallowed); @@ -874,7 +869,7 @@ impl Coordinator { // rather than eagerly on connection startup. This avoids expensive catalog_mut() calls // for the common case where connections never create temporary objects. - Ok((role_id, session_defaults)) + Ok((role_id, superuser_attribute, session_defaults)) } /// Handles an execute command. diff --git a/src/adapter/src/session.rs b/src/adapter/src/session.rs index 50f2dcff168fe..a60b4a154be4e 100644 --- a/src/adapter/src/session.rs +++ b/src/adapter/src/session.rs @@ -210,8 +210,6 @@ pub struct SessionConfig { /// An optional receiver that the session will periodically check for /// updates to a user's external metadata. pub external_metadata_rx: Option>, - /// The metadata of the user associated with the session. - pub internal_user_metadata: Option, /// Helm chart version pub helm_chart_version: Option, } @@ -299,7 +297,6 @@ impl Session { user: SYSTEM_USER.name.clone(), client_ip: None, external_metadata_rx: None, - internal_user_metadata: None, helm_chart_version: None, }, metrics, @@ -316,7 +313,6 @@ impl Session { user, client_ip, mut external_metadata_rx, - internal_user_metadata, helm_chart_version, }: SessionConfig, metrics: SessionMetrics, @@ -325,7 +321,7 @@ impl Session { let default_cluster = INTERNAL_USER_NAME_TO_DEFAULT_CLUSTER.get(&user); let user = User { name: user, - internal_metadata: internal_user_metadata, + internal_metadata: None, external_metadata: external_metadata_rx .as_mut() .map(|rx| rx.borrow_and_update().clone()), @@ -871,6 +867,11 @@ impl Session { self.vars.set_external_user_metadata(metadata); } + /// Applies the internal user metadata to the session. + pub fn apply_internal_user_metadata(&mut self, metadata: InternalUserMetadata) { + self.vars.set_internal_user_metadata(metadata); + } + /// Initializes the session's role metadata. pub fn initialize_role_metadata(&mut self, role_id: RoleId) { self.role_metadata = Some(RoleMetadata::new(role_id)); diff --git a/src/environmentd/src/http.rs b/src/environmentd/src/http.rs index 85d9995aa911e..a12cc7aeb4dad 100644 --- a/src/environmentd/src/http.rs +++ b/src/environmentd/src/http.rs @@ -54,7 +54,7 @@ use mz_ore::metrics::MetricsRegistry; use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7}; use mz_ore::str::StrExt; use mz_pgwire_common::{ConnectionCounter, ConnectionHandle}; -use mz_repr::user::{ExternalUserMetadata, InternalUserMetadata}; +use mz_repr::user::ExternalUserMetadata; use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind, HttpRoutesEnabled}; use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server}; use mz_sql::session::metadata::SessionMetadata; @@ -550,11 +550,9 @@ async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl In ))); } }; - let superuser = matches!(username.as_str(), SYSTEM_USER_NAME); req.extensions_mut().insert(AuthedUser { name: username, external_metadata_rx: None, - internal_metadata: Some(InternalUserMetadata { superuser }), }); } Ok(next.run(req).await) @@ -572,7 +570,6 @@ enum ConnProtocol { pub struct AuthedUser { name: String, external_metadata_rx: Option>, - internal_metadata: Option, } pub struct AuthedClient { @@ -601,7 +598,6 @@ impl AuthedClient { user: user.name, client_ip: Some(peer_addr), external_metadata_rx: user.external_metadata_rx, - internal_user_metadata: user.internal_metadata, helm_chart_version, }); let connection_guard = active_connection_counter.allocate_connection(session.user())?; @@ -764,12 +760,9 @@ pub async fn handle_login( let Ok(adapter_client) = adapter_client_rx.clone().await else { return StatusCode::INTERNAL_SERVER_ERROR; }; - let auth_response = match adapter_client.authenticate(&username, &password).await { - Ok(auth_response) => auth_response, - Err(err) => { - warn!(?err, "HTTP login failed authentication"); - return StatusCode::UNAUTHORIZED; - } + if let Err(err) = adapter_client.authenticate(&username, &password).await { + warn!(?err, "HTTP login failed authentication"); + return StatusCode::UNAUTHORIZED; }; // Create session data @@ -777,9 +770,6 @@ pub async fn handle_login( username, created_at: SystemTime::now(), last_activity: SystemTime::now(), - internal_metadata: InternalUserMetadata { - superuser: auth_response.superuser, - }, }; // Store session data let session = session.and_then(|Extension(session)| Some(session)); @@ -836,7 +826,6 @@ async fn http_auth( req.extensions_mut().insert(AuthedUser { name: session_data.username, external_metadata_rx: None, - internal_metadata: Some(session_data.internal_metadata), }); return Ok(next.run(req).await); } @@ -996,13 +985,13 @@ async fn auth( allowed_roles: AllowedRoles, include_www_authenticate_header: bool, ) -> Result { - let (name, external_metadata_rx, internal_metadata) = match authenticator { + let (name, external_metadata_rx) = match authenticator { Authenticator::Frontegg(frontegg) => match creds { Some(Credentials::Password { username, password }) => { let auth_session = frontegg.authenticate(&username, &password.0).await?; let name = auth_session.user().into(); let external_metadata_rx = Some(auth_session.external_metadata_rx()); - (name, external_metadata_rx, None) + (name, external_metadata_rx) } Some(Credentials::Token { token }) => { let claims = frontegg.validate_access_token(&token, None).await?; @@ -1010,7 +999,7 @@ async fn auth( user_id: claims.user_id, admin: claims.is_admin, }); - (claims.user, Some(external_metadata_rx), None) + (claims.user, Some(external_metadata_rx)) } None => { return Err(AuthError::MissingHttpAuthentication { @@ -1020,14 +1009,11 @@ async fn auth( }, Authenticator::Password(adapter_client) => match creds { Some(Credentials::Password { username, password }) => { - let auth_response = adapter_client + adapter_client .authenticate(&username, &password) .await .map_err(|_| AuthError::InvalidCredentials)?; - let internal_metadata = InternalUserMetadata { - superuser: auth_response.superuser, - }; - (username, None, Some(internal_metadata)) + (username, None) } _ => { return Err(AuthError::MissingHttpAuthentication { @@ -1051,7 +1037,7 @@ async fn auth( .await .map_err(|_| AuthError::InvalidCredentials)?; let name = claims.username().to_string(); - (name, None, None) + (name, None) } Some(Credentials::Password { password, .. }) => { // Allow JWT to be passed as password @@ -1060,7 +1046,7 @@ async fn auth( .await .map_err(|_| AuthError::InvalidCredentials)?; let name = claims.username().to_string(); - (name, None, None) + (name, None) } None => { return Err(AuthError::MissingHttpAuthentication { @@ -1076,7 +1062,7 @@ async fn auth( Some(Credentials::Password { username, .. }) => username, _ => HTTP_DEFAULT_USER.name.to_owned(), }; - (name, None, None) + (name, None) } }; @@ -1085,7 +1071,6 @@ async fn auth( Ok(AuthedUser { name, external_metadata_rx, - internal_metadata, }) } @@ -1152,7 +1137,6 @@ pub struct TowerSessionData { username: String, created_at: SystemTime, last_activity: SystemTime, - internal_metadata: InternalUserMetadata, } #[cfg(test)] diff --git a/src/environmentd/src/http/sql.rs b/src/environmentd/src/http/sql.rs index e1f03a90d158d..d7d4b096c427f 100644 --- a/src/environmentd/src/http/sql.rs +++ b/src/environmentd/src/http/sql.rs @@ -311,7 +311,6 @@ pub async fn handle_sql_ws( Some(AuthedUser { name: session_data.username, external_metadata_rx: None, - internal_metadata: Some(session_data.internal_metadata), }) } else { None diff --git a/src/pgwire/src/protocol.rs b/src/pgwire/src/protocol.rs index 09e9803e72042..98f04a0d78a5d 100644 --- a/src/pgwire/src/protocol.rs +++ b/src/pgwire/src/protocol.rs @@ -42,7 +42,6 @@ use mz_pgwire_common::{ ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3, VERSIONS, }; -use mz_repr::user::InternalUserMetadata; use mz_repr::{ CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef, SqlRelationType, SqlScalarType, @@ -208,7 +207,6 @@ where user: auth_session.user().into(), client_ip: conn.peer_addr().clone(), external_metadata_rx: Some(auth_session.external_metadata_rx()), - internal_user_metadata: None, helm_chart_version, }); let expired = async move { auth_session.expired().await }.boxed(); @@ -238,11 +236,6 @@ where tracing::info!("JWT: {}", jwt); - // Two steps: - // 1. Validate the JWT - // 2. Check the catalog to see if the user is a superuser - // 3. If the role doesn't exist, just create one. - let auth_response = oidc.authenticate(&user, &jwt).await; match auth_response { @@ -260,8 +253,6 @@ where user: auth_session.user().into(), client_ip: conn.peer_addr().clone(), external_metadata_rx: None, - // TODO (oidc_auth): Add superuser status to internal_user_metadata from catalog. - internal_user_metadata: None, helm_chart_version, }); let expired = async move { auth_session.expired().await }.boxed(); @@ -286,7 +277,7 @@ where return conn.send(e).await; } }; - let auth_response = match adapter_client.authenticate(&user, &password).await { + match adapter_client.authenticate(&user, &password).await { Ok(resp) => resp, Err(err) => { warn!(?err, "pgwire connection failed authentication"); @@ -304,9 +295,6 @@ where user, client_ip: conn.peer_addr().clone(), external_metadata_rx: None, - internal_user_metadata: Some(InternalUserMetadata { - superuser: auth_response.superuser, - }), helm_chart_version, }); // No frontegg check, so auth session lasts indefinitely. @@ -411,7 +399,7 @@ where } }; - let auth_resp = match conn.recv().await? { + match conn.recv().await? { Some(FrontendMessage::RawAuthentication(data)) => { match decode_sasl_response(Cursor::new(&data)).ok() { Some(FrontendMessage::SASLResponse(response)) => { @@ -449,7 +437,6 @@ where )) .await?; conn.flush().await?; - resp.auth_resp } Err(_) => { return conn @@ -487,9 +474,6 @@ where user, client_ip: conn.peer_addr().clone(), external_metadata_rx: None, - internal_user_metadata: Some(InternalUserMetadata { - superuser: auth_resp.superuser, - }), helm_chart_version, }); // No frontegg check, so auth session lasts indefinitely. @@ -504,7 +488,6 @@ where user, client_ip: conn.peer_addr().clone(), external_metadata_rx: None, - internal_user_metadata: None, helm_chart_version, }); // No frontegg check, so auth session lasts indefinitely. diff --git a/src/sql/src/session/user.rs b/src/sql/src/session/user.rs index 1ff00cca6376d..d36cde11a7ea1 100644 --- a/src/sql/src/session/user.rs +++ b/src/sql/src/session/user.rs @@ -67,6 +67,8 @@ pub struct User { pub name: String, /// Metadata about this user in an external system. pub external_metadata: Option, + /// Metadata about this user stored in the catalog, + /// such as its role's `SUPERUSER` attribute. pub internal_metadata: Option, } diff --git a/src/sql/src/session/vars.rs b/src/sql/src/session/vars.rs index 014a2f2308589..a4717f00b0130 100644 --- a/src/sql/src/session/vars.rs +++ b/src/sql/src/session/vars.rs @@ -85,7 +85,7 @@ use mz_persist_client::cfg::{ use mz_repr::adt::numeric::Numeric; use mz_repr::adt::timestamp::CheckedTimestamp; use mz_repr::bytes::ByteSize; -use mz_repr::user::ExternalUserMetadata; +use mz_repr::user::{ExternalUserMetadata, InternalUserMetadata}; use mz_tracing::{CloneableEnvFilter, SerializableDirective}; use serde::Serialize; use thiserror::Error; @@ -845,6 +845,11 @@ impl SessionVars { .as_bytes() } + /// Sets the internal metadata associated with the user. + pub fn set_internal_user_metadata(&mut self, metadata: InternalUserMetadata) { + self.user.internal_metadata = Some(metadata); + } + /// Sets the external metadata associated with the user. pub fn set_external_user_metadata(&mut self, metadata: ExternalUserMetadata) { self.user.external_metadata = Some(metadata); From 4e64ba34acf1be06f3338e4b64a26f41cb6d25d4 Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Tue, 20 Jan 2026 23:17:46 -0500 Subject: [PATCH 05/15] Add OIDC mock server for testing - Introduced a new `mz-oidc-mock` package - Implemented tests for OIDC authentication --- Cargo.lock | 21 +++ Cargo.toml | 2 + src/authenticator/src/lib.rs | 2 +- src/authenticator/src/oidc.rs | 1 + src/environmentd/Cargo.toml | 2 + src/environmentd/src/test_util.rs | 60 ++++++- src/environmentd/tests/auth.rs | 183 +++++++++++++++++++++ src/oidc-mock/Cargo.toml | 36 +++++ src/oidc-mock/src/lib.rs | 256 ++++++++++++++++++++++++++++++ 9 files changed, 560 insertions(+), 3 deletions(-) create mode 100644 src/oidc-mock/Cargo.toml create mode 100644 src/oidc-mock/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index b8bd29375a8a9..f20e6880177ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6347,6 +6347,7 @@ dependencies = [ "mz-license-keys", "mz-metrics", "mz-npm", + "mz-oidc-mock", "mz-orchestrator", "mz-orchestrator-kubernetes", "mz-orchestrator-process", @@ -6901,6 +6902,26 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "mz-oidc-mock" +version = "0.0.0" +dependencies = [ + "anyhow", + "axum", + "base64 0.22.1", + "jsonwebtoken", + "mz-authenticator", + "mz-ore", + "openssl", + "reqwest", + "serde", + "serde_json", + "tokio", + "tracing", + "uuid", + "workspace-hack", +] + [[package]] name = "mz-orchestrator" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index ac084c6193e36..3cf6846491a6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,7 @@ members = [ "src/mz", "src/mz-debug", "src/npm", + "src/oidc-mock", "src/orchestrator", "src/orchestrator-kubernetes", "src/orchestrator-process", @@ -184,6 +185,7 @@ default-members = [ "src/mz", "src/mz-debug", "src/npm", + "src/oidc-mock", "src/orchestrator", "src/orchestrator-kubernetes", "src/orchestrator-process", diff --git a/src/authenticator/src/lib.rs b/src/authenticator/src/lib.rs index e45ce4b20122b..8525873656c89 100644 --- a/src/authenticator/src/lib.rs +++ b/src/authenticator/src/lib.rs @@ -12,7 +12,7 @@ pub mod oidc; use mz_adapter::Client as AdapterClient; use mz_frontegg_auth::Authenticator as FronteggAuthenticator; -pub use oidc::{GenericOidcAuthenticator, OidcConfig, OidcError}; +pub use oidc::{GenericOidcAuthenticator, OidcClaims, OidcConfig, OidcError}; #[derive(Debug, Clone)] pub enum Authenticator { diff --git a/src/authenticator/src/oidc.rs b/src/authenticator/src/oidc.rs index 0451e2070c2d5..52fcaff9f0757 100644 --- a/src/authenticator/src/oidc.rs +++ b/src/authenticator/src/oidc.rs @@ -156,6 +156,7 @@ impl GenericOidcAuthenticator { issuer: config.oidc_issuer, jwks_uri, decoding_keys: Mutex::new(BTreeMap::new()), + // TODO: Use same client code as frontegg-auth. http_client: HttpClient::new(), }), }) diff --git a/src/environmentd/Cargo.toml b/src/environmentd/Cargo.toml index 60c89f4e3f243..d3dcb3a2ec6ba 100644 --- a/src/environmentd/Cargo.toml +++ b/src/environmentd/Cargo.toml @@ -56,6 +56,7 @@ mz-dyncfg = { path = "../dyncfg" } mz-dyncfgs = { path = "../dyncfgs" } mz-frontegg-auth = { path = "../frontegg-auth" } mz-frontegg-mock = { path = "../frontegg-mock", optional = true } +mz-oidc-mock = { path = "../oidc-mock", optional = true } mz-http-util = { path = "../http-util" } mz-interchange = { path = "../interchange" } mz-license-keys = { path = "../license-keys" } @@ -185,6 +186,7 @@ test = [ "postgres-openssl", "mz-tracing", "mz-frontegg-mock", + "mz-oidc-mock", "tracing-capture", "mz-orchestrator-tracing/capture", ] diff --git a/src/environmentd/src/test_util.rs b/src/environmentd/src/test_util.rs index 2b5c9dd9f8b77..ac5a895a5307f 100644 --- a/src/environmentd/src/test_util.rs +++ b/src/environmentd/src/test_util.rs @@ -90,6 +90,7 @@ use crate::{ CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig, WebSocketAuth, WebSocketResponse, }; +use mz_authenticator::GenericOidcAuthenticator; pub static KAFKA_ADDRS: LazyLock = LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into())); @@ -100,6 +101,7 @@ pub struct TestHarness { data_directory: Option, tls: Option, frontegg: Option, + oidc: Option, external_login_password_mz_system: Option, listeners_config: ListenersConfig, unsafe_mode: bool, @@ -137,6 +139,7 @@ impl Default for TestHarness { data_directory: None, tls: None, frontegg: None, + oidc: None, external_login_password_mz_system: None, listeners_config: ListenersConfig { sql: btreemap![ @@ -362,6 +365,60 @@ impl TestHarness { self } + pub fn with_oidc_auth(mut self, oidc: &GenericOidcAuthenticator) -> Self { + self.oidc = Some(oidc.clone()); + let enable_tls = self.tls.is_some(); + self.listeners_config = ListenersConfig { + sql: btreemap! { + "external".to_owned() => SqlListenerConfig { + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + authenticator_kind: AuthenticatorKind::Oidc, + allowed_roles: AllowedRoles::Normal, + enable_tls, + }, + "internal".to_owned() => SqlListenerConfig { + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + authenticator_kind: AuthenticatorKind::None, + allowed_roles: AllowedRoles::NormalAndInternal, + enable_tls: false, + }, + }, + http: btreemap! { + "external".to_owned() => HttpListenerConfig { + base: BaseListenerConfig { + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + authenticator_kind: AuthenticatorKind::Oidc, + allowed_roles: AllowedRoles::Normal, + enable_tls, + }, + routes: HttpRoutesEnabled{ + base: true, + webhook: true, + internal: false, + metrics: false, + profiling: false, + }, + }, + "internal".to_owned() => HttpListenerConfig { + base: BaseListenerConfig { + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + authenticator_kind: AuthenticatorKind::None, + allowed_roles: AllowedRoles::NormalAndInternal, + enable_tls: false, + }, + routes: HttpRoutesEnabled{ + base: true, + webhook: true, + internal: true, + metrics: true, + profiling: true, + }, + }, + }, + }; + self + } + pub fn with_password_auth(mut self, mz_system_password: Password) -> Self { self.external_login_password_mz_system = Some(mz_system_password); let enable_tls = self.tls.is_some(); @@ -750,8 +807,7 @@ impl Listeners { connection_context, replica_http_locator: Default::default(), }, - // TODO (SangJunBak): Add a mock OIDC authenticator - oidc: None, + oidc: config.oidc, secrets_controller, cloud_resource_controller: None, tls: config.tls, diff --git a/src/environmentd/tests/auth.rs b/src/environmentd/tests/auth.rs index 4c1a7312eac69..ec3526555439b 100644 --- a/src/environmentd/tests/auth.rs +++ b/src/environmentd/tests/auth.rs @@ -38,6 +38,7 @@ use hyper_util::rt::TokioExecutor; use itertools::Itertools; use jsonwebtoken::{self, DecodingKey, EncodingKey}; use mz_auth::password::Password; +use mz_authenticator::{GenericOidcAuthenticator, OidcConfig}; use mz_environmentd::test_util::{self, Ca, make_header, make_pg_tls}; use mz_environmentd::{WebSocketAuth, WebSocketResponse}; use mz_frontegg_auth::{ @@ -47,6 +48,7 @@ use mz_frontegg_auth::{ use mz_frontegg_mock::{ FronteggMockServer, models::ApiToken, models::TenantApiTokenConfig, models::UserConfig, }; +use mz_oidc_mock::OidcMockServer; use mz_ore::error::ErrorExt; use mz_ore::metrics::MetricsRegistry; use mz_ore::now::{NowFn, SYSTEM_TIME}; @@ -1269,6 +1271,187 @@ async fn test_auth_base_require_tls_frontegg() { .await; } +/// Tests OIDC authentication with TLS required. +/// +/// This test verifies that users can authenticate using JWT tokens +/// over TLS connections +#[allow(clippy::unit_arg)] +#[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +async fn test_auth_base_require_tls_oidc() { + let ca = Ca::new_root("test ca").unwrap(); + let (server_cert, server_key) = ca + .request_cert("server", vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]) + .unwrap(); + + // Get PEM string from CA's key (same pattern as Frontegg TLS test). + let encoding_key = String::from_utf8(ca.pkey.private_key_to_pem_pkcs8().unwrap()).unwrap(); + + let kid = "test-key-1".to_string(); + let oidc_server = OidcMockServer::start( + None, + encoding_key, + kid, + SYSTEM_TIME.clone(), + i64::try_from(EXPIRES_IN_SECS).unwrap(), + ) + .await + .unwrap(); + + let oidc_auth = GenericOidcAuthenticator::new(OidcConfig { + oidc_issuer: oidc_server.issuer.clone(), + }) + .unwrap(); + + let oidc_user = "user@example.com"; + let jwt_token = oidc_server.generate_jwt(oidc_user, Some(oidc_user)); + let expired_token = oidc_server.generate_jwt_with_exp(oidc_user, Some(oidc_user), 0); + let wrong_user_token = oidc_server.generate_jwt("other@example.com", Some("other@example.com")); + let wrong_issuer_token = oidc_server.generate_jwt_with_issuer( + oidc_user, + Some(oidc_user), + "https://wrong-issuer.com", + ); + + // Create Bearer auth header for HTTP tests. + let oidc_bearer = Authorization::bearer(&jwt_token).unwrap(); + let oidc_header_bearer = make_header(oidc_bearer); + + let oidc_header_basic = make_header(Authorization::basic(oidc_user, &jwt_token)); + + let server = test_util::TestHarness::default() + .with_tls(server_cert, server_key) + .with_oidc_auth(&oidc_auth) + .start() + .await; + + run_tests( + "TlsMode::Require, OIDC", + &server, + &[ + // TLS with valid JWT should succeed. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&jwt_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // HTTP with Bearer auth should succeed. + TestCase::Http { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + scheme: Scheme::HTTPS, + headers: &oidc_header_bearer, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // HTTP with basic username/password should succeed. + TestCase::Http { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + scheme: Scheme::HTTPS, + headers: &oidc_header_basic, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // Ws with bearer token should succeed. + TestCase::Ws { + auth: &WebSocketAuth::Bearer { + token: jwt_token.clone(), + options: BTreeMap::default(), + }, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // Ws with basic username/password should succeed. + TestCase::Ws { + auth: &WebSocketAuth::Basic { + user: oidc_user.to_string(), + password: Password(jwt_token.clone()), + options: BTreeMap::default(), + }, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // No TLS should fail when server requires it. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&jwt_token)), + ssl_mode: SslMode::Disable, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!( + *err.code(), + SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION + ); + assert_eq!(err.message(), "TLS encryption is required"); + })), + }, + // HTTP without TLS should be rejected. + TestCase::Http { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + scheme: Scheme::HTTP, + headers: &oidc_header_bearer, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: assert_http_rejected(), + }, + // Invalid JWT should fail. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed("invalid-jwt-token")), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!(err.message(), "invalid password"); + assert_eq!(*err.code(), SqlState::INVALID_PASSWORD); + })), + }, + // Expired JWT should fail. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&expired_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!(err.message(), "invalid password"); + assert_eq!(*err.code(), SqlState::INVALID_PASSWORD); + })), + }, + // JWT for wrong user should fail. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&wrong_user_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!(err.message(), "invalid password"); + assert_eq!(*err.code(), SqlState::INVALID_PASSWORD); + })), + }, + // JWT with wrong issuer should fail. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&wrong_issuer_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!(err.message(), "invalid password"); + assert_eq!(*err.code(), SqlState::INVALID_PASSWORD); + })), + }, + ], + ) + .await; +} + #[allow(clippy::unit_arg)] #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` diff --git a/src/oidc-mock/Cargo.toml b/src/oidc-mock/Cargo.toml new file mode 100644 index 0000000000000..76b0fdd3a66c3 --- /dev/null +++ b/src/oidc-mock/Cargo.toml @@ -0,0 +1,36 @@ + +[package] +name = "mz-oidc-mock" +description = "OIDC mock server for testing." +version = "0.0.0" +edition.workspace = true +rust-version.workspace = true +publish = false + +[lints] +workspace = true + +[dependencies] +anyhow = "1.0.100" +axum = "0.7.5" +base64 = "0.22.1" +jsonwebtoken = "9.3.1" +mz-authenticator = { path = "../authenticator" } +mz-ore = { path = "../ore", default-features = false, features = ["cli"] } +openssl = { version = "0.10.75", features = ["vendored"] } +serde = { version = "1.0.219", features = ["derive"] } +serde_json = "1.0.148" +tokio = { version = "1.48.0", default-features = false } +tracing = "0.1.44" +uuid = "1.19.0" +workspace-hack = { version = "0.0.0", path = "../workspace-hack", optional = true } + +[dev-dependencies] +openssl = { version = "0.10.75", features = ["vendored"] } +reqwest = { version = "0.12.28", features = ["json"] } + +[package.metadata.cargo-udeps.ignore] +normal = ["workspace-hack"] + +[features] +default = ["workspace-hack"] diff --git a/src/oidc-mock/src/lib.rs b/src/oidc-mock/src/lib.rs new file mode 100644 index 0000000000000..cad82d2f7122d --- /dev/null +++ b/src/oidc-mock/src/lib.rs @@ -0,0 +1,256 @@ +// Copyright Materialize, Inc. and contributors. All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +//! OIDC mock server for testing. +//! +//! This module provides a mock OIDC server that serves JWKS endpoints +//! for validating JWT tokens in tests. + +use std::borrow::Cow; +use std::future::IntoFuture; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; + +use axum::extract::State; +use axum::routing::get; +use axum::{Json, Router}; +use base64::Engine; +use jsonwebtoken::{EncodingKey, Header, encode}; +use mz_authenticator::OidcClaims; +use mz_ore::now::NowFn; +use mz_ore::task::JoinHandle; +use openssl::pkey::{PKey, Private}; +use openssl::rsa::Rsa; +use serde::{Deserialize, Serialize}; +use tokio::net::TcpListener; + +/// JWKS response structure. +#[derive(Debug, Clone, Serialize, Deserialize)] +struct JwkSet { + pub keys: Vec, +} + +/// JSON Web Key structure. +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Jwk { + pub kty: String, + pub kid: String, + #[serde(rename = "use")] + pub key_use: String, + pub alg: String, + pub n: String, + pub e: String, +} + +/// Shared context for the OIDC mock server. +struct OidcMockContext { + /// The issuer URL (base URL of this server). + issuer: String, + /// RSA public key in JWK format. + jwk: Jwk, +} + +/// OIDC mock server for testing. +pub struct OidcMockServer { + /// Base URL of the server (e.g., "http://127.0.0.1:12345"). + pub base_url: String, + /// The issuer URL (same as base_url, used for JWT iss claim). + pub issuer: String, + /// Key ID used in JWT headers. + pub kid: String, + /// Encoding key for signing JWTs (for generating test tokens). + pub encoding_key: EncodingKey, + /// Function for getting current time. + pub now: NowFn, + /// How long tokens should be valid (in seconds). + pub expires_in_secs: i64, + /// Handle to the server task. + pub handle: JoinHandle>, +} + +impl OidcMockServer { + /// Starts an [`OidcMockServer`]. + /// + /// Must be started from within a [`tokio::runtime::Runtime`]. + /// + /// # Arguments + /// + /// * `addr` - Optional address to bind to. If None, binds to localhost on a random port. + /// * `encoding_key` - PEM-encoded RSA private key string for signing JWTs. + /// * `kid` - Key ID to use in JWT headers and JWKS. + /// * `now` - Function for getting current time. + /// * `expires_in_secs` - How long tokens should be valid. + pub async fn start( + addr: Option<&SocketAddr>, + encoding_key: String, + kid: String, + now: NowFn, + expires_in_secs: i64, + ) -> Result { + // Convert PEM string to key. + let encoding_key_typed = EncodingKey::from_rsa_pem(encoding_key.as_bytes())?; + + // Parse the private key PEM to extract RSA components for JWKS. + let pkey = PKey::private_key_from_pem(encoding_key.as_bytes())?; + let rsa = pkey.rsa().expect("pkey should be RSA"); + + let addr = match addr { + Some(addr) => Cow::Borrowed(addr), + None => Cow::Owned(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)), + }; + + let listener = TcpListener::bind(*addr).await.unwrap_or_else(|e| { + panic!("error binding to {}: {}", addr, e); + }); + let base_url = format!("http://{}", listener.local_addr().unwrap()); + let issuer = base_url.clone(); + + // Extract RSA public key components from the decoding key + // We need to serialize the public key to get n and e values + let jwk = create_jwk(&kid, &rsa); + + let context = Arc::new(OidcMockContext { + issuer: issuer.clone(), + jwk, + }); + + let router = Router::new() + .route("/.well-known/jwks.json", get(handle_jwks)) + .route( + "/.well-known/openid-configuration", + get(handle_openid_config), + ) + .with_state(context); + + let server = axum::serve( + listener, + router.into_make_service_with_connect_info::(), + ); + println!("oidc-mock listening..."); + println!(" HTTP address: {}", base_url); + let handle = mz_ore::task::spawn(|| "oidc-mock-server", server.into_future()); + + Ok(OidcMockServer { + base_url, + issuer, + kid, + encoding_key: encoding_key_typed, + now, + expires_in_secs, + handle, + }) + } + + /// Generates a JWT token for testing. + /// + /// # Arguments + /// + /// * `sub` - Subject (user identifier). + /// * `email` - Optional email claim. + pub fn generate_jwt(&self, sub: &str, email: Option<&str>) -> String { + let now_ms = (self.now)(); + let now_secs = (now_ms / 1000) as i64; + + let claims = OidcClaims { + sub: sub.to_string(), + iss: self.issuer.clone(), + exp: now_secs + self.expires_in_secs, + iat: Some(now_secs), + email: email.map(|s| s.to_string()), + }; + + let mut header = Header::new(jsonwebtoken::Algorithm::RS256); + header.kid = Some(self.kid.clone()); + + encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT") + } + + /// Generates a JWT token with a specific expiration time. + pub fn generate_jwt_with_exp(&self, sub: &str, email: Option<&str>, exp: i64) -> String { + let now_ms = (self.now)(); + let now_secs = (now_ms / 1000) as i64; + + let claims = OidcClaims { + sub: sub.to_string(), + iss: self.issuer.clone(), + exp, + iat: Some(now_secs), + email: email.map(|s| s.to_string()), + }; + + let mut header = Header::new(jsonwebtoken::Algorithm::RS256); + header.kid = Some(self.kid.clone()); + + encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT") + } + + /// Generates a JWT token with a custom issuer (for testing invalid tokens). + pub fn generate_jwt_with_issuer(&self, sub: &str, email: Option<&str>, issuer: &str) -> String { + let now_ms = (self.now)(); + let now_secs = (now_ms / 1000) as i64; + + let claims = OidcClaims { + sub: sub.to_string(), + iss: issuer.to_string(), + exp: now_secs + self.expires_in_secs, + iat: Some(now_secs), + email: email.map(|s| s.to_string()), + }; + + let mut header = Header::new(jsonwebtoken::Algorithm::RS256); + header.kid = Some(self.kid.clone()); + + encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT") + } + + /// Returns the JWKS URL for this server. + pub fn jwks_url(&self) -> String { + format!("{}/.well-known/jwks.json", self.base_url) + } +} + +/// Handler for JWKS endpoint. +async fn handle_jwks(State(context): State>) -> Json { + Json(JwkSet { + keys: vec![context.jwk.clone()], + }) +} + +/// OpenID Configuration response. +#[derive(Serialize)] +struct OpenIdConfiguration { + issuer: String, + jwks_uri: String, +} + +/// Handler for OpenID Configuration endpoint. +async fn handle_openid_config( + State(context): State>, +) -> Json { + Json(OpenIdConfiguration { + issuer: context.issuer.clone(), + jwks_uri: format!("{}/.well-known/jwks.json", context.issuer), + }) +} + +/// Creates a JWK from RSA key components. +fn create_jwk(kid: &str, rsa: &Rsa) -> Jwk { + let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD; + let n = rsa.n().to_vec(); + let e = rsa.e().to_vec(); + + Jwk { + kty: "RSA".to_string(), + kid: kid.to_string(), + key_use: "sig".to_string(), + alg: "RS256".to_string(), + n: engine.encode(n), + e: engine.encode(e), + } +} From 7aa7db6915c0791fdefd8ad72b985ba956a13bd5 Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Tue, 20 Jan 2026 16:57:20 -0500 Subject: [PATCH 06/15] Implement OIDC option extraction - Add functionality to extract `oidc_auth_enabled` from startup options, allowing us to use the password authenticator in the future --- src/pgwire/src/protocol.rs | 125 ++++++++++++++++++++++++++++++++++--- 1 file changed, 118 insertions(+), 7 deletions(-) diff --git a/src/pgwire/src/protocol.rs b/src/pgwire/src/protocol.rs index 98f04a0d78a5d..525178bcafd3a 100644 --- a/src/pgwire/src/protocol.rs +++ b/src/pgwire/src/protocol.rs @@ -160,6 +160,13 @@ where } let user = params.remove("user").unwrap_or_else(String::new); + let options = parse_options(params.get("options").unwrap_or(&String::new())); + + // If oidc_auth_enabled exists as an option, return its value and filter it from + // the remaining options. + // TODO (SangJunBak): Use oidc_auth_enabled boolean to implement OIDC flow instead of + // OIDC authenticator kind. + let (_oidc_auth_enabled, options) = extract_oidc_auth_enabled_from_options(options); // TODO move this somewhere it can be shared with HTTP let is_internal_user = INTERNAL_USER_NAMES.contains(&user); @@ -499,7 +506,7 @@ where let system_vars = adapter_client.get_system_vars().await; for (name, value) in params { let settings = match name.as_str() { - "options" => match parse_options(&value) { + "options" => match &options { Ok(opts) => opts, Err(()) => { session.add_notice(AdapterNotice::BadStartupSetting { @@ -509,7 +516,7 @@ where continue; } }, - _ => vec![(name, value)], + _ => &vec![(name, value)], }; for (key, val) in settings { const LOCAL: bool = false; @@ -517,13 +524,12 @@ where // (silently ignore errors on set), but erroring the connection // might be the better behavior. We maybe need to support more // options sent by psql and drivers before we can safely do this. - if let Err(err) = - session - .vars_mut() - .set(&system_vars, &key, VarInput::Flat(&val), LOCAL) + if let Err(err) = session + .vars_mut() + .set(&system_vars, key, VarInput::Flat(val), LOCAL) { session.add_notice(AdapterNotice::BadStartupSetting { - name: key, + name: key.clone(), reason: err.to_string(), }); } @@ -601,6 +607,31 @@ where } } +/// Gets `oidc_auth_enabled` from options if it exists. +/// Returns options with oidc_auth_enabled extracted +/// and the oidc_auth_enabled value. +fn extract_oidc_auth_enabled_from_options( + options: Result, ()>, +) -> (bool, Result, ()>) { + let options = match options { + Ok(opts) => opts, + Err(_) => return (false, options), + }; + + let mut new_options = Vec::new(); + let mut oidc_auth_enabled = false; + + for (k, v) in options { + if k == "oidc_auth_enabled" { + oidc_auth_enabled = v.parse::().unwrap_or(false); + } else { + new_options.push((k, v)); + } + } + + (oidc_auth_enabled, Ok(new_options)) +} + /// Returns (name, value) session settings pairs from an options value. /// /// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches @@ -3039,4 +3070,84 @@ mod test { assert_eq!(got, test.expect, "input: {}", test.input); } } + + #[mz_ore::test] + fn test_extract_oidc_auth_enabled_from_options() { + struct TestCase { + input: Result, ()>, + expect_enabled: bool, + expect_options: Result, ()>, + } + let tests = vec![ + // Empty options + TestCase { + input: Ok(vec![]), + expect_enabled: false, + expect_options: Ok(vec![]), + }, + // Error input passthrough + TestCase { + input: Err(()), + expect_enabled: false, + expect_options: Err(()), + }, + // oidc_auth_enabled=true + TestCase { + input: Ok(vec![("oidc_auth_enabled", "true")]), + expect_enabled: true, + expect_options: Ok(vec![]), + }, + // oidc_auth_enabled=false + TestCase { + input: Ok(vec![("oidc_auth_enabled", "false")]), + expect_enabled: false, + expect_options: Ok(vec![]), + }, + // Invalid oidc_auth_enabled value defaults to false + TestCase { + input: Ok(vec![("oidc_auth_enabled", "invalid")]), + expect_enabled: false, + expect_options: Ok(vec![]), + }, + // No oidc_auth_enabled, other options preserved + TestCase { + input: Ok(vec![("key1", "val1"), ("key2", "val2")]), + expect_enabled: false, + expect_options: Ok(vec![("key1", "val1"), ("key2", "val2")]), + }, + // Mixed: oidc_auth_enabled with other options + TestCase { + input: Ok(vec![ + ("key1", "val1"), + ("oidc_auth_enabled", "true"), + ("key2", "val2"), + ]), + expect_enabled: true, + expect_options: Ok(vec![("key1", "val1"), ("key2", "val2")]), + }, + ]; + for test in tests { + let input = test.input.map(|r| { + r.into_iter() + .map(|(k, v)| (k.to_owned(), v.to_owned())) + .collect() + }); + let (got_enabled, got_options) = extract_oidc_auth_enabled_from_options(input.clone()); + let expect_options = test.expect_options.map(|r| { + r.into_iter() + .map(|(k, v)| (k.to_owned(), v.to_owned())) + .collect() + }); + assert_eq!( + got_enabled, test.expect_enabled, + "enabled mismatch for input: {:?}", + input + ); + assert_eq!( + got_options, expect_options, + "options mismatch for input: {:?}", + input + ); + } + } } From ffadefd2acd4a8cc176b88695a5943d82d2d1981 Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Tue, 20 Jan 2026 23:27:18 -0500 Subject: [PATCH 07/15] Fix lint errors --- oidc_auth_setup.md | 83 ----------------------- src/adapter/src/command.rs | 2 +- src/adapter/src/coord/validity.rs | 1 - src/authenticator-types/src/lib.rs | 8 ++- src/authenticator/src/oidc.rs | 2 +- src/environmentd/src/environmentd/main.rs | 2 +- src/environmentd/tests/auth.rs | 5 +- src/oidc-mock/src/lib.rs | 6 +- src/pgwire/src/protocol.rs | 22 +----- src/sqllogictest/src/runner.rs | 1 + 10 files changed, 16 insertions(+), 116 deletions(-) delete mode 100644 oidc_auth_setup.md diff --git a/oidc_auth_setup.md b/oidc_auth_setup.md deleted file mode 100644 index 9a9dcb3a0bebe..0000000000000 --- a/oidc_auth_setup.md +++ /dev/null @@ -1,83 +0,0 @@ -## Setting - -PGOAUTHDEBUG=UNSAFE psql 'host=192.168.215.3 user=employees dbname=promo oauth_issuer=http://host.docker.internal:4444 oauth_client_id=1186624a-7bed-44f8-867e-d3938a29c924 oauth_client_secret=88OlbvOkiDaPSypu94qK_WHDjG' - - -code_client=$(docker compose -f quickstart.yml exec hydra \ - hydra create client \ - --endpoint http://127.0.0.1:4445 \ - --grant-type authorization_code,refresh_token \ - --response-type code,id_token \ - --format json \ - --scope openid --scope offline --scope profile --scope email\ - --access-token-strategy jwt \ - --redirect-uri http://127.0.0.1:5555/callback) - -code_client_id=$(echo $code_client | jq -r '.client_id') -code_client_secret=$(echo $code_client | jq -r '.client_secret') - -docker compose -f quickstart.yml exec hydra \ - hydra perform authorization-code \ - --client-id $code_client_id \ - --client-secret $code_client_secret \ - --endpoint http://127.0.0.1:4444/ \ - --port 5555 \ - --scope openid --scope offline --scope profile --scope email - -## Deleting a client -hydra delete oauth2-client --endpoint http://localhost:4445 b1a93de1-e4dd-4da9-8e81-083ec4e89f6e=$ - -client id: 060a4f3d-1cac-46e4-b5a5-6b9c66cd9431 -secret: wAghHCKR_E26yuLRpSkaoz2epq - - - - - -device - -device_client=$(docker compose -f quickstart.yml exec hydra \ - hydra create client \ - --endpoint http://127.0.0.1:4445 \ - --format json \ - --name "my device app" \ - --grant-type urn:ietf:params:oauth:grant-type:device_code,refresh_token \ - --token-endpoint-auth-method none \ - --access-token-strategy jwt \ - --scope openid,offline_access,profile) - -device_client_id=$(echo $device_client | jq -r '.client_id') -device_client_secret=$(echo $device_client | jq -r '.client_secret') - -echo $device_client_id -echo $device_client_secret - - -docker compose -f quickstart.yml exec hydra \ - hydra perform device-code \ - --client-id $device_client_id \ - --client-secret $device_client_secret \ - --endpoint http://127.0.0.1:4444/ \ - --scope openid,offline_access - - -Visit http://host.docker.internal:4444/oauth2/device/verify and enter the code: mpGRAMPk - - -http://localhost:4444/.well-known/jwks.json - - bin/environmentd -- \ - --oidc-issuer="http://127.0.0.1:4444" \ - --listeners-config-path='src/materialized/ci/listener_configs/oidc.json' - - -Access token: - -eyJhbGciOiJSUzI1NiIsImtpZCI6Ijk3ZTJmOTJhLWM2YjQtNDQ0ZC1hNjZhLWY3Y2YwOTIwNzdhMyIsInR5cCI6IkpXVCJ9.eyJhdWQiOltdLCJjbGllbnRfaWQiOiJlMGZkYzJkZC05YTU1LTRlZjEtYWU2Zi05YjQyYTJhMzA3NWYiLCJleHAiOjE3NjgyNTgzODksImV4dCI6e30sImlhdCI6MTc2ODI1NDc4OCwiaXNzIjoiaHR0cDovLzEyNy4wLjAuMTo0NDQ0IiwianRpIjoiNTk1N2RiN2YtZGVjZi00OWE0LTliNTMtOTBiM2M1ZDhlMWI5IiwibmJmIjoxNzY4MjU0Nzg4LCJzY3AiOlsib3BlbmlkIiwib2ZmbGluZSIsInByb2ZpbGUiLCJlbWFpbCJdLCJzdWIiOiJmb29AYmFyLmNvbSJ9.w_Vype6NRh_gAIowtja24alhINfXOGRavwq9Nd3gu1tW5l1zxDza6X5iPhMmSTnnlDm1dbekAVQ8tldZs5XDfycDFsFfuMa-IvsoQ3GUyglGMv-hdVa8hDBLGLonVkn5fZDAiotnRzDKZo1qGJ7nDGkV1_oO7DE_BqlXC6OebqQKXdyzZI4xXrreQvQCF0JiW4Kz7F3FZrJeIMyBgMgwgt1spi6YFuER-08l0ZPotrQ20KGhTHy0k-zpyjPUZA8vm8AAiyePvgIHh4pAm_0k4gG_fcX6rw5Hv3UsNtDH42b2QQhGgqY_gvBTNCxCW_wHmHtrgFYiIH7N3NwQE36ZJLSAVL9xuVdaV9km1ZSHAnJ5TdXrtB1wEEsjwYFIrv0AwUv-mlUk0QS7E_8Wv_-BqwgbE4TjdcTIe2-S85N3i7w_LkJT5D1tIwSKlotXCfRV_nTvrWAwar9bLBdynBXYAhwpzASCub_L4qqwCvrWnOPYIHqb9EQFsIEqYaKv_Iz5BLUaMC4fgymSDpsb_kujlQNmR1R_EfMIZA2noFQ8HZ3JJfWckYgLLpJL5RhHDZoQIQOpZL2RvXE1Ud8roT1f2sRGqotNJ93PcBuISzVzJ5ov3ZM7VU3QjhQ4q4Z5q6BQIAFW_j8WoLbWZ9KlCbhVTzXu0usjeI4BHf0_HluIkNE - - Run ``` - PGPASSWORD="eyJhbGciOiJSUzI1NiIsImtpZCI6Ijk3ZTJmOTJhLWM2YjQtNDQ0ZC1hNjZhLWY3Y2YwOTIwNzdhMyIsInR5cCI6IkpXVCJ9.eyJhdWQiOltdLCJjbGllbnRfaWQiOiIyY2MzZTE0Ny1iZTMxLTQzNGYtOWNiYy02MzY3Y2Y0NDE3Y2MiLCJleHAiOjE3Njg5Mzk5MTYsImV4dCI6e30sImlhdCI6MTc2ODkzNjMxNiwiaXNzIjoiaHR0cDovLzEyNy4wLjAuMTo0NDQ0IiwianRpIjoiMjlkMjY4YjgtOTc1ZC00NzhkLTkzOTktNDNjMDk2MTVmZGJlIiwibmJmIjoxNzY4OTM2MzE2LCJzY3AiOlsib3BlbmlkIiwib2ZmbGluZSIsInByb2ZpbGUiLCJlbWFpbCJdLCJzdWIiOiJmb29AYmFyLmNvbSJ9.LlxflqxSf8l0EsgR0F1mVW2JVBYAsVBeElZcqUK6CD-_wBfvgQzCQlIVZNRNYnVQYZwrDzJvrv4-niysYxJ-PRgMAp826nuZIjgQz5qMDovmrqO6UFXSW6pA1rR4N1tQVAYmCAoS-O3PgntJHUE5vpX19D28YbyyS60Yo1u5KzXkQdqZPrkStUhcHXJP-4CfX43k9ginbn23XtKqp0NXzMbJRXk1wY4P6Si-9pqeqLiD7CtNCRXbFowFBePsr9cYoQJRV6ausBRTFE7mXxybU7NFKvuerGNUFI5u0LKzuRsvhw5iuJHPi2PLxrxWjqh3idKgCRbFGJR-Vk63El7Z4O-piwREShyNDfAU1_KREIPt5-zxCp-qe0JhCYEPVikICRT2NF0c29ZzvaxsGEOb8PZBXgPRncfDAsz-fXTOr2MXuGsxIZBgcRx6oDR2mnGZKIXrLqRiBDwik66M2LDE7x5FQZqiha0y2_h_PwNhnDdWqcRAb-l48FqtZCXDi5V4zyfYCw24sWhXVyqLi5WGozavVSxCvZmUP3Qd1OvE9j2n3JOjlecY-7G3ccV1Te_uYNcALyo2DRaiA1mO7XHhmh-9W1_DOZWljmYF9j6qJhMft38N-6fB4Wp8U7vKwdHVfBk5dHb8q95qaLefJkaXk79vA28Wmu6_LHn7EFC0U_M" psql -h localhost -p 6875 -U foo@bar.com materialize - ``` diff --git a/src/adapter/src/command.rs b/src/adapter/src/command.rs index 0815681e4b270..e07728281c1aa 100644 --- a/src/adapter/src/command.rs +++ b/src/adapter/src/command.rs @@ -381,7 +381,7 @@ pub struct StartupResponse { /// The role's superuser attribute in the Catalog. /// This attribute is None for Cloud. Cloud is able /// to derive the role's superuser status from - /// [Session.external_metadata_rx](crate::session::Session::external_metadata_rx). + /// external_metadata_rx. pub superuser_attribute: Option, /// A future that completes when all necessary Builtin Table writes have completed. #[derivative(Debug = "ignore")] diff --git a/src/adapter/src/coord/validity.rs b/src/adapter/src/coord/validity.rs index 48ec3cffcd2db..8cf542cfbc4c9 100644 --- a/src/adapter/src/coord/validity.rs +++ b/src/adapter/src/coord/validity.rs @@ -223,7 +223,6 @@ mod tests { user, client_ip: None, external_metadata_rx: None, - internal_user_metadata: None, helm_chart_version: None, }, metrics.session_metrics(), diff --git a/src/authenticator-types/src/lib.rs b/src/authenticator-types/src/lib.rs index 30793a325be28..25be9de24d95c 100644 --- a/src/authenticator-types/src/lib.rs +++ b/src/authenticator-types/src/lib.rs @@ -18,14 +18,16 @@ use async_trait::async_trait; /// An authentication session represents a duration of time during which a /// user's authentication is known to be valid. /// -/// [`OidcAuthSessionHandle::external_metadata_rx`] can be used to receive events if -/// the session's metadata is updated. -/// /// [`OidcAuthSessionHandle::expired`] can be used to learn if the session has /// failed to refresh the validity of the API key. #[async_trait] pub trait OidcAuthSessionHandle: Debug + Send { /// Returns the name of the user that created the session. + // In particular, it's important that the username comes from the + // auth session, as the OIDC authenticator may return a user with + // different casing than the user supplied via the pgwire + // username field. We want to use the IdP's casing as + // canonical. fn user(&self) -> &str; /// Completes when the authentication session has expired. async fn expired(&mut self); diff --git a/src/authenticator/src/oidc.rs b/src/authenticator/src/oidc.rs index 52fcaff9f0757..a8f7f9fc780e2 100644 --- a/src/authenticator/src/oidc.rs +++ b/src/authenticator/src/oidc.rs @@ -27,7 +27,7 @@ use tracing::warn; /// Command line arguments for OIDC authentication. #[derive(Debug, Clone)] pub struct OidcConfig { - /// OIDC issuer URL (e.g., "https://accounts.google.com"). + /// OIDC issuer URL (e.g., ""). /// This is validated against the `iss` claim in the JWT. pub oidc_issuer: String, } diff --git a/src/environmentd/src/environmentd/main.rs b/src/environmentd/src/environmentd/main.rs index 879768bbbb30b..97cf4d8e4dafc 100644 --- a/src/environmentd/src/environmentd/main.rs +++ b/src/environmentd/src/environmentd/main.rs @@ -173,7 +173,7 @@ pub struct Args { #[clap(flatten)] frontegg: FronteggCliArgs, // === OIDC options. === - /// OIDC issuer URL (e.g., "https://accounts.google.com"). + /// OIDC issuer URL (e.g., ""). #[clap(long, env = "MZ_OIDC_ISSUER")] oidc_issuer: Option, // === Orchestrator options. === diff --git a/src/environmentd/tests/auth.rs b/src/environmentd/tests/auth.rs index ec3526555439b..d877472090331 100644 --- a/src/environmentd/tests/auth.rs +++ b/src/environmentd/tests/auth.rs @@ -1273,7 +1273,7 @@ async fn test_auth_base_require_tls_frontegg() { /// Tests OIDC authentication with TLS required. /// -/// This test verifies that users can authenticate using JWT tokens +/// This test verifies that users can authenticate using OIDC tokens /// over TLS connections #[allow(clippy::unit_arg)] #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] @@ -1284,7 +1284,6 @@ async fn test_auth_base_require_tls_oidc() { .request_cert("server", vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]) .unwrap(); - // Get PEM string from CA's key (same pattern as Frontegg TLS test). let encoding_key = String::from_utf8(ca.pkey.private_key_to_pem_pkcs8().unwrap()).unwrap(); let kid = "test-key-1".to_string(); @@ -1313,10 +1312,8 @@ async fn test_auth_base_require_tls_oidc() { "https://wrong-issuer.com", ); - // Create Bearer auth header for HTTP tests. let oidc_bearer = Authorization::bearer(&jwt_token).unwrap(); let oidc_header_bearer = make_header(oidc_bearer); - let oidc_header_basic = make_header(Authorization::basic(oidc_user, &jwt_token)); let server = test_util::TestHarness::default() diff --git a/src/oidc-mock/src/lib.rs b/src/oidc-mock/src/lib.rs index cad82d2f7122d..4770a1f11d170 100644 --- a/src/oidc-mock/src/lib.rs +++ b/src/oidc-mock/src/lib.rs @@ -155,7 +155,7 @@ impl OidcMockServer { /// * `email` - Optional email claim. pub fn generate_jwt(&self, sub: &str, email: Option<&str>) -> String { let now_ms = (self.now)(); - let now_secs = (now_ms / 1000) as i64; + let now_secs = i64::try_from(now_ms / 1000).expect("timestamp must fit in i64"); let claims = OidcClaims { sub: sub.to_string(), @@ -174,7 +174,7 @@ impl OidcMockServer { /// Generates a JWT token with a specific expiration time. pub fn generate_jwt_with_exp(&self, sub: &str, email: Option<&str>, exp: i64) -> String { let now_ms = (self.now)(); - let now_secs = (now_ms / 1000) as i64; + let now_secs = i64::try_from(now_ms / 1000).expect("timestamp must fit in i64"); let claims = OidcClaims { sub: sub.to_string(), @@ -193,7 +193,7 @@ impl OidcMockServer { /// Generates a JWT token with a custom issuer (for testing invalid tokens). pub fn generate_jwt_with_issuer(&self, sub: &str, email: Option<&str>, issuer: &str) -> String { let now_ms = (self.now)(); - let now_secs = (now_ms / 1000) as i64; + let now_secs = i64::try_from(now_ms / 1000).expect("timestamp must fit in i64"); let claims = OidcClaims { sub: sub.to_string(), diff --git a/src/pgwire/src/protocol.rs b/src/pgwire/src/protocol.rs index 525178bcafd3a..2b3d36da5e5b6 100644 --- a/src/pgwire/src/protocol.rs +++ b/src/pgwire/src/protocol.rs @@ -164,8 +164,9 @@ where // If oidc_auth_enabled exists as an option, return its value and filter it from // the remaining options. - // TODO (SangJunBak): Use oidc_auth_enabled boolean to implement OIDC flow instead of - // OIDC authenticator kind. + // TODO (SangJunBak): Use oidc_auth_enabled boolean and Authenticator::OIDC + // to decide whether or not we want to use OIDC authentication or password + // authentication. let (_oidc_auth_enabled, options) = extract_oidc_auth_enabled_from_options(options); // TODO move this somewhere it can be shared with HTTP @@ -201,13 +202,6 @@ where let auth_response = frontegg.authenticate(&user, &password).await; match auth_response { Ok(mut auth_session) => { - // Create a session based on the auth session. - // - // In particular, it's important that the username come from the - // auth session, as Frontegg may return an email address with - // different casing than the user supplied via the pgwire - // username field. We want to use the Frontegg casing as - // canonical. let session = adapter_client.new_session(SessionConfig { conn_id: conn.conn_id().clone(), uuid: conn_uuid, @@ -231,7 +225,6 @@ where } } Authenticator::Oidc(oidc) => { - tracing::info!("OIDC authentication"); // OIDC authentication: JWT sent as password in cleartext flow let jwt = match request_cleartext_password(conn).await { Ok(password) => password, @@ -241,19 +234,10 @@ where } }; - tracing::info!("JWT: {}", jwt); - let auth_response = oidc.authenticate(&user, &jwt).await; match auth_response { Ok(mut auth_session) => { - // Create a session based on the auth session. - // - // In particular, it's important that the username come from the - // auth session, as Frontegg may return an email address with - // different casing than the user supplied via the pgwire - // username field. We want to use the Frontegg casing as - // canonical. let session = adapter_client.new_session(SessionConfig { conn_id: conn.conn_id().clone(), uuid: conn_uuid, diff --git a/src/sqllogictest/src/runner.rs b/src/sqllogictest/src/runner.rs index 60641b1811a06..75b8aaf15d5a2 100644 --- a/src/sqllogictest/src/runner.rs +++ b/src/sqllogictest/src/runner.rs @@ -1171,6 +1171,7 @@ impl<'a> RunnerInner<'a> { cloud_resource_controller: None, tls: None, frontegg: None, + oidc: None, cors_allowed_origin: AllowOrigin::list([]), unsafe_mode: true, all_features: false, From 9e19191e94bed35eb5db76ab4fdce3eb238cbc6d Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Wed, 21 Jan 2026 19:20:48 -0500 Subject: [PATCH 08/15] Fetch jwks uri from openid-configuration endpoint Before we were mistakenly fetching from jwks.json instead of getting it from the configuration endpoint --- src/authenticator/src/oidc.rs | 68 +++++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/src/authenticator/src/oidc.rs b/src/authenticator/src/oidc.rs index a8f7f9fc780e2..c28f3c3466116 100644 --- a/src/authenticator/src/oidc.rs +++ b/src/authenticator/src/oidc.rs @@ -23,6 +23,7 @@ use reqwest::Client as HttpClient; use serde::{Deserialize, Serialize}; use tracing::warn; +use url::Url; /// Command line arguments for OIDC authentication. #[derive(Debug, Clone)] @@ -36,7 +37,9 @@ pub struct OidcConfig { #[derive(Debug)] pub enum OidcError { /// Failed to parse OIDC configuration URL. - InvalidConfigUrl(url::ParseError), + InvalidIssuerUrl(url::ParseError), + /// Failed to fetch OpenID configuration from provider. + OpenIdConfigFetchFailed(String), /// Failed to fetch JWKS from provider. JwksFetchFailed(String), /// The key ID is missing in the token header. @@ -52,8 +55,11 @@ pub enum OidcError { impl std::fmt::Display for OidcError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - OidcError::InvalidConfigUrl(e) => { - write!(f, "failed to parse OIDC configuration URL: {}", e) + OidcError::InvalidIssuerUrl(e) => { + write!(f, "failed to parse OIDC issuer URL: {}", e) + } + OidcError::OpenIdConfigFetchFailed(e) => { + write!(f, "failed to fetch OpenID configuration: {}", e) } OidcError::JwksFetchFailed(e) => write!(f, "failed to fetch JWKS: {}", e), OidcError::MissingKid => write!(f, "missing key ID in token header"), @@ -130,10 +136,17 @@ pub struct GenericOidcAuthenticator { inner: Arc, } +/// OpenID Connect Discovery document. +/// See: +#[derive(Debug, Deserialize)] +struct OpenIdConfiguration { + /// URL of the JWKS endpoint. + jwks_uri: String, +} + #[derive(Debug)] pub struct GenericOidcAuthenticatorInner { issuer: String, - jwks_uri: String, decoding_keys: Mutex>, http_client: HttpClient, } @@ -141,34 +154,53 @@ pub struct GenericOidcAuthenticatorInner { impl GenericOidcAuthenticator { /// Create a new [`GenericOidcAuthenticator`] from [`OidcConfig`]. pub fn new(config: OidcConfig) -> Result { - let issuer_url = - url::Url::parse(&config.oidc_issuer).map_err(OidcError::InvalidConfigUrl)?; - - // TODO (SangJunBak): Add a configuration variable for the JWKS set and - // a boolean jwksFetchFromIssuer. - let jwks_uri = issuer_url - .join(".well-known/jwks.json") - .map_err(OidcError::InvalidConfigUrl)? - .to_string(); + let http_client = HttpClient::new(); Ok(Self { inner: Arc::new(GenericOidcAuthenticatorInner { issuer: config.oidc_issuer, - jwks_uri, decoding_keys: Mutex::new(BTreeMap::new()), - // TODO: Use same client code as frontegg-auth. - http_client: HttpClient::new(), + http_client, }), }) } } impl GenericOidcAuthenticatorInner { + async fn fetch_jwks_uri(&self) -> Result { + let openid_config_url = Url::parse(&self.issuer) + .and_then(|url| url.join(".well-known/openid-configuration")) + .map_err(OidcError::InvalidIssuerUrl)?; + + // Fetch OpenID configuration to get the JWKS URI + let response = self + .http_client + .get(openid_config_url) + .timeout(Duration::from_secs(10)) + .send() + .await + .map_err(|e| OidcError::OpenIdConfigFetchFailed(e.to_string()))?; + + if !response.status().is_success() { + return Err(OidcError::OpenIdConfigFetchFailed(format!( + "HTTP {}", + response.status() + ))); + } + + let openid_config: OpenIdConfiguration = response + .json() + .await + .map_err(|e| OidcError::OpenIdConfigFetchFailed(e.to_string()))?; + + Ok(openid_config.jwks_uri) + } /// Fetch JWKS from the provider and parse into a map of key IDs to decoding keys. - async fn fetch_jwks(self: &Self) -> Result, OidcError> { + async fn fetch_jwks(&self) -> Result, OidcError> { + let jwks_uri = self.fetch_jwks_uri().await?; let response = self .http_client - .get(&self.jwks_uri) + .get(&jwks_uri) .timeout(Duration::from_secs(10)) .send() .await From 92eb3d93dde5f59ab3543319d374c89feedff182 Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Wed, 21 Jan 2026 19:35:51 -0500 Subject: [PATCH 09/15] Annotate mz imports with default-features --- Cargo.lock | 2 -- src/authenticator-types/Cargo.toml | 4 +--- src/authenticator/Cargo.toml | 2 +- src/balancerd/Cargo.toml | 2 +- src/environmentd/Cargo.toml | 4 ++-- src/frontegg-auth/Cargo.toml | 2 +- src/oidc-mock/Cargo.toml | 2 +- src/pgwire/Cargo.toml | 2 +- 8 files changed, 8 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f20e6880177ca..b47104a68fa39 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5611,8 +5611,6 @@ version = "0.0.0" dependencies = [ "async-trait", "mz-repr", - "tokio", - "uuid", "workspace-hack", ] diff --git a/src/authenticator-types/Cargo.toml b/src/authenticator-types/Cargo.toml index 71430fdfd7617..c3dd86dbeafb8 100644 --- a/src/authenticator-types/Cargo.toml +++ b/src/authenticator-types/Cargo.toml @@ -11,9 +11,7 @@ workspace = true [dependencies] async-trait = "0.1.89" -mz-repr = { path = "../repr" } -tokio = { version = "1.48.0", features = ["macros"] } -uuid = { version = "1.19.0", features = ["serde"] } +mz-repr = { path = "../repr", default-features = false } workspace-hack = { version = "0.0.0", path = "../workspace-hack", optional = true } [package.metadata.cargo-udeps.ignore] diff --git a/src/authenticator/Cargo.toml b/src/authenticator/Cargo.toml index b7976b026196c..e3ad7bbcd9a95 100644 --- a/src/authenticator/Cargo.toml +++ b/src/authenticator/Cargo.toml @@ -11,7 +11,7 @@ publish = false async-trait = "0.1" jsonwebtoken = "9.3.1" mz-adapter = { path = "../adapter", default-features = false } -mz-authenticator-types = { path = "../authenticator-types" } +mz-authenticator-types = { path = "../authenticator-types", default-features = false } mz-frontegg-auth = { path = "../frontegg-auth", default-features = false } reqwest = "0.12.24" serde = { version = "1.0.219", features = ["derive"] } diff --git a/src/balancerd/Cargo.toml b/src/balancerd/Cargo.toml index 245a77bc63f61..e475673a4359f 100644 --- a/src/balancerd/Cargo.toml +++ b/src/balancerd/Cargo.toml @@ -27,7 +27,7 @@ jsonwebtoken = "9.3.1" launchdarkly-server-sdk = { version = "2.6.2", default-features = false } mz-alloc = { path = "../alloc" } mz-alloc-default = { path = "../alloc-default", optional = true } -mz-authenticator-types = { path = "../authenticator-types" } +mz-authenticator-types = { path = "../authenticator-types", default-features = false } mz-build-info = { path = "../build-info" } mz-dyncfg-launchdarkly = { path = "../dyncfg-launchdarkly" } mz-dyncfg-file= { path = "../dyncfg-file" } diff --git a/src/environmentd/Cargo.toml b/src/environmentd/Cargo.toml index d3dcb3a2ec6ba..1cfa4a520f0bf 100644 --- a/src/environmentd/Cargo.toml +++ b/src/environmentd/Cargo.toml @@ -43,7 +43,7 @@ mz-alloc = { path = "../alloc" } mz-alloc-default = { path = "../alloc-default", optional = true } mz-auth = { path = "../auth" } mz-authenticator = { path = "../authenticator" } -mz-authenticator-types = { path = "../authenticator-types" } +mz-authenticator-types = { path = "../authenticator-types", default-features = false } mz-aws-secrets-controller = { path = "../aws-secrets-controller" } mz-build-info = { path = "../build-info" } mz-adapter = { path = "../adapter" } @@ -56,7 +56,7 @@ mz-dyncfg = { path = "../dyncfg" } mz-dyncfgs = { path = "../dyncfgs" } mz-frontegg-auth = { path = "../frontegg-auth" } mz-frontegg-mock = { path = "../frontegg-mock", optional = true } -mz-oidc-mock = { path = "../oidc-mock", optional = true } +mz-oidc-mock = { path = "../oidc-mock", default-features = false, optional = true } mz-http-util = { path = "../http-util" } mz-interchange = { path = "../interchange" } mz-license-keys = { path = "../license-keys" } diff --git a/src/frontegg-auth/Cargo.toml b/src/frontegg-auth/Cargo.toml index f8a0d89e30b9b..e62ede194a0f5 100644 --- a/src/frontegg-auth/Cargo.toml +++ b/src/frontegg-auth/Cargo.toml @@ -18,7 +18,7 @@ derivative = "2.2.0" futures = "0.3.31" jsonwebtoken = "9.3.1" lru = "0.16.3" -mz-authenticator-types = { path = "../authenticator-types" } +mz-authenticator-types = { path = "../authenticator-types", default-features = false } mz-ore = { path = "../ore", features = ["network", "metrics"] } mz-repr = { path = "../repr" } prometheus = { version = "0.14.0", default-features = false } diff --git a/src/oidc-mock/Cargo.toml b/src/oidc-mock/Cargo.toml index 76b0fdd3a66c3..223a0390952dd 100644 --- a/src/oidc-mock/Cargo.toml +++ b/src/oidc-mock/Cargo.toml @@ -15,7 +15,7 @@ anyhow = "1.0.100" axum = "0.7.5" base64 = "0.22.1" jsonwebtoken = "9.3.1" -mz-authenticator = { path = "../authenticator" } +mz-authenticator = { path = "../authenticator", default-features = false } mz-ore = { path = "../ore", default-features = false, features = ["cli"] } openssl = { version = "0.10.75", features = ["vendored"] } serde = { version = "1.0.219", features = ["derive"] } diff --git a/src/pgwire/Cargo.toml b/src/pgwire/Cargo.toml index 7d0b5d503f914..28e114649f302 100644 --- a/src/pgwire/Cargo.toml +++ b/src/pgwire/Cargo.toml @@ -23,7 +23,7 @@ mz-adapter = { path = "../adapter" } mz-adapter-types = { path = "../adapter-types" } mz-auth = { path = "../auth" } mz-authenticator = { path = "../authenticator" } -mz-authenticator-types = { path = "../authenticator-types" } +mz-authenticator-types = { path = "../authenticator-types", default-features = false } mz-frontegg-auth = { path = "../frontegg-auth" } mz-ore = { path = "../ore", features = ["tracing"] } mz-pgcopy = { path = "../pgcopy" } From d0a1751f84cd322c052c19e057944a99094518dc Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Wed, 21 Jan 2026 19:48:17 -0500 Subject: [PATCH 10/15] Refactor OIDC issuer URL documentation and improve password handling in authentication - Updated comments to use backticks for OIDC issuer URL - Changed password handling in oidc http/ws auth to include the username when validating access tokens. - Simplified OIDC mock server structure by removing the base URL field - Remove unneeded assertion on role existence --- src/adapter/src/coord/command_handler.rs | 1 - src/authenticator/src/oidc.rs | 2 +- src/environmentd/src/environmentd/main.rs | 2 +- src/environmentd/src/http.rs | 4 ++-- src/oidc-mock/src/lib.rs | 13 +++++-------- 5 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/adapter/src/coord/command_handler.rs b/src/adapter/src/coord/command_handler.rs index 3d806f74bd670..821f794732dad 100644 --- a/src/adapter/src/coord/command_handler.rs +++ b/src/adapter/src/coord/command_handler.rs @@ -504,7 +504,6 @@ impl Coordinator { Ok(verifier) => { // Success only if role exists, allows login, and a real password hash was used. if login && real_hash.is_some() { - role.expect("login implies role exists"); let _ = tx.send(Ok(SASLVerifyProofResponse { verifier })); } else { let _ = tx.send(Err(make_auth_err(role_present, login))); diff --git a/src/authenticator/src/oidc.rs b/src/authenticator/src/oidc.rs index c28f3c3466116..d6ea6a731f3d9 100644 --- a/src/authenticator/src/oidc.rs +++ b/src/authenticator/src/oidc.rs @@ -28,7 +28,7 @@ use url::Url; /// Command line arguments for OIDC authentication. #[derive(Debug, Clone)] pub struct OidcConfig { - /// OIDC issuer URL (e.g., ""). + /// OIDC issuer URL (e.g., `https://accounts.google.com`). /// This is validated against the `iss` claim in the JWT. pub oidc_issuer: String, } diff --git a/src/environmentd/src/environmentd/main.rs b/src/environmentd/src/environmentd/main.rs index 97cf4d8e4dafc..44df44630701c 100644 --- a/src/environmentd/src/environmentd/main.rs +++ b/src/environmentd/src/environmentd/main.rs @@ -173,7 +173,7 @@ pub struct Args { #[clap(flatten)] frontegg: FronteggCliArgs, // === OIDC options. === - /// OIDC issuer URL (e.g., ""). + /// OIDC issuer URL (e.g., `https://accounts.google.com`). #[clap(long, env = "MZ_OIDC_ISSUER")] oidc_issuer: Option, // === Orchestrator options. === diff --git a/src/environmentd/src/http.rs b/src/environmentd/src/http.rs index a12cc7aeb4dad..b63bf835d0e26 100644 --- a/src/environmentd/src/http.rs +++ b/src/environmentd/src/http.rs @@ -1039,10 +1039,10 @@ async fn auth( let name = claims.username().to_string(); (name, None) } - Some(Credentials::Password { password, .. }) => { + Some(Credentials::Password { username, password }) => { // Allow JWT to be passed as password let claims = oidc - .validate_access_token(&password.0, None) + .validate_access_token(&password.0, Some(&username)) .await .map_err(|_| AuthError::InvalidCredentials)?; let name = claims.username().to_string(); diff --git a/src/oidc-mock/src/lib.rs b/src/oidc-mock/src/lib.rs index 4770a1f11d170..2bfc4e614302d 100644 --- a/src/oidc-mock/src/lib.rs +++ b/src/oidc-mock/src/lib.rs @@ -58,9 +58,8 @@ struct OidcMockContext { /// OIDC mock server for testing. pub struct OidcMockServer { - /// Base URL of the server (e.g., "http://127.0.0.1:12345"). - pub base_url: String, - /// The issuer URL (same as base_url, used for JWT iss claim). + /// The issuer URL. Used as the base URL of the server + /// and as the issuer for JWT iss claim. pub issuer: String, /// Key ID used in JWT headers. pub kid: String, @@ -108,8 +107,7 @@ impl OidcMockServer { let listener = TcpListener::bind(*addr).await.unwrap_or_else(|e| { panic!("error binding to {}: {}", addr, e); }); - let base_url = format!("http://{}", listener.local_addr().unwrap()); - let issuer = base_url.clone(); + let issuer = format!("http://{}", listener.local_addr().unwrap()); // Extract RSA public key components from the decoding key // We need to serialize the public key to get n and e values @@ -133,11 +131,10 @@ impl OidcMockServer { router.into_make_service_with_connect_info::(), ); println!("oidc-mock listening..."); - println!(" HTTP address: {}", base_url); + println!(" HTTP address: {}", issuer); let handle = mz_ore::task::spawn(|| "oidc-mock-server", server.into_future()); Ok(OidcMockServer { - base_url, issuer, kid, encoding_key: encoding_key_typed, @@ -211,7 +208,7 @@ impl OidcMockServer { /// Returns the JWKS URL for this server. pub fn jwks_url(&self) -> String { - format!("{}/.well-known/jwks.json", self.base_url) + format!("{}/.well-known/jwks.json", self.issuer) } } From 021de6439e1eabd87d8dd6410fa5dffa755dc4e7 Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Wed, 21 Jan 2026 18:33:01 -0500 Subject: [PATCH 11/15] Add OIDC audience validation support - Add aud claim - Introduced `oidc_audience` field in OidcConfig to validate JWT's `aud` claim. This follows the spec in the design doc - Added OIDC mock server audience claims support - Added tests for audience validation and when we don't need to --- Cargo.lock | 2 + src/authenticator/Cargo.toml | 4 + src/authenticator/src/oidc.rs | 52 ++++- src/environmentd/src/environmentd/main.rs | 9 +- src/environmentd/tests/auth.rs | 235 +++++++++++++++++++++- src/oidc-mock/src/lib.rs | 62 ++---- 6 files changed, 310 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b47104a68fa39..e0ba849150174 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5597,8 +5597,10 @@ dependencies = [ "mz-adapter", "mz-authenticator-types", "mz-frontegg-auth", + "mz-ore", "reqwest", "serde", + "serde_json", "tokio", "tracing", "url", diff --git a/src/authenticator/Cargo.toml b/src/authenticator/Cargo.toml index e3ad7bbcd9a95..3e04016a97841 100644 --- a/src/authenticator/Cargo.toml +++ b/src/authenticator/Cargo.toml @@ -20,6 +20,10 @@ tracing = "0.1.43" url = "2.5.7" workspace-hack = { version = "0.0.0", path = "../workspace-hack", optional = true } +[dev-dependencies] +mz-ore = { path = "../ore", default-features = false, features = ["test"] } +serde_json = "1.0" + [lints] workspace = true diff --git a/src/authenticator/src/oidc.rs b/src/authenticator/src/oidc.rs index d6ea6a731f3d9..16ba87ac0dca9 100644 --- a/src/authenticator/src/oidc.rs +++ b/src/authenticator/src/oidc.rs @@ -20,7 +20,7 @@ use async_trait::async_trait; use jsonwebtoken::{DecodingKey, Validation, decode, decode_header, jwk::JwkSet}; use mz_authenticator_types::{OidcAuthSessionHandle, OidcAuthenticator}; use reqwest::Client as HttpClient; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use tracing::warn; use url::Url; @@ -31,6 +31,9 @@ pub struct OidcConfig { /// OIDC issuer URL (e.g., `https://accounts.google.com`). /// This is validated against the `iss` claim in the JWT. pub oidc_issuer: String, + /// Optional OIDC audience (client ID). + /// If set, validates that the JWT's `aud` claim contains this value. + pub oidc_audience: Option, } /// Errors that can occur during OIDC authentication. @@ -72,6 +75,22 @@ impl std::fmt::Display for OidcError { impl std::error::Error for OidcError {} +fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + #[derive(Deserialize)] + #[serde(untagged)] + enum StringOrVec { + String(String), + Vec(Vec), + } + + match StringOrVec::deserialize(deserializer)? { + StringOrVec::String(s) => Ok(vec![s]), + StringOrVec::Vec(v) => Ok(v), + } +} /// Claims extracted from a validated JWT. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OidcClaims { @@ -87,6 +106,9 @@ pub struct OidcClaims { /// Email claim (commonly used for username). #[serde(default)] pub email: Option, + /// Audience claim (can be single string or array in JWT). + #[serde(default, deserialize_with = "deserialize_string_or_vec")] + pub aud: Vec, } impl OidcClaims { @@ -147,6 +169,7 @@ struct OpenIdConfiguration { #[derive(Debug)] pub struct GenericOidcAuthenticatorInner { issuer: String, + audience: Option, decoding_keys: Mutex>, http_client: HttpClient, } @@ -159,6 +182,7 @@ impl GenericOidcAuthenticator { Ok(Self { inner: Arc::new(GenericOidcAuthenticatorInner { issuer: config.oidc_issuer, + audience: config.oidc_audience, decoding_keys: Mutex::new(BTreeMap::new()), http_client, }), @@ -284,8 +308,11 @@ impl GenericOidcAuthenticatorInner { // TODO (SangJunBak): Make JWT expiration configurable. let mut validation = Validation::new(header.alg); validation.set_issuer(&[&self.issuer]); - // TODO (SangJunBak): Validate audience based on configuration. - validation.validate_aud = false; + if let Some(ref audience) = self.audience { + validation.set_audience(&[audience]); + } else { + validation.validate_aud = false; + } // Decode and validate the token let token_data = @@ -331,3 +358,22 @@ impl OidcAuthenticator for GenericOidcAuthenticator { self.inner.validate_access_token(token, expected_user).await } } + +#[cfg(test)] +mod tests { + use super::*; + + #[mz_ore::test] + fn test_aud_single_string() { + let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#; + let claims: OidcClaims = serde_json::from_str(json).unwrap(); + assert_eq!(claims.aud, vec!["my-app"]); + } + + #[mz_ore::test] + fn test_aud_array() { + let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#; + let claims: OidcClaims = serde_json::from_str(json).unwrap(); + assert_eq!(claims.aud, vec!["app1", "app2"]); + } +} diff --git a/src/environmentd/src/environmentd/main.rs b/src/environmentd/src/environmentd/main.rs index 44df44630701c..f4a148295ad90 100644 --- a/src/environmentd/src/environmentd/main.rs +++ b/src/environmentd/src/environmentd/main.rs @@ -176,6 +176,10 @@ pub struct Args { /// OIDC issuer URL (e.g., `https://accounts.google.com`). #[clap(long, env = "MZ_OIDC_ISSUER")] oidc_issuer: Option, + /// OIDC audience (client ID). If set, validates that the JWT's `aud` claim + /// contains this value. + #[clap(long, env = "MZ_OIDC_AUDIENCE")] + oidc_audience: Option, // === Orchestrator options. === /// The service orchestrator implementation to use. #[structopt(long, value_enum, env = "ORCHESTRATOR")] @@ -749,7 +753,10 @@ fn run(mut args: Args) -> Result<(), anyhow::Error> { let tls = args.tls.into_config()?; let frontegg = FronteggAuthenticator::from_args(args.frontegg, &metrics_registry)?; let oidc = if let Some(oidc_issuer) = args.oidc_issuer { - Some(GenericOidcAuthenticator::new(OidcConfig { oidc_issuer })?) + Some(GenericOidcAuthenticator::new(OidcConfig { + oidc_issuer, + oidc_audience: args.oidc_audience, + })?) } else { None }; diff --git a/src/environmentd/tests/auth.rs b/src/environmentd/tests/auth.rs index d877472090331..fe6623b510a6a 100644 --- a/src/environmentd/tests/auth.rs +++ b/src/environmentd/tests/auth.rs @@ -48,7 +48,7 @@ use mz_frontegg_auth::{ use mz_frontegg_mock::{ FronteggMockServer, models::ApiToken, models::TenantApiTokenConfig, models::UserConfig, }; -use mz_oidc_mock::OidcMockServer; +use mz_oidc_mock::{GenerateJwtOptions, OidcMockServer}; use mz_ore::error::ErrorExt; use mz_ore::metrics::MetricsRegistry; use mz_ore::now::{NowFn, SYSTEM_TIME}; @@ -1299,17 +1299,36 @@ async fn test_auth_base_require_tls_oidc() { let oidc_auth = GenericOidcAuthenticator::new(OidcConfig { oidc_issuer: oidc_server.issuer.clone(), + oidc_audience: None, }) .unwrap(); let oidc_user = "user@example.com"; - let jwt_token = oidc_server.generate_jwt(oidc_user, Some(oidc_user)); - let expired_token = oidc_server.generate_jwt_with_exp(oidc_user, Some(oidc_user), 0); - let wrong_user_token = oidc_server.generate_jwt("other@example.com", Some("other@example.com")); - let wrong_issuer_token = oidc_server.generate_jwt_with_issuer( + let jwt_token = oidc_server.generate_jwt( oidc_user, - Some(oidc_user), - "https://wrong-issuer.com", + GenerateJwtOptions { + ..Default::default() + }, + ); + let expired_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + exp: Some(0), + ..Default::default() + }, + ); + let wrong_user_token = oidc_server.generate_jwt( + "other@example.com", + GenerateJwtOptions { + ..Default::default() + }, + ); + let wrong_issuer_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + issuer: Some("https://wrong-issuer.com"), + ..Default::default() + }, ); let oidc_bearer = Authorization::bearer(&jwt_token).unwrap(); @@ -1449,6 +1468,208 @@ async fn test_auth_base_require_tls_oidc() { .await; } +/// Tests OIDC audience validation. +/// +/// This test verifies that when an audience is configured, only JWTs with +/// matching `aud` claims are accepted. +#[allow(clippy::unit_arg)] +#[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +async fn test_auth_oidc_audience_validation() { + let ca = Ca::new_root("test ca").unwrap(); + let (server_cert, server_key) = ca + .request_cert("server", vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]) + .unwrap(); + + let encoding_key = String::from_utf8(ca.pkey.private_key_to_pem_pkcs8().unwrap()).unwrap(); + + let kid = "test-key-1".to_string(); + let oidc_server = OidcMockServer::start( + None, + encoding_key, + kid, + SYSTEM_TIME.clone(), + i64::try_from(EXPIRES_IN_SECS).unwrap(), + ) + .await + .unwrap(); + + let expected_audience = "my-app-client-id"; + let oidc_auth = GenericOidcAuthenticator::new(OidcConfig { + oidc_issuer: oidc_server.issuer.clone(), + oidc_audience: Some(expected_audience.to_string()), + }) + .unwrap(); + + let oidc_user = "user@example.com"; + // Token with correct audience + let valid_aud_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + aud: Some(vec![expected_audience.to_string()]), + ..Default::default() + }, + ); + // Token with correct audience among multiple audiences + let valid_multi_aud_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + aud: Some(vec!["other-app".to_string(), expected_audience.to_string()]), + ..Default::default() + }, + ); + // Token with wrong audience + let wrong_aud_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + aud: Some(vec!["wrong-app".to_string()]), + ..Default::default() + }, + ); + // Token with no audience + let no_aud_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + ..Default::default() + }, + ); + + let server = test_util::TestHarness::default() + .with_tls(server_cert, server_key) + .with_oidc_auth(&oidc_auth) + .start() + .await; + + run_tests( + "OIDC Audience Validation", + &server, + &[ + // JWT with correct audience should succeed. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&valid_aud_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // JWT with correct audience among multiple audiences should succeed. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&valid_multi_aud_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // JWT with wrong audience should fail. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&wrong_aud_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!(err.message(), "invalid password"); + assert_eq!(*err.code(), SqlState::INVALID_PASSWORD); + })), + }, + // JWT with no audience should fail when audience is required. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&no_aud_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!(err.message(), "invalid password"); + assert_eq!(*err.code(), SqlState::INVALID_PASSWORD); + })), + }, + ], + ) + .await; +} + +/// Tests OIDC where we don't validate the audience claim. +#[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +async fn test_auth_oidc_audience_optional() { + let ca = Ca::new_root("test ca").unwrap(); + let (server_cert, server_key) = ca + .request_cert("server", vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]) + .unwrap(); + + let encoding_key = String::from_utf8(ca.pkey.private_key_to_pem_pkcs8().unwrap()).unwrap(); + + let kid = "test-key-1".to_string(); + let oidc_server = OidcMockServer::start( + None, + encoding_key, + kid, + SYSTEM_TIME.clone(), + i64::try_from(EXPIRES_IN_SECS).unwrap(), + ) + .await + .unwrap(); + + let oidc_auth = GenericOidcAuthenticator::new(OidcConfig { + oidc_issuer: oidc_server.issuer.clone(), + oidc_audience: None, + }) + .unwrap(); + + let oidc_user = "user@example.com"; + // Token with no audience + let no_aud_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + ..Default::default() + }, + ); + + // Token with any audience + let valid_aud_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + aud: Some(vec!["my-app-client-id".to_string()]), + ..Default::default() + }, + ); + + let server = test_util::TestHarness::default() + .with_tls(server_cert, server_key) + .with_oidc_auth(&oidc_auth) + .start() + .await; + + run_tests( + "OIDC no audience validation", + &server, + &[ + // JWT with no audience should succeed. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&no_aud_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // JWT with any audience should succeed. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&valid_aud_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + ], + ) + .await; +} + #[allow(clippy::unit_arg)] #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` diff --git a/src/oidc-mock/src/lib.rs b/src/oidc-mock/src/lib.rs index 2bfc4e614302d..c3167a38a136a 100644 --- a/src/oidc-mock/src/lib.rs +++ b/src/oidc-mock/src/lib.rs @@ -56,6 +56,19 @@ struct OidcMockContext { jwk: Jwk, } +/// Options for generating JWT tokens. +#[derive(Debug, Clone, Default)] +pub struct GenerateJwtOptions<'a> { + /// Optional email claim. + pub email: Option<&'a str>, + /// Custom expiration time. If None, uses server's default expires_in_secs. + pub exp: Option, + /// Custom issuer. If None, uses server's issuer. + pub issuer: Option<&'a str>, + /// Audience claim. If None, uses empty vec. + pub aud: Option>, +} + /// OIDC mock server for testing. pub struct OidcMockServer { /// The issuer URL. Used as the base URL of the server @@ -149,55 +162,18 @@ impl OidcMockServer { /// # Arguments /// /// * `sub` - Subject (user identifier). - /// * `email` - Optional email claim. - pub fn generate_jwt(&self, sub: &str, email: Option<&str>) -> String { - let now_ms = (self.now)(); - let now_secs = i64::try_from(now_ms / 1000).expect("timestamp must fit in i64"); - - let claims = OidcClaims { - sub: sub.to_string(), - iss: self.issuer.clone(), - exp: now_secs + self.expires_in_secs, - iat: Some(now_secs), - email: email.map(|s| s.to_string()), - }; - - let mut header = Header::new(jsonwebtoken::Algorithm::RS256); - header.kid = Some(self.kid.clone()); - - encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT") - } - - /// Generates a JWT token with a specific expiration time. - pub fn generate_jwt_with_exp(&self, sub: &str, email: Option<&str>, exp: i64) -> String { - let now_ms = (self.now)(); - let now_secs = i64::try_from(now_ms / 1000).expect("timestamp must fit in i64"); - - let claims = OidcClaims { - sub: sub.to_string(), - iss: self.issuer.clone(), - exp, - iat: Some(now_secs), - email: email.map(|s| s.to_string()), - }; - - let mut header = Header::new(jsonwebtoken::Algorithm::RS256); - header.kid = Some(self.kid.clone()); - - encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT") - } - - /// Generates a JWT token with a custom issuer (for testing invalid tokens). - pub fn generate_jwt_with_issuer(&self, sub: &str, email: Option<&str>, issuer: &str) -> String { + /// * `opts` - Optional JWT generation options. Use `Default::default()` for defaults. + pub fn generate_jwt(&self, sub: &str, opts: GenerateJwtOptions<'_>) -> String { let now_ms = (self.now)(); let now_secs = i64::try_from(now_ms / 1000).expect("timestamp must fit in i64"); let claims = OidcClaims { sub: sub.to_string(), - iss: issuer.to_string(), - exp: now_secs + self.expires_in_secs, + iss: opts.issuer.unwrap_or(&self.issuer).to_string(), + exp: opts.exp.unwrap_or(now_secs + self.expires_in_secs), iat: Some(now_secs), - email: email.map(|s| s.to_string()), + email: opts.email.map(|s| s.to_string()), + aud: opts.aud.unwrap_or_default(), }; let mut header = Header::new(jsonwebtoken::Algorithm::RS256); From 1b04b151d4f0b25002a513e8adc49fc2478164e3 Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Mon, 26 Jan 2026 11:45:09 -0500 Subject: [PATCH 12/15] Combine mz-authenticator-types into mz-auth Didn't realize we already had a shared crate! --- Cargo.lock | 19 +++------ Cargo.toml | 2 - src/auth/Cargo.toml | 2 + src/auth/src/lib.rs | 51 ++++++++++++++++++++++++ src/authenticator-types/Cargo.toml | 21 ---------- src/authenticator-types/src/lib.rs | 62 ------------------------------ src/authenticator/Cargo.toml | 2 +- src/authenticator/src/oidc.rs | 2 +- src/balancerd/Cargo.toml | 2 +- src/balancerd/src/lib.rs | 2 +- src/environmentd/Cargo.toml | 3 +- src/environmentd/src/http.rs | 2 +- src/frontegg-auth/Cargo.toml | 2 +- src/frontegg-auth/src/auth.rs | 2 +- src/pgwire/Cargo.toml | 3 +- src/pgwire/src/protocol.rs | 2 +- 16 files changed, 68 insertions(+), 111 deletions(-) delete mode 100644 src/authenticator-types/Cargo.toml delete mode 100644 src/authenticator-types/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index e0ba849150174..efdf0f0c0581b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5577,9 +5577,11 @@ dependencies = [ name = "mz-auth" version = "0.0.0" dependencies = [ + "async-trait", "base64 0.22.1", "itertools 0.14.0", "mz-ore", + "mz-repr", "openssl", "proptest", "proptest-derive", @@ -5595,7 +5597,7 @@ dependencies = [ "async-trait", "jsonwebtoken", "mz-adapter", - "mz-authenticator-types", + "mz-auth", "mz-frontegg-auth", "mz-ore", "reqwest", @@ -5607,15 +5609,6 @@ dependencies = [ "workspace-hack", ] -[[package]] -name = "mz-authenticator-types" -version = "0.0.0" -dependencies = [ - "async-trait", - "mz-repr", - "workspace-hack", -] - [[package]] name = "mz-avro" version = "0.7.0" @@ -5704,7 +5697,7 @@ dependencies = [ "launchdarkly-server-sdk", "mz-alloc", "mz-alloc-default", - "mz-authenticator-types", + "mz-auth", "mz-build-info", "mz-dyncfg", "mz-dyncfg-file", @@ -6330,7 +6323,6 @@ dependencies = [ "mz-alloc-default", "mz-auth", "mz-authenticator", - "mz-authenticator-types", "mz-aws-secrets-controller", "mz-build-info", "mz-catalog", @@ -6588,7 +6580,7 @@ dependencies = [ "futures", "jsonwebtoken", "lru 0.16.3", - "mz-authenticator-types", + "mz-auth", "mz-ore", "mz-repr", "prometheus", @@ -7381,7 +7373,6 @@ dependencies = [ "mz-adapter-types", "mz-auth", "mz-authenticator", - "mz-authenticator-types", "mz-frontegg-auth", "mz-ore", "mz-pgcopy", diff --git a/Cargo.toml b/Cargo.toml index 3cf6846491a6c..94fdee06e5e56 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,6 @@ members = [ "src/audit-log", "src/auth", "src/authenticator", - "src/authenticator-types", "src/avro", "src/aws-secrets-controller", "src/aws-util", @@ -135,7 +134,6 @@ default-members = [ "src/audit-log", "src/auth", "src/authenticator", - "src/authenticator-types", "src/avro", "src/aws-secrets-controller", "src/aws-util", diff --git a/src/auth/Cargo.toml b/src/auth/Cargo.toml index a5742eaaa18a8..24da93f948df5 100644 --- a/src/auth/Cargo.toml +++ b/src/auth/Cargo.toml @@ -10,8 +10,10 @@ publish = false workspace = true [dependencies] +async-trait = "0.1.89" base64 = "0.22.1" mz-ore = { path = "../ore", features = ["test"] } +mz-repr = { path = "../repr", default-features = false } workspace-hack = { version = "0.0.0", path = "../workspace-hack", optional = true } serde = "1.0.219" proptest-derive = "0.7.0" diff --git a/src/auth/src/lib.rs b/src/auth/src/lib.rs index 174a798681d2f..d8fe7180bee4a 100644 --- a/src/auth/src/lib.rs +++ b/src/auth/src/lib.rs @@ -9,3 +9,54 @@ pub mod hash; pub mod password; + +use async_trait::async_trait; +use std::fmt::Debug; + +/// A handle to an authentication session. +/// +/// An authentication session represents a duration of time during which a +/// user's authentication is known to be valid. +/// +/// [`OidcAuthSessionHandle::expired`] can be used to learn if the session has +/// failed to refresh the validity of the API key. +#[async_trait] +pub trait OidcAuthSessionHandle: Debug + Send { + /// Returns the name of the user that created the session. + // In particular, it's important that the username comes from the + // auth session, as the OIDC authenticator may return a user with + // different casing than the user supplied via the pgwire + // username field. We want to use the IdP's casing as + // canonical. + fn user(&self) -> &str; + /// Completes when the authentication session has expired. + async fn expired(&mut self); +} + +#[async_trait] +pub trait OidcAuthenticator { + /// The error type for the authenticator. + type Error; + /// The authenticator's session handle type. + type SessionHandle: OidcAuthSessionHandle; + /// Claims that have been validated by [`OidcAuthenticator::validate_access_token`]. + type ValidatedClaims; + /// Establishes a new authentication session. + /// If successful, returns a [`OidcAuthenticator::SessionHandle`] to the authentication session. + /// Otherwise, returns [`OidcAuthenticator::Error`]. + async fn authenticate( + &self, + expected_user: &str, + password: &str, + ) -> Result; + + /// Validates an access token, returning the validated claims. + /// + /// If `expected_user` is provided, the token's user name is additionally + /// validated to match `expected_user`. + async fn validate_access_token( + &self, + token: &str, + expected_user: Option<&str>, + ) -> Result; +} diff --git a/src/authenticator-types/Cargo.toml b/src/authenticator-types/Cargo.toml deleted file mode 100644 index c3dd86dbeafb8..0000000000000 --- a/src/authenticator-types/Cargo.toml +++ /dev/null @@ -1,21 +0,0 @@ -[package] -name = "mz-authenticator-types" -description = "Shared types for Materialize authentication." -version = "0.0.0" -edition.workspace = true -rust-version.workspace = true -publish = false - -[lints] -workspace = true - -[dependencies] -async-trait = "0.1.89" -mz-repr = { path = "../repr", default-features = false } -workspace-hack = { version = "0.0.0", path = "../workspace-hack", optional = true } - -[package.metadata.cargo-udeps.ignore] -normal = ["workspace-hack"] - -[features] -default = ["workspace-hack"] diff --git a/src/authenticator-types/src/lib.rs b/src/authenticator-types/src/lib.rs deleted file mode 100644 index 25be9de24d95c..0000000000000 --- a/src/authenticator-types/src/lib.rs +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright Materialize, Inc. and contributors. All rights reserved. -// -// Use of this software is governed by the Business Source License -// included in the LICENSE file. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0. - -//! Shared types for Materialize authentication. - -use std::fmt::Debug; - -use async_trait::async_trait; - -/// A handle to an authentication session. -/// -/// An authentication session represents a duration of time during which a -/// user's authentication is known to be valid. -/// -/// [`OidcAuthSessionHandle::expired`] can be used to learn if the session has -/// failed to refresh the validity of the API key. -#[async_trait] -pub trait OidcAuthSessionHandle: Debug + Send { - /// Returns the name of the user that created the session. - // In particular, it's important that the username comes from the - // auth session, as the OIDC authenticator may return a user with - // different casing than the user supplied via the pgwire - // username field. We want to use the IdP's casing as - // canonical. - fn user(&self) -> &str; - /// Completes when the authentication session has expired. - async fn expired(&mut self); -} - -#[async_trait] -pub trait OidcAuthenticator { - /// The error type for the authenticator. - type Error; - /// The authenticator's session handle type. - type SessionHandle: OidcAuthSessionHandle; - /// Claims that have been validated by [`OidcAuthenticator::validate_access_token`]. - type ValidatedClaims; - /// Establishes a new authentication session. - /// If successful, returns a [`OidcAuthenticator::SessionHandle`] to the authentication session. - /// Otherwise, returns [`OidcAuthenticator::Error`]. - async fn authenticate( - &self, - expected_user: &str, - password: &str, - ) -> Result; - - /// Validates an access token, returning the validated claims. - /// - /// If `expected_user` is provided, the token's user name is additionally - /// validated to match `expected_user`. - async fn validate_access_token( - &self, - token: &str, - expected_user: Option<&str>, - ) -> Result; -} diff --git a/src/authenticator/Cargo.toml b/src/authenticator/Cargo.toml index 3e04016a97841..2c693f447e0c2 100644 --- a/src/authenticator/Cargo.toml +++ b/src/authenticator/Cargo.toml @@ -11,7 +11,7 @@ publish = false async-trait = "0.1" jsonwebtoken = "9.3.1" mz-adapter = { path = "../adapter", default-features = false } -mz-authenticator-types = { path = "../authenticator-types", default-features = false } +mz-auth = { path = "../auth", default-features = false } mz-frontegg-auth = { path = "../frontegg-auth", default-features = false } reqwest = "0.12.24" serde = { version = "1.0.219", features = ["derive"] } diff --git a/src/authenticator/src/oidc.rs b/src/authenticator/src/oidc.rs index 16ba87ac0dca9..5ab056d19ace8 100644 --- a/src/authenticator/src/oidc.rs +++ b/src/authenticator/src/oidc.rs @@ -18,7 +18,7 @@ use std::time::Duration; use async_trait::async_trait; use jsonwebtoken::{DecodingKey, Validation, decode, decode_header, jwk::JwkSet}; -use mz_authenticator_types::{OidcAuthSessionHandle, OidcAuthenticator}; +use mz_auth::{OidcAuthSessionHandle, OidcAuthenticator}; use reqwest::Client as HttpClient; use serde::{Deserialize, Deserializer, Serialize}; diff --git a/src/balancerd/Cargo.toml b/src/balancerd/Cargo.toml index e475673a4359f..a6e1fa57ae81d 100644 --- a/src/balancerd/Cargo.toml +++ b/src/balancerd/Cargo.toml @@ -27,7 +27,7 @@ jsonwebtoken = "9.3.1" launchdarkly-server-sdk = { version = "2.6.2", default-features = false } mz-alloc = { path = "../alloc" } mz-alloc-default = { path = "../alloc-default", optional = true } -mz-authenticator-types = { path = "../authenticator-types", default-features = false } +mz-auth = { path = "../auth", default-features = false } mz-build-info = { path = "../build-info" } mz-dyncfg-launchdarkly = { path = "../dyncfg-launchdarkly" } mz-dyncfg-file= { path = "../dyncfg-file" } diff --git a/src/balancerd/src/lib.rs b/src/balancerd/src/lib.rs index 65b9f4a1526da..01ee55488f878 100644 --- a/src/balancerd/src/lib.rs +++ b/src/balancerd/src/lib.rs @@ -39,7 +39,7 @@ use futures::stream::BoxStream; use hyper::StatusCode; use hyper_util::rt::TokioIo; use launchdarkly_server_sdk as ld; -use mz_authenticator_types::OidcAuthenticator; +use mz_auth::OidcAuthenticator; use mz_build_info::{BuildInfo, build_info}; use mz_dyncfg::ConfigSet; use mz_frontegg_auth::Authenticator as FronteggAuthentication; diff --git a/src/environmentd/Cargo.toml b/src/environmentd/Cargo.toml index 1cfa4a520f0bf..9004a79d44694 100644 --- a/src/environmentd/Cargo.toml +++ b/src/environmentd/Cargo.toml @@ -41,9 +41,8 @@ maplit = "1.0.2" mime = "0.3.16" mz-alloc = { path = "../alloc" } mz-alloc-default = { path = "../alloc-default", optional = true } -mz-auth = { path = "../auth" } +mz-auth = { path = "../auth", default-features = false } mz-authenticator = { path = "../authenticator" } -mz-authenticator-types = { path = "../authenticator-types", default-features = false } mz-aws-secrets-controller = { path = "../aws-secrets-controller" } mz-build-info = { path = "../build-info" } mz-adapter = { path = "../adapter" } diff --git a/src/environmentd/src/http.rs b/src/environmentd/src/http.rs index b63bf835d0e26..12f8dbe3e90a7 100644 --- a/src/environmentd/src/http.rs +++ b/src/environmentd/src/http.rs @@ -44,8 +44,8 @@ use hyper_util::rt::TokioIo; use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig}; use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache}; use mz_auth::password::Password; +use mz_auth::{OidcAuthSessionHandle, OidcAuthenticator}; use mz_authenticator::Authenticator; -use mz_authenticator_types::{OidcAuthSessionHandle, OidcAuthenticator}; use mz_controller::ReplicaHttpLocator; use mz_frontegg_auth::Error as FronteggError; use mz_http_util::DynamicFilterTarget; diff --git a/src/frontegg-auth/Cargo.toml b/src/frontegg-auth/Cargo.toml index e62ede194a0f5..47a10d33c8415 100644 --- a/src/frontegg-auth/Cargo.toml +++ b/src/frontegg-auth/Cargo.toml @@ -18,7 +18,7 @@ derivative = "2.2.0" futures = "0.3.31" jsonwebtoken = "9.3.1" lru = "0.16.3" -mz-authenticator-types = { path = "../authenticator-types", default-features = false } +mz-auth = { path = "../auth", default-features = false } mz-ore = { path = "../ore", features = ["network", "metrics"] } mz-repr = { path = "../repr" } prometheus = { version = "0.14.0", default-features = false } diff --git a/src/frontegg-auth/src/auth.rs b/src/frontegg-auth/src/auth.rs index f5774506bebe5..26991e79c0338 100644 --- a/src/frontegg-auth/src/auth.rs +++ b/src/frontegg-auth/src/auth.rs @@ -21,7 +21,7 @@ use futures::FutureExt; use futures::future::Shared; use jsonwebtoken::{Algorithm, DecodingKey, Validation}; use lru::LruCache; -use mz_authenticator_types::{OidcAuthSessionHandle, OidcAuthenticator}; +use mz_auth::{OidcAuthSessionHandle, OidcAuthenticator}; use mz_ore::instrument; use mz_ore::metrics::MetricsRegistry; use mz_ore::now::NowFn; diff --git a/src/pgwire/Cargo.toml b/src/pgwire/Cargo.toml index 28e114649f302..ed7c70fe9bfec 100644 --- a/src/pgwire/Cargo.toml +++ b/src/pgwire/Cargo.toml @@ -21,9 +21,8 @@ futures = "0.3.31" itertools = "0.14.0" mz-adapter = { path = "../adapter" } mz-adapter-types = { path = "../adapter-types" } -mz-auth = { path = "../auth" } +mz-auth = { path = "../auth", default-features = false } mz-authenticator = { path = "../authenticator" } -mz-authenticator-types = { path = "../authenticator-types", default-features = false } mz-frontegg-auth = { path = "../frontegg-auth" } mz-ore = { path = "../ore", features = ["tracing"] } mz-pgcopy = { path = "../pgcopy" } diff --git a/src/pgwire/src/protocol.rs b/src/pgwire/src/protocol.rs index 2b3d36da5e5b6..50b23883bd0aa 100644 --- a/src/pgwire/src/protocol.rs +++ b/src/pgwire/src/protocol.rs @@ -30,8 +30,8 @@ use mz_adapter::{ verify_datum_desc, }; use mz_auth::password::Password; +use mz_auth::{OidcAuthSessionHandle, OidcAuthenticator}; use mz_authenticator::Authenticator; -use mz_authenticator_types::{OidcAuthSessionHandle, OidcAuthenticator}; use mz_ore::cast::CastFrom; use mz_ore::netio::AsyncReady; use mz_ore::now::{EpochMillis, SYSTEM_TIME}; From ab4c25f1d18b4493553e964e2645c63df27df508 Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Mon, 26 Jan 2026 13:31:37 -0500 Subject: [PATCH 13/15] Remove common OIDC authenticator Due to changed restrictions after a discussion with an external advisor, we decided we no longer need to implement the refresh token flow. However, this also means we no longer have the need for a shared OidcAuthenticator trait. --- Cargo.lock | 3 - .../design/20251215_oidc_authentication.md | 16 ++-- src/auth/src/lib.rs | 51 ----------- src/authenticator/Cargo.toml | 1 - src/authenticator/src/oidc.rs | 40 ++------- src/balancerd/Cargo.toml | 1 - src/balancerd/src/lib.rs | 1 - src/environmentd/src/http.rs | 3 +- src/frontegg-auth/Cargo.toml | 1 - src/frontegg-auth/src/auth.rs | 89 ++++++++++--------- src/pgwire/src/protocol.rs | 10 +-- 11 files changed, 61 insertions(+), 155 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index efdf0f0c0581b..34a3f5d607794 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5594,7 +5594,6 @@ dependencies = [ name = "mz-authenticator" version = "0.1.0" dependencies = [ - "async-trait", "jsonwebtoken", "mz-adapter", "mz-auth", @@ -5697,7 +5696,6 @@ dependencies = [ "launchdarkly-server-sdk", "mz-alloc", "mz-alloc-default", - "mz-auth", "mz-build-info", "mz-dyncfg", "mz-dyncfg-file", @@ -6572,7 +6570,6 @@ name = "mz-frontegg-auth" version = "0.0.0" dependencies = [ "anyhow", - "async-trait", "axum", "base64 0.22.1", "clap", diff --git a/doc/developer/design/20251215_oidc_authentication.md b/doc/developer/design/20251215_oidc_authentication.md index abbef1786965a..cbd00c98e00f8 100644 --- a/doc/developer/design/20251215_oidc_authentication.md +++ b/doc/developer/design/20251215_oidc_authentication.md @@ -51,12 +51,6 @@ spec: # Note: Not all JWT providers support .well-known/openid-configuration, # so use jwks directly if the provider doesn't support it. jwksFetchFromIssuer: true - # The OAuth 2.0 token endpoint where Materialize will request new access - # tokens using a refresh token (https://www.rfc-editor.org/rfc/rfc6749.html#section-6). - # Requires `grant_type` of `refresh_token` in the client. - # Optional. If not provided, sessions will expire when the access token expires - # and refresh tokens in the password field will be ignored. - tokenEndpoint: https://dev-123456.okta.com/oauth2/default/v1/token ``` Where in environmentd, it’ll look like so: @@ -82,9 +76,13 @@ When a user first logs in with a valid token, we create a role for them if one d ### Solution proposal: The user should be disabled from logging in when a user is removed from the upstream IDP. However, the database level role should still exist. +When doing pgwire Oidc authentication, we can accept a cleartext password that is the access token. The OIDC authenticator will do JWT authentication on the access token. If the token is expired, the session will not be established. We will not do any invalidation on the session if the session has already been authenticated/established, but the token is expired. Eventually, the token will expire and the user will not be able to authenticate a new session. This creates a tradeoff between security and developer experience, but is acceptable since organizations will have supplemental methods of deprovisioning users outside of the database. This accomplishes disabling a user from logging in, but the database role still existing. + +**Alternative: Use a refresh token flow to invalidate active sessions** + When doing pgwire Oidc authentication, we can accept a cleartext password of the form `access=&refresh=` where `&` is a delimiter and `refresh=` is optional. The OIDC authenticator will then try to authenticate again and fetch a new access token using the refresh token when close to expiration (using the token API URL in the spec above). If the refresh token doesn’t exist, the session will invalidate. This would require users to have their IdP client generate `refresh` tokens. For token expiration checking, in a task, we'll repeatedly wait for `(expiration - now) * 0.8` and see if it's less than a minute. This is also how we check token expiration in the Frontegg authenticator. We'll also implement a config variable to turn off this mechanism and have it default to true. -By suggesting a short time to live for access tokens, this accomplishes invalidating sessions on removal of a user from an IDP. When admins remove a user from an IDP, the next time the user tries to authenticate or refresh their access token, the token API will not allow the user to login but will keep the role in the database. +This approach would enhance security by ensuring that sessions are invalidated once the access token expires. However, it would also introduce additional complexity and degrade the developer experience, as it would require users to configure refresh tokens in their IdP. Additionally, some IdPs may impose rate limits on token refresh operations. By opting for a simpler design, we minimize potential incompatibilities with various IdPs. **Alternative: Use SASL Authentication using the OAUTHBEARER mechanism rather than a cleartext password** @@ -124,12 +122,8 @@ An MVP of what this might look like exists here: [https://github.com/Materialize ### Tests: - Successful login (e2e mzcompose) -- Invalidating the session on access token expiration and no refresh token (Rust unit test) -- A token should successfully refresh if the access token and refresh token are valid (Rust unit test) - Session should error if access token is invalid (Rust unit test) -- Session should error if refresh token is invalid (Rust unit test) - A user shouldn't be able to login as another user (Rust unit test) -- Removing a user from the upstream IDP should invalidate the refresh token (e2e mzcompose) - Platform-check simple login check (platform-check framework) - JWTs should only be accepted when a valid JWK is set (we do not want to accept JWTs that are not signed with a real, cryptographically sound key) diff --git a/src/auth/src/lib.rs b/src/auth/src/lib.rs index d8fe7180bee4a..174a798681d2f 100644 --- a/src/auth/src/lib.rs +++ b/src/auth/src/lib.rs @@ -9,54 +9,3 @@ pub mod hash; pub mod password; - -use async_trait::async_trait; -use std::fmt::Debug; - -/// A handle to an authentication session. -/// -/// An authentication session represents a duration of time during which a -/// user's authentication is known to be valid. -/// -/// [`OidcAuthSessionHandle::expired`] can be used to learn if the session has -/// failed to refresh the validity of the API key. -#[async_trait] -pub trait OidcAuthSessionHandle: Debug + Send { - /// Returns the name of the user that created the session. - // In particular, it's important that the username comes from the - // auth session, as the OIDC authenticator may return a user with - // different casing than the user supplied via the pgwire - // username field. We want to use the IdP's casing as - // canonical. - fn user(&self) -> &str; - /// Completes when the authentication session has expired. - async fn expired(&mut self); -} - -#[async_trait] -pub trait OidcAuthenticator { - /// The error type for the authenticator. - type Error; - /// The authenticator's session handle type. - type SessionHandle: OidcAuthSessionHandle; - /// Claims that have been validated by [`OidcAuthenticator::validate_access_token`]. - type ValidatedClaims; - /// Establishes a new authentication session. - /// If successful, returns a [`OidcAuthenticator::SessionHandle`] to the authentication session. - /// Otherwise, returns [`OidcAuthenticator::Error`]. - async fn authenticate( - &self, - expected_user: &str, - password: &str, - ) -> Result; - - /// Validates an access token, returning the validated claims. - /// - /// If `expected_user` is provided, the token's user name is additionally - /// validated to match `expected_user`. - async fn validate_access_token( - &self, - token: &str, - expected_user: Option<&str>, - ) -> Result; -} diff --git a/src/authenticator/Cargo.toml b/src/authenticator/Cargo.toml index 2c693f447e0c2..78c811634b565 100644 --- a/src/authenticator/Cargo.toml +++ b/src/authenticator/Cargo.toml @@ -8,7 +8,6 @@ rust-version.workspace = true publish = false [dependencies] -async-trait = "0.1" jsonwebtoken = "9.3.1" mz-adapter = { path = "../adapter", default-features = false } mz-auth = { path = "../auth", default-features = false } diff --git a/src/authenticator/src/oidc.rs b/src/authenticator/src/oidc.rs index 5ab056d19ace8..0ca99d5a8c0a1 100644 --- a/src/authenticator/src/oidc.rs +++ b/src/authenticator/src/oidc.rs @@ -16,9 +16,7 @@ use std::collections::BTreeMap; use std::sync::{Arc, Mutex}; use std::time::Duration; -use async_trait::async_trait; use jsonwebtoken::{DecodingKey, Validation, decode, decode_header, jwk::JwkSet}; -use mz_auth::{OidcAuthSessionHandle, OidcAuthenticator}; use reqwest::Client as HttpClient; use serde::{Deserialize, Deserializer, Serialize}; @@ -130,25 +128,6 @@ impl std::fmt::Debug for OidcDecodingKey { } } -/// Session handle for generic OIDC authentication. -#[derive(Debug)] -pub struct GenericOidcSessionHandle { - user: String, -} - -#[async_trait] -impl OidcAuthSessionHandle for GenericOidcSessionHandle { - fn user(&self) -> &str { - &self.user - } - - async fn expired(&mut self) { - // This session never expires - wait forever - // TODO (SangJunBak): Implement expiration. - std::future::pending().await - } -} - /// OIDC Authenticator that validates JWTs using JWKS. /// /// This implementation pre-fetches JWKS at construction time for synchronous @@ -292,7 +271,7 @@ impl GenericOidcAuthenticatorInner { Err(OidcError::NoMatchingKey) } - async fn validate_access_token( + pub async fn validate_access_token( &self, token: &str, expected_user: Option<&str>, @@ -329,28 +308,21 @@ impl GenericOidcAuthenticatorInner { } } -#[async_trait] -impl OidcAuthenticator for GenericOidcAuthenticator { - type Error = OidcError; - type SessionHandle = GenericOidcSessionHandle; - type ValidatedClaims = OidcClaims; - - async fn authenticate( +impl GenericOidcAuthenticator { + pub async fn authenticate( &self, expected_user: &str, password: &str, - ) -> Result { + ) -> Result { // The password is the JWT token let claims = self .validate_access_token(password, Some(expected_user)) .await?; - Ok(GenericOidcSessionHandle { - user: claims.username().to_string(), - }) + Ok(claims) } - async fn validate_access_token( + pub async fn validate_access_token( &self, token: &str, expected_user: Option<&str>, diff --git a/src/balancerd/Cargo.toml b/src/balancerd/Cargo.toml index a6e1fa57ae81d..e58daf43c3cf7 100644 --- a/src/balancerd/Cargo.toml +++ b/src/balancerd/Cargo.toml @@ -27,7 +27,6 @@ jsonwebtoken = "9.3.1" launchdarkly-server-sdk = { version = "2.6.2", default-features = false } mz-alloc = { path = "../alloc" } mz-alloc-default = { path = "../alloc-default", optional = true } -mz-auth = { path = "../auth", default-features = false } mz-build-info = { path = "../build-info" } mz-dyncfg-launchdarkly = { path = "../dyncfg-launchdarkly" } mz-dyncfg-file= { path = "../dyncfg-file" } diff --git a/src/balancerd/src/lib.rs b/src/balancerd/src/lib.rs index 01ee55488f878..67c898939731b 100644 --- a/src/balancerd/src/lib.rs +++ b/src/balancerd/src/lib.rs @@ -39,7 +39,6 @@ use futures::stream::BoxStream; use hyper::StatusCode; use hyper_util::rt::TokioIo; use launchdarkly_server_sdk as ld; -use mz_auth::OidcAuthenticator; use mz_build_info::{BuildInfo, build_info}; use mz_dyncfg::ConfigSet; use mz_frontegg_auth::Authenticator as FronteggAuthentication; diff --git a/src/environmentd/src/http.rs b/src/environmentd/src/http.rs index 12f8dbe3e90a7..978d265b846b5 100644 --- a/src/environmentd/src/http.rs +++ b/src/environmentd/src/http.rs @@ -44,7 +44,6 @@ use hyper_util::rt::TokioIo; use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig}; use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache}; use mz_auth::password::Password; -use mz_auth::{OidcAuthSessionHandle, OidcAuthenticator}; use mz_authenticator::Authenticator; use mz_controller::ReplicaHttpLocator; use mz_frontegg_auth::Error as FronteggError; @@ -994,7 +993,7 @@ async fn auth( (name, external_metadata_rx) } Some(Credentials::Token { token }) => { - let claims = frontegg.validate_access_token(&token, None).await?; + let claims = frontegg.validate_access_token(&token, None)?; let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata { user_id: claims.user_id, admin: claims.is_admin, diff --git a/src/frontegg-auth/Cargo.toml b/src/frontegg-auth/Cargo.toml index 47a10d33c8415..c74d5671c9cf0 100644 --- a/src/frontegg-auth/Cargo.toml +++ b/src/frontegg-auth/Cargo.toml @@ -11,7 +11,6 @@ workspace = true [dependencies] anyhow = "1.0.100" -async-trait = "0.1.89" base64 = "0.22.1" clap = { version = "4.5.23", features = ["wrap_help", "env", "derive"] } derivative = "2.2.0" diff --git a/src/frontegg-auth/src/auth.rs b/src/frontegg-auth/src/auth.rs index 26991e79c0338..a80b1a01aa31c 100644 --- a/src/frontegg-auth/src/auth.rs +++ b/src/frontegg-auth/src/auth.rs @@ -15,13 +15,11 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use anyhow::Context as _; -use async_trait::async_trait; use derivative::Derivative; use futures::FutureExt; use futures::future::Shared; use jsonwebtoken::{Algorithm, DecodingKey, Validation}; use lru::LruCache; -use mz_auth::{OidcAuthSessionHandle, OidcAuthenticator}; use mz_ore::instrument; use mz_ore::metrics::MetricsRegistry; use mz_ore::now::NowFn; @@ -168,6 +166,28 @@ impl Authenticator { Ok(Some(Self::new(config, client, registry))) } + /// Establishes a new authentication session. + /// + /// If successful, returns a handle to the authentication session. + /// Otherwise, returns the authentication error. + pub async fn authenticate( + &self, + expected_user: &str, + password: &str, + ) -> Result { + let password: AppPassword = password.parse()?; + match self.authenticate_inner(expected_user, password).await { + Ok(handle) => { + tracing::debug!("authentication successful"); + Ok(handle) + } + Err(e) => { + tracing::debug!(error = ?e, "authentication failed"); + Err(e) + } + } + } + #[instrument(level = "debug", fields(client_id = %password.client_id))] async fn authenticate_inner( &self, @@ -278,34 +298,6 @@ impl Authenticator { }; request.await } -} - -#[async_trait] -impl OidcAuthenticator for Authenticator { - type Error = Error; - type SessionHandle = AuthSessionHandle; - type ValidatedClaims = ValidatedClaims; - /// Establishes a new authentication session. - /// - /// If successful, returns a handle to the authentication session. - /// Otherwise, returns the authentication error. - async fn authenticate( - &self, - expected_user: &str, - password: &str, - ) -> Result { - let password: AppPassword = password.parse()?; - match self.authenticate_inner(expected_user, password).await { - Ok(handle) => { - tracing::debug!("authentication successful"); - Ok(handle) - } - Err(e) => { - tracing::debug!(error = ?e, "authentication failed"); - Err(e) - } - } - } /// Validates an access token, returning the validated claims. /// @@ -317,7 +309,7 @@ impl OidcAuthenticator for Authenticator { /// /// If `expected_user` is provided, the token's user name is additionally /// validated to match `expected_user`. - async fn validate_access_token( + pub fn validate_access_token( &self, token: &str, expected_user: Option<&str>, @@ -325,13 +317,23 @@ impl OidcAuthenticator for Authenticator { self.inner.validate_access_token(token, expected_user) } } + /// A handle to an authentication session. /// +/// An authentication session represents a duration of time during which a +/// user's authentication is known to be valid. +/// /// An authentication session begins with a successful API key exchange with /// Frontegg. While there is at least one outstanding handle to the session, the /// session's metadata and validity are refreshed with Frontegg at a regular /// interval. The session ends when all outstanding handles are dropped and the /// refresh interval is reached. +/// +/// [`AuthSessionHandle::external_metadata_rx`] can be used to receive events if +/// the session's metadata is updated. +/// +/// [`AuthSessionHandle::expired`] can be used to learn if the session has +/// failed to refresh the validity of the API key. #[derive(Debug, Clone)] pub struct AuthSessionHandle { ident: Arc, @@ -342,30 +344,29 @@ pub struct AuthSessionHandle { app_password: AppPassword, } -#[async_trait] -impl OidcAuthSessionHandle for AuthSessionHandle { - fn user(&self) -> &str { +impl AuthSessionHandle { + /// Returns the name of the user that created the session. + pub fn user(&self) -> &str { &self.ident.user } - async fn expired(&mut self) { - // We piggyback on the external metadata channel to determine session - // expiration. The external metadata channel is closed when the session - // expires. - let _ = self.external_metadata_rx.wait_for(|_| false).await; - } -} - -impl AuthSessionHandle { /// Returns the ID of the tenant that created the session. pub fn tenant_id(&self) -> Uuid { self.ident.tenant_id } - /// Returns a receiver for updates to the session user's external metadata. + /// Mints a receiver for updates to the session user's external metadata. pub fn external_metadata_rx(&self) -> watch::Receiver { self.external_metadata_rx.clone() } + + /// Completes when the authentication session has expired. + pub async fn expired(&mut self) { + // We piggyback on the external metadata channel to determine session + // expiration. The external metadata channel is closed when the session + // expires. + let _ = self.external_metadata_rx.wait_for(|_| false).await; + } } impl Drop for AuthSessionHandle { diff --git a/src/pgwire/src/protocol.rs b/src/pgwire/src/protocol.rs index 50b23883bd0aa..b47a2632ea7d0 100644 --- a/src/pgwire/src/protocol.rs +++ b/src/pgwire/src/protocol.rs @@ -30,7 +30,6 @@ use mz_adapter::{ verify_datum_desc, }; use mz_auth::password::Password; -use mz_auth::{OidcAuthSessionHandle, OidcAuthenticator}; use mz_authenticator::Authenticator; use mz_ore::cast::CastFrom; use mz_ore::netio::AsyncReady; @@ -210,7 +209,7 @@ where external_metadata_rx: Some(auth_session.external_metadata_rx()), helm_chart_version, }); - let expired = async move { auth_session.expired().await }.boxed(); + let expired = async move { auth_session.expired().await }; (session, expired.left_future()) } Err(err) => { @@ -237,17 +236,16 @@ where let auth_response = oidc.authenticate(&user, &jwt).await; match auth_response { - Ok(mut auth_session) => { + Ok(auth_session) => { let session = adapter_client.new_session(SessionConfig { conn_id: conn.conn_id().clone(), uuid: conn_uuid, - user: auth_session.user().into(), + user: auth_session.username().into(), client_ip: conn.peer_addr().clone(), external_metadata_rx: None, helm_chart_version, }); - let expired = async move { auth_session.expired().await }.boxed(); - (session, expired.left_future()) + (session, pending().right_future()) } Err(err) => { warn!(?err, "pgwire connection failed authentication"); From 436dd3fa2e09bed475cf067b0b164939e8af3176 Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Mon, 26 Jan 2026 16:08:52 -0500 Subject: [PATCH 14/15] Introduce sentinel type for authentication Before when we had to return internal user metadata data from the auth response, it meant we couldn't forget to call adapter_client.authenticate. By introducing a sentinel type, we make it harder for a developer to. We also combine `validate_access_token` into `authenticate` for GenericOidcAuthenticator. --- Cargo.lock | 1 - src/adapter/src/client.rs | 35 +++++---- src/adapter/src/command.rs | 5 +- src/adapter/src/config/backend.rs | 20 +++--- src/auth/Cargo.toml | 1 - src/auth/src/lib.rs | 10 +++ src/authenticator/src/oidc.rs | 22 ++---- src/balancerd/src/lib.rs | 2 +- src/environmentd/src/http.rs | 65 ++++++++++------- src/environmentd/src/http/sql.rs | 1 + src/environmentd/tests/server.rs | 35 +++++++++ src/frontegg-auth/src/auth.rs | 10 +-- src/pgwire/src/protocol.rs | 114 +++++++++++++++++------------- 13 files changed, 202 insertions(+), 119 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 34a3f5d607794..a089de510d1d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5577,7 +5577,6 @@ dependencies = [ name = "mz-auth" version = "0.0.0" dependencies = [ - "async-trait", "base64 0.22.1", "itertools 0.14.0", "mz-ore", diff --git a/src/adapter/src/client.rs b/src/adapter/src/client.rs index ae90292300ab4..0e50cd33ee5b6 100644 --- a/src/adapter/src/client.rs +++ b/src/adapter/src/client.rs @@ -21,6 +21,7 @@ use derivative::Derivative; use futures::{Stream, StreamExt}; use itertools::Itertools; use mz_adapter_types::connection::{ConnectionId, ConnectionIdType}; +use mz_auth::Authenticated; use mz_auth::password::Password; use mz_build_info::BuildInfo; use mz_compute_types::ComputeInstanceId; @@ -149,19 +150,22 @@ impl Client { /// Creates a new session associated with this client for the given user. /// /// It is the caller's responsibility to have authenticated the user. - pub fn new_session(&self, config: SessionConfig) -> Session { + /// We pass in an Authenticated marker as a guardrail to ensure the + /// user has authenticated with an authenticator before creating a session. + pub fn new_session(&self, config: SessionConfig, _authenticated: Authenticated) -> Session { // We use the system clock to determine when a session connected to Materialize. This is not // intended to be 100% accurate and correct, so we don't burden the timestamp oracle with // generating a more correct timestamp. Session::new(self.build_info, config, self.metrics().session_metrics()) } - /// Preforms an authentication check for the given user. + /// Verifies the provided user's password against the + /// stored credentials in the catalog. pub async fn authenticate( &self, user: &String, password: &Password, - ) -> Result<(), AdapterError> { + ) -> Result { let (tx, rx) = oneshot::channel(); self.send(Command::AuthenticatePassword { role_name: user.to_string(), @@ -169,7 +173,7 @@ impl Client { tx, }); rx.await.expect("sender dropped")?; - Ok(()) + Ok(Authenticated) } pub async fn generate_sasl_challenge( @@ -193,7 +197,7 @@ impl Client { proof: &String, nonce: &String, mock_hash: &String, - ) -> Result { + ) -> Result<(SASLVerifyProofResponse, Authenticated), AdapterError> { let (tx, rx) = oneshot::channel(); self.send(Command::AuthenticateVerifySASLProof { role_name: user.to_string(), @@ -203,7 +207,7 @@ impl Client { tx, }); let response = rx.await.expect("sender dropped")?; - Ok(response) + Ok((response, Authenticated)) } /// Upgrades this client to a session client. @@ -449,14 +453,17 @@ Issue a SQL query to get started. Need help? ) -> Result + Send>>, anyhow::Error> { // Connect to the coordinator. let conn_id = self.new_conn_id()?; - let session = self.new_session(SessionConfig { - conn_id, - uuid: Uuid::new_v4(), - user: SUPPORT_USER.name.clone(), - client_ip: None, - external_metadata_rx: None, - helm_chart_version: None, - }); + let session = self.new_session( + SessionConfig { + conn_id, + uuid: Uuid::new_v4(), + user: SUPPORT_USER.name.clone(), + client_ip: None, + external_metadata_rx: None, + helm_chart_version: None, + }, + Authenticated, + ); let mut session_client = self.startup(session).await?; // Parse the SQL statement. diff --git a/src/adapter/src/command.rs b/src/adapter/src/command.rs index e07728281c1aa..05143a9e49b31 100644 --- a/src/adapter/src/command.rs +++ b/src/adapter/src/command.rs @@ -372,6 +372,9 @@ pub struct Response { pub otel_ctx: OpenTelemetryContext, } +#[derive(Debug, Clone, Copy)] +pub struct SuperuserAttribute(pub Option); + /// The response to [`Client::startup`](crate::Client::startup). #[derive(Derivative)] #[derivative(Debug)] @@ -382,7 +385,7 @@ pub struct StartupResponse { /// This attribute is None for Cloud. Cloud is able /// to derive the role's superuser status from /// external_metadata_rx. - pub superuser_attribute: Option, + pub superuser_attribute: SuperuserAttribute, /// A future that completes when all necessary Builtin Table writes have completed. #[derivative(Debug = "ignore")] pub write_notify: BuiltinTableAppendNotify, diff --git a/src/adapter/src/config/backend.rs b/src/adapter/src/config/backend.rs index fa6a52329bbbc..2d8cd74e1e41f 100644 --- a/src/adapter/src/config/backend.rs +++ b/src/adapter/src/config/backend.rs @@ -9,6 +9,7 @@ use std::collections::BTreeMap; +use mz_auth::Authenticated; use mz_sql::session::user::SYSTEM_USER; use tracing::{error, info}; use uuid::Uuid; @@ -28,14 +29,17 @@ pub struct SystemParameterBackend { impl SystemParameterBackend { pub async fn new(client: Client) -> Result { let conn_id = client.new_conn_id()?; - let session = client.new_session(SessionConfig { - conn_id, - uuid: Uuid::new_v4(), - user: SYSTEM_USER.name.clone(), - client_ip: None, - external_metadata_rx: None, - helm_chart_version: None, - }); + let session = client.new_session( + SessionConfig { + conn_id, + uuid: Uuid::new_v4(), + user: SYSTEM_USER.name.clone(), + client_ip: None, + external_metadata_rx: None, + helm_chart_version: None, + }, + Authenticated, + ); let session_client = client.startup(session).await?; Ok(Self { session_client }) } diff --git a/src/auth/Cargo.toml b/src/auth/Cargo.toml index 24da93f948df5..f912922afb018 100644 --- a/src/auth/Cargo.toml +++ b/src/auth/Cargo.toml @@ -10,7 +10,6 @@ publish = false workspace = true [dependencies] -async-trait = "0.1.89" base64 = "0.22.1" mz-ore = { path = "../ore", features = ["test"] } mz-repr = { path = "../repr", default-features = false } diff --git a/src/auth/src/lib.rs b/src/auth/src/lib.rs index 174a798681d2f..2db42e4360866 100644 --- a/src/auth/src/lib.rs +++ b/src/auth/src/lib.rs @@ -7,5 +7,15 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use serde::{Deserialize, Serialize}; + pub mod hash; pub mod password; + +/// A sentinel type signifying successful authentication. +/// +/// This type is used to establish an authenticated Adapter client session, +/// and should only be constructed by authenticators to indicate that authentication +/// has succeeded. It may also be used when authentication is not required. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Authenticated; diff --git a/src/authenticator/src/oidc.rs b/src/authenticator/src/oidc.rs index 0ca99d5a8c0a1..a4d08f22fb4c9 100644 --- a/src/authenticator/src/oidc.rs +++ b/src/authenticator/src/oidc.rs @@ -17,6 +17,7 @@ use std::sync::{Arc, Mutex}; use std::time::Duration; use jsonwebtoken::{DecodingKey, Validation, decode, decode_header, jwk::JwkSet}; +use mz_auth::Authenticated; use reqwest::Client as HttpClient; use serde::{Deserialize, Deserializer, Serialize}; @@ -310,24 +311,15 @@ impl GenericOidcAuthenticatorInner { impl GenericOidcAuthenticator { pub async fn authenticate( - &self, - expected_user: &str, - password: &str, - ) -> Result { - // The password is the JWT token - let claims = self - .validate_access_token(password, Some(expected_user)) - .await?; - - Ok(claims) - } - - pub async fn validate_access_token( &self, token: &str, expected_user: Option<&str>, - ) -> Result { - self.inner.validate_access_token(token, expected_user).await + ) -> Result<(OidcClaims, Authenticated), OidcError> { + let claims = self + .inner + .validate_access_token(token, expected_user) + .await?; + Ok((claims, Authenticated)) } } diff --git a/src/balancerd/src/lib.rs b/src/balancerd/src/lib.rs index 67c898939731b..259957525bcd2 100644 --- a/src/balancerd/src/lib.rs +++ b/src/balancerd/src/lib.rs @@ -1394,7 +1394,7 @@ impl Resolver { let auth_response = auth.authenticate(user, &password).await; let auth_session = match auth_response { - Ok(auth_session) => auth_session, + Ok((auth_session, _)) => auth_session, Err(e) => { warn!("pgwire connection failed authentication: {}", e); // TODO: fix error codes. diff --git a/src/environmentd/src/http.rs b/src/environmentd/src/http.rs index 978d265b846b5..e5d097b5c5cd8 100644 --- a/src/environmentd/src/http.rs +++ b/src/environmentd/src/http.rs @@ -43,6 +43,7 @@ use hyper_openssl::client::legacy::MaybeHttpsStream; use hyper_util::rt::TokioIo; use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig}; use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache}; +use mz_auth::Authenticated; use mz_auth::password::Password; use mz_authenticator::Authenticator; use mz_controller::ReplicaHttpLocator; @@ -552,6 +553,7 @@ async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl In req.extensions_mut().insert(AuthedUser { name: username, external_metadata_rx: None, + authenticated: Authenticated, }); } Ok(next.run(req).await) @@ -569,6 +571,7 @@ enum ConnProtocol { pub struct AuthedUser { name: String, external_metadata_rx: Option>, + authenticated: Authenticated, } pub struct AuthedClient { @@ -591,14 +594,17 @@ impl AuthedClient { F: FnOnce(&mut AdapterSession), { let conn_id = adapter_client.new_conn_id()?; - let mut session = adapter_client.new_session(AdapterSessionConfig { - conn_id, - uuid: epoch_to_uuid_v7(&(now)()), - user: user.name, - client_ip: Some(peer_addr), - external_metadata_rx: user.external_metadata_rx, - helm_chart_version, - }); + let mut session = adapter_client.new_session( + AdapterSessionConfig { + conn_id, + uuid: epoch_to_uuid_v7(&(now)()), + user: user.name, + client_ip: Some(peer_addr), + external_metadata_rx: user.external_metadata_rx, + helm_chart_version, + }, + user.authenticated, + ); let connection_guard = active_connection_counter.allocate_connection(session.user())?; session_config(&mut session); @@ -759,16 +765,19 @@ pub async fn handle_login( let Ok(adapter_client) = adapter_client_rx.clone().await else { return StatusCode::INTERNAL_SERVER_ERROR; }; - if let Err(err) = adapter_client.authenticate(&username, &password).await { - warn!(?err, "HTTP login failed authentication"); - return StatusCode::UNAUTHORIZED; + let authenticated = match adapter_client.authenticate(&username, &password).await { + Ok(authenticated) => authenticated, + Err(err) => { + warn!(?err, "HTTP login failed authentication"); + return StatusCode::UNAUTHORIZED; + } }; - // Create session data let session_data = TowerSessionData { username, created_at: SystemTime::now(), last_activity: SystemTime::now(), + authenticated, }; // Store session data let session = session.and_then(|Extension(session)| Some(session)); @@ -825,6 +834,7 @@ async fn http_auth( req.extensions_mut().insert(AuthedUser { name: session_data.username, external_metadata_rx: None, + authenticated: session_data.authenticated, }); return Ok(next.run(req).await); } @@ -984,21 +994,22 @@ async fn auth( allowed_roles: AllowedRoles, include_www_authenticate_header: bool, ) -> Result { - let (name, external_metadata_rx) = match authenticator { + let (name, external_metadata_rx, authenticated) = match authenticator { Authenticator::Frontegg(frontegg) => match creds { Some(Credentials::Password { username, password }) => { - let auth_session = frontegg.authenticate(&username, &password.0).await?; + let (auth_session, authenticated) = + frontegg.authenticate(&username, &password.0).await?; let name = auth_session.user().into(); let external_metadata_rx = Some(auth_session.external_metadata_rx()); - (name, external_metadata_rx) + (name, external_metadata_rx, authenticated) } Some(Credentials::Token { token }) => { - let claims = frontegg.validate_access_token(&token, None)?; + let (claims, authenticated) = frontegg.validate_access_token(&token, None)?; let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata { user_id: claims.user_id, admin: claims.is_admin, }); - (claims.user, Some(external_metadata_rx)) + (claims.user, Some(external_metadata_rx), authenticated) } None => { return Err(AuthError::MissingHttpAuthentication { @@ -1008,11 +1019,11 @@ async fn auth( }, Authenticator::Password(adapter_client) => match creds { Some(Credentials::Password { username, password }) => { - adapter_client + let authenticated = adapter_client .authenticate(&username, &password) .await .map_err(|_| AuthError::InvalidCredentials)?; - (username, None) + (username, None, authenticated) } _ => { return Err(AuthError::MissingHttpAuthentication { @@ -1031,21 +1042,21 @@ async fn auth( Authenticator::Oidc(oidc) => match creds { Some(Credentials::Token { token }) => { // Validate JWT token - let claims = oidc - .validate_access_token(&token, None) + let (claims, authenticated) = oidc + .authenticate(&token, None) .await .map_err(|_| AuthError::InvalidCredentials)?; let name = claims.username().to_string(); - (name, None) + (name, None, authenticated) } Some(Credentials::Password { username, password }) => { // Allow JWT to be passed as password - let claims = oidc - .validate_access_token(&password.0, Some(&username)) + let (claims, authenticated) = oidc + .authenticate(&password.0, Some(&username)) .await .map_err(|_| AuthError::InvalidCredentials)?; let name = claims.username().to_string(); - (name, None) + (name, None, authenticated) } None => { return Err(AuthError::MissingHttpAuthentication { @@ -1061,7 +1072,7 @@ async fn auth( Some(Credentials::Password { username, .. }) => username, _ => HTTP_DEFAULT_USER.name.to_owned(), }; - (name, None) + (name, None, Authenticated) } }; @@ -1070,6 +1081,7 @@ async fn auth( Ok(AuthedUser { name, external_metadata_rx, + authenticated, }) } @@ -1136,6 +1148,7 @@ pub struct TowerSessionData { username: String, created_at: SystemTime, last_activity: SystemTime, + authenticated: Authenticated, } #[cfg(test)] diff --git a/src/environmentd/src/http/sql.rs b/src/environmentd/src/http/sql.rs index d7d4b096c427f..5604e8c703913 100644 --- a/src/environmentd/src/http/sql.rs +++ b/src/environmentd/src/http/sql.rs @@ -311,6 +311,7 @@ pub async fn handle_sql_ws( Some(AuthedUser { name: session_data.username, external_metadata_rx: None, + authenticated: session_data.authenticated, }) } else { None diff --git a/src/environmentd/tests/server.rs b/src/environmentd/tests/server.rs index 8eb0753f39f91..bbc5da419f119 100644 --- a/src/environmentd/tests/server.rs +++ b/src/environmentd/tests/server.rs @@ -2647,6 +2647,24 @@ fn test_internal_http_auth() { // can be explicitly set to mz_system assert!(res.text().unwrap().to_string().contains("mz_system")); + // Check that mz_system is a superuser + let json_superuser = serde_json::json!({"query": "SHOW is_superuser;"}); + let res = Client::new() + .post(url.clone()) + .header("x-materialize-user", "mz_system") + .json(&json_superuser) + .send() + .unwrap(); + + tracing::info!("response: {res:?}"); + assert_eq!( + res.status(), + StatusCode::OK, + "{:?}", + res.json::() + ); + assert!(res.text().unwrap().to_string().contains("on")); + let res = Client::new() .post(url.clone()) .header("x-materialize-user", "mz_support") @@ -2664,6 +2682,23 @@ fn test_internal_http_auth() { // can be explicitly set to mz_support assert!(res.text().unwrap().to_string().contains("mz_support")); + // Check that mz_support is not a superuser + let res = Client::new() + .post(url.clone()) + .header("x-materialize-user", "mz_support") + .json(&json_superuser) + .send() + .unwrap(); + + tracing::info!("response: {res:?}"); + assert_eq!( + res.status(), + StatusCode::OK, + "{:?}", + res.json::() + ); + assert!(res.text().unwrap().to_string().contains("off")); + let res = Client::new() .post(url.clone()) .header("x-materialize-user", "invalid value") diff --git a/src/frontegg-auth/src/auth.rs b/src/frontegg-auth/src/auth.rs index a80b1a01aa31c..bf32fe8e46268 100644 --- a/src/frontegg-auth/src/auth.rs +++ b/src/frontegg-auth/src/auth.rs @@ -20,6 +20,7 @@ use futures::FutureExt; use futures::future::Shared; use jsonwebtoken::{Algorithm, DecodingKey, Validation}; use lru::LruCache; +use mz_auth::Authenticated; use mz_ore::instrument; use mz_ore::metrics::MetricsRegistry; use mz_ore::now::NowFn; @@ -174,12 +175,12 @@ impl Authenticator { &self, expected_user: &str, password: &str, - ) -> Result { + ) -> Result<(AuthSessionHandle, Authenticated), Error> { let password: AppPassword = password.parse()?; match self.authenticate_inner(expected_user, password).await { Ok(handle) => { tracing::debug!("authentication successful"); - Ok(handle) + Ok((handle, Authenticated)) } Err(e) => { tracing::debug!(error = ?e, "authentication failed"); @@ -313,8 +314,9 @@ impl Authenticator { &self, token: &str, expected_user: Option<&str>, - ) -> Result { - self.inner.validate_access_token(token, expected_user) + ) -> Result<(ValidatedClaims, Authenticated), Error> { + let claims = self.inner.validate_access_token(token, expected_user)?; + Ok((claims, Authenticated)) } } diff --git a/src/pgwire/src/protocol.rs b/src/pgwire/src/protocol.rs index b47a2632ea7d0..aeb0fb8a2cff0 100644 --- a/src/pgwire/src/protocol.rs +++ b/src/pgwire/src/protocol.rs @@ -29,6 +29,7 @@ use mz_adapter::{ AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics, verify_datum_desc, }; +use mz_auth::Authenticated; use mz_auth::password::Password; use mz_authenticator::Authenticator; use mz_ore::cast::CastFrom; @@ -200,15 +201,18 @@ where let auth_response = frontegg.authenticate(&user, &password).await; match auth_response { - Ok(mut auth_session) => { - let session = adapter_client.new_session(SessionConfig { - conn_id: conn.conn_id().clone(), - uuid: conn_uuid, - user: auth_session.user().into(), - client_ip: conn.peer_addr().clone(), - external_metadata_rx: Some(auth_session.external_metadata_rx()), - helm_chart_version, - }); + Ok((mut auth_session, authenticated)) => { + let session = adapter_client.new_session( + SessionConfig { + conn_id: conn.conn_id().clone(), + uuid: conn_uuid, + user: auth_session.user().into(), + client_ip: conn.peer_addr().clone(), + external_metadata_rx: Some(auth_session.external_metadata_rx()), + helm_chart_version, + }, + authenticated, + ); let expired = async move { auth_session.expired().await }; (session, expired.left_future()) } @@ -233,18 +237,22 @@ where } }; - let auth_response = oidc.authenticate(&user, &jwt).await; - + let auth_response = oidc.authenticate(&jwt, Some(&user)).await; match auth_response { - Ok(auth_session) => { - let session = adapter_client.new_session(SessionConfig { - conn_id: conn.conn_id().clone(), - uuid: conn_uuid, - user: auth_session.username().into(), - client_ip: conn.peer_addr().clone(), - external_metadata_rx: None, - helm_chart_version, - }); + Ok((claims, authenticated)) => { + let session = adapter_client.new_session( + SessionConfig { + conn_id: conn.conn_id().clone(), + uuid: conn_uuid, + user: claims.username().into(), + client_ip: conn.peer_addr().clone(), + external_metadata_rx: None, + helm_chart_version, + }, + authenticated, + ); + // No invalidation of the auth session once authenticated, + // so auth session lasts indefinitely. (session, pending().right_future()) } Err(err) => { @@ -266,7 +274,7 @@ where return conn.send(e).await; } }; - match adapter_client.authenticate(&user, &password).await { + let authenticated = match adapter_client.authenticate(&user, &password).await { Ok(resp) => resp, Err(err) => { warn!(?err, "pgwire connection failed authentication"); @@ -278,14 +286,17 @@ where .await; } }; - let session = adapter_client.new_session(SessionConfig { - conn_id: conn.conn_id().clone(), - uuid: conn_uuid, - user, - client_ip: conn.peer_addr().clone(), - external_metadata_rx: None, - helm_chart_version, - }); + let session = adapter_client.new_session( + SessionConfig { + conn_id: conn.conn_id().clone(), + uuid: conn_uuid, + user, + client_ip: conn.peer_addr().clone(), + external_metadata_rx: None, + helm_chart_version, + }, + authenticated, + ); // No frontegg check, so auth session lasts indefinitely. let auth_session = pending().right_future(); (session, auth_session) @@ -388,7 +399,7 @@ where } }; - match conn.recv().await? { + let authenticated = match conn.recv().await? { Some(FrontendMessage::RawAuthentication(data)) => { match decode_sasl_response(Cursor::new(&data)).ok() { Some(FrontendMessage::SASLResponse(response)) => { @@ -415,17 +426,18 @@ where ) .await { - Ok(resp) => { + Ok((proof_response, authenticated)) => { conn.send(BackendMessage::AuthenticationSASLFinal( SASLServerFinalMessage { kind: SASLServerFinalMessageKinds::Verifier( - resp.verifier, + proof_response.verifier, ), extensions: vec![], }, )) .await?; conn.flush().await?; + authenticated } Err(_) => { return conn @@ -457,28 +469,34 @@ where } }; - let session = adapter_client.new_session(SessionConfig { - conn_id: conn.conn_id().clone(), - uuid: conn_uuid, - user, - client_ip: conn.peer_addr().clone(), - external_metadata_rx: None, - helm_chart_version, - }); + let session = adapter_client.new_session( + SessionConfig { + conn_id: conn.conn_id().clone(), + uuid: conn_uuid, + user, + client_ip: conn.peer_addr().clone(), + external_metadata_rx: None, + helm_chart_version, + }, + authenticated, + ); // No frontegg check, so auth session lasts indefinitely. let auth_session = pending().right_future(); (session, auth_session) } Authenticator::None => { - let session = adapter_client.new_session(SessionConfig { - conn_id: conn.conn_id().clone(), - uuid: conn_uuid, - user, - client_ip: conn.peer_addr().clone(), - external_metadata_rx: None, - helm_chart_version, - }); + let session = adapter_client.new_session( + SessionConfig { + conn_id: conn.conn_id().clone(), + uuid: conn_uuid, + user, + client_ip: conn.peer_addr().clone(), + external_metadata_rx: None, + helm_chart_version, + }, + Authenticated, + ); // No frontegg check, so auth session lasts indefinitely. let auth_session = pending().right_future(); (session, auth_session) From 6794a1ac7663b886f4b70b76b03f7508c828a1fe Mon Sep 17 00:00:00 2001 From: Sang Jun Bak Date: Mon, 26 Jan 2026 21:09:44 -0500 Subject: [PATCH 15/15] Wrap superuser attribute in a custom struct --- src/adapter/src/client.rs | 8 +++----- src/adapter/src/coord/command_handler.rs | 10 +++++++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/adapter/src/client.rs b/src/adapter/src/client.rs index 0e50cd33ee5b6..a6439d02957a5 100644 --- a/src/adapter/src/client.rs +++ b/src/adapter/src/client.rs @@ -54,7 +54,7 @@ use uuid::Uuid; use crate::catalog::Catalog; use crate::command::{ CatalogDump, CatalogSnapshot, Command, ExecuteResponse, Response, SASLChallengeResponse, - SASLVerifyProofResponse, + SASLVerifyProofResponse, SuperuserAttribute, }; use crate::coord::{Coordinator, ExecuteContextGuard}; use crate::error::AdapterError; @@ -296,10 +296,8 @@ impl Client { // Apply the superuser attribute to the session's user if // it exists. - if let Some(superuser_attribute) = superuser_attribute { - session.apply_internal_user_metadata(InternalUserMetadata { - superuser: superuser_attribute, - }); + if let SuperuserAttribute(Some(superuser)) = superuser_attribute { + session.apply_internal_user_metadata(InternalUserMetadata { superuser }); } session.initialize_role_metadata(role_id); diff --git a/src/adapter/src/coord/command_handler.rs b/src/adapter/src/coord/command_handler.rs index 821f794732dad..39c9e332992f8 100644 --- a/src/adapter/src/coord/command_handler.rs +++ b/src/adapter/src/coord/command_handler.rs @@ -64,7 +64,7 @@ use uuid::Uuid; use crate::command::{ CatalogSnapshot, Command, ExecuteResponse, SASLChallengeResponse, SASLVerifyProofResponse, - StartupResponse, + StartupResponse, SuperuserAttribute, }; use crate::coord::appends::PendingWriteTxn; use crate::coord::peek::PendingPeek; @@ -762,7 +762,7 @@ impl Coordinator { user: &User, _conn_id: &ConnectionId, client_ip: &Option, - ) -> Result<(RoleId, Option, BTreeMap), AdapterError> { + ) -> Result<(RoleId, SuperuserAttribute, BTreeMap), AdapterError> { if self.catalog().try_get_role_by_name(&user.name).is_none() { // If the user has made it to this point, that means they have been fully authenticated. // This includes preventing any user, except a pre-defined set of system users, from @@ -868,7 +868,11 @@ impl Coordinator { // rather than eagerly on connection startup. This avoids expensive catalog_mut() calls // for the common case where connections never create temporary objects. - Ok((role_id, superuser_attribute, session_defaults)) + Ok(( + role_id, + SuperuserAttribute(superuser_attribute), + session_defaults, + )) } /// Handles an execute command.