diff --git a/Cargo.toml b/Cargo.toml index 6ad0c653b..a96702ca3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ quickcheck = "1" rand = "0.8" slab = "0.4" sync_wrapper = "1" -tokio = "1.6.2" +tokio = "1.32.1" tokio-stream = "0.1.0" tokio-test = "0.4" tokio-util = { version = "0.7.0", default-features = false } diff --git a/tower/Cargo.toml b/tower/Cargo.toml index a04e182a4..fd99bc036 100644 --- a/tower/Cargo.toml +++ b/tower/Cargo.toml @@ -58,6 +58,7 @@ reconnect = ["make", "tokio/io-std", "tracing"] retry = ["__common", "tokio/time", "util"] spawn-ready = ["__common", "futures-util", "tokio/sync", "tokio/rt", "util", "tracing"] steer = [] +task-local = ["tokio/rt"] timeout = ["pin-project-lite", "tokio/time"] util = ["__common", "futures-util", "pin-project-lite", "sync_wrapper"] diff --git a/tower/src/lib.rs b/tower/src/lib.rs index ce911e9d8..c3f926e98 100644 --- a/tower/src/lib.rs +++ b/tower/src/lib.rs @@ -191,6 +191,8 @@ pub mod retry; pub mod spawn_ready; #[cfg(feature = "steer")] pub mod steer; +#[cfg(feature = "task-local")] +pub mod task_local; #[cfg(feature = "timeout")] pub mod timeout; #[cfg(feature = "util")] diff --git a/tower/src/task_local.rs b/tower/src/task_local.rs new file mode 100644 index 000000000..a423da248 --- /dev/null +++ b/tower/src/task_local.rs @@ -0,0 +1,80 @@ +//! Middleware to set tokio task-local data. + +use tokio::task::{futures::TaskLocalFuture, LocalKey}; +use tower_layer::Layer; +use tower_service::Service; + +/// A [`Layer`] that produces a [`SetTaskLocal`] service. +#[derive(Clone, Copy, Debug)] +pub struct SetTaskLocalLayer { + key: &'static LocalKey, + value: T, +} + +impl SetTaskLocalLayer +where + T: Clone + Send + Sync + 'static, +{ + /// Create a new [`SetTaskLocalLayer`]. + pub const fn new(key: &'static LocalKey, value: T) -> Self { + SetTaskLocalLayer { key, value } + } +} + +impl Layer for SetTaskLocalLayer +where + T: Clone + Send + Sync + 'static, +{ + type Service = SetTaskLocal; + + fn layer(&self, inner: S) -> Self::Service { + SetTaskLocal::new(inner, self.key, self.value.clone()) + } +} + +/// Service returned by the [`set_task_local`] combinator. +/// +/// [`set_task_local`]: crate::util::ServiceExt::set_task_local +#[derive(Clone, Copy, Debug)] +pub struct SetTaskLocal { + inner: S, + key: &'static LocalKey, + value: T, +} + +impl SetTaskLocal +where + T: Clone + Send + Sync + 'static, +{ + /// Create a new [`SetTaskLocal`] service. + pub const fn new(inner: S, key: &'static LocalKey, value: T) -> Self { + Self { inner, key, value } + } +} + +impl Service for SetTaskLocal +where + S: Service, + T: Clone + Send + Sync + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = TaskLocalFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: R) -> Self::Future { + // This is not great. I don't want to clone the value twice. + // Probably need to introduce a custom Future that delays calling + // inner.call until the local-key is set? + let fut = self + .key + .sync_scope(self.value.clone(), || self.inner.call(req)); + self.key.scope(self.value.clone(), fut) + } +} diff --git a/tower/src/util/mod.rs b/tower/src/util/mod.rs index 4c56de813..a03c22875 100644 --- a/tower/src/util/mod.rs +++ b/tower/src/util/mod.rs @@ -944,6 +944,21 @@ pub trait ServiceExt: tower_service::Service { MapFuture::new(self, f) } + /// Set the given tokio [`task_local!`][tokio::task_local] to a clone of the given value for + /// every [`call`][crate::Service::call] of the underlying service. + #[cfg(feature = "task-local")] + fn set_task_local( + self, + key: &'static tokio::task::LocalKey, + value: T, + ) -> crate::task_local::SetTaskLocal + where + Self: Sized, + T: Clone + Send + Sync + 'static, + { + crate::task_local::SetTaskLocal::new(self, key, value) + } + /// Convert the service into a [`Service`] + [`Send`] trait object. /// /// See [`BoxService`] for more details. diff --git a/tower/tests/task_local.rs b/tower/tests/task_local.rs new file mode 100644 index 000000000..3879df2bc --- /dev/null +++ b/tower/tests/task_local.rs @@ -0,0 +1,36 @@ +#![cfg(all(feature = "task-local", feature = "util"))] + +use futures::pin_mut; +use tower::{util::ServiceExt, Service as _}; +use tower_test::{assert_request_eq, mock}; + +mod support; + +tokio::task_local! { + static NUM: i32; +} + +#[tokio::test] +async fn set_task_local() { + let _t = support::trace_init(); + + let (service, handle) = mock::pair(); + pin_mut!(handle); + + let mut client = service + .map_request(|()| { + assert_eq!(NUM.get(), 9000); + }) + .map_response(|()| { + assert_eq!(NUM.get(), 9000); + }) + .set_task_local(&NUM, 9000); + + // allow a request through + handle.allow(1); + + let ready = client.ready().await.unwrap(); + let fut = ready.call(()); + assert_request_eq!(handle, ()).send_response(()); + fut.await.unwrap(); +}