From 11d86516891687d9e3a7699fc248e60e2ede24e5 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 2 Jan 2026 13:26:44 -0800 Subject: [PATCH 01/10] Add Push Loop, Task Pusher, and Worker Pool --- Cargo.lock | 5 +- Cargo.toml | 4 +- src/config.rs | 8 ++ src/lib.rs | 2 + src/main.rs | 34 ++++++- src/pool.rs | 114 ++++++++++++++++++++++ src/push.rs | 161 +++++++++++++++++++++++++++++++ src/store/inflight_activation.rs | 64 ++++++++++++ src/test_utils.rs | 4 +- 9 files changed, 386 insertions(+), 10 deletions(-) create mode 100644 src/pool.rs create mode 100644 src/push.rs diff --git a/Cargo.lock b/Cargo.lock index cf3f6400..c6d00174 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2431,8 +2431,7 @@ dependencies = [ [[package]] name = "sentry_protos" version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db6b0a26106f3a2fae5791618daafbdde92502c09dcbf48006db07c7fa0ba733" +source = "git+https://github.com/getsentry/sentry-protos?branch=george%2Fpush-broker-worker#c3015d4807f208c22e448a17ff1b3e42e066ba5c" dependencies = [ "prost", "prost-types", @@ -2890,7 +2889,7 @@ dependencies = [ "metrics-exporter-statsd", "prost", "prost-types", - "rand 0.8.5", + "rand 0.9.2", "rdkafka", "sentry", "sentry_protos", diff --git a/Cargo.toml b/Cargo.toml index 4ca8308f..28e28bb2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ metrics = "0.24.0" metrics-exporter-statsd = "0.9.0" prost = "0.13" prost-types = "0.13.3" -rand = "0.8.5" +rand = "0.9.2" rdkafka = { version = "0.37.0", features = ["cmake-build", "ssl"] } sentry = { version = "0.41.0", default-features = false, features = [ # default features, except `release-health` is disabled @@ -39,7 +39,7 @@ sentry = { version = "0.41.0", default-features = false, features = [ "tracing", "logs" ] } -sentry_protos = "0.4.10" +sentry_protos = { git = "https://github.com/getsentry/sentry-protos", branch = "george/push-broker-worker" } serde = "1.0.214" serde_yaml = "0.9.34" sha2 = "0.10.8" diff --git a/src/config.rs b/src/config.rs index d4bbaf2b..d2b94804 100644 --- a/src/config.rs +++ b/src/config.rs @@ -214,6 +214,12 @@ pub struct Config { /// Enable additional metrics for the sqlite. pub enable_sqlite_status_metrics: bool, + + /// Enable push mode. + pub push: bool, + + /// Worker addresses. + pub workers: Vec, } impl Default for Config { @@ -279,6 +285,8 @@ impl Default for Config { full_vacuum_on_upkeep: true, vacuum_interval_ms: 30000, enable_sqlite_status_metrics: true, + push: false, + workers: vec![], } } } diff --git a/src/lib.rs b/src/lib.rs index 33567944..6798fd05 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,8 @@ pub mod grpc; pub mod kafka; pub mod logging; pub mod metrics; +pub mod pool; +pub mod push; pub mod runtime_config; pub mod store; pub mod test_utils; diff --git a/src/main.rs b/src/main.rs index 3998e116..956f7455 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,8 +5,11 @@ use std::{sync::Arc, time::Duration}; use taskbroker::kafka::inflight_activation_batcher::{ ActivationBatcherConfig, InflightActivationBatcher, }; +use taskbroker::pool::WorkerPool; +use taskbroker::push::TaskPusher; use taskbroker::upkeep::upkeep; use tokio::signal::unix::SignalKind; +use tokio::sync::RwLock; use tokio::task::JoinHandle; use tokio::{select, time}; use tonic::transport::Server; @@ -57,6 +60,8 @@ async fn main() -> Result<(), Error> { let runtime_config_manager = Arc::new(RuntimeConfigManager::new(config.runtime_config_path.clone()).await); + let pool = Arc::new(RwLock::new(WorkerPool::new(config.workers.clone()))); + println!("taskbroker starting"); println!("version: {}", get_version().trim()); @@ -177,6 +182,23 @@ async fn main() -> Result<(), Error> { } }); + // Push task loop (conditionally enabled) + let push_task = if config.push { + info!("Running in PUSH mode"); + + let push_task_store = store.clone(); + let push_task_config = config.clone(); + let push_task_pool = pool.clone(); + + Some(tokio::spawn(async move { + let pusher = TaskPusher::new(push_task_store, push_task_config, push_task_pool); + pusher.start().await + })) + } else { + info!("Running in PULL mode"); + None + }; + // GRPC server let grpc_server_task = tokio::spawn({ let grpc_store = store.clone(); @@ -225,7 +247,7 @@ async fn main() -> Result<(), Error> { } }); - elegant_departure::tokio::depart() + let mut depart = elegant_departure::tokio::depart() .on_termination() .on_sigint() .on_signal(SignalKind::hangup()) @@ -233,8 +255,14 @@ async fn main() -> Result<(), Error> { .on_completion(log_task_completion("consumer", consumer_task)) .on_completion(log_task_completion("grpc_server", grpc_server_task)) .on_completion(log_task_completion("upkeep_task", upkeep_task)) - .on_completion(log_task_completion("maintenance_task", maintenance_task)) - .await; + .on_completion(log_task_completion("maintenance_task", maintenance_task)); + + // Only register push_task if it was spawned + if let Some(task) = push_task { + depart = depart.on_completion(log_task_completion("push_task", task)); + } + + depart.await; Ok(()) } diff --git a/src/pool.rs b/src/pool.rs new file mode 100644 index 00000000..6ba6b147 --- /dev/null +++ b/src/pool.rs @@ -0,0 +1,114 @@ +use std::collections::HashMap; + +use anyhow::Result; +use rand::seq::IteratorRandom; +use sentry_protos::taskworker::v1::{PushTaskRequest, worker_service_client::WorkerServiceClient}; +use tonic::transport::{Channel, Error}; +use tracing::{info, warn}; + +#[derive(Clone)] +pub struct WorkerPool { + /// Maps every worker address to its client. + clients: HashMap, + + /// All available worker addresses. + addresses: Vec, +} + +#[derive(Clone)] +struct WorkerClient { + /// The actual RPC client connection. + connection: WorkerServiceClient, + + /// The worker address. + address: String, + + /// The worker's ESTIMATED queue size. + queue_size: u32, +} + +impl WorkerClient { + pub fn new(connection: WorkerServiceClient, address: String, queue_size: u32) -> Self { + Self { + connection, + address, + queue_size, + } + } +} + +impl WorkerPool { + /// Create a new `WorkerPool` instance. + pub fn new(addresses: Vec) -> Self { + Self { + addresses, + clients: HashMap::new(), + } + } + + /// Call this function over and over again in another thread to keep the pool of active connections updated. + pub async fn update(&mut self) { + for address in self.addresses.clone() { + if !self.clients.contains_key(&address) { + match connect(&address).await { + Ok(connection) => { + info!("Connected to {address}"); + + let client = WorkerClient::new(connection, address.clone(), 0); + self.clients.insert(address, client); + } + + Err(e) => { + warn!("Couldn't connect to {address} - {:?}", e); + } + } + } + } + } + + /// Try pushing a task to the best worker using P2C (Power of Two Choices). + pub async fn push(&mut self, request: PushTaskRequest) -> Result<()> { + let candidate = { + let mut rng = rand::rng(); + + self.clients + .values() + .choose_multiple(&mut rng, 2) + .into_iter() + .min_by_key(|client| client.queue_size) + .cloned() + }; + + let Some(mut client) = candidate else { + return Err(anyhow::anyhow!("No connected workers")); + }; + + let address = client.address.clone(); + + match client.connection.push_task(request).await { + Ok(response) => { + let response = response.into_inner(); + + if !response.added { + return Err(anyhow::anyhow!("Selected worker was full")); + } + + // Update this worker's queue size + client.queue_size = 5; + self.clients.insert(address, client); + Ok(()) + } + + Err(e) => { + // Remove this unhealthy worker from the active connection pool + self.clients.remove(&address); + Err(e.into()) + } + } + } +} + +#[inline] +async fn connect>(address: T) -> Result, Error> { + WorkerServiceClient::connect(address.into()).await +} diff --git a/src/push.rs b/src/push.rs new file mode 100644 index 00000000..b6167304 --- /dev/null +++ b/src/push.rs @@ -0,0 +1,161 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::Result; +use prost::Message; +use sentry_protos::{taskbroker::v1::TaskActivation, taskworker::v1::PushTaskRequest}; +use tokio::{sync::RwLock, time::sleep}; +use tracing::{debug, error, info, warn}; + +use crate::config::Config; +use crate::pool::WorkerPool; +use crate::store::inflight_activation::{InflightActivation, InflightActivationStore}; + +pub struct TaskPusher { + /// Pool of workers through which we will push tasks. + pool: Arc>, + + /// Broker configuration. + config: Arc, + + /// Inflight activation store. + store: Arc, +} + +impl TaskPusher { + /// Create a new `TaskPusher` instance. + pub fn new( + store: Arc, + config: Arc, + pool: Arc>, + ) -> Self { + Self { + store, + config, + pool, + } + } + + /// Start the worker update and push task loops. + pub async fn start(self) -> Result<()> { + info!("Push task loop starting..."); + + let guard = elegant_departure::get_shutdown_guard().shutdown_on_drop(); + let pool = self.pool.clone(); + + // Spawn a separate task to update the candidate worker collection + tokio::spawn(async move { + let guard = elegant_departure::get_shutdown_guard().shutdown_on_drop(); + let mut beep_interval = tokio::time::interval(Duration::from_secs(1)); + + loop { + tokio::select! { + _ = guard.wait() => { + break; + } + + _ = beep_interval.tick() => { + pool.write().await.update().await; + } + } + } + }); + + loop { + tokio::select! { + _ = guard.wait() => { + info!("Push task loop received shutdown signal"); + break; + } + + _ = async { + self.process_next_task().await; + } => {} + } + } + + info!("Push task loop shutting down..."); + Ok(()) + } +} + +impl TaskPusher { + /// Grab the next pending task from the store. + async fn process_next_task(&self) { + match self.store.peek_pending_activation().await { + Ok(Some(inflight)) => { + if let Err(e) = self.handle_task_push(inflight).await { + warn!("Task push resulted in error - {:?}", e); + } + } + + Ok(None) => { + debug!("No pending tasks, sleeping briefly"); + sleep(milliseconds(100)).await; + } + + Err(e) => { + error!("Failed to fetch pending activation - {:?}", e); + sleep(milliseconds(100)).await; + } + } + } + + /// Decode task activation and push it to a worker. + async fn handle_task_push(&self, inflight: InflightActivation) -> Result<()> { + let task_id = inflight.id.clone(); + + let activation = TaskActivation::decode(&inflight.activation as &[u8]).map_err(|e| { + error!("Failed to decode activation {task_id}: {:?}", e); + e + })?; + + self.push_to_worker(activation, &task_id).await + } + + /// Build an RPC request and send it to the worker pool to be pushed. + async fn push_to_worker(&self, activation: TaskActivation, task_id: &str) -> Result<()> { + let request = PushTaskRequest { + task: Some(activation), + callback_url: format!("{}:{}", self.config.grpc_addr, self.config.grpc_port), + }; + + let result = self.pool.write().await.push(request).await; + + match result { + Ok(()) => { + info!("Pushed task {task_id}"); + self.mark_task_as_processing(task_id).await; + Ok(()) + } + + Err(e) => { + warn!("Could not push task {task_id} - {:?}", e); + sleep(milliseconds(100)).await; + Err(e) + } + } + } + + /// Mark task with id `task_id` as processing if it's still pending. + async fn mark_task_as_processing(&self, task_id: &str) { + match self.store.mark_as_processing_if_pending(task_id).await { + Ok(true) => { + info!("Task {} pushed and marked as processing", task_id); + } + + Ok(false) => { + warn!("Task {task_id} was already taken by another process (race condition)"); + } + + Err(e) => { + error!("Failed to mark task {task_id} as processing - {:?}", e); + sleep(milliseconds(100)).await; + } + } + } +} + +#[inline] +fn milliseconds(i: u64) -> Duration { + Duration::from_millis(i) +} diff --git a/src/store/inflight_activation.rs b/src/store/inflight_activation.rs index 5c911a19..900e5270 100644 --- a/src/store/inflight_activation.rs +++ b/src/store/inflight_activation.rs @@ -604,6 +604,70 @@ impl InflightActivationStore { meta_result } + #[instrument(skip_all)] + pub async fn peek_pending_activation(&self) -> Result, Error> { + let now = Utc::now(); + + let row_result: Option = sqlx::query_as( + " + SELECT id, + activation, + partition, + offset, + added_at, + received_at, + processing_attempts, + expires_at, + delay_until, + processing_deadline_duration, + processing_deadline, + status, + at_most_once, + namespace, + taskname, + on_attempts_exceeded + FROM inflight_taskactivations + WHERE status = $1 + AND (expires_at IS NULL OR expires_at > $2) + ORDER BY added_at + LIMIT 1 + ", + ) + .bind(InflightActivationStatus::Pending) + .bind(now.timestamp()) + .fetch_optional(&self.read_pool) + .await?; + + Ok(row_result.map(|row| row.into())) + } + + #[instrument(skip_all)] + pub async fn mark_as_processing_if_pending(&self, id: &str) -> Result { + let grace_period = self.config.processing_deadline_grace_sec; + let mut conn = self + .acquire_write_conn_metric("mark_as_processing_if_pending") + .await?; + + let result: Option = sqlx::query_as(&format!( + "UPDATE inflight_taskactivations + SET + processing_deadline = unixepoch( + 'now', '+' || (processing_deadline_duration + {grace_period}) || ' seconds' + ), + status = $1 + WHERE id = $2 + AND status = $3 + RETURNING *" + )) + .bind(InflightActivationStatus::Processing) + .bind(id) + .bind(InflightActivationStatus::Pending) + .fetch_optional(&mut *conn) + .await?; + + Ok(result.is_some()) + } + #[instrument(skip_all)] pub async fn get_pending_activation( &self, diff --git a/src/test_utils.rs b/src/test_utils.rs index 9ae23de5..4f480a68 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -22,8 +22,8 @@ use sentry_protos::taskbroker::v1::{OnAttemptsExceeded, RetryState, TaskActivati /// Generate a unique filename for isolated SQLite databases. pub fn generate_temp_filename() -> String { - let mut rng = rand::thread_rng(); - format!("/var/tmp/{}-{}.sqlite", Utc::now(), rng.r#gen::()) + let mut rng = rand::rng(); + format!("/var/tmp/{}-{}.sqlite", Utc::now(), rng.random::()) } /// Generate a unique alphanumeric string for namespaces (and possibly other purposes). From bb198670ee8d8d41991dc0336b2a713d1bd1d80b Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 2 Jan 2026 13:45:45 -0800 Subject: [PATCH 02/10] Implement Add and Remove Worker gRPC Methods --- Cargo.lock | 2 +- src/grpc/server.rs | 26 +++++++++++++++++++++++++- src/grpc/server_tests.rs | 22 ++++++++++++++++------ src/main.rs | 3 +++ src/pool.rs | 18 ++++++++++++++---- src/test_utils.rs | 7 +++++++ 6 files changed, 66 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c6d00174..7faf0ac4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2431,7 +2431,7 @@ dependencies = [ [[package]] name = "sentry_protos" version = "0.4.10" -source = "git+https://github.com/getsentry/sentry-protos?branch=george%2Fpush-broker-worker#c3015d4807f208c22e448a17ff1b3e42e066ba5c" +source = "git+https://github.com/getsentry/sentry-protos?branch=george%2Fpush-broker-worker#5392cee7d1aeeb91f7501542e105bb3e391a2acd" dependencies = [ "prost", "prost-types", diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 99fe03d8..6e1bbea7 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -2,22 +2,46 @@ use chrono::Utc; use prost::Message; use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerService; use sentry_protos::taskbroker::v1::{ - FetchNextTask, GetTaskRequest, GetTaskResponse, SetTaskStatusRequest, SetTaskStatusResponse, + AddWorkerRequest, AddWorkerResponse, FetchNextTask, GetTaskRequest, GetTaskResponse, + RemoveWorkerRequest, RemoveWorkerResponse, SetTaskStatusRequest, SetTaskStatusResponse, TaskActivation, TaskActivationStatus, }; use std::sync::Arc; use std::time::Instant; +use tokio::sync::RwLock; use tonic::{Request, Response, Status}; +use crate::pool::WorkerPool; use crate::store::inflight_activation::{InflightActivationStatus, InflightActivationStore}; use tracing::{error, instrument}; pub struct TaskbrokerServer { pub store: Arc, + pub pool: Arc>, } #[tonic::async_trait] impl ConsumerService for TaskbrokerServer { + #[instrument(skip_all)] + async fn add_worker( + &self, + request: Request, + ) -> Result, Status> { + let address = &request.get_ref().address; + self.pool.write().await.add_worker(address); + Ok(Response::new(AddWorkerResponse {})) + } + + #[instrument(skip_all)] + async fn remove_worker( + &self, + request: Request, + ) -> Result, Status> { + let address = &request.get_ref().address; + self.pool.write().await.remove_worker(address); + Ok(Response::new(RemoveWorkerResponse {})) + } + #[instrument(skip_all)] async fn get_task( &self, diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index 6387b44d..edfea9d4 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -4,12 +4,14 @@ use tonic::{Code, Request}; use crate::grpc::server::TaskbrokerServer; -use crate::test_utils::{create_test_store, make_activations}; +use crate::test_utils::{create_pool, create_test_store, make_activations}; #[tokio::test] async fn test_get_task() { let store = create_test_store().await; - let service = TaskbrokerServer { store }; + let pool = create_pool(); + + let service = TaskbrokerServer { store, pool }; let request = GetTaskRequest { namespace: None }; let response = service.get_task(Request::new(request)).await; assert!(response.is_err()); @@ -22,7 +24,9 @@ async fn test_get_task() { #[allow(deprecated)] async fn test_set_task_status() { let store = create_test_store().await; - let service = TaskbrokerServer { store }; + let pool = create_pool(); + + let service = TaskbrokerServer { store, pool }; let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 5, // Complete @@ -38,7 +42,9 @@ async fn test_set_task_status() { #[allow(deprecated)] async fn test_set_task_status_invalid() { let store = create_test_store().await; - let service = TaskbrokerServer { store }; + let pool = create_pool(); + + let service = TaskbrokerServer { store, pool }; let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 1, // Invalid @@ -58,10 +64,12 @@ async fn test_set_task_status_invalid() { #[allow(deprecated)] async fn test_get_task_success() { let store = create_test_store().await; + let pool = create_pool(); + let activations = make_activations(1); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, pool }; let request = GetTaskRequest { namespace: None }; let response = service.get_task(Request::new(request)).await; assert!(response.is_ok()); @@ -75,10 +83,12 @@ async fn test_get_task_success() { #[allow(deprecated)] async fn test_set_task_status_success() { let store = create_test_store().await; + let pool = create_pool(); + let activations = make_activations(2); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store }; + let service = TaskbrokerServer { store, pool }; let request = GetTaskRequest { namespace: None }; let response = service.get_task(Request::new(request)).await; diff --git a/src/main.rs b/src/main.rs index 956f7455..f15b9d34 100644 --- a/src/main.rs +++ b/src/main.rs @@ -203,6 +203,8 @@ async fn main() -> Result<(), Error> { let grpc_server_task = tokio::spawn({ let grpc_store = store.clone(); let grpc_config = config.clone(); + let grpc_pool = pool.clone(); + async move { let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) .parse() @@ -217,6 +219,7 @@ async fn main() -> Result<(), Error> { .layer(layers) .add_service(ConsumerServiceServer::new(TaskbrokerServer { store: grpc_store, + pool: grpc_pool, })) .add_service(health_service.clone()) .serve(addr); diff --git a/src/pool.rs b/src/pool.rs index 6ba6b147..9bff05ff 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use anyhow::Result; use rand::seq::IteratorRandom; @@ -12,7 +12,7 @@ pub struct WorkerPool { clients: HashMap, /// All available worker addresses. - addresses: Vec, + addresses: HashSet, } #[derive(Clone)] @@ -39,13 +39,23 @@ impl WorkerClient { impl WorkerPool { /// Create a new `WorkerPool` instance. - pub fn new(addresses: Vec) -> Self { + pub fn new>(addresses: T) -> Self { Self { - addresses, + addresses: addresses.into_iter().collect(), clients: HashMap::new(), } } + /// Register worker address during execution. + pub fn add_worker>(&mut self, address: T) { + self.addresses.insert(address.into()); + } + + /// Unregister worker address during execution. + pub fn remove_worker(&mut self, address: &String) { + self.addresses.remove(address); + } + /// Call this function over and over again in another thread to keep the pool of active connections updated. pub async fn update(&mut self) { for address in self.addresses.clone() { diff --git a/src/test_utils.rs b/src/test_utils.rs index 4f480a68..6763af85 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -8,10 +8,12 @@ use rdkafka::{ producer::FutureProducer, }; use std::{collections::HashMap, sync::Arc}; +use tokio::sync::RwLock; use uuid::Uuid; use crate::{ config::Config, + pool::WorkerPool, store::inflight_activation::{ InflightActivation, InflightActivationStatus, InflightActivationStore, InflightActivationStoreConfig, @@ -87,6 +89,11 @@ pub fn create_config() -> Arc { Arc::new(Config::default()) } +/// Create a basic [`WorkerPool`]. +pub fn create_pool() -> Arc> { + Arc::new(RwLock::new(WorkerPool::new(["127.0.0.1:50052".into()]))) +} + /// Create an InflightActivationStore instance pub async fn create_test_store() -> Arc { Arc::new( From 8a5accba15de1f8be907fc7b52131f58b1e969fe Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 2 Jan 2026 15:27:35 -0800 Subject: [PATCH 03/10] Randomize Worker Selection --- Cargo.lock | 1 + Cargo.toml | 1 + config.yaml | 6 ++++++ src/pool.rs | 37 +++++++++++++++++++++++++++++++++---- 4 files changed, 41 insertions(+), 4 deletions(-) create mode 100644 config.yaml diff --git a/Cargo.lock b/Cargo.lock index 7faf0ac4..4a9c587f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2884,6 +2884,7 @@ dependencies = [ "hmac", "http", "http-body-util", + "itertools 0.14.0", "libsqlite3-sys", "metrics", "metrics-exporter-statsd", diff --git a/Cargo.toml b/Cargo.toml index 28e28bb2..1ba7368b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ sentry = { version = "0.41.0", default-features = false, features = [ "logs" ] } sentry_protos = { git = "https://github.com/getsentry/sentry-protos", branch = "george/push-broker-worker" } +itertools = "0.14.0" serde = "1.0.214" serde_yaml = "0.9.34" sha2 = "0.10.8" diff --git a/config.yaml b/config.yaml new file mode 100644 index 00000000..e522cc23 --- /dev/null +++ b/config.yaml @@ -0,0 +1,6 @@ +kafka_topic: "test-topic" +push: true +workers: + - "http://127.0.0.1:50052" + - "http://127.0.0.1:50053" + - "http://127.0.0.1:50054" diff --git a/src/pool.rs b/src/pool.rs index 9bff05ff..327e8d05 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,6 +1,11 @@ -use std::collections::{HashMap, HashSet}; +use std::{ + cmp::Ordering, + collections::{HashMap, HashSet}, +}; use anyhow::Result; +use itertools::Itertools; +use rand::Rng; use rand::seq::IteratorRandom; use sentry_protos::taskworker::v1::{PushTaskRequest, worker_service_client::WorkerServiceClient}; use tonic::transport::{Channel, Error}; @@ -48,12 +53,16 @@ impl WorkerPool { /// Register worker address during execution. pub fn add_worker>(&mut self, address: T) { - self.addresses.insert(address.into()); + let address = address.into(); + info!("Adding worker {address}"); + self.addresses.insert(address); } /// Unregister worker address during execution. pub fn remove_worker(&mut self, address: &String) { + info!("Removing worker {address}"); self.addresses.remove(address); + self.clients.remove(address); } /// Call this function over and over again in another thread to keep the pool of active connections updated. @@ -74,6 +83,14 @@ impl WorkerPool { } } } + + let pool = self + .clients + .iter() + .map(|(address, client)| format!("{address}:{}", client.queue_size)) + .join(","); + + info!(pool) } /// Try pushing a task to the best worker using P2C (Power of Two Choices). @@ -85,7 +102,19 @@ impl WorkerPool { .values() .choose_multiple(&mut rng, 2) .into_iter() - .min_by_key(|client| client.queue_size) + .min_by(|a, b| { + match a.queue_size.cmp(&b.queue_size) { + // When two workers are the same, we pick one randomly to avoid hammering one worker repeatedly + Ordering::Equal => { + if rng.random::() { + Ordering::Less + } else { + Ordering::Greater + } + } + other => other, + } + }) .cloned() }; @@ -104,7 +133,7 @@ impl WorkerPool { } // Update this worker's queue size - client.queue_size = 5; + client.queue_size = response.queue_size; self.clients.insert(address, client); Ok(()) } From 420508b85ffff9b666d335a815b27d8a62a97f4e Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 2 Jan 2026 17:00:21 -0800 Subject: [PATCH 04/10] Improve Queue Size Counting Logic --- Cargo.lock | 2 +- src/grpc/server.rs | 6 ++++++ src/grpc/server_tests.rs | 3 +++ src/pool.rs | 7 +++++++ 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 4a9c587f..58cac758 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2431,7 +2431,7 @@ dependencies = [ [[package]] name = "sentry_protos" version = "0.4.10" -source = "git+https://github.com/getsentry/sentry-protos?branch=george%2Fpush-broker-worker#5392cee7d1aeeb91f7501542e105bb3e391a2acd" +source = "git+https://github.com/getsentry/sentry-protos?branch=george%2Fpush-broker-worker#a4b97fe5ae594996e20509f4053db360203e9c3c" dependencies = [ "prost", "prost-types", diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 6e1bbea7..c92ffd8f 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -91,6 +91,12 @@ impl ConsumerService for TaskbrokerServer { let start_time = Instant::now(); let id = request.get_ref().id.clone(); + // Update worker queue size estimate + self.pool + .write() + .await + .decrement_queue_size(&request.get_ref().address); + let status: InflightActivationStatus = TaskActivationStatus::try_from(request.get_ref().status) .map_err(|e| { diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index edfea9d4..f435ebbd 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -31,6 +31,7 @@ async fn test_set_task_status() { id: "test_task".to_string(), status: 5, // Complete fetch_next_task: None, + address: "http://127.0.0.1:50052".into(), }; let response = service.set_task_status(Request::new(request)).await; assert!(response.is_ok()); @@ -49,6 +50,7 @@ async fn test_set_task_status_invalid() { id: "test_task".to_string(), status: 1, // Invalid fetch_next_task: None, + address: "http://127.0.0.1:50052".into(), }; let response = service.set_task_status(Request::new(request)).await; assert!(response.is_err()); @@ -102,6 +104,7 @@ async fn test_set_task_status_success() { id: "id_0".to_string(), status: 5, // Complete fetch_next_task: Some(FetchNextTask { namespace: None }), + address: "http://127.0.0.1:50052".into(), }; let response = service.set_task_status(Request::new(request)).await; assert!(response.is_ok()); diff --git a/src/pool.rs b/src/pool.rs index 327e8d05..fc1b6bf7 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -65,6 +65,13 @@ impl WorkerPool { self.clients.remove(address); } + /// Decrement `queue_size` for the worker with address `address`. Called when worker reports task status. + pub fn decrement_queue_size(&mut self, address: &String) { + if let Some(client) = self.clients.get_mut(address) { + client.queue_size -= 1; + } + } + /// Call this function over and over again in another thread to keep the pool of active connections updated. pub async fn update(&mut self) { for address in self.addresses.clone() { From 1b51f1489f3203e05ae6502897bc2e05c9fd978c Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 2 Jan 2026 17:03:03 -0800 Subject: [PATCH 05/10] Make Clippy Happy --- src/pool.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/pool.rs b/src/pool.rs index fc1b6bf7..8ebb0465 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,7 +1,5 @@ -use std::{ - cmp::Ordering, - collections::{HashMap, HashSet}, -}; +use std::cmp::Ordering; +use std::collections::{HashMap, HashSet, hash_map::Entry}; use anyhow::Result; use itertools::Itertools; @@ -74,14 +72,14 @@ impl WorkerPool { /// Call this function over and over again in another thread to keep the pool of active connections updated. pub async fn update(&mut self) { - for address in self.addresses.clone() { - if !self.clients.contains_key(&address) { - match connect(&address).await { + for address in &self.addresses { + if let Entry::Vacant(e) = self.clients.entry(address.into()) { + match connect(address).await { Ok(connection) => { info!("Connected to {address}"); let client = WorkerClient::new(connection, address.clone(), 0); - self.clients.insert(address, client); + e.insert(client); } Err(e) => { From 646661997e5d694172fca4dd8bdc77a04a2799cb Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 2 Jan 2026 17:11:02 -0800 Subject: [PATCH 06/10] More Clippy Appeasement --- benches/store_bench.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/benches/store_bench.rs b/benches/store_bench.rs index af5ce5f5..0b767c50 100644 --- a/benches/store_bench.rs +++ b/benches/store_bench.rs @@ -15,11 +15,11 @@ use tokio::task::JoinSet; async fn get_pending_activations(num_activations: u32, num_workers: u32) { let url = if cfg!(feature = "bench-with-mnt-disk") { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); format!( "/mnt/disks/sqlite/{}-{}.sqlite", Utc::now(), - rng.r#gen::() + rng.random::() ) } else { generate_temp_filename() @@ -78,11 +78,11 @@ async fn set_status(num_activations: u32, num_workers: u32) { assert!(num_activations.is_multiple_of(num_workers)); let url = if cfg!(feature = "bench-with-mnt-disk") { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); format!( "/mnt/disks/sqlite/{}-{}.sqlite", Utc::now(), - rng.r#gen::() + rng.random::() ) } else { generate_temp_filename() From 46f8cece4ada6b898828ec7a95d9a57fa0f7eeac Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 5 Jan 2026 17:19:18 -0800 Subject: [PATCH 07/10] Add Queue Size Estimate Error Metrics --- config.yaml | 8 ++++---- src/grpc/server.rs | 8 ++++---- src/pool.rs | 31 +++++++++++++++++++++++++------ 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/config.yaml b/config.yaml index e522cc23..dcf846be 100644 --- a/config.yaml +++ b/config.yaml @@ -1,6 +1,6 @@ kafka_topic: "test-topic" push: true -workers: - - "http://127.0.0.1:50052" - - "http://127.0.0.1:50053" - - "http://127.0.0.1:50054" +# workers: +# - "http://127.0.0.1:50052" +# - "http://127.0.0.1:50053" +# - "http://127.0.0.1:50054" diff --git a/src/grpc/server.rs b/src/grpc/server.rs index c92ffd8f..c4d10f1c 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -92,10 +92,10 @@ impl ConsumerService for TaskbrokerServer { let id = request.get_ref().id.clone(); // Update worker queue size estimate - self.pool - .write() - .await - .decrement_queue_size(&request.get_ref().address); + // self.pool + // .write() + // .await + // .decrement_queue_size(&request.get_ref().address); let status: InflightActivationStatus = TaskActivationStatus::try_from(request.get_ref().status) diff --git a/src/pool.rs b/src/pool.rs index 8ebb0465..35e07d3b 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -64,11 +64,11 @@ impl WorkerPool { } /// Decrement `queue_size` for the worker with address `address`. Called when worker reports task status. - pub fn decrement_queue_size(&mut self, address: &String) { - if let Some(client) = self.clients.get_mut(address) { - client.queue_size -= 1; - } - } + // pub fn decrement_queue_size(&mut self, address: &String) { + // if let Some(client) = self.clients.get_mut(address) { + // client.queue_size = client.queue_size.saturating_sub(1); + // } + // } /// Call this function over and over again in another thread to keep the pool of active connections updated. pub async fn update(&mut self) { @@ -137,8 +137,27 @@ impl WorkerPool { return Err(anyhow::anyhow!("Selected worker was full")); } + // Calculate estimation error before updating + let estimated = client.queue_size; + let actual = response.queue_size; + let error = (estimated as i64) - (actual as i64); + + // Record the absolute error + metrics::histogram!("worker.queue_size.estimation_error", "worker" => address.clone()) + .record(error.abs() as f64); + + // Record the signed error to see if we're systematically over/under-estimating + metrics::histogram!("worker.queue_size.estimation_delta", "worker" => address.clone()) + .record(error as f64); + + // Record both values for reference + metrics::gauge!("worker.queue_size.estimated", "worker" => address.clone()) + .set(estimated as f64); + metrics::gauge!("worker.queue_size.actual", "worker" => address.clone()) + .set(actual as f64); + // Update this worker's queue size - client.queue_size = response.queue_size; + client.queue_size = actual; self.clients.insert(address, client); Ok(()) } From c7be82e6502a7080fe920de74c66151b1f6fed75 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 6 Jan 2026 14:30:18 -0800 Subject: [PATCH 08/10] More Metrics for Testing --- src/pool.rs | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/src/pool.rs b/src/pool.rs index 35e07d3b..02424cdb 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -2,6 +2,7 @@ use std::cmp::Ordering; use std::collections::{HashMap, HashSet, hash_map::Entry}; use anyhow::Result; +use chrono::{DateTime, Utc}; use itertools::Itertools; use rand::Rng; use rand::seq::IteratorRandom; @@ -28,6 +29,12 @@ struct WorkerClient { /// The worker's ESTIMATED queue size. queue_size: u32, + + /// (TEMP) How many times has this worker been hit? + hits: u32, + + /// (TEMP) When was the last time this worker was hit? + last_hit: DateTime, } impl WorkerClient { @@ -36,6 +43,8 @@ impl WorkerClient { connection, address, queue_size, + hits: 0, + last_hit: Utc::now(), } } } @@ -143,16 +152,32 @@ impl WorkerPool { let error = (estimated as i64) - (actual as i64); // Record the absolute error - metrics::histogram!("worker.queue_size.estimation_error", "worker" => address.clone()) - .record(error.abs() as f64); + // metrics::histogram!("worker.queue_size.estimation_error", "worker" => address.clone()) + // .record(error.abs() as f64); // Record the signed error to see if we're systematically over/under-estimating - metrics::histogram!("worker.queue_size.estimation_delta", "worker" => address.clone()) - .record(error as f64); + // metrics::histogram!("worker.queue_size.estimation_delta", "worker" => address.clone()) + // .record(error as f64); + + client.hits += 1; + + metrics::gauge!("worker.queue_size.hits", "worker" => address.clone()) + .set(client.hits as f64); + + let now = Utc::now(); + let time_delta = (now - client.last_hit).as_seconds_f64(); + client.last_hit = now; + + metrics::gauge!("worker.queue_size.time_since_hit", "worker" => address.clone()) + .set(time_delta); + + metrics::gauge!("worker.queue_size.delta", "worker" => address.clone()) + .set(error as f64); // Record both values for reference metrics::gauge!("worker.queue_size.estimated", "worker" => address.clone()) .set(estimated as f64); + metrics::gauge!("worker.queue_size.actual", "worker" => address.clone()) .set(actual as f64); From 5934f64e3739dd54ba7d88b7139ce4bceaa299eb Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 9 Jan 2026 14:21:40 -0800 Subject: [PATCH 09/10] Fix Bugs --- config.yaml | 3 + src/grpc/server.rs | 23 +++++--- src/grpc/server_tests.rs | 30 ++++++++-- src/main.rs | 1 + src/pool.rs | 121 +++++++++++---------------------------- src/push.rs | 43 ++++++-------- 6 files changed, 96 insertions(+), 125 deletions(-) diff --git a/config.yaml b/config.yaml index dcf846be..f7015921 100644 --- a/config.yaml +++ b/config.yaml @@ -1,5 +1,8 @@ kafka_topic: "test-topic" push: true +default_metrics_tags: + host: "127.0.0.1" +log_filter: "debug,sqlx=debug,librdkafka=warn,h2=off" # workers: # - "http://127.0.0.1:50052" # - "http://127.0.0.1:50053" diff --git a/src/grpc/server.rs b/src/grpc/server.rs index c4d10f1c..335b682d 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -13,11 +13,12 @@ use tonic::{Request, Response, Status}; use crate::pool::WorkerPool; use crate::store::inflight_activation::{InflightActivationStatus, InflightActivationStore}; -use tracing::{error, instrument}; +use tracing::{debug, error, instrument}; pub struct TaskbrokerServer { pub store: Arc, pub pool: Arc>, + pub push: bool, } #[tonic::async_trait] @@ -28,7 +29,7 @@ impl ConsumerService for TaskbrokerServer { request: Request, ) -> Result, Status> { let address = &request.get_ref().address; - self.pool.write().await.add_worker(address); + self.pool.write().await.add_worker(address).await; Ok(Response::new(AddWorkerResponse {})) } @@ -47,6 +48,12 @@ impl ConsumerService for TaskbrokerServer { &self, request: Request, ) -> Result, Status> { + if self.push { + return Err(Status::failed_precondition( + "get_task is not available in push mode", + )); + } + let start_time = Instant::now(); let namespace = &request.get_ref().namespace; let inflight = self @@ -91,11 +98,7 @@ impl ConsumerService for TaskbrokerServer { let start_time = Instant::now(); let id = request.get_ref().id.clone(); - // Update worker queue size estimate - // self.pool - // .write() - // .await - // .decrement_queue_size(&request.get_ref().address); + debug!("Received task status {} for {id}", request.get_ref().status); let status: InflightActivationStatus = TaskActivationStatus::try_from(request.get_ref().status) @@ -113,6 +116,8 @@ impl ConsumerService for TaskbrokerServer { metrics::counter!("grpc_server.set_status.failure").increment(1); } + debug!("Status of task {id} set to {:?}", status); + let update_result = self.store.set_status(&id, status).await; if let Err(e) = update_result { error!( @@ -131,6 +136,10 @@ impl ConsumerService for TaskbrokerServer { return Ok(Response::new(SetTaskStatusResponse { task: None })); }; + if self.push { + return Ok(Response::new(SetTaskStatusResponse { task: None })); + } + let start_time = Instant::now(); let res = match self .store diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index f435ebbd..19eb653d 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -11,7 +11,11 @@ async fn test_get_task() { let store = create_test_store().await; let pool = create_pool(); - let service = TaskbrokerServer { store, pool }; + let service = TaskbrokerServer { + store, + pool, + push: false, + }; let request = GetTaskRequest { namespace: None }; let response = service.get_task(Request::new(request)).await; assert!(response.is_err()); @@ -26,7 +30,11 @@ async fn test_set_task_status() { let store = create_test_store().await; let pool = create_pool(); - let service = TaskbrokerServer { store, pool }; + let service = TaskbrokerServer { + store, + pool, + push: false, + }; let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 5, // Complete @@ -45,7 +53,11 @@ async fn test_set_task_status_invalid() { let store = create_test_store().await; let pool = create_pool(); - let service = TaskbrokerServer { store, pool }; + let service = TaskbrokerServer { + store, + pool, + push: false, + }; let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 1, // Invalid @@ -71,7 +83,11 @@ async fn test_get_task_success() { let activations = make_activations(1); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, pool }; + let service = TaskbrokerServer { + store, + pool, + push: false, + }; let request = GetTaskRequest { namespace: None }; let response = service.get_task(Request::new(request)).await; assert!(response.is_ok()); @@ -90,7 +106,11 @@ async fn test_set_task_status_success() { let activations = make_activations(2); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, pool }; + let service = TaskbrokerServer { + store, + pool, + push: false, + }; let request = GetTaskRequest { namespace: None }; let response = service.get_task(Request::new(request)).await; diff --git a/src/main.rs b/src/main.rs index f15b9d34..4ae680b6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -220,6 +220,7 @@ async fn main() -> Result<(), Error> { .add_service(ConsumerServiceServer::new(TaskbrokerServer { store: grpc_store, pool: grpc_pool, + push: config.push, })) .add_service(health_service.clone()) .serve(addr); diff --git a/src/pool.rs b/src/pool.rs index 02424cdb..56725af5 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,9 +1,7 @@ use std::cmp::Ordering; -use std::collections::{HashMap, HashSet, hash_map::Entry}; +use std::collections::{HashMap, HashSet}; use anyhow::Result; -use chrono::{DateTime, Utc}; -use itertools::Itertools; use rand::Rng; use rand::seq::IteratorRandom; use sentry_protos::taskworker::v1::{PushTaskRequest, worker_service_client::WorkerServiceClient}; @@ -27,14 +25,8 @@ struct WorkerClient { /// The worker address. address: String, - /// The worker's ESTIMATED queue size. + /// The worker's last known queue size. queue_size: u32, - - /// (TEMP) How many times has this worker been hit? - hits: u32, - - /// (TEMP) When was the last time this worker was hit? - last_hit: DateTime, } impl WorkerClient { @@ -43,8 +35,6 @@ impl WorkerClient { connection, address, queue_size, - hits: 0, - last_hit: Utc::now(), } } } @@ -58,11 +48,30 @@ impl WorkerPool { } } - /// Register worker address during execution. - pub fn add_worker>(&mut self, address: T) { + /// Register worker address and attempt to connect immediately. + /// Only adds the worker to the pool if the connection succeeds. + pub async fn add_worker>(&mut self, address: T) { let address = address.into(); info!("Adding worker {address}"); - self.addresses.insert(address); + + // Only add to the pool if we can connect + match connect(&address).await { + Ok(connection) => { + info!("Connected to {address}"); + + let client = WorkerClient::new(connection, address.clone(), 0); + + self.clients.insert(address.clone(), client); + self.addresses.insert(address); + } + + Err(e) => { + warn!( + "Did not register worker {address} due to connection error - {:?}", + e + ); + } + } } /// Unregister worker address during execution. @@ -72,41 +81,6 @@ impl WorkerPool { self.clients.remove(address); } - /// Decrement `queue_size` for the worker with address `address`. Called when worker reports task status. - // pub fn decrement_queue_size(&mut self, address: &String) { - // if let Some(client) = self.clients.get_mut(address) { - // client.queue_size = client.queue_size.saturating_sub(1); - // } - // } - - /// Call this function over and over again in another thread to keep the pool of active connections updated. - pub async fn update(&mut self) { - for address in &self.addresses { - if let Entry::Vacant(e) = self.clients.entry(address.into()) { - match connect(address).await { - Ok(connection) => { - info!("Connected to {address}"); - - let client = WorkerClient::new(connection, address.clone(), 0); - e.insert(client); - } - - Err(e) => { - warn!("Couldn't connect to {address} - {:?}", e); - } - } - } - } - - let pool = self - .clients - .iter() - .map(|(address, client)| format!("{address}:{}", client.queue_size)) - .join(","); - - info!(pool) - } - /// Try pushing a task to the best worker using P2C (Power of Two Choices). pub async fn push(&mut self, request: PushTaskRequest) -> Result<()> { let candidate = { @@ -146,50 +120,23 @@ impl WorkerPool { return Err(anyhow::anyhow!("Selected worker was full")); } - // Calculate estimation error before updating - let estimated = client.queue_size; - let actual = response.queue_size; - let error = (estimated as i64) - (actual as i64); - - // Record the absolute error - // metrics::histogram!("worker.queue_size.estimation_error", "worker" => address.clone()) - // .record(error.abs() as f64); - - // Record the signed error to see if we're systematically over/under-estimating - // metrics::histogram!("worker.queue_size.estimation_delta", "worker" => address.clone()) - // .record(error as f64); - - client.hits += 1; - - metrics::gauge!("worker.queue_size.hits", "worker" => address.clone()) - .set(client.hits as f64); - - let now = Utc::now(); - let time_delta = (now - client.last_hit).as_seconds_f64(); - client.last_hit = now; - - metrics::gauge!("worker.queue_size.time_since_hit", "worker" => address.clone()) - .set(time_delta); - - metrics::gauge!("worker.queue_size.delta", "worker" => address.clone()) - .set(error as f64); - - // Record both values for reference - metrics::gauge!("worker.queue_size.estimated", "worker" => address.clone()) - .set(estimated as f64); - - metrics::gauge!("worker.queue_size.actual", "worker" => address.clone()) - .set(actual as f64); - // Update this worker's queue size - client.queue_size = actual; + client.queue_size = response.queue_size; self.clients.insert(address, client); + Ok(()) } Err(e) => { - // Remove this unhealthy worker from the active connection pool + warn!( + "Removing worker {address} from pool due to RPC error - {:?}", + e + ); + + // Remove this unhealthy worker completely - from both clients and addresses self.clients.remove(&address); + self.addresses.remove(&address); + Err(e.into()) } } diff --git a/src/push.rs b/src/push.rs index b6167304..35babe2a 100644 --- a/src/push.rs +++ b/src/push.rs @@ -40,25 +40,6 @@ impl TaskPusher { info!("Push task loop starting..."); let guard = elegant_departure::get_shutdown_guard().shutdown_on_drop(); - let pool = self.pool.clone(); - - // Spawn a separate task to update the candidate worker collection - tokio::spawn(async move { - let guard = elegant_departure::get_shutdown_guard().shutdown_on_drop(); - let mut beep_interval = tokio::time::interval(Duration::from_secs(1)); - - loop { - tokio::select! { - _ = guard.wait() => { - break; - } - - _ = beep_interval.tick() => { - pool.write().await.update().await; - } - } - } - }); loop { tokio::select! { @@ -68,6 +49,7 @@ impl TaskPusher { } _ = async { + debug!("About to process next task..."); self.process_next_task().await; } => {} } @@ -81,10 +63,17 @@ impl TaskPusher { impl TaskPusher { /// Grab the next pending task from the store. async fn process_next_task(&self) { + debug!("Getting the next task..."); + match self.store.peek_pending_activation().await { Ok(Some(inflight)) => { + let id = inflight.id.clone(); + debug!("Found task {id} with status {:?}", inflight.status); + if let Err(e) = self.handle_task_push(inflight).await { - warn!("Task push resulted in error - {:?}", e); + warn!("Pushing task {id} resulted in error - {:?}", e); + } else { + debug!("Task {id} was pushed!"); } } @@ -116,21 +105,23 @@ impl TaskPusher { async fn push_to_worker(&self, activation: TaskActivation, task_id: &str) -> Result<()> { let request = PushTaskRequest { task: Some(activation), - callback_url: format!("{}:{}", self.config.grpc_addr, self.config.grpc_port), + callback_url: format!( + "{}:{}", + self.config.default_metrics_tags["host"], self.config.grpc_port + ), }; let result = self.pool.write().await.push(request).await; match result { Ok(()) => { - info!("Pushed task {task_id}"); + debug!("Pushed task {task_id}"); self.mark_task_as_processing(task_id).await; Ok(()) } Err(e) => { - warn!("Could not push task {task_id} - {:?}", e); - sleep(milliseconds(100)).await; + debug!("Could not push task {task_id} - {:?}", e); Err(e) } } @@ -140,11 +131,11 @@ impl TaskPusher { async fn mark_task_as_processing(&self, task_id: &str) { match self.store.mark_as_processing_if_pending(task_id).await { Ok(true) => { - info!("Task {} pushed and marked as processing", task_id); + debug!("Task {} pushed and marked as processing", task_id); } Ok(false) => { - warn!("Task {task_id} was already taken by another process (race condition)"); + error!("Task {task_id} was already taken by another process (race condition)"); } Err(e) => { From e627a9d784031101327a8f1d48d46e3faa4477ec Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 9 Jan 2026 15:48:52 -0800 Subject: [PATCH 10/10] Replace `HashMap` with `DashMap` --- Cargo.lock | 27 ++++++++++++++++++++++++--- Cargo.toml | 1 + src/grpc/server.rs | 7 +++---- src/main.rs | 3 +-- src/pool.rs | 29 +++++++++++------------------ src/push.rs | 15 ++++++++------- src/test_utils.rs | 5 ++--- 7 files changed, 50 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 58cac758..ca469317 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -552,6 +552,20 @@ dependencies = [ "typenum", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "debugid" version = "0.8.0" @@ -919,6 +933,12 @@ dependencies = [ "crunchy", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.5" @@ -936,7 +956,7 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" dependencies = [ - "hashbrown", + "hashbrown 0.15.5", ] [[package]] @@ -1255,7 +1275,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2481980430f9f78649238835720ddccc57e52df14ffce1c6f37391d61b563e9" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.5", ] [[package]] @@ -2645,7 +2665,7 @@ dependencies = [ "futures-intrusive", "futures-io", "futures-util", - "hashbrown", + "hashbrown 0.15.5", "hashlink", "indexmap", "log", @@ -2876,6 +2896,7 @@ dependencies = [ "chrono", "clap", "criterion", + "dashmap", "elegant-departure", "figment", "futures", diff --git a/Cargo.toml b/Cargo.toml index 1ba7368b..a50b32d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ anyhow = "1.0.92" bytes = "1.10.0" chrono = { version = "0.4.26" } clap = { version = "4.5.20", features = ["derive"] } +dashmap = "6.1.0" elegant-departure = { version = "0.3.1", features = ["tokio"] } figment = { version = "0.10.19", features = ["env", "yaml", "test"] } futures = "0.3.31" diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 335b682d..94e6b0f4 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -8,7 +8,6 @@ use sentry_protos::taskbroker::v1::{ }; use std::sync::Arc; use std::time::Instant; -use tokio::sync::RwLock; use tonic::{Request, Response, Status}; use crate::pool::WorkerPool; @@ -17,7 +16,7 @@ use tracing::{debug, error, instrument}; pub struct TaskbrokerServer { pub store: Arc, - pub pool: Arc>, + pub pool: Arc, pub push: bool, } @@ -29,7 +28,7 @@ impl ConsumerService for TaskbrokerServer { request: Request, ) -> Result, Status> { let address = &request.get_ref().address; - self.pool.write().await.add_worker(address).await; + self.pool.add_worker(address).await; Ok(Response::new(AddWorkerResponse {})) } @@ -39,7 +38,7 @@ impl ConsumerService for TaskbrokerServer { request: Request, ) -> Result, Status> { let address = &request.get_ref().address; - self.pool.write().await.remove_worker(address); + self.pool.remove_worker(address); Ok(Response::new(RemoveWorkerResponse {})) } diff --git a/src/main.rs b/src/main.rs index 4ae680b6..483d6a0d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,6 @@ use taskbroker::pool::WorkerPool; use taskbroker::push::TaskPusher; use taskbroker::upkeep::upkeep; use tokio::signal::unix::SignalKind; -use tokio::sync::RwLock; use tokio::task::JoinHandle; use tokio::{select, time}; use tonic::transport::Server; @@ -60,7 +59,7 @@ async fn main() -> Result<(), Error> { let runtime_config_manager = Arc::new(RuntimeConfigManager::new(config.runtime_config_path.clone()).await); - let pool = Arc::new(RwLock::new(WorkerPool::new(config.workers.clone()))); + let pool = Arc::new(WorkerPool::new(config.workers.clone())); println!("taskbroker starting"); println!("version: {}", get_version().trim()); diff --git a/src/pool.rs b/src/pool.rs index 56725af5..242832d5 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,20 +1,17 @@ use std::cmp::Ordering; -use std::collections::{HashMap, HashSet}; use anyhow::Result; +use dashmap::DashMap; use rand::Rng; use rand::seq::IteratorRandom; use sentry_protos::taskworker::v1::{PushTaskRequest, worker_service_client::WorkerServiceClient}; use tonic::transport::{Channel, Error}; use tracing::{info, warn}; -#[derive(Clone)] pub struct WorkerPool { /// Maps every worker address to its client. - clients: HashMap, - - /// All available worker addresses. - addresses: HashSet, + /// Uses DashMap for concurrent access without external locking. + clients: DashMap, } #[derive(Clone)] @@ -41,16 +38,15 @@ impl WorkerClient { impl WorkerPool { /// Create a new `WorkerPool` instance. - pub fn new>(addresses: T) -> Self { + pub fn new>(_addresses: T) -> Self { Self { - addresses: addresses.into_iter().collect(), - clients: HashMap::new(), + clients: DashMap::new(), } } /// Register worker address and attempt to connect immediately. /// Only adds the worker to the pool if the connection succeeds. - pub async fn add_worker>(&mut self, address: T) { + pub async fn add_worker>(&self, address: T) { let address = address.into(); info!("Adding worker {address}"); @@ -62,7 +58,6 @@ impl WorkerPool { let client = WorkerClient::new(connection, address.clone(), 0); self.clients.insert(address.clone(), client); - self.addresses.insert(address); } Err(e) => { @@ -75,19 +70,19 @@ impl WorkerPool { } /// Unregister worker address during execution. - pub fn remove_worker(&mut self, address: &String) { + pub fn remove_worker(&self, address: &String) { info!("Removing worker {address}"); - self.addresses.remove(address); self.clients.remove(address); } /// Try pushing a task to the best worker using P2C (Power of Two Choices). - pub async fn push(&mut self, request: PushTaskRequest) -> Result<()> { + pub async fn push(&self, request: PushTaskRequest) -> Result<()> { let candidate = { let mut rng = rand::rng(); self.clients - .values() + .iter() + .map(|entry| entry.value().clone()) .choose_multiple(&mut rng, 2) .into_iter() .min_by(|a, b| { @@ -103,7 +98,6 @@ impl WorkerPool { other => other, } }) - .cloned() }; let Some(mut client) = candidate else { @@ -133,9 +127,8 @@ impl WorkerPool { e ); - // Remove this unhealthy worker completely - from both clients and addresses + // Remove this unhealthy worker self.clients.remove(&address); - self.addresses.remove(&address); Err(e.into()) } diff --git a/src/push.rs b/src/push.rs index 35babe2a..5fbb0135 100644 --- a/src/push.rs +++ b/src/push.rs @@ -1,10 +1,11 @@ -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; +use std::time::Duration; use anyhow::Result; use prost::Message; use sentry_protos::{taskbroker::v1::TaskActivation, taskworker::v1::PushTaskRequest}; -use tokio::{sync::RwLock, time::sleep}; -use tracing::{debug, error, info, warn}; +use tokio::time::sleep; +use tracing::{debug, error, info}; use crate::config::Config; use crate::pool::WorkerPool; @@ -12,7 +13,7 @@ use crate::store::inflight_activation::{InflightActivation, InflightActivationSt pub struct TaskPusher { /// Pool of workers through which we will push tasks. - pool: Arc>, + pool: Arc, /// Broker configuration. config: Arc, @@ -26,7 +27,7 @@ impl TaskPusher { pub fn new( store: Arc, config: Arc, - pool: Arc>, + pool: Arc, ) -> Self { Self { store, @@ -71,7 +72,7 @@ impl TaskPusher { debug!("Found task {id} with status {:?}", inflight.status); if let Err(e) = self.handle_task_push(inflight).await { - warn!("Pushing task {id} resulted in error - {:?}", e); + debug!("Pushing task {id} resulted in error - {:?}", e); } else { debug!("Task {id} was pushed!"); } @@ -111,7 +112,7 @@ impl TaskPusher { ), }; - let result = self.pool.write().await.push(request).await; + let result = self.pool.push(request).await; match result { Ok(()) => { diff --git a/src/test_utils.rs b/src/test_utils.rs index 6763af85..7b9a3436 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -8,7 +8,6 @@ use rdkafka::{ producer::FutureProducer, }; use std::{collections::HashMap, sync::Arc}; -use tokio::sync::RwLock; use uuid::Uuid; use crate::{ @@ -90,8 +89,8 @@ pub fn create_config() -> Arc { } /// Create a basic [`WorkerPool`]. -pub fn create_pool() -> Arc> { - Arc::new(RwLock::new(WorkerPool::new(["127.0.0.1:50052".into()]))) +pub fn create_pool() -> Arc { + Arc::new(WorkerPool::new(["127.0.0.1:50052".into()])) } /// Create an InflightActivationStore instance