From f70bfda92f7380aa5d147511950593c5a4d20b57 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Tue, 4 Nov 2025 15:59:33 -0500 Subject: [PATCH 1/2] feat(v2): Postgres storage adapter This adds a postgres storage adapter for the taskbroker, as well as providing a way to choose between the adapters in the configuration. This adapter will also work with AlloyDB. In postgres, the keyword `offset` is reserved, so that column is called `kafka_offset` in the PG tables and converted to `offset`. The tests were updated to run with both the SQLite and Postgres adapter using the rstest crate. The `create_test_store` function was updated to be the standard for all tests, and to allow choosing between a SQLite and Postgres DB. A `remove_db` function was added to the trait and the existing adapters, since the tests create a unique PG database on every run that should be cleaned up. The `create_test_store` function was updated to be the standard for all tests, and to allow choosing between an SQLite and Postgres DB. --- Cargo.lock | 49 ++ Cargo.toml | 3 +- Dockerfile | 3 +- default_migrations/0001_create_database.sql | 1 + .../0001_create_inflight_activations.sql | 20 + src/config.rs | 11 + src/grpc/server_tests.rs | 73 +- src/kafka/deserialize_activation.rs | 2 +- src/kafka/inflight_activation_writer.rs | 157 ++-- src/main.rs | 26 +- src/store/inflight_activation.rs | 97 ++- src/store/inflight_activation_tests.rs | 375 ++++++--- src/store/mod.rs | 1 + src/store/postgres_activation_store.rs | 732 ++++++++++++++++++ src/test_utils.rs | 71 +- src/upkeep.rs | 127 +-- 16 files changed, 1431 insertions(+), 317 deletions(-) create mode 100644 default_migrations/0001_create_database.sql create mode 100644 pg_migrations/0001_create_inflight_activations.sql create mode 100644 src/store/postgres_activation_store.rs diff --git a/Cargo.lock b/Cargo.lock index b933e361..c93a4d1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -833,6 +833,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -890,6 +896,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "h2" version = "0.4.12" @@ -2133,6 +2145,12 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + [[package]] name = "reqwest" version = "0.12.23" @@ -2191,6 +2209,36 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rstest" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a2c585be59b6b5dd66a9d2084aa1d8bd52fbdb806eafdeffb52791147862035" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "825ea780781b15345a146be27eaefb05085e337e869bff01b4306a4fd4a9ad5a" +dependencies = [ + "cfg-if", + "glob", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn", + "unicode-ident", +] + [[package]] name = "rustc-demangle" version = "0.1.26" @@ -2884,6 +2932,7 @@ dependencies = [ "prost-types", "rand 0.8.5", "rdkafka", + "rstest", "sentry", "sentry_protos", "serde", diff --git a/Cargo.toml b/Cargo.toml index d785ec99..5d928197 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,7 @@ sentry_protos = "0.4.11" serde = "1.0.214" serde_yaml = "0.9.34" sha2 = "0.10.8" -sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "chrono"] } +sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "chrono", "postgres"] } tokio = { version = "1.43.1", features = ["full"] } tokio-stream = { version = "0.1.16", features = ["full"] } tokio-util = "0.7.12" @@ -61,6 +61,7 @@ uuid = { version = "1.11.0", features = ["v4"] } [dev-dependencies] criterion = { version = "0.5.1", features = ["async_tokio"] } +rstest = "0.23" [[bench]] name = "store_bench" diff --git a/Dockerfile b/Dockerfile index 125447c4..700da096 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ # recent enough version of protobuf-compiler FROM rust:1-bookworm AS build -RUN apt-get update && apt-get upgrade -y +RUN apt-get update && apt-get upgrade -y RUN apt-get install -y cmake pkg-config libssl-dev librdkafka-dev protobuf-compiler RUN USER=root cargo new --bin taskbroker @@ -17,6 +17,7 @@ ENV TASKBROKER_VERSION=$TASKBROKER_GIT_REVISION COPY ./Cargo.lock ./Cargo.lock COPY ./Cargo.toml ./Cargo.toml COPY ./migrations ./migrations +COPY ./pg_migrations ./pg_migrations COPY ./benches ./benches # Build dependencies in a way they can be cached diff --git a/default_migrations/0001_create_database.sql b/default_migrations/0001_create_database.sql new file mode 100644 index 00000000..00d61748 --- /dev/null +++ b/default_migrations/0001_create_database.sql @@ -0,0 +1 @@ +CREATE DATABASE taskbroker; diff --git a/pg_migrations/0001_create_inflight_activations.sql b/pg_migrations/0001_create_inflight_activations.sql new file mode 100644 index 00000000..80b552db --- /dev/null +++ b/pg_migrations/0001_create_inflight_activations.sql @@ -0,0 +1,20 @@ +-- PostgreSQL equivalent of the inflight_taskactivations table +CREATE TABLE IF NOT EXISTS inflight_taskactivations ( + id TEXT NOT NULL PRIMARY KEY, + activation BYTEA NOT NULL, + partition INTEGER NOT NULL, + kafka_offset BIGINT NOT NULL, + added_at TIMESTAMPTZ NOT NULL, + received_at TIMESTAMPTZ NOT NULL, + processing_attempts INTEGER NOT NULL, + expires_at TIMESTAMPTZ, + delay_until TIMESTAMPTZ, + processing_deadline_duration INTEGER NOT NULL, + processing_deadline TIMESTAMPTZ, + status TEXT NOT NULL, + at_most_once BOOLEAN NOT NULL DEFAULT FALSE, + application TEXT NOT NULL DEFAULT '', + namespace TEXT NOT NULL, + taskname TEXT NOT NULL, + on_attempts_exceeded INTEGER NOT NULL DEFAULT 1 +); diff --git a/src/config.rs b/src/config.rs index d4bbaf2b..af1d2fc1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -121,6 +121,14 @@ pub struct Config { /// The number of ms for timeouts when publishing messages to kafka. pub kafka_send_timeout_ms: u64, + pub database_adapter: &'static str, + + /// The url of the postgres database to use for the inflight activation store. + pub pg_url: String, + + /// The name of the postgres database to use for the inflight activation store. + pub pg_database_name: String, + /// The path to the sqlite database pub db_path: String, @@ -256,6 +264,9 @@ impl Default for Config { kafka_auto_offset_reset: "latest".to_owned(), kafka_send_timeout_ms: 500, db_path: "./taskbroker-inflight.sqlite".to_owned(), + database_adapter: "sqlite", + pg_url: "postgres://postgres:password@sentry-postgres-1:5432/".to_owned(), + pg_database_name: "taskbroker".to_owned(), db_write_failure_backoff_ms: 4000, db_insert_batch_max_len: 256, db_insert_batch_max_size: 16_000_000, diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index 1c9c9279..c72b1b50 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -1,6 +1,6 @@ use crate::grpc::server::TaskbrokerServer; -use crate::store::inflight_activation::InflightActivationStore; use prost::Message; +use rstest::rstest; use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerService; use sentry_protos::taskbroker::v1::{ FetchNextTask, GetTaskRequest, SetTaskStatusRequest, TaskActivation, @@ -10,8 +10,11 @@ use tonic::{Code, Request}; use crate::test_utils::{create_test_store, make_activations}; #[tokio::test] -async fn test_get_task() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_task(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let service = TaskbrokerServer { store }; let request = GetTaskRequest { namespace: None, @@ -25,9 +28,12 @@ async fn test_get_task() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_set_task_status() { - let store = create_test_store().await; +async fn test_set_task_status(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let service = TaskbrokerServer { store }; let request = SetTaskStatusRequest { id: "test_task".to_string(), @@ -41,9 +47,12 @@ async fn test_set_task_status() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_set_task_status_invalid() { - let store = create_test_store().await; +async fn test_set_task_status_invalid(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let service = TaskbrokerServer { store }; let request = SetTaskStatusRequest { id: "test_task".to_string(), @@ -61,9 +70,12 @@ async fn test_set_task_status_invalid() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_get_task_success() { - let store = create_test_store().await; +async fn test_get_task_success(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let activations = make_activations(1); store.store(activations).await.unwrap(); @@ -81,9 +93,12 @@ async fn test_get_task_success() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_get_task_with_application_success() { - let store = create_test_store().await; +async fn test_get_task_with_application_success(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -108,9 +123,12 @@ async fn test_get_task_with_application_success() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_get_task_with_namespace_requires_application() { - let store = create_test_store().await; +async fn test_get_task_with_namespace_requires_application(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let activations = make_activations(2); let namespace = activations[0].namespace.clone(); @@ -129,9 +147,12 @@ async fn test_get_task_with_namespace_requires_application() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_set_task_status_success() { - let store = create_test_store().await; +async fn test_set_task_status_success(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let activations = make_activations(2); store.store(activations).await.unwrap(); @@ -157,6 +178,7 @@ async fn test_set_task_status_success() { }), }; let response = service.set_task_status(Request::new(request)).await; + println!("response: {:?}", response); assert!(response.is_ok()); let resp = response.unwrap(); assert!(resp.get_ref().task.is_some()); @@ -165,9 +187,12 @@ async fn test_set_task_status_success() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_set_task_status_with_application() { - let store = create_test_store().await; +async fn test_set_task_status_with_application(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -199,9 +224,12 @@ async fn test_set_task_status_with_application() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_set_task_status_with_application_no_match() { - let store = create_test_store().await; +async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -228,9 +256,12 @@ async fn test_set_task_status_with_application_no_match() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_set_task_status_with_namespace_requires_application() { - let store = create_test_store().await; +async fn test_set_task_status_with_namespace_requires_application(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let activations = make_activations(2); let namespace = activations[0].namespace.clone(); diff --git a/src/kafka/deserialize_activation.rs b/src/kafka/deserialize_activation.rs index 891bc4f8..d0ea263e 100644 --- a/src/kafka/deserialize_activation.rs +++ b/src/kafka/deserialize_activation.rs @@ -87,7 +87,7 @@ pub fn new( added_at: Utc::now(), received_at: activation_time, processing_deadline: None, - processing_deadline_duration: activation.processing_deadline_duration as u32, + processing_deadline_duration: activation.processing_deadline_duration as i32, processing_attempts: 0, expires_at, delay_until, diff --git a/src/kafka/inflight_activation_writer.rs b/src/kafka/inflight_activation_writer.rs index ff48d5c6..d89bdf24 100644 --- a/src/kafka/inflight_activation_writer.rs +++ b/src/kafka/inflight_activation_writer.rs @@ -80,7 +80,6 @@ impl Reducer for InflightActivationWriter { self.batch.take(); return Ok(Some(())); } - // Check if writing the batch would exceed the limits let exceeded_pending_limit = self .store @@ -145,7 +144,6 @@ impl Reducer for InflightActivationWriter { "reason" => reason, ) .increment(1); - return Ok(None); } @@ -206,22 +204,23 @@ mod tests { use chrono::{DateTime, Utc}; use prost::Message; use prost_types::Timestamp; + use rstest::rstest; use std::collections::HashMap; + use crate::test_utils::create_test_store; use sentry_protos::taskbroker::v1::OnAttemptsExceeded; use sentry_protos::taskbroker::v1::TaskActivation; - use std::sync::Arc; - use crate::store::inflight_activation::{ - InflightActivationStatus, InflightActivationStore, InflightActivationStoreConfig, - SqliteActivationStore, - }; + use crate::store::inflight_activation::InflightActivationStatus; use crate::test_utils::generate_unique_namespace; use crate::test_utils::make_activations; - use crate::test_utils::{create_integration_config, generate_temp_filename}; #[tokio::test] - async fn test_writer_flush_batch() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_flush_batch(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, @@ -230,17 +229,7 @@ mod tests { max_delay_activations: 10, write_failure_backoff_ms: 4000, }; - let mut writer = InflightActivationWriter::new( - Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ), - writer_config, - ); + let mut writer = InflightActivationWriter::new(store, writer_config); let received_at = Timestamp { seconds: 0, @@ -333,29 +322,24 @@ mod tests { .await .unwrap(); assert_eq!(count_pending + count_delay, 2); + writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_flush_only_pending() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_flush_only_pending(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, max_pending_activations: 10, max_processing_activations: 10, - max_delay_activations: 0, + max_delay_activations: 10, write_failure_backoff_ms: 4000, }; - let mut writer = InflightActivationWriter::new( - Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ), - writer_config, - ); + let mut writer = InflightActivationWriter::new(store, writer_config); let received_at = Timestamp { seconds: 0, @@ -402,10 +386,15 @@ mod tests { writer.flush().await.unwrap(); let count_pending = writer.store.count_pending_activations().await.unwrap(); assert_eq!(count_pending, 1); + writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_flush_only_delay() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_flush_only_delay(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, @@ -414,17 +403,7 @@ mod tests { max_delay_activations: 10, write_failure_backoff_ms: 4000, }; - let mut writer = InflightActivationWriter::new( - Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ), - writer_config, - ); + let mut writer = InflightActivationWriter::new(store, writer_config); let received_at = Timestamp { seconds: 0, @@ -475,10 +454,15 @@ mod tests { .await .unwrap(); assert_eq!(count_delay, 1); + writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_backpressure_pending_limit_reached() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_backpressure_pending_limit_reached(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, @@ -487,17 +471,7 @@ mod tests { max_delay_activations: 0, write_failure_backoff_ms: 4000, }; - let mut writer = InflightActivationWriter::new( - Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ), - writer_config, - ); + let mut writer = InflightActivationWriter::new(store, writer_config); let received_at = Timestamp { seconds: 0, @@ -591,10 +565,17 @@ mod tests { .await .unwrap(); assert_eq!(count_delay, 0); + writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_backpressure_only_delay_limit_reached_and_entire_batch_is_pending() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_backpressure_only_delay_limit_reached_and_entire_batch_is_pending( + #[case] adapter: &str, + ) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, @@ -603,17 +584,7 @@ mod tests { max_delay_activations: 0, write_failure_backoff_ms: 4000, }; - let mut writer = InflightActivationWriter::new( - Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ), - writer_config, - ); + let mut writer = InflightActivationWriter::new(store, writer_config); let received_at = Timestamp { seconds: 0, @@ -707,10 +678,15 @@ mod tests { .await .unwrap(); assert_eq!(count_delay, 0); + writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_backpressure_processing_limit_reached() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_backpressure_processing_limit_reached(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, @@ -719,14 +695,6 @@ mod tests { max_delay_activations: 0, write_failure_backoff_ms: 4000, }; - let store = Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ); let received_at = Timestamp { seconds: 0, @@ -866,10 +834,17 @@ mod tests { .unwrap(); // Only the existing processing activation should remain, new ones should be blocked assert_eq!(count_processing, 1); + // TODO: Because the store and the writer both access the DB, both need to be cleaned up. + // Uncomment this when we figure out how to do that cleanly. + // writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_backpressure_db_size_limit_reached() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_backpressure_db_size_limit_reached(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { // 200 rows is ~50KB db_max_size: Some(50_000), @@ -879,14 +854,6 @@ mod tests { max_delay_activations: 0, write_failure_backoff_ms: 4000, }; - let store = Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ); let first_round = make_activations(200); store.store(first_round).await.unwrap(); assert!(store.db_size().await.unwrap() > 50_000); @@ -901,10 +868,15 @@ mod tests { let count_pending = writer.store.count_pending_activations().await.unwrap(); assert_eq!(count_pending, 200); + writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_flush_empty_batch() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_flush_empty_batch(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, @@ -913,17 +885,10 @@ mod tests { max_delay_activations: 10, write_failure_backoff_ms: 4000, }; - let store = Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ); let mut writer = InflightActivationWriter::new(store.clone(), writer_config); writer.reduce(vec![]).await.unwrap(); let flush_result = writer.flush().await.unwrap(); assert!(flush_result.is_some()); + writer.store.remove_db().await.unwrap(); } } diff --git a/src/main.rs b/src/main.rs index 5f0eee52..c1221d1c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,6 +33,9 @@ use taskbroker::runtime_config::RuntimeConfigManager; use taskbroker::store::inflight_activation::{ InflightActivationStore, InflightActivationStoreConfig, SqliteActivationStore, }; +use taskbroker::store::postgres_activation_store::{ + PostgresActivationStore, PostgresActivationStoreConfig, +}; use taskbroker::{Args, get_version}; use tonic_health::ServingStatus; @@ -62,13 +65,21 @@ async fn main() -> Result<(), Error> { logging::init(logging::LoggingConfig::from_config(&config)); metrics::init(metrics::MetricsConfig::from_config(&config)); - let store: Arc = Arc::new( - SqliteActivationStore::new( - &config.db_path, - InflightActivationStoreConfig::from_config(&config), - ) - .await?, - ); + + let store: Arc = match config.database_adapter { + "sqlite" => Arc::new( + SqliteActivationStore::new( + &config.db_path, + InflightActivationStoreConfig::from_config(&config), + ) + .await?, + ), + "postgres" => Arc::new( + PostgresActivationStore::new(PostgresActivationStoreConfig::from_config(&config)) + .await?, + ), + _ => panic!("Invalid database adapter: {}", config.database_adapter), + }; // If this is an environment where the topics might not exist, check and create them. if config.create_missing_topics { @@ -80,6 +91,7 @@ async fn main() -> Result<(), Error> { ) .await?; } + if config.full_vacuum_on_start { info!("Running full vacuum on database"); match store.full_vacuum_db().await { diff --git a/src/store/inflight_activation.rs b/src/store/inflight_activation.rs index 3668bc40..7d0034ab 100644 --- a/src/store/inflight_activation.rs +++ b/src/store/inflight_activation.rs @@ -1,6 +1,10 @@ +use anyhow::{Error, anyhow}; +use sqlx::postgres::PgQueryResult; +use std::fmt::Result as FmtResult; +use std::fmt::{Display, Formatter}; use std::{str::FromStr, time::Instant}; -use anyhow::{Error, anyhow}; +use crate::config::Config; use async_trait::async_trait; use chrono::{DateTime, Utc}; use libsqlite3_sys::{ @@ -23,8 +27,6 @@ use sqlx::{ }; use tracing::{instrument, warn}; -use crate::config::Config; - /// The members of this enum should be synced with the members /// of InflightActivationStatus in sentry_protos #[derive(Clone, Copy, Debug, PartialEq, Eq, Type)] @@ -39,6 +41,36 @@ pub enum InflightActivationStatus { Delay, } +impl Display for InflightActivationStatus { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{:?}", self) + } +} + +impl FromStr for InflightActivationStatus { + type Err = String; + + fn from_str(s: &str) -> Result { + if s == "Unspecified" { + Ok(InflightActivationStatus::Unspecified) + } else if s == "Pending" { + Ok(InflightActivationStatus::Pending) + } else if s == "Processing" { + Ok(InflightActivationStatus::Processing) + } else if s == "Failure" { + Ok(InflightActivationStatus::Failure) + } else if s == "Retry" { + Ok(InflightActivationStatus::Retry) + } else if s == "Complete" { + Ok(InflightActivationStatus::Complete) + } else if s == "Delay" { + Ok(InflightActivationStatus::Delay) + } else { + Err(format!("Unknown inflight activation status string: {}", s)) + } + } +} + impl InflightActivationStatus { /// Is the current value a 'conclusion' status that can be supplied over GRPC. pub fn is_conclusion(&self) -> bool { @@ -93,7 +125,7 @@ pub struct InflightActivation { /// The duration in seconds that a worker has to complete task execution. /// When an activation is moved from pending -> processing a result is expected /// in this many seconds. - pub processing_deadline_duration: u32, + pub processing_deadline_duration: i32, /// If the task has specified an expiry, this is the timestamp after which the task should be removed from inflight store pub expires_at: Option>, @@ -145,31 +177,39 @@ impl From for QueryResult { } } +impl From for QueryResult { + fn from(value: PgQueryResult) -> Self { + Self { + rows_affected: value.rows_affected(), + } + } +} + pub struct FailedTasksForwarder { pub to_discard: Vec<(String, Vec)>, pub to_deadletter: Vec<(String, Vec)>, } #[derive(Debug, FromRow)] -struct TableRow { - id: String, - activation: Vec, - partition: i32, - offset: i64, - added_at: DateTime, - received_at: DateTime, - processing_attempts: i32, - expires_at: Option>, - delay_until: Option>, - processing_deadline_duration: u32, - processing_deadline: Option>, - status: InflightActivationStatus, - at_most_once: bool, - application: String, - namespace: String, - taskname: String, +pub struct TableRow { + pub id: String, + pub activation: Vec, + pub partition: i32, + pub offset: i64, + pub added_at: DateTime, + pub received_at: DateTime, + pub processing_attempts: i32, + pub expires_at: Option>, + pub delay_until: Option>, + pub processing_deadline_duration: i32, + pub processing_deadline: Option>, + pub status: String, + pub at_most_once: bool, + pub application: String, + pub namespace: String, + pub taskname: String, #[sqlx(try_from = "i32")] - on_attempts_exceeded: OnAttemptsExceeded, + pub on_attempts_exceeded: OnAttemptsExceeded, } impl TryFrom for TableRow { @@ -188,7 +228,7 @@ impl TryFrom for TableRow { delay_until: value.delay_until, processing_deadline_duration: value.processing_deadline_duration, processing_deadline: value.processing_deadline, - status: value.status, + status: value.status.to_string(), at_most_once: value.at_most_once, application: value.application, namespace: value.namespace, @@ -203,7 +243,7 @@ impl From for InflightActivation { Self { id: value.id, activation: value.activation, - status: value.status, + status: InflightActivationStatus::from_str(&value.status).unwrap(), partition: value.partition, offset: value.offset, added_at: value.added_at, @@ -360,6 +400,9 @@ pub trait InflightActivationStore: Send + Sync { /// Remove killswitched tasks async fn remove_killswitched(&self, killswitched_tasks: Vec) -> Result; + + /// Remove the database, used only in tests + async fn remove_db(&self) -> Result<(), Error>; } pub struct SqliteActivationStore { @@ -656,6 +699,7 @@ impl InflightActivationStore for SqliteActivationStore { .into_iter() .map(TableRow::try_from) .collect::, _>>()?; + let query = query_builder .push_values(rows, |mut b, row| { b.push_bind(row.id); @@ -1180,4 +1224,9 @@ impl InflightActivationStore for SqliteActivationStore { Ok(query.rows_affected()) } + + // Used in tests + async fn remove_db(&self) -> Result<(), Error> { + Ok(()) + } } diff --git a/src/store/inflight_activation_tests.rs b/src/store/inflight_activation_tests.rs index 83d71c65..657bb3e2 100644 --- a/src/store/inflight_activation_tests.rs +++ b/src/store/inflight_activation_tests.rs @@ -1,5 +1,8 @@ use prost::Message; +use rstest::rstest; +use sqlx::{QueryBuilder, Sqlite}; use std::collections::{HashMap, HashSet}; +use std::fs; use std::io::Error; use std::path::Path; use std::sync::Arc; @@ -19,8 +22,6 @@ use chrono::{DateTime, SubsecRound, TimeZone, Utc}; use sentry_protos::taskbroker::v1::{ OnAttemptsExceeded, RetryState, TaskActivation, TaskActivationStatus, }; -use sqlx::{QueryBuilder, Sqlite}; -use std::fs; use tokio::sync::broadcast; use tokio::task::JoinSet; @@ -64,7 +65,7 @@ fn test_inflightactivation_status_from() { } #[tokio::test] -async fn test_create_db() { +async fn test_sqlite_create_db() { assert!( SqliteActivationStore::new( &generate_temp_filename(), @@ -76,34 +77,50 @@ async fn test_create_db() { } #[tokio::test] -async fn test_store() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_store(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch).await.is_ok()); let result = store.count().await; assert_eq!(result.unwrap(), 2); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_store_duplicate_id_in_batch() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_store_duplicate_id_in_batch(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); // Coerce a conflict batch[0].id = "id_0".into(); batch[1].id = "id_0".into(); - assert!(store.store(batch).await.is_ok()); + let first_result = store.store(batch).await; + assert!( + first_result.is_ok(), + "{}", + first_result.err().unwrap().to_string() + ); let result = store.count().await; assert_eq!(result.unwrap(), 1); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_store_duplicate_id_between_batches() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_store_duplicate_id_between_batches(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch.clone()).await.is_ok()); @@ -118,11 +135,15 @@ async fn test_store_duplicate_id_between_batches() { let second_count = store.count().await; assert_eq!(second_count.unwrap(), 2); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch.clone()).await.is_ok()); @@ -149,11 +170,15 @@ async fn test_get_pending_activation() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test(flavor = "multi_thread", worker_threads = 32)] -async fn test_get_pending_activation_with_race() { - let store = Arc::new(create_test_store().await); +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_with_race(#[case] adapter: &str) { + let store = Arc::new(create_test_store(adapter).await); let namespace = generate_unique_namespace(); const NUM_CONCURRENT_WRITES: u32 = 2000; @@ -192,11 +217,15 @@ async fn test_get_pending_activation_with_race() { .collect(); assert_eq!(res.len(), NUM_CONCURRENT_WRITES as usize); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_with_namespace() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_with_namespace(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].namespace = "other_namespace".into(); @@ -212,11 +241,15 @@ async fn test_get_pending_activation_with_namespace() { assert_eq!(result.status, InflightActivationStatus::Processing); assert!(result.processing_deadline.unwrap() > Utc::now()); assert_eq!(result.namespace, "other_namespace"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_from_multiple_namespaces() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_from_multiple_namespaces(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(4); batch[0].namespace = "ns1".into(); @@ -233,17 +266,21 @@ async fn test_get_pending_activation_from_multiple_namespaces() { .unwrap(); assert_eq!(result.len(), 2); - assert_eq!(result[0].id, "id_1"); - assert_eq!(result[0].namespace, "ns2"); - assert_eq!(result[0].status, InflightActivationStatus::Processing); assert_eq!(result[1].id, "id_2"); assert_eq!(result[1].namespace, "ns3"); assert_eq!(result[1].status, InflightActivationStatus::Processing); + assert_eq!(result[0].id, "id_1"); + assert_eq!(result[0].namespace, "ns2"); + assert_eq!(result[0].status, InflightActivationStatus::Processing); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_with_namespace_requires_application() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_with_namespace_requires_application(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].namespace = "other_namespace".into(); @@ -268,11 +305,24 @@ async fn test_get_pending_activation_with_namespace_requires_application() { activations.len(), "should find 1 activation with a matching namespace" ); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_skip_expires() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_skip_expires(#[case] adapter: &str) { + let store = create_test_store(adapter).await; + + assert_counts( + StatusCount { + pending: 0, + ..StatusCount::default() + }, + store.as_ref(), + ) + .await; let mut batch = make_activations(1); batch[0].expires_at = Some(Utc::now() - Duration::from_secs(100)); @@ -291,16 +341,21 @@ async fn test_get_pending_activation_skip_expires() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_earliest() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_earliest(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[0].added_at = Utc.with_ymd_and_hms(2024, 6, 24, 0, 0, 0).unwrap(); batch[1].added_at = Utc.with_ymd_and_hms(1998, 6, 24, 0, 0, 0).unwrap(); - assert!(store.store(batch.clone()).await.is_ok()); + let ret = store.store(batch.clone()).await; + assert!(ret.is_ok(), "{}", ret.err().unwrap().to_string()); let result = store .get_pending_activation(None, None) @@ -311,11 +366,15 @@ async fn test_get_pending_activation_earliest() { result.added_at, Utc.with_ymd_and_hms(1998, 6, 24, 0, 0, 0).unwrap() ); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_fetches_application() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_fetches_application(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(1); batch[0].application = "hammers".into(); @@ -332,11 +391,15 @@ async fn test_get_pending_activation_fetches_application() { assert_eq!(result.status, InflightActivationStatus::Processing); assert!(result.processing_deadline.unwrap() > Utc::now()); assert_eq!(result.application, "hammers"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_with_application() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_with_application(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].application = "hammers".into(); @@ -364,11 +427,15 @@ async fn test_get_pending_activation_with_application() { let result_opt = store.get_pending_activation(None, None).await.unwrap(); assert!(result_opt.is_some(), "one pending activation in '' left"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_with_application_and_namespace() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_with_application_and_namespace(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(3); batch[0].namespace = "target".into(); @@ -400,11 +467,15 @@ async fn test_get_pending_activation_with_application_and_namespace() { assert_eq!(result.id, "id_2"); assert_eq!(result.application, "hammers"); assert_eq!(result.namespace, "not-target"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_count_pending_activations() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_count_pending_activations(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(3); batch[0].status = InflightActivationStatus::Processing; @@ -420,11 +491,15 @@ async fn test_count_pending_activations() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn set_activation_status() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_set_activation_status(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch).await.is_ok()); @@ -514,37 +589,37 @@ async fn set_activation_status() { let inflight = result_opt.unwrap(); assert_eq!(inflight.id, "id_0"); assert_eq!(inflight.status, InflightActivationStatus::Complete); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_set_processing_deadline() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_set_processing_deadline(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(1); assert!(store.store(batch.clone()).await.is_ok()); - let deadline = Utc::now(); - assert!( - store - .set_processing_deadline("id_0", Some(deadline)) - .await - .is_ok() - ); + let deadline = Utc::now().round_subsecs(0); + let result = store.set_processing_deadline("id_0", Some(deadline)).await; + assert!(result.is_ok(), "query error: {:?}", result.err().unwrap()); let result = store.get_by_id("id_0").await.unwrap().unwrap(); assert_eq!( - result - .processing_deadline - .unwrap() - .round_subsecs(0) - .timestamp(), + result.processing_deadline.unwrap().timestamp(), deadline.timestamp() - ) + ); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_delete_activation() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_delete_activation(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch).await.is_ok()); @@ -563,11 +638,15 @@ async fn test_delete_activation() { assert!(store.delete_activation("id_1").await.is_ok()); let result = store.count().await; assert_eq!(result.unwrap(), 0); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_retry_activations() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_retry_activations(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch.clone()).await.is_ok()); @@ -608,11 +687,15 @@ async fn test_get_retry_activations() { for record in retries.iter() { assert_eq!(record.status, InflightActivationStatus::Retry); } + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_processing_deadline() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_processing_deadline(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].status = InflightActivationStatus::Processing; @@ -648,11 +731,15 @@ async fn test_handle_processing_deadline() { let count = store.handle_processing_deadline().await; assert!(count.is_ok()); assert_eq!(count.unwrap(), 0); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_processing_deadline_multiple_tasks() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_processing_deadline_multiple_tasks(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[0].status = InflightActivationStatus::Processing; @@ -681,11 +768,15 @@ async fn test_handle_processing_deadline_multiple_tasks() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_processing_at_most_once() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_processing_at_most_once(#[case] adapter: &str) { + let store = create_test_store(adapter).await; // Both records are past processing deadlines let mut batch = make_activations(2); @@ -731,11 +822,15 @@ async fn test_handle_processing_at_most_once() { let task = store.get_by_id(&batch[1].id).await.unwrap().unwrap(); assert_eq!(task.status, InflightActivationStatus::Failure); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_processing_deadline_discard_after() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_processing_deadline_discard_after(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].status = InflightActivationStatus::Processing; @@ -773,11 +868,15 @@ async fn test_handle_processing_deadline_discard_after() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_processing_deadline_deadletter_after() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_processing_deadline_deadletter_after(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].status = InflightActivationStatus::Processing; @@ -815,11 +914,15 @@ async fn test_handle_processing_deadline_deadletter_after() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_processing_deadline_no_retries_remaining() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_processing_deadline_no_retries_remaining(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].status = InflightActivationStatus::Processing; @@ -857,12 +960,16 @@ async fn test_handle_processing_deadline_no_retries_remaining() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_processing_attempts_exceeded() { +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_processing_attempts_exceeded(#[case] adapter: &str) { let config = create_integration_config(); - let store = create_test_store().await; + let store = create_test_store(adapter).await; let mut batch = make_activations(3); batch[0].status = InflightActivationStatus::Pending; @@ -899,11 +1006,15 @@ async fn test_processing_attempts_exceeded() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_remove_completed() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_remove_completed(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut records = make_activations(3); records[0].status = InflightActivationStatus::Complete; @@ -956,11 +1067,15 @@ async fn test_remove_completed() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_remove_completed_multiple_gaps() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_remove_completed_multiple_gaps(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut records = make_activations(4); // only record 1 can be removed @@ -1027,11 +1142,15 @@ async fn test_remove_completed_multiple_gaps() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_failed_tasks() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_failed_tasks(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut records = make_activations(4); // deadletter @@ -1113,11 +1232,15 @@ async fn test_handle_failed_tasks() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_mark_completed() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_mark_completed(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let records = make_activations(3); assert!(store.store(records.clone()).await.is_ok()); @@ -1145,11 +1268,15 @@ async fn test_mark_completed() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_expires_at() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_expires_at(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(3); // All expired tasks should be removed, regardless of order or other tasks. @@ -1168,7 +1295,11 @@ async fn test_handle_expires_at() { .await; let result = store.handle_expires_at().await; - assert!(result.is_ok()); + assert!( + result.is_ok(), + "handle_expires_at should be ok {:?}", + result + ); assert_eq!(result.unwrap(), 2); assert_counts( StatusCount { @@ -1178,11 +1309,15 @@ async fn test_handle_expires_at() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_remove_killswitched() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_remove_killswitched(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(6); batch[0].taskname = "task_to_be_killswitched_one".to_string(); @@ -1216,11 +1351,15 @@ async fn test_remove_killswitched() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_clear() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_clear(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let namespace = generate_unique_namespace(); #[allow(deprecated)] @@ -1275,28 +1414,37 @@ async fn test_clear() { assert!(store.clear().await.is_ok()); assert_eq!(store.count().await.unwrap(), 0); assert_counts(StatusCount::default(), store.as_ref()).await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_full_vacuum() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_full_vacuum(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch).await.is_ok()); let result = store.full_vacuum_db().await; assert!(result.is_ok()); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_vacuum_db_no_limit() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_vacuum_db_no_limit(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch).await.is_ok()); let result = store.vacuum_db().await; assert!(result.is_ok()); + store.remove_db().await.unwrap(); } #[tokio::test] @@ -1320,8 +1468,11 @@ async fn test_vacuum_db_incremental() { } #[tokio::test] -async fn test_db_size() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_db_size(#[case] adapter: &str) { + let store = create_test_store(adapter).await; assert!(store.db_size().await.is_ok()); let first_size = store.db_size().await.unwrap(); @@ -1333,12 +1484,16 @@ async fn test_db_size() { let second_size = store.db_size().await.unwrap(); assert!(second_size > first_size, "should have more bytes now"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_pending_activation_max_lag_no_pending() { +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_pending_activation_max_lag_no_pending(#[case] adapter: &str) { let now = Utc::now(); - let store = create_test_store().await; + let store = create_test_store(adapter).await; // No activations, max lag is 0 assert_eq!(0.0, store.pending_activation_max_lag(&now).await); @@ -1348,12 +1503,16 @@ async fn test_pending_activation_max_lag_no_pending() { // No pending activations, max lag is 0 assert_eq!(0.0, store.pending_activation_max_lag(&now).await); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_pending_activation_max_lag_use_oldest() { +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_pending_activation_max_lag_use_oldest(#[case] adapter: &str) { let now = Utc::now(); - let store = create_test_store().await; + let store = create_test_store(adapter).await; let mut pending = make_activations(2); pending[0].received_at = now - Duration::from_secs(10); @@ -1363,12 +1522,16 @@ async fn test_pending_activation_max_lag_use_oldest() { let result = store.pending_activation_max_lag(&now).await; assert!(11.0 < result, "Should not get the small record"); assert!(result < 501.0, "Should not get an inflated value"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_pending_activation_max_lag_ignore_processing_attempts() { - let now = Utc::now(); - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_pending_activation_max_lag_ignore_processing_attempts(#[case] adapter: &str) { + let now = Utc::now().round_subsecs(0); + let store = create_test_store(adapter).await; let mut pending = make_activations(2); pending[0].received_at = now - Duration::from_secs(10); @@ -1377,14 +1540,17 @@ async fn test_pending_activation_max_lag_ignore_processing_attempts() { assert!(store.store(pending).await.is_ok()); let result = store.pending_activation_max_lag(&now).await; - assert!(10.00 < result); - assert!(result < 11.00); + assert_eq!(result, 10.0, "max lag: {result:?}"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_pending_activation_max_lag_account_for_delayed() { +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_pending_activation_max_lag_account_for_delayed(#[case] adapter: &str) { let now = Utc::now(); - let store = create_test_store().await; + let store = create_test_store(adapter).await; let mut pending = make_activations(2); // delayed tasks are received well before they become pending @@ -1395,7 +1561,8 @@ async fn test_pending_activation_max_lag_account_for_delayed() { let result = store.pending_activation_max_lag(&now).await; assert!(22.00 < result, "result: {result}"); - assert!(result < 23.00, "result: {result}"); + assert!(result < 24.00, "result: {result}"); + store.remove_db().await.unwrap(); } #[tokio::test] diff --git a/src/store/mod.rs b/src/store/mod.rs index dcc0f255..deb05655 100644 --- a/src/store/mod.rs +++ b/src/store/mod.rs @@ -1,3 +1,4 @@ pub mod inflight_activation; #[cfg(test)] pub mod inflight_activation_tests; +pub mod postgres_activation_store; diff --git a/src/store/postgres_activation_store.rs b/src/store/postgres_activation_store.rs new file mode 100644 index 00000000..363bbda1 --- /dev/null +++ b/src/store/postgres_activation_store.rs @@ -0,0 +1,732 @@ +use crate::store::inflight_activation::{ + FailedTasksForwarder, InflightActivation, InflightActivationStatus, InflightActivationStore, + QueryResult, TableRow, +}; +use anyhow::{Error, anyhow}; +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use sentry_protos::taskbroker::v1::OnAttemptsExceeded; +use sqlx::{ + Pool, Postgres, QueryBuilder, Row, + pool::PoolConnection, + postgres::{PgConnectOptions, PgPool, PgPoolOptions, PgRow}, +}; +use std::{str::FromStr, time::Instant}; +use tracing::{instrument, warn}; + +use crate::config::Config; + +pub async fn create_postgres_pool( + url: &str, + database_name: &str, +) -> Result<(Pool, Pool), Error> { + let conn_str = url.to_owned() + "/" + database_name; + let read_pool = PgPoolOptions::new() + .max_connections(64) + .connect_with(PgConnectOptions::from_str(&conn_str)?) + .await?; + + let write_pool = PgPoolOptions::new() + .max_connections(64) + .connect_with(PgConnectOptions::from_str(&conn_str)?) + .await?; + Ok((read_pool, write_pool)) +} + +pub async fn create_default_postgres_pool(url: &str) -> Result, Error> { + let conn_str = url.to_owned() + "/postgres"; + let read_pool = PgPoolOptions::new() + .max_connections(64) + .connect_with(PgConnectOptions::from_str(&conn_str)?) + .await?; + Ok(read_pool) +} + +pub struct PostgresActivationStoreConfig { + pub pg_url: String, + pub pg_database_name: String, + pub max_processing_attempts: usize, + pub processing_deadline_grace_sec: u64, + pub vacuum_page_count: Option, + pub enable_sqlite_status_metrics: bool, +} + +impl PostgresActivationStoreConfig { + pub fn from_config(config: &Config) -> Self { + Self { + pg_url: config.pg_url.clone(), + pg_database_name: config.pg_database_name.clone(), + max_processing_attempts: config.max_processing_attempts, + vacuum_page_count: config.vacuum_page_count, + processing_deadline_grace_sec: config.processing_deadline_grace_sec, + enable_sqlite_status_metrics: config.enable_sqlite_status_metrics, + } + } +} + +pub struct PostgresActivationStore { + read_pool: PgPool, + write_pool: PgPool, + config: PostgresActivationStoreConfig, +} + +impl PostgresActivationStore { + async fn acquire_write_conn_metric( + &self, + caller: &'static str, + ) -> Result, Error> { + let start = Instant::now(); + let conn = self.write_pool.acquire().await?; + metrics::histogram!("postgres.write.acquire_conn", "fn" => caller).record(start.elapsed()); + Ok(conn) + } + + pub async fn new(config: PostgresActivationStoreConfig) -> Result { + let default_pool = create_default_postgres_pool(&config.pg_url).await?; + + // Create the database if it doesn't exist + let row: (bool,) = sqlx::query_as( + "SELECT EXISTS ( SELECT 1 FROM pg_catalog.pg_database WHERE datname = $1 )", + ) + .bind(&config.pg_database_name) + .fetch_one(&default_pool) + .await?; + + if !row.0 { + println!("Creating database {}", &config.pg_database_name); + sqlx::query(format!("CREATE DATABASE {}", &config.pg_database_name).as_str()) + .bind(&config.pg_database_name) + .execute(&default_pool) + .await?; + } + // Close the default pool + default_pool.close().await; + + let (read_pool, write_pool) = + create_postgres_pool(&config.pg_url, &config.pg_database_name).await?; + sqlx::migrate!("./pg_migrations").run(&write_pool).await?; + + Ok(Self { + read_pool, + write_pool, + config, + }) + } +} + +#[async_trait] +impl InflightActivationStore for PostgresActivationStore { + /// Trigger incremental vacuum to reclaim free pages in the database. + /// Depending on config data, will either vacuum a set number of + /// pages or attempt to reclaim all free pages. + #[instrument(skip_all)] + async fn vacuum_db(&self) -> Result<(), Error> { + // TODO: Remove + Ok(()) + } + + /// Perform a full vacuum on the database. + async fn full_vacuum_db(&self) -> Result<(), Error> { + // TODO: Remove + Ok(()) + } + + /// Get the size of the database in bytes based on SQLite metadata queries. + async fn db_size(&self) -> Result { + let row_result: (i64,) = sqlx::query_as("SELECT pg_database_size($1) as size") + .bind(&self.config.pg_database_name) + .fetch_one(&self.read_pool) + .await?; + if row_result.0 < 0 { + return Ok(0); + } + Ok(row_result.0 as u64) + } + + /// Get an activation by id. Primarily used for testing + async fn get_by_id(&self, id: &str) -> Result, Error> { + let row_result: Option = sqlx::query_as( + " + SELECT id, + activation, + partition, + kafka_offset AS offset, + added_at, + received_at, + processing_attempts, + expires_at, + delay_until, + processing_deadline_duration, + processing_deadline, + status, + at_most_once, + application, + namespace, + taskname, + on_attempts_exceeded + FROM inflight_taskactivations + WHERE id = $1 + ", + ) + .bind(id) + .fetch_optional(&self.read_pool) + .await?; + + let Some(row) = row_result else { + return Ok(None); + }; + + Ok(Some(row.into())) + } + + #[instrument(skip_all)] + async fn store(&self, batch: Vec) -> Result { + if batch.is_empty() { + return Ok(QueryResult { rows_affected: 0 }); + } + let mut query_builder = QueryBuilder::::new( + " + INSERT INTO inflight_taskactivations + ( + id, + activation, + partition, + kafka_offset, + added_at, + received_at, + processing_attempts, + expires_at, + delay_until, + processing_deadline_duration, + processing_deadline, + status, + at_most_once, + application, + namespace, + taskname, + on_attempts_exceeded + ) + ", + ); + let rows = batch + .into_iter() + .map(TableRow::try_from) + .collect::, _>>()?; + let query = query_builder + .push_values(rows, |mut b, row| { + b.push_bind(row.id); + b.push_bind(row.activation); + b.push_bind(row.partition); + b.push_bind(row.offset); + b.push_bind(row.added_at); + b.push_bind(row.received_at); + b.push_bind(row.processing_attempts); + b.push_bind(Some(row.expires_at)); + b.push_bind(Some(row.delay_until)); + b.push_bind(row.processing_deadline_duration); + if let Some(deadline) = row.processing_deadline { + b.push_bind(deadline); + } else { + // Add a literal null + b.push("null"); + } + b.push_bind(row.status); + b.push_bind(row.at_most_once); + b.push_bind(row.application); + b.push_bind(row.namespace); + b.push_bind(row.taskname); + b.push_bind(row.on_attempts_exceeded as i32); + }) + .push(" ON CONFLICT(id) DO NOTHING") + .build(); + let mut conn = self.acquire_write_conn_metric("store").await?; + Ok(query.execute(&mut *conn).await?.into()) + } + + #[instrument(skip_all)] + async fn get_pending_activation( + &self, + application: Option<&str>, + namespace: Option<&str>, + ) -> Result, Error> { + // Convert single namespace to vector for internal use + let namespaces = namespace.map(|ns| vec![ns.to_string()]); + + // If a namespace filter is used, an application must also be used. + if namespaces.is_some() && application.is_none() { + warn!( + "Received request for namespaced task without application. namespaces = {namespaces:?}" + ); + return Ok(None); + } + let result = self + .get_pending_activations_from_namespaces(application, namespaces.as_deref(), Some(1)) + .await?; + if result.is_empty() { + return Ok(None); + } + Ok(Some(result[0].clone())) + } + + /// Get a pending activation from specified namespaces + /// If namespaces is None, gets from any namespace + /// If namespaces is Some(&[...]), gets from those namespaces + #[instrument(skip_all)] + async fn get_pending_activations_from_namespaces( + &self, + application: Option<&str>, + namespaces: Option<&[String]>, + limit: Option, + ) -> Result, Error> { + let now = Utc::now(); + + let grace_period = self.config.processing_deadline_grace_sec; + let mut query_builder = QueryBuilder::new( + "WITH selected_activations AS ( + SELECT id + FROM inflight_taskactivations + WHERE status = ", + ); + query_builder.push_bind(InflightActivationStatus::Pending.to_string()); + query_builder.push(" AND (expires_at IS NULL OR expires_at > "); + query_builder.push_bind(now); + query_builder.push(")"); + + // Handle application & namespace filtering + if let Some(value) = application { + query_builder.push(" AND application ="); + query_builder.push_bind(value); + } + if let Some(namespaces) = namespaces + && !namespaces.is_empty() + { + query_builder.push(" AND namespace IN ("); + let mut separated = query_builder.separated(", "); + for namespace in namespaces.iter() { + separated.push_bind(namespace); + } + query_builder.push(")"); + } + query_builder.push(" ORDER BY added_at"); + if let Some(limit) = limit { + query_builder.push(" LIMIT "); + query_builder.push_bind(limit); + } + query_builder.push(" FOR UPDATE SKIP LOCKED)"); + query_builder.push(format!( + "UPDATE inflight_taskactivations + SET + processing_deadline = now() + (processing_deadline_duration * interval '1 second') + (interval '{grace_period} seconds'), + status = " + )); + query_builder.push_bind(InflightActivationStatus::Processing.to_string()); + query_builder.push(" FROM selected_activations "); + query_builder.push(" WHERE inflight_taskactivations.id = selected_activations.id"); + query_builder.push(" RETURNING *, kafka_offset AS offset"); + + let mut conn = self + .acquire_write_conn_metric("get_pending_activation") + .await?; + let rows: Vec = query_builder + .build_query_as::() + .fetch_all(&mut *conn) + .await?; + + Ok(rows.into_iter().map(|row| row.into()).collect()) + } + + /// Get the age of the oldest pending activation in seconds. + /// Only activations with status=pending and processing_attempts=0 are considered + /// as we are interested in latency to the *first* attempt. + /// Tasks with delay_until set, will have their age adjusted based on their + /// delay time. No tasks = 0 lag + async fn pending_activation_max_lag(&self, now: &DateTime) -> f64 { + let result = sqlx::query( + "SELECT received_at, delay_until + FROM inflight_taskactivations + WHERE status = $1 + AND processing_attempts = 0 + ORDER BY received_at ASC + LIMIT 1 + ", + ) + .bind(InflightActivationStatus::Pending.to_string()) + .fetch_one(&self.read_pool) + .await; + if let Ok(row) = result { + let received_at: DateTime = row.get("received_at"); + let delay_until: Option> = row.get("delay_until"); + let millis = now.signed_duration_since(received_at).num_milliseconds() + - delay_until.map_or(0, |delay_time| { + delay_time + .signed_duration_since(received_at) + .num_milliseconds() + }); + millis as f64 / 1000.0 + } else { + // If we couldn't find a row, there is no latency. + 0.0 + } + } + + #[instrument(skip_all)] + async fn count_pending_activations(&self) -> Result { + self.count_by_status(InflightActivationStatus::Pending) + .await + } + + #[instrument(skip_all)] + async fn count_by_status(&self, status: InflightActivationStatus) -> Result { + let result = + sqlx::query("SELECT COUNT(*) as count FROM inflight_taskactivations WHERE status = $1") + .bind(status.to_string()) + .fetch_one(&self.read_pool) + .await?; + Ok(result.get::("count") as usize) + } + + async fn count(&self) -> Result { + let result = sqlx::query("SELECT COUNT(*) as count FROM inflight_taskactivations") + .fetch_one(&self.read_pool) + .await?; + Ok(result.get::("count") as usize) + } + + /// Update the status of a specific activation + #[instrument(skip_all)] + async fn set_status( + &self, + id: &str, + status: InflightActivationStatus, + ) -> Result, Error> { + let mut conn = self.acquire_write_conn_metric("set_status").await?; + let result: Option = sqlx::query_as( + "UPDATE inflight_taskactivations SET status = $1 WHERE id = $2 RETURNING *, kafka_offset AS offset", + ) + .bind(status.to_string()) + .bind(id) + .fetch_optional(&mut *conn) + .await?; + println!("result: {:?}", result); + let Some(row) = result else { + return Ok(None); + }; + + Ok(Some(row.into())) + } + + #[instrument(skip_all)] + async fn set_processing_deadline( + &self, + id: &str, + deadline: Option>, + ) -> Result<(), Error> { + let mut conn = self + .acquire_write_conn_metric("set_processing_deadline") + .await?; + sqlx::query("UPDATE inflight_taskactivations SET processing_deadline = $1 WHERE id = $2") + .bind(deadline.unwrap()) + .bind(id) + .execute(&mut *conn) + .await?; + Ok(()) + } + + #[instrument(skip_all)] + async fn delete_activation(&self, id: &str) -> Result<(), Error> { + let mut conn = self.acquire_write_conn_metric("delete_activation").await?; + sqlx::query("DELETE FROM inflight_taskactivations WHERE id = $1") + .bind(id) + .execute(&mut *conn) + .await?; + Ok(()) + } + + #[instrument(skip_all)] + async fn get_retry_activations(&self) -> Result, Error> { + Ok(sqlx::query_as( + " + SELECT id, + activation, + partition, + kafka_offset AS offset, + added_at, + received_at, + processing_attempts, + expires_at, + delay_until, + processing_deadline_duration, + processing_deadline, + status, + at_most_once, + application, + namespace, + taskname, + on_attempts_exceeded + FROM inflight_taskactivations + WHERE status = $1 + ", + ) + .bind(InflightActivationStatus::Retry.to_string()) + .fetch_all(&self.read_pool) + .await? + .into_iter() + .map(|row: TableRow| row.into()) + .collect()) + } + + // Used in tests + async fn clear(&self) -> Result<(), Error> { + let mut conn = self.acquire_write_conn_metric("clear").await?; + sqlx::query("TRUNCATE TABLE inflight_taskactivations") + .execute(&mut *conn) + .await?; + + Ok(()) + } + + /// Update tasks that are in processing and have exceeded their processing deadline + /// Exceeding a processing deadline does not consume a retry as we don't know + /// if a worker took the task and was killed, or failed. + #[instrument(skip_all)] + async fn handle_processing_deadline(&self) -> Result { + let now = Utc::now(); + let mut atomic = self.write_pool.begin().await?; + + // Idempotent tasks that fail their processing deadlines go directly to failure + // there are no retries, as the worker will reject the task due to idempotency keys. + let most_once_result = sqlx::query( + "UPDATE inflight_taskactivations + SET processing_deadline = null, status = $1 + WHERE processing_deadline < $2 AND at_most_once = TRUE AND status = $3", + ) + .bind(InflightActivationStatus::Failure.to_string()) + .bind(now) + .bind(InflightActivationStatus::Processing.to_string()) + .execute(&mut *atomic) + .await; + + let mut processing_deadline_modified_rows = 0; + if let Ok(query_res) = most_once_result { + processing_deadline_modified_rows = query_res.rows_affected(); + } + + // Update non-idempotent tasks. + // Increment processing_attempts by 1 and reset processing_deadline to null. + let result = sqlx::query( + "UPDATE inflight_taskactivations + SET processing_deadline = null, status = $1, processing_attempts = processing_attempts + 1 + WHERE processing_deadline < $2 AND status = $3", + ) + .bind(InflightActivationStatus::Pending.to_string()) + .bind(now) + .bind(InflightActivationStatus::Processing.to_string()) + .execute(&mut *atomic) + .await; + + atomic.commit().await?; + + if let Ok(query_res) = result { + processing_deadline_modified_rows += query_res.rows_affected(); + return Ok(processing_deadline_modified_rows); + } + + Err(anyhow!("Could not update tasks past processing_deadline")) + } + + /// Update tasks that have exceeded their max processing attempts. + /// These tasks are set to status=failure and will be handled by handle_failed_tasks accordingly. + #[instrument(skip_all)] + async fn handle_processing_attempts(&self) -> Result { + let mut conn = self + .acquire_write_conn_metric("handle_processing_attempts") + .await?; + let processing_attempts_result = sqlx::query( + "UPDATE inflight_taskactivations + SET status = $1 + WHERE processing_attempts >= $2 AND status = $3", + ) + .bind(InflightActivationStatus::Failure.to_string()) + .bind(self.config.max_processing_attempts as i32) + .bind(InflightActivationStatus::Pending.to_string()) + .execute(&mut *conn) + .await; + + if let Ok(query_res) = processing_attempts_result { + return Ok(query_res.rows_affected()); + } + + Err(anyhow!("Could not update tasks past processing_deadline")) + } + + /// Perform upkeep work for tasks that are past expires_at deadlines + /// + /// Tasks that are pending and past their expires_at deadline are updated + /// to have status=failure so that they can be discarded/deadlettered by handle_failed_tasks + /// + /// The number of impacted records is returned in a Result. + #[instrument(skip_all)] + async fn handle_expires_at(&self) -> Result { + let now = Utc::now(); + let mut conn = self.acquire_write_conn_metric("handle_expires_at").await?; + let query = sqlx::query( + "DELETE FROM inflight_taskactivations WHERE status = $1 AND expires_at IS NOT NULL AND expires_at < $2", + ) + .bind(InflightActivationStatus::Pending.to_string()) + .bind(now) + .execute(&mut *conn) + .await?; + + Ok(query.rows_affected()) + } + + /// Perform upkeep work for tasks that are past delay_until deadlines + /// + /// Tasks that are delayed and past their delay_until deadline are updated + /// to have status=pending so that they can be executed by workers + /// + /// The number of impacted records is returned in a Result. + #[instrument(skip_all)] + async fn handle_delay_until(&self) -> Result { + let now = Utc::now(); + let mut conn = self.acquire_write_conn_metric("handle_delay_until").await?; + let update_result = sqlx::query( + r#"UPDATE inflight_taskactivations + SET status = $1 + WHERE delay_until IS NOT NULL AND delay_until < $2 AND status = $3 + "#, + ) + .bind(InflightActivationStatus::Pending.to_string()) + .bind(now) + .bind(InflightActivationStatus::Delay.to_string()) + .execute(&mut *conn) + .await?; + + Ok(update_result.rows_affected()) + } + + /// Perform upkeep work related to status=failure + /// + /// Activations that are status=failure need to either be discarded by setting status=complete + /// or need to be moved to deadletter and are returned in the Result. + /// Once dead-lettered tasks have been added to Kafka those tasks can have their status set to + /// complete. + #[instrument(skip_all)] + async fn handle_failed_tasks(&self) -> Result { + let mut atomic = self.write_pool.begin().await?; + + let failed_tasks: Vec = + sqlx::query("SELECT id, activation, on_attempts_exceeded FROM inflight_taskactivations WHERE status = $1") + .bind(InflightActivationStatus::Failure.to_string()) + .fetch_all(&mut *atomic) + .await? + .into_iter() + .collect(); + + let mut forwarder = FailedTasksForwarder { + to_discard: vec![], + to_deadletter: vec![], + }; + + for record in failed_tasks.iter() { + let activation_data: &[u8] = record.get("activation"); + let id: String = record.get("id"); + // We could be deadlettering because of activation.expires + // when a task expires we still deadletter if configured. + let on_attempts_exceeded_val: i32 = record.get("on_attempts_exceeded"); + let on_attempts_exceeded: OnAttemptsExceeded = + on_attempts_exceeded_val.try_into().unwrap(); + if on_attempts_exceeded == OnAttemptsExceeded::Discard + || on_attempts_exceeded == OnAttemptsExceeded::Unspecified + { + forwarder.to_discard.push((id, activation_data.to_vec())) + } else if on_attempts_exceeded == OnAttemptsExceeded::Deadletter { + forwarder.to_deadletter.push((id, activation_data.to_vec())) + } + } + + if !forwarder.to_discard.is_empty() { + let mut query_builder = QueryBuilder::new("UPDATE inflight_taskactivations "); + query_builder + .push("SET status = ") + .push_bind(InflightActivationStatus::Complete.to_string()) + .push(" WHERE id IN ("); + + let mut separated = query_builder.separated(", "); + for (id, _body) in forwarder.to_discard.iter() { + separated.push_bind(id); + } + separated.push_unseparated(")"); + + query_builder.build().execute(&mut *atomic).await?; + } + + atomic.commit().await?; + + Ok(forwarder) + } + + /// Mark a collection of tasks as complete by id + #[instrument(skip_all)] + async fn mark_completed(&self, ids: Vec) -> Result { + let mut query_builder = QueryBuilder::new("UPDATE inflight_taskactivations "); + query_builder + .push("SET status = ") + .push_bind(InflightActivationStatus::Complete.to_string()) + .push(" WHERE id IN ("); + + let mut separated = query_builder.separated(", "); + for id in ids.iter() { + separated.push_bind(id); + } + separated.push_unseparated(")"); + let mut conn = self.acquire_write_conn_metric("mark_completed").await?; + let result = query_builder.build().execute(&mut *conn).await?; + + Ok(result.rows_affected()) + } + + /// Remove completed tasks. + /// This method is a garbage collector for the inflight task store. + #[instrument(skip_all)] + async fn remove_completed(&self) -> Result { + let mut conn = self.acquire_write_conn_metric("remove_completed").await?; + let query = sqlx::query("DELETE FROM inflight_taskactivations WHERE status = $1") + .bind(InflightActivationStatus::Complete.to_string()) + .execute(&mut *conn) + .await?; + + Ok(query.rows_affected()) + } + + /// Remove killswitched tasks. + #[instrument(skip_all)] + async fn remove_killswitched(&self, killswitched_tasks: Vec) -> Result { + let mut query_builder = + QueryBuilder::new("DELETE FROM inflight_taskactivations WHERE taskname IN ("); + let mut separated = query_builder.separated(", "); + for taskname in killswitched_tasks.iter() { + separated.push_bind(taskname); + } + separated.push_unseparated(")"); + let mut conn = self + .acquire_write_conn_metric("remove_killswitched") + .await?; + let query = query_builder.build().execute(&mut *conn).await?; + + Ok(query.rows_affected()) + } + + // Used in tests + async fn remove_db(&self) -> Result<(), Error> { + self.read_pool.close().await; + self.write_pool.close().await; + let default_pool = create_default_postgres_pool(&self.config.pg_url).await?; + let _ = sqlx::query(format!("DROP DATABASE {}", &self.config.pg_database_name).as_str()) + .bind(&self.config.pg_database_name) + .execute(&default_pool) + .await; + let _ = default_pool.close().await; + Ok(()) + } +} diff --git a/src/test_utils.rs b/src/test_utils.rs index 9df8ba05..eefc7540 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,13 +1,12 @@ use futures::StreamExt; use prost::Message as ProstMessage; -use rand::Rng; use rdkafka::{ Message, admin::{AdminClient, AdminOptions, NewTopic, TopicReplication}, - consumer::{Consumer, StreamConsumer}, + consumer::{CommitMode, Consumer, StreamConsumer}, producer::FutureProducer, }; -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, env::var, sync::Arc}; use uuid::Uuid; use crate::{ @@ -16,14 +15,25 @@ use crate::{ InflightActivation, InflightActivationStatus, InflightActivationStore, InflightActivationStoreConfig, SqliteActivationStore, }, + store::postgres_activation_store::{PostgresActivationStore, PostgresActivationStoreConfig}, }; use chrono::{Timelike, Utc}; use sentry_protos::taskbroker::v1::{OnAttemptsExceeded, RetryState, TaskActivation}; -/// Generate a unique filename for isolated SQLite databases. +pub fn get_pg_url() -> String { + var("TASKBROKER_PG_URL").unwrap_or("postgres://postgres:password@localhost:5432/".to_string()) +} + +pub fn get_pg_database_name() -> String { + let random_name = format!("a{}", Uuid::new_v4().to_string().replace("-", "")); + var("TASKBROKER_PG_DATABASE_NAME").unwrap_or(random_name) +} + pub fn generate_temp_filename() -> String { - let mut rng = rand::thread_rng(); - format!("/var/tmp/{}-{}.sqlite", Utc::now(), rng.r#gen::()) + format!( + "/tmp/taskbroker-test-{}", + Uuid::new_v4().to_string().replace("-", "") + ) } /// Generate a unique alphanumeric string for namespaces (and possibly other purposes). @@ -90,15 +100,25 @@ pub fn create_config() -> Arc { } /// Create an InflightActivationStore instance -pub async fn create_test_store() -> Arc { - Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ) +pub async fn create_test_store(adapter: &str) -> Arc { + match adapter { + "sqlite" => Arc::new( + SqliteActivationStore::new( + &generate_temp_filename(), + InflightActivationStoreConfig::from_config(&create_integration_config()), + ) + .await + .unwrap(), + ) as Arc, + "postgres" => Arc::new( + PostgresActivationStore::new(PostgresActivationStoreConfig::from_config( + &create_integration_config(), + )) + .await + .unwrap(), + ) as Arc, + _ => panic!("Invalid adapter: {}", adapter), + } } /// Create a Config instance that uses a testing topic @@ -106,6 +126,8 @@ pub async fn create_test_store() -> Arc { /// with [`reset_topic`] pub fn create_integration_config() -> Arc { let config = Config { + pg_url: get_pg_url(), + pg_database_name: get_pg_database_name(), kafka_topic: "taskbroker-test".into(), kafka_auto_offset_reset: "earliest".into(), ..Config::default() @@ -114,6 +136,18 @@ pub fn create_integration_config() -> Arc { Arc::new(config) } +pub fn create_integration_config_with_topic(topic: String) -> Arc { + let config = Config { + pg_url: get_pg_url(), + pg_database_name: get_pg_database_name(), + kafka_topic: topic, + kafka_auto_offset_reset: "earliest".into(), + ..Config::default() + }; + + Arc::new(config) +} + /// Create a kafka producer for a given config pub fn create_producer(config: Arc) -> Arc { let producer: FutureProducer = config @@ -166,6 +200,7 @@ pub async fn consume_topic( let mut stream = consumer.stream(); let mut results: Vec = vec![]; + let mut last_message = None; let start = Utc::now(); loop { let current = Utc::now(); @@ -187,8 +222,12 @@ pub async fn consume_topic( let payload = message.payload().expect("Could not fetch message payload"); let activation = TaskActivation::decode(payload).unwrap(); results.push(activation); + last_message = Some(message); + } + // Commit the last message's offset so subsequent calls start from the next message + if let Some(msg) = last_message { + consumer.commit_message(&msg, CommitMode::Sync).unwrap(); } - results } diff --git a/src/upkeep.rs b/src/upkeep.rs index dbcfeb1e..fc485e13 100644 --- a/src/upkeep.rs +++ b/src/upkeep.rs @@ -510,6 +510,7 @@ mod tests { use chrono::{DateTime, TimeDelta, TimeZone, Utc}; use prost::Message; use prost_types::Timestamp; + use rstest::rstest; use sentry_protos::taskbroker::v1::{OnAttemptsExceeded, RetryState, TaskActivation}; use std::sync::Arc; use std::time::Duration; @@ -520,29 +521,15 @@ mod tests { use crate::{ config::Config, runtime_config::RuntimeConfigManager, - store::inflight_activation::{ - InflightActivationStatus, InflightActivationStore, InflightActivationStoreConfig, - SqliteActivationStore, - }, + store::inflight_activation::InflightActivationStatus, test_utils::{ StatusCount, assert_counts, consume_topic, create_config, create_integration_config, - create_producer, generate_temp_filename, make_activations, replace_retry_state, - reset_topic, + create_integration_config_with_topic, create_producer, create_test_store, + make_activations, replace_retry_state, reset_topic, }, upkeep::{create_retry_activation, do_upkeep}, }; - async fn create_inflight_store() -> Arc { - let url = generate_temp_filename(); - let config = create_integration_config(); - - Arc::new( - SqliteActivationStore::new(&url, InflightActivationStoreConfig::from_config(&config)) - .await - .unwrap(), - ) - } - #[tokio::test] async fn test_retry_activation_sets_delay_with_delay_on_retry() { let inflight = make_activations(1).remove(0); @@ -625,14 +612,17 @@ mod tests { } #[tokio::test] - async fn test_retry_activation_is_appended_to_kafka() { - let config = create_integration_config(); + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_retry_activation_is_appended_to_kafka(#[case] adapter: &str) { + let config = create_integration_config_with_topic(format!("taskbroker-test-{}", adapter)); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); reset_topic(config.clone()).await; let start_time = Utc::now(); let mut last_vacuum = Instant::now(); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let mut records = make_activations(2); @@ -706,10 +696,13 @@ mod tests { } #[tokio::test] - async fn test_processing_deadline_retains_future_deadline() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_processing_deadline_retains_future_deadline(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now() - Duration::from_secs(90); let mut last_vacuum = Instant::now(); @@ -741,10 +734,13 @@ mod tests { } #[tokio::test] - async fn test_processing_deadline_skip_past_deadline_after_startup() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_processing_deadline_skip_past_deadline_after_startup(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let mut batch = make_activations(2); @@ -792,10 +788,13 @@ mod tests { } #[tokio::test] - async fn test_processing_deadline_updates_past_deadline() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_processing_deadline_updates_past_deadline(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now() - Duration::from_secs(90); let mut last_vacuum = Instant::now(); @@ -805,6 +804,7 @@ mod tests { batch[1].status = InflightActivationStatus::Processing; batch[1].processing_deadline = Some(Utc.with_ymd_and_hms(2024, 11, 14, 21, 22, 23).unwrap()); + batch[1].processing_attempts = 0; assert!(store.store(batch.clone()).await.is_ok()); // Should start off with one in processing @@ -815,6 +815,13 @@ mod tests { .unwrap(), 1 ); + assert_eq!( + store + .count_by_status(InflightActivationStatus::Pending) + .await + .unwrap(), + 1 + ); let result_context = do_upkeep( config, @@ -826,6 +833,7 @@ mod tests { ) .await; + println!("result_context: {:?}", result_context); // 0 processing, 2 pending now assert_eq!(result_context.processing_deadline_reset, 1); assert_counts( @@ -840,10 +848,13 @@ mod tests { } #[tokio::test] - async fn test_processing_deadline_discard_at_most_once() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_processing_deadline_discard_at_most_once(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now() - Duration::from_secs(90); let mut last_vacuum = Instant::now(); @@ -890,10 +901,13 @@ mod tests { } #[tokio::test] - async fn test_processing_attempts_exceeded_discard() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_processing_attempts_exceeded_discard(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -941,12 +955,15 @@ mod tests { } #[tokio::test] - async fn test_remove_at_remove_failed_publish_to_kafka() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_remove_at_remove_failed_publish_to_kafka(#[case] adapter: &str) { let config = create_integration_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); reset_topic(config.clone()).await; - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -992,10 +1009,13 @@ mod tests { } #[tokio::test] - async fn test_remove_failed_discard() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_remove_failed_discard(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -1033,10 +1053,13 @@ mod tests { } #[tokio::test] - async fn test_expired_discard() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_expired_discard(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -1100,10 +1123,13 @@ mod tests { } #[tokio::test] - async fn test_delay_elapsed() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_delay_elapsed(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -1195,7 +1221,10 @@ mod tests { } #[tokio::test] - async fn test_forward_demoted_namespaces() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_forward_demoted_namespaces(#[case] adapter: &str) { // Create runtime config with demoted namespaces let config = create_config(); let test_yaml = r#" @@ -1209,7 +1238,7 @@ demoted_namespaces: fs::write(test_path, test_yaml).await.unwrap(); let runtime_config = Arc::new(RuntimeConfigManager::new(Some(test_path.to_string())).await); let producer = create_producer(config.clone()); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -1246,11 +1275,14 @@ demoted_namespaces: 2, "two tasks should be marked as complete" ); - fs::remove_file(test_path).await.unwrap(); + let _ = fs::remove_file(test_path).await; } #[tokio::test] - async fn test_remove_killswitched() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_remove_killswitched(#[case] adapter: &str) { let config = create_config(); let test_yaml = r#" drop_task_killswitch: @@ -1263,7 +1295,7 @@ demoted_namespaces: let runtime_config = Arc::new(RuntimeConfigManager::new(Some(test_path.to_string())).await); let producer = create_producer(config.clone()); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -1294,11 +1326,14 @@ demoted_namespaces: 3 ); - fs::remove_file(test_path).await.unwrap(); + let _ = fs::remove_file(test_path).await; } #[tokio::test] - async fn test_full_vacuum_on_upkeep() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_full_vacuum_on_upkeep(#[case] adapter: &str) { let raw_config = Config { full_vacuum_on_start: true, ..Default::default() @@ -1306,7 +1341,7 @@ demoted_namespaces: let config = Arc::new(raw_config); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now() - Duration::from_secs(90); let mut last_vacuum = Instant::now() - Duration::from_secs(60); From 04b13bd36825a32579e2b5f1e608b029f2158662 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Tue, 13 Jan 2026 16:45:17 -0500 Subject: [PATCH 2/2] remove unnecessary Option --- src/store/postgres_activation_store.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/store/postgres_activation_store.rs b/src/store/postgres_activation_store.rs index 363bbda1..dd86697b 100644 --- a/src/store/postgres_activation_store.rs +++ b/src/store/postgres_activation_store.rs @@ -221,8 +221,8 @@ impl InflightActivationStore for PostgresActivationStore { b.push_bind(row.added_at); b.push_bind(row.received_at); b.push_bind(row.processing_attempts); - b.push_bind(Some(row.expires_at)); - b.push_bind(Some(row.delay_until)); + b.push_bind(row.expires_at); + b.push_bind(row.delay_until); b.push_bind(row.processing_deadline_duration); if let Some(deadline) = row.processing_deadline { b.push_bind(deadline);