diff --git a/Cargo.lock b/Cargo.lock index 8442f8133837e..a089de510d1d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5580,6 +5580,7 @@ dependencies = [ "base64 0.22.1", "itertools 0.14.0", "mz-ore", + "mz-repr", "openssl", "proptest", "proptest-derive", @@ -5592,9 +5593,17 @@ dependencies = [ name = "mz-authenticator" version = "0.1.0" dependencies = [ + "jsonwebtoken", "mz-adapter", + "mz-auth", "mz-frontegg-auth", + "mz-ore", + "reqwest", "serde", + "serde_json", + "tokio", + "tracing", + "url", "workspace-hack", ] @@ -6327,6 +6336,7 @@ dependencies = [ "mz-license-keys", "mz-metrics", "mz-npm", + "mz-oidc-mock", "mz-orchestrator", "mz-orchestrator-kubernetes", "mz-orchestrator-process", @@ -6566,6 +6576,7 @@ dependencies = [ "futures", "jsonwebtoken", "lru 0.16.3", + "mz-auth", "mz-ore", "mz-repr", "prometheus", @@ -6879,6 +6890,26 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "mz-oidc-mock" +version = "0.0.0" +dependencies = [ + "anyhow", + "axum", + "base64 0.22.1", + "jsonwebtoken", + "mz-authenticator", + "mz-ore", + "openssl", + "reqwest", + "serde", + "serde_json", + "tokio", + "tracing", + "uuid", + "workspace-hack", +] + [[package]] name = "mz-orchestrator" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index 9b9ad955c7c75..94fdee06e5e56 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,6 +59,7 @@ members = [ "src/mz", "src/mz-debug", "src/npm", + "src/oidc-mock", "src/orchestrator", "src/orchestrator-kubernetes", "src/orchestrator-process", @@ -182,6 +183,7 @@ default-members = [ "src/mz", "src/mz-debug", "src/npm", + "src/oidc-mock", "src/orchestrator", "src/orchestrator-kubernetes", "src/orchestrator-process", diff --git a/doc/developer/design/20251215_oidc_authentication.md b/doc/developer/design/20251215_oidc_authentication.md index abbef1786965a..cbd00c98e00f8 100644 --- a/doc/developer/design/20251215_oidc_authentication.md +++ b/doc/developer/design/20251215_oidc_authentication.md @@ -51,12 +51,6 @@ spec: # Note: Not all JWT providers support .well-known/openid-configuration, # so use jwks directly if the provider doesn't support it. jwksFetchFromIssuer: true - # The OAuth 2.0 token endpoint where Materialize will request new access - # tokens using a refresh token (https://www.rfc-editor.org/rfc/rfc6749.html#section-6). - # Requires `grant_type` of `refresh_token` in the client. - # Optional. If not provided, sessions will expire when the access token expires - # and refresh tokens in the password field will be ignored. - tokenEndpoint: https://dev-123456.okta.com/oauth2/default/v1/token ``` Where in environmentd, it’ll look like so: @@ -82,9 +76,13 @@ When a user first logs in with a valid token, we create a role for them if one d ### Solution proposal: The user should be disabled from logging in when a user is removed from the upstream IDP. However, the database level role should still exist. +When doing pgwire Oidc authentication, we can accept a cleartext password that is the access token. The OIDC authenticator will do JWT authentication on the access token. If the token is expired, the session will not be established. We will not do any invalidation on the session if the session has already been authenticated/established, but the token is expired. Eventually, the token will expire and the user will not be able to authenticate a new session. This creates a tradeoff between security and developer experience, but is acceptable since organizations will have supplemental methods of deprovisioning users outside of the database. This accomplishes disabling a user from logging in, but the database role still existing. + +**Alternative: Use a refresh token flow to invalidate active sessions** + When doing pgwire Oidc authentication, we can accept a cleartext password of the form `access=&refresh=` where `&` is a delimiter and `refresh=` is optional. The OIDC authenticator will then try to authenticate again and fetch a new access token using the refresh token when close to expiration (using the token API URL in the spec above). If the refresh token doesn’t exist, the session will invalidate. This would require users to have their IdP client generate `refresh` tokens. For token expiration checking, in a task, we'll repeatedly wait for `(expiration - now) * 0.8` and see if it's less than a minute. This is also how we check token expiration in the Frontegg authenticator. We'll also implement a config variable to turn off this mechanism and have it default to true. -By suggesting a short time to live for access tokens, this accomplishes invalidating sessions on removal of a user from an IDP. When admins remove a user from an IDP, the next time the user tries to authenticate or refresh their access token, the token API will not allow the user to login but will keep the role in the database. +This approach would enhance security by ensuring that sessions are invalidated once the access token expires. However, it would also introduce additional complexity and degrade the developer experience, as it would require users to configure refresh tokens in their IdP. Additionally, some IdPs may impose rate limits on token refresh operations. By opting for a simpler design, we minimize potential incompatibilities with various IdPs. **Alternative: Use SASL Authentication using the OAUTHBEARER mechanism rather than a cleartext password** @@ -124,12 +122,8 @@ An MVP of what this might look like exists here: [https://github.com/Materialize ### Tests: - Successful login (e2e mzcompose) -- Invalidating the session on access token expiration and no refresh token (Rust unit test) -- A token should successfully refresh if the access token and refresh token are valid (Rust unit test) - Session should error if access token is invalid (Rust unit test) -- Session should error if refresh token is invalid (Rust unit test) - A user shouldn't be able to login as another user (Rust unit test) -- Removing a user from the upstream IDP should invalidate the refresh token (e2e mzcompose) - Platform-check simple login check (platform-check framework) - JWTs should only be accepted when a valid JWK is set (we do not want to accept JWTs that are not signed with a real, cryptographically sound key) diff --git a/src/adapter/src/client.rs b/src/adapter/src/client.rs index 175ccaa57741d..a6439d02957a5 100644 --- a/src/adapter/src/client.rs +++ b/src/adapter/src/client.rs @@ -21,6 +21,7 @@ use derivative::Derivative; use futures::{Stream, StreamExt}; use itertools::Itertools; use mz_adapter_types::connection::{ConnectionId, ConnectionIdType}; +use mz_auth::Authenticated; use mz_auth::password::Password; use mz_build_info::BuildInfo; use mz_compute_types::ComputeInstanceId; @@ -33,6 +34,7 @@ use mz_ore::result::ResultExt; use mz_ore::task::AbortOnDropHandle; use mz_ore::thread::JoinOnDropHandle; use mz_ore::tracing::OpenTelemetryContext; +use mz_repr::user::InternalUserMetadata; use mz_repr::{CatalogItemId, ColumnIndex, Row, SqlScalarType}; use mz_sql::ast::{Raw, Statement}; use mz_sql::catalog::{EnvironmentId, SessionCatalog}; @@ -51,8 +53,8 @@ use uuid::Uuid; use crate::catalog::Catalog; use crate::command::{ - AuthResponse, CatalogDump, CatalogSnapshot, Command, ExecuteResponse, Response, - SASLChallengeResponse, SASLVerifyProofResponse, + CatalogDump, CatalogSnapshot, Command, ExecuteResponse, Response, SASLChallengeResponse, + SASLVerifyProofResponse, SuperuserAttribute, }; use crate::coord::{Coordinator, ExecuteContextGuard}; use crate::error::AdapterError; @@ -148,27 +150,30 @@ impl Client { /// Creates a new session associated with this client for the given user. /// /// It is the caller's responsibility to have authenticated the user. - pub fn new_session(&self, config: SessionConfig) -> Session { + /// We pass in an Authenticated marker as a guardrail to ensure the + /// user has authenticated with an authenticator before creating a session. + pub fn new_session(&self, config: SessionConfig, _authenticated: Authenticated) -> Session { // We use the system clock to determine when a session connected to Materialize. This is not // intended to be 100% accurate and correct, so we don't burden the timestamp oracle with // generating a more correct timestamp. Session::new(self.build_info, config, self.metrics().session_metrics()) } - /// Preforms an authentication check for the given user. + /// Verifies the provided user's password against the + /// stored credentials in the catalog. pub async fn authenticate( &self, user: &String, password: &Password, - ) -> Result { + ) -> Result { let (tx, rx) = oneshot::channel(); self.send(Command::AuthenticatePassword { role_name: user.to_string(), password: Some(password.clone()), tx, }); - let response = rx.await.expect("sender dropped")?; - Ok(response) + rx.await.expect("sender dropped")?; + Ok(Authenticated) } pub async fn generate_sasl_challenge( @@ -192,7 +197,7 @@ impl Client { proof: &String, nonce: &String, mock_hash: &String, - ) -> Result { + ) -> Result<(SASLVerifyProofResponse, Authenticated), AdapterError> { let (tx, rx) = oneshot::channel(); self.send(Command::AuthenticateVerifySASLProof { role_name: user.to_string(), @@ -202,7 +207,7 @@ impl Client { tx, }); let response = rx.await.expect("sender dropped")?; - Ok(response) + Ok((response, Authenticated)) } /// Upgrades this client to a session client. @@ -265,6 +270,7 @@ impl Client { optimizer_metrics, persist_client, statement_logging_frontend, + superuser_attribute, } = response; let peek_client = PeekClient::new( @@ -287,6 +293,13 @@ impl Client { }; let session = client.session(); + + // Apply the superuser attribute to the session's user if + // it exists. + if let SuperuserAttribute(Some(superuser)) = superuser_attribute { + session.apply_internal_user_metadata(InternalUserMetadata { superuser }); + } + session.initialize_role_metadata(role_id); let vars_mut = session.vars_mut(); for (name, val) in session_defaults { @@ -438,15 +451,17 @@ Issue a SQL query to get started. Need help? ) -> Result + Send>>, anyhow::Error> { // Connect to the coordinator. let conn_id = self.new_conn_id()?; - let session = self.new_session(SessionConfig { - conn_id, - uuid: Uuid::new_v4(), - user: SUPPORT_USER.name.clone(), - client_ip: None, - external_metadata_rx: None, - internal_user_metadata: None, - helm_chart_version: None, - }); + let session = self.new_session( + SessionConfig { + conn_id, + uuid: Uuid::new_v4(), + user: SUPPORT_USER.name.clone(), + client_ip: None, + external_metadata_rx: None, + helm_chart_version: None, + }, + Authenticated, + ); let mut session_client = self.startup(session).await?; // Parse the SQL statement. diff --git a/src/adapter/src/command.rs b/src/adapter/src/command.rs index e53bbe3295b3e..05143a9e49b31 100644 --- a/src/adapter/src/command.rs +++ b/src/adapter/src/command.rs @@ -86,7 +86,7 @@ pub enum Command { }, AuthenticatePassword { - tx: oneshot::Sender>, + tx: oneshot::Sender>, role_name: String, password: Option, }, @@ -372,12 +372,20 @@ pub struct Response { pub otel_ctx: OpenTelemetryContext, } +#[derive(Debug, Clone, Copy)] +pub struct SuperuserAttribute(pub Option); + /// The response to [`Client::startup`](crate::Client::startup). #[derive(Derivative)] #[derivative(Debug)] pub struct StartupResponse { /// RoleId for the user. pub role_id: RoleId, + /// The role's superuser attribute in the Catalog. + /// This attribute is None for Cloud. Cloud is able + /// to derive the role's superuser status from + /// external_metadata_rx. + pub superuser_attribute: SuperuserAttribute, /// A future that completes when all necessary Builtin Table writes have completed. #[derivative(Debug = "ignore")] pub write_notify: BuiltinTableAppendNotify, @@ -396,16 +404,6 @@ pub struct StartupResponse { pub statement_logging_frontend: StatementLoggingFrontend, } -/// The response to [`Client::authenticate`](crate::Client::authenticate). -#[derive(Derivative)] -#[derivative(Debug)] -pub struct AuthResponse { - /// RoleId for the user. - pub role_id: RoleId, - /// If the user is a superuser. - pub superuser: bool, -} - #[derive(Derivative)] #[derivative(Debug)] pub struct SASLChallengeResponse { @@ -419,7 +417,6 @@ pub struct SASLChallengeResponse { #[derivative(Debug)] pub struct SASLVerifyProofResponse { pub verifier: String, - pub auth_resp: AuthResponse, } // Facile implementation for `StartupResponse`, which does not use the `allowed` diff --git a/src/adapter/src/config/backend.rs b/src/adapter/src/config/backend.rs index e51f9540b92bc..2d8cd74e1e41f 100644 --- a/src/adapter/src/config/backend.rs +++ b/src/adapter/src/config/backend.rs @@ -9,6 +9,7 @@ use std::collections::BTreeMap; +use mz_auth::Authenticated; use mz_sql::session::user::SYSTEM_USER; use tracing::{error, info}; use uuid::Uuid; @@ -28,15 +29,17 @@ pub struct SystemParameterBackend { impl SystemParameterBackend { pub async fn new(client: Client) -> Result { let conn_id = client.new_conn_id()?; - let session = client.new_session(SessionConfig { - conn_id, - uuid: Uuid::new_v4(), - user: SYSTEM_USER.name.clone(), - client_ip: None, - external_metadata_rx: None, - internal_user_metadata: None, - helm_chart_version: None, - }); + let session = client.new_session( + SessionConfig { + conn_id, + uuid: Uuid::new_v4(), + user: SYSTEM_USER.name.clone(), + client_ip: None, + external_metadata_rx: None, + helm_chart_version: None, + }, + Authenticated, + ); let session_client = client.startup(session).await?; Ok(Self { session_client }) } diff --git a/src/adapter/src/coord/command_handler.rs b/src/adapter/src/coord/command_handler.rs index aa3a63b7a821b..39c9e332992f8 100644 --- a/src/adapter/src/coord/command_handler.rs +++ b/src/adapter/src/coord/command_handler.rs @@ -63,8 +63,8 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; use uuid::Uuid; use crate::command::{ - AuthResponse, CatalogSnapshot, Command, ExecuteResponse, SASLChallengeResponse, - SASLVerifyProofResponse, StartupResponse, + CatalogSnapshot, Command, ExecuteResponse, SASLChallengeResponse, SASLVerifyProofResponse, + StartupResponse, SuperuserAttribute, }; use crate::coord::appends::PendingWriteTxn; use crate::coord::peek::PendingPeek; @@ -504,14 +504,7 @@ impl Coordinator { Ok(verifier) => { // Success only if role exists, allows login, and a real password hash was used. if login && real_hash.is_some() { - let role = role.expect("login implies role exists"); - let _ = tx.send(Ok(SASLVerifyProofResponse { - verifier, - auth_resp: AuthResponse { - role_id: role.id, - superuser: role.attributes.superuser.unwrap_or(false), - }, - })); + let _ = tx.send(Ok(SASLVerifyProofResponse { verifier })); } else { let _ = tx.send(Err(make_auth_err(role_present, login))); } @@ -604,7 +597,7 @@ impl Coordinator { #[mz_ore::instrument(level = "debug")] async fn handle_authenticate_password( &self, - tx: oneshot::Sender>, + tx: oneshot::Sender>, role_name: String, password: Option, ) { @@ -627,13 +620,11 @@ impl Coordinator { if let Some(auth) = self.catalog().try_get_role_auth_by_id(&role.id) { if let Some(hash) = &auth.password_hash { let hash = hash.clone(); - let role_id = role.id; - let superuser = role.attributes.superuser.unwrap_or(false); task::spawn_blocking( || "auth-check-hash", move || { let _ = match mz_auth::hash::scram256_verify(&password, &hash) { - Ok(_) => tx.send(Ok(AuthResponse { role_id, superuser })), + Ok(_) => tx.send(Ok(())), Err(_) => tx.send(Err(AdapterError::AuthenticationError( AuthenticationError::InvalidCredentials, ))), @@ -669,7 +660,7 @@ impl Coordinator { ) { // Early return if successful, otherwise cleanup any possible state. match self.handle_startup_inner(&user, &conn_id, &client_ip).await { - Ok((role_id, session_defaults)) => { + Ok((role_id, superuser_attribute, session_defaults)) => { let session_type = metrics::session_type_label_value(&user); self.metrics .active_sessions @@ -744,6 +735,7 @@ impl Coordinator { optimizer_metrics: self.optimizer_metrics.clone(), persist_client: self.persist_client.clone(), statement_logging_frontend, + superuser_attribute, }); if tx.send(resp).is_err() { // Failed to send to adapter, but everything is setup so we can terminate @@ -770,7 +762,7 @@ impl Coordinator { user: &User, _conn_id: &ConnectionId, client_ip: &Option, - ) -> Result<(RoleId, BTreeMap), AdapterError> { + ) -> Result<(RoleId, SuperuserAttribute, BTreeMap), AdapterError> { if self.catalog().try_get_role_by_name(&user.name).is_none() { // If the user has made it to this point, that means they have been fully authenticated. // This includes preventing any user, except a pre-defined set of system users, from @@ -783,11 +775,13 @@ impl Coordinator { }; self.sequence_create_role_for_startup(plan).await?; } - let role_id = self + let role = self .catalog() .try_get_role_by_name(&user.name) - .expect("created above") - .id; + .expect("created above"); + let role_id = role.id; + + let superuser_attribute = role.attributes.superuser; if role_id.is_user() && !ALLOW_USER_SESSIONS.get(self.catalog().system_config().dyncfgs()) { return Err(AdapterError::UserSessionsDisallowed); @@ -874,7 +868,11 @@ impl Coordinator { // rather than eagerly on connection startup. This avoids expensive catalog_mut() calls // for the common case where connections never create temporary objects. - Ok((role_id, session_defaults)) + Ok(( + role_id, + SuperuserAttribute(superuser_attribute), + session_defaults, + )) } /// Handles an execute command. diff --git a/src/adapter/src/coord/validity.rs b/src/adapter/src/coord/validity.rs index 48ec3cffcd2db..8cf542cfbc4c9 100644 --- a/src/adapter/src/coord/validity.rs +++ b/src/adapter/src/coord/validity.rs @@ -223,7 +223,6 @@ mod tests { user, client_ip: None, external_metadata_rx: None, - internal_user_metadata: None, helm_chart_version: None, }, metrics.session_metrics(), diff --git a/src/adapter/src/session.rs b/src/adapter/src/session.rs index 50f2dcff168fe..a60b4a154be4e 100644 --- a/src/adapter/src/session.rs +++ b/src/adapter/src/session.rs @@ -210,8 +210,6 @@ pub struct SessionConfig { /// An optional receiver that the session will periodically check for /// updates to a user's external metadata. pub external_metadata_rx: Option>, - /// The metadata of the user associated with the session. - pub internal_user_metadata: Option, /// Helm chart version pub helm_chart_version: Option, } @@ -299,7 +297,6 @@ impl Session { user: SYSTEM_USER.name.clone(), client_ip: None, external_metadata_rx: None, - internal_user_metadata: None, helm_chart_version: None, }, metrics, @@ -316,7 +313,6 @@ impl Session { user, client_ip, mut external_metadata_rx, - internal_user_metadata, helm_chart_version, }: SessionConfig, metrics: SessionMetrics, @@ -325,7 +321,7 @@ impl Session { let default_cluster = INTERNAL_USER_NAME_TO_DEFAULT_CLUSTER.get(&user); let user = User { name: user, - internal_metadata: internal_user_metadata, + internal_metadata: None, external_metadata: external_metadata_rx .as_mut() .map(|rx| rx.borrow_and_update().clone()), @@ -871,6 +867,11 @@ impl Session { self.vars.set_external_user_metadata(metadata); } + /// Applies the internal user metadata to the session. + pub fn apply_internal_user_metadata(&mut self, metadata: InternalUserMetadata) { + self.vars.set_internal_user_metadata(metadata); + } + /// Initializes the session's role metadata. pub fn initialize_role_metadata(&mut self, role_id: RoleId) { self.role_metadata = Some(RoleMetadata::new(role_id)); diff --git a/src/auth/Cargo.toml b/src/auth/Cargo.toml index a5742eaaa18a8..f912922afb018 100644 --- a/src/auth/Cargo.toml +++ b/src/auth/Cargo.toml @@ -12,6 +12,7 @@ workspace = true [dependencies] base64 = "0.22.1" mz-ore = { path = "../ore", features = ["test"] } +mz-repr = { path = "../repr", default-features = false } workspace-hack = { version = "0.0.0", path = "../workspace-hack", optional = true } serde = "1.0.219" proptest-derive = "0.7.0" diff --git a/src/auth/src/lib.rs b/src/auth/src/lib.rs index 174a798681d2f..2db42e4360866 100644 --- a/src/auth/src/lib.rs +++ b/src/auth/src/lib.rs @@ -7,5 +7,15 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use serde::{Deserialize, Serialize}; + pub mod hash; pub mod password; + +/// A sentinel type signifying successful authentication. +/// +/// This type is used to establish an authenticated Adapter client session, +/// and should only be constructed by authenticators to indicate that authentication +/// has succeeded. It may also be used when authentication is not required. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Authenticated; diff --git a/src/authenticator/Cargo.toml b/src/authenticator/Cargo.toml index 3c34cd66028cb..78c811634b565 100644 --- a/src/authenticator/Cargo.toml +++ b/src/authenticator/Cargo.toml @@ -8,11 +8,21 @@ rust-version.workspace = true publish = false [dependencies] +jsonwebtoken = "9.3.1" mz-adapter = { path = "../adapter", default-features = false } +mz-auth = { path = "../auth", default-features = false } mz-frontegg-auth = { path = "../frontegg-auth", default-features = false } +reqwest = "0.12.24" serde = { version = "1.0.219", features = ["derive"] } +tokio = { version = "1.48.0", default-features = false, features = ["sync"] } +tracing = "0.1.43" +url = "2.5.7" workspace-hack = { version = "0.0.0", path = "../workspace-hack", optional = true } +[dev-dependencies] +mz-ore = { path = "../ore", default-features = false, features = ["test"] } +serde_json = "1.0" + [lints] workspace = true diff --git a/src/authenticator/src/lib.rs b/src/authenticator/src/lib.rs index 6e1527d13b742..8525873656c89 100644 --- a/src/authenticator/src/lib.rs +++ b/src/authenticator/src/lib.rs @@ -7,13 +7,18 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +pub mod oidc; + use mz_adapter::Client as AdapterClient; use mz_frontegg_auth::Authenticator as FronteggAuthenticator; +pub use oidc::{GenericOidcAuthenticator, OidcClaims, OidcConfig, OidcError}; + #[derive(Debug, Clone)] pub enum Authenticator { Frontegg(FronteggAuthenticator), Password(AdapterClient), Sasl(AdapterClient), + Oidc(GenericOidcAuthenticator), None, } diff --git a/src/authenticator/src/oidc.rs b/src/authenticator/src/oidc.rs new file mode 100644 index 0000000000000..a4d08f22fb4c9 --- /dev/null +++ b/src/authenticator/src/oidc.rs @@ -0,0 +1,343 @@ +// Copyright Materialize, Inc. and contributors. All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +//! OIDC Authentication for pgwire connections. +//! +//! This module provides JWT-based authentication using OpenID Connect (OIDC). +//! JWTs are validated locally using JWKS fetched from the configured provider. + +use std::collections::BTreeMap; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use jsonwebtoken::{DecodingKey, Validation, decode, decode_header, jwk::JwkSet}; +use mz_auth::Authenticated; +use reqwest::Client as HttpClient; +use serde::{Deserialize, Deserializer, Serialize}; + +use tracing::warn; +use url::Url; + +/// Command line arguments for OIDC authentication. +#[derive(Debug, Clone)] +pub struct OidcConfig { + /// OIDC issuer URL (e.g., `https://accounts.google.com`). + /// This is validated against the `iss` claim in the JWT. + pub oidc_issuer: String, + /// Optional OIDC audience (client ID). + /// If set, validates that the JWT's `aud` claim contains this value. + pub oidc_audience: Option, +} + +/// Errors that can occur during OIDC authentication. +#[derive(Debug)] +pub enum OidcError { + /// Failed to parse OIDC configuration URL. + InvalidIssuerUrl(url::ParseError), + /// Failed to fetch OpenID configuration from provider. + OpenIdConfigFetchFailed(String), + /// Failed to fetch JWKS from provider. + JwksFetchFailed(String), + /// The key ID is missing in the token header. + MissingKid, + /// No matching key found in JWKS. + NoMatchingKey, + /// JWT validation error from jsonwebtoken. + Jwt(jsonwebtoken::errors::Error), + /// User does not match expected value. + WrongUser, +} + +impl std::fmt::Display for OidcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OidcError::InvalidIssuerUrl(e) => { + write!(f, "failed to parse OIDC issuer URL: {}", e) + } + OidcError::OpenIdConfigFetchFailed(e) => { + write!(f, "failed to fetch OpenID configuration: {}", e) + } + OidcError::JwksFetchFailed(e) => write!(f, "failed to fetch JWKS: {}", e), + OidcError::MissingKid => write!(f, "missing key ID in token header"), + OidcError::NoMatchingKey => write!(f, "no matching key in JWKS"), + OidcError::Jwt(e) => write!(f, "JWT error: {}", e), + OidcError::WrongUser => write!(f, "user does not match expected value"), + } + } +} + +impl std::error::Error for OidcError {} + +fn deserialize_string_or_vec<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + #[derive(Deserialize)] + #[serde(untagged)] + enum StringOrVec { + String(String), + Vec(Vec), + } + + match StringOrVec::deserialize(deserializer)? { + StringOrVec::String(s) => Ok(vec![s]), + StringOrVec::Vec(v) => Ok(v), + } +} +/// Claims extracted from a validated JWT. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OidcClaims { + /// Subject (user identifier). + pub sub: String, + /// Issuer. + pub iss: String, + /// Expiration time (Unix timestamp). + pub exp: i64, + /// Issued at time (Unix timestamp). + #[serde(default)] + pub iat: Option, + /// Email claim (commonly used for username). + #[serde(default)] + pub email: Option, + /// Audience claim (can be single string or array in JWT). + #[serde(default, deserialize_with = "deserialize_string_or_vec")] + pub aud: Vec, +} + +impl OidcClaims { + /// Extract the username to use for the session. + /// + /// Priority: email > sub + // TODO (SangJunBak): Add a configuration variable to use a different username field. + pub fn username(&self) -> &str { + self.email.as_deref().unwrap_or(&self.sub) + } +} + +#[derive(Clone)] +struct OidcDecodingKey(DecodingKey); + +impl std::fmt::Debug for OidcDecodingKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("JWKS").field("key", &"").finish() + } +} + +/// OIDC Authenticator that validates JWTs using JWKS. +/// +/// This implementation pre-fetches JWKS at construction time for synchronous +/// token validation. +#[derive(Clone, Debug)] +pub struct GenericOidcAuthenticator { + inner: Arc, +} + +/// OpenID Connect Discovery document. +/// See: +#[derive(Debug, Deserialize)] +struct OpenIdConfiguration { + /// URL of the JWKS endpoint. + jwks_uri: String, +} + +#[derive(Debug)] +pub struct GenericOidcAuthenticatorInner { + issuer: String, + audience: Option, + decoding_keys: Mutex>, + http_client: HttpClient, +} + +impl GenericOidcAuthenticator { + /// Create a new [`GenericOidcAuthenticator`] from [`OidcConfig`]. + pub fn new(config: OidcConfig) -> Result { + let http_client = HttpClient::new(); + + Ok(Self { + inner: Arc::new(GenericOidcAuthenticatorInner { + issuer: config.oidc_issuer, + audience: config.oidc_audience, + decoding_keys: Mutex::new(BTreeMap::new()), + http_client, + }), + }) + } +} + +impl GenericOidcAuthenticatorInner { + async fn fetch_jwks_uri(&self) -> Result { + let openid_config_url = Url::parse(&self.issuer) + .and_then(|url| url.join(".well-known/openid-configuration")) + .map_err(OidcError::InvalidIssuerUrl)?; + + // Fetch OpenID configuration to get the JWKS URI + let response = self + .http_client + .get(openid_config_url) + .timeout(Duration::from_secs(10)) + .send() + .await + .map_err(|e| OidcError::OpenIdConfigFetchFailed(e.to_string()))?; + + if !response.status().is_success() { + return Err(OidcError::OpenIdConfigFetchFailed(format!( + "HTTP {}", + response.status() + ))); + } + + let openid_config: OpenIdConfiguration = response + .json() + .await + .map_err(|e| OidcError::OpenIdConfigFetchFailed(e.to_string()))?; + + Ok(openid_config.jwks_uri) + } + /// Fetch JWKS from the provider and parse into a map of key IDs to decoding keys. + async fn fetch_jwks(&self) -> Result, OidcError> { + let jwks_uri = self.fetch_jwks_uri().await?; + let response = self + .http_client + .get(&jwks_uri) + .timeout(Duration::from_secs(10)) + .send() + .await + .map_err(|e| OidcError::JwksFetchFailed(e.to_string()))?; + + if !response.status().is_success() { + return Err(OidcError::JwksFetchFailed(format!( + "HTTP {}", + response.status() + ))); + } + + let jwks: JwkSet = response + .json() + .await + .map_err(|e| OidcError::JwksFetchFailed(e.to_string()))?; + + let mut keys = BTreeMap::new(); + for jwk in jwks.keys { + match DecodingKey::from_jwk(&jwk) { + Ok(key) => { + if let Some(kid) = jwk.common.key_id { + keys.insert(kid, OidcDecodingKey(key)); + } + } + Err(e) => { + warn!("Failed to parse JWK: {}", e); + } + } + } + + if keys.is_empty() { + return Err(OidcError::JwksFetchFailed( + "no valid keys found in JWKS".to_string(), + )); + } + + Ok(keys) + } + + /// Find a decoding key matching the given key ID. + /// If the key is not found, fetch the JWKS and cache the keys. + async fn find_key(&self, kid: &String) -> Result { + { + let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned"); + + if let Some(key) = decoding_keys.get_mut(kid) { + return Ok(key.clone()); + } + } + + let new_decoding_keys = self.fetch_jwks().await?; + + let decoding_key = new_decoding_keys.get(kid).cloned(); + + { + let mut decoding_keys = self.decoding_keys.lock().expect("lock poisoned"); + decoding_keys.extend(new_decoding_keys); + } + + if let Some(key) = decoding_key { + return Ok(key); + } + + Err(OidcError::NoMatchingKey) + } + + pub async fn validate_access_token( + &self, + token: &str, + expected_user: Option<&str>, + ) -> Result { + // Decode header to get key ID (kid) and algorithm + let header = decode_header(token).map_err(OidcError::Jwt)?; + + let kid = header.kid.ok_or(OidcError::MissingKid)?; + // Find matching key from cached keys + let decoding_key = self.find_key(&kid).await?; + + // Set up validation + // TODO (SangJunBak): Make JWT expiration configurable. + let mut validation = Validation::new(header.alg); + validation.set_issuer(&[&self.issuer]); + if let Some(ref audience) = self.audience { + validation.set_audience(&[audience]); + } else { + validation.validate_aud = false; + } + + // Decode and validate the token + let token_data = + decode::(token, &(decoding_key.0), &validation).map_err(OidcError::Jwt)?; + + // Optionally validate expected user + if let Some(expected) = expected_user { + if token_data.claims.username() != expected { + return Err(OidcError::WrongUser); + } + } + + Ok(token_data.claims) + } +} + +impl GenericOidcAuthenticator { + pub async fn authenticate( + &self, + token: &str, + expected_user: Option<&str>, + ) -> Result<(OidcClaims, Authenticated), OidcError> { + let claims = self + .inner + .validate_access_token(token, expected_user) + .await?; + Ok((claims, Authenticated)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[mz_ore::test] + fn test_aud_single_string() { + let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":"my-app"}"#; + let claims: OidcClaims = serde_json::from_str(json).unwrap(); + assert_eq!(claims.aud, vec!["my-app"]); + } + + #[mz_ore::test] + fn test_aud_array() { + let json = r#"{"sub":"user","iss":"issuer","exp":1234,"aud":["app1","app2"]}"#; + let claims: OidcClaims = serde_json::from_str(json).unwrap(); + assert_eq!(claims.aud, vec!["app1", "app2"]); + } +} diff --git a/src/balancerd/src/lib.rs b/src/balancerd/src/lib.rs index 67c898939731b..259957525bcd2 100644 --- a/src/balancerd/src/lib.rs +++ b/src/balancerd/src/lib.rs @@ -1394,7 +1394,7 @@ impl Resolver { let auth_response = auth.authenticate(user, &password).await; let auth_session = match auth_response { - Ok(auth_session) => auth_session, + Ok((auth_session, _)) => auth_session, Err(e) => { warn!("pgwire connection failed authentication: {}", e); // TODO: fix error codes. diff --git a/src/environmentd/Cargo.toml b/src/environmentd/Cargo.toml index 254f13646cf06..9004a79d44694 100644 --- a/src/environmentd/Cargo.toml +++ b/src/environmentd/Cargo.toml @@ -41,7 +41,7 @@ maplit = "1.0.2" mime = "0.3.16" mz-alloc = { path = "../alloc" } mz-alloc-default = { path = "../alloc-default", optional = true } -mz-auth = { path = "../auth" } +mz-auth = { path = "../auth", default-features = false } mz-authenticator = { path = "../authenticator" } mz-aws-secrets-controller = { path = "../aws-secrets-controller" } mz-build-info = { path = "../build-info" } @@ -55,6 +55,7 @@ mz-dyncfg = { path = "../dyncfg" } mz-dyncfgs = { path = "../dyncfgs" } mz-frontegg-auth = { path = "../frontegg-auth" } mz-frontegg-mock = { path = "../frontegg-mock", optional = true } +mz-oidc-mock = { path = "../oidc-mock", default-features = false, optional = true } mz-http-util = { path = "../http-util" } mz-interchange = { path = "../interchange" } mz-license-keys = { path = "../license-keys" } @@ -184,6 +185,7 @@ test = [ "postgres-openssl", "mz-tracing", "mz-frontegg-mock", + "mz-oidc-mock", "tracing-capture", "mz-orchestrator-tracing/capture", ] diff --git a/src/environmentd/src/environmentd/main.rs b/src/environmentd/src/environmentd/main.rs index 06d086e8e6567..f4a148295ad90 100644 --- a/src/environmentd/src/environmentd/main.rs +++ b/src/environmentd/src/environmentd/main.rs @@ -36,6 +36,7 @@ use mz_adapter_types::bootstrap_builtin_cluster_config::{ SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR, }; use mz_auth::password::Password; +use mz_authenticator::{GenericOidcAuthenticator, OidcConfig}; use mz_aws_secrets_controller::AwsSecretsController; use mz_build_info::BuildInfo; use mz_catalog::config::ClusterReplicaSizeMap; @@ -171,6 +172,14 @@ pub struct Args { /// Frontegg arguments. #[clap(flatten)] frontegg: FronteggCliArgs, + // === OIDC options. === + /// OIDC issuer URL (e.g., `https://accounts.google.com`). + #[clap(long, env = "MZ_OIDC_ISSUER")] + oidc_issuer: Option, + /// OIDC audience (client ID). If set, validates that the JWT's `aud` claim + /// contains this value. + #[clap(long, env = "MZ_OIDC_AUDIENCE")] + oidc_audience: Option, // === Orchestrator options. === /// The service orchestrator implementation to use. #[structopt(long, value_enum, env = "ORCHESTRATOR")] @@ -743,6 +752,14 @@ fn run(mut args: Args) -> Result<(), anyhow::Error> { // Configure connections. let tls = args.tls.into_config()?; let frontegg = FronteggAuthenticator::from_args(args.frontegg, &metrics_registry)?; + let oidc = if let Some(oidc_issuer) = args.oidc_issuer { + Some(GenericOidcAuthenticator::new(OidcConfig { + oidc_issuer, + oidc_audience: args.oidc_audience, + })?) + } else { + None + }; let listeners_config: ListenersConfig = { let f = File::open(args.listeners_config_path)?; serde_json::from_reader(f)? @@ -1083,6 +1100,7 @@ fn run(mut args: Args) -> Result<(), anyhow::Error> { tls_reload_certs: mz_server_core::default_cert_reload_ticker(), external_login_password_mz_system: args.external_login_password_mz_system, frontegg, + oidc, cors_allowed_origin, egress_addresses: args.announce_egress_address, http_host_name: args.http_host_name, diff --git a/src/environmentd/src/http.rs b/src/environmentd/src/http.rs index e06c7fb87cb25..e5d097b5c5cd8 100644 --- a/src/environmentd/src/http.rs +++ b/src/environmentd/src/http.rs @@ -43,6 +43,7 @@ use hyper_openssl::client::legacy::MaybeHttpsStream; use hyper_util::rt::TokioIo; use mz_adapter::session::{Session as AdapterSession, SessionConfig as AdapterSessionConfig}; use mz_adapter::{AdapterError, AdapterNotice, Client, SessionClient, WebhookAppenderCache}; +use mz_auth::Authenticated; use mz_auth::password::Password; use mz_authenticator::Authenticator; use mz_controller::ReplicaHttpLocator; @@ -53,7 +54,7 @@ use mz_ore::metrics::MetricsRegistry; use mz_ore::now::{NowFn, SYSTEM_TIME, epoch_to_uuid_v7}; use mz_ore::str::StrExt; use mz_pgwire_common::{ConnectionCounter, ConnectionHandle}; -use mz_repr::user::{ExternalUserMetadata, InternalUserMetadata}; +use mz_repr::user::ExternalUserMetadata; use mz_server_core::listeners::{AllowedRoles, AuthenticatorKind, HttpRoutesEnabled}; use mz_server_core::{Connection, ConnectionHandler, ReloadingSslContext, Server}; use mz_sql::session::metadata::SessionMetadata; @@ -549,11 +550,10 @@ async fn x_materialize_user_header_auth(mut req: Request, next: Next) -> impl In ))); } }; - let superuser = matches!(username.as_str(), SYSTEM_USER_NAME); req.extensions_mut().insert(AuthedUser { name: username, external_metadata_rx: None, - internal_metadata: Some(InternalUserMetadata { superuser }), + authenticated: Authenticated, }); } Ok(next.run(req).await) @@ -571,7 +571,7 @@ enum ConnProtocol { pub struct AuthedUser { name: String, external_metadata_rx: Option>, - internal_metadata: Option, + authenticated: Authenticated, } pub struct AuthedClient { @@ -594,15 +594,17 @@ impl AuthedClient { F: FnOnce(&mut AdapterSession), { let conn_id = adapter_client.new_conn_id()?; - let mut session = adapter_client.new_session(AdapterSessionConfig { - conn_id, - uuid: epoch_to_uuid_v7(&(now)()), - user: user.name, - client_ip: Some(peer_addr), - external_metadata_rx: user.external_metadata_rx, - internal_user_metadata: user.internal_metadata, - helm_chart_version, - }); + let mut session = adapter_client.new_session( + AdapterSessionConfig { + conn_id, + uuid: epoch_to_uuid_v7(&(now)()), + user: user.name, + client_ip: Some(peer_addr), + external_metadata_rx: user.external_metadata_rx, + helm_chart_version, + }, + user.authenticated, + ); let connection_guard = active_connection_counter.allocate_connection(session.user())?; session_config(&mut session); @@ -763,22 +765,19 @@ pub async fn handle_login( let Ok(adapter_client) = adapter_client_rx.clone().await else { return StatusCode::INTERNAL_SERVER_ERROR; }; - let auth_response = match adapter_client.authenticate(&username, &password).await { - Ok(auth_response) => auth_response, + let authenticated = match adapter_client.authenticate(&username, &password).await { + Ok(authenticated) => authenticated, Err(err) => { warn!(?err, "HTTP login failed authentication"); return StatusCode::UNAUTHORIZED; } }; - // Create session data let session_data = TowerSessionData { username, created_at: SystemTime::now(), last_activity: SystemTime::now(), - internal_metadata: InternalUserMetadata { - superuser: auth_response.superuser, - }, + authenticated, }; // Store session data let session = session.and_then(|Extension(session)| Some(session)); @@ -835,7 +834,7 @@ async fn http_auth( req.extensions_mut().insert(AuthedUser { name: session_data.username, external_metadata_rx: None, - internal_metadata: Some(session_data.internal_metadata), + authenticated: session_data.authenticated, }); return Ok(next.run(req).await); } @@ -995,21 +994,22 @@ async fn auth( allowed_roles: AllowedRoles, include_www_authenticate_header: bool, ) -> Result { - let (name, external_metadata_rx, internal_metadata) = match authenticator { + let (name, external_metadata_rx, authenticated) = match authenticator { Authenticator::Frontegg(frontegg) => match creds { Some(Credentials::Password { username, password }) => { - let auth_session = frontegg.authenticate(&username, &password.0).await?; + let (auth_session, authenticated) = + frontegg.authenticate(&username, &password.0).await?; let name = auth_session.user().into(); let external_metadata_rx = Some(auth_session.external_metadata_rx()); - (name, external_metadata_rx, None) + (name, external_metadata_rx, authenticated) } Some(Credentials::Token { token }) => { - let claims = frontegg.validate_access_token(&token, None)?; + let (claims, authenticated) = frontegg.validate_access_token(&token, None)?; let (_, external_metadata_rx) = watch::channel(ExternalUserMetadata { user_id: claims.user_id, admin: claims.is_admin, }); - (claims.user, Some(external_metadata_rx), None) + (claims.user, Some(external_metadata_rx), authenticated) } None => { return Err(AuthError::MissingHttpAuthentication { @@ -1019,14 +1019,11 @@ async fn auth( }, Authenticator::Password(adapter_client) => match creds { Some(Credentials::Password { username, password }) => { - let auth_response = adapter_client + let authenticated = adapter_client .authenticate(&username, &password) .await .map_err(|_| AuthError::InvalidCredentials)?; - let internal_metadata = InternalUserMetadata { - superuser: auth_response.superuser, - }; - (username, None, Some(internal_metadata)) + (username, None, authenticated) } _ => { return Err(AuthError::MissingHttpAuthentication { @@ -1042,6 +1039,31 @@ async fn auth( include_www_authenticate_header, }); } + Authenticator::Oidc(oidc) => match creds { + Some(Credentials::Token { token }) => { + // Validate JWT token + let (claims, authenticated) = oidc + .authenticate(&token, None) + .await + .map_err(|_| AuthError::InvalidCredentials)?; + let name = claims.username().to_string(); + (name, None, authenticated) + } + Some(Credentials::Password { username, password }) => { + // Allow JWT to be passed as password + let (claims, authenticated) = oidc + .authenticate(&password.0, Some(&username)) + .await + .map_err(|_| AuthError::InvalidCredentials)?; + let name = claims.username().to_string(); + (name, None, authenticated) + } + None => { + return Err(AuthError::MissingHttpAuthentication { + include_www_authenticate_header, + }); + } + }, Authenticator::None => { // If no authentication, use whatever is in the HTTP auth // header (without checking the password), or fall back to the @@ -1050,7 +1072,7 @@ async fn auth( Some(Credentials::Password { username, .. }) => username, _ => HTTP_DEFAULT_USER.name.to_owned(), }; - (name, None, None) + (name, None, Authenticated) } }; @@ -1059,7 +1081,7 @@ async fn auth( Ok(AuthedUser { name, external_metadata_rx, - internal_metadata, + authenticated, }) } @@ -1126,7 +1148,7 @@ pub struct TowerSessionData { username: String, created_at: SystemTime, last_activity: SystemTime, - internal_metadata: InternalUserMetadata, + authenticated: Authenticated, } #[cfg(test)] diff --git a/src/environmentd/src/http/sql.rs b/src/environmentd/src/http/sql.rs index e1f03a90d158d..5604e8c703913 100644 --- a/src/environmentd/src/http/sql.rs +++ b/src/environmentd/src/http/sql.rs @@ -311,7 +311,7 @@ pub async fn handle_sql_ws( Some(AuthedUser { name: session_data.username, external_metadata_rx: None, - internal_metadata: Some(session_data.internal_metadata), + authenticated: session_data.authenticated, }) } else { None diff --git a/src/environmentd/src/lib.rs b/src/environmentd/src/lib.rs index 045980d57a658..7c48ed3574602 100644 --- a/src/environmentd/src/lib.rs +++ b/src/environmentd/src/lib.rs @@ -36,7 +36,7 @@ use mz_adapter_types::dyncfgs::{ WITH_0DT_DEPLOYMENT_MAX_WAIT, }; use mz_auth::password::Password; -use mz_authenticator::Authenticator; +use mz_authenticator::{Authenticator, GenericOidcAuthenticator}; use mz_build_info::{BuildInfo, build_info}; use mz_catalog::config::ClusterReplicaSizeMap; use mz_catalog::durable::BootstrapArgs; @@ -106,6 +106,8 @@ pub struct Config { pub external_login_password_mz_system: Option, /// Frontegg JWT authentication configuration. pub frontegg: Option, + /// OIDC JWT authentication configuration. + pub oidc: Option, /// Origins for which cross-origin resource sharing (CORS) for HTTP requests /// is permitted. pub cors_allowed_origin: AllowOrigin, @@ -261,6 +263,7 @@ impl Listener { active_connection_counter: ConnectionCounter, tls_reloading_context: Option, frontegg: Option, + oidc: Option, adapter_client: AdapterClient, metrics: MetricsConfig, helm_chart_version: Option, @@ -280,6 +283,9 @@ impl Listener { ), AuthenticatorKind::Password => Authenticator::Password(adapter_client.clone()), AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client.clone()), + AuthenticatorKind::Oidc => Authenticator::Oidc( + oidc.expect("OIDC config is required with AuthenticatorKind::Oidc"), + ), AuthenticatorKind::None => Authenticator::None, }; @@ -380,16 +386,23 @@ impl Listeners { let authenticator_frontegg_rx = authenticator_frontegg_rx.shared(); let (authenticator_password_tx, authenticator_password_rx) = oneshot::channel(); let authenticator_password_rx = authenticator_password_rx.shared(); + let (authenticator_oidc_tx, authenticator_oidc_rx) = oneshot::channel(); + let authenticator_oidc_rx = authenticator_oidc_rx.shared(); let (authenticator_none_tx, authenticator_none_rx) = oneshot::channel(); let authenticator_none_rx = authenticator_none_rx.shared(); - // We can only send the Frontegg and None variants immediately. + // We can only send the Frontegg, OIDC, and None variants immediately. // The Password variant requires an adapter client. if let Some(frontegg) = &config.frontegg { authenticator_frontegg_tx .send(Arc::new(Authenticator::Frontegg(frontegg.clone()))) .expect("rx known to be live"); } + if let Some(oidc) = &config.oidc { + authenticator_oidc_tx + .send(Arc::new(Authenticator::Oidc(oidc.clone()))) + .expect("rx known to be live"); + } authenticator_none_tx .send(Arc::new(Authenticator::None)) .expect("rx known to be live"); @@ -406,6 +419,7 @@ impl Listeners { AuthenticatorKind::Frontegg => authenticator_frontegg_rx.clone(), AuthenticatorKind::Password => authenticator_password_rx.clone(), AuthenticatorKind::Sasl => authenticator_password_rx.clone(), + AuthenticatorKind::Oidc => authenticator_oidc_rx.clone(), AuthenticatorKind::None => authenticator_none_rx.clone(), }; let source: &'static str = Box::leak(name.clone().into_boxed_str()); @@ -827,6 +841,7 @@ impl Listeners { active_connection_counter.clone(), tls_reloading_context.clone(), config.frontegg.clone(), + config.oidc.clone(), adapter_client.clone(), metrics.clone(), config.helm_chart_version.clone(), diff --git a/src/environmentd/src/test_util.rs b/src/environmentd/src/test_util.rs index 1d1f28e4d4fa8..ac5a895a5307f 100644 --- a/src/environmentd/src/test_util.rs +++ b/src/environmentd/src/test_util.rs @@ -90,6 +90,7 @@ use crate::{ CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig, WebSocketAuth, WebSocketResponse, }; +use mz_authenticator::GenericOidcAuthenticator; pub static KAFKA_ADDRS: LazyLock = LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into())); @@ -100,6 +101,7 @@ pub struct TestHarness { data_directory: Option, tls: Option, frontegg: Option, + oidc: Option, external_login_password_mz_system: Option, listeners_config: ListenersConfig, unsafe_mode: bool, @@ -137,6 +139,7 @@ impl Default for TestHarness { data_directory: None, tls: None, frontegg: None, + oidc: None, external_login_password_mz_system: None, listeners_config: ListenersConfig { sql: btreemap![ @@ -362,6 +365,60 @@ impl TestHarness { self } + pub fn with_oidc_auth(mut self, oidc: &GenericOidcAuthenticator) -> Self { + self.oidc = Some(oidc.clone()); + let enable_tls = self.tls.is_some(); + self.listeners_config = ListenersConfig { + sql: btreemap! { + "external".to_owned() => SqlListenerConfig { + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + authenticator_kind: AuthenticatorKind::Oidc, + allowed_roles: AllowedRoles::Normal, + enable_tls, + }, + "internal".to_owned() => SqlListenerConfig { + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + authenticator_kind: AuthenticatorKind::None, + allowed_roles: AllowedRoles::NormalAndInternal, + enable_tls: false, + }, + }, + http: btreemap! { + "external".to_owned() => HttpListenerConfig { + base: BaseListenerConfig { + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + authenticator_kind: AuthenticatorKind::Oidc, + allowed_roles: AllowedRoles::Normal, + enable_tls, + }, + routes: HttpRoutesEnabled{ + base: true, + webhook: true, + internal: false, + metrics: false, + profiling: false, + }, + }, + "internal".to_owned() => HttpListenerConfig { + base: BaseListenerConfig { + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + authenticator_kind: AuthenticatorKind::None, + allowed_roles: AllowedRoles::NormalAndInternal, + enable_tls: false, + }, + routes: HttpRoutesEnabled{ + base: true, + webhook: true, + internal: true, + metrics: true, + profiling: true, + }, + }, + }, + }; + self + } + pub fn with_password_auth(mut self, mz_system_password: Password) -> Self { self.external_login_password_mz_system = Some(mz_system_password); let enable_tls = self.tls.is_some(); @@ -750,6 +807,7 @@ impl Listeners { connection_context, replica_http_locator: Default::default(), }, + oidc: config.oidc, secrets_controller, cloud_resource_controller: None, tls: config.tls, diff --git a/src/environmentd/tests/auth.rs b/src/environmentd/tests/auth.rs index 4c1a7312eac69..fe6623b510a6a 100644 --- a/src/environmentd/tests/auth.rs +++ b/src/environmentd/tests/auth.rs @@ -38,6 +38,7 @@ use hyper_util::rt::TokioExecutor; use itertools::Itertools; use jsonwebtoken::{self, DecodingKey, EncodingKey}; use mz_auth::password::Password; +use mz_authenticator::{GenericOidcAuthenticator, OidcConfig}; use mz_environmentd::test_util::{self, Ca, make_header, make_pg_tls}; use mz_environmentd::{WebSocketAuth, WebSocketResponse}; use mz_frontegg_auth::{ @@ -47,6 +48,7 @@ use mz_frontegg_auth::{ use mz_frontegg_mock::{ FronteggMockServer, models::ApiToken, models::TenantApiTokenConfig, models::UserConfig, }; +use mz_oidc_mock::{GenerateJwtOptions, OidcMockServer}; use mz_ore::error::ErrorExt; use mz_ore::metrics::MetricsRegistry; use mz_ore::now::{NowFn, SYSTEM_TIME}; @@ -1269,6 +1271,405 @@ async fn test_auth_base_require_tls_frontegg() { .await; } +/// Tests OIDC authentication with TLS required. +/// +/// This test verifies that users can authenticate using OIDC tokens +/// over TLS connections +#[allow(clippy::unit_arg)] +#[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +async fn test_auth_base_require_tls_oidc() { + let ca = Ca::new_root("test ca").unwrap(); + let (server_cert, server_key) = ca + .request_cert("server", vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]) + .unwrap(); + + let encoding_key = String::from_utf8(ca.pkey.private_key_to_pem_pkcs8().unwrap()).unwrap(); + + let kid = "test-key-1".to_string(); + let oidc_server = OidcMockServer::start( + None, + encoding_key, + kid, + SYSTEM_TIME.clone(), + i64::try_from(EXPIRES_IN_SECS).unwrap(), + ) + .await + .unwrap(); + + let oidc_auth = GenericOidcAuthenticator::new(OidcConfig { + oidc_issuer: oidc_server.issuer.clone(), + oidc_audience: None, + }) + .unwrap(); + + let oidc_user = "user@example.com"; + let jwt_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + ..Default::default() + }, + ); + let expired_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + exp: Some(0), + ..Default::default() + }, + ); + let wrong_user_token = oidc_server.generate_jwt( + "other@example.com", + GenerateJwtOptions { + ..Default::default() + }, + ); + let wrong_issuer_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + issuer: Some("https://wrong-issuer.com"), + ..Default::default() + }, + ); + + let oidc_bearer = Authorization::bearer(&jwt_token).unwrap(); + let oidc_header_bearer = make_header(oidc_bearer); + let oidc_header_basic = make_header(Authorization::basic(oidc_user, &jwt_token)); + + let server = test_util::TestHarness::default() + .with_tls(server_cert, server_key) + .with_oidc_auth(&oidc_auth) + .start() + .await; + + run_tests( + "TlsMode::Require, OIDC", + &server, + &[ + // TLS with valid JWT should succeed. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&jwt_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // HTTP with Bearer auth should succeed. + TestCase::Http { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + scheme: Scheme::HTTPS, + headers: &oidc_header_bearer, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // HTTP with basic username/password should succeed. + TestCase::Http { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + scheme: Scheme::HTTPS, + headers: &oidc_header_basic, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // Ws with bearer token should succeed. + TestCase::Ws { + auth: &WebSocketAuth::Bearer { + token: jwt_token.clone(), + options: BTreeMap::default(), + }, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // Ws with basic username/password should succeed. + TestCase::Ws { + auth: &WebSocketAuth::Basic { + user: oidc_user.to_string(), + password: Password(jwt_token.clone()), + options: BTreeMap::default(), + }, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // No TLS should fail when server requires it. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&jwt_token)), + ssl_mode: SslMode::Disable, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!( + *err.code(), + SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION + ); + assert_eq!(err.message(), "TLS encryption is required"); + })), + }, + // HTTP without TLS should be rejected. + TestCase::Http { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + scheme: Scheme::HTTP, + headers: &oidc_header_bearer, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: assert_http_rejected(), + }, + // Invalid JWT should fail. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed("invalid-jwt-token")), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!(err.message(), "invalid password"); + assert_eq!(*err.code(), SqlState::INVALID_PASSWORD); + })), + }, + // Expired JWT should fail. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&expired_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!(err.message(), "invalid password"); + assert_eq!(*err.code(), SqlState::INVALID_PASSWORD); + })), + }, + // JWT for wrong user should fail. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&wrong_user_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!(err.message(), "invalid password"); + assert_eq!(*err.code(), SqlState::INVALID_PASSWORD); + })), + }, + // JWT with wrong issuer should fail. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&wrong_issuer_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!(err.message(), "invalid password"); + assert_eq!(*err.code(), SqlState::INVALID_PASSWORD); + })), + }, + ], + ) + .await; +} + +/// Tests OIDC audience validation. +/// +/// This test verifies that when an audience is configured, only JWTs with +/// matching `aud` claims are accepted. +#[allow(clippy::unit_arg)] +#[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +async fn test_auth_oidc_audience_validation() { + let ca = Ca::new_root("test ca").unwrap(); + let (server_cert, server_key) = ca + .request_cert("server", vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]) + .unwrap(); + + let encoding_key = String::from_utf8(ca.pkey.private_key_to_pem_pkcs8().unwrap()).unwrap(); + + let kid = "test-key-1".to_string(); + let oidc_server = OidcMockServer::start( + None, + encoding_key, + kid, + SYSTEM_TIME.clone(), + i64::try_from(EXPIRES_IN_SECS).unwrap(), + ) + .await + .unwrap(); + + let expected_audience = "my-app-client-id"; + let oidc_auth = GenericOidcAuthenticator::new(OidcConfig { + oidc_issuer: oidc_server.issuer.clone(), + oidc_audience: Some(expected_audience.to_string()), + }) + .unwrap(); + + let oidc_user = "user@example.com"; + // Token with correct audience + let valid_aud_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + aud: Some(vec![expected_audience.to_string()]), + ..Default::default() + }, + ); + // Token with correct audience among multiple audiences + let valid_multi_aud_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + aud: Some(vec!["other-app".to_string(), expected_audience.to_string()]), + ..Default::default() + }, + ); + // Token with wrong audience + let wrong_aud_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + aud: Some(vec!["wrong-app".to_string()]), + ..Default::default() + }, + ); + // Token with no audience + let no_aud_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + ..Default::default() + }, + ); + + let server = test_util::TestHarness::default() + .with_tls(server_cert, server_key) + .with_oidc_auth(&oidc_auth) + .start() + .await; + + run_tests( + "OIDC Audience Validation", + &server, + &[ + // JWT with correct audience should succeed. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&valid_aud_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // JWT with correct audience among multiple audiences should succeed. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&valid_multi_aud_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // JWT with wrong audience should fail. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&wrong_aud_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!(err.message(), "invalid password"); + assert_eq!(*err.code(), SqlState::INVALID_PASSWORD); + })), + }, + // JWT with no audience should fail when audience is required. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&no_aud_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::DbErr(Box::new(|err| { + assert_eq!(err.message(), "invalid password"); + assert_eq!(*err.code(), SqlState::INVALID_PASSWORD); + })), + }, + ], + ) + .await; +} + +/// Tests OIDC where we don't validate the audience claim. +#[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +async fn test_auth_oidc_audience_optional() { + let ca = Ca::new_root("test ca").unwrap(); + let (server_cert, server_key) = ca + .request_cert("server", vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]) + .unwrap(); + + let encoding_key = String::from_utf8(ca.pkey.private_key_to_pem_pkcs8().unwrap()).unwrap(); + + let kid = "test-key-1".to_string(); + let oidc_server = OidcMockServer::start( + None, + encoding_key, + kid, + SYSTEM_TIME.clone(), + i64::try_from(EXPIRES_IN_SECS).unwrap(), + ) + .await + .unwrap(); + + let oidc_auth = GenericOidcAuthenticator::new(OidcConfig { + oidc_issuer: oidc_server.issuer.clone(), + oidc_audience: None, + }) + .unwrap(); + + let oidc_user = "user@example.com"; + // Token with no audience + let no_aud_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + ..Default::default() + }, + ); + + // Token with any audience + let valid_aud_token = oidc_server.generate_jwt( + oidc_user, + GenerateJwtOptions { + aud: Some(vec!["my-app-client-id".to_string()]), + ..Default::default() + }, + ); + + let server = test_util::TestHarness::default() + .with_tls(server_cert, server_key) + .with_oidc_auth(&oidc_auth) + .start() + .await; + + run_tests( + "OIDC no audience validation", + &server, + &[ + // JWT with no audience should succeed. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&no_aud_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + // JWT with any audience should succeed. + TestCase::Pgwire { + user_to_auth_as: oidc_user, + user_reported_by_system: oidc_user, + password: Some(Cow::Borrowed(&valid_aud_token)), + ssl_mode: SslMode::Require, + configure: Box::new(|b| Ok(b.set_verify(SslVerifyMode::NONE))), + assert: Assert::Success, + }, + ], + ) + .await; +} + #[allow(clippy::unit_arg)] #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` diff --git a/src/environmentd/tests/server.rs b/src/environmentd/tests/server.rs index 8eb0753f39f91..bbc5da419f119 100644 --- a/src/environmentd/tests/server.rs +++ b/src/environmentd/tests/server.rs @@ -2647,6 +2647,24 @@ fn test_internal_http_auth() { // can be explicitly set to mz_system assert!(res.text().unwrap().to_string().contains("mz_system")); + // Check that mz_system is a superuser + let json_superuser = serde_json::json!({"query": "SHOW is_superuser;"}); + let res = Client::new() + .post(url.clone()) + .header("x-materialize-user", "mz_system") + .json(&json_superuser) + .send() + .unwrap(); + + tracing::info!("response: {res:?}"); + assert_eq!( + res.status(), + StatusCode::OK, + "{:?}", + res.json::() + ); + assert!(res.text().unwrap().to_string().contains("on")); + let res = Client::new() .post(url.clone()) .header("x-materialize-user", "mz_support") @@ -2664,6 +2682,23 @@ fn test_internal_http_auth() { // can be explicitly set to mz_support assert!(res.text().unwrap().to_string().contains("mz_support")); + // Check that mz_support is not a superuser + let res = Client::new() + .post(url.clone()) + .header("x-materialize-user", "mz_support") + .json(&json_superuser) + .send() + .unwrap(); + + tracing::info!("response: {res:?}"); + assert_eq!( + res.status(), + StatusCode::OK, + "{:?}", + res.json::() + ); + assert!(res.text().unwrap().to_string().contains("off")); + let res = Client::new() .post(url.clone()) .header("x-materialize-user", "invalid value") diff --git a/src/frontegg-auth/Cargo.toml b/src/frontegg-auth/Cargo.toml index 644256f687bee..c74d5671c9cf0 100644 --- a/src/frontegg-auth/Cargo.toml +++ b/src/frontegg-auth/Cargo.toml @@ -17,6 +17,7 @@ derivative = "2.2.0" futures = "0.3.31" jsonwebtoken = "9.3.1" lru = "0.16.3" +mz-auth = { path = "../auth", default-features = false } mz-ore = { path = "../ore", features = ["network", "metrics"] } mz-repr = { path = "../repr" } prometheus = { version = "0.14.0", default-features = false } diff --git a/src/frontegg-auth/src/auth.rs b/src/frontegg-auth/src/auth.rs index a80b1a01aa31c..bf32fe8e46268 100644 --- a/src/frontegg-auth/src/auth.rs +++ b/src/frontegg-auth/src/auth.rs @@ -20,6 +20,7 @@ use futures::FutureExt; use futures::future::Shared; use jsonwebtoken::{Algorithm, DecodingKey, Validation}; use lru::LruCache; +use mz_auth::Authenticated; use mz_ore::instrument; use mz_ore::metrics::MetricsRegistry; use mz_ore::now::NowFn; @@ -174,12 +175,12 @@ impl Authenticator { &self, expected_user: &str, password: &str, - ) -> Result { + ) -> Result<(AuthSessionHandle, Authenticated), Error> { let password: AppPassword = password.parse()?; match self.authenticate_inner(expected_user, password).await { Ok(handle) => { tracing::debug!("authentication successful"); - Ok(handle) + Ok((handle, Authenticated)) } Err(e) => { tracing::debug!(error = ?e, "authentication failed"); @@ -313,8 +314,9 @@ impl Authenticator { &self, token: &str, expected_user: Option<&str>, - ) -> Result { - self.inner.validate_access_token(token, expected_user) + ) -> Result<(ValidatedClaims, Authenticated), Error> { + let claims = self.inner.validate_access_token(token, expected_user)?; + Ok((claims, Authenticated)) } } diff --git a/src/materialized/ci/listener_configs/oidc.json b/src/materialized/ci/listener_configs/oidc.json new file mode 100644 index 0000000000000..28cc5ca0cc353 --- /dev/null +++ b/src/materialized/ci/listener_configs/oidc.json @@ -0,0 +1,38 @@ +{ + "sql": { + "external": { + "addr": "0.0.0.0:6875", + "authenticator_kind": "Oidc", + "allowed_roles": "NormalAndInternal", + "enable_tls": false + } + }, + "http": { + "external": { + "addr": "0.0.0.0:6876", + "authenticator_kind": "Oidc", + "allowed_roles": "NormalAndInternal", + "enable_tls": false, + "routes": { + "base": true, + "webhook": true, + "internal": true, + "metrics": false, + "profiling": true + } + }, + "metrics": { + "addr": "0.0.0.0:6878", + "authenticator_kind": "None", + "allowed_roles": "NormalAndInternal", + "enable_tls": false, + "routes": { + "base": false, + "webhook": false, + "internal": false, + "metrics": true, + "profiling": false + } + } + } +} diff --git a/src/oidc-mock/Cargo.toml b/src/oidc-mock/Cargo.toml new file mode 100644 index 0000000000000..223a0390952dd --- /dev/null +++ b/src/oidc-mock/Cargo.toml @@ -0,0 +1,36 @@ + +[package] +name = "mz-oidc-mock" +description = "OIDC mock server for testing." +version = "0.0.0" +edition.workspace = true +rust-version.workspace = true +publish = false + +[lints] +workspace = true + +[dependencies] +anyhow = "1.0.100" +axum = "0.7.5" +base64 = "0.22.1" +jsonwebtoken = "9.3.1" +mz-authenticator = { path = "../authenticator", default-features = false } +mz-ore = { path = "../ore", default-features = false, features = ["cli"] } +openssl = { version = "0.10.75", features = ["vendored"] } +serde = { version = "1.0.219", features = ["derive"] } +serde_json = "1.0.148" +tokio = { version = "1.48.0", default-features = false } +tracing = "0.1.44" +uuid = "1.19.0" +workspace-hack = { version = "0.0.0", path = "../workspace-hack", optional = true } + +[dev-dependencies] +openssl = { version = "0.10.75", features = ["vendored"] } +reqwest = { version = "0.12.28", features = ["json"] } + +[package.metadata.cargo-udeps.ignore] +normal = ["workspace-hack"] + +[features] +default = ["workspace-hack"] diff --git a/src/oidc-mock/src/lib.rs b/src/oidc-mock/src/lib.rs new file mode 100644 index 0000000000000..c3167a38a136a --- /dev/null +++ b/src/oidc-mock/src/lib.rs @@ -0,0 +1,229 @@ +// Copyright Materialize, Inc. and contributors. All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +//! OIDC mock server for testing. +//! +//! This module provides a mock OIDC server that serves JWKS endpoints +//! for validating JWT tokens in tests. + +use std::borrow::Cow; +use std::future::IntoFuture; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; + +use axum::extract::State; +use axum::routing::get; +use axum::{Json, Router}; +use base64::Engine; +use jsonwebtoken::{EncodingKey, Header, encode}; +use mz_authenticator::OidcClaims; +use mz_ore::now::NowFn; +use mz_ore::task::JoinHandle; +use openssl::pkey::{PKey, Private}; +use openssl::rsa::Rsa; +use serde::{Deserialize, Serialize}; +use tokio::net::TcpListener; + +/// JWKS response structure. +#[derive(Debug, Clone, Serialize, Deserialize)] +struct JwkSet { + pub keys: Vec, +} + +/// JSON Web Key structure. +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Jwk { + pub kty: String, + pub kid: String, + #[serde(rename = "use")] + pub key_use: String, + pub alg: String, + pub n: String, + pub e: String, +} + +/// Shared context for the OIDC mock server. +struct OidcMockContext { + /// The issuer URL (base URL of this server). + issuer: String, + /// RSA public key in JWK format. + jwk: Jwk, +} + +/// Options for generating JWT tokens. +#[derive(Debug, Clone, Default)] +pub struct GenerateJwtOptions<'a> { + /// Optional email claim. + pub email: Option<&'a str>, + /// Custom expiration time. If None, uses server's default expires_in_secs. + pub exp: Option, + /// Custom issuer. If None, uses server's issuer. + pub issuer: Option<&'a str>, + /// Audience claim. If None, uses empty vec. + pub aud: Option>, +} + +/// OIDC mock server for testing. +pub struct OidcMockServer { + /// The issuer URL. Used as the base URL of the server + /// and as the issuer for JWT iss claim. + pub issuer: String, + /// Key ID used in JWT headers. + pub kid: String, + /// Encoding key for signing JWTs (for generating test tokens). + pub encoding_key: EncodingKey, + /// Function for getting current time. + pub now: NowFn, + /// How long tokens should be valid (in seconds). + pub expires_in_secs: i64, + /// Handle to the server task. + pub handle: JoinHandle>, +} + +impl OidcMockServer { + /// Starts an [`OidcMockServer`]. + /// + /// Must be started from within a [`tokio::runtime::Runtime`]. + /// + /// # Arguments + /// + /// * `addr` - Optional address to bind to. If None, binds to localhost on a random port. + /// * `encoding_key` - PEM-encoded RSA private key string for signing JWTs. + /// * `kid` - Key ID to use in JWT headers and JWKS. + /// * `now` - Function for getting current time. + /// * `expires_in_secs` - How long tokens should be valid. + pub async fn start( + addr: Option<&SocketAddr>, + encoding_key: String, + kid: String, + now: NowFn, + expires_in_secs: i64, + ) -> Result { + // Convert PEM string to key. + let encoding_key_typed = EncodingKey::from_rsa_pem(encoding_key.as_bytes())?; + + // Parse the private key PEM to extract RSA components for JWKS. + let pkey = PKey::private_key_from_pem(encoding_key.as_bytes())?; + let rsa = pkey.rsa().expect("pkey should be RSA"); + + let addr = match addr { + Some(addr) => Cow::Borrowed(addr), + None => Cow::Owned(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)), + }; + + let listener = TcpListener::bind(*addr).await.unwrap_or_else(|e| { + panic!("error binding to {}: {}", addr, e); + }); + let issuer = format!("http://{}", listener.local_addr().unwrap()); + + // Extract RSA public key components from the decoding key + // We need to serialize the public key to get n and e values + let jwk = create_jwk(&kid, &rsa); + + let context = Arc::new(OidcMockContext { + issuer: issuer.clone(), + jwk, + }); + + let router = Router::new() + .route("/.well-known/jwks.json", get(handle_jwks)) + .route( + "/.well-known/openid-configuration", + get(handle_openid_config), + ) + .with_state(context); + + let server = axum::serve( + listener, + router.into_make_service_with_connect_info::(), + ); + println!("oidc-mock listening..."); + println!(" HTTP address: {}", issuer); + let handle = mz_ore::task::spawn(|| "oidc-mock-server", server.into_future()); + + Ok(OidcMockServer { + issuer, + kid, + encoding_key: encoding_key_typed, + now, + expires_in_secs, + handle, + }) + } + + /// Generates a JWT token for testing. + /// + /// # Arguments + /// + /// * `sub` - Subject (user identifier). + /// * `opts` - Optional JWT generation options. Use `Default::default()` for defaults. + pub fn generate_jwt(&self, sub: &str, opts: GenerateJwtOptions<'_>) -> String { + let now_ms = (self.now)(); + let now_secs = i64::try_from(now_ms / 1000).expect("timestamp must fit in i64"); + + let claims = OidcClaims { + sub: sub.to_string(), + iss: opts.issuer.unwrap_or(&self.issuer).to_string(), + exp: opts.exp.unwrap_or(now_secs + self.expires_in_secs), + iat: Some(now_secs), + email: opts.email.map(|s| s.to_string()), + aud: opts.aud.unwrap_or_default(), + }; + + let mut header = Header::new(jsonwebtoken::Algorithm::RS256); + header.kid = Some(self.kid.clone()); + + encode(&header, &claims, &self.encoding_key).expect("failed to encode JWT") + } + + /// Returns the JWKS URL for this server. + pub fn jwks_url(&self) -> String { + format!("{}/.well-known/jwks.json", self.issuer) + } +} + +/// Handler for JWKS endpoint. +async fn handle_jwks(State(context): State>) -> Json { + Json(JwkSet { + keys: vec![context.jwk.clone()], + }) +} + +/// OpenID Configuration response. +#[derive(Serialize)] +struct OpenIdConfiguration { + issuer: String, + jwks_uri: String, +} + +/// Handler for OpenID Configuration endpoint. +async fn handle_openid_config( + State(context): State>, +) -> Json { + Json(OpenIdConfiguration { + issuer: context.issuer.clone(), + jwks_uri: format!("{}/.well-known/jwks.json", context.issuer), + }) +} + +/// Creates a JWK from RSA key components. +fn create_jwk(kid: &str, rsa: &Rsa) -> Jwk { + let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD; + let n = rsa.n().to_vec(); + let e = rsa.e().to_vec(); + + Jwk { + kty: "RSA".to_string(), + kid: kid.to_string(), + key_use: "sig".to_string(), + alg: "RS256".to_string(), + n: engine.encode(n), + e: engine.encode(e), + } +} diff --git a/src/pgwire/Cargo.toml b/src/pgwire/Cargo.toml index e66ffa2306664..ed7c70fe9bfec 100644 --- a/src/pgwire/Cargo.toml +++ b/src/pgwire/Cargo.toml @@ -21,7 +21,7 @@ futures = "0.3.31" itertools = "0.14.0" mz-adapter = { path = "../adapter" } mz-adapter-types = { path = "../adapter-types" } -mz-auth = { path = "../auth" } +mz-auth = { path = "../auth", default-features = false } mz-authenticator = { path = "../authenticator" } mz-frontegg-auth = { path = "../frontegg-auth" } mz-ore = { path = "../ore", features = ["tracing"] } diff --git a/src/pgwire/src/protocol.rs b/src/pgwire/src/protocol.rs index c6a3d68301a99..aeb0fb8a2cff0 100644 --- a/src/pgwire/src/protocol.rs +++ b/src/pgwire/src/protocol.rs @@ -29,6 +29,7 @@ use mz_adapter::{ AdapterError, AdapterNotice, ExecuteContextGuard, ExecuteResponse, PeekResponseUnary, metrics, verify_datum_desc, }; +use mz_auth::Authenticated; use mz_auth::password::Password; use mz_authenticator::Authenticator; use mz_ore::cast::CastFrom; @@ -41,7 +42,6 @@ use mz_pgwire_common::{ ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3, VERSIONS, }; -use mz_repr::user::InternalUserMetadata; use mz_repr::{ CatalogItemId, ColumnIndex, Datum, RelationDesc, RowArena, RowIterator, RowRef, SqlRelationType, SqlScalarType, @@ -160,6 +160,14 @@ where } let user = params.remove("user").unwrap_or_else(String::new); + let options = parse_options(params.get("options").unwrap_or(&String::new())); + + // If oidc_auth_enabled exists as an option, return its value and filter it from + // the remaining options. + // TODO (SangJunBak): Use oidc_auth_enabled boolean and Authenticator::OIDC + // to decide whether or not we want to use OIDC authentication or password + // authentication. + let (_oidc_auth_enabled, options) = extract_oidc_auth_enabled_from_options(options); // TODO move this somewhere it can be shared with HTTP let is_internal_user = INTERNAL_USER_NAMES.contains(&user); @@ -183,52 +191,28 @@ where let (mut session, expired) = match authenticator { Authenticator::Frontegg(frontegg) => { - conn.send(BackendMessage::AuthenticationCleartextPassword) - .await?; - conn.flush().await?; - let password = match conn.recv().await? { - Some(FrontendMessage::RawAuthentication(data)) => { - match decode_password(Cursor::new(&data)).ok() { - Some(FrontendMessage::Password { password }) => password, - _ => { - return conn - .send(ErrorResponse::fatal( - SqlState::INVALID_AUTHORIZATION_SPECIFICATION, - "expected Password message", - )) - .await; - } - } - } - _ => { - return conn - .send(ErrorResponse::fatal( - SqlState::INVALID_AUTHORIZATION_SPECIFICATION, - "expected Password message", - )) - .await; + let password = match request_cleartext_password(conn).await { + Ok(password) => password, + Err(PasswordRequestError::IoError(e)) => return Err(e), + Err(PasswordRequestError::InvalidPasswordError(e)) => { + return conn.send(e).await; } }; let auth_response = frontegg.authenticate(&user, &password).await; match auth_response { - Ok(mut auth_session) => { - // Create a session based on the auth session. - // - // In particular, it's important that the username come from the - // auth session, as Frontegg may return an email address with - // different casing than the user supplied via the pgwire - // username field. We want to use the Frontegg casing as - // canonical. - let session = adapter_client.new_session(SessionConfig { - conn_id: conn.conn_id().clone(), - uuid: conn_uuid, - user: auth_session.user().into(), - client_ip: conn.peer_addr().clone(), - external_metadata_rx: Some(auth_session.external_metadata_rx()), - internal_user_metadata: None, - helm_chart_version, - }); + Ok((mut auth_session, authenticated)) => { + let session = adapter_client.new_session( + SessionConfig { + conn_id: conn.conn_id().clone(), + uuid: conn_uuid, + user: auth_session.user().into(), + client_ip: conn.peer_addr().clone(), + external_metadata_rx: Some(auth_session.external_metadata_rx()), + helm_chart_version, + }, + authenticated, + ); let expired = async move { auth_session.expired().await }; (session, expired.left_future()) } @@ -243,34 +227,54 @@ where } } } - Authenticator::Password(adapter_client) => { - conn.send(BackendMessage::AuthenticationCleartextPassword) - .await?; - conn.flush().await?; - let password = match conn.recv().await? { - Some(FrontendMessage::RawAuthentication(data)) => { - match decode_password(Cursor::new(&data)).ok() { - Some(FrontendMessage::Password { password }) => Password(password), - _ => { - return conn - .send(ErrorResponse::fatal( - SqlState::INVALID_AUTHORIZATION_SPECIFICATION, - "expected Password message", - )) - .await; - } - } + Authenticator::Oidc(oidc) => { + // OIDC authentication: JWT sent as password in cleartext flow + let jwt = match request_cleartext_password(conn).await { + Ok(password) => password, + Err(PasswordRequestError::IoError(e)) => return Err(e), + Err(PasswordRequestError::InvalidPasswordError(e)) => { + return conn.send(e).await; } - _ => { + }; + + let auth_response = oidc.authenticate(&jwt, Some(&user)).await; + match auth_response { + Ok((claims, authenticated)) => { + let session = adapter_client.new_session( + SessionConfig { + conn_id: conn.conn_id().clone(), + uuid: conn_uuid, + user: claims.username().into(), + client_ip: conn.peer_addr().clone(), + external_metadata_rx: None, + helm_chart_version, + }, + authenticated, + ); + // No invalidation of the auth session once authenticated, + // so auth session lasts indefinitely. + (session, pending().right_future()) + } + Err(err) => { + warn!(?err, "pgwire connection failed authentication"); return conn .send(ErrorResponse::fatal( - SqlState::INVALID_AUTHORIZATION_SPECIFICATION, - "expected Password message", + SqlState::INVALID_PASSWORD, + "invalid password", )) .await; } + } + } + Authenticator::Password(adapter_client) => { + let password = match request_cleartext_password(conn).await { + Ok(password) => Password(password), + Err(PasswordRequestError::IoError(e)) => return Err(e), + Err(PasswordRequestError::InvalidPasswordError(e)) => { + return conn.send(e).await; + } }; - let auth_response = match adapter_client.authenticate(&user, &password).await { + let authenticated = match adapter_client.authenticate(&user, &password).await { Ok(resp) => resp, Err(err) => { warn!(?err, "pgwire connection failed authentication"); @@ -282,17 +286,17 @@ where .await; } }; - let session = adapter_client.new_session(SessionConfig { - conn_id: conn.conn_id().clone(), - uuid: conn_uuid, - user, - client_ip: conn.peer_addr().clone(), - external_metadata_rx: None, - internal_user_metadata: Some(InternalUserMetadata { - superuser: auth_response.superuser, - }), - helm_chart_version, - }); + let session = adapter_client.new_session( + SessionConfig { + conn_id: conn.conn_id().clone(), + uuid: conn_uuid, + user, + client_ip: conn.peer_addr().clone(), + external_metadata_rx: None, + helm_chart_version, + }, + authenticated, + ); // No frontegg check, so auth session lasts indefinitely. let auth_session = pending().right_future(); (session, auth_session) @@ -395,7 +399,7 @@ where } }; - let auth_resp = match conn.recv().await? { + let authenticated = match conn.recv().await? { Some(FrontendMessage::RawAuthentication(data)) => { match decode_sasl_response(Cursor::new(&data)).ok() { Some(FrontendMessage::SASLResponse(response)) => { @@ -422,18 +426,18 @@ where ) .await { - Ok(resp) => { + Ok((proof_response, authenticated)) => { conn.send(BackendMessage::AuthenticationSASLFinal( SASLServerFinalMessage { kind: SASLServerFinalMessageKinds::Verifier( - resp.verifier, + proof_response.verifier, ), extensions: vec![], }, )) .await?; conn.flush().await?; - resp.auth_resp + authenticated } Err(_) => { return conn @@ -465,31 +469,34 @@ where } }; - let session = adapter_client.new_session(SessionConfig { - conn_id: conn.conn_id().clone(), - uuid: conn_uuid, - user, - client_ip: conn.peer_addr().clone(), - external_metadata_rx: None, - internal_user_metadata: Some(InternalUserMetadata { - superuser: auth_resp.superuser, - }), - helm_chart_version, - }); + let session = adapter_client.new_session( + SessionConfig { + conn_id: conn.conn_id().clone(), + uuid: conn_uuid, + user, + client_ip: conn.peer_addr().clone(), + external_metadata_rx: None, + helm_chart_version, + }, + authenticated, + ); // No frontegg check, so auth session lasts indefinitely. let auth_session = pending().right_future(); (session, auth_session) } + Authenticator::None => { - let session = adapter_client.new_session(SessionConfig { - conn_id: conn.conn_id().clone(), - uuid: conn_uuid, - user, - client_ip: conn.peer_addr().clone(), - external_metadata_rx: None, - internal_user_metadata: None, - helm_chart_version, - }); + let session = adapter_client.new_session( + SessionConfig { + conn_id: conn.conn_id().clone(), + uuid: conn_uuid, + user, + client_ip: conn.peer_addr().clone(), + external_metadata_rx: None, + helm_chart_version, + }, + Authenticated, + ); // No frontegg check, so auth session lasts indefinitely. let auth_session = pending().right_future(); (session, auth_session) @@ -499,7 +506,7 @@ where let system_vars = adapter_client.get_system_vars().await; for (name, value) in params { let settings = match name.as_str() { - "options" => match parse_options(&value) { + "options" => match &options { Ok(opts) => opts, Err(()) => { session.add_notice(AdapterNotice::BadStartupSetting { @@ -509,7 +516,7 @@ where continue; } }, - _ => vec![(name, value)], + _ => &vec![(name, value)], }; for (key, val) in settings { const LOCAL: bool = false; @@ -517,13 +524,12 @@ where // (silently ignore errors on set), but erroring the connection // might be the better behavior. We maybe need to support more // options sent by psql and drivers before we can safely do this. - if let Err(err) = - session - .vars_mut() - .set(&system_vars, &key, VarInput::Flat(&val), LOCAL) + if let Err(err) = session + .vars_mut() + .set(&system_vars, key, VarInput::Flat(val), LOCAL) { session.add_notice(AdapterNotice::BadStartupSetting { - name: key, + name: key.clone(), reason: err.to_string(), }); } @@ -601,6 +607,31 @@ where } } +/// Gets `oidc_auth_enabled` from options if it exists. +/// Returns options with oidc_auth_enabled extracted +/// and the oidc_auth_enabled value. +fn extract_oidc_auth_enabled_from_options( + options: Result, ()>, +) -> (bool, Result, ()>) { + let options = match options { + Ok(opts) => opts, + Err(_) => return (false, options), + }; + + let mut new_options = Vec::new(); + let mut oidc_auth_enabled = false; + + for (k, v) in options { + if k == "oidc_auth_enabled" { + oidc_auth_enabled = v.parse::().unwrap_or(false); + } else { + new_options.push((k, v)); + } + } + + (oidc_auth_enabled, Ok(new_options)) +} + /// Returns (name, value) session settings pairs from an options value. /// /// From Postgres, see pg_split_opts in postinit.c and process_postgres_switches @@ -682,6 +713,48 @@ fn split_options(value: &str) -> Vec { strs } +enum PasswordRequestError { + InvalidPasswordError(ErrorResponse), + IoError(io::Error), +} + +impl From for PasswordRequestError { + fn from(e: io::Error) -> Self { + PasswordRequestError::IoError(e) + } +} + +/// Requests a cleartext password from a connection and returns it if it is valid. +/// Sends an error response in the connection and returns None if the password +/// is not valid. +async fn request_cleartext_password( + conn: &mut FramedConn, +) -> Result +where + A: AsyncRead + AsyncWrite + Unpin, +{ + conn.send(BackendMessage::AuthenticationCleartextPassword) + .await?; + conn.flush().await?; + + if let Some(message) = conn.recv().await? { + if let FrontendMessage::RawAuthentication(data) = message { + if let Some(FrontendMessage::Password { password }) = + decode_password(Cursor::new(&data)).ok() + { + return Ok(password); + } + } + } + + Err(PasswordRequestError::InvalidPasswordError( + ErrorResponse::fatal( + SqlState::INVALID_AUTHORIZATION_SPECIFICATION, + "expected Password message", + ), + )) +} + #[derive(Debug)] enum State { Ready, @@ -2997,4 +3070,84 @@ mod test { assert_eq!(got, test.expect, "input: {}", test.input); } } + + #[mz_ore::test] + fn test_extract_oidc_auth_enabled_from_options() { + struct TestCase { + input: Result, ()>, + expect_enabled: bool, + expect_options: Result, ()>, + } + let tests = vec![ + // Empty options + TestCase { + input: Ok(vec![]), + expect_enabled: false, + expect_options: Ok(vec![]), + }, + // Error input passthrough + TestCase { + input: Err(()), + expect_enabled: false, + expect_options: Err(()), + }, + // oidc_auth_enabled=true + TestCase { + input: Ok(vec![("oidc_auth_enabled", "true")]), + expect_enabled: true, + expect_options: Ok(vec![]), + }, + // oidc_auth_enabled=false + TestCase { + input: Ok(vec![("oidc_auth_enabled", "false")]), + expect_enabled: false, + expect_options: Ok(vec![]), + }, + // Invalid oidc_auth_enabled value defaults to false + TestCase { + input: Ok(vec![("oidc_auth_enabled", "invalid")]), + expect_enabled: false, + expect_options: Ok(vec![]), + }, + // No oidc_auth_enabled, other options preserved + TestCase { + input: Ok(vec![("key1", "val1"), ("key2", "val2")]), + expect_enabled: false, + expect_options: Ok(vec![("key1", "val1"), ("key2", "val2")]), + }, + // Mixed: oidc_auth_enabled with other options + TestCase { + input: Ok(vec![ + ("key1", "val1"), + ("oidc_auth_enabled", "true"), + ("key2", "val2"), + ]), + expect_enabled: true, + expect_options: Ok(vec![("key1", "val1"), ("key2", "val2")]), + }, + ]; + for test in tests { + let input = test.input.map(|r| { + r.into_iter() + .map(|(k, v)| (k.to_owned(), v.to_owned())) + .collect() + }); + let (got_enabled, got_options) = extract_oidc_auth_enabled_from_options(input.clone()); + let expect_options = test.expect_options.map(|r| { + r.into_iter() + .map(|(k, v)| (k.to_owned(), v.to_owned())) + .collect() + }); + assert_eq!( + got_enabled, test.expect_enabled, + "enabled mismatch for input: {:?}", + input + ); + assert_eq!( + got_options, expect_options, + "options mismatch for input: {:?}", + input + ); + } + } } diff --git a/src/server-core/src/listeners.rs b/src/server-core/src/listeners.rs index e078cfdf9a1f6..4298a014097c4 100644 --- a/src/server-core/src/listeners.rs +++ b/src/server-core/src/listeners.rs @@ -22,6 +22,8 @@ pub enum AuthenticatorKind { Password, /// Authenticate users using SASL. Sasl, + /// Authenticate users using OIDC (JWT tokens). + Oidc, /// Do not authenticate users. Trust they are who they say they are without verification. #[default] None, diff --git a/src/sql/src/session/user.rs b/src/sql/src/session/user.rs index 1ff00cca6376d..d36cde11a7ea1 100644 --- a/src/sql/src/session/user.rs +++ b/src/sql/src/session/user.rs @@ -67,6 +67,8 @@ pub struct User { pub name: String, /// Metadata about this user in an external system. pub external_metadata: Option, + /// Metadata about this user stored in the catalog, + /// such as its role's `SUPERUSER` attribute. pub internal_metadata: Option, } diff --git a/src/sql/src/session/vars.rs b/src/sql/src/session/vars.rs index 014a2f2308589..a4717f00b0130 100644 --- a/src/sql/src/session/vars.rs +++ b/src/sql/src/session/vars.rs @@ -85,7 +85,7 @@ use mz_persist_client::cfg::{ use mz_repr::adt::numeric::Numeric; use mz_repr::adt::timestamp::CheckedTimestamp; use mz_repr::bytes::ByteSize; -use mz_repr::user::ExternalUserMetadata; +use mz_repr::user::{ExternalUserMetadata, InternalUserMetadata}; use mz_tracing::{CloneableEnvFilter, SerializableDirective}; use serde::Serialize; use thiserror::Error; @@ -845,6 +845,11 @@ impl SessionVars { .as_bytes() } + /// Sets the internal metadata associated with the user. + pub fn set_internal_user_metadata(&mut self, metadata: InternalUserMetadata) { + self.user.internal_metadata = Some(metadata); + } + /// Sets the external metadata associated with the user. pub fn set_external_user_metadata(&mut self, metadata: ExternalUserMetadata) { self.user.external_metadata = Some(metadata); diff --git a/src/sqllogictest/src/runner.rs b/src/sqllogictest/src/runner.rs index 60641b1811a06..75b8aaf15d5a2 100644 --- a/src/sqllogictest/src/runner.rs +++ b/src/sqllogictest/src/runner.rs @@ -1171,6 +1171,7 @@ impl<'a> RunnerInner<'a> { cloud_resource_controller: None, tls: None, frontegg: None, + oidc: None, cors_allowed_origin: AllowOrigin::list([]), unsafe_mode: true, all_features: false,