diff --git a/Cargo.lock b/Cargo.lock index 03039ce9..02eb08da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,41 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -224,6 +259,16 @@ dependencies = [ "windows-link", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "4.5.36" @@ -326,7 +371,7 @@ dependencies = [ [[package]] name = "config" -version = "0.3.13" +version = "0.3.14" dependencies = [ "chrono", "clap", @@ -408,13 +453,15 @@ dependencies = [ [[package]] name = "crypto" -version = "0.3.13" +version = "0.3.14" dependencies = [ + "aes-gcm", "base64", "pem", "rand", "rsa", "sha2", + "snafu", "testutils", ] @@ -425,6 +472,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", + "rand_core", "typenum", ] @@ -449,6 +497,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "darling" version = "0.20.11" @@ -836,6 +893,16 @@ dependencies = [ "wasi 0.14.2+wasi-0.2.4", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "gimli" version = "0.31.1" @@ -1223,6 +1290,15 @@ dependencies = [ "web-time", ] +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "generic-array", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -1573,6 +1649,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "openssl" version = "0.10.72" @@ -1720,6 +1802,18 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portable-atomic" version = "1.11.0" @@ -2486,7 +2580,7 @@ dependencies = [ [[package]] name = "testutils" -version = "0.3.13" +version = "0.3.14" dependencies = [ "pem", "rsa", @@ -2690,7 +2784,7 @@ dependencies = [ [[package]] name = "tower" -version = "0.3.13" +version = "0.3.14" dependencies = [ "tokio", "tower-api", @@ -2714,7 +2808,7 @@ dependencies = [ [[package]] name = "tower-api" -version = "0.3.13" +version = "0.3.14" dependencies = [ "log", "reqwest", @@ -2727,7 +2821,7 @@ dependencies = [ [[package]] name = "tower-cmd" -version = "0.3.13" +version = "0.3.14" dependencies = [ "anyhow", "bytes", @@ -2748,6 +2842,7 @@ dependencies = [ "serde", "serde_json", "simple_logger", + "snafu", "spinners", "tokio", "tokio-util", @@ -2766,7 +2861,7 @@ checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-package" -version = "0.3.13" +version = "0.3.14" dependencies = [ "async-compression", "config", @@ -2785,7 +2880,7 @@ dependencies = [ [[package]] name = "tower-runtime" -version = "0.3.13" +version = "0.3.14" dependencies = [ "chrono", "log", @@ -2802,7 +2897,7 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tower-version" -version = "0.3.13" +version = "0.3.14" dependencies = [ "anyhow", "chrono", @@ -2873,6 +2968,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "untrusted" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index 0894f68b..f28e5351 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,8 +4,7 @@ resolver = "2" [workspace.package] edition = "2021" -version = "0.3.13" - +version = "0.3.14" description = "Tower is the best way to host Python data apps in production" @@ -15,6 +14,7 @@ license = "MIT" repository = "https://github.com/tower/tower-cli" [workspace.dependencies] +aes-gcm = "0.10" anyhow = "1.0.95" async-compression = { version = "0.4", features = ["tokio", "gzip"] } base64 = "0.22" diff --git a/README.md b/README.md index 9b236a83..f50450aa 100644 --- a/README.md +++ b/README.md @@ -123,5 +123,12 @@ easily. Then you can `import tower` and you're off to the races! uv run python ``` +To run tests: + +```bash +uv sync --locked --all-extras --dev +uv run pytest tests +``` + If you need to get the latest OpenAPI SDK, you can run `./scripts/generate-python-api-client.sh`. diff --git a/crates/crypto/Cargo.toml b/crates/crypto/Cargo.toml index dffea22f..87977f89 100644 --- a/crates/crypto/Cargo.toml +++ b/crates/crypto/Cargo.toml @@ -4,9 +4,11 @@ version = { workspace = true } edition = "2021" [dependencies] +aes-gcm = { workspace = true } base64 = { workspace = true } pem = { workspace = true } rand = { workspace = true } rsa = { workspace = true } sha2 = { workspace = true } +snafu = { workspace = true } testutils = { workspace = true } diff --git a/crates/crypto/src/errors.rs b/crates/crypto/src/errors.rs new file mode 100644 index 00000000..5edb1796 --- /dev/null +++ b/crates/crypto/src/errors.rs @@ -0,0 +1,40 @@ +use snafu::prelude::*; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("invalid message"))] + InvalidMessage, + + #[snafu(display("invalid encoding"))] + InvalidEncoding, + + #[snafu(display("cryptography error"))] + CryptographyError, + + #[snafu(display("base64 error"))] + Base64Error, +} + +impl From for Error { + fn from(_error: std::string::FromUtf8Error) -> Self { + Self::InvalidEncoding + } +} + +impl From for Error { + fn from(_error: aes_gcm::Error) -> Self { + Self::CryptographyError + } +} + +impl From for Error { + fn from(_error: rsa::Error) -> Self { + Self::CryptographyError + } +} + +impl From for Error { + fn from(_error: base64::DecodeError) -> Self { + Self::Base64Error + } +} diff --git a/crates/crypto/src/lib.rs b/crates/crypto/src/lib.rs index 98a38b0d..92bc1fac 100644 --- a/crates/crypto/src/lib.rs +++ b/crates/crypto/src/lib.rs @@ -1,47 +1,81 @@ -use sha2::{Sha256, Digest, digest::DynDigest}; -use rand::rngs::OsRng; +use sha2::Sha256; use base64::prelude::*; +use rand::rngs::OsRng; +use rand::RngCore; +use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce}; // Or Aes256GcmSiv, Aes256GcmHs +use aes_gcm::aead::Aead; use rsa::{ - RsaPrivateKey, RsaPublicKey, Oaep, + Oaep, RsaPrivateKey, RsaPublicKey, traits::PublicKeyParts, - pkcs1::EncodeRsaPublicKey, + pkcs8::EncodePublicKey, }; -/// encrypt manages the process of encrypting long messages using the RSA algorithm and OAEP -/// padding. It takes a public key and a plaintext message and returns the ciphertext. -pub fn encrypt(key: RsaPublicKey, plaintext: String) -> String { - let mut rng = OsRng; - let hash = Sha256::new(); - let bytes = key.n().bits() / 8; - let step = bytes - 2*hash.output_size() - 2; - let chunks = plaintext.as_bytes().chunks(step); - let mut res = vec![]; - - for chunk in chunks { - let padding = Oaep::new::(); - let encrypted = key.encrypt(&mut rng, padding, chunk).unwrap(); - res.extend(encrypted); - } - - BASE64_STANDARD.encode(res) +mod errors; +pub use errors::Error; + +/// encrypt encryptes plaintext with a randomly-generated AES-256 key and IV, then encrypts the AES +/// key with RSA-OAEP using the provided public key. The result is a non-URL-safe base64-encoded +/// string. +pub fn encrypt( + key: RsaPublicKey, + plaintext: String +) -> Result { + // Generate a random 32-byte AES key + let mut aes_key = [0u8; 32]; + OsRng.fill_bytes(&mut aes_key); + + // Generate a random 12-byte IV + let mut iv = [0u8; 12]; + OsRng.fill_bytes(&mut iv); + + // Create AES cipher (GCM mode) + let aes_cipher = Aes256Gcm::new(Key::::from_slice(&aes_key)); + + // Encrypt the message + let nonce = Nonce::from_slice(&iv); // 12 bytes; unique per message + let ciphertext = aes_cipher.encrypt(nonce, plaintext.as_bytes())?; + + // Encrypt the AES key with RSA-OAEP + let padding = Oaep::new::(); + let encrypted_key = key.encrypt(&mut OsRng, padding, &aes_key)?; + + // Combine encrypted key + IV + ciphertext + let mut result = Vec::new(); + result.extend_from_slice(&encrypted_key); + result.extend_from_slice(&iv); + result.extend_from_slice(&ciphertext); + + // Encode the result as base64 + Ok(BASE64_STANDARD.encode(&result)) } -/// decrypt takes a given RSA Private Key and the relevant ciphertext and decrypts it into -/// plaintext. It's expected that the message was encrypted using OAEP padding and SHA256 digest. -pub fn decrypt(key: RsaPrivateKey, ciphertext: String) -> String { - let decoded = BASE64_STANDARD.decode(ciphertext.as_bytes()).unwrap(); +/// decrypt uses `key` to decrypt an AES-256 key that's prepended to the ciphertext. The decrypted +/// key is then used to decrypt the suffix of `ciphertext` which contains the relevant message. +/// It's expected that the message was encrypted using OAEP padding and SHA256 digest. +pub fn decrypt(key: RsaPrivateKey, ciphertext: String) -> Result { + let decoded = BASE64_STANDARD.decode(ciphertext)?; - let step = key.n().bits() / 8; - let chunks: Vec<&[u8]> = decoded.chunks(step).collect(); - let mut res = vec![]; + let n = key.size(); + let (ciphered_key, suffix) = decoded.split_at(n); - for (_, chunk) in chunks.iter().enumerate() { - let padding = Oaep::new::(); - let decrypted = key.decrypt(padding, chunk).unwrap(); - res.extend(decrypted); + let key = key.decrypt( + Oaep::new::(), + ciphered_key, + )?; + + let aes_key =Key::::from_slice(&key); + let cipher = Aes256Gcm::new(aes_key); + + // Check if the suffix is at least 12 bytes (96 bits) for the IV + if suffix.len() < 12 { + return Err(Error::InvalidMessage); } - String::from_utf8(res).unwrap() + let (iv, ciphertext) = suffix.split_at(12); + let nonce = Nonce::from_slice(iv); + + let plaintext = cipher.decrypt(nonce, ciphertext)?; + Ok(String::from_utf8(plaintext)?) } /// generate_key_pair creates a new 2048-bit public and private key for use in @@ -55,22 +89,22 @@ pub fn generate_key_pair() -> (RsaPrivateKey, RsaPublicKey) { /// serialize_public_key takes an RSA public key and serializes it into a PEM-encoded string. pub fn serialize_public_key(key: RsaPublicKey) -> String { - key.to_pkcs1_pem(rsa::pkcs1::LineEnding::LF).unwrap() + key.to_public_key_pem(rsa::pkcs8::LineEnding::LF).unwrap() } #[cfg(test)] mod test { use super::*; use rand::{distributions::Alphanumeric, Rng}; - use rsa::pkcs1::DecodeRsaPublicKey; + use rsa::pkcs8::DecodePublicKey; #[test] fn test_encrypt_decrypt() { let (private_key, public_key) = testutils::crypto::get_test_keys(); let plaintext = "Hello, World!".to_string(); - let ciphertext = encrypt(public_key, plaintext.clone()); - let decrypted = decrypt(private_key, ciphertext); + let ciphertext = encrypt(public_key, plaintext.clone()).unwrap(); + let decrypted = decrypt(private_key, ciphertext).unwrap(); assert_eq!(plaintext, decrypted); } @@ -85,8 +119,8 @@ mod test { .map(char::from) .collect(); - let ciphertext = encrypt(public_key, plaintext.clone()); - let decrypted = decrypt(private_key, ciphertext); + let ciphertext = encrypt(public_key, plaintext.clone()).unwrap(); + let decrypted = decrypt(private_key, ciphertext).unwrap(); assert_eq!(plaintext, decrypted); } @@ -95,7 +129,7 @@ mod test { fn test_serialize_public_key() { let (_private_key, public_key) = testutils::crypto::get_test_keys(); let serialized = serialize_public_key(public_key.clone()); - let deserialized = RsaPublicKey::from_pkcs1_pem(&serialized).unwrap(); + let deserialized = RsaPublicKey::from_public_key_pem(&serialized).unwrap(); assert_eq!(public_key, deserialized); } diff --git a/crates/tower-cmd/Cargo.toml b/crates/tower-cmd/Cargo.toml index 5afbdf9d..9f57eb04 100644 --- a/crates/tower-cmd/Cargo.toml +++ b/crates/tower-cmd/Cargo.toml @@ -23,6 +23,7 @@ rsa = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } simple_logger = { workspace = true } +snafu = { workspace = true } spinners = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true } diff --git a/crates/tower-cmd/src/api.rs b/crates/tower-cmd/src/api.rs index f4b11c42..819b8efe 100644 --- a/crates/tower-cmd/src/api.rs +++ b/crates/tower-cmd/src/api.rs @@ -108,6 +108,24 @@ pub async fn export_secrets(config: &Config, env: &str, all: bool, public_key: r unwrap_api_response(tower_api::apis::default_api::export_secrets(api_config, params)).await } +pub async fn export_catalogs(config: &Config, env: &str, all: bool, public_key: rsa::RsaPublicKey) -> Result> { + let api_config = &config.into(); + + let params = tower_api::apis::default_api::ExportCatalogsParams { + export_catalogs_params: tower_api::models::ExportCatalogsParams { + schema: None, + all, + public_key: crypto::serialize_public_key(public_key), + environment: env.to_string(), + page: 1, + page_size: 100, + }, + }; + + unwrap_api_response(tower_api::apis::default_api::export_catalogs(api_config, params)).await +} + + pub async fn list_secrets(config: &Config, env: &str, all: bool) -> Result> { let api_config = &config.into(); @@ -247,6 +265,17 @@ impl ResponseEntity for tower_api::apis::default_api::ExportSecretsSuccess { } } +impl ResponseEntity for tower_api::apis::default_api::ExportCatalogsSuccess { + type Data = tower_api::models::ExportCatalogsResponse; + + fn extract_data(self) -> Option { + match self { + Self::Status200(data) => Some(data), + Self::UnknownValue(_) => None, + } + } +} + impl ResponseEntity for tower_api::apis::default_api::CreateSecretSuccess { type Data = tower_api::models::CreateSecretResponse; diff --git a/crates/tower-cmd/src/error.rs b/crates/tower-cmd/src/error.rs index e69de29b..7ba35b0a 100644 --- a/crates/tower-cmd/src/error.rs +++ b/crates/tower-cmd/src/error.rs @@ -0,0 +1,20 @@ +use snafu::prelude::*; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("fetching catalogs failed"))] + FetchingCatalogsFailed, + + #[snafu(display("fetching secrets failed"))] + FetchingSecretsFailed, + + #[snafu(display("cryptography error"))] + CryptographyError, +} + +impl From for Error { + fn from(err: crypto::Error) -> Self { + log::debug!("cryptography error: {:?}", err); + Self::CryptographyError + } +} diff --git a/crates/tower-cmd/src/lib.rs b/crates/tower-cmd/src/lib.rs index 3fe628c9..76906b5c 100644 --- a/crates/tower-cmd/src/lib.rs +++ b/crates/tower-cmd/src/lib.rs @@ -5,6 +5,7 @@ mod apps; mod deploy; pub mod output; pub mod api; +pub mod error; mod run; mod secrets; mod session; @@ -12,6 +13,8 @@ mod teams; mod util; mod version; +pub use error::Error; + pub struct App { session: Option, cmd: Command, diff --git a/crates/tower-cmd/src/run.rs b/crates/tower-cmd/src/run.rs index 7e78ca7f..e46b7c12 100644 --- a/crates/tower-cmd/src/run.rs +++ b/crates/tower-cmd/src/run.rs @@ -8,6 +8,7 @@ use tower_runtime::{local::LocalApp, App, AppLauncher, OutputReceiver}; use crate::{ output, api, + Error, }; pub fn run_cmd() -> Command { @@ -85,8 +86,22 @@ async fn do_run_local(config: Config, path: PathBuf, mut params: HashMap = AppLauncher::default(); if let Err(err) = launcher - .launch(package, env, secrets, params, HashMap::new()) + .launch(package, env, secrets, params, catalogs) .await { output::runtime_error(err); @@ -236,26 +251,51 @@ fn get_app_name(cmd: Option<(&str, &ArgMatches)>) -> Option { /// get_secrets manages the process of getting secrets from the Tower server in a way that can be /// used by the local runtime during local app execution. -async fn get_secrets(config: &Config, env: &str) -> HashMap { +async fn get_secrets(config: &Config, env: &str) -> Result, Error> { let (private_key, public_key) = crypto::generate_key_pair(); - let mut spinner = output::spinner("Getting secrets..."); - match api::export_secrets(&config, env, false, public_key).await { Ok(res) => { - spinner.success(); - res.secrets - .into_iter() - .map(|secret| { - let decrypted_value = crypto::decrypt(private_key.clone(), secret.encrypted_value.to_string()); - (secret.name, decrypted_value) - }) - .collect() + let mut secrets = HashMap::new(); + + for secret in res.secrets { + // we will decrypt each property and inject it into the vals map. + let decrypted_value = crypto::decrypt(private_key.clone(), secret.encrypted_value.to_string())?; + secrets.insert(secret.name, decrypted_value); + } + + Ok(secrets) + }, + Err(err) => { + output::tower_error(err); + Err(Error::FetchingSecretsFailed) + } + } +} + +/// get_catalogs manages the process of exporting catalogs, decrypting their properties, and +/// preparting them for injection into the environment during app execution +async fn get_catalogs(config: &Config, env: &str) -> Result, Error> { + let (private_key, public_key) = crypto::generate_key_pair(); + + match api::export_catalogs(&config, env, false, public_key).await { + Ok(res) => { + let mut vals = HashMap::new(); + + for catalog in res.catalogs { + // we will decrypt each property and inject it into the vals map. + for property in catalog.properties { + let decrypted_value = crypto::decrypt(private_key.clone(), property.encrypted_value.to_string())?; + let name = create_pyiceberg_catalog_property_name(&catalog.name, &property.name); + vals.insert(name, decrypted_value); + } + } + + Ok(vals) }, Err(err) => { - spinner.failure(); output::tower_error(err); - HashMap::new() + Err(Error::FetchingCatalogsFailed) } } } @@ -323,3 +363,11 @@ async fn monitor_status(mut app: LocalApp) { } } } + +fn create_pyiceberg_catalog_property_name(catalog_name: &str, property_name: &str) -> String { + let catalog_name = catalog_name.replace('.', "_").replace(':', "_").to_uppercase(); + let property_name = property_name.replace('.', "_").replace(':', "_").to_uppercase(); + + format!("PYICEBERG_CATALOG__{}__{}", catalog_name, property_name) +} + diff --git a/crates/tower-cmd/src/secrets.rs b/crates/tower-cmd/src/secrets.rs index 614a8595..1cfb16ca 100644 --- a/crates/tower-cmd/src/secrets.rs +++ b/crates/tower-cmd/src/secrets.rs @@ -103,7 +103,7 @@ pub async fn do_list(config: Config, args: &ArgMatches) { let decrypted_value = crypto::decrypt( private_key.clone(), secret.encrypted_value.clone(), - ); + ).unwrap(); vec![ secret.name.clone(), @@ -203,7 +203,7 @@ async fn encrypt_and_create_secret( output::die("Failed to parse public key"); }); - let encrypted_value = encrypt(public_key, value.to_string()); + let encrypted_value = encrypt(public_key, value.to_string()).unwrap(); let preview = create_preview(value); api::create_secret(&config, name, environment, &encrypted_value, &preview).await diff --git a/pyproject.toml b/pyproject.toml index 71efe0ab..13a48d6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,13 +4,7 @@ build-backend = "maturin" [project] name = "tower" -version = "0.3.13" - - - - - - +version = "0.3.14" description = "Tower CLI and runtime environment for Tower." @@ -68,8 +62,8 @@ tower = { workspace = true } [dependency-groups] dev = [ - "openapi-python-client>=0.12.1", + "openapi-python-client>=0.12.1", "pytest>=8.3.5", "pytest-httpx>=0.35.0", - "pyiceberg[sql-sqlite]>=0.9.0", + "pyiceberg[sql-sqlite]>=0.9.0", ] diff --git a/src/tower/__init__.py b/src/tower/__init__.py index 129529ba..4dbeca74 100644 --- a/src/tower/__init__.py +++ b/src/tower/__init__.py @@ -13,6 +13,7 @@ from ._client import ( run_app, wait_for_run, + wait_for_runs, ) from ._features import override_get_attr, get_available_features, is_feature_enabled diff --git a/src/tower/_client.py b/src/tower/_client.py index d30c6ada..0d3a0b0b 100644 --- a/src/tower/_client.py +++ b/src/tower/_client.py @@ -1,8 +1,18 @@ import os import time -from typing import Dict, Optional +import httpx +from typing import List, Dict, Optional from ._context import TowerContext +from .exceptions import ( + NotFoundException, + UnauthorizedException, + UnknownException, + UnhandledRunStateException, + RunFailedError, + TimeoutException, +) + from .tower_api_client import AuthenticatedClient from .tower_api_client.api.default import describe_run as describe_run_api from .tower_api_client.api.default import run_app as run_app_api @@ -27,23 +37,9 @@ # app somewhere. DEFAULT_TOWER_ENVIRONMENT = "default" - -def _env_client(ctx: TowerContext) -> AuthenticatedClient: - tower_url = ctx.tower_url - - if not tower_url.endswith("/v1"): - if tower_url.endswith("/"): - tower_url += "v1" - else: - tower_url += "/v1" - - return AuthenticatedClient( - verify_ssl=False, - base_url=tower_url, - token=ctx.api_key, - auth_header_name="X-API-Key", - prefix="", - ) +# DEFAULT_NUM_TIMEOUT_RETRIES is the number of times to retry querying the Tower +# API before we just give up entirely. +DEFAULT_NUM_TIMEOUT_RETRIES = 5 def run_app( @@ -52,9 +48,26 @@ def run_app( parameters: Optional[Dict[str, str]] = None, ) -> Run: """ - `run_app` invokes an app based on the configured environment. You can - supply an optional `environment` override, and an optional dict - `parameters` to pass into the app. + Run a Tower application with specified parameters and environment. + + This function initiates a new run of a Tower application identified by its slug. + The run can be configured with an optional environment override and runtime parameters. + If no environment is specified, the default environment from the Tower context is used. + + Args: + slug (str): The unique identifier of the application to run. + environment (Optional[str]): The environment to run the application in. + If not provided, uses the default environment from the Tower context. + parameters (Optional[Dict[str, str]]): A dictionary of key-value pairs + to pass as parameters to the application run. + + Returns: + Run: A Run object containing information about the initiated application run, + including the app_slug and run number. + + Raises: + RuntimeError: If there is an error initiating the run or if the Tower API + returns an error response. """ ctx = TowerContext.build() client = _env_client(ctx) @@ -84,17 +97,234 @@ def run_app( return output.run -def wait_for_run(run: Run) -> None: +def wait_for_run( + run: Run, + timeout: Optional[float] = 86_400.0, # one day + raise_on_failure: bool = False, +) -> Run: """ - `wait_for_run` waits for a run to reach a terminal state by polling the - Tower API every 2 seconds for the latest status. If the app returns a - terminal status (`exited`, `errored`, `cancelled`, or `crashed`) then this - function returns. + Wait for a Tower app run to reach a terminal state by polling the Tower API. + + This function continuously polls the Tower API every 2 seconds (defined by WAIT_TIMEOUT) + to check the status of the specified run. The function returns when the run reaches + a terminal state (exited, errored, cancelled, or crashed). + + Args: + run (Run): The Run object containing the app_slug and number of the run to monitor. + timeout (Optional[float]): Maximum time to wait in seconds before raising a + TimeoutException. Defaults to one day (86,400 seconds). + raise_on_failure (bool): If True, raises a RunFailedError when the run fails. + If False, returns the failed run object. Defaults to False. + + Returns: + Run: The final state of the run after completion or failure. + + Raises: + TimeoutException: If the specified timeout is reached before the run completes. + RunFailedError: If raise_on_failure is True and the run fails. + UnhandledRunStateException: If the run enters an unexpected state. + UnknownException: If there are persistent problems communicating with the Tower API. + NotFoundException: If the run cannot be found. + UnauthorizedException: If the API key is invalid or unauthorized. """ ctx = TowerContext.build() - client = _env_client(ctx) + retries = 0 + + # We use this to track the timeout, if one is defined. + start_time = time.time() while True: + # We check for a timeout at the top of the loop because we want to + # avoid waiting unnecessarily for the timeout hitting the Tower API if + # we've enounctered some sort of operational problem there. + if timeout is not None: + if _time_since(start_time) > timeout: + raise TimeoutException(_time_since(start_time)) + + # We time this out to avoid waiting forever on the API. + try: + desc = _check_run_status(ctx, run, timeout=2.0) + retries = 0 + + if _is_successful_run(desc): + return desc + elif _is_failed_run(desc): + if raise_on_failure: + raise RunFailedError(desc.app_slug, desc.number, desc.status) + else: + return desc + + elif _is_run_awaiting_completion(desc): + time.sleep(WAIT_TIMEOUT) + else: + raise UnhandledRunStateException(desc.status) + except TimeoutException: + # timed out in the API, we want to keep trying this for a while + # (assuming we didn't hit the global timeout limit) until we give + # up entirely. + retries += 1 + + if retries >= DEFAULT_NUM_TIMEOUT_RETRIES: + raise UnknownException("There was a problem with the Tower API.") + + +def wait_for_runs( + runs: List[Run], + timeout: Optional[float] = 86_400.0, # one day + raise_on_failure: bool = False, +) -> tuple[List[Run], List[Run]]: + """ + Wait for multiple Tower app runs to reach terminal states by polling the Tower API. + + This function continuously polls the Tower API every 2 seconds (defined by WAIT_TIMEOUT) + to check the status of all specified runs. The function returns when all runs reach + terminal states (`exited`, `errored`, `cancelled`, or `crashed`). + + Args: + runs (List[Run]): A list of Run objects to monitor. + timeout (Optional[float]): Maximum time to wait in seconds before raising a + TimeoutException. Defaults to one day (86,400 seconds). + raise_on_failure (bool): If True, raises a RunFailedError when any run fails. + If False, failed runs are returned in the failed_runs list. Defaults to False. + + Returns: + tuple[List[Run], List[Run]]: A tuple containing two lists: + - successful_runs: List of runs that completed successfully (status: 'exited') + - failed_runs: List of runs that failed (status: 'crashed', 'cancelled', or 'errored') + + Raises: + TimeoutException: If the specified timeout is reached before all runs complete. + RunFailedError: If raise_on_failure is True and any run fails. + UnhandledRunStateException: If a run enters an unexpected state. + UnknownException: If there are persistent problems communicating with the Tower API. + NotFoundException: If any run cannot be found. + UnauthorizedException: If the API key is invalid or unauthorized. + """ + ctx = TowerContext.build() + retries = 0 + + # We use this to track the timeout, if one is defined. + start_time = time.time() + + awaiting_runs = runs + successful_runs = [] + failed_runs = [] + + while len(awaiting_runs) > 0: + run = awaiting_runs.pop(0) + + # Check the overall timeout at the top of the loop in case we've + # spent a load of time deeper inside the loop on reties, etc. + if timeout is not None: + if _time_since(start_time) > timeout: + raise TimeoutException(_time_since(start_time)) + + try: + desc = _check_run_status(ctx, run, timeout=2.0) + retries = 0 + + if _is_successful_run(desc): + successful_runs.append(desc) + elif _is_failed_run(desc): + if raise_on_failure: + raise RunFailedError(desc.app_slug, desc.number, desc.status) + else: + failed_runs.append(desc) + + elif _is_run_awaiting_completion(desc): + time.sleep(WAIT_TIMEOUT) + + # We need to re-add this run to the list so we check it again + # in the future. We add it to the back since we took it off the + # front, effectively moving to the next run. + awaiting_runs.append(desc) + else: + raise UnhandledRunStateException(desc.status) + except TimeoutException: + # timed out in the API, we want to keep trying this for a while + # (assuming we didn't hit the global timeout limit) until we give + # up entirely. + retries += 1 + + if retries >= DEFAULT_NUM_TIMEOUT_RETRIES: + raise UnknownException("There was a problem with the Tower API.") + else: + # Add the item back on the list for retry later on. + awaiting_runs.append(run) + + return (successful_runs, failed_runs) + + +def _is_failed_run(run: Run) -> bool: + """ + Check if the given run has failed. + + Args: + run (Run): The Run object containing the status to check. + + Returns: + bool: True if the run has failed, False otherwise. + """ + return run.status in ["crashed", "cancelled", "errored"] + + +def _is_successful_run(run: Run) -> bool: + """ + Check if a given run was successful. + + Args: + run (Run): The Run object containing the status to check. + + Returns: + bool: True if the run was successful, False otherwise. + """ + return run.status in ["exited"] + + +def _is_run_awaiting_completion(run: Run) -> bool: + """ + Check if a given run is either running or expected to run in the near future. + + Args: + run (Run): The Run object containing the status to check. + + Returns: + bool: True if the run is awaiting run or currently running, False otherwise. + """ + return run.status in ["pending", "scheduled", "running"] + + +def _env_client(ctx: TowerContext, timeout: Optional[float] = None) -> AuthenticatedClient: + tower_url = ctx.tower_url + + if not tower_url.endswith("/v1"): + if tower_url.endswith("/"): + tower_url += "v1" + else: + tower_url += "/v1" + + return AuthenticatedClient( + verify_ssl=False, + base_url=tower_url, + token=ctx.api_key, + auth_header_name="X-API-Key", + prefix="", + timeout=timeout, + ) + + +def _time_since(start_time: float) -> float: + return time.time() - start_time + + +def _check_run_status( + ctx: TowerContext, + run: Run, + timeout: Optional[float] = 2.0, # one day +) -> Run: + client = _env_client(ctx, timeout=timeout) + + try: output: Optional[Union[DescribeRunResponse, ErrorModel]] = describe_run_api.sync( slug=run.app_slug, seq=run.number, @@ -102,20 +332,23 @@ def wait_for_run(run: Run) -> None: ) if output is None: - raise RuntimeError("Error fetching run") - else: - if isinstance(output, ErrorModel): - raise RuntimeError(f"Error fetching run: {output.title}") + raise UnknownException("Failed to fetch run") + elif isinstance(output, ErrorModel): + # If it was a 404 error, that means that we couldn't find this + # app for some reason. This is really only relevant on the + # first time that we check--if we could find the run, but then + # suddenly couldn't that's a really big problem I'd say. + if output.status == 404: + raise NotFoundException(output.detail) + elif output.status == 401: + # NOTE: Most of the time, this shouldn't happen? + raise UnauthorizedException(output.detail) else: - desc = output.run - - if desc.status == "exited": - return - elif desc.status == "failed": - return - elif desc.status == "canceled": - return - elif desc.status == "errored": - return - else: - time.sleep(WAIT_TIMEOUT) + raise UnknownException(output.detail) + else: + # There was a run object, so let's return that. + return output.run + except httpx.TimeoutException: + # If we received a timeout from the API then we should raise our own + # timeout type. + raise TimeoutException(timeout) diff --git a/src/tower/_tables.py b/src/tower/_tables.py index 63d0294c..8c5ea8fd 100644 --- a/src/tower/_tables.py +++ b/src/tower/_tables.py @@ -36,6 +36,30 @@ class Table: """ def __init__(self, context: TowerContext, table: IcebergTable): + """ + Initialize a new Table instance that wraps an Iceberg table. + + This constructor creates a Table object that provides a high-level interface + for interacting with an Iceberg table. It initializes the table statistics + tracking and stores the necessary context and table references. + + Args: + context (TowerContext): The context in which the table operates, providing + configuration and environment settings. + table (IcebergTable): The underlying Iceberg table instance to be wrapped. + + Attributes: + _stats (RowsAffectedInformation): Tracks the number of rows affected by + insert and update operations. Initialized with zero counts. + _context (TowerContext): The context in which the table operates. + _table (IcebergTable): The underlying Iceberg table instance. + + Example: + >>> # Create a table reference and load it + >>> table_ref = tables("my_table") + >>> table = table_ref.load() # This internally calls Table.__init__ + """ + self._stats = RowsAffectedInformation(0, 0) self._context = context self._table = table @@ -43,7 +67,23 @@ def __init__(self, context: TowerContext, table: IcebergTable): def read(self) -> pl.DataFrame: """ - Reads from the Iceberg tables. Returns the results as a Polars DataFrame. + Reads all data from the Iceberg table and returns it as a Polars DataFrame. + + This method executes a full table scan and materializes the results into memory + as a Polars DataFrame. For large tables, consider using `to_polars()` to get a + LazyFrame that can be processed incrementally. + + Returns: + pl.DataFrame: A Polars DataFrame containing all rows from the table. + + Example: + >>> table = tables("my_table").load() + >>> # Read all data into a DataFrame + >>> df = table.read() + >>> # Perform operations on the DataFrame + >>> filtered_df = df.filter(pl.col("age") > 30) + >>> # Get basic statistics + >>> print(df.describe()) """ # We call `collect` here to force the execution of the query and get # the result as a DataFrame. @@ -52,29 +92,88 @@ def read(self) -> pl.DataFrame: def to_polars(self) -> pl.LazyFrame: """ - Converts the table to a Polars LazyFrame. This is useful when you - understand Polars and you want to do something more complicated. - """ + Converts the table to a Polars LazyFrame for efficient, lazy evaluation. + + This method returns a LazyFrame that allows for building complex query plans + without immediately executing them. This is particularly useful for: + - Processing large tables that don't fit in memory + - Building complex transformations and aggregations + - Optimizing query performance through Polars' query optimizer + + Returns: + pl.LazyFrame: A Polars LazyFrame representing the table data. + + Example: + >>> table = tables("my_table").load() + >>> # Create a lazy query plan + >>> lazy_df = table.to_polars() + >>> # Build complex transformations + >>> result = (lazy_df + ... .filter(pl.col("age") > 30) + ... .groupby("department") + ... .agg(pl.col("salary").mean()) + ... .sort("department")) + >>> # Execute the plan + >>> final_df = result.collect() + """ return pl.scan_iceberg(self._table) def rows_affected(self) -> RowsAffectedInformation: """ - Returns the stats for the table. This includes the number of inserts, - updates, and deletes. + Returns statistics about the number of rows affected by write operations on the table. + + This method tracks the cumulative number of rows that have been inserted or updated + through operations like `insert()` and `upsert()`. Note that delete operations are + not currently tracked due to limitations in the underlying Iceberg implementation. + + Returns: + RowsAffectedInformation: An object containing: + - inserts (int): Total number of rows inserted + - updates (int): Total number of rows updated + + Example: + >>> table = tables("my_table").load() + >>> # Insert some data + >>> table.insert(new_data) + >>> # Upsert some data + >>> table.upsert(updated_data, join_cols=["id"]) + >>> # Check the impact of our operations + >>> stats = table.rows_affected() + >>> print(f"Inserted {stats.inserts} rows") + >>> print(f"Updated {stats.updates} rows") """ return self._stats def insert(self, data: pa.Table) -> TTable: """ - Inserts data into the Iceberg table. The data is expressed as a PyArrow table. + Inserts new rows into the Iceberg table. + + This method appends the provided data to the table. The data must be provided as a + PyArrow table with a schema that matches the table's schema. The operation is + tracked in the table's statistics, incrementing the insert count. Args: - data (pa.Table): The data to insert into the table. + data (pa.Table): The data to insert into the table. The schema of this table + must match the schema of the target table. Returns: - TTable: The table with the inserted rows. + TTable: The table instance with the newly inserted rows, allowing for method chaining. + + Example: + >>> table = tables("my_table").load() + >>> # Create a PyArrow table with new data + >>> new_data = pa.table({ + ... "id": [1, 2, 3], + ... "name": ["Alice", "Bob", "Charlie"], + ... "age": [25, 30, 35] + ... }) + >>> # Insert the data + >>> table.insert(new_data) + >>> # Verify the insertion + >>> stats = table.rows_affected() + >>> print(f"Inserted {stats.inserts} rows") """ self._table.append(data) self._stats.inserts += data.num_rows @@ -83,14 +182,42 @@ def insert(self, data: pa.Table) -> TTable: def upsert(self, data: pa.Table, join_cols: Optional[list[str]] = None) -> TTable: """ - Upserts data into the Iceberg table. The data is expressed as a PyArrow table. + Performs an upsert operation (update or insert) on the Iceberg table. + + This method will: + - Update existing rows if they match the join columns + - Insert new rows if no match is found + All operations are case-sensitive by default. Args: - data (pa.Table): The data to upsert into the table. - join_cols (Optional[list[str]]): The columns that form the key to match rows on + data (pa.Table): The data to upsert into the table. The schema of this table + must match the schema of the target table. + join_cols (Optional[list[str]]): The columns that form the key to match rows on. + If not provided, all columns will be used for matching. Returns: - TTable: The table with the upserted rows. + TTable: The table instance with the upserted rows, allowing for method chaining. + + Note: + - The operation is always case-sensitive + - When a match is found, all columns are updated + - When no match is found, the row is inserted + - The operation is tracked in the table's statistics + + Example: + >>> table = tables("my_table").load() + >>> # Create a PyArrow table with data to upsert + >>> data = pa.table({ + ... "id": [1, 2, 3], + ... "name": ["Alice", "Bob", "Charlie"], + ... "age": [26, 31, 36] # Updated ages + ... }) + >>> # Upsert the data using 'id' as the key + >>> table.upsert(data, join_cols=["id"]) + >>> # Verify the operation + >>> stats = table.rows_affected() + >>> print(f"Updated {stats.updates} rows") + >>> print(f"Inserted {stats.inserts} rows") """ res = self._table.upsert( data, @@ -114,17 +241,40 @@ def upsert(self, data: pa.Table, join_cols: Optional[list[str]] = None) -> TTabl def delete(self, filters: Union[str, List[pc.Expression]]) -> TTable: """ - Deletes data from the Iceberg table. The filters are expressed as a - PyArrow expression. The filters are applied to the table and the - matching rows are deleted. + Deletes rows from the Iceberg table that match the specified filter conditions. + + This method removes rows from the table based on the provided filter expressions. + The operation is always case-sensitive. Note that the number of deleted rows + cannot be tracked due to limitations in the underlying Iceberg implementation. Args: - filters (Union[str, List[pc.Expression]]): The filters to apply to the table. - This can be a string or a list of PyArrow expressions. + filters (Union[str, List[pc.Expression]]): The filter conditions to apply. + Can be either: + - A single PyArrow compute expression + - A list of PyArrow compute expressions (combined with AND) + - A string expression Returns: - TTable: The table with the deleted rows. + TTable: The table instance with the deleted rows, allowing for method chaining. + + Note: + - The operation is always case-sensitive + - The number of deleted rows cannot be tracked in the table statistics + - To get the number of deleted rows, you would need to compare snapshots + + Example: + >>> table = tables("my_table").load() + >>> # Delete rows where age is greater than 30 + >>> table.delete(table.column("age") > 30) + >>> # Delete rows matching multiple conditions + >>> table.delete([ + ... table.column("age") > 30, + ... table.column("department") == "IT" + ... ]) + >>> # Delete rows using a string expression + >>> table.delete("age > 30 AND department = 'IT'") """ + if isinstance(filters, list): # We need to convert the pc.Expression into PyIceberg next_filters = convert_pyarrow_expressions(filters) @@ -144,15 +294,45 @@ def delete(self, filters: Union[str, List[pc.Expression]]) -> TTable: def schema(self) -> pa.Schema: - # We take an Iceberg Schema and we need to convert it into a PyArrow Schema + """ + Returns the schema of the table as a PyArrow schema. + + This method converts the underlying Iceberg table schema into a PyArrow schema, + which can be used for type information and schema validation. + + Returns: + pa.Schema: The PyArrow schema representation of the table's structure. + Example: + >>> table = tables("my_table").load() + >>> schema = table.schema() + """ iceberg_schema = self._table.schema() return iceberg_schema.as_arrow() def column(self, name: str) -> pa.compute.Expression: """ - Returns a column from the table. This is useful when you want to - perform some operations on the column. + Returns a column from the table as a PyArrow compute expression. + + This method is useful for creating column-based expressions that can be used in + operations like filtering, sorting, or aggregating data. The returned expression + can be used with PyArrow's compute functions. + + Args: + name (str): The name of the column to retrieve from the table schema. + + Returns: + pa.compute.Expression: A PyArrow compute expression representing the column. + + Raises: + ValueError: If the specified column name is not found in the table schema. + + Example: + >>> table = tables("my_table").load() + >>> # Create a filter expression for rows where age > 30 + >>> age_expr = table.column("age") > 30 + >>> # Use the expression in a delete operation + >>> table.delete(age_expr) """ field = self.schema().field(name) @@ -172,6 +352,26 @@ def __init__(self, ctx: TowerContext, catalog: Catalog, name: str, namespace: Op def load(self) -> Table: + """ + Loads an existing Iceberg table from the catalog. + + This method resolves the table's namespace and name, then loads the table + from the catalog. If the table doesn't exist, this will raise an error. + Use `create()` or `create_if_not_exists()` to create new tables. + + Returns: + Table: A new Table instance wrapping the loaded Iceberg table. + + Raises: + TableNotFoundError: If the table doesn't exist in the catalog. + + Example: + >>> # Load the existing table + >>> table = tables("my_table", namespace="my_namespace").load() + >>> # Now you can use the table + >>> df = table.read() + """ + namespace = namespace_or_default(self._namespace) table_name = make_table_name(self._name, namespace) table = self._catalog.load_table(table_name) @@ -179,6 +379,40 @@ def load(self) -> Table: def create(self, schema: pa.Schema) -> Table: + + """ + Creates a new Iceberg table with the specified schema. + + This method will: + 1. Resolve the table's namespace (using default if not specified) + 2. Create the namespace if it doesn't exist + 3. Create a new table with the provided schema + 4. Return a Table instance for the newly created table + + Args: + schema (pa.Schema): The PyArrow schema defining the structure of the table. + This will be converted to an Iceberg schema internally. + + Returns: + Table: A new Table instance wrapping the created Iceberg table. + + Raises: + TableAlreadyExistsError: If a table with the same name already exists in the namespace. + NamespaceError: If there are issues creating or accessing the namespace. + + Example: + >>> # Define the table schema + >>> schema = pa.schema([ + ... pa.field("id", pa.int64()), + ... pa.field("name", pa.string()), + ... pa.field("age", pa.int32()) + ... ]) + >>> # Create the table + >>> table = tables("my_table", namespace="my_namespace").create(schema) + >>> # Now you can use the table + >>> table.insert(new_data) + """ + namespace = namespace_or_default(self._namespace) table_name = make_table_name(self._name, namespace) @@ -198,6 +432,43 @@ def create(self, schema: pa.Schema) -> Table: def create_if_not_exists(self, schema: pa.Schema) -> Table: + """ + Creates a new Iceberg table if it doesn't exist, or returns the existing table. + + This method will: + 1. Resolve the table's namespace (using default if not specified) + 2. Create the namespace if it doesn't exist + 3. Create a new table with the provided schema if it doesn't exist + 4. Return the existing table if it already exists + 5. Return a Table instance for the table + + Unlike `create()`, this method will not raise an error if the table already exists. + Instead, it will return the existing table, making it safe for idempotent operations. + + Args: + schema (pa.Schema): The PyArrow schema defining the structure of the table. + This will be converted to an Iceberg schema internally. Note that this + schema is only used if the table needs to be created. + + Returns: + Table: A Table instance wrapping either the newly created or existing Iceberg table. + + Raises: + NamespaceError: If there are issues creating or accessing the namespace. + + Example: + >>> # Define the table schema + >>> schema = pa.schema([ + ... pa.field("id", pa.int64()), + ... pa.field("name", pa.string()), + ... pa.field("age", pa.int32()) + ... ]) + >>> # Create the table if it doesn't exist + >>> table = tables("my_table", namespace="my_namespace").create_if_not_exists(schema) + >>> # This is safe to call multiple times + >>> table = tables("my_table", namespace="my_namespace").create_if_not_exists(schema) + """ + namespace = namespace_or_default(self._namespace) table_name = make_table_name(self._name, namespace) @@ -222,18 +493,50 @@ def tables( namespace: Optional[str] = None ) -> TableReference: """ - `tables` creates a reference to an Iceberg table with the name `name` from - the catalog with name `catalog_name`. + Creates a reference to an Iceberg table that can be used to load or create tables. + + This function is the main entry point for working with Iceberg tables in Tower. It returns + a TableReference object that can be used to either load an existing table or create a new one. + The actual table operations (read, write, etc.) are performed through the Table instance + obtained by calling `load()` or `create()` on the returned reference. Args: - `name` (str): The name of the table to load. - `catalog` (Union[str, Catalog]): The name of the catalog or the actual - catalog to use. "default" is the default value. You can pass in an - actual catalog object for testing purposes. - `namespace` (Optional[str]): The namespace in which to load the table. + name (str): The name of the table to reference. This will be used to either load + an existing table or create a new one. + catalog (Union[str, Catalog], optional): The catalog to use. Can be either: + - A string name of the catalog (defaults to "default") + - A Catalog instance (useful for testing or custom catalog implementations) + Defaults to "default". + namespace (Optional[str], optional): The namespace in which the table exists or + should be created. If not provided, a default namespace will be used. Returns: - TableReference: A reference to a table to be resolved with `create` or `load` + TableReference: A reference object that can be used to: + - Load an existing table using `load()` + - Create a new table using `create()` + - Create a table if it doesn't exist using `create_if_not_exists()` + + Raises: + CatalogError: If there are issues accessing or loading the specified catalog. + TableNotFoundError: When trying to load a non-existent table (only if `load()` is called). + + Examples: + >>> # Load an existing table from the default catalog + >>> table = tables("my_table").load() + >>> df = table.read() + + >>> # Create a new table in a specific namespace + >>> schema = pa.schema([ + ... pa.field("id", pa.int64()), + ... pa.field("name", pa.string()) + ... ]) + >>> table = tables("new_table", namespace="my_namespace").create(schema) + + >>> # Use a specific catalog + >>> table = tables("my_table", catalog="my_catalog").load() + + >>> # Create a table if it doesn't exist + >>> table = tables("my_table").create_if_not_exists(schema) """ if isinstance(catalog, str): catalog = load_catalog(catalog) diff --git a/src/tower/exceptions.py b/src/tower/exceptions.py new file mode 100644 index 00000000..e2a7295f --- /dev/null +++ b/src/tower/exceptions.py @@ -0,0 +1,29 @@ +class NotFoundException(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class UnauthorizedException(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class UnknownException(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class UnhandledRunStateException(Exception): + def __init__(self, state: str): + message = f"Run state '{state}' was unexpected. Maybe you need to upgrade to the latest Tower SDK." + super().__init__(message) + + +class TimeoutException(Exception): + def __init__(self, time: float): + super().__init__(f"A timeout occurred after {time} seconds.") + + +class RunFailedError(RuntimeError): + def __init__(self, app_name: str, number: int, state: str): + super().__init__(f"Run {app_name}#{number} failed with status '{state}'") diff --git a/src/tower/utils/pyarrow.py b/src/tower/utils/pyarrow.py index a0ad414c..9b8c1111 100644 --- a/src/tower/utils/pyarrow.py +++ b/src/tower/utils/pyarrow.py @@ -3,93 +3,202 @@ import pyarrow as pa import pyarrow.compute as pc -import pyiceberg.types as types +from pyiceberg import types as iceberg_types from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.expressions import ( BooleanExpression, - And, Or, Not, - EqualTo, NotEqualTo, - GreaterThan, GreaterThanOrEqual, - LessThan, LessThanOrEqual, - Literal, Reference + And, + Or, + Not, + EqualTo, + NotEqualTo, + GreaterThan, + GreaterThanOrEqual, + LessThan, + LessThanOrEqual, + Reference, ) -def arrow_to_iceberg_type(arrow_type): + +class FieldIdManager: + """ + Manages the assignment of unique field IDs. + Field IDs in Iceberg start from 1. + """ + + def __init__(self, start_id=1): + # Initialize current_id to start_id - 1 so the first call to get_next_id() returns start_id + self.current_id = start_id - 1 + + def get_next_id(self) -> int: + """Returns the next available unique field ID.""" + self.current_id += 1 + return self.current_id + + +def arrow_to_iceberg_type_recursive( + arrow_type: pa.DataType, field_id_manager: FieldIdManager +) -> iceberg_types.IcebergType: """ - Convert a PyArrow type to a PyIceberg type. Special thanks to Claude for - the help on this. + Recursively convert a PyArrow DataType to a PyIceberg type, + managing field IDs for nested structures. """ - - if pa.types.is_boolean(arrow_type): - return types.BooleanType() + # Primitive type mappings (most remain the same) + if pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type): + return iceberg_types.StringType() elif pa.types.is_integer(arrow_type): - # Check the bit width to determine the appropriate Iceberg integer type - bit_width = arrow_type.bit_width - if bit_width <= 32: - return types.IntegerType() + if arrow_type.bit_width <= 32: # type: ignore + return iceberg_types.IntegerType() else: - return types.LongType() + return iceberg_types.LongType() elif pa.types.is_floating(arrow_type): - if arrow_type.bit_width == 32: - return types.FloatType() + if arrow_type.bit_width <= 32: # type: ignore + return iceberg_types.FloatType() else: - return types.DoubleType() - elif pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type): - return types.StringType() - elif pa.types.is_binary(arrow_type) or pa.types.is_large_binary(arrow_type): - return types.BinaryType() + return iceberg_types.DoubleType() + elif pa.types.is_boolean(arrow_type): + return iceberg_types.BooleanType() elif pa.types.is_date(arrow_type): - return types.DateType() - elif pa.types.is_timestamp(arrow_type): - return types.TimestampType() + return iceberg_types.DateType() elif pa.types.is_time(arrow_type): - return types.TimeType() + return iceberg_types.TimeType() + elif pa.types.is_timestamp(arrow_type): + if arrow_type.tz is not None: # type: ignore + return iceberg_types.TimestamptzType() + else: + return iceberg_types.TimestampType() + elif pa.types.is_binary(arrow_type) or pa.types.is_large_binary(arrow_type): + return iceberg_types.BinaryType() + elif pa.types.is_fixed_size_binary(arrow_type): + return iceberg_types.FixedType(length=arrow_type.byte_width) # type: ignore elif pa.types.is_decimal(arrow_type): - precision = arrow_type.precision - scale = arrow_type.scale - return types.DecimalType(precision, scale) - elif pa.types.is_list(arrow_type): - element_type = arrow_to_iceberg_type(arrow_type.value_type) - return types.ListType(element_type) + return iceberg_types.DecimalType(arrow_type.precision, arrow_type.scale) # type: ignore + + # Nested type mappings + elif ( + pa.types.is_list(arrow_type) + or pa.types.is_large_list(arrow_type) + or pa.types.is_fixed_size_list(arrow_type) + ): + # The element field itself in Iceberg needs an ID. + element_id = field_id_manager.get_next_id() + + # Recursively convert the list's element type. + # arrow_type.value_type is the DataType of the elements. + # arrow_type.value_field is the Field of the elements (contains name, type, nullability). + element_pyarrow_type = arrow_type.value_type # type: ignore + element_iceberg_type = arrow_to_iceberg_type_recursive( + element_pyarrow_type, field_id_manager + ) + + # Determine if the elements themselves are required (not nullable). + element_is_required = not arrow_type.value_field.nullable # type: ignore + + return iceberg_types.ListType( + element_id=element_id, + element_type=element_iceberg_type, + element_required=element_is_required, + ) elif pa.types.is_struct(arrow_type): - fields = [] - for i, field in enumerate(arrow_type): - name = field.name - field_type = arrow_to_iceberg_type(field.type) - fields.append(types.NestedField(i + 1, name, field_type, required=not field.nullable)) - return types.StructType(*fields) + struct_iceberg_fields = [] + # arrow_type is a StructType. Iterate through its fields. + for i in range(arrow_type.num_fields): # type: ignore + pyarrow_child_field = arrow_type.field(i) # This is a pyarrow.Field + + # Each field within the struct needs its own unique ID. + nested_field_id = field_id_manager.get_next_id() + nested_iceberg_type = arrow_to_iceberg_type_recursive( + pyarrow_child_field.type, field_id_manager + ) + + doc = None + if pyarrow_child_field.metadata and b"doc" in pyarrow_child_field.metadata: + doc = pyarrow_child_field.metadata[b"doc"].decode("utf-8") + + struct_iceberg_fields.append( + iceberg_types.NestedField( + field_id=nested_field_id, + name=pyarrow_child_field.name, + field_type=nested_iceberg_type, + required=not pyarrow_child_field.nullable, + doc=doc, + ) + ) + return iceberg_types.StructType(*struct_iceberg_fields) elif pa.types.is_map(arrow_type): - key_type = arrow_to_iceberg_type(arrow_type.key_type) - value_type = arrow_to_iceberg_type(arrow_type.item_type) - return types.MapType(key_type, value_type) + # Iceberg MapType requires IDs for key and value fields. + key_id = field_id_manager.get_next_id() + value_id = field_id_manager.get_next_id() + + key_iceberg_type = arrow_to_iceberg_type_recursive( + arrow_type.key_type, field_id_manager + ) # type: ignore + value_iceberg_type = arrow_to_iceberg_type_recursive( + arrow_type.item_type, field_id_manager + ) # type: ignore + + # PyArrow map keys are always non-nullable by Arrow specification. + # Nullability of map values comes from the item_field. + value_is_required = not arrow_type.item_field.nullable # type: ignore + + return iceberg_types.MapType( + key_id=key_id, + key_type=key_iceberg_type, + value_id=value_id, + value_type=value_iceberg_type, + value_required=value_is_required, + ) else: raise ValueError(f"Unsupported Arrow type: {arrow_type}") -def convert_pyarrow_field(num, field) -> types.NestedField: - name = field.name - field_type = arrow_to_iceberg_type(field.type) - field_id = num + 1 # Iceberg requires field IDs +def convert_pyarrow_schema( + arrow_schema: pa.Schema, schema_id: int = 1, start_field_id: int = 1 +) -> IcebergSchema: + """ + Convert a PyArrow schema to a PyIceberg schema. + + Args: + arrow_schema: The input PyArrow.Schema. + schema_id: The schema ID for the Iceberg schema. + start_field_id: The starting ID for field ID assignment. + Returns: + An IcebergSchema object. + """ + field_id_manager = FieldIdManager(start_id=start_field_id) + iceberg_fields = [] + + for pyarrow_field in arrow_schema: # pyarrow_field is a pa.Field object + # Assign a unique ID for this top-level field. + top_level_field_id = field_id_manager.get_next_id() - return types.NestedField( - field_id, - name, - field_type, - required=not field.nullable - ) + # Recursively convert the field's type. This will handle ID assignment + # for any nested structures using the same field_id_manager. + iceberg_field_type = arrow_to_iceberg_type_recursive( + pyarrow_field.type, field_id_manager + ) + doc = None + if pyarrow_field.metadata and b"doc" in pyarrow_field.metadata: + doc = pyarrow_field.metadata[b"doc"].decode("utf-8") -def convert_pyarrow_schema(arrow_schema: pa.Schema) -> IcebergSchema: - """Convert a PyArrow schema to a PyIceberg schema.""" - fields = [convert_pyarrow_field(i, field) for i, field in enumerate(arrow_schema)] - return IcebergSchema(*fields) + iceberg_fields.append( + iceberg_types.NestedField( + field_id=top_level_field_id, + name=pyarrow_field.name, + field_type=iceberg_field_type, + required=not pyarrow_field.nullable, # Top-level field nullability + doc=doc, + ) + ) + return IcebergSchema(*iceberg_fields, schema_id=schema_id) def extract_field_and_literal(expr: pc.Expression) -> tuple[str, Any]: """Extract field name and literal value from a comparison expression.""" # First, convert the expression to a string and parse it expr_str = str(expr) - + # PyArrow expression strings look like: "(field_name == literal)" or similar # Need to determine the operator and then split accordingly operators = ["==", "!=", ">", ">=", "<", "<="] @@ -98,36 +207,38 @@ def extract_field_and_literal(expr: pc.Expression) -> tuple[str, Any]: if op in expr_str: op_used = op break - + if not op_used: - raise ValueError(f"Could not find comparison operator in expression: {expr_str}") - + raise ValueError( + f"Could not find comparison operator in expression: {expr_str}" + ) + # Remove parentheses and split by operator expr_clean = expr_str.strip("()") parts = expr_clean.split(op_used) if len(parts) != 2: raise ValueError(f"Expected binary comparison in expression: {expr_str}") - + # Determine which part is the field and which is the literal field_name = None literal_value = None - + # Clean up the parts left = parts[0].strip() right = parts[1].strip() - + # Typically field name doesn't have quotes, literals (strings) do if left.startswith('"') or left.startswith("'"): # Right side is the field field_name = right # Extract the literal value - this is a simplification - literal_value = left.strip('"\'') + literal_value = left.strip("\"'") else: # Left side is the field field_name = left # Extract the literal value - this is a simplification - literal_value = right.strip('"\'') - + literal_value = right.strip("\"'") + # Try to convert numeric literals try: if "." in literal_value: @@ -137,17 +248,18 @@ def extract_field_and_literal(expr: pc.Expression) -> tuple[str, Any]: except ValueError: # Keep as string if not numeric pass - + return field_name, literal_value + def convert_pyarrow_expression(expr: pc.Expression) -> Optional[BooleanExpression]: """Convert a PyArrow compute expression to a PyIceberg boolean expression.""" if expr is None: return None - + # Handle the expression based on its string representation expr_str = str(expr) - + # Handle logical operations if "and" in expr_str.lower() and isinstance(expr, pc.Expression): # This is a simplification - in real code, you'd need to parse the expression @@ -156,7 +268,7 @@ def convert_pyarrow_expression(expr: pc.Expression) -> Optional[BooleanExpressio right_expr = None # You'd need to extract this return And( convert_pyarrow_expression(left_expr), - convert_pyarrow_expression(right_expr) + convert_pyarrow_expression(right_expr), ) elif "or" in expr_str.lower() and isinstance(expr, pc.Expression): # Similar simplification @@ -164,13 +276,13 @@ def convert_pyarrow_expression(expr: pc.Expression) -> Optional[BooleanExpressio right_expr = None # You'd need to extract this return Or( convert_pyarrow_expression(left_expr), - convert_pyarrow_expression(right_expr) + convert_pyarrow_expression(right_expr), ) elif "not" in expr_str.lower() and isinstance(expr, pc.Expression): # Similar simplification inner_expr = None # You'd need to extract this return Not(convert_pyarrow_expression(inner_expr)) - + # Handle comparison operations try: if "==" in expr_str: @@ -204,13 +316,13 @@ def convert_pyarrow_expressions(exprs: List[pc.Expression]) -> BooleanExpression """ if not exprs: raise ValueError("No expressions provided") - + if len(exprs) == 1: return convert_pyarrow_expression(exprs[0]) - + # Combine multiple expressions with AND result = convert_pyarrow_expression(exprs[0]) for expr in exprs[1:]: result = And(result, convert_pyarrow_expression(expr)) - + return result diff --git a/tests/tower/test_client.py b/tests/tower/test_client.py index e9a75c70..55174841 100644 --- a/tests/tower/test_client.py +++ b/tests/tower/test_client.py @@ -1,121 +1,324 @@ - import os -import httpx import pytest +from datetime import datetime +from typing import List, Dict, Any, Optional -from tower.tower_api_client.models import ( - Run, -) +from tower.tower_api_client.models import Run +from tower.exceptions import RunFailedError -def test_running_apps(httpx_mock): - # Mock the response from the API - httpx_mock.add_response( - method="POST", - url="https://api.example.com/v1/apps/my-app/runs", - json={ + +@pytest.fixture +def mock_api_config(): + """Configure the Tower API client to use mock server.""" + os.environ["TOWER_URL"] = "https://api.example.com" + os.environ["TOWER_API_KEY"] = "abc123" + + # Only import after environment is configured + import tower + # Set WAIT_TIMEOUT to 0 to avoid actual waiting in tests + tower._client.WAIT_TIMEOUT = 0 + + return tower + + +@pytest.fixture +def mock_run_response_factory(): + """Factory to create consistent run response objects.""" + def _create_run_response( + app_slug: str = "my-app", + app_version: str = "v6", + number: int = 0, + run_id: str = "50ac9bc1-c783-4359-9917-a706f20dc02c", + status: str = "pending", + status_group: str = "", + parameters: Optional[List[Dict[str, Any]]] = None + ) -> Dict[str, Any]: + """Create a mock run response with the given parameters.""" + if parameters is None: + parameters = [] + + return { "run": { - "app_slug": "my-app", - "app_version": "v6", + "app_slug": app_slug, + "app_version": app_version, "cancelled_at": None, - "created_at": "2025-04-25T20:54:58.762547Z", - "ended_at": "2025-04-25T20:55:35.220295Z", - "environment": "default", - "number": 0, - "run_id": "50ac9bc1-c783-4359-9917-a706f20dc02c", + "created_at": "2025-04-25T20:54:58.762547Z", + "ended_at": "2025-04-25T20:55:35.220295Z", + "environment": "default", + "number": number, + "run_id": run_id, "scheduled_at": "2025-04-25T20:54:58.761867Z", - "started_at": "2025-04-25T20:54:59.366937Z", - "status": "pending", - "status_group": "", - "parameters": [] + "started_at": "2025-04-25T20:54:59.366937Z", + "status": status, + "status_group": status_group, + "parameters": parameters } - }, + } + + return _create_run_response + + +@pytest.fixture +def create_run_object(): + """Factory to create Run objects for testing.""" + def _create_run( + app_slug: str = "my-app", + app_version: str = "v6", + number: int = 0, + run_id: str = "50ac9bc1-c783-4359-9917-a706f20dc02c", + status: str = "running", + status_group: str = "failed", + parameters: Optional[List[Dict[str, Any]]] = None + ) -> Run: + """Create a Run object with the given parameters.""" + if parameters is None: + parameters = [] + + return Run( + app_slug=app_slug, + app_version=app_version, + cancelled_at=None, + created_at="2025-04-25T20:54:58.762547Z", + ended_at="2025-04-25T20:55:35.220295Z", + environment="default", + number=number, + run_id=run_id, + scheduled_at="2025-04-25T20:54:58.761867Z", + started_at="2025-04-25T20:54:59.366937Z", + status=status, + status_group=status_group, + parameters=parameters + ) + + return _create_run + + +def test_running_apps(httpx_mock, mock_api_config, mock_run_response_factory): + # Mock the response from the API + httpx_mock.add_response( + method="POST", + url="https://api.example.com/v1/apps/my-app/runs", + json=mock_run_response_factory(), status_code=200, ) - # We tell the client to use the mock server. - os.environ["TOWER_URL"] = "https://api.example.com" - os.environ["TOWER_API_KEY"] = "abc123" - # Call the function that makes the API request - import tower + tower = mock_api_config run: Run = tower.run_app("my-app", environment="production") # Assert the response assert run is not None + assert run.app_slug == "my-app" + assert run.status == "pending" -def test_waiting_for_runs(httpx_mock): - # Mock the response from the API + +def test_waiting_for_a_run(httpx_mock, mock_api_config, mock_run_response_factory, create_run_object): + run_number = 3 + + # First response: pending status httpx_mock.add_response( method="GET", - url="https://api.example.com/v1/apps/my-app/runs/3", - json={ - "run": { - "app_slug": "my-app", - "app_version": "v6", - "cancelled_at": None, - "created_at": "2025-04-25T20:54:58.762547Z", - "ended_at": "2025-04-25T20:55:35.220295Z", - "environment": "default", - "number": 3, - "run_id": "50ac9bc1-c783-4359-9917-a706f20dc02c", - "scheduled_at": "2025-04-25T20:54:58.761867Z", - "started_at": "2025-04-25T20:54:59.366937Z", - "status": "pending", - "status_group": "", - "parameters": [] - } - }, + url=f"https://api.example.com/v1/apps/my-app/runs/{run_number}", + json=mock_run_response_factory(number=run_number, status="pending"), status_code=200, ) - # Second request, will indicate that it's done. + # Second response: completed status httpx_mock.add_response( method="GET", - url="https://api.example.com/v1/apps/my-app/runs/3", - json={ - "run": { - "app_slug": "my-app", - "app_version": "v6", - "cancelled_at": None, - "created_at": "2025-04-25T20:54:58.762547Z", - "ended_at": "2025-04-25T20:55:35.220295Z", - "environment": "default", - "number": 3, - "run_id": "50ac9bc1-c783-4359-9917-a706f20dc02c", - "scheduled_at": "2025-04-25T20:54:58.761867Z", - "started_at": "2025-04-25T20:54:59.366937Z", - "status": "exited", - "status_group": "successful", - "parameters": [] - } - }, + url=f"https://api.example.com/v1/apps/my-app/runs/{run_number}", + json=mock_run_response_factory(number=run_number, status="exited", status_group="successful"), status_code=200, ) - # We tell the client to use the mock server. - os.environ["TOWER_URL"] = "https://api.example.com" - os.environ["TOWER_API_KEY"] = "abc123" + tower = mock_api_config + run = create_run_object(number=run_number, status="pending") - import tower + # Now actually wait for the run + final_run = tower.wait_for_run(run) + + # Verify the final state + assert final_run.status == "exited" + assert final_run.status_group == "successful" - run = Run( - app_slug="my-app", - app_version="v6", - cancelled_at=None, - created_at="2025-04-25T20:54:58.762547Z", - ended_at="2025-04-25T20:55:35.220295Z", - environment="default", - number=3, - run_id="50ac9bc1-c783-4359-9917-a706f20dc02c", - scheduled_at="2025-04-25T20:54:58.761867Z", - started_at="2025-04-25T20:54:59.366937Z", - status="crashed", - status_group="failed", - parameters=[] + +@pytest.mark.parametrize("run_numbers", [(3, 4)]) +def test_waiting_for_multiple_runs( + httpx_mock, + mock_api_config, + mock_run_response_factory, + create_run_object, + run_numbers +): + tower = mock_api_config + runs = [] + + # Setup mocks for each run + for run_number in run_numbers: + # First response: pending status + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/{run_number}", + json=mock_run_response_factory(number=run_number, status="pending"), + status_code=200, + ) + + # Second response: completed status + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/{run_number}", + json=mock_run_response_factory(number=run_number, status="exited", status_group="successful"), + status_code=200, + ) + + # Create the Run object + runs.append(create_run_object(number=run_number)) + + # Now actually wait for the runs + successful_runs, failed_runs = tower.wait_for_runs(runs) + + assert len(failed_runs) == 0 + + # Verify all runs completed successfully + for run in successful_runs: + assert run.status == "exited" + assert run.status_group == "successful" + + +def test_failed_runs_in_the_list( + httpx_mock, + mock_api_config, + mock_run_response_factory, + create_run_object +): + tower = mock_api_config + runs = [] + + # For the first run, we're going to simulate a success. + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/1", + json=mock_run_response_factory(number=1, status="pending"), + status_code=200, ) - # Set WAIT_TIMEOUT to 0 so we don't have to...wait. - tower._client.WAIT_TIMEOUT = 0 + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/1", + json=mock_run_response_factory(number=1, status="exited", status_group="successful"), + status_code=200, + ) + + runs.append(create_run_object(number=1)) + + # Second run will have been a failure. + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/2", + json=mock_run_response_factory(number=2, status="pending"), + status_code=200, + ) + + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/2", + json=mock_run_response_factory(number=2, status="crashed", status_group="failed"), + status_code=200, + ) + + runs.append(create_run_object(number=2)) + + # Third run was a success. + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/3", + json=mock_run_response_factory(number=3, status="pending"), + status_code=200, + ) + + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/3", + json=mock_run_response_factory(number=3, status="exited", status_group="successful"), + status_code=200, + ) + + runs.append(create_run_object(number=3)) + + + # Now actually wait for the runs + successful_runs, failed_runs = tower.wait_for_runs(runs) + + assert len(failed_runs) == 1 + + # Verify all successful runs + for run in successful_runs: + assert run.status == "exited" + assert run.status_group == "successful" + + # Verify all failed + for run in failed_runs: + assert run.status == "crashed" + assert run.status_group == "failed" + + +def test_raising_an_error_during_partial_failure( + httpx_mock, + mock_api_config, + mock_run_response_factory, + create_run_object +): + tower = mock_api_config + runs = [] + + # For the first run, we're going to simulate a success. + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/1", + json=mock_run_response_factory(number=1, status="pending"), + status_code=200, + ) + + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/1", + json=mock_run_response_factory(number=1, status="exited", status_group="successful"), + status_code=200, + ) + + runs.append(create_run_object(number=1)) + + # Second run will have been a failure. + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/2", + json=mock_run_response_factory(number=2, status="pending"), + status_code=200, + ) + + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/2", + json=mock_run_response_factory(number=2, status="crashed", status_group="failed"), + status_code=200, + ) + + runs.append(create_run_object(number=2)) + + # Third run was a success. + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/3", + json=mock_run_response_factory(number=3, status="pending"), + status_code=200, + ) + + # NOTE: We don't have a second response for this run because we'll never + # get to it. + + runs.append(create_run_object(number=3)) + - # Now actually wait for the run. - tower.wait_for_run(run) + # Now actually wait for the runs + with pytest.raises(RunFailedError) as excinfo: + tower.wait_for_runs(runs, raise_on_failure=True) diff --git a/tests/tower/test_tables.py b/tests/tower/test_tables.py index 3ef0b027..bf044b25 100644 --- a/tests/tower/test_tables.py +++ b/tests/tower/test_tables.py @@ -14,13 +14,14 @@ # Imports the library under test import tower + def get_temp_dir(): """Create a temporary directory and return its file:// URL.""" # Create a temporary directory that will be automatically cleaned up temp_dir = tempfile.TemporaryDirectory() abs_path = pathlib.Path(temp_dir.name).absolute() - file_url = urljoin('file:', pathname2url(str(abs_path))) - + file_url = urljoin("file:", pathname2url(str(abs_path))) + # Return both the URL and the path to the temporary directory return file_url, abs_path @@ -38,21 +39,41 @@ def in_memory_catalog(): def test_reading_and_writing_to_tables(in_memory_catalog): - schema = pa.schema([ - pa.field("id", pa.int64()), - pa.field("name", pa.string()), - pa.field("age", pa.int32()), - pa.field("created_at", pa.timestamp("ms")), - ]) + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("name", pa.string()), + pa.field("age", pa.int32()), + pa.field("created_at", pa.timestamp("ms")), + ] + ) ref = tower.tables("users", catalog=in_memory_catalog) table = ref.create_if_not_exists(schema) - data_with_schema = pa.Table.from_pylist([ - {"id": 1, "name": "Alice", "age": 30, "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0)}, - {"id": 2, "name": "Bob", "age": 25, "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0)}, - {"id": 3, "name": "Charlie", "age": 35, "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0)}, - ], schema=schema) + data_with_schema = pa.Table.from_pylist( + [ + { + "id": 1, + "name": "Alice", + "age": 30, + "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0), + }, + { + "id": 2, + "name": "Bob", + "age": 25, + "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0), + }, + { + "id": 3, + "name": "Charlie", + "age": 35, + "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0), + }, + ], + schema=schema, + ) # If we write some data to the table, that should be...OK. table = table.insert(data_with_schema) @@ -66,24 +87,48 @@ def test_reading_and_writing_to_tables(in_memory_catalog): avg_age = df.select(pl.mean("age").alias("mean_age")).collect().item() assert avg_age == 30.0 + def test_upsert_to_tables(in_memory_catalog): - schema = pa.schema([ - pa.field("id", pa.int64()), - pa.field("username", pa.string()), - pa.field("name", pa.string()), - pa.field("age", pa.int32()), - pa.field("created_at", pa.timestamp("ms")), - ]) + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("username", pa.string()), + pa.field("name", pa.string()), + pa.field("age", pa.int32()), + pa.field("created_at", pa.timestamp("ms")), + ] + ) # First we'll insert some data into the relevant table. ref = tower.tables("users", catalog=in_memory_catalog) table = ref.create_if_not_exists(schema) - data_with_schema = pa.Table.from_pylist([ - {"id": 1, "username": "alicea", "name": "Alice", "age": 30, "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0)}, - {"id": 2, "username": "bobb", "name": "Bob", "age": 25, "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0)}, - {"id": 3, "username": "charliec", "name": "Charlie", "age": 35, "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0)}, - ], schema=schema) + data_with_schema = pa.Table.from_pylist( + [ + { + "id": 1, + "username": "alicea", + "name": "Alice", + "age": 30, + "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0), + }, + { + "id": 2, + "username": "bobb", + "name": "Bob", + "age": 25, + "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0), + }, + { + "id": 3, + "username": "charliec", + "name": "Charlie", + "age": 35, + "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0), + }, + ], + schema=schema, + ) # Make sure that we can actually insert the data into the table. table = table.insert(data_with_schema) @@ -91,13 +136,22 @@ def test_upsert_to_tables(in_memory_catalog): assert table.rows_affected().inserts == 3 # Now we'll update records in the table. - data_with_schema = pa.Table.from_pylist([ - {"id": 2, "username": "bobb", "name": "Bob", "age": 26, "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0)}, - ], schema=schema) + data_with_schema = pa.Table.from_pylist( + [ + { + "id": 2, + "username": "bobb", + "name": "Bob", + "age": 26, + "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0), + }, + ], + schema=schema, + ) # And make sure we can upsert the data. table = table.upsert(data_with_schema, join_cols=["username"]) - assert table.rows_affected().updates == 1 + assert table.rows_affected().updates == 1 # Now let's read from the table and see what we get back out. df = table.to_polars() @@ -107,24 +161,48 @@ def test_upsert_to_tables(in_memory_catalog): # The age should match what we updated the relevant record to assert res["age"].item() == 26 + def test_delete_from_tables(in_memory_catalog): - schema = pa.schema([ - pa.field("id", pa.int64()), - pa.field("username", pa.string()), - pa.field("name", pa.string()), - pa.field("age", pa.int32()), - pa.field("created_at", pa.timestamp("ms")), - ]) + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("username", pa.string()), + pa.field("name", pa.string()), + pa.field("age", pa.int32()), + pa.field("created_at", pa.timestamp("ms")), + ] + ) # First we'll insert some data into the relevant table. ref = tower.tables("users", catalog=in_memory_catalog) table = ref.create_if_not_exists(schema) - data_with_schema = pa.Table.from_pylist([ - {"id": 1, "username": "alicea", "name": "Alice", "age": 30, "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0)}, - {"id": 2, "username": "bobb", "name": "Bob", "age": 25, "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0)}, - {"id": 3, "username": "charliec", "name": "Charlie", "age": 35, "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0)}, - ], schema=schema) + data_with_schema = pa.Table.from_pylist( + [ + { + "id": 1, + "username": "alicea", + "name": "Alice", + "age": 30, + "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0), + }, + { + "id": 2, + "username": "bobb", + "name": "Bob", + "age": 25, + "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0), + }, + { + "id": 3, + "username": "charliec", + "name": "Charlie", + "age": 35, + "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0), + }, + ], + schema=schema, + ) # Make sure that we can actually insert the data into the table. table = table.insert(data_with_schema) @@ -132,9 +210,7 @@ def test_delete_from_tables(in_memory_catalog): assert table.rows_affected().inserts == 3 # Perform the underlying delete from the table... - table.delete(filters=[ - table.column("username") == "bobb" - ]) + table.delete(filters=[table.column("username") == "bobb"]) # ...and let's make sure that record is actually gone. df = table.to_polars() @@ -143,14 +219,17 @@ def test_delete_from_tables(in_memory_catalog): all_rows = df.collect() assert all_rows.height == 2 + def test_getting_schemas_for_tables(in_memory_catalog): - original_schema = pa.schema([ - pa.field("id", pa.int64()), - pa.field("username", pa.string()), - pa.field("name", pa.string()), - pa.field("age", pa.int32()), - pa.field("created_at", pa.timestamp("ms")), - ]) + original_schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("username", pa.string()), + pa.field("name", pa.string()), + pa.field("age", pa.int32()), + pa.field("created_at", pa.timestamp("ms")), + ] + ) # First we'll insert some data into the relevant table. ref = tower.tables("users", catalog=in_memory_catalog) @@ -163,3 +242,317 @@ def test_getting_schemas_for_tables(in_memory_catalog): assert new_schema.field("id") is not None assert new_schema.field("age") is not None assert new_schema.field("created_at") is not None + + +def test_list_of_structs(in_memory_catalog): + """Tests writing and reading a list of non-nullable structs.""" + table_name = "test_list_of_structs_table" + # Define a pyarrow schema with a list of structs + # The list 'tags' can be null, but its elements (structs) are not nullable. + # Inside the struct, 'key' is non-nullable, 'value' is nullable. + item_struct_type = pa.struct( + [ + pa.field("key", pa.string(), nullable=False), # Non-nullable key + pa.field("value", pa.int64(), nullable=True), # Nullable value + ] + ) + # The 'item' field represents the elements of the list. It's non-nullable. + # This means each element in the list must be a valid struct, not a Python None. + pa_schema = pa.schema( + [ + pa.field("doc_id", pa.int32(), nullable=False), + pa.field( + "tags", + pa.list_(pa.field("item", item_struct_type, nullable=False)), + nullable=True, + ), + ] + ) + + ref = tower.tables(table_name, catalog=in_memory_catalog) + table = ref.create_if_not_exists(pa_schema) + assert table is not None, f"Table '{table_name}' should have been created" + + data_to_write = [ + { + "doc_id": 1, + "tags": [{"key": "user", "value": 100}, {"key": "priority", "value": 1}], + }, + { + "doc_id": 2, + "tags": [ + {"key": "source", "value": 200}, + {"key": "reviewed", "value": None}, + ], + }, # Null value for a struct field + {"doc_id": 3, "tags": []}, # Empty list + {"doc_id": 4, "tags": None}, # Null list + ] + arrow_table_write = pa.Table.from_pylist(data_to_write, schema=pa_schema) + + op_result = table.insert(arrow_table_write) + assert op_result is not None + assert op_result.rows_affected().inserts == 4 + + # Read back and verify + df_read = table.to_polars().collect() # Collect to get a Polars DataFrame + + assert df_read.shape[0] == 4 + assert df_read["doc_id"].to_list() == [1, 2, 3, 4] + + # Verify nested data (Polars handles structs and lists well) + # For doc_id = 1 + tags_doc1 = ( + df_read.filter(pl.col("doc_id") == 1).select("tags").row(0)[0] + ) # Get the list of structs + assert len(tags_doc1) == 2 + assert tags_doc1[0]["key"] == "user" + assert tags_doc1[0]["value"] == 100 + assert tags_doc1[1]["key"] == "priority" + assert tags_doc1[1]["value"] == 1 + + # For doc_id = 2 (with a null inside a struct) + tags_doc2 = df_read.filter(pl.col("doc_id") == 2).select("tags").row(0)[0] + assert len(tags_doc2) == 2 + assert tags_doc2[0]["key"] == "source" + assert tags_doc2[0]["value"] == 200 + assert tags_doc2[1]["key"] == "reviewed" + assert tags_doc2[1]["value"] is None + + # For doc_id = 3 (empty list) + tags_doc3 = df_read.filter(pl.col("doc_id") == 3).select("tags").row(0)[0] + assert len(tags_doc3) == 0 + + # For doc_id = 4 (null list should also be an empty list) + tags_doc4 = df_read.filter(pl.col("doc_id") == 4).select("tags").row(0)[0] + assert tags_doc4 == [] + + +def test_nested_structs(in_memory_catalog): + """Tests writing and reading a table with nested structs.""" + table_name = "test_nested_structs_table" + # Define a pyarrow schema with nested structs + # config: struct> + settings_struct_type = pa.struct( + [ + pa.field("retries", pa.int8(), nullable=False), + pa.field("timeout", pa.int32(), nullable=True), + pa.field("active", pa.bool_(), nullable=False), + ] + ) + pa_schema = pa.schema( + [ + pa.field("record_id", pa.string(), nullable=False), + pa.field( + "config", + pa.struct( + [ + pa.field("name", pa.string(), nullable=True), + pa.field( + "details", settings_struct_type, nullable=True + ), # This inner struct can be null + ] + ), + nullable=True, + ), # The outer 'config' struct can also be null + ] + ) + + ref = tower.tables(table_name, catalog=in_memory_catalog) + table = ref.create_if_not_exists(pa_schema) + assert table is not None, f"Table '{table_name}' should have been created" + + data_to_write = [ + { + "record_id": "rec1", + "config": { + "name": "Default", + "details": {"retries": 3, "timeout": 1000, "active": True}, + }, + }, + { + "record_id": "rec2", + "config": { + "name": "Fast", + "details": {"retries": 1, "timeout": None, "active": True}, + }, + }, # Null timeout + { + "record_id": "rec3", + "config": {"name": "Inactive", "details": None}, + }, # Null inner struct + {"record_id": "rec4", "config": None}, # Null outer struct + ] + arrow_table_write = pa.Table.from_pylist(data_to_write, schema=pa_schema) + + op_result = table.insert(arrow_table_write) + assert op_result is not None + assert op_result.rows_affected().inserts == 4 + + # Read back and verify + df_read = table.to_polars().collect() + + assert df_read.shape[0] == 4 + assert df_read["record_id"].to_list() == ["rec1", "rec2", "rec3", "rec4"] + + # Verify nested data for rec1 + config_rec1 = ( + df_read.filter(pl.col("record_id") == "rec1").select("config").row(0)[0] + ) + assert config_rec1["name"] == "Default" + details_rec1 = config_rec1["details"] + assert details_rec1["retries"] == 3 + assert details_rec1["timeout"] == 1000 + assert details_rec1["active"] is True + + # Verify nested data for rec2 (null timeout) + config_rec2 = ( + df_read.filter(pl.col("record_id") == "rec2").select("config").row(0)[0] + ) + assert config_rec2["name"] == "Fast" + details_rec2 = config_rec2["details"] + assert details_rec2["retries"] == 1 + assert details_rec2["timeout"] is None + assert details_rec2["active"] is True + + # Verify nested data for rec3 (null inner struct 'details') + config_rec3 = ( + df_read.filter(pl.col("record_id") == "rec3").select("config").row(0)[0] + ) + assert config_rec3["name"] == "Inactive" + assert config_rec3["details"] is None # The 'details' struct itself is null + + # Verify nested data for rec4 (null outer struct 'config') + config_rec4 = ( + df_read.filter(pl.col("record_id") == "rec4").select("config").row(0)[0] + ) + assert config_rec4 is None # The 'config' struct is null + + +def test_list_of_primitive_types(in_memory_catalog): + """Tests writing and reading a list of primitive types.""" + table_name = "test_list_of_primitives_table" + pa_schema = pa.schema( + [ + pa.field("event_id", pa.int32(), nullable=False), + pa.field( + "scores", + pa.list_(pa.field("score", pa.float32(), nullable=False)), + nullable=True, + ), # List of non-nullable floats + pa.field( + "keywords", + pa.list_(pa.field("keyword", pa.string(), nullable=True)), + nullable=True, + ), # List of nullable strings + ] + ) + + ref = tower.tables(table_name, catalog=in_memory_catalog) + table = ref.create_if_not_exists(pa_schema) + assert table is not None, f"Table '{table_name}' should have been created" + + data_to_write = [ + {"event_id": 1, "scores": [1.0, 2.5, 3.0], "keywords": ["alpha", "beta", None]}, + {"event_id": 2, "scores": [], "keywords": ["gamma"]}, + {"event_id": 3, "scores": None, "keywords": None}, + {"event_id": 4, "scores": [4.2], "keywords": []}, + ] + arrow_table_write = pa.Table.from_pylist(data_to_write, schema=pa_schema) + + op_result = table.insert(arrow_table_write) + assert op_result is not None + assert op_result.rows_affected().inserts == 4 + + df_read = table.to_polars().collect() + + assert df_read.shape[0] == 4 + + # Event 1 + row1 = df_read.filter(pl.col("event_id") == 1) + assert row1.select("scores").to_series()[0].to_list() == [1.0, 2.5, 3.0] + assert row1.select("keywords").to_series()[0].to_list() == ["alpha", "beta", None] + + # Event 2 + row2 = df_read.filter(pl.col("event_id") == 2) + assert row2.select("scores").to_series()[0].to_list() == [] + assert row2.select("keywords").to_series()[0].to_list() == ["gamma"] + + # Event 3 + row3 = df_read.filter(pl.col("event_id") == 3) + assert row3.select("scores").to_series()[0] is None + assert row3.select("keywords").to_series()[0] is None + + +def test_map_type_simple(in_memory_catalog): + """Tests writing and reading a simple map type.""" + table_name = "test_map_type_simple_table" + # Map from string to string. Keys are non-nullable, values can be nullable. + pa_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field( + "properties", + pa.map_(pa.string(), pa.string(), keys_sorted=False), + nullable=True, + ), + # Note: PyArrow map values are nullable by default if item_field is not specified with nullable=False + ] + ) + + ref = tower.tables(table_name, catalog=in_memory_catalog) + table = ref.create_if_not_exists(pa_schema) + assert table is not None, f"Table '{table_name}' should have been created" + + # PyArrow represents maps as a list of structs with 'key' and 'value' fields + data_to_write = [ + {"id": 1, "properties": [("color", "blue"), ("size", "large")]}, + { + "id": 2, + "properties": [("status", "pending"), ("owner", None)], + }, # Null value in map + {"id": 3, "properties": []}, # Empty map + {"id": 4, "properties": None}, # Null map field + ] + arrow_table_write = pa.Table.from_pylist(data_to_write, schema=pa_schema) + + op_result = table.insert(arrow_table_write) + assert op_result is not None + assert op_result.rows_affected().inserts == 4 + + df_read = table.to_polars().collect() + assert df_read.shape[0] == 4 + + # Verify map data + # Polars represents map as list of structs: struct + # Row 1 + props1_series = df_read.filter(pl.col("id") == 1).select("properties").to_series() + # The series item is already a list of dictionaries + props1_list = props1_series[0] + expected_props1 = [ + {"key": "color", "value": "blue"}, + {"key": "size", "value": "large"}, + ] + # Sort by key for consistent comparison if order is not guaranteed + assert sorted(props1_list, key=lambda x: x["key"]) == sorted( + expected_props1, key=lambda x: x["key"] + ) + + # Row 2 + props2_series = df_read.filter(pl.col("id") == 2).select("properties").to_series() + props2_list = props2_series[0] + expected_props2 = [ + {"key": "status", "value": "pending"}, + {"key": "owner", "value": None}, + ] + assert sorted(props2_list, key=lambda x: x["key"]) == sorted( + expected_props2, key=lambda x: x["key"] + ) + + # Row 3 (empty map) + props3_series = df_read.filter(pl.col("id") == 3).select("properties").to_series() + assert props3_series[0].to_list() == [] + + # Row 4 (null map) + props4_series = df_read.filter(pl.col("id") == 4).select("properties").to_series() + assert props4_series[0] is None diff --git a/uv.lock b/uv.lock index e8c5a210..5953f46c 100644 --- a/uv.lock +++ b/uv.lock @@ -1176,7 +1176,7 @@ wheels = [ [[package]] name = "tower" -version = "0.3.13" +version = "0.3.14" source = { editable = "." } dependencies = [ { name = "attrs" },