From 3218f1a7f86745bb2d74cc3df6536e8ff3602d68 Mon Sep 17 00:00:00 2001 From: Brad Heller Date: Fri, 9 May 2025 13:57:45 +0200 Subject: [PATCH 01/15] chore: Fix encryption, inject catalogs to local runtime enviro --- Cargo.lock | 105 ++++++++++++++++++++++++++++++ Cargo.toml | 4 +- crates/crypto/Cargo.toml | 2 + crates/crypto/src/errors.rs | 40 ++++++++++++ crates/crypto/src/lib.rs | 110 +++++++++++++++++++++----------- crates/tower-cmd/Cargo.toml | 1 + crates/tower-cmd/src/api.rs | 29 +++++++++ crates/tower-cmd/src/error.rs | 20 ++++++ crates/tower-cmd/src/lib.rs | 3 + crates/tower-cmd/src/run.rs | 80 ++++++++++++++++++----- crates/tower-cmd/src/secrets.rs | 4 +- 11 files changed, 339 insertions(+), 59 deletions(-) create mode 100644 crates/crypto/src/errors.rs diff --git a/Cargo.lock b/Cargo.lock index 03039ce9..9b864c92 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,41 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -224,6 +259,16 @@ dependencies = [ "windows-link", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "4.5.36" @@ -410,11 +455,13 @@ dependencies = [ name = "crypto" version = "0.3.13" dependencies = [ + "aes-gcm", "base64", "pem", "rand", "rsa", "sha2", + "snafu", "testutils", ] @@ -425,6 +472,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", + "rand_core", "typenum", ] @@ -449,6 +497,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "darling" version = "0.20.11" @@ -836,6 +893,16 @@ dependencies = [ "wasi 0.14.2+wasi-0.2.4", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "gimli" version = "0.31.1" @@ -1223,6 +1290,15 @@ dependencies = [ "web-time", ] +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "generic-array", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -1573,6 +1649,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "openssl" version = "0.10.72" @@ -1720,6 +1802,18 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portable-atomic" version = "1.11.0" @@ -2748,6 +2842,7 @@ dependencies = [ "serde", "serde_json", "simple_logger", + "snafu", "spinners", "tokio", "tokio-util", @@ -2873,6 +2968,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "untrusted" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index 0894f68b..66a2529f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,9 +5,6 @@ resolver = "2" [workspace.package] edition = "2021" version = "0.3.13" - - - description = "Tower is the best way to host Python data apps in production" rust-version = "1.77" authors = ["Brad Heller "] @@ -15,6 +12,7 @@ license = "MIT" repository = "https://github.com/tower/tower-cli" [workspace.dependencies] +aes-gcm = "0.10" anyhow = "1.0.95" async-compression = { version = "0.4", features = ["tokio", "gzip"] } base64 = "0.22" diff --git a/crates/crypto/Cargo.toml b/crates/crypto/Cargo.toml index dffea22f..87977f89 100644 --- a/crates/crypto/Cargo.toml +++ b/crates/crypto/Cargo.toml @@ -4,9 +4,11 @@ version = { workspace = true } edition = "2021" [dependencies] +aes-gcm = { workspace = true } base64 = { workspace = true } pem = { workspace = true } rand = { workspace = true } rsa = { workspace = true } sha2 = { workspace = true } +snafu = { workspace = true } testutils = { workspace = true } diff --git a/crates/crypto/src/errors.rs b/crates/crypto/src/errors.rs new file mode 100644 index 00000000..5edb1796 --- /dev/null +++ b/crates/crypto/src/errors.rs @@ -0,0 +1,40 @@ +use snafu::prelude::*; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("invalid message"))] + InvalidMessage, + + #[snafu(display("invalid encoding"))] + InvalidEncoding, + + #[snafu(display("cryptography error"))] + CryptographyError, + + #[snafu(display("base64 error"))] + Base64Error, +} + +impl From for Error { + fn from(_error: std::string::FromUtf8Error) -> Self { + Self::InvalidEncoding + } +} + +impl From for Error { + fn from(_error: aes_gcm::Error) -> Self { + Self::CryptographyError + } +} + +impl From for Error { + fn from(_error: rsa::Error) -> Self { + Self::CryptographyError + } +} + +impl From for Error { + fn from(_error: base64::DecodeError) -> Self { + Self::Base64Error + } +} diff --git a/crates/crypto/src/lib.rs b/crates/crypto/src/lib.rs index 98a38b0d..e891c2fe 100644 --- a/crates/crypto/src/lib.rs +++ b/crates/crypto/src/lib.rs @@ -1,47 +1,81 @@ -use sha2::{Sha256, Digest, digest::DynDigest}; -use rand::rngs::OsRng; +use sha2::Sha256; use base64::prelude::*; +use rand::rngs::OsRng; +use rand::RngCore; +use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce}; // Or Aes256GcmSiv, Aes256GcmHs +use aes_gcm::aead::Aead; use rsa::{ - RsaPrivateKey, RsaPublicKey, Oaep, + Oaep, RsaPrivateKey, RsaPublicKey, traits::PublicKeyParts, - pkcs1::EncodeRsaPublicKey, + pkcs8::EncodePublicKey, }; -/// encrypt manages the process of encrypting long messages using the RSA algorithm and OAEP -/// padding. It takes a public key and a plaintext message and returns the ciphertext. -pub fn encrypt(key: RsaPublicKey, plaintext: String) -> String { - let mut rng = OsRng; - let hash = Sha256::new(); - let bytes = key.n().bits() / 8; - let step = bytes - 2*hash.output_size() - 2; - let chunks = plaintext.as_bytes().chunks(step); - let mut res = vec![]; - - for chunk in chunks { - let padding = Oaep::new::(); - let encrypted = key.encrypt(&mut rng, padding, chunk).unwrap(); - res.extend(encrypted); - } - - BASE64_STANDARD.encode(res) +mod errors; +pub use errors::Error; + +/// encrypt encryptes plaintext with a randomly-generated AES-256 key and IV, then encrypts the AES +/// key with RSA-OAEP using the provided public key. The result is a non-URL-safe base64-encoded +/// string. +pub fn encrypt( + key: RsaPublicKey, + plaintext: String +) -> Result { + // Generate a random 32-byte AES key + let mut aes_key = [0u8; 32]; + OsRng.fill_bytes(&mut aes_key); + + // Generate a random 12-byte IV + let mut iv = [0u8; 12]; + OsRng.fill_bytes(&mut iv); + + // Create AES cipher (GCM mode) + let aes_cipher = Aes256Gcm::new(Key::::from_slice(&aes_key)); + + // Encrypt the message + let nonce = Nonce::from_slice(&iv); // 12 bytes; unique per message + let ciphertext = aes_cipher.encrypt(nonce, plaintext.as_bytes())?; + + // Encrypt the AES key with RSA-OAEP + let padding = Oaep::new::(); + let encrypted_key = key.encrypt(&mut OsRng, padding, &aes_key)?; + + // Combine encrypted key + IV + ciphertext + let mut result = Vec::new(); + result.extend_from_slice(&encrypted_key); + result.extend_from_slice(&iv); + result.extend_from_slice(&ciphertext); + + // Encode the result as base64 + Ok(BASE64_STANDARD.encode(&result)) } -/// decrypt takes a given RSA Private Key and the relevant ciphertext and decrypts it into -/// plaintext. It's expected that the message was encrypted using OAEP padding and SHA256 digest. -pub fn decrypt(key: RsaPrivateKey, ciphertext: String) -> String { - let decoded = BASE64_STANDARD.decode(ciphertext.as_bytes()).unwrap(); +/// decrypt uses `key` to decrypt an AES-256 key that's prepended to the ciphertext. The decrypted +/// key is then used to decrypt the suffix of `ciphertext` which contains the relevant message. +/// It's expected that the message was encrypted using OAEP padding and SHA256 digest. +pub fn decrypt(key: RsaPrivateKey, ciphertext: String) -> Result { + let decoded = BASE64_STANDARD.decode(ciphertext)?; - let step = key.n().bits() / 8; - let chunks: Vec<&[u8]> = decoded.chunks(step).collect(); - let mut res = vec![]; + let n = key.size(); + let (ciphered_key, suffix) = decoded.split_at(n); - for (_, chunk) in chunks.iter().enumerate() { - let padding = Oaep::new::(); - let decrypted = key.decrypt(padding, chunk).unwrap(); - res.extend(decrypted); + let key = key.decrypt( + Oaep::new::(), + ciphered_key, + )?; + + let aes_key =Key::::from_slice(&key); + let cipher = Aes256Gcm::new(aes_key); + + // Check if the suffix is at least 12 bytes (96 bits) for the IV + if suffix.len() < 12 { + return Err(Error::InvalidMessage); } - String::from_utf8(res).unwrap() + let (iv, ciphertext) = suffix.split_at(12); + let nonce = Nonce::from_slice(iv); + + let plaintext = cipher.decrypt(nonce, ciphertext)?; + Ok(String::from_utf8(plaintext)?) } /// generate_key_pair creates a new 2048-bit public and private key for use in @@ -55,7 +89,7 @@ pub fn generate_key_pair() -> (RsaPrivateKey, RsaPublicKey) { /// serialize_public_key takes an RSA public key and serializes it into a PEM-encoded string. pub fn serialize_public_key(key: RsaPublicKey) -> String { - key.to_pkcs1_pem(rsa::pkcs1::LineEnding::LF).unwrap() + key.to_public_key_pem(rsa::pkcs8::LineEnding::LF).unwrap() } #[cfg(test)] @@ -69,8 +103,8 @@ mod test { let (private_key, public_key) = testutils::crypto::get_test_keys(); let plaintext = "Hello, World!".to_string(); - let ciphertext = encrypt(public_key, plaintext.clone()); - let decrypted = decrypt(private_key, ciphertext); + let ciphertext = encrypt(public_key, plaintext.clone()).unwrap(); + let decrypted = decrypt(private_key, ciphertext).unwrap(); assert_eq!(plaintext, decrypted); } @@ -85,8 +119,8 @@ mod test { .map(char::from) .collect(); - let ciphertext = encrypt(public_key, plaintext.clone()); - let decrypted = decrypt(private_key, ciphertext); + let ciphertext = encrypt(public_key, plaintext.clone()).unwrap(); + let decrypted = decrypt(private_key, ciphertext).unwrap(); assert_eq!(plaintext, decrypted); } diff --git a/crates/tower-cmd/Cargo.toml b/crates/tower-cmd/Cargo.toml index 5afbdf9d..9f57eb04 100644 --- a/crates/tower-cmd/Cargo.toml +++ b/crates/tower-cmd/Cargo.toml @@ -23,6 +23,7 @@ rsa = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } simple_logger = { workspace = true } +snafu = { workspace = true } spinners = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true } diff --git a/crates/tower-cmd/src/api.rs b/crates/tower-cmd/src/api.rs index f4b11c42..819b8efe 100644 --- a/crates/tower-cmd/src/api.rs +++ b/crates/tower-cmd/src/api.rs @@ -108,6 +108,24 @@ pub async fn export_secrets(config: &Config, env: &str, all: bool, public_key: r unwrap_api_response(tower_api::apis::default_api::export_secrets(api_config, params)).await } +pub async fn export_catalogs(config: &Config, env: &str, all: bool, public_key: rsa::RsaPublicKey) -> Result> { + let api_config = &config.into(); + + let params = tower_api::apis::default_api::ExportCatalogsParams { + export_catalogs_params: tower_api::models::ExportCatalogsParams { + schema: None, + all, + public_key: crypto::serialize_public_key(public_key), + environment: env.to_string(), + page: 1, + page_size: 100, + }, + }; + + unwrap_api_response(tower_api::apis::default_api::export_catalogs(api_config, params)).await +} + + pub async fn list_secrets(config: &Config, env: &str, all: bool) -> Result> { let api_config = &config.into(); @@ -247,6 +265,17 @@ impl ResponseEntity for tower_api::apis::default_api::ExportSecretsSuccess { } } +impl ResponseEntity for tower_api::apis::default_api::ExportCatalogsSuccess { + type Data = tower_api::models::ExportCatalogsResponse; + + fn extract_data(self) -> Option { + match self { + Self::Status200(data) => Some(data), + Self::UnknownValue(_) => None, + } + } +} + impl ResponseEntity for tower_api::apis::default_api::CreateSecretSuccess { type Data = tower_api::models::CreateSecretResponse; diff --git a/crates/tower-cmd/src/error.rs b/crates/tower-cmd/src/error.rs index e69de29b..7ba35b0a 100644 --- a/crates/tower-cmd/src/error.rs +++ b/crates/tower-cmd/src/error.rs @@ -0,0 +1,20 @@ +use snafu::prelude::*; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("fetching catalogs failed"))] + FetchingCatalogsFailed, + + #[snafu(display("fetching secrets failed"))] + FetchingSecretsFailed, + + #[snafu(display("cryptography error"))] + CryptographyError, +} + +impl From for Error { + fn from(err: crypto::Error) -> Self { + log::debug!("cryptography error: {:?}", err); + Self::CryptographyError + } +} diff --git a/crates/tower-cmd/src/lib.rs b/crates/tower-cmd/src/lib.rs index 3fe628c9..76906b5c 100644 --- a/crates/tower-cmd/src/lib.rs +++ b/crates/tower-cmd/src/lib.rs @@ -5,6 +5,7 @@ mod apps; mod deploy; pub mod output; pub mod api; +pub mod error; mod run; mod secrets; mod session; @@ -12,6 +13,8 @@ mod teams; mod util; mod version; +pub use error::Error; + pub struct App { session: Option, cmd: Command, diff --git a/crates/tower-cmd/src/run.rs b/crates/tower-cmd/src/run.rs index 7e78ca7f..e46b7c12 100644 --- a/crates/tower-cmd/src/run.rs +++ b/crates/tower-cmd/src/run.rs @@ -8,6 +8,7 @@ use tower_runtime::{local::LocalApp, App, AppLauncher, OutputReceiver}; use crate::{ output, api, + Error, }; pub fn run_cmd() -> Command { @@ -85,8 +86,22 @@ async fn do_run_local(config: Config, path: PathBuf, mut params: HashMap = AppLauncher::default(); if let Err(err) = launcher - .launch(package, env, secrets, params, HashMap::new()) + .launch(package, env, secrets, params, catalogs) .await { output::runtime_error(err); @@ -236,26 +251,51 @@ fn get_app_name(cmd: Option<(&str, &ArgMatches)>) -> Option { /// get_secrets manages the process of getting secrets from the Tower server in a way that can be /// used by the local runtime during local app execution. -async fn get_secrets(config: &Config, env: &str) -> HashMap { +async fn get_secrets(config: &Config, env: &str) -> Result, Error> { let (private_key, public_key) = crypto::generate_key_pair(); - let mut spinner = output::spinner("Getting secrets..."); - match api::export_secrets(&config, env, false, public_key).await { Ok(res) => { - spinner.success(); - res.secrets - .into_iter() - .map(|secret| { - let decrypted_value = crypto::decrypt(private_key.clone(), secret.encrypted_value.to_string()); - (secret.name, decrypted_value) - }) - .collect() + let mut secrets = HashMap::new(); + + for secret in res.secrets { + // we will decrypt each property and inject it into the vals map. + let decrypted_value = crypto::decrypt(private_key.clone(), secret.encrypted_value.to_string())?; + secrets.insert(secret.name, decrypted_value); + } + + Ok(secrets) + }, + Err(err) => { + output::tower_error(err); + Err(Error::FetchingSecretsFailed) + } + } +} + +/// get_catalogs manages the process of exporting catalogs, decrypting their properties, and +/// preparting them for injection into the environment during app execution +async fn get_catalogs(config: &Config, env: &str) -> Result, Error> { + let (private_key, public_key) = crypto::generate_key_pair(); + + match api::export_catalogs(&config, env, false, public_key).await { + Ok(res) => { + let mut vals = HashMap::new(); + + for catalog in res.catalogs { + // we will decrypt each property and inject it into the vals map. + for property in catalog.properties { + let decrypted_value = crypto::decrypt(private_key.clone(), property.encrypted_value.to_string())?; + let name = create_pyiceberg_catalog_property_name(&catalog.name, &property.name); + vals.insert(name, decrypted_value); + } + } + + Ok(vals) }, Err(err) => { - spinner.failure(); output::tower_error(err); - HashMap::new() + Err(Error::FetchingCatalogsFailed) } } } @@ -323,3 +363,11 @@ async fn monitor_status(mut app: LocalApp) { } } } + +fn create_pyiceberg_catalog_property_name(catalog_name: &str, property_name: &str) -> String { + let catalog_name = catalog_name.replace('.', "_").replace(':', "_").to_uppercase(); + let property_name = property_name.replace('.', "_").replace(':', "_").to_uppercase(); + + format!("PYICEBERG_CATALOG__{}__{}", catalog_name, property_name) +} + diff --git a/crates/tower-cmd/src/secrets.rs b/crates/tower-cmd/src/secrets.rs index 614a8595..1cfb16ca 100644 --- a/crates/tower-cmd/src/secrets.rs +++ b/crates/tower-cmd/src/secrets.rs @@ -103,7 +103,7 @@ pub async fn do_list(config: Config, args: &ArgMatches) { let decrypted_value = crypto::decrypt( private_key.clone(), secret.encrypted_value.clone(), - ); + ).unwrap(); vec![ secret.name.clone(), @@ -203,7 +203,7 @@ async fn encrypt_and_create_secret( output::die("Failed to parse public key"); }); - let encrypted_value = encrypt(public_key, value.to_string()); + let encrypted_value = encrypt(public_key, value.to_string()).unwrap(); let preview = create_preview(value); api::create_secret(&config, name, environment, &encrypted_value, &preview).await From 40a44b739ddc9aa6e179d918734aa2a06013b82c Mon Sep 17 00:00:00 2001 From: Brad Heller Date: Fri, 9 May 2025 16:10:51 +0200 Subject: [PATCH 02/15] chore: Fix broken test --- crates/crypto/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/crypto/src/lib.rs b/crates/crypto/src/lib.rs index e891c2fe..92bc1fac 100644 --- a/crates/crypto/src/lib.rs +++ b/crates/crypto/src/lib.rs @@ -96,7 +96,7 @@ pub fn serialize_public_key(key: RsaPublicKey) -> String { mod test { use super::*; use rand::{distributions::Alphanumeric, Rng}; - use rsa::pkcs1::DecodeRsaPublicKey; + use rsa::pkcs8::DecodePublicKey; #[test] fn test_encrypt_decrypt() { @@ -129,7 +129,7 @@ mod test { fn test_serialize_public_key() { let (_private_key, public_key) = testutils::crypto::get_test_keys(); let serialized = serialize_public_key(public_key.clone()); - let deserialized = RsaPublicKey::from_pkcs1_pem(&serialized).unwrap(); + let deserialized = RsaPublicKey::from_public_key_pem(&serialized).unwrap(); assert_eq!(public_key, deserialized); } From 025a8535b8cd2fc89b9db817b04794df6cadec3c Mon Sep 17 00:00:00 2001 From: Serhii Sokolenko Date: Fri, 9 May 2025 22:46:34 +0200 Subject: [PATCH 03/15] Updated docstrings in _tables --- src/tower/_tables.py | 363 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 333 insertions(+), 30 deletions(-) diff --git a/src/tower/_tables.py b/src/tower/_tables.py index 63d0294c..8c5ea8fd 100644 --- a/src/tower/_tables.py +++ b/src/tower/_tables.py @@ -36,6 +36,30 @@ class Table: """ def __init__(self, context: TowerContext, table: IcebergTable): + """ + Initialize a new Table instance that wraps an Iceberg table. + + This constructor creates a Table object that provides a high-level interface + for interacting with an Iceberg table. It initializes the table statistics + tracking and stores the necessary context and table references. + + Args: + context (TowerContext): The context in which the table operates, providing + configuration and environment settings. + table (IcebergTable): The underlying Iceberg table instance to be wrapped. + + Attributes: + _stats (RowsAffectedInformation): Tracks the number of rows affected by + insert and update operations. Initialized with zero counts. + _context (TowerContext): The context in which the table operates. + _table (IcebergTable): The underlying Iceberg table instance. + + Example: + >>> # Create a table reference and load it + >>> table_ref = tables("my_table") + >>> table = table_ref.load() # This internally calls Table.__init__ + """ + self._stats = RowsAffectedInformation(0, 0) self._context = context self._table = table @@ -43,7 +67,23 @@ def __init__(self, context: TowerContext, table: IcebergTable): def read(self) -> pl.DataFrame: """ - Reads from the Iceberg tables. Returns the results as a Polars DataFrame. + Reads all data from the Iceberg table and returns it as a Polars DataFrame. + + This method executes a full table scan and materializes the results into memory + as a Polars DataFrame. For large tables, consider using `to_polars()` to get a + LazyFrame that can be processed incrementally. + + Returns: + pl.DataFrame: A Polars DataFrame containing all rows from the table. + + Example: + >>> table = tables("my_table").load() + >>> # Read all data into a DataFrame + >>> df = table.read() + >>> # Perform operations on the DataFrame + >>> filtered_df = df.filter(pl.col("age") > 30) + >>> # Get basic statistics + >>> print(df.describe()) """ # We call `collect` here to force the execution of the query and get # the result as a DataFrame. @@ -52,29 +92,88 @@ def read(self) -> pl.DataFrame: def to_polars(self) -> pl.LazyFrame: """ - Converts the table to a Polars LazyFrame. This is useful when you - understand Polars and you want to do something more complicated. - """ + Converts the table to a Polars LazyFrame for efficient, lazy evaluation. + + This method returns a LazyFrame that allows for building complex query plans + without immediately executing them. This is particularly useful for: + - Processing large tables that don't fit in memory + - Building complex transformations and aggregations + - Optimizing query performance through Polars' query optimizer + + Returns: + pl.LazyFrame: A Polars LazyFrame representing the table data. + + Example: + >>> table = tables("my_table").load() + >>> # Create a lazy query plan + >>> lazy_df = table.to_polars() + >>> # Build complex transformations + >>> result = (lazy_df + ... .filter(pl.col("age") > 30) + ... .groupby("department") + ... .agg(pl.col("salary").mean()) + ... .sort("department")) + >>> # Execute the plan + >>> final_df = result.collect() + """ return pl.scan_iceberg(self._table) def rows_affected(self) -> RowsAffectedInformation: """ - Returns the stats for the table. This includes the number of inserts, - updates, and deletes. + Returns statistics about the number of rows affected by write operations on the table. + + This method tracks the cumulative number of rows that have been inserted or updated + through operations like `insert()` and `upsert()`. Note that delete operations are + not currently tracked due to limitations in the underlying Iceberg implementation. + + Returns: + RowsAffectedInformation: An object containing: + - inserts (int): Total number of rows inserted + - updates (int): Total number of rows updated + + Example: + >>> table = tables("my_table").load() + >>> # Insert some data + >>> table.insert(new_data) + >>> # Upsert some data + >>> table.upsert(updated_data, join_cols=["id"]) + >>> # Check the impact of our operations + >>> stats = table.rows_affected() + >>> print(f"Inserted {stats.inserts} rows") + >>> print(f"Updated {stats.updates} rows") """ return self._stats def insert(self, data: pa.Table) -> TTable: """ - Inserts data into the Iceberg table. The data is expressed as a PyArrow table. + Inserts new rows into the Iceberg table. + + This method appends the provided data to the table. The data must be provided as a + PyArrow table with a schema that matches the table's schema. The operation is + tracked in the table's statistics, incrementing the insert count. Args: - data (pa.Table): The data to insert into the table. + data (pa.Table): The data to insert into the table. The schema of this table + must match the schema of the target table. Returns: - TTable: The table with the inserted rows. + TTable: The table instance with the newly inserted rows, allowing for method chaining. + + Example: + >>> table = tables("my_table").load() + >>> # Create a PyArrow table with new data + >>> new_data = pa.table({ + ... "id": [1, 2, 3], + ... "name": ["Alice", "Bob", "Charlie"], + ... "age": [25, 30, 35] + ... }) + >>> # Insert the data + >>> table.insert(new_data) + >>> # Verify the insertion + >>> stats = table.rows_affected() + >>> print(f"Inserted {stats.inserts} rows") """ self._table.append(data) self._stats.inserts += data.num_rows @@ -83,14 +182,42 @@ def insert(self, data: pa.Table) -> TTable: def upsert(self, data: pa.Table, join_cols: Optional[list[str]] = None) -> TTable: """ - Upserts data into the Iceberg table. The data is expressed as a PyArrow table. + Performs an upsert operation (update or insert) on the Iceberg table. + + This method will: + - Update existing rows if they match the join columns + - Insert new rows if no match is found + All operations are case-sensitive by default. Args: - data (pa.Table): The data to upsert into the table. - join_cols (Optional[list[str]]): The columns that form the key to match rows on + data (pa.Table): The data to upsert into the table. The schema of this table + must match the schema of the target table. + join_cols (Optional[list[str]]): The columns that form the key to match rows on. + If not provided, all columns will be used for matching. Returns: - TTable: The table with the upserted rows. + TTable: The table instance with the upserted rows, allowing for method chaining. + + Note: + - The operation is always case-sensitive + - When a match is found, all columns are updated + - When no match is found, the row is inserted + - The operation is tracked in the table's statistics + + Example: + >>> table = tables("my_table").load() + >>> # Create a PyArrow table with data to upsert + >>> data = pa.table({ + ... "id": [1, 2, 3], + ... "name": ["Alice", "Bob", "Charlie"], + ... "age": [26, 31, 36] # Updated ages + ... }) + >>> # Upsert the data using 'id' as the key + >>> table.upsert(data, join_cols=["id"]) + >>> # Verify the operation + >>> stats = table.rows_affected() + >>> print(f"Updated {stats.updates} rows") + >>> print(f"Inserted {stats.inserts} rows") """ res = self._table.upsert( data, @@ -114,17 +241,40 @@ def upsert(self, data: pa.Table, join_cols: Optional[list[str]] = None) -> TTabl def delete(self, filters: Union[str, List[pc.Expression]]) -> TTable: """ - Deletes data from the Iceberg table. The filters are expressed as a - PyArrow expression. The filters are applied to the table and the - matching rows are deleted. + Deletes rows from the Iceberg table that match the specified filter conditions. + + This method removes rows from the table based on the provided filter expressions. + The operation is always case-sensitive. Note that the number of deleted rows + cannot be tracked due to limitations in the underlying Iceberg implementation. Args: - filters (Union[str, List[pc.Expression]]): The filters to apply to the table. - This can be a string or a list of PyArrow expressions. + filters (Union[str, List[pc.Expression]]): The filter conditions to apply. + Can be either: + - A single PyArrow compute expression + - A list of PyArrow compute expressions (combined with AND) + - A string expression Returns: - TTable: The table with the deleted rows. + TTable: The table instance with the deleted rows, allowing for method chaining. + + Note: + - The operation is always case-sensitive + - The number of deleted rows cannot be tracked in the table statistics + - To get the number of deleted rows, you would need to compare snapshots + + Example: + >>> table = tables("my_table").load() + >>> # Delete rows where age is greater than 30 + >>> table.delete(table.column("age") > 30) + >>> # Delete rows matching multiple conditions + >>> table.delete([ + ... table.column("age") > 30, + ... table.column("department") == "IT" + ... ]) + >>> # Delete rows using a string expression + >>> table.delete("age > 30 AND department = 'IT'") """ + if isinstance(filters, list): # We need to convert the pc.Expression into PyIceberg next_filters = convert_pyarrow_expressions(filters) @@ -144,15 +294,45 @@ def delete(self, filters: Union[str, List[pc.Expression]]) -> TTable: def schema(self) -> pa.Schema: - # We take an Iceberg Schema and we need to convert it into a PyArrow Schema + """ + Returns the schema of the table as a PyArrow schema. + + This method converts the underlying Iceberg table schema into a PyArrow schema, + which can be used for type information and schema validation. + + Returns: + pa.Schema: The PyArrow schema representation of the table's structure. + Example: + >>> table = tables("my_table").load() + >>> schema = table.schema() + """ iceberg_schema = self._table.schema() return iceberg_schema.as_arrow() def column(self, name: str) -> pa.compute.Expression: """ - Returns a column from the table. This is useful when you want to - perform some operations on the column. + Returns a column from the table as a PyArrow compute expression. + + This method is useful for creating column-based expressions that can be used in + operations like filtering, sorting, or aggregating data. The returned expression + can be used with PyArrow's compute functions. + + Args: + name (str): The name of the column to retrieve from the table schema. + + Returns: + pa.compute.Expression: A PyArrow compute expression representing the column. + + Raises: + ValueError: If the specified column name is not found in the table schema. + + Example: + >>> table = tables("my_table").load() + >>> # Create a filter expression for rows where age > 30 + >>> age_expr = table.column("age") > 30 + >>> # Use the expression in a delete operation + >>> table.delete(age_expr) """ field = self.schema().field(name) @@ -172,6 +352,26 @@ def __init__(self, ctx: TowerContext, catalog: Catalog, name: str, namespace: Op def load(self) -> Table: + """ + Loads an existing Iceberg table from the catalog. + + This method resolves the table's namespace and name, then loads the table + from the catalog. If the table doesn't exist, this will raise an error. + Use `create()` or `create_if_not_exists()` to create new tables. + + Returns: + Table: A new Table instance wrapping the loaded Iceberg table. + + Raises: + TableNotFoundError: If the table doesn't exist in the catalog. + + Example: + >>> # Load the existing table + >>> table = tables("my_table", namespace="my_namespace").load() + >>> # Now you can use the table + >>> df = table.read() + """ + namespace = namespace_or_default(self._namespace) table_name = make_table_name(self._name, namespace) table = self._catalog.load_table(table_name) @@ -179,6 +379,40 @@ def load(self) -> Table: def create(self, schema: pa.Schema) -> Table: + + """ + Creates a new Iceberg table with the specified schema. + + This method will: + 1. Resolve the table's namespace (using default if not specified) + 2. Create the namespace if it doesn't exist + 3. Create a new table with the provided schema + 4. Return a Table instance for the newly created table + + Args: + schema (pa.Schema): The PyArrow schema defining the structure of the table. + This will be converted to an Iceberg schema internally. + + Returns: + Table: A new Table instance wrapping the created Iceberg table. + + Raises: + TableAlreadyExistsError: If a table with the same name already exists in the namespace. + NamespaceError: If there are issues creating or accessing the namespace. + + Example: + >>> # Define the table schema + >>> schema = pa.schema([ + ... pa.field("id", pa.int64()), + ... pa.field("name", pa.string()), + ... pa.field("age", pa.int32()) + ... ]) + >>> # Create the table + >>> table = tables("my_table", namespace="my_namespace").create(schema) + >>> # Now you can use the table + >>> table.insert(new_data) + """ + namespace = namespace_or_default(self._namespace) table_name = make_table_name(self._name, namespace) @@ -198,6 +432,43 @@ def create(self, schema: pa.Schema) -> Table: def create_if_not_exists(self, schema: pa.Schema) -> Table: + """ + Creates a new Iceberg table if it doesn't exist, or returns the existing table. + + This method will: + 1. Resolve the table's namespace (using default if not specified) + 2. Create the namespace if it doesn't exist + 3. Create a new table with the provided schema if it doesn't exist + 4. Return the existing table if it already exists + 5. Return a Table instance for the table + + Unlike `create()`, this method will not raise an error if the table already exists. + Instead, it will return the existing table, making it safe for idempotent operations. + + Args: + schema (pa.Schema): The PyArrow schema defining the structure of the table. + This will be converted to an Iceberg schema internally. Note that this + schema is only used if the table needs to be created. + + Returns: + Table: A Table instance wrapping either the newly created or existing Iceberg table. + + Raises: + NamespaceError: If there are issues creating or accessing the namespace. + + Example: + >>> # Define the table schema + >>> schema = pa.schema([ + ... pa.field("id", pa.int64()), + ... pa.field("name", pa.string()), + ... pa.field("age", pa.int32()) + ... ]) + >>> # Create the table if it doesn't exist + >>> table = tables("my_table", namespace="my_namespace").create_if_not_exists(schema) + >>> # This is safe to call multiple times + >>> table = tables("my_table", namespace="my_namespace").create_if_not_exists(schema) + """ + namespace = namespace_or_default(self._namespace) table_name = make_table_name(self._name, namespace) @@ -222,18 +493,50 @@ def tables( namespace: Optional[str] = None ) -> TableReference: """ - `tables` creates a reference to an Iceberg table with the name `name` from - the catalog with name `catalog_name`. + Creates a reference to an Iceberg table that can be used to load or create tables. + + This function is the main entry point for working with Iceberg tables in Tower. It returns + a TableReference object that can be used to either load an existing table or create a new one. + The actual table operations (read, write, etc.) are performed through the Table instance + obtained by calling `load()` or `create()` on the returned reference. Args: - `name` (str): The name of the table to load. - `catalog` (Union[str, Catalog]): The name of the catalog or the actual - catalog to use. "default" is the default value. You can pass in an - actual catalog object for testing purposes. - `namespace` (Optional[str]): The namespace in which to load the table. + name (str): The name of the table to reference. This will be used to either load + an existing table or create a new one. + catalog (Union[str, Catalog], optional): The catalog to use. Can be either: + - A string name of the catalog (defaults to "default") + - A Catalog instance (useful for testing or custom catalog implementations) + Defaults to "default". + namespace (Optional[str], optional): The namespace in which the table exists or + should be created. If not provided, a default namespace will be used. Returns: - TableReference: A reference to a table to be resolved with `create` or `load` + TableReference: A reference object that can be used to: + - Load an existing table using `load()` + - Create a new table using `create()` + - Create a table if it doesn't exist using `create_if_not_exists()` + + Raises: + CatalogError: If there are issues accessing or loading the specified catalog. + TableNotFoundError: When trying to load a non-existent table (only if `load()` is called). + + Examples: + >>> # Load an existing table from the default catalog + >>> table = tables("my_table").load() + >>> df = table.read() + + >>> # Create a new table in a specific namespace + >>> schema = pa.schema([ + ... pa.field("id", pa.int64()), + ... pa.field("name", pa.string()) + ... ]) + >>> table = tables("new_table", namespace="my_namespace").create(schema) + + >>> # Use a specific catalog + >>> table = tables("my_table", catalog="my_catalog").load() + + >>> # Create a table if it doesn't exist + >>> table = tables("my_table").create_if_not_exists(schema) """ if isinstance(catalog, str): catalog = load_catalog(catalog) From f302ccac0948081a373f73567c70904556647893 Mon Sep 17 00:00:00 2001 From: Brad Heller Date: Sat, 10 May 2025 22:00:20 +0200 Subject: [PATCH 04/15] chore: A simple implementation and test for `wait_for_runs` --- src/tower/__init__.py | 1 + src/tower/_client.py | 20 ++++++- tests/tower/test_client.py | 119 ++++++++++++++++++++++++++++++++++++- 3 files changed, 138 insertions(+), 2 deletions(-) diff --git a/src/tower/__init__.py b/src/tower/__init__.py index 129529ba..4dbeca74 100644 --- a/src/tower/__init__.py +++ b/src/tower/__init__.py @@ -13,6 +13,7 @@ from ._client import ( run_app, wait_for_run, + wait_for_runs, ) from ._features import override_get_attr, get_available_features, is_feature_enabled diff --git a/src/tower/_client.py b/src/tower/_client.py index d30c6ada..77c38eea 100644 --- a/src/tower/_client.py +++ b/src/tower/_client.py @@ -1,6 +1,6 @@ import os import time -from typing import Dict, Optional +from typing import List, Dict, Optional from ._context import TowerContext from .tower_api_client import AuthenticatedClient @@ -119,3 +119,21 @@ def wait_for_run(run: Run) -> None: return else: time.sleep(WAIT_TIMEOUT) + + +def wait_for_runs(runs: List[Run]) -> None: + """ + `wait_for_runs` waits for a list of runs to reach a terminal state by + polling the Tower API every 2 seconds for the latest status. If any of the + runs return a terminal status (`exited`, `errored`, `cancelled`, or + `crashed`) then this function returns. + + Args: + runs (List[Run]): A list of runs to wait for. + + Raises: + RuntimeError: If there is an error fetching the run status or if any + of the runs fail. + """ + for run in runs: + wait_for_run(run) diff --git a/tests/tower/test_client.py b/tests/tower/test_client.py index e9a75c70..156ab2bd 100644 --- a/tests/tower/test_client.py +++ b/tests/tower/test_client.py @@ -43,7 +43,7 @@ def test_running_apps(httpx_mock): # Assert the response assert run is not None -def test_waiting_for_runs(httpx_mock): +def test_waiting_for_a_run(httpx_mock): # Mock the response from the API httpx_mock.add_response( method="GET", @@ -119,3 +119,120 @@ def test_waiting_for_runs(httpx_mock): # Now actually wait for the run. tower.wait_for_run(run) + +def test_waiting_for_multiple_runs(httpx_mock): + # Mock the response from the API + httpx_mock.add_response( + method="GET", + url="https://api.example.com/v1/apps/my-app/runs/3", + json={ + "run": { + "app_slug": "my-app", + "app_version": "v6", + "cancelled_at": None, + "created_at": "2025-04-25T20:54:58.762547Z", + "ended_at": "2025-04-25T20:55:35.220295Z", + "environment": "default", + "number": 3, + "run_id": "50ac9bc1-c783-4359-9917-a706f20dc02c", + "scheduled_at": "2025-04-25T20:54:58.761867Z", + "started_at": "2025-04-25T20:54:59.366937Z", + "status": "pending", + "status_group": "", + "parameters": [] + } + }, + status_code=200, + ) + + # Second request, will indicate that it's done. + httpx_mock.add_response( + method="GET", + url="https://api.example.com/v1/apps/my-app/runs/3", + json={ + "run": { + "app_slug": "my-app", + "app_version": "v6", + "cancelled_at": None, + "created_at": "2025-04-25T20:54:58.762547Z", + "ended_at": "2025-04-25T20:55:35.220295Z", + "environment": "default", + "number": 3, + "run_id": "50ac9bc1-c783-4359-9917-a706f20dc02c", + "scheduled_at": "2025-04-25T20:54:58.761867Z", + "started_at": "2025-04-25T20:54:59.366937Z", + "status": "exited", + "status_group": "successful", + "parameters": [] + } + }, + status_code=200, + ) + + # Second request, will indicate that it's done. + httpx_mock.add_response( + method="GET", + url="https://api.example.com/v1/apps/my-app/runs/4", + json={ + "run": { + "app_slug": "my-app", + "app_version": "v6", + "cancelled_at": None, + "created_at": "2025-04-25T20:54:58.762547Z", + "ended_at": "2025-04-25T20:55:35.220295Z", + "environment": "default", + "number": 3, + "run_id": "50ac9bc1-c783-4359-9917-a706f20dc02c", + "scheduled_at": "2025-04-25T20:54:58.761867Z", + "started_at": "2025-04-25T20:54:59.366937Z", + "status": "exited", + "status_group": "successful", + "parameters": [] + } + }, + status_code=200, + ) + + # We tell the client to use the mock server. + os.environ["TOWER_URL"] = "https://api.example.com" + os.environ["TOWER_API_KEY"] = "abc123" + + import tower + + run1 = Run( + app_slug="my-app", + app_version="v6", + cancelled_at=None, + created_at="2025-04-25T20:54:58.762547Z", + ended_at="2025-04-25T20:55:35.220295Z", + environment="default", + number=3, + run_id="50ac9bc1-c783-4359-9917-a706f20dc02c", + scheduled_at="2025-04-25T20:54:58.761867Z", + started_at="2025-04-25T20:54:59.366937Z", + status="running", + status_group="failed", + parameters=[] + ) + + run2 = Run( + app_slug="my-app", + app_version="v6", + cancelled_at=None, + created_at="2025-04-25T20:54:58.762547Z", + ended_at="2025-04-25T20:55:35.220295Z", + environment="default", + number=4, + run_id="50ac9bc1-c783-4359-9917-a706f20dc02c", + scheduled_at="2025-04-25T20:54:58.761867Z", + started_at="2025-04-25T20:54:59.366937Z", + status="running", + status_group="failed", + parameters=[] + ) + + # Set WAIT_TIMEOUT to 0 so we don't have to...wait. + tower._client.WAIT_TIMEOUT = 0 + + # Now actually wait for the run. + tower.wait_for_runs([run1, run2]) From e59f1f0d4f84bb5c06262101758d4843824c5f22 Mon Sep 17 00:00:00 2001 From: Serhii Sokolenko Date: Sat, 10 May 2025 22:24:09 +0200 Subject: [PATCH 05/15] Updated docstrings in the orchestration client --- src/tower/_client.py | 51 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/src/tower/_client.py b/src/tower/_client.py index 77c38eea..efb05638 100644 --- a/src/tower/_client.py +++ b/src/tower/_client.py @@ -52,9 +52,26 @@ def run_app( parameters: Optional[Dict[str, str]] = None, ) -> Run: """ - `run_app` invokes an app based on the configured environment. You can - supply an optional `environment` override, and an optional dict - `parameters` to pass into the app. + Run a Tower application with specified parameters and environment. + + This function initiates a new run of a Tower application identified by its slug. + The run can be configured with an optional environment override and runtime parameters. + If no environment is specified, the default environment from the Tower context is used. + + Args: + slug (str): The unique identifier of the application to run. + environment (Optional[str]): The environment to run the application in. + If not provided, uses the default environment from the Tower context. + parameters (Optional[Dict[str, str]]): A dictionary of key-value pairs + to pass as parameters to the application run. + + Returns: + Run: A Run object containing information about the initiated application run, + including the app_slug and run number. + + Raises: + RuntimeError: If there is an error initiating the run or if the Tower API + returns an error response. """ ctx = TowerContext.build() client = _env_client(ctx) @@ -86,10 +103,25 @@ def run_app( def wait_for_run(run: Run) -> None: """ - `wait_for_run` waits for a run to reach a terminal state by polling the - Tower API every 2 seconds for the latest status. If the app returns a - terminal status (`exited`, `errored`, `cancelled`, or `crashed`) then this - function returns. + Wait for a Tower app run to reach a terminal state by polling the Tower API. + + This function continuously polls the Tower API every 2 seconds (defined by WAIT_TIMEOUT) + to check the status of the specified run. The function returns when the run reaches + any of the following terminal states: + - exited: The run completed successfully + - failed: The run failed during execution + - canceled: The run was manually canceled + - errored: The run encountered an error + + Args: + run (Run): The Run object containing the app_slug and number of the run to monitor. + + Returns: + None: This function does not return any value. + + Raises: + RuntimeError: If there is an error fetching the run status from the Tower API + or if the API returns an error response. """ ctx = TowerContext.build() client = _env_client(ctx) @@ -129,7 +161,10 @@ def wait_for_runs(runs: List[Run]) -> None: `crashed`) then this function returns. Args: - runs (List[Run]): A list of runs to wait for. + runs (List[Run]): A list of Run objects to monitor. + + Returns: + None: This function does not return any value. Raises: RuntimeError: If there is an error fetching the run status or if any From 07111d72f12cf492a9400c9ec128831e4fba3f98 Mon Sep 17 00:00:00 2001 From: Brad Heller Date: Mon, 12 May 2025 14:40:45 +0100 Subject: [PATCH 06/15] chore: Feedback from @datancoffee --- src/tower/_client.py | 129 ++++++++++++++++++++++++++++++++++++------- src/tower/_errors.py | 29 ++++++++++ 2 files changed, 137 insertions(+), 21 deletions(-) create mode 100644 src/tower/_errors.py diff --git a/src/tower/_client.py b/src/tower/_client.py index efb05638..658f4112 100644 --- a/src/tower/_client.py +++ b/src/tower/_client.py @@ -3,6 +3,15 @@ from typing import List, Dict, Optional from ._context import TowerContext +from ._errors import ( + NotFoundException, + UnauthorizedException, + UnknownException, + UnhandledRunStateException, + RunFailedError, + TimeoutException, +) + from .tower_api_client import AuthenticatedClient from .tower_api_client.api.default import describe_run as describe_run_api from .tower_api_client.api.default import run_app as run_app_api @@ -101,20 +110,63 @@ def run_app( return output.run -def wait_for_run(run: Run) -> None: +def _is_failed_run(run: Run) -> bool: + """ + Check if the given run has failed. + + Args: + run (Run): The Run object containing the status to check. + + Returns: + bool: True if the run has failed, False otherwise. + """ + return run.status in ["crashed", "cancelled", "errored"] + + +def _is_successful_run(run: Run) -> bool: + """ + Check if a given run was successful. + + Args: + run (Run): The Run object containing the status to check. + + Returns: + bool: True if the run was successful, False otherwise. + """ + return run.status in ["exited"] + + +def _is_run_awaiting_completion(run: Run) -> bool: + """ + Check if a given run is either running or expected to run in the near future. + + Args: + run (Run): The Run object containing the status to check. + + Returns: + bool: True if the run is awaiting run or currently running, False otherwise. + """ + return run.status in ["pending", "scheduled", "running"] + + +def wait_for_run( + run: Run, + timeout: Optional[float] = 86_400.0, # one day + raise_on_failure: bool = False, +) -> Run: """ Wait for a Tower app run to reach a terminal state by polling the Tower API. This function continuously polls the Tower API every 2 seconds (defined by WAIT_TIMEOUT) to check the status of the specified run. The function returns when the run reaches - any of the following terminal states: - - exited: The run completed successfully - - failed: The run failed during execution - - canceled: The run was manually canceled - - errored: The run encountered an error + any of the defined terminal states. Args: run (Run): The Run object containing the app_slug and number of the run to monitor. + timeout (Optional[float]): An optional timeout for this wait. Defaults + to one day (86,000 seconds). + raise_on_failure (bool): Whether to raise an exception when a failure + occurs. Defaults to False. Returns: None: This function does not return any value. @@ -126,6 +178,9 @@ def wait_for_run(run: Run) -> None: ctx = TowerContext.build() client = _env_client(ctx) + # We use this to track the timeout, if one is defined. + start_time = time.time() + while True: output: Optional[Union[DescribeRunResponse, ErrorModel]] = describe_run_api.sync( slug=run.app_slug, @@ -134,26 +189,51 @@ def wait_for_run(run: Run) -> None: ) if output is None: - raise RuntimeError("Error fetching run") + raise UnknownException("Error fetching run") else: if isinstance(output, ErrorModel): - raise RuntimeError(f"Error fetching run: {output.title}") + # If it was a 404 error, that means that we couldn't find this + # app for some reason. This is really only relevant on the + # first time that we check--if we could find the run, but then + # suddenly couldn't that's a really big problem I'd say. + if output.status == 404: + raise NotFoundException(output.detail) + elif output.status == 401: + # NOTE: Most of the time, this shouldn't happen? + raise UnauthorizedException(output.detail) + else: + raise UnknownException(output.detail) else: desc = output.run - if desc.status == "exited": - return - elif desc.status == "failed": - return - elif desc.status == "canceled": - return - elif desc.status == "errored": - return - else: - time.sleep(WAIT_TIMEOUT) - + if _is_successful_run(desc): + return True + elif _is_failed_run(desc): + if raise_on_failure: + raise RunFailedError(desc.app_slug, desc.number) + else: + return False -def wait_for_runs(runs: List[Run]) -> None: + elif _is_run_awaiting_completion(desc): + time.sleep(WAIT_TIMEOUT) + else: + raise UnhandledRunStateException(desc.status) + + # Before we head back to the top of the loop, let's see if we + # should timeout + if timeout is not None: + # The user defined a timeout, so let's actually see if we + # reached it. + t = time.time() - start_time + if t > timeout: + raise TimeoutException(t) + + +def wait_for_runs( + runs: List[Run], + timeout: Optional[float] = 86_400.0, # one day + raise_on_failure: bool = False, +) -> tuple[List[Run], List[Run]]: """ `wait_for_runs` waits for a list of runs to reach a terminal state by polling the Tower API every 2 seconds for the latest status. If any of the @@ -162,6 +242,9 @@ def wait_for_runs(runs: List[Run]) -> None: Args: runs (List[Run]): A list of Run objects to monitor. + timeout (Optional[float]): Timeout to wait. + raise_on_failure (bool): If true, raises an exception when + any one of the awaited runs fails. Defaults to False. Returns: None: This function does not return any value. @@ -171,4 +254,8 @@ def wait_for_runs(runs: List[Run]) -> None: of the runs fail. """ for run in runs: - wait_for_run(run) + wait_for_run( + run, + timeout=timeout, + raise_on_failure=raise_on_failure, + ) diff --git a/src/tower/_errors.py b/src/tower/_errors.py new file mode 100644 index 00000000..09752dc6 --- /dev/null +++ b/src/tower/_errors.py @@ -0,0 +1,29 @@ +class NotFoundException(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class UnauthorizedException(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class UnknownException(Exception): + def __init__(self, message: str): + super().__init__(message) + + +class UnhandledRunStateException(Exception): + def __init__(self, state: str): + message = f"Run state '{state}' was unexpected. Maybe you need to upgrade to the latest Tower SDK." + super().__init__(message) + + +class TimeoutException(Exception): + def __init__(self, time: float): + super().__init__("A timeout occured after {time} seconds.") + + +class RunFailedError(RuntimeError): + def __init__(self, app_name: str, number: int, state: str): + super().__init__(f"Run {app_name}#{number} failed with status '{state}'") From 6dd8ab45d46b9f1d1e5b55d27cc93e82d1d098d3 Mon Sep 17 00:00:00 2001 From: Rohit Sankaran Date: Tue, 13 May 2025 17:55:26 +0800 Subject: [PATCH 07/15] Handle field type mappings correctly (#49) * fix: Fix type mappings, format code * feat: Handle nested types correctly and set ids * chore: Fix formatting * chore: Add how to run tests * feat: Add tests for types --- README.md | 7 + pyproject.toml | 12 +- src/tower/utils/pyarrow.py | 264 ++++++++++++++------ tests/tower/test_tables.py | 495 +++++++++++++++++++++++++++++++++---- 4 files changed, 641 insertions(+), 137 deletions(-) diff --git a/README.md b/README.md index 9b236a83..f50450aa 100644 --- a/README.md +++ b/README.md @@ -123,5 +123,12 @@ easily. Then you can `import tower` and you're off to the races! uv run python ``` +To run tests: + +```bash +uv sync --locked --all-extras --dev +uv run pytest tests +``` + If you need to get the latest OpenAPI SDK, you can run `./scripts/generate-python-api-client.sh`. diff --git a/pyproject.toml b/pyproject.toml index 71efe0ab..7c7e3531 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,14 +5,6 @@ build-backend = "maturin" [project] name = "tower" version = "0.3.13" - - - - - - - - description = "Tower CLI and runtime environment for Tower." authors = [{ name = "Tower Computing Inc.", email = "brad@tower.dev" }] readme = "README.md" @@ -68,8 +60,8 @@ tower = { workspace = true } [dependency-groups] dev = [ - "openapi-python-client>=0.12.1", + "openapi-python-client>=0.12.1", "pytest>=8.3.5", "pytest-httpx>=0.35.0", - "pyiceberg[sql-sqlite]>=0.9.0", + "pyiceberg[sql-sqlite]>=0.9.0", ] diff --git a/src/tower/utils/pyarrow.py b/src/tower/utils/pyarrow.py index a0ad414c..9b8c1111 100644 --- a/src/tower/utils/pyarrow.py +++ b/src/tower/utils/pyarrow.py @@ -3,93 +3,202 @@ import pyarrow as pa import pyarrow.compute as pc -import pyiceberg.types as types +from pyiceberg import types as iceberg_types from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.expressions import ( BooleanExpression, - And, Or, Not, - EqualTo, NotEqualTo, - GreaterThan, GreaterThanOrEqual, - LessThan, LessThanOrEqual, - Literal, Reference + And, + Or, + Not, + EqualTo, + NotEqualTo, + GreaterThan, + GreaterThanOrEqual, + LessThan, + LessThanOrEqual, + Reference, ) -def arrow_to_iceberg_type(arrow_type): + +class FieldIdManager: + """ + Manages the assignment of unique field IDs. + Field IDs in Iceberg start from 1. + """ + + def __init__(self, start_id=1): + # Initialize current_id to start_id - 1 so the first call to get_next_id() returns start_id + self.current_id = start_id - 1 + + def get_next_id(self) -> int: + """Returns the next available unique field ID.""" + self.current_id += 1 + return self.current_id + + +def arrow_to_iceberg_type_recursive( + arrow_type: pa.DataType, field_id_manager: FieldIdManager +) -> iceberg_types.IcebergType: """ - Convert a PyArrow type to a PyIceberg type. Special thanks to Claude for - the help on this. + Recursively convert a PyArrow DataType to a PyIceberg type, + managing field IDs for nested structures. """ - - if pa.types.is_boolean(arrow_type): - return types.BooleanType() + # Primitive type mappings (most remain the same) + if pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type): + return iceberg_types.StringType() elif pa.types.is_integer(arrow_type): - # Check the bit width to determine the appropriate Iceberg integer type - bit_width = arrow_type.bit_width - if bit_width <= 32: - return types.IntegerType() + if arrow_type.bit_width <= 32: # type: ignore + return iceberg_types.IntegerType() else: - return types.LongType() + return iceberg_types.LongType() elif pa.types.is_floating(arrow_type): - if arrow_type.bit_width == 32: - return types.FloatType() + if arrow_type.bit_width <= 32: # type: ignore + return iceberg_types.FloatType() else: - return types.DoubleType() - elif pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type): - return types.StringType() - elif pa.types.is_binary(arrow_type) or pa.types.is_large_binary(arrow_type): - return types.BinaryType() + return iceberg_types.DoubleType() + elif pa.types.is_boolean(arrow_type): + return iceberg_types.BooleanType() elif pa.types.is_date(arrow_type): - return types.DateType() - elif pa.types.is_timestamp(arrow_type): - return types.TimestampType() + return iceberg_types.DateType() elif pa.types.is_time(arrow_type): - return types.TimeType() + return iceberg_types.TimeType() + elif pa.types.is_timestamp(arrow_type): + if arrow_type.tz is not None: # type: ignore + return iceberg_types.TimestamptzType() + else: + return iceberg_types.TimestampType() + elif pa.types.is_binary(arrow_type) or pa.types.is_large_binary(arrow_type): + return iceberg_types.BinaryType() + elif pa.types.is_fixed_size_binary(arrow_type): + return iceberg_types.FixedType(length=arrow_type.byte_width) # type: ignore elif pa.types.is_decimal(arrow_type): - precision = arrow_type.precision - scale = arrow_type.scale - return types.DecimalType(precision, scale) - elif pa.types.is_list(arrow_type): - element_type = arrow_to_iceberg_type(arrow_type.value_type) - return types.ListType(element_type) + return iceberg_types.DecimalType(arrow_type.precision, arrow_type.scale) # type: ignore + + # Nested type mappings + elif ( + pa.types.is_list(arrow_type) + or pa.types.is_large_list(arrow_type) + or pa.types.is_fixed_size_list(arrow_type) + ): + # The element field itself in Iceberg needs an ID. + element_id = field_id_manager.get_next_id() + + # Recursively convert the list's element type. + # arrow_type.value_type is the DataType of the elements. + # arrow_type.value_field is the Field of the elements (contains name, type, nullability). + element_pyarrow_type = arrow_type.value_type # type: ignore + element_iceberg_type = arrow_to_iceberg_type_recursive( + element_pyarrow_type, field_id_manager + ) + + # Determine if the elements themselves are required (not nullable). + element_is_required = not arrow_type.value_field.nullable # type: ignore + + return iceberg_types.ListType( + element_id=element_id, + element_type=element_iceberg_type, + element_required=element_is_required, + ) elif pa.types.is_struct(arrow_type): - fields = [] - for i, field in enumerate(arrow_type): - name = field.name - field_type = arrow_to_iceberg_type(field.type) - fields.append(types.NestedField(i + 1, name, field_type, required=not field.nullable)) - return types.StructType(*fields) + struct_iceberg_fields = [] + # arrow_type is a StructType. Iterate through its fields. + for i in range(arrow_type.num_fields): # type: ignore + pyarrow_child_field = arrow_type.field(i) # This is a pyarrow.Field + + # Each field within the struct needs its own unique ID. + nested_field_id = field_id_manager.get_next_id() + nested_iceberg_type = arrow_to_iceberg_type_recursive( + pyarrow_child_field.type, field_id_manager + ) + + doc = None + if pyarrow_child_field.metadata and b"doc" in pyarrow_child_field.metadata: + doc = pyarrow_child_field.metadata[b"doc"].decode("utf-8") + + struct_iceberg_fields.append( + iceberg_types.NestedField( + field_id=nested_field_id, + name=pyarrow_child_field.name, + field_type=nested_iceberg_type, + required=not pyarrow_child_field.nullable, + doc=doc, + ) + ) + return iceberg_types.StructType(*struct_iceberg_fields) elif pa.types.is_map(arrow_type): - key_type = arrow_to_iceberg_type(arrow_type.key_type) - value_type = arrow_to_iceberg_type(arrow_type.item_type) - return types.MapType(key_type, value_type) + # Iceberg MapType requires IDs for key and value fields. + key_id = field_id_manager.get_next_id() + value_id = field_id_manager.get_next_id() + + key_iceberg_type = arrow_to_iceberg_type_recursive( + arrow_type.key_type, field_id_manager + ) # type: ignore + value_iceberg_type = arrow_to_iceberg_type_recursive( + arrow_type.item_type, field_id_manager + ) # type: ignore + + # PyArrow map keys are always non-nullable by Arrow specification. + # Nullability of map values comes from the item_field. + value_is_required = not arrow_type.item_field.nullable # type: ignore + + return iceberg_types.MapType( + key_id=key_id, + key_type=key_iceberg_type, + value_id=value_id, + value_type=value_iceberg_type, + value_required=value_is_required, + ) else: raise ValueError(f"Unsupported Arrow type: {arrow_type}") -def convert_pyarrow_field(num, field) -> types.NestedField: - name = field.name - field_type = arrow_to_iceberg_type(field.type) - field_id = num + 1 # Iceberg requires field IDs +def convert_pyarrow_schema( + arrow_schema: pa.Schema, schema_id: int = 1, start_field_id: int = 1 +) -> IcebergSchema: + """ + Convert a PyArrow schema to a PyIceberg schema. + + Args: + arrow_schema: The input PyArrow.Schema. + schema_id: The schema ID for the Iceberg schema. + start_field_id: The starting ID for field ID assignment. + Returns: + An IcebergSchema object. + """ + field_id_manager = FieldIdManager(start_id=start_field_id) + iceberg_fields = [] + + for pyarrow_field in arrow_schema: # pyarrow_field is a pa.Field object + # Assign a unique ID for this top-level field. + top_level_field_id = field_id_manager.get_next_id() - return types.NestedField( - field_id, - name, - field_type, - required=not field.nullable - ) + # Recursively convert the field's type. This will handle ID assignment + # for any nested structures using the same field_id_manager. + iceberg_field_type = arrow_to_iceberg_type_recursive( + pyarrow_field.type, field_id_manager + ) + doc = None + if pyarrow_field.metadata and b"doc" in pyarrow_field.metadata: + doc = pyarrow_field.metadata[b"doc"].decode("utf-8") -def convert_pyarrow_schema(arrow_schema: pa.Schema) -> IcebergSchema: - """Convert a PyArrow schema to a PyIceberg schema.""" - fields = [convert_pyarrow_field(i, field) for i, field in enumerate(arrow_schema)] - return IcebergSchema(*fields) + iceberg_fields.append( + iceberg_types.NestedField( + field_id=top_level_field_id, + name=pyarrow_field.name, + field_type=iceberg_field_type, + required=not pyarrow_field.nullable, # Top-level field nullability + doc=doc, + ) + ) + return IcebergSchema(*iceberg_fields, schema_id=schema_id) def extract_field_and_literal(expr: pc.Expression) -> tuple[str, Any]: """Extract field name and literal value from a comparison expression.""" # First, convert the expression to a string and parse it expr_str = str(expr) - + # PyArrow expression strings look like: "(field_name == literal)" or similar # Need to determine the operator and then split accordingly operators = ["==", "!=", ">", ">=", "<", "<="] @@ -98,36 +207,38 @@ def extract_field_and_literal(expr: pc.Expression) -> tuple[str, Any]: if op in expr_str: op_used = op break - + if not op_used: - raise ValueError(f"Could not find comparison operator in expression: {expr_str}") - + raise ValueError( + f"Could not find comparison operator in expression: {expr_str}" + ) + # Remove parentheses and split by operator expr_clean = expr_str.strip("()") parts = expr_clean.split(op_used) if len(parts) != 2: raise ValueError(f"Expected binary comparison in expression: {expr_str}") - + # Determine which part is the field and which is the literal field_name = None literal_value = None - + # Clean up the parts left = parts[0].strip() right = parts[1].strip() - + # Typically field name doesn't have quotes, literals (strings) do if left.startswith('"') or left.startswith("'"): # Right side is the field field_name = right # Extract the literal value - this is a simplification - literal_value = left.strip('"\'') + literal_value = left.strip("\"'") else: # Left side is the field field_name = left # Extract the literal value - this is a simplification - literal_value = right.strip('"\'') - + literal_value = right.strip("\"'") + # Try to convert numeric literals try: if "." in literal_value: @@ -137,17 +248,18 @@ def extract_field_and_literal(expr: pc.Expression) -> tuple[str, Any]: except ValueError: # Keep as string if not numeric pass - + return field_name, literal_value + def convert_pyarrow_expression(expr: pc.Expression) -> Optional[BooleanExpression]: """Convert a PyArrow compute expression to a PyIceberg boolean expression.""" if expr is None: return None - + # Handle the expression based on its string representation expr_str = str(expr) - + # Handle logical operations if "and" in expr_str.lower() and isinstance(expr, pc.Expression): # This is a simplification - in real code, you'd need to parse the expression @@ -156,7 +268,7 @@ def convert_pyarrow_expression(expr: pc.Expression) -> Optional[BooleanExpressio right_expr = None # You'd need to extract this return And( convert_pyarrow_expression(left_expr), - convert_pyarrow_expression(right_expr) + convert_pyarrow_expression(right_expr), ) elif "or" in expr_str.lower() and isinstance(expr, pc.Expression): # Similar simplification @@ -164,13 +276,13 @@ def convert_pyarrow_expression(expr: pc.Expression) -> Optional[BooleanExpressio right_expr = None # You'd need to extract this return Or( convert_pyarrow_expression(left_expr), - convert_pyarrow_expression(right_expr) + convert_pyarrow_expression(right_expr), ) elif "not" in expr_str.lower() and isinstance(expr, pc.Expression): # Similar simplification inner_expr = None # You'd need to extract this return Not(convert_pyarrow_expression(inner_expr)) - + # Handle comparison operations try: if "==" in expr_str: @@ -204,13 +316,13 @@ def convert_pyarrow_expressions(exprs: List[pc.Expression]) -> BooleanExpression """ if not exprs: raise ValueError("No expressions provided") - + if len(exprs) == 1: return convert_pyarrow_expression(exprs[0]) - + # Combine multiple expressions with AND result = convert_pyarrow_expression(exprs[0]) for expr in exprs[1:]: result = And(result, convert_pyarrow_expression(expr)) - + return result diff --git a/tests/tower/test_tables.py b/tests/tower/test_tables.py index 3ef0b027..bf044b25 100644 --- a/tests/tower/test_tables.py +++ b/tests/tower/test_tables.py @@ -14,13 +14,14 @@ # Imports the library under test import tower + def get_temp_dir(): """Create a temporary directory and return its file:// URL.""" # Create a temporary directory that will be automatically cleaned up temp_dir = tempfile.TemporaryDirectory() abs_path = pathlib.Path(temp_dir.name).absolute() - file_url = urljoin('file:', pathname2url(str(abs_path))) - + file_url = urljoin("file:", pathname2url(str(abs_path))) + # Return both the URL and the path to the temporary directory return file_url, abs_path @@ -38,21 +39,41 @@ def in_memory_catalog(): def test_reading_and_writing_to_tables(in_memory_catalog): - schema = pa.schema([ - pa.field("id", pa.int64()), - pa.field("name", pa.string()), - pa.field("age", pa.int32()), - pa.field("created_at", pa.timestamp("ms")), - ]) + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("name", pa.string()), + pa.field("age", pa.int32()), + pa.field("created_at", pa.timestamp("ms")), + ] + ) ref = tower.tables("users", catalog=in_memory_catalog) table = ref.create_if_not_exists(schema) - data_with_schema = pa.Table.from_pylist([ - {"id": 1, "name": "Alice", "age": 30, "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0)}, - {"id": 2, "name": "Bob", "age": 25, "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0)}, - {"id": 3, "name": "Charlie", "age": 35, "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0)}, - ], schema=schema) + data_with_schema = pa.Table.from_pylist( + [ + { + "id": 1, + "name": "Alice", + "age": 30, + "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0), + }, + { + "id": 2, + "name": "Bob", + "age": 25, + "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0), + }, + { + "id": 3, + "name": "Charlie", + "age": 35, + "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0), + }, + ], + schema=schema, + ) # If we write some data to the table, that should be...OK. table = table.insert(data_with_schema) @@ -66,24 +87,48 @@ def test_reading_and_writing_to_tables(in_memory_catalog): avg_age = df.select(pl.mean("age").alias("mean_age")).collect().item() assert avg_age == 30.0 + def test_upsert_to_tables(in_memory_catalog): - schema = pa.schema([ - pa.field("id", pa.int64()), - pa.field("username", pa.string()), - pa.field("name", pa.string()), - pa.field("age", pa.int32()), - pa.field("created_at", pa.timestamp("ms")), - ]) + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("username", pa.string()), + pa.field("name", pa.string()), + pa.field("age", pa.int32()), + pa.field("created_at", pa.timestamp("ms")), + ] + ) # First we'll insert some data into the relevant table. ref = tower.tables("users", catalog=in_memory_catalog) table = ref.create_if_not_exists(schema) - data_with_schema = pa.Table.from_pylist([ - {"id": 1, "username": "alicea", "name": "Alice", "age": 30, "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0)}, - {"id": 2, "username": "bobb", "name": "Bob", "age": 25, "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0)}, - {"id": 3, "username": "charliec", "name": "Charlie", "age": 35, "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0)}, - ], schema=schema) + data_with_schema = pa.Table.from_pylist( + [ + { + "id": 1, + "username": "alicea", + "name": "Alice", + "age": 30, + "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0), + }, + { + "id": 2, + "username": "bobb", + "name": "Bob", + "age": 25, + "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0), + }, + { + "id": 3, + "username": "charliec", + "name": "Charlie", + "age": 35, + "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0), + }, + ], + schema=schema, + ) # Make sure that we can actually insert the data into the table. table = table.insert(data_with_schema) @@ -91,13 +136,22 @@ def test_upsert_to_tables(in_memory_catalog): assert table.rows_affected().inserts == 3 # Now we'll update records in the table. - data_with_schema = pa.Table.from_pylist([ - {"id": 2, "username": "bobb", "name": "Bob", "age": 26, "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0)}, - ], schema=schema) + data_with_schema = pa.Table.from_pylist( + [ + { + "id": 2, + "username": "bobb", + "name": "Bob", + "age": 26, + "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0), + }, + ], + schema=schema, + ) # And make sure we can upsert the data. table = table.upsert(data_with_schema, join_cols=["username"]) - assert table.rows_affected().updates == 1 + assert table.rows_affected().updates == 1 # Now let's read from the table and see what we get back out. df = table.to_polars() @@ -107,24 +161,48 @@ def test_upsert_to_tables(in_memory_catalog): # The age should match what we updated the relevant record to assert res["age"].item() == 26 + def test_delete_from_tables(in_memory_catalog): - schema = pa.schema([ - pa.field("id", pa.int64()), - pa.field("username", pa.string()), - pa.field("name", pa.string()), - pa.field("age", pa.int32()), - pa.field("created_at", pa.timestamp("ms")), - ]) + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("username", pa.string()), + pa.field("name", pa.string()), + pa.field("age", pa.int32()), + pa.field("created_at", pa.timestamp("ms")), + ] + ) # First we'll insert some data into the relevant table. ref = tower.tables("users", catalog=in_memory_catalog) table = ref.create_if_not_exists(schema) - data_with_schema = pa.Table.from_pylist([ - {"id": 1, "username": "alicea", "name": "Alice", "age": 30, "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0)}, - {"id": 2, "username": "bobb", "name": "Bob", "age": 25, "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0)}, - {"id": 3, "username": "charliec", "name": "Charlie", "age": 35, "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0)}, - ], schema=schema) + data_with_schema = pa.Table.from_pylist( + [ + { + "id": 1, + "username": "alicea", + "name": "Alice", + "age": 30, + "created_at": datetime.datetime(2023, 1, 1, 0, 0, 0), + }, + { + "id": 2, + "username": "bobb", + "name": "Bob", + "age": 25, + "created_at": datetime.datetime(2023, 1, 2, 0, 0, 0), + }, + { + "id": 3, + "username": "charliec", + "name": "Charlie", + "age": 35, + "created_at": datetime.datetime(2023, 1, 3, 0, 0, 0), + }, + ], + schema=schema, + ) # Make sure that we can actually insert the data into the table. table = table.insert(data_with_schema) @@ -132,9 +210,7 @@ def test_delete_from_tables(in_memory_catalog): assert table.rows_affected().inserts == 3 # Perform the underlying delete from the table... - table.delete(filters=[ - table.column("username") == "bobb" - ]) + table.delete(filters=[table.column("username") == "bobb"]) # ...and let's make sure that record is actually gone. df = table.to_polars() @@ -143,14 +219,17 @@ def test_delete_from_tables(in_memory_catalog): all_rows = df.collect() assert all_rows.height == 2 + def test_getting_schemas_for_tables(in_memory_catalog): - original_schema = pa.schema([ - pa.field("id", pa.int64()), - pa.field("username", pa.string()), - pa.field("name", pa.string()), - pa.field("age", pa.int32()), - pa.field("created_at", pa.timestamp("ms")), - ]) + original_schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("username", pa.string()), + pa.field("name", pa.string()), + pa.field("age", pa.int32()), + pa.field("created_at", pa.timestamp("ms")), + ] + ) # First we'll insert some data into the relevant table. ref = tower.tables("users", catalog=in_memory_catalog) @@ -163,3 +242,317 @@ def test_getting_schemas_for_tables(in_memory_catalog): assert new_schema.field("id") is not None assert new_schema.field("age") is not None assert new_schema.field("created_at") is not None + + +def test_list_of_structs(in_memory_catalog): + """Tests writing and reading a list of non-nullable structs.""" + table_name = "test_list_of_structs_table" + # Define a pyarrow schema with a list of structs + # The list 'tags' can be null, but its elements (structs) are not nullable. + # Inside the struct, 'key' is non-nullable, 'value' is nullable. + item_struct_type = pa.struct( + [ + pa.field("key", pa.string(), nullable=False), # Non-nullable key + pa.field("value", pa.int64(), nullable=True), # Nullable value + ] + ) + # The 'item' field represents the elements of the list. It's non-nullable. + # This means each element in the list must be a valid struct, not a Python None. + pa_schema = pa.schema( + [ + pa.field("doc_id", pa.int32(), nullable=False), + pa.field( + "tags", + pa.list_(pa.field("item", item_struct_type, nullable=False)), + nullable=True, + ), + ] + ) + + ref = tower.tables(table_name, catalog=in_memory_catalog) + table = ref.create_if_not_exists(pa_schema) + assert table is not None, f"Table '{table_name}' should have been created" + + data_to_write = [ + { + "doc_id": 1, + "tags": [{"key": "user", "value": 100}, {"key": "priority", "value": 1}], + }, + { + "doc_id": 2, + "tags": [ + {"key": "source", "value": 200}, + {"key": "reviewed", "value": None}, + ], + }, # Null value for a struct field + {"doc_id": 3, "tags": []}, # Empty list + {"doc_id": 4, "tags": None}, # Null list + ] + arrow_table_write = pa.Table.from_pylist(data_to_write, schema=pa_schema) + + op_result = table.insert(arrow_table_write) + assert op_result is not None + assert op_result.rows_affected().inserts == 4 + + # Read back and verify + df_read = table.to_polars().collect() # Collect to get a Polars DataFrame + + assert df_read.shape[0] == 4 + assert df_read["doc_id"].to_list() == [1, 2, 3, 4] + + # Verify nested data (Polars handles structs and lists well) + # For doc_id = 1 + tags_doc1 = ( + df_read.filter(pl.col("doc_id") == 1).select("tags").row(0)[0] + ) # Get the list of structs + assert len(tags_doc1) == 2 + assert tags_doc1[0]["key"] == "user" + assert tags_doc1[0]["value"] == 100 + assert tags_doc1[1]["key"] == "priority" + assert tags_doc1[1]["value"] == 1 + + # For doc_id = 2 (with a null inside a struct) + tags_doc2 = df_read.filter(pl.col("doc_id") == 2).select("tags").row(0)[0] + assert len(tags_doc2) == 2 + assert tags_doc2[0]["key"] == "source" + assert tags_doc2[0]["value"] == 200 + assert tags_doc2[1]["key"] == "reviewed" + assert tags_doc2[1]["value"] is None + + # For doc_id = 3 (empty list) + tags_doc3 = df_read.filter(pl.col("doc_id") == 3).select("tags").row(0)[0] + assert len(tags_doc3) == 0 + + # For doc_id = 4 (null list should also be an empty list) + tags_doc4 = df_read.filter(pl.col("doc_id") == 4).select("tags").row(0)[0] + assert tags_doc4 == [] + + +def test_nested_structs(in_memory_catalog): + """Tests writing and reading a table with nested structs.""" + table_name = "test_nested_structs_table" + # Define a pyarrow schema with nested structs + # config: struct> + settings_struct_type = pa.struct( + [ + pa.field("retries", pa.int8(), nullable=False), + pa.field("timeout", pa.int32(), nullable=True), + pa.field("active", pa.bool_(), nullable=False), + ] + ) + pa_schema = pa.schema( + [ + pa.field("record_id", pa.string(), nullable=False), + pa.field( + "config", + pa.struct( + [ + pa.field("name", pa.string(), nullable=True), + pa.field( + "details", settings_struct_type, nullable=True + ), # This inner struct can be null + ] + ), + nullable=True, + ), # The outer 'config' struct can also be null + ] + ) + + ref = tower.tables(table_name, catalog=in_memory_catalog) + table = ref.create_if_not_exists(pa_schema) + assert table is not None, f"Table '{table_name}' should have been created" + + data_to_write = [ + { + "record_id": "rec1", + "config": { + "name": "Default", + "details": {"retries": 3, "timeout": 1000, "active": True}, + }, + }, + { + "record_id": "rec2", + "config": { + "name": "Fast", + "details": {"retries": 1, "timeout": None, "active": True}, + }, + }, # Null timeout + { + "record_id": "rec3", + "config": {"name": "Inactive", "details": None}, + }, # Null inner struct + {"record_id": "rec4", "config": None}, # Null outer struct + ] + arrow_table_write = pa.Table.from_pylist(data_to_write, schema=pa_schema) + + op_result = table.insert(arrow_table_write) + assert op_result is not None + assert op_result.rows_affected().inserts == 4 + + # Read back and verify + df_read = table.to_polars().collect() + + assert df_read.shape[0] == 4 + assert df_read["record_id"].to_list() == ["rec1", "rec2", "rec3", "rec4"] + + # Verify nested data for rec1 + config_rec1 = ( + df_read.filter(pl.col("record_id") == "rec1").select("config").row(0)[0] + ) + assert config_rec1["name"] == "Default" + details_rec1 = config_rec1["details"] + assert details_rec1["retries"] == 3 + assert details_rec1["timeout"] == 1000 + assert details_rec1["active"] is True + + # Verify nested data for rec2 (null timeout) + config_rec2 = ( + df_read.filter(pl.col("record_id") == "rec2").select("config").row(0)[0] + ) + assert config_rec2["name"] == "Fast" + details_rec2 = config_rec2["details"] + assert details_rec2["retries"] == 1 + assert details_rec2["timeout"] is None + assert details_rec2["active"] is True + + # Verify nested data for rec3 (null inner struct 'details') + config_rec3 = ( + df_read.filter(pl.col("record_id") == "rec3").select("config").row(0)[0] + ) + assert config_rec3["name"] == "Inactive" + assert config_rec3["details"] is None # The 'details' struct itself is null + + # Verify nested data for rec4 (null outer struct 'config') + config_rec4 = ( + df_read.filter(pl.col("record_id") == "rec4").select("config").row(0)[0] + ) + assert config_rec4 is None # The 'config' struct is null + + +def test_list_of_primitive_types(in_memory_catalog): + """Tests writing and reading a list of primitive types.""" + table_name = "test_list_of_primitives_table" + pa_schema = pa.schema( + [ + pa.field("event_id", pa.int32(), nullable=False), + pa.field( + "scores", + pa.list_(pa.field("score", pa.float32(), nullable=False)), + nullable=True, + ), # List of non-nullable floats + pa.field( + "keywords", + pa.list_(pa.field("keyword", pa.string(), nullable=True)), + nullable=True, + ), # List of nullable strings + ] + ) + + ref = tower.tables(table_name, catalog=in_memory_catalog) + table = ref.create_if_not_exists(pa_schema) + assert table is not None, f"Table '{table_name}' should have been created" + + data_to_write = [ + {"event_id": 1, "scores": [1.0, 2.5, 3.0], "keywords": ["alpha", "beta", None]}, + {"event_id": 2, "scores": [], "keywords": ["gamma"]}, + {"event_id": 3, "scores": None, "keywords": None}, + {"event_id": 4, "scores": [4.2], "keywords": []}, + ] + arrow_table_write = pa.Table.from_pylist(data_to_write, schema=pa_schema) + + op_result = table.insert(arrow_table_write) + assert op_result is not None + assert op_result.rows_affected().inserts == 4 + + df_read = table.to_polars().collect() + + assert df_read.shape[0] == 4 + + # Event 1 + row1 = df_read.filter(pl.col("event_id") == 1) + assert row1.select("scores").to_series()[0].to_list() == [1.0, 2.5, 3.0] + assert row1.select("keywords").to_series()[0].to_list() == ["alpha", "beta", None] + + # Event 2 + row2 = df_read.filter(pl.col("event_id") == 2) + assert row2.select("scores").to_series()[0].to_list() == [] + assert row2.select("keywords").to_series()[0].to_list() == ["gamma"] + + # Event 3 + row3 = df_read.filter(pl.col("event_id") == 3) + assert row3.select("scores").to_series()[0] is None + assert row3.select("keywords").to_series()[0] is None + + +def test_map_type_simple(in_memory_catalog): + """Tests writing and reading a simple map type.""" + table_name = "test_map_type_simple_table" + # Map from string to string. Keys are non-nullable, values can be nullable. + pa_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field( + "properties", + pa.map_(pa.string(), pa.string(), keys_sorted=False), + nullable=True, + ), + # Note: PyArrow map values are nullable by default if item_field is not specified with nullable=False + ] + ) + + ref = tower.tables(table_name, catalog=in_memory_catalog) + table = ref.create_if_not_exists(pa_schema) + assert table is not None, f"Table '{table_name}' should have been created" + + # PyArrow represents maps as a list of structs with 'key' and 'value' fields + data_to_write = [ + {"id": 1, "properties": [("color", "blue"), ("size", "large")]}, + { + "id": 2, + "properties": [("status", "pending"), ("owner", None)], + }, # Null value in map + {"id": 3, "properties": []}, # Empty map + {"id": 4, "properties": None}, # Null map field + ] + arrow_table_write = pa.Table.from_pylist(data_to_write, schema=pa_schema) + + op_result = table.insert(arrow_table_write) + assert op_result is not None + assert op_result.rows_affected().inserts == 4 + + df_read = table.to_polars().collect() + assert df_read.shape[0] == 4 + + # Verify map data + # Polars represents map as list of structs: struct + # Row 1 + props1_series = df_read.filter(pl.col("id") == 1).select("properties").to_series() + # The series item is already a list of dictionaries + props1_list = props1_series[0] + expected_props1 = [ + {"key": "color", "value": "blue"}, + {"key": "size", "value": "large"}, + ] + # Sort by key for consistent comparison if order is not guaranteed + assert sorted(props1_list, key=lambda x: x["key"]) == sorted( + expected_props1, key=lambda x: x["key"] + ) + + # Row 2 + props2_series = df_read.filter(pl.col("id") == 2).select("properties").to_series() + props2_list = props2_series[0] + expected_props2 = [ + {"key": "status", "value": "pending"}, + {"key": "owner", "value": None}, + ] + assert sorted(props2_list, key=lambda x: x["key"]) == sorted( + expected_props2, key=lambda x: x["key"] + ) + + # Row 3 (empty map) + props3_series = df_read.filter(pl.col("id") == 3).select("properties").to_series() + assert props3_series[0].to_list() == [] + + # Row 4 (null map) + props4_series = df_read.filter(pl.col("id") == 4).select("properties").to_series() + assert props4_series[0] is None From e676f259f53ad2009753951aef6feb6a7b61f15b Mon Sep 17 00:00:00 2001 From: Brad Heller Date: Tue, 13 May 2025 12:46:56 +0100 Subject: [PATCH 08/15] chore: Update `wait_for_run` and `wait_for_runs` implementations Updated implementations ensure that we equally check runs to detect failures part way through executions. Likewise, we add timeouts while talking to the Tower API in case there are some operational problems on that side of things. --- src/tower/_client.py | 313 ++++++++++------ src/tower/{_errors.py => exceptions.py} | 0 tests/tower/test_client.py | 458 ++++++++++++++---------- 3 files changed, 473 insertions(+), 298 deletions(-) rename src/tower/{_errors.py => exceptions.py} (100%) diff --git a/src/tower/_client.py b/src/tower/_client.py index 658f4112..2530a39a 100644 --- a/src/tower/_client.py +++ b/src/tower/_client.py @@ -3,7 +3,7 @@ from typing import List, Dict, Optional from ._context import TowerContext -from ._errors import ( +from .exceptions import ( NotFoundException, UnauthorizedException, UnknownException, @@ -36,23 +36,9 @@ # app somewhere. DEFAULT_TOWER_ENVIRONMENT = "default" - -def _env_client(ctx: TowerContext) -> AuthenticatedClient: - tower_url = ctx.tower_url - - if not tower_url.endswith("/v1"): - if tower_url.endswith("/"): - tower_url += "v1" - else: - tower_url += "/v1" - - return AuthenticatedClient( - verify_ssl=False, - base_url=tower_url, - token=ctx.api_key, - auth_header_name="X-API-Key", - prefix="", - ) +# DEFAULT_RETIRES_ON_FAILURE is the number of times to retry querying the Tower +# API before we just give up entirely. +DEFAULT_RETIRES_ON_FAILURE = 5 def run_app( @@ -110,6 +96,157 @@ def run_app( return output.run +def wait_for_run( + run: Run, + timeout: Optional[float] = 86_400.0, # one day + raise_on_failure: bool = False, +) -> Run: + """ + Wait for a Tower app run to reach a terminal state by polling the Tower API. + + This function continuously polls the Tower API every 2 seconds (defined by WAIT_TIMEOUT) + to check the status of the specified run. The function returns when the run reaches + a terminal state (exited, errored, cancelled, or crashed). + + Args: + run (Run): The Run object containing the app_slug and number of the run to monitor. + timeout (Optional[float]): Maximum time to wait in seconds before raising a + TimeoutException. Defaults to one day (86,400 seconds). + raise_on_failure (bool): If True, raises a RunFailedError when the run fails. + If False, returns the failed run object. Defaults to False. + + Returns: + Run: The final state of the run after completion or failure. + + Raises: + TimeoutException: If the specified timeout is reached before the run completes. + RunFailedError: If raise_on_failure is True and the run fails. + UnhandledRunStateException: If the run enters an unexpected state. + UnknownException: If there are persistent problems communicating with the Tower API. + NotFoundException: If the run cannot be found. + UnauthorizedException: If the API key is invalid or unauthorized. + """ + ctx = TowerContext.build() + retries = 0 + + # We use this to track the timeout, if one is defined. + start_time = time.time() + + while True: + # We check for a timeout at the top of the loop because we want to + # avoid waiting unnecessarily for the timeout hitting the Tower API if + # we've enounctered some sort of operational problem there. + if timeout is not None: + if _time_since(start_time) > timeout: + raise TimeoutException(t) + + # We time this out to avoid waiting forever on the API. + try: + desc = _check_run_status(ctx, run, timeout=2.0) + retries = 0 + + if _is_successful_run(desc): + return desc + elif _is_failed_run(desc): + if raise_on_failure: + raise RunFailedError(desc.app_slug, desc.number, desc.status) + else: + return desc + + elif _is_run_awaiting_completion(desc): + time.sleep(WAIT_TIMEOUT) + else: + raise UnhandledRunStateException(desc.status) + except TimeoutException: + # timed out in the API, we want to keep trying this for a while + # (assuming we didn't hit the global timeout limit) until we give + # up entirely. + retries += 1 + + if retries >= DEFAULT_RETRIES_ON_FAILURE: + raise UnknownException("There was a problem with the Tower API.") + + +def wait_for_runs( + runs: List[Run], + timeout: Optional[float] = 86_400.0, # one day + raise_on_failure: bool = False, +) -> tuple[List[Run], List[Run]]: + """ + Wait for multiple Tower app runs to reach terminal states by polling the Tower API. + + This function continuously polls the Tower API every 2 seconds (defined by WAIT_TIMEOUT) + to check the status of all specified runs. The function returns when all runs reach + terminal states (`exited`, `errored`, `cancelled`, or `crashed`). + + Args: + runs (List[Run]): A list of Run objects to monitor. + timeout (Optional[float]): Maximum time to wait in seconds before raising a + TimeoutException. Defaults to one day (86,400 seconds). + raise_on_failure (bool): If True, raises a RunFailedError when any run fails. + If False, failed runs are returned in the failed_runs list. Defaults to False. + + Returns: + tuple[List[Run], List[Run]]: A tuple containing two lists: + - successful_runs: List of runs that completed successfully (status: 'exited') + - failed_runs: List of runs that failed (status: 'crashed', 'cancelled', or 'errored') + + Raises: + TimeoutException: If the specified timeout is reached before all runs complete. + RunFailedError: If raise_on_failure is True and any run fails. + UnhandledRunStateException: If a run enters an unexpected state. + UnknownException: If there are persistent problems communicating with the Tower API. + NotFoundException: If any run cannot be found. + UnauthorizedException: If the API key is invalid or unauthorized. + """ + ctx = TowerContext.build() + retries = 0 + + # We use this to track the timeout, if one is defined. + start_time = time.time() + + awaiting_runs = runs + successful_runs = [] + failed_runs = [] + + while awaiting_runs: + for run in awaiting_runs: + # Check the overall timeout at the top of the loop in case we've + # spent a load of time deeper inside the loop on reties, etc. + if timeout is not None: + if _time_since(start_time) > timeout: + raise TimeoutException(t) + + try: + desc = _check_run_status(ctx, run, timeout=2.0) + retries = 0 + + if _is_successful_run(desc): + successful_runs.append(desc) + awaiting_runs.remove(run) + elif _is_failed_run(desc): + if raise_on_failure: + raise RunFailedError(desc.app_slug, desc.number, desc.status) + else: + failed_runs.append(desc) + awaiting_runs.remove(run) + + elif _is_run_awaiting_completion(desc): + time.sleep(WAIT_TIMEOUT) + else: + raise UnhandledRunStateException(desc.status) + except TimeoutException: + # timed out in the API, we want to keep trying this for a while + # (assuming we didn't hit the global timeout limit) until we give + # up entirely. + retries += 1 + + if retries >= DEFAULT_RETRIES_ON_FAILURE: + raise UnknownException("There was a problem with the Tower API.") + + return (successful_runs, failed_runs) + + def _is_failed_run(run: Run) -> bool: """ Check if the given run has failed. @@ -149,39 +286,37 @@ def _is_run_awaiting_completion(run: Run) -> bool: return run.status in ["pending", "scheduled", "running"] -def wait_for_run( - run: Run, - timeout: Optional[float] = 86_400.0, # one day - raise_on_failure: bool = False, -) -> Run: - """ - Wait for a Tower app run to reach a terminal state by polling the Tower API. +def _env_client(ctx: TowerContext, timeout: Optional[float] = None) -> AuthenticatedClient: + tower_url = ctx.tower_url - This function continuously polls the Tower API every 2 seconds (defined by WAIT_TIMEOUT) - to check the status of the specified run. The function returns when the run reaches - any of the defined terminal states. + if not tower_url.endswith("/v1"): + if tower_url.endswith("/"): + tower_url += "v1" + else: + tower_url += "/v1" - Args: - run (Run): The Run object containing the app_slug and number of the run to monitor. - timeout (Optional[float]): An optional timeout for this wait. Defaults - to one day (86,000 seconds). - raise_on_failure (bool): Whether to raise an exception when a failure - occurs. Defaults to False. + return AuthenticatedClient( + verify_ssl=False, + base_url=tower_url, + token=ctx.api_key, + auth_header_name="X-API-Key", + prefix="", + timeout=timeout, + ) - Returns: - None: This function does not return any value. - Raises: - RuntimeError: If there is an error fetching the run status from the Tower API - or if the API returns an error response. - """ - ctx = TowerContext.build() - client = _env_client(ctx) +def _time_since(start_time: float) -> float: + return time.time() - start_time - # We use this to track the timeout, if one is defined. - start_time = time.time() - while True: +def _check_run_status( + ctx: TowerContext, + run: Run, + timeout: Optional[float] = 2.0, # one day +) -> Run: + client = _env_client(ctx, timeout=timeout) + + try: output: Optional[Union[DescribeRunResponse, ErrorModel]] = describe_run_api.sync( slug=run.app_slug, seq=run.number, @@ -189,73 +324,23 @@ def wait_for_run( ) if output is None: - raise UnknownException("Error fetching run") - else: - if isinstance(output, ErrorModel): - # If it was a 404 error, that means that we couldn't find this - # app for some reason. This is really only relevant on the - # first time that we check--if we could find the run, but then - # suddenly couldn't that's a really big problem I'd say. - if output.status == 404: - raise NotFoundException(output.detail) - elif output.status == 401: - # NOTE: Most of the time, this shouldn't happen? - raise UnauthorizedException(output.detail) - else: - raise UnknownException(output.detail) + raise UnknownException("Failed to fetch run") + elif isinstance(output, ErrorModel): + # If it was a 404 error, that means that we couldn't find this + # app for some reason. This is really only relevant on the + # first time that we check--if we could find the run, but then + # suddenly couldn't that's a really big problem I'd say. + if output.status == 404: + raise NotFoundException(output.detail) + elif output.status == 401: + # NOTE: Most of the time, this shouldn't happen? + raise UnauthorizedException(output.detail) else: - desc = output.run - - if _is_successful_run(desc): - return True - elif _is_failed_run(desc): - if raise_on_failure: - raise RunFailedError(desc.app_slug, desc.number) - else: - return False - - elif _is_run_awaiting_completion(desc): - time.sleep(WAIT_TIMEOUT) - else: - raise UnhandledRunStateException(desc.status) - - # Before we head back to the top of the loop, let's see if we - # should timeout - if timeout is not None: - # The user defined a timeout, so let's actually see if we - # reached it. - t = time.time() - start_time - if t > timeout: - raise TimeoutException(t) - - -def wait_for_runs( - runs: List[Run], - timeout: Optional[float] = 86_400.0, # one day - raise_on_failure: bool = False, -) -> tuple[List[Run], List[Run]]: - """ - `wait_for_runs` waits for a list of runs to reach a terminal state by - polling the Tower API every 2 seconds for the latest status. If any of the - runs return a terminal status (`exited`, `errored`, `cancelled`, or - `crashed`) then this function returns. - - Args: - runs (List[Run]): A list of Run objects to monitor. - timeout (Optional[float]): Timeout to wait. - raise_on_failure (bool): If true, raises an exception when - any one of the awaited runs fails. Defaults to False. - - Returns: - None: This function does not return any value. - - Raises: - RuntimeError: If there is an error fetching the run status or if any - of the runs fail. - """ - for run in runs: - wait_for_run( - run, - timeout=timeout, - raise_on_failure=raise_on_failure, - ) + raise UnknownException(output.detail) + else: + # There was a run object, so let's return that. + return output.run + except httpx.TimeoutException: + # If we received a timeout from the API then we should raise our own + # timeout type. + raise TimeoutException("Timeout while waiting for run status") diff --git a/src/tower/_errors.py b/src/tower/exceptions.py similarity index 100% rename from src/tower/_errors.py rename to src/tower/exceptions.py diff --git a/tests/tower/test_client.py b/tests/tower/test_client.py index 156ab2bd..8c7d1ea2 100644 --- a/tests/tower/test_client.py +++ b/tests/tower/test_client.py @@ -1,238 +1,328 @@ - import os -import httpx import pytest +from datetime import datetime +from typing import List, Dict, Any, Optional -from tower.tower_api_client.models import ( - Run, -) +from tower.tower_api_client.models import Run +from tower.exceptions import RunFailedError -def test_running_apps(httpx_mock): - # Mock the response from the API - httpx_mock.add_response( - method="POST", - url="https://api.example.com/v1/apps/my-app/runs", - json={ + +@pytest.fixture +def mock_api_config(): + """Configure the Tower API client to use mock server.""" + os.environ["TOWER_URL"] = "https://api.example.com" + os.environ["TOWER_API_KEY"] = "abc123" + + # Only import after environment is configured + import tower + # Set WAIT_TIMEOUT to 0 to avoid actual waiting in tests + tower._client.WAIT_TIMEOUT = 0 + + return tower + + +@pytest.fixture +def mock_run_response_factory(): + """Factory to create consistent run response objects.""" + def _create_run_response( + app_slug: str = "my-app", + app_version: str = "v6", + number: int = 0, + run_id: str = "50ac9bc1-c783-4359-9917-a706f20dc02c", + status: str = "pending", + status_group: str = "", + parameters: Optional[List[Dict[str, Any]]] = None + ) -> Dict[str, Any]: + """Create a mock run response with the given parameters.""" + if parameters is None: + parameters = [] + + return { "run": { - "app_slug": "my-app", - "app_version": "v6", + "app_slug": app_slug, + "app_version": app_version, "cancelled_at": None, - "created_at": "2025-04-25T20:54:58.762547Z", - "ended_at": "2025-04-25T20:55:35.220295Z", - "environment": "default", - "number": 0, - "run_id": "50ac9bc1-c783-4359-9917-a706f20dc02c", + "created_at": "2025-04-25T20:54:58.762547Z", + "ended_at": "2025-04-25T20:55:35.220295Z", + "environment": "default", + "number": number, + "run_id": run_id, "scheduled_at": "2025-04-25T20:54:58.761867Z", - "started_at": "2025-04-25T20:54:59.366937Z", - "status": "pending", - "status_group": "", - "parameters": [] + "started_at": "2025-04-25T20:54:59.366937Z", + "status": status, + "status_group": status_group, + "parameters": parameters } - }, + } + + return _create_run_response + + +@pytest.fixture +def create_run_object(): + """Factory to create Run objects for testing.""" + def _create_run( + app_slug: str = "my-app", + app_version: str = "v6", + number: int = 0, + run_id: str = "50ac9bc1-c783-4359-9917-a706f20dc02c", + status: str = "running", + status_group: str = "failed", + parameters: Optional[List[Dict[str, Any]]] = None + ) -> Run: + """Create a Run object with the given parameters.""" + if parameters is None: + parameters = [] + + return Run( + app_slug=app_slug, + app_version=app_version, + cancelled_at=None, + created_at="2025-04-25T20:54:58.762547Z", + ended_at="2025-04-25T20:55:35.220295Z", + environment="default", + number=number, + run_id=run_id, + scheduled_at="2025-04-25T20:54:58.761867Z", + started_at="2025-04-25T20:54:59.366937Z", + status=status, + status_group=status_group, + parameters=parameters + ) + + return _create_run + + +def test_running_apps(httpx_mock, mock_api_config, mock_run_response_factory): + # Mock the response from the API + httpx_mock.add_response( + method="POST", + url="https://api.example.com/v1/apps/my-app/runs", + json=mock_run_response_factory(), status_code=200, ) - # We tell the client to use the mock server. - os.environ["TOWER_URL"] = "https://api.example.com" - os.environ["TOWER_API_KEY"] = "abc123" - # Call the function that makes the API request - import tower + tower = mock_api_config run: Run = tower.run_app("my-app", environment="production") # Assert the response assert run is not None + assert run.app_slug == "my-app" + assert run.status == "pending" -def test_waiting_for_a_run(httpx_mock): - # Mock the response from the API + +def test_waiting_for_a_run(httpx_mock, mock_api_config, mock_run_response_factory, create_run_object): + run_number = 3 + + # First response: pending status httpx_mock.add_response( method="GET", - url="https://api.example.com/v1/apps/my-app/runs/3", - json={ - "run": { - "app_slug": "my-app", - "app_version": "v6", - "cancelled_at": None, - "created_at": "2025-04-25T20:54:58.762547Z", - "ended_at": "2025-04-25T20:55:35.220295Z", - "environment": "default", - "number": 3, - "run_id": "50ac9bc1-c783-4359-9917-a706f20dc02c", - "scheduled_at": "2025-04-25T20:54:58.761867Z", - "started_at": "2025-04-25T20:54:59.366937Z", - "status": "pending", - "status_group": "", - "parameters": [] - } - }, + url=f"https://api.example.com/v1/apps/my-app/runs/{run_number}", + json=mock_run_response_factory(number=run_number, status="pending"), status_code=200, ) - # Second request, will indicate that it's done. + # Second response: completed status httpx_mock.add_response( method="GET", - url="https://api.example.com/v1/apps/my-app/runs/3", - json={ - "run": { - "app_slug": "my-app", - "app_version": "v6", - "cancelled_at": None, - "created_at": "2025-04-25T20:54:58.762547Z", - "ended_at": "2025-04-25T20:55:35.220295Z", - "environment": "default", - "number": 3, - "run_id": "50ac9bc1-c783-4359-9917-a706f20dc02c", - "scheduled_at": "2025-04-25T20:54:58.761867Z", - "started_at": "2025-04-25T20:54:59.366937Z", - "status": "exited", - "status_group": "successful", - "parameters": [] - } - }, + url=f"https://api.example.com/v1/apps/my-app/runs/{run_number}", + json=mock_run_response_factory(number=run_number, status="exited", status_group="successful"), status_code=200, ) - # We tell the client to use the mock server. - os.environ["TOWER_URL"] = "https://api.example.com" - os.environ["TOWER_API_KEY"] = "abc123" + tower = mock_api_config + run = create_run_object(number=run_number, status="crashed") - import tower + # Now actually wait for the run + final_run = tower.wait_for_run(run) + + # Verify the final state + assert final_run.status == "exited" + assert final_run.status_group == "successful" - run = Run( - app_slug="my-app", - app_version="v6", - cancelled_at=None, - created_at="2025-04-25T20:54:58.762547Z", - ended_at="2025-04-25T20:55:35.220295Z", - environment="default", - number=3, - run_id="50ac9bc1-c783-4359-9917-a706f20dc02c", - scheduled_at="2025-04-25T20:54:58.761867Z", - started_at="2025-04-25T20:54:59.366937Z", - status="crashed", - status_group="failed", - parameters=[] - ) - # Set WAIT_TIMEOUT to 0 so we don't have to...wait. - tower._client.WAIT_TIMEOUT = 0 +@pytest.mark.parametrize("run_numbers", [(3, 4)]) +def test_waiting_for_multiple_runs( + httpx_mock, + mock_api_config, + mock_run_response_factory, + create_run_object, + run_numbers +): + tower = mock_api_config + runs = [] + + # Setup mocks for each run + for run_number in run_numbers: + # First response: pending status + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/{run_number}", + json=mock_run_response_factory(number=run_number, status="pending"), + status_code=200, + ) - # Now actually wait for the run. - tower.wait_for_run(run) + # Second response: completed status + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/{run_number}", + json=mock_run_response_factory(number=run_number, status="exited", status_group="successful"), + status_code=200, + ) + + # Create the Run object + runs.append(create_run_object(number=run_number)) + + # Now actually wait for the runs + successful_runs, failed_runs = tower.wait_for_runs(runs) -def test_waiting_for_multiple_runs(httpx_mock): - # Mock the response from the API + assert len(failed_runs) == 0 + + # Verify all runs completed successfully + for run in successful_runs: + assert run.status == "exited" + assert run.status_group == "successful" + + +def test_failed_runs_in_the_list( + httpx_mock, + mock_api_config, + mock_run_response_factory, + create_run_object +): + tower = mock_api_config + runs = [] + + # For the first run, we're going to simulate a success. httpx_mock.add_response( method="GET", - url="https://api.example.com/v1/apps/my-app/runs/3", - json={ - "run": { - "app_slug": "my-app", - "app_version": "v6", - "cancelled_at": None, - "created_at": "2025-04-25T20:54:58.762547Z", - "ended_at": "2025-04-25T20:55:35.220295Z", - "environment": "default", - "number": 3, - "run_id": "50ac9bc1-c783-4359-9917-a706f20dc02c", - "scheduled_at": "2025-04-25T20:54:58.761867Z", - "started_at": "2025-04-25T20:54:59.366937Z", - "status": "pending", - "status_group": "", - "parameters": [] - } - }, + url=f"https://api.example.com/v1/apps/my-app/runs/1", + json=mock_run_response_factory(number=1, status="pending"), status_code=200, ) - # Second request, will indicate that it's done. httpx_mock.add_response( method="GET", - url="https://api.example.com/v1/apps/my-app/runs/3", - json={ - "run": { - "app_slug": "my-app", - "app_version": "v6", - "cancelled_at": None, - "created_at": "2025-04-25T20:54:58.762547Z", - "ended_at": "2025-04-25T20:55:35.220295Z", - "environment": "default", - "number": 3, - "run_id": "50ac9bc1-c783-4359-9917-a706f20dc02c", - "scheduled_at": "2025-04-25T20:54:58.761867Z", - "started_at": "2025-04-25T20:54:59.366937Z", - "status": "exited", - "status_group": "successful", - "parameters": [] - } - }, + url=f"https://api.example.com/v1/apps/my-app/runs/1", + json=mock_run_response_factory(number=1, status="exited", status_group="successful"), + status_code=200, + ) + + runs.append(create_run_object(number=1)) + + # Second run will have been a failure. + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/2", + json=mock_run_response_factory(number=2, status="pending"), status_code=200, ) - # Second request, will indicate that it's done. httpx_mock.add_response( method="GET", - url="https://api.example.com/v1/apps/my-app/runs/4", - json={ - "run": { - "app_slug": "my-app", - "app_version": "v6", - "cancelled_at": None, - "created_at": "2025-04-25T20:54:58.762547Z", - "ended_at": "2025-04-25T20:55:35.220295Z", - "environment": "default", - "number": 3, - "run_id": "50ac9bc1-c783-4359-9917-a706f20dc02c", - "scheduled_at": "2025-04-25T20:54:58.761867Z", - "started_at": "2025-04-25T20:54:59.366937Z", - "status": "exited", - "status_group": "successful", - "parameters": [] - } - }, + url=f"https://api.example.com/v1/apps/my-app/runs/2", + json=mock_run_response_factory(number=2, status="crashed", status_group="failed"), status_code=200, ) + + runs.append(create_run_object(number=2)) - # We tell the client to use the mock server. - os.environ["TOWER_URL"] = "https://api.example.com" - os.environ["TOWER_API_KEY"] = "abc123" + # Third run was a success. + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/3", + json=mock_run_response_factory(number=3, status="pending"), + status_code=200, + ) - import tower + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/3", + json=mock_run_response_factory(number=3, status="exited", status_group="successful"), + status_code=200, + ) + + runs.append(create_run_object(number=3)) + + + # Now actually wait for the runs + successful_runs, failed_runs = tower.wait_for_runs(runs) + + assert len(failed_runs) == 1 + + # Verify all successful runs + for run in successful_runs: + assert run.status == "exited" + assert run.status_group == "successful" + + # Verify all failed + for run in failed_runs: + assert run.status == "crashed" + assert run.status_group == "failed" + + +def test_raising_an_error_during_partial_failure( + httpx_mock, + mock_api_config, + mock_run_response_factory, + create_run_object +): + tower = mock_api_config + runs = [] + + # For the first run, we're going to simulate a success. + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/1", + json=mock_run_response_factory(number=1, status="pending"), + status_code=200, + ) + + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/1", + json=mock_run_response_factory(number=1, status="exited", status_group="successful"), + status_code=200, + ) + + runs.append(create_run_object(number=1)) + + # Second run will have been a failure. + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/2", + json=mock_run_response_factory(number=2, status="pending"), + status_code=200, + ) - run1 = Run( - app_slug="my-app", - app_version="v6", - cancelled_at=None, - created_at="2025-04-25T20:54:58.762547Z", - ended_at="2025-04-25T20:55:35.220295Z", - environment="default", - number=3, - run_id="50ac9bc1-c783-4359-9917-a706f20dc02c", - scheduled_at="2025-04-25T20:54:58.761867Z", - started_at="2025-04-25T20:54:59.366937Z", - status="running", - status_group="failed", - parameters=[] + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/2", + json=mock_run_response_factory(number=2, status="crashed", status_group="failed"), + status_code=200, ) + + runs.append(create_run_object(number=2)) - run2 = Run( - app_slug="my-app", - app_version="v6", - cancelled_at=None, - created_at="2025-04-25T20:54:58.762547Z", - ended_at="2025-04-25T20:55:35.220295Z", - environment="default", - number=4, - run_id="50ac9bc1-c783-4359-9917-a706f20dc02c", - scheduled_at="2025-04-25T20:54:58.761867Z", - started_at="2025-04-25T20:54:59.366937Z", - status="running", - status_group="failed", - parameters=[] + # Third run was a success. + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/3", + json=mock_run_response_factory(number=3, status="pending"), + status_code=200, ) - # Set WAIT_TIMEOUT to 0 so we don't have to...wait. - tower._client.WAIT_TIMEOUT = 0 + httpx_mock.add_response( + method="GET", + url=f"https://api.example.com/v1/apps/my-app/runs/3", + json=mock_run_response_factory(number=3, status="exited", status_group="successful"), + status_code=200, + ) + + runs.append(create_run_object(number=3)) + - # Now actually wait for the run. - tower.wait_for_runs([run1, run2]) + # Now actually wait for the runs + with pytest.raises(RunFailedError) as excinfo: + tower.wait_for_runs(runs, raise_on_failure=True) From 0c0e01956b5adfbf32341f13d4994e1a4840b867 Mon Sep 17 00:00:00 2001 From: Brad Heller Date: Tue, 13 May 2025 12:57:01 +0100 Subject: [PATCH 09/15] chore: Fix a few typos and some missing data --- src/tower/_client.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/tower/_client.py b/src/tower/_client.py index 2530a39a..02d0bc99 100644 --- a/src/tower/_client.py +++ b/src/tower/_client.py @@ -36,9 +36,9 @@ # app somewhere. DEFAULT_TOWER_ENVIRONMENT = "default" -# DEFAULT_RETIRES_ON_FAILURE is the number of times to retry querying the Tower +# DEFAULT_NUM_TIMEOUT_RETRIES is the number of times to retry querying the Tower # API before we just give up entirely. -DEFAULT_RETIRES_ON_FAILURE = 5 +DEFAULT_NUM_TIMEOUT_RETRIES = 5 def run_app( @@ -138,7 +138,7 @@ def wait_for_run( # we've enounctered some sort of operational problem there. if timeout is not None: if _time_since(start_time) > timeout: - raise TimeoutException(t) + raise TimeoutException(_time_since(start_time)) # We time this out to avoid waiting forever on the API. try: @@ -163,7 +163,7 @@ def wait_for_run( # up entirely. retries += 1 - if retries >= DEFAULT_RETRIES_ON_FAILURE: + if retries >= DEFAULT_NUM_TIMEOUT_RETRIES: raise UnknownException("There was a problem with the Tower API.") @@ -215,7 +215,7 @@ def wait_for_runs( # spent a load of time deeper inside the loop on reties, etc. if timeout is not None: if _time_since(start_time) > timeout: - raise TimeoutException(t) + raise TimeoutException(_time_since(start_time)) try: desc = _check_run_status(ctx, run, timeout=2.0) @@ -241,7 +241,7 @@ def wait_for_runs( # up entirely. retries += 1 - if retries >= DEFAULT_RETRIES_ON_FAILURE: + if retries >= DEFAULT_NUM_TIMEOUT_RETRIES: raise UnknownException("There was a problem with the Tower API.") return (successful_runs, failed_runs) @@ -343,4 +343,4 @@ def _check_run_status( except httpx.TimeoutException: # If we received a timeout from the API then we should raise our own # timeout type. - raise TimeoutException("Timeout while waiting for run status") + raise TimeoutException(timeout) From 556f458b48b73893ac5203e796c632bf971f8ef8 Mon Sep 17 00:00:00 2001 From: Brad Heller Date: Tue, 13 May 2025 12:58:30 +0100 Subject: [PATCH 10/15] chore: Import `httpx` when detecting timeoutes Thanks @copilot for the recommendation! --- src/tower/_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tower/_client.py b/src/tower/_client.py index 02d0bc99..c5940720 100644 --- a/src/tower/_client.py +++ b/src/tower/_client.py @@ -1,5 +1,6 @@ import os import time +import httpx from typing import List, Dict, Optional from ._context import TowerContext From bc03e0ce901d384c9132f99b71b65129d93906b4 Mon Sep 17 00:00:00 2001 From: Brad Heller Date: Tue, 13 May 2025 14:27:08 +0200 Subject: [PATCH 11/15] Update src/tower/exceptions.py Thanks @copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/tower/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tower/exceptions.py b/src/tower/exceptions.py index 09752dc6..e2a7295f 100644 --- a/src/tower/exceptions.py +++ b/src/tower/exceptions.py @@ -21,7 +21,7 @@ def __init__(self, state: str): class TimeoutException(Exception): def __init__(self, time: float): - super().__init__("A timeout occured after {time} seconds.") + super().__init__(f"A timeout occurred after {time} seconds.") class RunFailedError(RuntimeError): From 3de504e0491a8d1fd98bb941d3a97d181ace320d Mon Sep 17 00:00:00 2001 From: Brad Heller Date: Tue, 13 May 2025 13:34:08 +0100 Subject: [PATCH 12/15] chore: Don't modify a list while it's being iterated over --- src/tower/_client.py | 73 +++++++++++++++++++++----------------- tests/tower/test_client.py | 10 ++---- 2 files changed, 43 insertions(+), 40 deletions(-) diff --git a/src/tower/_client.py b/src/tower/_client.py index c5940720..1e371a85 100644 --- a/src/tower/_client.py +++ b/src/tower/_client.py @@ -210,40 +210,47 @@ def wait_for_runs( successful_runs = [] failed_runs = [] - while awaiting_runs: - for run in awaiting_runs: - # Check the overall timeout at the top of the loop in case we've - # spent a load of time deeper inside the loop on reties, etc. - if timeout is not None: - if _time_since(start_time) > timeout: - raise TimeoutException(_time_since(start_time)) - - try: - desc = _check_run_status(ctx, run, timeout=2.0) - retries = 0 - - if _is_successful_run(desc): - successful_runs.append(desc) - awaiting_runs.remove(run) - elif _is_failed_run(desc): - if raise_on_failure: - raise RunFailedError(desc.app_slug, desc.number, desc.status) - else: - failed_runs.append(desc) - awaiting_runs.remove(run) - - elif _is_run_awaiting_completion(desc): - time.sleep(WAIT_TIMEOUT) + while len(awaiting_runs) > 0: + run = awaiting_runs.pop(0) + + # Check the overall timeout at the top of the loop in case we've + # spent a load of time deeper inside the loop on reties, etc. + if timeout is not None: + if _time_since(start_time) > timeout: + raise TimeoutException(_time_since(start_time)) + + try: + desc = _check_run_status(ctx, run, timeout=2.0) + retries = 0 + + if _is_successful_run(desc): + successful_runs.append(desc) + elif _is_failed_run(desc): + if raise_on_failure: + raise RunFailedError(desc.app_slug, desc.number, desc.status) else: - raise UnhandledRunStateException(desc.status) - except TimeoutException: - # timed out in the API, we want to keep trying this for a while - # (assuming we didn't hit the global timeout limit) until we give - # up entirely. - retries += 1 - - if retries >= DEFAULT_NUM_TIMEOUT_RETRIES: - raise UnknownException("There was a problem with the Tower API.") + failed_runs.append(desc) + + elif _is_run_awaiting_completion(desc): + time.sleep(WAIT_TIMEOUT) + + # We need to re-add this run to the list so we check it again + # in the future. We add it to the back since we took it off the + # front, effectively moving to the next run. + awaiting_runs.append(run) + else: + raise UnhandledRunStateException(desc.status) + except TimeoutException: + # timed out in the API, we want to keep trying this for a while + # (assuming we didn't hit the global timeout limit) until we give + # up entirely. + retries += 1 + + if retries >= DEFAULT_NUM_TIMEOUT_RETRIES: + raise UnknownException("There was a problem with the Tower API.") + else: + # Add the item back on the list for retry later on. + awaiting_runs.append(run) return (successful_runs, failed_runs) diff --git a/tests/tower/test_client.py b/tests/tower/test_client.py index 8c7d1ea2..734a17d0 100644 --- a/tests/tower/test_client.py +++ b/tests/tower/test_client.py @@ -312,14 +312,10 @@ def test_raising_an_error_during_partial_failure( json=mock_run_response_factory(number=3, status="pending"), status_code=200, ) - - httpx_mock.add_response( - method="GET", - url=f"https://api.example.com/v1/apps/my-app/runs/3", - json=mock_run_response_factory(number=3, status="exited", status_group="successful"), - status_code=200, - ) + # NOTE: We don't have a second response for this run because we'll never + # get to it. + runs.append(create_run_object(number=3)) From 840458dcfe03678c31118c1bc252a1b1dcb37cc2 Mon Sep 17 00:00:00 2001 From: Brad Heller Date: Tue, 13 May 2025 14:08:55 +0100 Subject: [PATCH 13/15] chore: Create v0.3.14-rc1 prerelease --- Cargo.lock | 18 +++++++++--------- Cargo.toml | 3 ++- pyproject.toml | 3 ++- uv.lock | 2 +- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9b864c92..a84617ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -371,7 +371,7 @@ dependencies = [ [[package]] name = "config" -version = "0.3.13" +version = "0.3.14-rc.1" dependencies = [ "chrono", "clap", @@ -453,7 +453,7 @@ dependencies = [ [[package]] name = "crypto" -version = "0.3.13" +version = "0.3.14-rc.1" dependencies = [ "aes-gcm", "base64", @@ -2580,7 +2580,7 @@ dependencies = [ [[package]] name = "testutils" -version = "0.3.13" +version = "0.3.14-rc.1" dependencies = [ "pem", "rsa", @@ -2784,7 +2784,7 @@ dependencies = [ [[package]] name = "tower" -version = "0.3.13" +version = "0.3.14-rc.1" dependencies = [ "tokio", "tower-api", @@ -2808,7 +2808,7 @@ dependencies = [ [[package]] name = "tower-api" -version = "0.3.13" +version = "0.3.14-rc.1" dependencies = [ "log", "reqwest", @@ -2821,7 +2821,7 @@ dependencies = [ [[package]] name = "tower-cmd" -version = "0.3.13" +version = "0.3.14-rc.1" dependencies = [ "anyhow", "bytes", @@ -2861,7 +2861,7 @@ checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-package" -version = "0.3.13" +version = "0.3.14-rc.1" dependencies = [ "async-compression", "config", @@ -2880,7 +2880,7 @@ dependencies = [ [[package]] name = "tower-runtime" -version = "0.3.13" +version = "0.3.14-rc.1" dependencies = [ "chrono", "log", @@ -2897,7 +2897,7 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tower-version" -version = "0.3.13" +version = "0.3.14-rc.1" dependencies = [ "anyhow", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 66a2529f..4da356e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,8 @@ resolver = "2" [workspace.package] edition = "2021" -version = "0.3.13" +version = "0.3.14-rc.1" + description = "Tower is the best way to host Python data apps in production" rust-version = "1.77" authors = ["Brad Heller "] diff --git a/pyproject.toml b/pyproject.toml index 7c7e3531..0041670a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,8 @@ build-backend = "maturin" [project] name = "tower" -version = "0.3.13" +version = "0.3.14rc1" + description = "Tower CLI and runtime environment for Tower." authors = [{ name = "Tower Computing Inc.", email = "brad@tower.dev" }] readme = "README.md" diff --git a/uv.lock b/uv.lock index e8c5a210..2a126e04 100644 --- a/uv.lock +++ b/uv.lock @@ -1176,7 +1176,7 @@ wheels = [ [[package]] name = "tower" -version = "0.3.13" +version = "0.3.14rc1" source = { editable = "." } dependencies = [ { name = "attrs" }, From 3d4dac9cafd7cad66ff6842892a3d5d6405c2e36 Mon Sep 17 00:00:00 2001 From: Brad Heller Date: Tue, 13 May 2025 14:13:08 +0100 Subject: [PATCH 14/15] chore: Feedback from @copilot --- src/tower/_client.py | 2 +- tests/tower/test_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tower/_client.py b/src/tower/_client.py index 1e371a85..0d3a0b0b 100644 --- a/src/tower/_client.py +++ b/src/tower/_client.py @@ -237,7 +237,7 @@ def wait_for_runs( # We need to re-add this run to the list so we check it again # in the future. We add it to the back since we took it off the # front, effectively moving to the next run. - awaiting_runs.append(run) + awaiting_runs.append(desc) else: raise UnhandledRunStateException(desc.status) except TimeoutException: diff --git a/tests/tower/test_client.py b/tests/tower/test_client.py index 734a17d0..55174841 100644 --- a/tests/tower/test_client.py +++ b/tests/tower/test_client.py @@ -132,7 +132,7 @@ def test_waiting_for_a_run(httpx_mock, mock_api_config, mock_run_response_factor ) tower = mock_api_config - run = create_run_object(number=run_number, status="crashed") + run = create_run_object(number=run_number, status="pending") # Now actually wait for the run final_run = tower.wait_for_run(run) From f80ad331fe11e81d073ec4fd7b9e6c060d5399e4 Mon Sep 17 00:00:00 2001 From: Brad Heller Date: Tue, 13 May 2025 14:39:40 +0100 Subject: [PATCH 15/15] chore: Bump version to v0.3.14 --- Cargo.lock | 18 +++++++++--------- Cargo.toml | 3 ++- pyproject.toml | 3 ++- uv.lock | 2 +- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a84617ed..02eb08da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -371,7 +371,7 @@ dependencies = [ [[package]] name = "config" -version = "0.3.14-rc.1" +version = "0.3.14" dependencies = [ "chrono", "clap", @@ -453,7 +453,7 @@ dependencies = [ [[package]] name = "crypto" -version = "0.3.14-rc.1" +version = "0.3.14" dependencies = [ "aes-gcm", "base64", @@ -2580,7 +2580,7 @@ dependencies = [ [[package]] name = "testutils" -version = "0.3.14-rc.1" +version = "0.3.14" dependencies = [ "pem", "rsa", @@ -2784,7 +2784,7 @@ dependencies = [ [[package]] name = "tower" -version = "0.3.14-rc.1" +version = "0.3.14" dependencies = [ "tokio", "tower-api", @@ -2808,7 +2808,7 @@ dependencies = [ [[package]] name = "tower-api" -version = "0.3.14-rc.1" +version = "0.3.14" dependencies = [ "log", "reqwest", @@ -2821,7 +2821,7 @@ dependencies = [ [[package]] name = "tower-cmd" -version = "0.3.14-rc.1" +version = "0.3.14" dependencies = [ "anyhow", "bytes", @@ -2861,7 +2861,7 @@ checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-package" -version = "0.3.14-rc.1" +version = "0.3.14" dependencies = [ "async-compression", "config", @@ -2880,7 +2880,7 @@ dependencies = [ [[package]] name = "tower-runtime" -version = "0.3.14-rc.1" +version = "0.3.14" dependencies = [ "chrono", "log", @@ -2897,7 +2897,7 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tower-version" -version = "0.3.14-rc.1" +version = "0.3.14" dependencies = [ "anyhow", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 4da356e5..f28e5351 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,8 @@ resolver = "2" [workspace.package] edition = "2021" -version = "0.3.14-rc.1" +version = "0.3.14" + description = "Tower is the best way to host Python data apps in production" rust-version = "1.77" diff --git a/pyproject.toml b/pyproject.toml index 0041670a..13a48d6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,8 @@ build-backend = "maturin" [project] name = "tower" -version = "0.3.14rc1" +version = "0.3.14" + description = "Tower CLI and runtime environment for Tower." authors = [{ name = "Tower Computing Inc.", email = "brad@tower.dev" }] diff --git a/uv.lock b/uv.lock index 2a126e04..5953f46c 100644 --- a/uv.lock +++ b/uv.lock @@ -1176,7 +1176,7 @@ wheels = [ [[package]] name = "tower" -version = "0.3.14rc1" +version = "0.3.14" source = { editable = "." } dependencies = [ { name = "attrs" },