diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5ca0182abeb17..bce2f347c7be8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -91,6 +91,13 @@ jobs: miri: runs-on: macos-latest timeout-minutes: 60 + env: + # -Zrandomize-layout makes sure we dont rely on the layout of anything that might change + RUSTFLAGS: -Zrandomize-layout + # https://github.com/rust-lang/miri#miri--z-flags-and-environment-variables + # -Zmiri-disable-isolation is needed because our executor uses `fastrand` which accesses system time. + # -Zmiri-ignore-leaks is necessary because a bunch of tests don't join all threads before finishing. + MIRIFLAGS: -Zmiri-ignore-leaks -Zmiri-disable-isolation steps: - uses: actions/checkout@v4 - uses: actions/cache/restore@v4 @@ -111,17 +118,14 @@ jobs: with: toolchain: ${{ env.NIGHTLY_TOOLCHAIN }} components: miri - - name: CI job + - name: CI job (Tasks) + # To run the tests one item at a time for troubleshooting, use + # cargo --quiet test --lib -- --list | sed 's/: test$//' | MIRIFLAGS="-Zmiri-disable-isolation -Zmiri-disable-weak-memory-emulation" xargs -n1 cargo miri test -p bevy_tasks --lib -- --exact + run: cargo miri test -p bevy_tasks --features bevy_executor --features multi_threaded + - name: CI job (ECS) # To run the tests one item at a time for troubleshooting, use # cargo --quiet test --lib -- --list | sed 's/: test$//' | MIRIFLAGS="-Zmiri-disable-isolation -Zmiri-disable-weak-memory-emulation" xargs -n1 cargo miri test -p bevy_ecs --lib -- --exact run: cargo miri test -p bevy_ecs --features bevy_utils/debug - env: - # -Zrandomize-layout makes sure we dont rely on the layout of anything that might change - RUSTFLAGS: -Zrandomize-layout - # https://github.com/rust-lang/miri#miri--z-flags-and-environment-variables - # -Zmiri-disable-isolation is needed because our executor uses `fastrand` which accesses system time. - # -Zmiri-ignore-leaks is necessary because a bunch of tests don't join all threads before finishing. - MIRIFLAGS: -Zmiri-ignore-leaks -Zmiri-disable-isolation check-compiles: runs-on: ubuntu-latest diff --git a/Cargo.toml b/Cargo.toml index 7d4dd19cef186..4527f955e7422 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -131,7 +131,6 @@ unused_qualifications = "warn" [features] default = [ "std", - "async_executor", "android-game-activity", "android_shared_stdcxx", "animation", @@ -570,9 +569,6 @@ custom_cursor = ["bevy_internal/custom_cursor"] # Experimental support for nodes that are ignored for UI layouting ghost_nodes = ["bevy_internal/ghost_nodes"] -# Uses `async-executor` as a task execution backend. -async_executor = ["std", "bevy_internal/async_executor"] - # Allows access to the `std` crate. std = ["bevy_internal/std"] @@ -1986,13 +1982,13 @@ wasm = true # Async Tasks [[example]] -name = "async_compute" -path = "examples/async_tasks/async_compute.rs" +name = "blocking_compute" +path = "examples/async_tasks/blocking_compute.rs" doc-scrape-examples = true -[package.metadata.example.async_compute] -name = "Async Compute" -description = "How to use `AsyncComputeTaskPool` to complete longer running tasks" +[package.metadata.example.blocking_compute] +name = "Blocking Compute" +description = "How to use `TaskPool` to complete longer running tasks" category = "Async Tasks" wasm = false diff --git a/benches/benches/bevy_ecs/iteration/heavy_compute.rs b/benches/benches/bevy_ecs/iteration/heavy_compute.rs index e057b20a431be..72141e4010fe4 100644 --- a/benches/benches/bevy_ecs/iteration/heavy_compute.rs +++ b/benches/benches/bevy_ecs/iteration/heavy_compute.rs @@ -1,5 +1,5 @@ use bevy_ecs::prelude::*; -use bevy_tasks::{ComputeTaskPool, TaskPool}; +use bevy_tasks::{TaskPool, TaskPoolBuilder}; use criterion::Criterion; use glam::*; @@ -20,7 +20,7 @@ pub fn heavy_compute(c: &mut Criterion) { group.warm_up_time(core::time::Duration::from_millis(500)); group.measurement_time(core::time::Duration::from_secs(4)); group.bench_function("base", |b| { - ComputeTaskPool::get_or_init(TaskPool::default); + TaskPool::get_or_init(TaskPoolBuilder::default); let mut world = World::default(); diff --git a/benches/benches/bevy_ecs/iteration/par_iter_simple.rs b/benches/benches/bevy_ecs/iteration/par_iter_simple.rs index 92259cb98fecf..aa7976c2f9730 100644 --- a/benches/benches/bevy_ecs/iteration/par_iter_simple.rs +++ b/benches/benches/bevy_ecs/iteration/par_iter_simple.rs @@ -1,5 +1,5 @@ use bevy_ecs::prelude::*; -use bevy_tasks::{ComputeTaskPool, TaskPool}; +use bevy_tasks::{TaskPool, TaskPoolBuilder}; use glam::*; #[derive(Component, Copy, Clone)] @@ -26,7 +26,7 @@ fn insert_if_bit_enabled(entity: &mut EntityWorldMut, i: u16) { impl<'w> Benchmark<'w> { pub fn new(fragment: u16) -> Self { - ComputeTaskPool::get_or_init(TaskPool::default); + TaskPool::get_or_init(TaskPoolBuilder::default); let mut world = World::new(); diff --git a/benches/benches/bevy_ecs/iteration/par_iter_simple_foreach_hybrid.rs b/benches/benches/bevy_ecs/iteration/par_iter_simple_foreach_hybrid.rs index 9dbcba87852f7..729a3dec89ec9 100644 --- a/benches/benches/bevy_ecs/iteration/par_iter_simple_foreach_hybrid.rs +++ b/benches/benches/bevy_ecs/iteration/par_iter_simple_foreach_hybrid.rs @@ -1,5 +1,5 @@ use bevy_ecs::prelude::*; -use bevy_tasks::{ComputeTaskPool, TaskPool}; +use bevy_tasks::{TaskPool, TaskPoolBuilder}; use rand::{prelude::SliceRandom, SeedableRng}; use rand_chacha::ChaCha8Rng; @@ -18,7 +18,7 @@ pub struct Benchmark<'w>(World, QueryState<(&'w mut TableData, &'w SparseData)>) impl<'w> Benchmark<'w> { pub fn new() -> Self { let mut world = World::new(); - ComputeTaskPool::get_or_init(TaskPool::default); + TaskPool::get_or_init(TaskPoolBuilder::default); let mut v = vec![]; for _ in 0..100000 { diff --git a/crates/bevy_a11y/Cargo.toml b/crates/bevy_a11y/Cargo.toml index 262b8e5b823fe..8c1ed0438e66f 100644 --- a/crates/bevy_a11y/Cargo.toml +++ b/crates/bevy_a11y/Cargo.toml @@ -9,7 +9,7 @@ license = "MIT OR Apache-2.0" keywords = ["bevy", "accessibility", "a11y"] [features] -default = ["std", "bevy_reflect", "bevy_ecs/async_executor"] +default = ["std", "bevy_reflect"] # Functionality diff --git a/crates/bevy_app/src/task_pool_plugin.rs b/crates/bevy_app/src/task_pool_plugin.rs index 8014790f07772..9b889719ba90c 100644 --- a/crates/bevy_app/src/task_pool_plugin.rs +++ b/crates/bevy_app/src/task_pool_plugin.rs @@ -1,8 +1,8 @@ use crate::{App, Plugin}; -use alloc::string::ToString; -use bevy_platform::sync::Arc; -use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder}; +use alloc::{string::ToString, vec::Vec}; +use bevy_platform::{collections::HashMap, sync::Arc}; +use bevy_tasks::{TaskPool, TaskPoolBuilder, TaskPriority}; use core::fmt::Debug; use log::trace; @@ -21,7 +21,7 @@ cfg_if::cfg_if! { } } -/// Setup of default task pools: [`AsyncComputeTaskPool`], [`ComputeTaskPool`], [`IoTaskPool`]. +/// Setup of the default task pool: [`TaskPool`]. #[derive(Default)] pub struct TaskPoolPlugin { /// Options for the [`TaskPool`](bevy_tasks::TaskPool) created at application start. @@ -40,7 +40,7 @@ impl Plugin for TaskPoolPlugin { /// Defines a simple way to determine how many threads to use given the number of remaining cores /// and number of total cores -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct TaskPoolThreadAssignmentPolicy { /// Force using at least this many threads pub min_threads: usize, @@ -49,22 +49,6 @@ pub struct TaskPoolThreadAssignmentPolicy { /// Target using this percentage of total cores, clamped by `min_threads` and `max_threads`. It is /// permitted to use 1.0 to try to use all remaining threads pub percent: f32, - /// Callback that is invoked once for every created thread as it starts. - /// This configuration will be ignored under wasm platform. - pub on_thread_spawn: Option>, - /// Callback that is invoked once for every created thread as it terminates - /// This configuration will be ignored under wasm platform. - pub on_thread_destroy: Option>, -} - -impl Debug for TaskPoolThreadAssignmentPolicy { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("TaskPoolThreadAssignmentPolicy") - .field("min_threads", &self.min_threads) - .field("max_threads", &self.max_threads) - .field("percent", &self.percent) - .finish() - } } impl TaskPoolThreadAssignmentPolicy { @@ -92,7 +76,7 @@ impl TaskPoolThreadAssignmentPolicy { /// Helper for configuring and creating the default task pools. For end-users who want full control, /// set up [`TaskPoolPlugin`] -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct TaskPoolOptions { /// If the number of physical cores is less than `min_total_threads`, force using /// `min_total_threads` @@ -101,47 +85,70 @@ pub struct TaskPoolOptions { /// `max_total_threads` pub max_total_threads: usize, - /// Used to determine number of IO threads to allocate - pub io: TaskPoolThreadAssignmentPolicy, - /// Used to determine number of async compute threads to allocate - pub async_compute: TaskPoolThreadAssignmentPolicy, - /// Used to determine number of compute threads to allocate - pub compute: TaskPoolThreadAssignmentPolicy, + /// Callback that is invoked once for every created thread as it starts. + /// This configuration will be ignored under wasm platform. + pub on_thread_spawn: Option>, + /// Callback that is invoked once for every created thread as it terminates + /// This configuration will be ignored under wasm platform. + pub on_thread_destroy: Option>, + + /// Used to determine number of threads to provide to each + pub priority_assignment_policies: HashMap, +} + +impl Debug for TaskPoolOptions { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("TaskPoolOptions") + .field("min_total_threads", &self.min_total_threads) + .field("max_total_threads", &self.max_total_threads) + .field( + "priority_assignment_policies", + &self.priority_assignment_policies, + ) + .finish() + } } impl Default for TaskPoolOptions { fn default() -> Self { - TaskPoolOptions { - // By default, use however many cores are available on the system - min_total_threads: 1, - max_total_threads: usize::MAX, - - // Use 25% of cores for IO, at least 1, no more than 4 - io: TaskPoolThreadAssignmentPolicy { + let mut priority_assignment_policies = HashMap::new(); + // Use 25% of cores for IO, at least 1, no more than 4 + priority_assignment_policies.insert( + TaskPriority::BlockingIO, + TaskPoolThreadAssignmentPolicy { min_threads: 1, max_threads: 4, percent: 0.25, - on_thread_spawn: None, - on_thread_destroy: None, }, - - // Use 25% of cores for async compute, at least 1, no more than 4 - async_compute: TaskPoolThreadAssignmentPolicy { + ); + // Use 25% of cores for blocking compute, at least 1, no more than 4 + priority_assignment_policies.insert( + TaskPriority::BlockingCompute, + TaskPoolThreadAssignmentPolicy { min_threads: 1, max_threads: 4, percent: 0.25, - on_thread_spawn: None, - on_thread_destroy: None, }, - - // Use all remaining cores for compute (at least 1) - compute: TaskPoolThreadAssignmentPolicy { + ); + // Use 25% of cores for async IO, at least 1, no more than 4 + priority_assignment_policies.insert( + TaskPriority::AsyncIO, + TaskPoolThreadAssignmentPolicy { min_threads: 1, - max_threads: usize::MAX, - percent: 1.0, // This 1.0 here means "whatever is left over" - on_thread_spawn: None, - on_thread_destroy: None, + max_threads: 4, + percent: 0.25, }, + ); + + TaskPoolOptions { + // By default, use however many cores are available on the system + min_total_threads: 1, + max_total_threads: usize::MAX, + + on_thread_spawn: None, + on_thread_destroy: None, + + priority_assignment_policies, } } } @@ -164,133 +171,55 @@ impl TaskPoolOptions { let mut remaining_threads = total_threads; - { - // Determine the number of IO threads we will use - let io_threads = self - .io - .get_number_of_threads(remaining_threads, total_threads); - - trace!("IO Threads: {io_threads}"); - remaining_threads = remaining_threads.saturating_sub(io_threads); - - IoTaskPool::get_or_init(|| { - let builder = TaskPoolBuilder::default() - .num_threads(io_threads) - .thread_name("IO Task Pool".to_string()); - - #[cfg(not(all(target_arch = "wasm32", feature = "web")))] - let builder = { - let mut builder = builder; - if let Some(f) = self.io.on_thread_spawn.clone() { - builder = builder.on_thread_spawn(move || f()); - } - if let Some(f) = self.io.on_thread_destroy.clone() { - builder = builder.on_thread_destroy(move || f()); - } - builder - }; - - builder.build() - }); - } + let mut builder = TaskPoolBuilder::default() + .num_threads(total_threads) + .thread_name("Task Pool".to_string()); - { - // Determine the number of async compute threads we will use - let async_compute_threads = self - .async_compute - .get_number_of_threads(remaining_threads, total_threads); - - trace!("Async Compute Threads: {async_compute_threads}"); - remaining_threads = remaining_threads.saturating_sub(async_compute_threads); - - AsyncComputeTaskPool::get_or_init(|| { - let builder = TaskPoolBuilder::default() - .num_threads(async_compute_threads) - .thread_name("Async Compute Task Pool".to_string()); - - #[cfg(not(all(target_arch = "wasm32", feature = "web")))] - let builder = { - let mut builder = builder; - if let Some(f) = self.async_compute.on_thread_spawn.clone() { - builder = builder.on_thread_spawn(move || f()); - } - if let Some(f) = self.async_compute.on_thread_destroy.clone() { - builder = builder.on_thread_destroy(move || f()); - } - builder - }; - - builder.build() - }); - } + let mut ordered = self.priority_assignment_policies.iter().collect::>(); + ordered.sort_by_key(|(prio, _)| **prio); + for (priority, policy) in ordered { + let priority_threads = policy.get_number_of_threads(remaining_threads, total_threads); + builder = builder.priority_limit(*priority, Some(priority_threads)); - { - // Determine the number of compute threads we will use - // This is intentionally last so that an end user can specify 1.0 as the percent - let compute_threads = self - .compute - .get_number_of_threads(remaining_threads, total_threads); - - trace!("Compute Threads: {compute_threads}"); - - ComputeTaskPool::get_or_init(|| { - let builder = TaskPoolBuilder::default() - .num_threads(compute_threads) - .thread_name("Compute Task Pool".to_string()); - - #[cfg(not(all(target_arch = "wasm32", feature = "web")))] - let builder = { - let mut builder = builder; - if let Some(f) = self.compute.on_thread_spawn.clone() { - builder = builder.on_thread_spawn(move || f()); - } - if let Some(f) = self.compute.on_thread_destroy.clone() { - builder = builder.on_thread_destroy(move || f()); - } - builder - }; - - builder.build() - }); + remaining_threads = remaining_threads.saturating_sub(priority_threads); + trace!("{:?} Threads: {priority_threads}", *priority); } + + #[cfg(not(all(target_arch = "wasm32", feature = "web")))] + let builder = { + let mut builder = builder; + if let Some(f) = self.on_thread_spawn.clone() { + builder = builder.on_thread_spawn(move || f()); + } + if let Some(f) = self.on_thread_destroy.clone() { + builder = builder.on_thread_destroy(move || f()); + } + builder + }; + + TaskPool::get_or_init(move || builder); } } #[cfg(test)] mod tests { use super::*; - use bevy_tasks::prelude::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool}; + use bevy_tasks::prelude::TaskPool; #[test] fn runs_spawn_local_tasks() { let mut app = App::new(); app.add_plugins(TaskPoolPlugin::default()); - let (async_tx, async_rx) = crossbeam_channel::unbounded(); - AsyncComputeTaskPool::get() - .spawn_local(async move { - async_tx.send(()).unwrap(); - }) - .detach(); - - let (compute_tx, compute_rx) = crossbeam_channel::unbounded(); - ComputeTaskPool::get() - .spawn_local(async move { - compute_tx.send(()).unwrap(); - }) - .detach(); - - let (io_tx, io_rx) = crossbeam_channel::unbounded(); - IoTaskPool::get() + let (tx, rx) = crossbeam_channel::unbounded(); + TaskPool::get_or_init(Default::default) .spawn_local(async move { - io_tx.send(()).unwrap(); + tx.send(()).unwrap(); }) .detach(); app.run(); - async_rx.try_recv().unwrap(); - compute_rx.try_recv().unwrap(); - io_rx.try_recv().unwrap(); + rx.try_recv().unwrap(); } } diff --git a/crates/bevy_asset/Cargo.toml b/crates/bevy_asset/Cargo.toml index 810299fc7349a..1d55005e83578 100644 --- a/crates/bevy_asset/Cargo.toml +++ b/crates/bevy_asset/Cargo.toml @@ -27,9 +27,7 @@ bevy_ecs = { path = "../bevy_ecs", version = "0.17.0-dev", default-features = fa bevy_reflect = { path = "../bevy_reflect", version = "0.17.0-dev", default-features = false, features = [ "uuid", ] } -bevy_tasks = { path = "../bevy_tasks", version = "0.17.0-dev", default-features = false, features = [ - "async_executor", -] } +bevy_tasks = { path = "../bevy_tasks", version = "0.17.0-dev", default-features = false } bevy_utils = { path = "../bevy_utils", version = "0.17.0-dev", default-features = false } bevy_platform = { path = "../bevy_platform", version = "0.17.0-dev", default-features = false, features = [ "std", diff --git a/crates/bevy_asset/src/processor/mod.rs b/crates/bevy_asset/src/processor/mod.rs index 7b3d36e686b02..258271808f9f7 100644 --- a/crates/bevy_asset/src/processor/mod.rs +++ b/crates/bevy_asset/src/processor/mod.rs @@ -59,7 +59,7 @@ use crate::{ use alloc::{borrow::ToOwned, boxed::Box, collections::VecDeque, sync::Arc, vec, vec::Vec}; use bevy_ecs::prelude::*; use bevy_platform::collections::{HashMap, HashSet}; -use bevy_tasks::IoTaskPool; +use bevy_tasks::{TaskPool, TaskPriority}; use futures_io::ErrorKind; use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt}; use parking_lot::RwLock; @@ -219,15 +219,18 @@ impl AssetProcessor { pub fn process_assets(&self) { let start_time = std::time::Instant::now(); debug!("Processing Assets"); - IoTaskPool::get().scope(|scope| { - scope.spawn(async move { - self.initialize().await.unwrap(); - for source in self.sources().iter_processed() { - self.process_assets_internal(scope, source, PathBuf::from("")) - .await - .unwrap(); - } - }); + TaskPool::get().scope(|scope| { + scope + .builder() + .with_priority(TaskPriority::BlockingCompute) + .spawn(async move { + self.initialize().await.unwrap(); + for source in self.sources().iter_processed() { + self.process_assets_internal(scope, source, PathBuf::from("")) + .await + .unwrap(); + } + }); }); // This must happen _after_ the scope resolves or it will happen "too early" // Don't move this into the async scope above! process_assets is a blocking/sync function this is fine @@ -421,12 +424,15 @@ impl AssetProcessor { #[cfg(any(target_arch = "wasm32", not(feature = "multi_threaded")))] error!("AddFolder event cannot be handled in single threaded mode (or Wasm) yet."); #[cfg(all(not(target_arch = "wasm32"), feature = "multi_threaded"))] - IoTaskPool::get().scope(|scope| { - scope.spawn(async move { - self.process_assets_internal(scope, source, path) - .await - .unwrap(); - }); + TaskPool::get().scope(|scope| { + scope + .builder() + .with_priority(TaskPriority::BlockingIO) + .spawn(async move { + self.process_assets_internal(scope, source, path) + .await + .unwrap(); + }); }); } @@ -563,13 +569,16 @@ impl AssetProcessor { loop { let mut check_reprocess_queue = core::mem::take(&mut self.data.asset_infos.write().await.check_reprocess_queue); - IoTaskPool::get().scope(|scope| { + TaskPool::get().scope(|scope| { for path in check_reprocess_queue.drain(..) { let processor = self.clone(); let source = self.get_source(path.source()).unwrap(); - scope.spawn(async move { - processor.process_asset(source, path.into()).await; - }); + scope + .builder() + .with_priority(TaskPriority::BlockingIO) + .spawn(async move { + processor.process_asset(source, path.into()).await; + }); } }); let infos = self.data.asset_infos.read().await; diff --git a/crates/bevy_asset/src/server/loaders.rs b/crates/bevy_asset/src/server/loaders.rs index 9c13c861bd986..0be6c3786c5b5 100644 --- a/crates/bevy_asset/src/server/loaders.rs +++ b/crates/bevy_asset/src/server/loaders.rs @@ -5,7 +5,7 @@ use crate::{ use alloc::{boxed::Box, sync::Arc, vec::Vec}; use async_broadcast::RecvError; use bevy_platform::collections::HashMap; -use bevy_tasks::IoTaskPool; +use bevy_tasks::{TaskPool, TaskPriority}; use bevy_utils::TypeIdMap; use core::any::TypeId; use thiserror::Error; @@ -91,7 +91,9 @@ impl AssetLoaders { match maybe_loader { MaybeAssetLoader::Ready(_) => unreachable!(), MaybeAssetLoader::Pending { sender, .. } => { - IoTaskPool::get() + TaskPool::get() + .builder() + .with_priority(TaskPriority::BlockingIO) .spawn(async move { let _ = sender.broadcast(loader).await; }) diff --git a/crates/bevy_asset/src/server/mod.rs b/crates/bevy_asset/src/server/mod.rs index 641952a67150e..9bfbb19890071 100644 --- a/crates/bevy_asset/src/server/mod.rs +++ b/crates/bevy_asset/src/server/mod.rs @@ -27,7 +27,7 @@ use alloc::{ use atomicow::CowArc; use bevy_ecs::prelude::*; use bevy_platform::collections::HashSet; -use bevy_tasks::IoTaskPool; +use bevy_tasks::{TaskPool, TaskPriority}; use core::{any::TypeId, future::Future, panic::AssertUnwindSafe, task::Poll}; use crossbeam_channel::{Receiver, Sender}; use either::Either; @@ -524,15 +524,18 @@ impl AssetServer { let owned_handle = handle.clone(); let server = self.clone(); - let task = IoTaskPool::get().spawn(async move { - if let Err(err) = server - .load_internal(Some(owned_handle), path, false, None) - .await - { - error!("{}", err); - } - drop(guard); - }); + let task = TaskPool::get() + .builder() + .with_priority(TaskPriority::BlockingIO) + .spawn(async move { + if let Err(err) = server + .load_internal(Some(owned_handle), path, false, None) + .await + { + error!("{}", err); + } + drop(guard); + }); #[cfg(not(any(target_arch = "wasm32", not(feature = "multi_threaded"))))] { @@ -587,24 +590,29 @@ impl AssetServer { let id = handle.id().untyped(); let server = self.clone(); - let task = IoTaskPool::get().spawn(async move { - let path_clone = path.clone(); - match server.load_untyped_async(path).await { - Ok(handle) => server.send_asset_event(InternalAssetEvent::Loaded { - id, - loaded_asset: LoadedAsset::new_with_dependencies(LoadedUntypedAsset { handle }) - .into(), - }), - Err(err) => { - error!("{err}"); - server.send_asset_event(InternalAssetEvent::Failed { + let task = TaskPool::get() + .builder() + .with_priority(TaskPriority::BlockingIO) + .spawn(async move { + let path_clone = path.clone(); + match server.load_untyped_async(path).await { + Ok(handle) => server.send_asset_event(InternalAssetEvent::Loaded { id, - path: path_clone, - error: err, - }); + loaded_asset: LoadedAsset::new_with_dependencies(LoadedUntypedAsset { + handle, + }) + .into(), + }), + Err(err) => { + error!("{err}"); + server.send_asset_event(InternalAssetEvent::Failed { + id, + path: path_clone, + error: err, + }); + } } - } - }); + }); #[cfg(not(any(target_arch = "wasm32", not(feature = "multi_threaded"))))] infos.pending_tasks.insert(handle.id().untyped(), task); @@ -827,7 +835,9 @@ impl AssetServer { pub fn reload<'a>(&self, path: impl Into>) { let server = self.clone(); let path = path.into().into_owned(); - IoTaskPool::get() + TaskPool::get() + .builder() + .with_priority(TaskPriority::BlockingIO) .spawn(async move { let mut reloaded = false; @@ -922,29 +932,32 @@ impl AssetServer { let event_sender = self.data.asset_event_sender.clone(); - let task = IoTaskPool::get().spawn(async move { - match future.await { - Ok(asset) => { - let loaded_asset = LoadedAsset::new_with_dependencies(asset).into(); - event_sender - .send(InternalAssetEvent::Loaded { id, loaded_asset }) - .unwrap(); - } - Err(error) => { - let error = AddAsyncError { - error: Arc::new(error), - }; - error!("{error}"); - event_sender - .send(InternalAssetEvent::Failed { - id, - path: Default::default(), - error: AssetLoadError::AddAsyncError(error), - }) - .unwrap(); + let task = TaskPool::get() + .builder() + .with_priority(TaskPriority::BlockingIO) + .spawn(async move { + match future.await { + Ok(asset) => { + let loaded_asset = LoadedAsset::new_with_dependencies(asset).into(); + event_sender + .send(InternalAssetEvent::Loaded { id, loaded_asset }) + .unwrap(); + } + Err(error) => { + let error = AddAsyncError { + error: Arc::new(error), + }; + error!("{error}"); + event_sender + .send(InternalAssetEvent::Failed { + id, + path: Default::default(), + error: AssetLoadError::AddAsyncError(error), + }) + .unwrap(); + } } - } - }); + }); #[cfg(not(any(target_arch = "wasm32", not(feature = "multi_threaded"))))] infos.pending_tasks.insert(id, task); @@ -1025,7 +1038,9 @@ impl AssetServer { let path = path.into_owned(); let server = self.clone(); - IoTaskPool::get() + TaskPool::get() + .builder() + .with_priority(TaskPriority::BlockingIO) .spawn(async move { let Ok(source) = server.get_source(path.source()) else { error!( diff --git a/crates/bevy_diagnostic/src/system_information_diagnostics_plugin.rs b/crates/bevy_diagnostic/src/system_information_diagnostics_plugin.rs index 83d3663895ca5..94297767b9e56 100644 --- a/crates/bevy_diagnostic/src/system_information_diagnostics_plugin.rs +++ b/crates/bevy_diagnostic/src/system_information_diagnostics_plugin.rs @@ -80,7 +80,7 @@ pub mod internal { use bevy_ecs::resource::Resource; use bevy_ecs::{prelude::ResMut, system::Local}; use bevy_platform::time::Instant; - use bevy_tasks::{available_parallelism, block_on, poll_once, AsyncComputeTaskPool, Task}; + use bevy_tasks::{available_parallelism, block_on, poll_once, Task, TaskPool, TaskPriority}; use log::info; use std::sync::Mutex; use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind, System}; @@ -143,7 +143,7 @@ pub mod internal { let last_refresh = last_refresh.get_or_insert_with(Instant::now); - let thread_pool = AsyncComputeTaskPool::get(); + let thread_pool = TaskPool::get(); // Only queue a new system refresh task when necessary // Queuing earlier than that will not give new data @@ -153,35 +153,38 @@ pub mod internal { && tasks.tasks.len() * 2 < available_parallelism() { let sys = Arc::clone(sysinfo); - let task = thread_pool.spawn(async move { - let mut sys = sys.lock().unwrap(); - let pid = sysinfo::get_current_pid().expect("Failed to get current process ID"); - sys.refresh_processes(sysinfo::ProcessesToUpdate::Some(&[pid]), true); + let task = thread_pool + .builder() + .with_priority(TaskPriority::BlockingCompute) + .spawn(async move { + let mut sys = sys.lock().unwrap(); + let pid = sysinfo::get_current_pid().expect("Failed to get current process ID"); + sys.refresh_processes(sysinfo::ProcessesToUpdate::Some(&[pid]), true); - sys.refresh_cpu_specifics(CpuRefreshKind::nothing().with_cpu_usage()); - sys.refresh_memory(); - let system_cpu_usage = sys.global_cpu_usage().into(); - let total_mem = sys.total_memory() as f64; - let used_mem = sys.used_memory() as f64; - let system_mem_usage = used_mem / total_mem * 100.0; + sys.refresh_cpu_specifics(CpuRefreshKind::nothing().with_cpu_usage()); + sys.refresh_memory(); + let system_cpu_usage = sys.global_cpu_usage().into(); + let total_mem = sys.total_memory() as f64; + let used_mem = sys.used_memory() as f64; + let system_mem_usage = used_mem / total_mem * 100.0; - let process_mem_usage = sys - .process(pid) - .map(|p| p.memory() as f64 * BYTES_TO_GIB) - .unwrap_or(0.0); + let process_mem_usage = sys + .process(pid) + .map(|p| p.memory() as f64 * BYTES_TO_GIB) + .unwrap_or(0.0); - let process_cpu_usage = sys - .process(pid) - .map(|p| p.cpu_usage() as f64 / sys.cpus().len() as f64) - .unwrap_or(0.0); + let process_cpu_usage = sys + .process(pid) + .map(|p| p.cpu_usage() as f64 / sys.cpus().len() as f64) + .unwrap_or(0.0); - SysinfoRefreshData { - system_cpu_usage, - system_mem_usage, - process_cpu_usage, - process_mem_usage, - } - }); + SysinfoRefreshData { + system_cpu_usage, + system_mem_usage, + process_cpu_usage, + process_mem_usage, + } + }); tasks.tasks.push(task); *last_refresh = Instant::now(); } diff --git a/crates/bevy_ecs/Cargo.toml b/crates/bevy_ecs/Cargo.toml index c4d685f86662e..38452417a7e99 100644 --- a/crates/bevy_ecs/Cargo.toml +++ b/crates/bevy_ecs/Cargo.toml @@ -11,7 +11,7 @@ categories = ["game-engines", "data-structures"] rust-version = "1.86.0" [features] -default = ["std", "bevy_reflect", "async_executor", "backtrace"] +default = ["std", "bevy_reflect", "backtrace"] # Functionality @@ -49,12 +49,6 @@ bevy_debug_stepping = [] ## This will often provide more detailed error messages. track_location = [] -# Executor Backend - -## Uses `async-executor` as a task execution backend. -## This backend is incompatible with `no_std` targets. -async_executor = ["std", "bevy_tasks/async_executor"] - # Platform Compatibility ## Allows access to the `std` crate. Enabling this feature will prevent compilation diff --git a/crates/bevy_ecs/src/batching.rs b/crates/bevy_ecs/src/batching.rs index ab9f2d582c781..c4ee552eca53a 100644 --- a/crates/bevy_ecs/src/batching.rs +++ b/crates/bevy_ecs/src/batching.rs @@ -29,13 +29,13 @@ pub struct BatchingStrategy { /// /// Defaults to `[1, usize::MAX]`. pub batch_size_limits: Range, - /// The number of batches per thread in the [`ComputeTaskPool`]. + /// The number of batches per thread in the [`TaskPool`]. /// Increasing this value will decrease the batch size, which may /// increase the scheduling overhead for the iteration. /// /// Defaults to 1. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool pub batches_per_thread: usize, } diff --git a/crates/bevy_ecs/src/event/iterators.rs b/crates/bevy_ecs/src/event/iterators.rs index c90aed2a19d5e..ca4031f39ca29 100644 --- a/crates/bevy_ecs/src/event/iterators.rs +++ b/crates/bevy_ecs/src/event/iterators.rs @@ -189,10 +189,10 @@ impl<'a, E: BufferedEvent> EventParIter<'a, E> { /// Unlike normal iteration, the event order is not guaranteed in any form. /// /// # Panics - /// If the [`ComputeTaskPool`] is not initialized. If using this from an event reader that is being + /// If the [`TaskPool`] is not initialized. If using this from an event reader that is being /// initialized and run from the ECS scheduler, this should never panic. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool pub fn for_each(self, func: FN) { self.for_each_with_id(move |e, _| func(e)); } @@ -203,10 +203,10 @@ impl<'a, E: BufferedEvent> EventParIter<'a, E> { /// Note that the order of iteration is not guaranteed, but `EventId`s are ordered by send order. /// /// # Panics - /// If the [`ComputeTaskPool`] is not initialized. If using this from an event reader that is being + /// If the [`TaskPool`] is not initialized. If using this from an event reader that is being /// initialized and run from the ECS scheduler, this should never panic. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool #[cfg_attr( target_arch = "wasm32", expect(unused_mut, reason = "not mutated on this target") @@ -219,7 +219,7 @@ impl<'a, E: BufferedEvent> EventParIter<'a, E> { #[cfg(not(target_arch = "wasm32"))] { - let pool = bevy_tasks::ComputeTaskPool::get(); + let pool = bevy_tasks::TaskPool::get(); let thread_count = pool.thread_num(); if thread_count <= 1 { return self.into_iter().for_each(|(e, i)| func(e, i)); diff --git a/crates/bevy_ecs/src/event/mut_iterators.rs b/crates/bevy_ecs/src/event/mut_iterators.rs index 3fa8378f23c17..b90b4cff9f6ef 100644 --- a/crates/bevy_ecs/src/event/mut_iterators.rs +++ b/crates/bevy_ecs/src/event/mut_iterators.rs @@ -190,10 +190,10 @@ impl<'a, E: BufferedEvent> EventMutParIter<'a, E> { /// Unlike normal iteration, the event order is not guaranteed in any form. /// /// # Panics - /// If the [`ComputeTaskPool`] is not initialized. If using this from an event reader that is being + /// If the [`TaskPool`] is not initialized. If using this from an event reader that is being /// initialized and run from the ECS scheduler, this should never panic. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool pub fn for_each(self, func: FN) { self.for_each_with_id(move |e, _| func(e)); } @@ -204,10 +204,10 @@ impl<'a, E: BufferedEvent> EventMutParIter<'a, E> { /// Note that the order of iteration is not guaranteed, but `EventId`s are ordered by send order. /// /// # Panics - /// If the [`ComputeTaskPool`] is not initialized. If using this from an event reader that is being + /// If the [`TaskPool`] is not initialized. If using this from an event reader that is being /// initialized and run from the ECS scheduler, this should never panic. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool #[cfg_attr( target_arch = "wasm32", expect(unused_mut, reason = "not mutated on this target") @@ -223,7 +223,7 @@ impl<'a, E: BufferedEvent> EventMutParIter<'a, E> { #[cfg(not(target_arch = "wasm32"))] { - let pool = bevy_tasks::ComputeTaskPool::get(); + let pool = bevy_tasks::TaskPool::get(); let thread_count = pool.thread_num(); if thread_count <= 1 { return self.into_iter().for_each(|(e, i)| func(e, i)); diff --git a/crates/bevy_ecs/src/lib.rs b/crates/bevy_ecs/src/lib.rs index 974c371bf31d0..15de1194d1c2b 100644 --- a/crates/bevy_ecs/src/lib.rs +++ b/crates/bevy_ecs/src/lib.rs @@ -170,7 +170,7 @@ mod tests { }; use alloc::{string::String, sync::Arc, vec, vec::Vec}; use bevy_platform::collections::HashSet; - use bevy_tasks::{ComputeTaskPool, TaskPool}; + use bevy_tasks::{TaskPool, TaskPoolBuilder}; use core::{ any::TypeId, marker::PhantomData, @@ -495,7 +495,7 @@ mod tests { #[test] fn par_for_each_dense() { - ComputeTaskPool::get_or_init(TaskPool::default); + TaskPool::get_or_init(TaskPoolBuilder::default); let mut world = World::new(); let e1 = world.spawn(A(1)).id(); let e2 = world.spawn(A(2)).id(); @@ -517,7 +517,7 @@ mod tests { #[test] fn par_for_each_sparse() { - ComputeTaskPool::get_or_init(TaskPool::default); + TaskPool::get_or_init(TaskPoolBuilder::default); let mut world = World::new(); let e1 = world.spawn(SparseStored(1)).id(); let e2 = world.spawn(SparseStored(2)).id(); diff --git a/crates/bevy_ecs/src/query/par_iter.rs b/crates/bevy_ecs/src/query/par_iter.rs index b8d8618fa5bf1..a2559fa23972d 100644 --- a/crates/bevy_ecs/src/query/par_iter.rs +++ b/crates/bevy_ecs/src/query/par_iter.rs @@ -34,10 +34,10 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> { /// Runs `func` on each query result in parallel. /// /// # Panics - /// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// If the [`TaskPool`] is not initialized. If using this from a query that is being /// initialized and run from the ECS scheduler, this should never panic. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool #[inline] pub fn for_each) + Send + Sync + Clone>(self, func: FN) { self.for_each_init(|| {}, |_, item| func(item)); @@ -69,10 +69,10 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> { /// ``` /// /// # Panics - /// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// If the [`TaskPool`] is not initialized. If using this from a query that is being /// initialized and run from the ECS scheduler, this should never panic. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool #[inline] pub fn for_each_init(self, init: INIT, func: FN) where @@ -101,7 +101,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> { } #[cfg(all(not(target_arch = "wasm32"), feature = "multi_threaded"))] { - let thread_count = bevy_tasks::ComputeTaskPool::get().thread_num(); + let thread_count = bevy_tasks::TaskPool::get().thread_num(); if thread_count <= 1 { let init = init(); // SAFETY: See the safety comment above. @@ -185,10 +185,10 @@ impl<'w, 's, D: ReadOnlyQueryData, F: QueryFilter, E: EntityEquivalent + Sync> /// Runs `func` on each query result in parallel. /// /// # Panics - /// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// If the [`TaskPool`] is not initialized. If using this from a query that is being /// initialized and run from the ECS scheduler, this should never panic. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool #[inline] pub fn for_each) + Send + Sync + Clone>(self, func: FN) { self.for_each_init(|| {}, |_, item| func(item)); @@ -240,10 +240,10 @@ impl<'w, 's, D: ReadOnlyQueryData, F: QueryFilter, E: EntityEquivalent + Sync> /// ``` /// /// # Panics - /// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// If the [`TaskPool`] is not initialized. If using this from a query that is being /// initialized and run from the ECS scheduler, this should never panic. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool #[inline] pub fn for_each_init(self, init: INIT, func: FN) where @@ -272,7 +272,7 @@ impl<'w, 's, D: ReadOnlyQueryData, F: QueryFilter, E: EntityEquivalent + Sync> } #[cfg(all(not(target_arch = "wasm32"), feature = "multi_threaded"))] { - let thread_count = bevy_tasks::ComputeTaskPool::get().thread_num(); + let thread_count = bevy_tasks::TaskPool::get().thread_num(); if thread_count <= 1 { let init = init(); // SAFETY: See the safety comment above. @@ -340,10 +340,10 @@ impl<'w, 's, D: QueryData, F: QueryFilter, E: EntityEquivalent + Sync> /// Runs `func` on each query result in parallel. /// /// # Panics - /// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// If the [`TaskPool`] is not initialized. If using this from a query that is being /// initialized and run from the ECS scheduler, this should never panic. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool #[inline] pub fn for_each) + Send + Sync + Clone>(self, func: FN) { self.for_each_init(|| {}, |_, item| func(item)); @@ -395,10 +395,10 @@ impl<'w, 's, D: QueryData, F: QueryFilter, E: EntityEquivalent + Sync> /// ``` /// /// # Panics - /// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// If the [`TaskPool`] is not initialized. If using this from a query that is being /// initialized and run from the ECS scheduler, this should never panic. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool #[inline] pub fn for_each_init(self, init: INIT, func: FN) where @@ -427,7 +427,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter, E: EntityEquivalent + Sync> } #[cfg(all(not(target_arch = "wasm32"), feature = "multi_threaded"))] { - let thread_count = bevy_tasks::ComputeTaskPool::get().thread_num(); + let thread_count = bevy_tasks::TaskPool::get().thread_num(); if thread_count <= 1 { let init = init(); // SAFETY: See the safety comment above. diff --git a/crates/bevy_ecs/src/query/state.rs b/crates/bevy_ecs/src/query/state.rs index 09821a718c668..f825530128dea 100644 --- a/crates/bevy_ecs/src/query/state.rs +++ b/crates/bevy_ecs/src/query/state.rs @@ -1345,7 +1345,7 @@ impl QueryState { /// #[derive(Component, PartialEq, Debug)] /// struct A(usize); /// - /// # bevy_tasks::ComputeTaskPool::get_or_init(|| bevy_tasks::TaskPool::new()); + /// # bevy_tasks::TaskPool::get_or_init(|| bevy_tasks::TaskPoolBuilder::default()); /// /// let mut world = World::new(); /// @@ -1371,11 +1371,11 @@ impl QueryState { /// ``` /// /// # Panics - /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// The [`TaskPool`] is not initialized. If using this from a query that is being /// initialized and run from the ECS scheduler, this should never panic. /// /// [`par_iter`]: Self::par_iter - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool #[inline] pub fn par_iter_mut<'w, 's>(&'s mut self, world: &'w mut World) -> QueryParIter<'w, 's, D, F> { self.query_mut(world).par_iter_inner() @@ -1386,7 +1386,7 @@ impl QueryState { /// `iter()` method, but cannot be chained like a normal [`Iterator`]. /// /// # Panics - /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// The [`TaskPool`] is not initialized. If using this from a query that is being /// initialized and run from the ECS scheduler, this should never panic. /// /// # Safety @@ -1396,7 +1396,7 @@ impl QueryState { /// This does not validate that `world.id()` matches `self.world_id`. Calling this on a `world` /// with a mismatched [`WorldId`] is unsound. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool #[cfg(all(not(target_arch = "wasm32"), feature = "multi_threaded"))] pub(crate) unsafe fn par_fold_init_unchecked_manual<'w, 's, T, FN, INIT>( &'s self, @@ -1415,7 +1415,7 @@ impl QueryState { // QueryState::par_many_fold_init_unchecked_manual, QueryState::par_many_unique_fold_init_unchecked_manual use arrayvec::ArrayVec; - bevy_tasks::ComputeTaskPool::get().scope(|scope| { + bevy_tasks::TaskPool::get().scope(|scope| { // SAFETY: We only access table data that has been registered in `self.component_access`. let tables = unsafe { &world.storages().tables }; let archetypes = world.archetypes(); @@ -1500,7 +1500,7 @@ impl QueryState { /// equivalent `iter_many_unique()` method, but cannot be chained like a normal [`Iterator`]. /// /// # Panics - /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// The [`TaskPool`] is not initialized. If using this from a query that is being /// initialized and run from the ECS scheduler, this should never panic. /// /// # Safety @@ -1510,7 +1510,7 @@ impl QueryState { /// This does not validate that `world.id()` matches `self.world_id`. Calling this on a `world` /// with a mismatched [`WorldId`] is unsound. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool #[cfg(all(not(target_arch = "wasm32"), feature = "multi_threaded"))] pub(crate) unsafe fn par_many_unique_fold_init_unchecked_manual<'w, 's, T, FN, INIT, E>( &'s self, @@ -1530,7 +1530,7 @@ impl QueryState { // QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter,QueryState::par_fold_init_unchecked_manual // QueryState::par_many_fold_init_unchecked_manual, QueryState::par_many_unique_fold_init_unchecked_manual - bevy_tasks::ComputeTaskPool::get().scope(|scope| { + bevy_tasks::TaskPool::get().scope(|scope| { let chunks = entity_list.chunks_exact(batch_size as usize); let remainder = chunks.remainder(); @@ -1563,7 +1563,7 @@ impl QueryState { /// `iter_many()` method, but cannot be chained like a normal [`Iterator`]. /// /// # Panics - /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// The [`TaskPool`] is not initialized. If using this from a query that is being /// initialized and run from the ECS scheduler, this should never panic. /// /// # Safety @@ -1573,7 +1573,7 @@ impl QueryState { /// This does not validate that `world.id()` matches `self.world_id`. Calling this on a `world` /// with a mismatched [`WorldId`] is unsound. /// - /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + /// [`TaskPool`]: bevy_tasks::TaskPool #[cfg(all(not(target_arch = "wasm32"), feature = "multi_threaded"))] pub(crate) unsafe fn par_many_fold_init_unchecked_manual<'w, 's, T, FN, INIT, E>( &'s self, @@ -1593,7 +1593,7 @@ impl QueryState { // QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::par_fold_init_unchecked_manual // QueryState::par_many_fold_init_unchecked_manual, QueryState::par_many_unique_fold_init_unchecked_manual - bevy_tasks::ComputeTaskPool::get().scope(|scope| { + bevy_tasks::TaskPool::get().scope(|scope| { let chunks = entity_list.chunks_exact(batch_size as usize); let remainder = chunks.remainder(); diff --git a/crates/bevy_ecs/src/schedule/executor/mod.rs b/crates/bevy_ecs/src/schedule/executor/mod.rs index 08fdf2374c9fd..7219bbb3ab801 100644 --- a/crates/bevy_ecs/src/schedule/executor/mod.rs +++ b/crates/bevy_ecs/src/schedule/executor/mod.rs @@ -11,7 +11,7 @@ use core::any::TypeId; pub use self::{simple::SimpleExecutor, single_threaded::SingleThreadedExecutor}; #[cfg(feature = "std")] -pub use self::multi_threaded::{MainThreadExecutor, MultiThreadedExecutor}; +pub use self::multi_threaded::{MainThreadSpawner, MultiThreadedExecutor}; use fixedbitset::FixedBitSet; diff --git a/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs b/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs index bd99344f498fa..342033586a301 100644 --- a/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs +++ b/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs @@ -1,7 +1,6 @@ use alloc::{boxed::Box, vec::Vec}; use bevy_platform::cell::SyncUnsafeCell; -use bevy_platform::sync::Arc; -use bevy_tasks::{ComputeTaskPool, Scope, TaskPool, ThreadExecutor}; +use bevy_tasks::{Scope, ScopeTaskTarget, TaskPool, TaskPoolBuilder, ThreadSpawner}; use concurrent_queue::ConcurrentQueue; use core::{any::Any, panic::AssertUnwindSafe}; use fixedbitset::FixedBitSet; @@ -270,14 +269,12 @@ impl SystemExecutor for MultiThreadedExecutor { } let thread_executor = world - .get_resource::() + .get_resource::() .map(|e| e.0.clone()); - let thread_executor = thread_executor.as_deref(); let environment = &Environment::new(self, schedule, world); - ComputeTaskPool::get_or_init(TaskPool::default).scope_with_executor( - false, + TaskPool::get_or_init(TaskPoolBuilder::default).scope_with_executor( thread_executor, |scope| { let context = Context { @@ -703,7 +700,11 @@ impl ExecutorState { context.scope.spawn(task); } else { self.local_thread_running = true; - context.scope.spawn_on_external(task); + context + .scope + .builder() + .with_target(ScopeTaskTarget::External) + .spawn(task); } } @@ -727,7 +728,11 @@ impl ExecutorState { context.system_completed(system_index, res, system); }; - context.scope.spawn_on_scope(task); + context + .scope + .builder() + .with_target(ScopeTaskTarget::Scope) + .spawn(task); } else { let task = async move { // SAFETY: `can_run` returned true for this system, which means @@ -749,7 +754,11 @@ impl ExecutorState { context.system_completed(system_index, res, system); }; - context.scope.spawn_on_scope(task); + context + .scope + .builder() + .with_target(ScopeTaskTarget::Scope) + .spawn(task); } self.exclusive_running = true; @@ -864,20 +873,20 @@ unsafe fn evaluate_and_fold_conditions( .fold(true, |acc, res| acc && res) } -/// New-typed [`ThreadExecutor`] [`Resource`] that is used to run systems on the main thread +/// New-typed [`ThreadSpawner`] [`Resource`] that is used to run systems on the main thread #[derive(Resource, Clone)] -pub struct MainThreadExecutor(pub Arc>); +pub struct MainThreadSpawner(pub ThreadSpawner); -impl Default for MainThreadExecutor { +impl Default for MainThreadSpawner { fn default() -> Self { Self::new() } } -impl MainThreadExecutor { +impl MainThreadSpawner { /// Creates a new executor that can be used to run systems on the main thread. pub fn new() -> Self { - MainThreadExecutor(TaskPool::get_thread_executor()) + MainThreadSpawner(TaskPool::get().current_thread_spawner()) } } diff --git a/crates/bevy_ecs/src/schedule/mod.rs b/crates/bevy_ecs/src/schedule/mod.rs index 1b01e031ef978..f120817d58edd 100644 --- a/crates/bevy_ecs/src/schedule/mod.rs +++ b/crates/bevy_ecs/src/schedule/mod.rs @@ -110,12 +110,12 @@ mod tests { #[cfg(not(miri))] fn parallel_execution() { use alloc::sync::Arc; - use bevy_tasks::{ComputeTaskPool, TaskPool}; + use bevy_tasks::{TaskPool, TaskPoolBuilder}; use std::sync::Barrier; let mut world = World::default(); let mut schedule = Schedule::default(); - let thread_count = ComputeTaskPool::get_or_init(TaskPool::default).thread_num(); + let thread_count = TaskPool::get_or_init(TaskPoolBuilder::default).thread_num(); let barrier = Arc::new(Barrier::new(thread_count)); diff --git a/crates/bevy_ecs/src/system/commands/parallel_scope.rs b/crates/bevy_ecs/src/system/commands/parallel_scope.rs index bee491017d500..76ba6163c71b6 100644 --- a/crates/bevy_ecs/src/system/commands/parallel_scope.rs +++ b/crates/bevy_ecs/src/system/commands/parallel_scope.rs @@ -29,7 +29,7 @@ struct ParallelCommandQueue { /// /// ``` /// # use bevy_ecs::prelude::*; -/// # use bevy_tasks::ComputeTaskPool; +/// # use bevy_tasks::TaskPool; /// # /// # #[derive(Component)] /// # struct Velocity; diff --git a/crates/bevy_gltf/src/loader/mod.rs b/crates/bevy_gltf/src/loader/mod.rs index 3eed903cca8ee..5ddd5e2b68eb8 100644 --- a/crates/bevy_gltf/src/loader/mod.rs +++ b/crates/bevy_gltf/src/loader/mod.rs @@ -44,7 +44,7 @@ use bevy_platform::collections::{HashMap, HashSet}; use bevy_render::render_resource::Face; use bevy_scene::Scene; #[cfg(not(target_arch = "wasm32"))] -use bevy_tasks::IoTaskPool; +use bevy_tasks::{TaskPool, TaskPriority}; use bevy_transform::components::Transform; use gltf::{ @@ -620,24 +620,27 @@ impl GltfLoader { } } else { #[cfg(not(target_arch = "wasm32"))] - IoTaskPool::get() + TaskPool::get() .scope(|scope| { gltf.textures().for_each(|gltf_texture| { let parent_path = load_context.path().parent().unwrap(); let linear_textures = &linear_textures; let buffer_data = &buffer_data; - scope.spawn(async move { - load_image( - gltf_texture, - buffer_data, - linear_textures, - parent_path, - loader.supported_compressed_formats, - default_sampler, - settings, - ) - .await - }); + scope + .builder() + .with_priority(TaskPriority::BlockingIO) + .spawn(async move { + load_image( + gltf_texture, + buffer_data, + linear_textures, + parent_path, + loader.supported_compressed_formats, + default_sampler, + settings, + ) + .await + }); }); }) .into_iter() diff --git a/crates/bevy_input/Cargo.toml b/crates/bevy_input/Cargo.toml index c32a87a52d9dc..2fe037a58fba4 100644 --- a/crates/bevy_input/Cargo.toml +++ b/crates/bevy_input/Cargo.toml @@ -9,7 +9,7 @@ license = "MIT OR Apache-2.0" keywords = ["bevy"] [features] -default = ["std", "bevy_reflect", "bevy_ecs/async_executor", "smol_str"] +default = ["std", "bevy_reflect", "smol_str"] # Functionality diff --git a/crates/bevy_input_focus/Cargo.toml b/crates/bevy_input_focus/Cargo.toml index 60b824258df31..2cd9d18771b8c 100644 --- a/crates/bevy_input_focus/Cargo.toml +++ b/crates/bevy_input_focus/Cargo.toml @@ -10,7 +10,7 @@ keywords = ["bevy"] rust-version = "1.85.0" [features] -default = ["std", "bevy_reflect", "bevy_ecs/async_executor"] +default = ["std", "bevy_reflect"] # Functionality diff --git a/crates/bevy_internal/Cargo.toml b/crates/bevy_internal/Cargo.toml index f7b98bef606ac..9e3607a719a1c 100644 --- a/crates/bevy_internal/Cargo.toml +++ b/crates/bevy_internal/Cargo.toml @@ -369,15 +369,6 @@ libm = [ "bevy_window?/libm", ] -# Uses `async-executor` as a task execution backend. -# This backend is incompatible with `no_std` targets. -async_executor = [ - "std", - "bevy_tasks/async_executor", - "bevy_ecs/async_executor", - "bevy_transform/async_executor", -] - # Enables use of browser APIs. # Note this is currently only applicable on `wasm32` architectures. web = ["bevy_app/web", "bevy_platform/web", "bevy_reflect/web"] diff --git a/crates/bevy_pbr/src/meshlet/from_mesh.rs b/crates/bevy_pbr/src/meshlet/from_mesh.rs index 141f4da0238ed..5c6b84098bd18 100644 --- a/crates/bevy_pbr/src/meshlet/from_mesh.rs +++ b/crates/bevy_pbr/src/meshlet/from_mesh.rs @@ -10,7 +10,7 @@ use bevy_math::{ use bevy_mesh::{Indices, Mesh}; use bevy_platform::collections::HashMap; use bevy_render::render_resource::PrimitiveTopology; -use bevy_tasks::{AsyncComputeTaskPool, ParallelSlice}; +use bevy_tasks::{ParallelSlice, TaskPool}; use bitvec::{order::Lsb0, vec::BitVec, view::BitView}; use core::{f32, ops::Range}; use itertools::Itertools; @@ -136,7 +136,7 @@ impl MeshletMesh { position_only_vertex_count, ); - let simplified = groups.par_chunk_map(AsyncComputeTaskPool::get(), 1, |_, groups| { + let simplified = groups.par_chunk_map(AsyncTaskPool::get(), 1, |_, groups| { let mut group = groups[0].clone(); // If the group only has a single meshlet we can't simplify it diff --git a/crates/bevy_remote/src/http.rs b/crates/bevy_remote/src/http.rs index 4e36e4a0bfe94..519af66bb317d 100644 --- a/crates/bevy_remote/src/http.rs +++ b/crates/bevy_remote/src/http.rs @@ -17,7 +17,7 @@ use async_io::Async; use bevy_app::{App, Plugin, Startup}; use bevy_ecs::resource::Resource; use bevy_ecs::system::Res; -use bevy_tasks::{futures_lite::StreamExt, IoTaskPool}; +use bevy_tasks::{futures_lite::StreamExt, TaskPool}; use core::{ convert::Infallible, net::{IpAddr, Ipv4Addr}, @@ -201,7 +201,9 @@ fn start_http_server( remote_port: Res, headers: Res, ) { - IoTaskPool::get() + TaskPool::get() + .builder() + .with_priority(bevy_tasks::TaskPriority::AsyncIO) .spawn(server_main( address.0, remote_port.0, @@ -236,7 +238,9 @@ async fn listen( let request_sender = request_sender.clone(); let headers = headers.clone(); - IoTaskPool::get() + TaskPool::get() + .builder() + .with_priority(bevy_tasks::TaskPriority::AsyncIO) .spawn(async move { let _ = handle_client(client, request_sender, headers).await; }) diff --git a/crates/bevy_render/src/lib.rs b/crates/bevy_render/src/lib.rs index 99666e2dff061..d3e1f1f3be325 100644 --- a/crates/bevy_render/src/lib.rs +++ b/crates/bevy_render/src/lib.rs @@ -353,7 +353,8 @@ impl Plugin for RenderPlugin { // In wasm, spawn a task and detach it for execution #[cfg(target_arch = "wasm32")] - bevy_tasks::IoTaskPool::get() + bevy_tasks::TaskPool::get() + .with_priority(bevy_tasks::TaskPriority::BlockingIO) .spawn_local(async_renderer) .detach(); // Otherwise, just block for it to complete diff --git a/crates/bevy_render/src/pipelined_rendering.rs b/crates/bevy_render/src/pipelined_rendering.rs index 00dfc4ba0e19c..f837d71121896 100644 --- a/crates/bevy_render/src/pipelined_rendering.rs +++ b/crates/bevy_render/src/pipelined_rendering.rs @@ -3,10 +3,10 @@ use async_channel::{Receiver, Sender}; use bevy_app::{App, AppExit, AppLabel, Plugin, SubApp}; use bevy_ecs::{ resource::Resource, - schedule::MainThreadExecutor, + schedule::MainThreadSpawner, world::{Mut, World}, }; -use bevy_tasks::ComputeTaskPool; +use bevy_tasks::TaskPool; use crate::RenderApp; @@ -114,7 +114,7 @@ impl Plugin for PipelinedRenderingPlugin { if app.get_sub_app(RenderApp).is_none() { return; } - app.insert_resource(MainThreadExecutor::new()); + app.insert_resource(MainThreadSpawner::new()); let mut sub_app = SubApp::new(); sub_app.set_extract(renderer_extract); @@ -136,7 +136,7 @@ impl Plugin for PipelinedRenderingPlugin { .expect("Unable to get RenderApp. Another plugin may have removed the RenderApp before PipelinedRenderingPlugin"); // clone main thread executor to render world - let executor = app.world().get_resource::().unwrap(); + let executor = app.world().get_resource::().unwrap(); render_app.world_mut().insert_resource(executor.clone()); render_to_app_sender.send_blocking(render_app).unwrap(); @@ -150,10 +150,10 @@ impl Plugin for PipelinedRenderingPlugin { #[cfg(feature = "trace")] let _span = tracing::info_span!("render thread").entered(); - let compute_task_pool = ComputeTaskPool::get(); + let task_pool = TaskPool::get(); loop { // run a scope here to allow main world to use this thread while it's waiting for the render app - let sent_app = compute_task_pool + let sent_app = task_pool .scope(|s| { s.spawn(async { app_to_render_receiver.recv().await }); }) @@ -181,12 +181,12 @@ impl Plugin for PipelinedRenderingPlugin { // This function waits for the rendering world to be received, // runs extract, and then sends the rendering world back to the render thread. fn renderer_extract(app_world: &mut World, _world: &mut World) { - app_world.resource_scope(|world, main_thread_executor: Mut| { + app_world.resource_scope(|world, main_thread_executor: Mut| { world.resource_scope(|world, mut render_channels: Mut| { // we use a scope here to run any main thread tasks that the render world still needs to run // while we wait for the render world to be received. - if let Some(mut render_app) = ComputeTaskPool::get() - .scope_with_executor(true, Some(&*main_thread_executor.0), |s| { + if let Some(mut render_app) = TaskPool::get() + .scope_with_executor(Some(main_thread_executor.0.clone()), |s| { s.spawn(async { render_channels.recv().await }); }) .pop() diff --git a/crates/bevy_render/src/render_resource/pipeline_cache.rs b/crates/bevy_render/src/render_resource/pipeline_cache.rs index 1224e1998fb0d..187430b5861e7 100644 --- a/crates/bevy_render/src/render_resource/pipeline_cache.rs +++ b/crates/bevy_render/src/render_resource/pipeline_cache.rs @@ -806,8 +806,15 @@ fn create_pipeline_task( task: impl Future> + Send + 'static, sync: bool, ) -> CachedPipelineState { + use bevy_tasks::{TaskPool, TaskPriority}; + if !sync { - return CachedPipelineState::Creating(bevy_tasks::AsyncComputeTaskPool::get().spawn(task)); + return CachedPipelineState::Creating( + TaskPool::get() + .builder() + .with_priority(TaskPriority::BlockingCompute) + .spawn(task), + ); } match bevy_tasks::block_on(task) { diff --git a/crates/bevy_render/src/renderer/mod.rs b/crates/bevy_render/src/renderer/mod.rs index 019d5f50e2422..69d05afcebd7e 100644 --- a/crates/bevy_render/src/renderer/mod.rs +++ b/crates/bevy_render/src/renderer/mod.rs @@ -579,24 +579,22 @@ impl<'w> RenderContext<'w> { #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] { - let mut task_based_command_buffers = - bevy_tasks::ComputeTaskPool::get().scope(|task_pool| { - for (i, queued_command_buffer) in - self.command_buffer_queue.into_iter().enumerate() - { - match queued_command_buffer { - QueuedCommandBuffer::Ready(command_buffer) => { - command_buffers.push((i, command_buffer)); - } - QueuedCommandBuffer::Task(command_buffer_generation_task) => { - let render_device = self.render_device.clone(); - task_pool.spawn(async move { - (i, command_buffer_generation_task(render_device)) - }); - } + let mut task_based_command_buffers = bevy_tasks::TaskPool::get().scope(|task_pool| { + for (i, queued_command_buffer) in self.command_buffer_queue.into_iter().enumerate() + { + match queued_command_buffer { + QueuedCommandBuffer::Ready(command_buffer) => { + command_buffers.push((i, command_buffer)); + } + QueuedCommandBuffer::Task(command_buffer_generation_task) => { + let render_device = self.render_device.clone(); + task_pool.spawn(async move { + (i, command_buffer_generation_task(render_device)) + }); } } - }); + } + }); command_buffers.append(&mut task_based_command_buffers); } diff --git a/crates/bevy_render/src/view/window/screenshot.rs b/crates/bevy_render/src/view/window/screenshot.rs index b87d76252c557..da2ca3a2f212f 100644 --- a/crates/bevy_render/src/view/window/screenshot.rs +++ b/crates/bevy_render/src/view/window/screenshot.rs @@ -25,7 +25,7 @@ use bevy_image::{Image, TextureFormatPixelInfo, ToExtents}; use bevy_platform::collections::HashSet; use bevy_reflect::Reflect; use bevy_shader::Shader; -use bevy_tasks::AsyncComputeTaskPool; +use bevy_tasks::{TaskPool, TaskPriority}; use bevy_utils::default; use bevy_window::{PrimaryWindow, WindowRef}; use core::ops::Deref; @@ -678,6 +678,10 @@ pub(crate) fn collect_screenshots(world: &mut World) { } }; - AsyncComputeTaskPool::get().spawn(finish).detach(); + TaskPool::get() + .builder() + .with_priority(TaskPriority::BlockingCompute) + .spawn(finish) + .detach(); } } diff --git a/crates/bevy_tasks/Cargo.toml b/crates/bevy_tasks/Cargo.toml index 7974e0c82cb23..fb40c1c8ad563 100644 --- a/crates/bevy_tasks/Cargo.toml +++ b/crates/bevy_tasks/Cargo.toml @@ -9,20 +9,24 @@ license = "MIT OR Apache-2.0" keywords = ["bevy"] [features] -default = ["async_executor", "futures-lite"] +default = ["futures-lite"] # Enables multi-threading support. # Without this feature, all tasks will be run on a single thread. -multi_threaded = [ - "bevy_platform/std", - "dep:async-channel", - "dep:concurrent-queue", - "async_executor", -] +multi_threaded = ["bevy_platform/std", "dep:async-channel", "bevy_executor"] -# Uses `async-executor` as a task execution backend. +# Uses a Bevy-specific fork of `async-executor` as a task execution backend. # This backend is incompatible with `no_std` targets. -async_executor = ["bevy_platform/std", "dep:async-executor", "futures-lite"] +bevy_executor = [ + "dep:fastrand", + "dep:slab", + "dep:thread_local", + "dep:crossbeam-utils", + "dep:pin-project-lite", + "futures-lite", + "async-task/std", + "concurrent-queue/std", +] # Provide an implementation of `block_on` from `futures-lite`. futures-lite = ["bevy_platform/std", "futures-lite/std"] @@ -44,18 +48,20 @@ derive_more = { version = "2", default-features = false, features = [ "deref", "deref_mut", ] } -async-executor = { version = "1.11", optional = true } +slab = { version = "0.4", optional = true } +pin-project-lite = { version = "0.2", optional = true } +thread_local = { version = "1.1", optional = true } +fastrand = { version = "2.3", optional = true, default-features = false } async-channel = { version = "2.3.0", optional = true } async-io = { version = "2.0.0", optional = true } -concurrent-queue = { version = "2.0.0", optional = true } atomic-waker = { version = "1", default-features = false } -crossbeam-queue = { version = "0.3", default-features = false, features = [ - "alloc", -] } +concurrent-queue = { version = "2.5", default-features = false } +crossbeam-utils = { version = "0.8", default-features = false, optional = true } +log = "0.4" [target.'cfg(target_arch = "wasm32")'.dependencies] -pin-project = "1" async-channel = "2.3.0" +pin-project-lite = "0.2" [target.'cfg(not(all(target_has_atomic = "8", target_has_atomic = "16", target_has_atomic = "32", target_has_atomic = "64", target_has_atomic = "ptr")))'.dependencies] async-task = { version = "4.4.0", default-features = false, features = [ @@ -73,6 +79,7 @@ futures-lite = { version = "2.0.1", default-features = false, features = [ "std", ] } async-channel = "2.3.0" +async-io = "2.0.0" [lints] workspace = true diff --git a/crates/bevy_tasks/README.md b/crates/bevy_tasks/README.md index 04815df35e6ed..2c18b8861623f 100644 --- a/crates/bevy_tasks/README.md +++ b/crates/bevy_tasks/README.md @@ -14,24 +14,24 @@ a single thread and having that thread await the completion of those tasks. This generating the tasks from a slice of data. This library is intended for games and makes no attempt to ensure fairness or ordering of spawned tasks. -It is based on [`async-executor`][async-executor], a lightweight executor that allows the end user to manage their own threads. -`async-executor` is based on async-task, a core piece of async-std. +It is based on a fork of [`async-executor`][async-executor], a lightweight executor that allows the end user to manage their own threads. +`async-executor` is based on [`async-task`][async-task], a core piece of [`smol`][smol]. ## Usage In order to be able to optimize task execution in multi-threaded environments, -bevy provides three different thread pools via which tasks of different kinds can be spawned. +Bevy supports a thread pool via which tasks of different priorities can be spawned. (The same API is used in single-threaded environments, even if execution is limited to a single thread. This currently applies to Wasm targets.) -The determining factor for what kind of work should go in each pool is latency requirements: +The determining factor for how work is prioritized based on latency requirements: * For CPU-intensive work (tasks that generally spin until completion) we have a standard - [`ComputeTaskPool`] and an [`AsyncComputeTaskPool`]. Work that does not need to be completed to - present the next frame should go to the [`AsyncComputeTaskPool`]. + `Compute` priority, the default. Work that does not need to be completed to present the + next frame be set to the `BlockingCompute` priority. * For IO-intensive work (tasks that spend very little time in a "woken" state) we have an - [`IoTaskPool`] whose tasks are expected to complete very quickly. Generally speaking, they should just - await receiving data from somewhere (i.e. disk) and signal other systems when the data is ready + [`AsyncIO`] priority whose tasks are expected to complete very quickly. Generally speaking, they should just + await receiving data from somewhere (i.e. network) and signal other systems when the data is ready for consumption. (likely via channels) ## `no_std` Support @@ -40,4 +40,6 @@ To enable `no_std` support in this crate, you will need to disable default featu [bevy]: https://bevy.org [rayon]: https://github.com/rayon-rs/rayon -[async-executor]: https://github.com/stjepang/async-executor +[async-executor]: https://github.com/smol-rs/async-executor +[smol]: https://github.com/smol-rs/smol +[async-task]: https://github.com/smol-rs/async-task diff --git a/crates/bevy_tasks/src/bevy_executor.rs b/crates/bevy_tasks/src/bevy_executor.rs new file mode 100644 index 0000000000000..ce91e4e98a0b3 --- /dev/null +++ b/crates/bevy_tasks/src/bevy_executor.rs @@ -0,0 +1,1204 @@ +#![expect( + unsafe_code, + reason = "Executor code requires unsafe code for dealing with non-'static lifetimes" +)] +#![allow( + dead_code, + reason = "Not all functions are used with every feature combination" +)] + +use core::panic::{RefUnwindSafe, UnwindSafe}; +use core::pin::Pin; +use core::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering}; +use core::task::{Context, Poll, Waker}; +use core::cell::UnsafeCell; +use core::mem; +use std::thread::{AccessError, ThreadId}; + +use crate::{Metadata, TaskPriority}; +use alloc::collections::VecDeque; +use alloc::fmt; +use core::num::NonZeroUsize; +use async_task::Builder; +use bevy_platform::prelude::Vec; +use bevy_platform::sync::{Mutex, PoisonError, RwLock, TryLockError}; +use concurrent_queue::ConcurrentQueue; +use futures_lite::{future,FutureExt}; +use slab::Slab; +use thread_local::ThreadLocal; +use crossbeam_utils::CachePadded; + +type Runnable = async_task::Runnable; +type Task = async_task::Task; + +// ThreadLocalState *must* stay `Sync` due to a currently existing soundness hole. +// See: https://github.com/Amanieu/thread_local-rs/issues/75 +static THREAD_LOCAL_STATE: ThreadLocal = ThreadLocal::new(); + +pub(crate) fn install_runtime_into_current_thread(executor: &'static Executor) { + // Use LOCAL_QUEUE here to set the thread destructor + LOCAL_QUEUE.with(|_| { + let tls = THREAD_LOCAL_STATE.get_or_default(); + let state_ptr: *const State = &executor.state; + tls.executor.swap(state_ptr.cast_mut(), Ordering::Relaxed); + }); +} + +std::thread_local! { + static LOCAL_QUEUE: CachePadded> = const { + CachePadded::new(UnsafeCell::new(LocalQueue { + local_queue: VecDeque::new(), + local_active: Slab::new(), + })) + }; +} + +/// # Safety +/// This must not be accessed at the same time as `LOCAL_QUEUE` in any way. +#[inline(always)] +unsafe fn try_with_local_queue(f: impl FnOnce(&mut LocalQueue) -> T) -> Result { + LOCAL_QUEUE.try_with(|tls| { + // SAFETY: This value is in thread local storage and thus can only be accessed + // from one thread. The caller guarantees that this function is not used with + // LOCAL_QUEUE in any way. + f(unsafe { &mut *tls.get() }) + }) +} + +struct LocalQueue { + local_queue: VecDeque, + local_active: Slab, +} + +impl Drop for LocalQueue { + fn drop(&mut self) { + for waker in self.local_active.drain() { + waker.wake(); + } + + while self.local_queue.pop_front().is_some() {} + } +} + +struct ThreadLocalState { + executor: AtomicPtr, + stealable_queue: ConcurrentQueue, + thread_locked_queue: ConcurrentQueue, +} + +impl Default for ThreadLocalState { + fn default() -> Self { + Self { + executor: AtomicPtr::new(core::ptr::null_mut()), + stealable_queue: ConcurrentQueue::bounded(512), + thread_locked_queue: ConcurrentQueue::unbounded(), + } + } +} + +/// A task spawner for a specific thread. Must be created by calling [`TaskPool::current_thread_spawner`] +/// from the target thread. +/// +/// [`TaskPool::current_thread_spawner`]: crate::TaskPool::current_thread_spawner +#[derive(Clone, Debug)] +pub struct ThreadSpawner { + thread_id: ThreadId, + target_queue: &'static ConcurrentQueue, + state: &'static State, +} + +impl ThreadSpawner { + /// Spawns a task onto the specific target thread. + pub fn spawn( + &self, + future: impl Future + Send + 'static, + ) -> crate::Task { + // SAFETY: T and `future` are both 'static, so the Task is guaranteed to not outlive it. + unsafe { self.spawn_scoped(future) } + } + + /// Spawns a task onto the executor. + /// + /// # Safety + /// The caller must ensure that the returned Task does not outlive 'a. + pub unsafe fn spawn_scoped<'a, T: Send + 'a>( + &self, + future: impl Future + Send + 'a, + ) -> crate::Task { + let builder = Builder::new() + .propagate_panic(true) + .metadata(Metadata { + priority: TaskPriority::Compute, + is_send: false, + }); + + // Create the task and register it in the set of active tasks. + // + // SAFETY: + // + // - `future` is `Send`. Therefore we do not need to worry about what thread + // the produced `Runnable` is used and dropped from. + // - `future` is not `'static`, but the caller must make sure that the Task + // and thus the `Runnable` will not outlive `'a`. + // - `self.schedule()` is `Send`, `Sync` and `'static`, as checked below. + // Therefore we do not need to worry about what is done with the + // `Waker`. + let (runnable, task) = unsafe { + builder.spawn_unchecked(|_| future, self.schedule()) + }; + + // Instead of directly scheduling this task, it's put into the onto the + // thread locked queue to be moved to the target thread, where it will + // either be run immediately or flushed into the thread's local queue. + let result = self.target_queue.push(runnable); + debug_assert!(result.is_ok()); + crate::Task::new(task) + } + + /// Returns a function that schedules a runnable task when it gets woken up. + fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static { + let thread_id = self.thread_id; + let state = self.state; + + move |runnable| { + // SAFETY: This value is in thread local storage and thus can only be accessed + // from one thread. There are no instances where the value is accessed mutably + // from multiple locations simultaneously. + if unsafe { try_with_local_queue(|tls| tls.local_queue.push_back(runnable)) }.is_ok() { + state.notify_specific_thread(thread_id, false); + } + } + } +} + +/// An async executor. +pub struct Executor { + /// The executor state. + state: State, +} + +impl UnwindSafe for Executor {} +impl RefUnwindSafe for Executor {} + +impl fmt::Debug for Executor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + debug_executor(self, "Executor", f) + } +} + +impl Executor { + /// Creates a new executor. + pub const fn new() -> Executor { + Executor { + state: State::new() + } + } + + /// # Safety + /// Must ensure that no other thread can call into the Executor from another + /// thread while this function is running. + pub unsafe fn set_priority_limits(&self, limits: [Option; TaskPriority::MAX]) { + let executor_limits = self.state.priority_limits.iter(); + for (i, (limit, executor_limit)) in limits.into_iter().zip(executor_limits).enumerate() { + // SAFETY: The caller is required to ensure that no other thread can call into the Executor from another + // thread while this function is running. + unsafe { executor_limit.set_limit(limit) }; + if let Some(limit) = limit { + log::debug!("{:?} tasks now limited to {:?} simultaneous tasks.", TaskPriority::from_index(i).unwrap(), limit); + } else { + log::debug!("{:?} are now not limited.", TaskPriority::from_index(i).unwrap()); + } + } + } + + /// Spawns a 'static and Send task onto the executor. + pub fn spawn(&'static self, future: impl Future + Send + 'static, metadata: Metadata) -> Task { + // SAFETY: Both `T` and `future` are 'static. + unsafe { self.spawn_scoped(future, metadata) } + } + + /// Spawns a non-'static Send task onto the executor. + /// + /// # Safety + /// The caller must ensure that the returned Task does not outlive 'a. + pub unsafe fn spawn_scoped<'a, T: Send + 'a>(&'static self, future: impl Future + Send + 'a, mut metadata: Metadata) -> Task { + metadata.is_send = true; + let builder = Builder::new().propagate_panic(true).metadata(metadata); + // Create the task and register it in the set of active tasks. + // + // SAFETY: + // + // - `future` is `Send`. Therefore we do not need to worry about what thread + // the produced `Runnable` is used and dropped from. + // - `future` is not `'static`, but we make sure that the `Runnable` does + // not outlive `'a`. When the executor is dropped, the `active` field is + // drained and all of the `Waker`s are woken. Then, the queue inside of + // the `Executor` is drained of all of its runnables. This ensures that + // runnables are dropped and this precondition is satisfied. + // - `self.schedule()` is `Send`, `Sync` and `'static`, as checked below. + // Therefore we do not need to worry about what is done with the + // `Waker`. + let (runnable, task) = unsafe { + builder.spawn_unchecked(|_| future, self.schedule()) + }; + + runnable.schedule(); + task + } + + /// Spawns a non-Send task onto the executor. + pub fn spawn_local(&'static self, future: impl Future + 'static, metadata: Metadata) -> Task { + // SAFETY: future is 'static + unsafe { self.spawn_local_scoped(future, metadata) } + } + + /// Spawns a non-'static and non-Send task onto the executor. + /// + /// # Safety + /// The caller must ensure that the returned Task does not outlive 'a. + pub unsafe fn spawn_local_scoped<'a, T: 'a>( + &'static self, + future: impl Future + 'a, + mut metadata: Metadata, + ) -> Task { + metadata.is_send = false; + // Remove the task from the set of active tasks when the future finishes. + // + // SAFETY: There are no instances where the value is accessed mutably + // from multiple locations simultaneously. + let (runnable, task) = unsafe { + try_with_local_queue(|tls| { + let entry = tls.local_active.vacant_entry(); + let index = entry.key(); + let builder = Builder::new().propagate_panic(true).metadata(metadata); + + // SAFETY: There are no instances where the value is accessed mutably + // from multiple locations simultaneously. This AsyncCallOnDrop will be + // invoked after the surrounding scope has exited in either a + // `try_tick_local` or `run` call. + let future = AsyncCallOnDrop::new(future, move || { + try_with_local_queue(|tls| drop(tls.local_active.try_remove(index))).ok(); + }); + + // This is a critical section which will result in UB by aliasing active + // if the AsyncCallOnDrop is called while still in this function. + // + // To avoid this, this guard will abort the process if it does + // panic. Rust's drop order will ensure that this will run before + // executor, and thus before the above AsyncCallOnDrop is dropped. + let _panic_guard = AbortOnPanic; + + // Create the task and register it in the set of active tasks. + // + // SAFETY: + // + // - `future` is not `Send`, but the produced `Runnable` does is bound + // to thread-local storage and thus cannot leave this thread of execution. + // - `future` may not be `'static`, but the caller is required to ensure that + // the future does not outlive the borrowed non-metadata variables of the + // task. + // - `self.schedule_local()` is not `Send` or `Sync` so all instances + // must not leave the current thread of execution, and it does not + // all of them are bound vy use of thread-local storage. + // - `self.schedule_local()` is `'static`, as checked below. + let (runnable, task) = builder + .spawn_unchecked(|_| future, self.schedule_local()); + entry.insert(runnable.waker()); + + mem::forget(_panic_guard); + + (runnable, task) + }).unwrap() + }; + + runnable.schedule(); + task + } + + pub fn current_thread_spawner(&'static self) -> ThreadSpawner { + ThreadSpawner { + thread_id: std::thread::current().id(), + target_queue: &THREAD_LOCAL_STATE.get_or_default().thread_locked_queue, + state: &self.state, + } + } + + pub fn try_tick_local() -> bool { + // SAFETY: There are no instances where the value is accessed mutably + // from multiple locations simultaneously. As the Runnable is run after + // this scope closes, the AsyncCallOnDrop around the future will be invoked + // without overlapping mutable accssses. + unsafe { try_with_local_queue(|tls| tls.local_queue.pop_front()) } + .ok() + .flatten() + .map(Runnable::run) + .is_some() + } + + /// Runs the executor until the given future completes. + pub fn run<'b, T>(&'static self, future: impl Future + 'b) -> impl Future + 'b { + const MAX_CONSECUTIVE_FAILURES: usize = 5; + let mut runner = Runner::new(&self.state); + + // A future that runs tasks forever. + let run_forever = async move { + let mut rng = fastrand::Rng::new(); + loop { + let mut failed = 0; + for _ in 0..200 { + let runnable = runner.runnable(&mut rng).await; + + if !Self::execute(&self.state, runnable) { + failed += 1; + } else { + failed = 0; + } + + if failed >= MAX_CONSECUTIVE_FAILURES { + break; + } + } + future::yield_now().await; + } + }; + + // Run `future` and `run_forever` concurrently until `future` completes. + future.or(run_forever) + } + + fn execute(state: &'static State, runnable: Runnable) -> bool { + let metadata = runnable.metadata(); + // SAFETY: This can never be outo bounds. + let semaphore = unsafe { state.priority_limits.get_unchecked(metadata.priority.to_index()) }; + match semaphore.acquire() { + Permit::Unrestricted | Permit::Held(_) => { + runnable.run(); + true + }, + Permit::Blocked => if metadata.is_send { + Self::queue_send(state, runnable); + false + } else { + Self::queue_local(state, runnable); + false + }, + } + } + + /// Returns a function that schedules a runnable task when it gets woken up. + fn schedule(&'static self) -> impl Fn(Runnable) + Send + Sync + 'static { + let state = &self.state; + + move |runnable| { + Self::queue_send(state, runnable); + } + } + + /// Returns a function that schedules a runnable task when it gets woken up. + fn schedule_local(&'static self) -> impl Fn(Runnable) + 'static { + let state = &self.state; + move |runnable| { + Self::queue_local(state, runnable); + } + } + + fn queue_send(state: &'static State, runnable: Runnable) { + debug_assert!(runnable.metadata().is_send); + if runnable.metadata().priority == TaskPriority::RunNow { + // SAFETY: This value is in thread local storage and thus can only be accessed + // from one thread. There are no instances where the value is accessed mutably + // from multiple locations simultaneously. + if unsafe { try_with_local_queue(|tls| tls.local_queue.push_front(runnable)) }.is_ok() { + state.notify_specific_thread(std::thread::current().id(), false); + } + return; + } + + // Attempt to push onto the local queue first in dedicated executor threads, + // because we know that this thread is awake and always processing new tasks. + let runnable = if let Some(local_state) = THREAD_LOCAL_STATE.get() { + if core::ptr::eq(local_state.executor.load(Ordering::Relaxed), state) { + match local_state.stealable_queue.push(runnable) { + Ok(()) => { + state.notify_specific_thread(std::thread::current().id(), true); + return; + } + Err(r) => r.into_inner(), + } + } else { + runnable + } + } else { + runnable + }; + // Otherwise push onto the global queue instead. + let result = state.queue.push(runnable); + debug_assert!(result.is_ok()); + state.notify(); + } + + fn queue_local(state: &'static State, runnable: Runnable) { + debug_assert!(!runnable.metadata().is_send); + let result = if runnable.metadata().priority == TaskPriority::RunNow { + // SAFETY: This value is in thread local storage and thus can only be accessed + // from one thread. There are no instances where the value is accessed mutably + // from multiple locations simultaneously. + unsafe { try_with_local_queue(|tls| tls.local_queue.push_front(runnable)) } + } else { + // SAFETY: This value is in thread local storage and thus can only be accessed + // from one thread. There are no instances where the value is accessed mutably + // from multiple locations simultaneously. + unsafe { try_with_local_queue(|tls| tls.local_queue.push_back(runnable)) } + }; + if result.is_ok() { + state.notify_specific_thread(std::thread::current().id(), false); + } + } +} + +/// The state of a executor. +struct State { + /// The global queue. + queue: ConcurrentQueue, + + /// Local queues created by runners. + stealer_queues: RwLock>>, + + /// Set to `true` when a sleeping ticker is notified or no tickers are sleeping. + notified: AtomicBool, + + /// A list of sleeping tickers. + sleepers: Mutex, + + // Semaphores for each priority level. + priority_limits: [CachePadded; TaskPriority::MAX] +} + +impl State { + /// Creates state for a new executor. + const fn new() -> State { + State { + queue: ConcurrentQueue::unbounded(), + stealer_queues: RwLock::new(Vec::new()), + notified: AtomicBool::new(true), + sleepers: Mutex::new(Sleepers { + count: 0, + wakers: Vec::new(), + free_ids: Vec::new(), + }), + priority_limits: [ + CachePadded::new(AtomicSemaphore::new()), + CachePadded::new(AtomicSemaphore::new()), + CachePadded::new(AtomicSemaphore::new()), + CachePadded::new(AtomicSemaphore::new()), + CachePadded::new(AtomicSemaphore::new()) + ], + } + } + + /// Notifies a sleeping ticker. + #[inline] + fn notify(&self) { + if self + .notified + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_ok() + { + let waker = self.sleepers.lock().unwrap_or_else(PoisonError::into_inner).notify(); + if let Some(w) = waker { + w.wake(); + } + } + } + + /// Notifies a sleeping ticker. + #[inline] + fn notify_specific_thread(&self, thread_id: ThreadId, allow_stealing: bool) { + let mut sleepers = self.sleepers.lock().unwrap_or_else(PoisonError::into_inner); + let mut waker = sleepers.notify_specific_thread(thread_id); + if waker.is_none() + && allow_stealing + && self + .notified + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_ok() + { + waker = sleepers.notify(); + } + if let Some(w) = waker { + w.wake(); + } + } +} + +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + debug_state(self, "State", f) + } +} + +/// A list of sleeping tickers. +struct Sleepers { + /// Number of sleeping tickers (both notified and unnotified). + count: usize, + + /// IDs and wakers of sleeping unnotified tickers. + /// + /// A sleeping ticker is notified when its waker is missing from this list. + wakers: Vec<(usize, ThreadId, Waker)>, + + /// Reclaimed IDs. + free_ids: Vec, +} + +impl Sleepers { + /// Inserts a new sleeping ticker. + fn insert(&mut self, waker: &Waker) -> usize { + let id = match self.free_ids.pop() { + Some(id) => id, + None => self.count + 1, + }; + self.count += 1; + self.wakers + .push((id, std::thread::current().id(), waker.clone())); + id + } + + /// Re-inserts a sleeping ticker's waker if it was notified. + /// + /// Returns `true` if the ticker was notified. + fn update(&mut self, id: usize, waker: &Waker) -> bool { + for item in &mut self.wakers { + if item.0 == id { + item.2.clone_from(waker); + return false; + } + } + + self.wakers + .push((id, std::thread::current().id(), waker.clone())); + true + } + + /// Removes a previously inserted sleeping ticker. + /// + /// Returns `true` if the ticker was notified. + fn remove(&mut self, id: usize) -> bool { + self.count -= 1; + self.free_ids.push(id); + + for i in (0..self.wakers.len()).rev() { + if self.wakers[i].0 == id { + self.wakers.remove(i); + return false; + } + } + true + } + + /// Returns `true` if a sleeping ticker is notified or no tickers are sleeping. + fn is_notified(&self) -> bool { + self.count == 0 || self.count > self.wakers.len() + } + + /// Returns notification waker for a sleeping ticker. + /// + /// If a ticker was notified already or there are no tickers, `None` will be returned. + fn notify(&mut self) -> Option { + if self.wakers.len() == self.count { + self.wakers.pop().map(|item| item.2) + } else { + None + } + } + + /// Returns notification waker for a sleeping ticker. + /// + /// If a ticker was notified already or there are no tickers, `None` will be returned. + fn notify_specific_thread(&mut self, thread_id: ThreadId) -> Option { + for i in (0..self.wakers.len()).rev() { + if self.wakers[i].1 == thread_id { + let (_, _, waker) = self.wakers.remove(i); + return Some(waker); + } + } + None + } +} + +/// Runs task one by one. +struct Ticker<'a> { + /// The executor state. + state: &'a State, + + /// Set to a non-zero sleeper ID when in sleeping state. + /// + /// States a ticker can be in: + /// 1) Woken. + /// 2a) Sleeping and unnotified. + /// 2b) Sleeping and notified. + sleeping: usize, +} + +impl Ticker<'_> { + /// Creates a ticker. + fn new(state: &State) -> Ticker<'_> { + Ticker { state, sleeping: 0 } + } + + /// Moves the ticker into sleeping and unnotified state. + /// + /// Returns `false` if the ticker was already sleeping and unnotified. + fn sleep(&mut self, waker: &Waker) -> bool { + let mut sleepers = self.state.sleepers.lock().unwrap_or_else(PoisonError::into_inner); + + match self.sleeping { + // Move to sleeping state. + 0 => { + self.sleeping = sleepers.insert(waker); + } + + // Already sleeping, check if notified. + id => { + if !sleepers.update(id, waker) { + return false; + } + } + } + + self.state + .notified + .store(sleepers.is_notified(), Ordering::Release); + + true + } + + /// Moves the ticker into woken state. + fn wake(&mut self) { + if self.sleeping != 0 { + let mut sleepers = self.state.sleepers.lock().unwrap_or_else(PoisonError::into_inner); + sleepers.remove(self.sleeping); + + self.state + .notified + .store(sleepers.is_notified(), Ordering::Release); + } + self.sleeping = 0; + } + + /// Waits for the next runnable task to run, given a function that searches for a task. + /// + /// # Safety + /// Caller must not access `LOCAL_QUEUE` either directly or with `try_with_local_queue` in any way inside `search`. + unsafe fn runnable_with(&mut self, mut search: impl FnMut(&mut LocalQueue) -> Option) -> impl Future { + future::poll_fn(move |cx| { + // SAFETY: Caller must ensure that there's no instances where LOCAL_QUEUE is accessed mutably + // from multiple locations simultaneously. + unsafe { + try_with_local_queue(|tls| { + loop { + match search(tls) { + None => { + // Move to sleeping and unnotified state. + if !self.sleep(cx.waker()) { + // If already sleeping and unnotified, return. + return Poll::Pending; + } + } + Some(r) => { + // Wake up. + self.wake(); + + // Notify another ticker now to pick up where this ticker left off, just in + // case running the task takes a long time. + self.state.notify(); + + return Poll::Ready(r); + } + } + } + }).unwrap_or(Poll::Pending) + } + }) + } +} + +impl Drop for Ticker<'_> { + fn drop(&mut self) { + // If this ticker is in sleeping state, it must be removed from the sleepers list. + if self.sleeping != 0 { + let mut sleepers = self.state.sleepers.lock().unwrap_or_else(PoisonError::into_inner); + let notified = sleepers.remove(self.sleeping); + + self.state + .notified + .store(sleepers.is_notified(), Ordering::Release); + + // If this ticker was notified, then notify another ticker. + if notified { + drop(sleepers); + self.state.notify(); + } + } + } +} + +/// A worker in a work-stealing executor. +/// +/// This is just a ticker that also has an associated local queue for improved cache locality. +struct Runner<'a> { + /// The executor state. + state: &'a State, + + /// Inner ticker. + ticker: Ticker<'a>, + + /// Bumped every time a runnable task is found. + ticks: usize, + + // The thread local state of the executor for the current thread. + local_state: &'a ThreadLocalState, +} + +impl Runner<'_> { + /// Creates a runner and registers it in the executor state. + fn new(state: &State) -> Runner<'_> { + let local_state = THREAD_LOCAL_STATE.get_or_default(); + let runner = Runner { + state, + ticker: Ticker::new(state), + ticks: 0, + local_state, + }; + state + .stealer_queues + .write() + .unwrap_or_else(PoisonError::into_inner) + .push(&local_state.stealable_queue); + runner + } + + /// Waits for the next runnable task to run. + fn runnable(&mut self, _rng: &mut fastrand::Rng) -> impl Future { + // SAFETY: The provided search function does not access LOCAL_QUEUE in any way, and thus cannot + // alias. + let runnable = unsafe { + self + .ticker + .runnable_with(|tls| { + if let Some(r) = tls.local_queue.pop_back() { + return Some(r); + } + + crate::cfg::multi_threaded! { + if { + // Try the local queue. + if let Ok(r) = self.local_state.stealable_queue.pop() { + return Some(r); + } + + // Try stealing from the global queue. + if let Ok(r) = self.state.queue.pop() { + steal(&self.state.queue, &self.local_state.stealable_queue); + return Some(r); + } + + // Try stealing from other runners. + if let Ok(stealer_queues) = self.state.stealer_queues.try_read() { + // Pick a random starting point in the iterator list and rotate the list. + let n = stealer_queues.len(); + let start = _rng.usize(..n); + let iter = stealer_queues + .iter() + .chain(stealer_queues.iter()) + .skip(start) + .take(n); + + // Remove this runner's local queue. + let iter = + iter.filter(|local| !core::ptr::eq(**local, &self.local_state.stealable_queue)); + + // Try stealing from each local queue in the list. + for local in iter { + steal(*local, &self.local_state.stealable_queue); + if let Ok(r) = self.local_state.stealable_queue.pop() { + return Some(r); + } + } + } + + if let Ok(r) = self.local_state.thread_locked_queue.pop() { + // Do not steal from this queue. If other threads steal + // from this current thread, the task will be moved. + // + // Instead, flush all queued tasks into the local queue to + // minimize the effort required to scan for these tasks. + flush_to_local(&self.local_state.thread_locked_queue, tls); + return Some(r); + } + } else {} + } + + None + }) + }; + + // Bump the tick counter. + self.ticks = self.ticks.wrapping_add(1); + + if self.ticks.is_multiple_of(64) { + // Steal tasks from the global queue to ensure fair task scheduling. + steal(&self.state.queue, &self.local_state.stealable_queue); + } + + runnable + } +} + +impl Drop for Runner<'_> { + fn drop(&mut self) { + // Remove the local queue. + { + let mut stealer_queues = self.state.stealer_queues.write().unwrap(); + if let Some((idx, _)) = stealer_queues + .iter() + .enumerate() + .rev() + .find(|(_, local)| core::ptr::eq(**local, &self.local_state.stealable_queue)) + { + stealer_queues.remove(idx); + } + } + + // Re-schedule remaining tasks in the local queue. + while let Ok(r) = self.local_state.stealable_queue.pop() { + r.schedule(); + } + } +} + +/// Steals some items from one queue into another. +fn steal(src: &ConcurrentQueue, dest: &ConcurrentQueue) { + // Half of `src`'s length rounded up. + let mut count = src.len(); + + if count > 0 { + if let Some(capacity) = dest.capacity() { + // Don't steal more than fits into the queue. + count = count.min(capacity- dest.len()); + } + + // Steal tasks. + for _ in 0..count { + let Ok(val) = src.pop() else { break }; + assert!(dest.push(val).is_ok()); + } + } +} + +fn flush_to_local(src: &ConcurrentQueue, dst: &mut LocalQueue) { + let count = src.len(); + + if count > 0 { + // Steal tasks. + for _ in 0..count { + let Ok(val) = src.pop() else { break }; + dst.local_queue.push_front(val); + } + } +} + +struct AtomicSemaphore { + available: AtomicUsize, + limit: UnsafeCell>, +} + +// SAFETY: The safety invaraints on `set_limit` ensure no aliasing occurs. +unsafe impl Send for AtomicSemaphore {} +// SAFETY: The safety invaraints on `set_limit` ensure no aliasing occurs. +unsafe impl Sync for AtomicSemaphore {} + +impl AtomicSemaphore { + pub const fn new() -> Self { + Self { + available: AtomicUsize::new(0), + limit: UnsafeCell::new(None), + } + } + + /// # Safety + /// Must not be called while another thread might call `acquire`. + pub unsafe fn set_limit(&self, limit: Option) { + // SAFETY: The caller must make sure that this does not alias. + unsafe { *self.limit.get() = limit; } + self.available.store(limit.map(NonZeroUsize::get).unwrap_or(0), Ordering::Relaxed); + } + + pub fn acquire<'a>(&'a self) -> Permit<'a> { + // SAFETY: `set_limit` is not reentrant, and is required to not be called while + if unsafe { &*self.limit.get() }.is_none() { + return Permit::Unrestricted; + } + let mut current = self.available.load(Ordering::Acquire); + if current == 0 { + return Permit::Blocked; + } + loop { + match self.available.compare_exchange_weak(current, current - 1, Ordering::AcqRel, Ordering::Relaxed) { + Ok(_) => return Permit::Held(self), + Err(0) => return Permit::Blocked, + Err(actual) => current = actual, + } + } + } +} + +enum Permit<'a> { + Unrestricted, + Held(&'a AtomicSemaphore), + Blocked, +} + +impl<'a> Drop for Permit<'a> { + fn drop(&mut self) { + if let Permit::Held(semaphore) = self { + semaphore.available.fetch_add(1, Ordering::AcqRel); + } + } +} + +/// Debug implementation for `Executor`. +fn debug_executor(executor: &Executor, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result { + debug_state(&executor.state, name, f) +} + +/// Debug implementation for `Executor`. +fn debug_state(state: &State, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result { + /// Debug wrapper for the number of active tasks. + struct ActiveTasks<'a>(&'a Mutex>); + + impl fmt::Debug for ActiveTasks<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0.try_lock() { + Ok(lock) => fmt::Debug::fmt(&lock.len(), f), + Err(TryLockError::WouldBlock) => f.write_str(""), + Err(TryLockError::Poisoned(err)) => fmt::Debug::fmt(&err.into_inner().len(), f), + } + } + } + + /// Debug wrapper for the local runners. + struct LocalRunners<'a>(&'a RwLock>>); + + impl fmt::Debug for LocalRunners<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0.try_read() { + Ok(lock) => f + .debug_list() + .entries(lock.iter().map(|queue| queue.len())) + .finish(), + Err(TryLockError::WouldBlock) => f.write_str(""), + Err(TryLockError::Poisoned(_)) => f.write_str(""), + } + } + } + + /// Debug wrapper for the sleepers. + struct SleepCount<'a>(&'a Mutex); + + impl fmt::Debug for SleepCount<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0.try_lock() { + Ok(lock) => fmt::Debug::fmt(&lock.count, f), + Err(TryLockError::WouldBlock) => f.write_str(""), + Err(TryLockError::Poisoned(_)) => f.write_str(""), + } + } + } + + f.debug_struct(name) + .field("global_tasks", &state.queue.len()) + .field("stealer_queues", &LocalRunners(&state.stealer_queues)) + .field("sleepers", &SleepCount(&state.sleepers)) + .finish() +} + +struct AbortOnPanic; + +impl Drop for AbortOnPanic { + fn drop(&mut self) { + // Panicking while unwinding will force an abort. + panic!("Aborting due to allocator error"); + } +} + +/// Runs a closure when dropped. +struct CallOnDrop(F); + +impl Drop for CallOnDrop { + fn drop(&mut self) { + (self.0)(); + } +} + +pin_project_lite::pin_project! { + /// A wrapper around a future, running a closure when dropped. + struct AsyncCallOnDrop { + #[pin] + future: Fut, + cleanup: CallOnDrop, + } +} + +impl AsyncCallOnDrop { + fn new(future: Fut, cleanup: Cleanup) -> Self { + Self { + future, + cleanup: CallOnDrop(cleanup), + } + } +} + +impl Future for AsyncCallOnDrop { + type Output = Fut::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().future.poll(cx) + } +} + +#[cfg(test)] +mod test { + use super::*; + use super::THREAD_LOCAL_STATE; + use alloc::{string::String, boxed::Box}; + use futures_lite::{future, pin}; + use async_task::Task; + use core::time::Duration; + + static EX: Executor = Executor::new(); + + fn _ensure_send_and_sync() { + fn is_send(_: T) {} + fn is_sync(_: T) {} + fn is_static(_: T) {} + + is_send::(Executor::new()); + is_sync::(Executor::new()); + + is_send(EX.schedule()); + is_sync(EX.schedule()); + is_static(EX.schedule()); + is_send(EX.current_thread_spawner()); + is_sync(EX.current_thread_spawner()); + is_send(THREAD_LOCAL_STATE.get_or_default()); + is_sync(THREAD_LOCAL_STATE.get_or_default()); + } + + #[test] + fn await_task_after_dropping_executor() { + let s: String = "hello".into(); + + // SAFETY: We make sure that the task does not outlive the borrow on `s`. + let task: Task<&str, Metadata> = unsafe { EX.spawn_scoped(async { &*s }, Metadata::default()) }; + future::block_on(EX.run(async { + for _ in 0..10 { + future::yield_now().await; + } + })); + + assert_eq!(future::block_on(task), "hello"); + drop(s); + } + + fn do_run>(mut f: impl FnMut(&'static Executor) -> Fut) { + // This should not run for longer than two minutes. + #[cfg(not(miri))] + let _stop_timeout = { + let (stop_timeout, stopper) = async_channel::bounded::<()>(1); + std::thread::spawn(move || { + future::block_on(async move { + #[expect(clippy::print_stderr, reason = "Explicitly used to warn about timed out tests")] + let timeout = async { + async_io::Timer::after(Duration::from_secs(2 * 60)).await; + std::eprintln!("test timed out after 2m"); + std::process::exit(1) + }; + + let _ = stopper.recv().or(timeout).await; + }); + }); + stop_timeout + }; + + // Test 1: Use the `run` command. + future::block_on(EX.run(f(&EX))); + + // Test 2: Run on many threads. + std::thread::scope(|scope| { + let (_signal, shutdown) = async_channel::bounded::<()>(1); + + for _ in 0..16 { + let shutdown = shutdown.clone(); + let ex = &EX; + scope.spawn(move || future::block_on(ex.run(shutdown.recv()))); + } + + future::block_on(f(&EX)); + }); + } + + #[test] + fn smoke() { + do_run(|ex| async move { ex.spawn(async {}, Metadata::default()).await }); + } + + #[test] + fn yield_now() { + do_run(|ex| async move { ex.spawn(future::yield_now(), Metadata::default()).await }); + } + + #[test] + fn timer() { + do_run(|ex| async move { + ex.spawn(async_io::Timer::after(Duration::from_millis(5)), Metadata::default()) + .await; + }); + } + + #[test] + fn test_panic_propagation() { + let task = EX.spawn(async { panic!("should be caught by the task") }, Metadata::default()); + + // Running the executor should not panic. + future::block_on(EX.run(async { + for _ in 0..10 { + future::yield_now().await; + } + })); + + // Polling the task should. + assert!(future::block_on(task.catch_unwind()).is_err()); + } + + #[test] + fn two_queues() { + future::block_on(async { + // Create an executor with two runners. + let (run1, run2) = ( + EX.run(future::pending::<()>()), + EX.run(future::pending::<()>()), + ); + let mut run1 = Box::pin(run1); + pin!(run2); + + // Poll them both. + assert!(future::poll_once(run1.as_mut()).await.is_none()); + assert!(future::poll_once(run2.as_mut()).await.is_none()); + + // Drop the first one, which should leave the local queue in the `None` state. + drop(run1); + assert!(future::poll_once(run2.as_mut()).await.is_none()); + }); + } +} diff --git a/crates/bevy_tasks/src/edge_executor.rs b/crates/bevy_tasks/src/edge_executor.rs index a8c80725cafe9..7215b0b7bb8e7 100644 --- a/crates/bevy_tasks/src/edge_executor.rs +++ b/crates/bevy_tasks/src/edge_executor.rs @@ -13,7 +13,6 @@ // TODO: Create a more tailored replacement, possibly integrating [Fotre](https://github.com/NthTensor/Forte) -use alloc::rc::Rc; use core::{ future::{poll_fn, Future}, marker::PhantomData, @@ -49,12 +48,11 @@ use futures_lite::FutureExt; /// drop(signal); /// })); /// ``` -pub struct Executor<'a, const C: usize = 64> { - state: LazyLock>>, - _invariant: PhantomData>, +pub struct Executor { + state: LazyLock>, } -impl<'a, const C: usize> Executor<'a, C> { +impl Executor { /// Creates a new executor. /// /// # Examples @@ -66,8 +64,7 @@ impl<'a, const C: usize> Executor<'a, C> { /// ``` pub const fn new() -> Self { Self { - state: LazyLock::new(|| Arc::new(State::new())), - _invariant: PhantomData, + state: LazyLock::new(|| State::new()), } } @@ -88,7 +85,16 @@ impl<'a, const C: usize> Executor<'a, C> { /// Note that if the executor's queue size is equal to the number of currently /// spawned and running tasks, spawning this additional task might cause the executor to panic /// later, when the task is scheduled for polling. - pub fn spawn(&self, fut: F) -> Task + pub fn spawn(&'static self, fut: F) -> Task + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + // SAFETY: Original implementation missing safety documentation + unsafe { self.spawn_unchecked(fut) } + } + + pub unsafe fn spawn_scoped<'a, F>(&'static self, fut: F) -> Task where F: Future + Send + 'a, F::Output: Send + 'a, @@ -97,6 +103,24 @@ impl<'a, const C: usize> Executor<'a, C> { unsafe { self.spawn_unchecked(fut) } } + pub fn spawn_local(&'static self, fut: F) -> Task + where + F: Future + 'static, + F::Output: 'static, + { + // SAFETY: Original implementation missing safety documentation + unsafe { self.spawn_unchecked(fut) } + } + + pub unsafe fn spawn_local_scoped<'a, F>(&'static self, fut: F) -> Task + where + F: Future + 'a, + F::Output: 'a, + { + // SAFETY: Original implementation missing safety documentation + unsafe { self.spawn_unchecked(fut) } + } + /// Attempts to run a task if at least one is scheduled. /// /// Running a scheduled task means simply polling its future once. @@ -160,9 +184,9 @@ impl<'a, const C: usize> Executor<'a, C> { /// /// assert_eq!(res, 6); /// ``` - pub async fn run(&self, fut: F) -> F::Output + pub async fn run<'a, F>(&'static self, fut: F) -> F::Output where - F: Future + Send + 'a, + F: Future + 'a, { // SAFETY: Original implementation missing safety documentation unsafe { self.run_unchecked(fut).await } @@ -175,7 +199,7 @@ impl<'a, const C: usize> Executor<'a, C> { /// Polls the first task scheduled for execution by the executor. fn poll_runnable(&self, ctx: &Context<'_>) -> Poll { - self.state().waker.register(ctx.waker()); + self.state.waker.register(ctx.waker()); if let Some(runnable) = self.try_runnable() { Poll::Ready(runnable) @@ -201,7 +225,7 @@ impl<'a, const C: usize> Executor<'a, C> { target_has_atomic = "ptr" ))] { - runnable = self.state().queue.pop(); + runnable = self.state.queue.pop().ok(); } #[cfg(not(all( @@ -212,7 +236,7 @@ impl<'a, const C: usize> Executor<'a, C> { target_has_atomic = "ptr" )))] { - runnable = self.state().queue.dequeue(); + runnable = self.state.queue.dequeue(); } runnable @@ -221,12 +245,12 @@ impl<'a, const C: usize> Executor<'a, C> { /// # Safety /// /// Original implementation missing safety documentation - unsafe fn spawn_unchecked(&self, fut: F) -> Task + unsafe fn spawn_unchecked(&'static self, fut: F) -> Task where F: Future, { let schedule = { - let state = self.state().clone(); + let state = &self.state; move |runnable| { #[cfg(all( @@ -280,158 +304,18 @@ impl<'a, const C: usize> Executor<'a, C> { run_forever.or(fut).await } - - /// Returns a reference to the inner state. - fn state(&self) -> &Arc> { - &self.state - } } -impl<'a, const C: usize> Default for Executor<'a, C> { +impl Default for Executor { fn default() -> Self { Self::new() } } // SAFETY: Original implementation missing safety documentation -unsafe impl<'a, const C: usize> Send for Executor<'a, C> {} +unsafe impl Send for Executor {} // SAFETY: Original implementation missing safety documentation -unsafe impl<'a, const C: usize> Sync for Executor<'a, C> {} - -/// A thread-local executor. -/// -/// The executor can only be run on the thread that created it. -/// -/// # Examples -/// -/// ```ignore -/// use edge_executor::{LocalExecutor, block_on}; -/// -/// let local_ex: LocalExecutor = Default::default(); -/// -/// block_on(local_ex.run(async { -/// println!("Hello world!"); -/// })); -/// ``` -pub struct LocalExecutor<'a, const C: usize = 64> { - executor: Executor<'a, C>, - _not_send: PhantomData>>, -} - -impl<'a, const C: usize> LocalExecutor<'a, C> { - /// Creates a single-threaded executor. - /// - /// # Examples - /// - /// ```ignore - /// use edge_executor::LocalExecutor; - /// - /// let local_ex: LocalExecutor = Default::default(); - /// ``` - pub const fn new() -> Self { - Self { - executor: Executor::::new(), - _not_send: PhantomData, - } - } - - /// Spawns a task onto the executor. - /// - /// # Examples - /// - /// ```ignore - /// use edge_executor::LocalExecutor; - /// - /// let local_ex: LocalExecutor = Default::default(); - /// - /// let task = local_ex.spawn(async { - /// println!("Hello world"); - /// }); - /// ``` - /// - /// Note that if the executor's queue size is equal to the number of currently - /// spawned and running tasks, spawning this additional task might cause the executor to panic - /// later, when the task is scheduled for polling. - pub fn spawn(&self, fut: F) -> Task - where - F: Future + 'a, - F::Output: 'a, - { - // SAFETY: Original implementation missing safety documentation - unsafe { self.executor.spawn_unchecked(fut) } - } - - /// Attempts to run a task if at least one is scheduled. - /// - /// Running a scheduled task means simply polling its future once. - /// - /// # Examples - /// - /// ```ignore - /// use edge_executor::LocalExecutor; - /// - /// let local_ex: LocalExecutor = Default::default(); - /// assert!(!local_ex.try_tick()); // no tasks to run - /// - /// let task = local_ex.spawn(async { - /// println!("Hello world"); - /// }); - /// assert!(local_ex.try_tick()); // a task was found - /// ``` - pub fn try_tick(&self) -> bool { - self.executor.try_tick() - } - - /// Runs a single task asynchronously. - /// - /// Running a task means simply polling its future once. - /// - /// If no tasks are scheduled when this method is called, it will wait until one is scheduled. - /// - /// # Examples - /// - /// ```ignore - /// use edge_executor::{LocalExecutor, block_on}; - /// - /// let local_ex: LocalExecutor = Default::default(); - /// - /// let task = local_ex.spawn(async { - /// println!("Hello world"); - /// }); - /// block_on(local_ex.tick()); // runs the task - /// ``` - pub async fn tick(&self) { - self.executor.tick().await; - } - - /// Runs the executor asynchronously until the given future completes. - /// - /// # Examples - /// - /// ```ignore - /// use edge_executor::{LocalExecutor, block_on}; - /// - /// let local_ex: LocalExecutor = Default::default(); - /// - /// let task = local_ex.spawn(async { 1 + 2 }); - /// let res = block_on(local_ex.run(async { task.await * 2 })); - /// - /// assert_eq!(res, 6); - /// ``` - pub async fn run(&self, fut: F) -> F::Output - where - F: Future, - { - // SAFETY: Original implementation missing safety documentation - unsafe { self.executor.run_unchecked(fut) }.await - } -} - -impl<'a, const C: usize> Default for LocalExecutor<'a, C> { - fn default() -> Self { - Self::new() - } -} +unsafe impl Sync for Executor {} struct State { #[cfg(all( @@ -441,7 +325,7 @@ struct State { target_has_atomic = "64", target_has_atomic = "ptr" ))] - queue: crossbeam_queue::ArrayQueue, + queue: concurrent_queue::ConcurrentQueue, #[cfg(not(all( target_has_atomic = "8", target_has_atomic = "16", @@ -463,7 +347,7 @@ impl State { target_has_atomic = "64", target_has_atomic = "ptr" ))] - queue: crossbeam_queue::ArrayQueue::new(C), + queue: concurrent_queue::ConcurrentQueue::bounded(C), #[cfg(not(all( target_has_atomic = "8", target_has_atomic = "16", @@ -477,46 +361,6 @@ impl State { } } -#[cfg(test)] -mod different_executor_tests { - use core::cell::Cell; - - use bevy_tasks::{block_on, futures_lite::{pending, poll_once}}; - use futures_lite::pin; - - use super::LocalExecutor; - - #[test] - fn shared_queue_slot() { - block_on(async { - let was_polled = Cell::new(false); - let future = async { - was_polled.set(true); - pending::<()>().await; - }; - - let ex1: LocalExecutor = Default::default(); - let ex2: LocalExecutor = Default::default(); - - // Start the futures for running forever. - let (run1, run2) = (ex1.run(pending::<()>()), ex2.run(pending::<()>())); - pin!(run1); - pin!(run2); - assert!(poll_once(run1.as_mut()).await.is_none()); - assert!(poll_once(run2.as_mut()).await.is_none()); - - // Spawn the future on executor one and then poll executor two. - ex1.spawn(future).detach(); - assert!(poll_once(run2).await.is_none()); - assert!(!was_polled.get()); - - // Poll the first one. - assert!(poll_once(run1).await.is_none()); - assert!(was_polled.get()); - }); - } -} - #[cfg(test)] mod drop_tests { use alloc::string::String; @@ -533,7 +377,7 @@ mod drop_tests { #[test] fn leaked_executor_leaks_everything() { static DROP: AtomicUsize = AtomicUsize::new(0); - static WAKER: LazyLock>> = LazyLock::new(Default::default); + static WAKER: Mutex> = Mutex::new(None); let ex: Executor = Default::default(); diff --git a/crates/bevy_tasks/src/executor.rs b/crates/bevy_tasks/src/executor.rs deleted file mode 100644 index fcce0e2985536..0000000000000 --- a/crates/bevy_tasks/src/executor.rs +++ /dev/null @@ -1,86 +0,0 @@ -//! Provides a fundamental executor primitive appropriate for the target platform -//! and feature set selected. -//! By default, the `async_executor` feature will be enabled, which will rely on -//! [`async-executor`] for the underlying implementation. This requires `std`, -//! so is not suitable for `no_std` contexts. Instead, you must use `edge_executor`, -//! which relies on the alternate [`edge-executor`] backend. -//! -//! [`async-executor`]: https://crates.io/crates/async-executor -//! [`edge-executor`]: https://crates.io/crates/edge-executor - -use core::{ - fmt, - panic::{RefUnwindSafe, UnwindSafe}, -}; -use derive_more::{Deref, DerefMut}; - -crate::cfg::async_executor! { - if { - type ExecutorInner<'a> = async_executor::Executor<'a>; - type LocalExecutorInner<'a> = async_executor::LocalExecutor<'a>; - } else { - type ExecutorInner<'a> = crate::edge_executor::Executor<'a, 64>; - type LocalExecutorInner<'a> = crate::edge_executor::LocalExecutor<'a, 64>; - } -} - -crate::cfg::multi_threaded! { - pub use async_task::FallibleTask; -} - -/// Wrapper around a multi-threading-aware async executor. -/// Spawning will generally require tasks to be `Send` and `Sync` to allow multiple -/// threads to send/receive/advance tasks. -/// -/// If you require an executor _without_ the `Send` and `Sync` requirements, consider -/// using [`LocalExecutor`] instead. -#[derive(Deref, DerefMut, Default)] -pub struct Executor<'a>(ExecutorInner<'a>); - -/// Wrapper around a single-threaded async executor. -/// Spawning wont generally require tasks to be `Send` and `Sync`, at the cost of -/// this executor itself not being `Send` or `Sync`. This makes it unsuitable for -/// global statics. -/// -/// If need to store an executor in a global static, or send across threads, -/// consider using [`Executor`] instead. -#[derive(Deref, DerefMut, Default)] -pub struct LocalExecutor<'a>(LocalExecutorInner<'a>); - -impl Executor<'_> { - /// Construct a new [`Executor`] - #[expect(clippy::allow_attributes, reason = "This lint may not always trigger.")] - #[allow(dead_code, reason = "not all feature flags require this function")] - pub const fn new() -> Self { - Self(ExecutorInner::new()) - } -} - -impl LocalExecutor<'_> { - /// Construct a new [`LocalExecutor`] - #[expect(clippy::allow_attributes, reason = "This lint may not always trigger.")] - #[allow(dead_code, reason = "not all feature flags require this function")] - pub const fn new() -> Self { - Self(LocalExecutorInner::new()) - } -} - -impl UnwindSafe for Executor<'_> {} - -impl RefUnwindSafe for Executor<'_> {} - -impl UnwindSafe for LocalExecutor<'_> {} - -impl RefUnwindSafe for LocalExecutor<'_> {} - -impl fmt::Debug for Executor<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Executor").finish() - } -} - -impl fmt::Debug for LocalExecutor<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("LocalExecutor").finish() - } -} diff --git a/crates/bevy_tasks/src/lib.rs b/crates/bevy_tasks/src/lib.rs index ddb014bb9867b..ad8525e01dca7 100644 --- a/crates/bevy_tasks/src/lib.rs +++ b/crates/bevy_tasks/src/lib.rs @@ -13,9 +13,9 @@ pub mod cfg { pub use bevy_platform::cfg::{alloc, std, web}; define_alias! { - #[cfg(feature = "async_executor")] => { - /// Indicates `async_executor` is used as the future execution backend. - async_executor + #[cfg(feature = "bevy_executor")] => { + /// Indicates `bevy_executor` is used as the future execution backend. + bevy_executor } #[cfg(all(not(target_arch = "wasm32"), feature = "multi_threaded"))] => { @@ -66,21 +66,24 @@ pub trait ConditionalSendFuture: Future + ConditionalSend {} impl ConditionalSendFuture for T {} +use core::marker::PhantomData; + use alloc::boxed::Box; /// An owned and dynamically typed Future used when you can't statically type your result or need to add some indirection. pub type BoxedFuture<'a, T> = core::pin::Pin + 'a>>; // Modules -mod executor; pub mod futures; mod iter; mod slice; mod task; mod usages; -cfg::async_executor! { - if {} else { +cfg::bevy_executor! { + if { + mod bevy_executor; + } else { mod edge_executor; } } @@ -89,7 +92,6 @@ cfg::async_executor! { pub use iter::ParallelIterator; pub use slice::{ParallelSlice, ParallelSliceMut}; pub use task::Task; -pub use usages::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool}; pub use futures_lite; pub use futures_lite::future::poll_once; @@ -103,14 +105,12 @@ cfg::web! { cfg::multi_threaded! { if { mod task_pool; - mod thread_executor; - pub use task_pool::{Scope, TaskPool, TaskPoolBuilder}; - pub use thread_executor::{ThreadExecutor, ThreadExecutorTicker}; + pub use task_pool::{Scope, TaskPool, TaskPoolBuilder, ThreadSpawner}; } else { mod single_threaded_task_pool; - pub use single_threaded_task_pool::{Scope, TaskPool, TaskPoolBuilder, ThreadExecutor}; + pub use single_threaded_task_pool::{Scope, TaskPool, TaskPoolBuilder, ThreadSpawner}; } } @@ -155,7 +155,7 @@ pub mod prelude { block_on, iter::ParallelIterator, slice::{ParallelSlice, ParallelSliceMut}, - usages::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool}, + TaskPool, }; } @@ -177,3 +177,146 @@ pub fn available_parallelism() -> usize { } }} } + +/// The priority of a task scheduled onto the [`TaskPool`]. +/// +/// Using [`TaskPoolBuilder::priority_limit`], the `TaskPool` will limit how many tasks can +/// execute in parallel. This is *not* a limit on the number of tasks that can be scheduled +/// onto the task pool, but rather the number of them that can execute in parallel, and is +/// used to avoid starving out higher priority groups of parallelism. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[repr(u8)] +pub enum TaskPriority { + /// Intended for blocking IO operations (e.g. `File::read`). + BlockingIO, + /// Intended for blocking CPU-bound tasks (e.g. shader compilation, building terrain) + BlockingCompute, + /// Intended for non-blocking async IO (e.g. HTTP servers/clients, network IO, io-uring file IO). + /// These jobs generally should do very little compute bound work and then yield immediately upon + /// there being no more work to do. + AsyncIO, + /// Intended for shortlived CPU-bound jobs. These jobs are expected to do a small amount of work + /// and quickly terminate. This is the default. + #[default] + Compute, + /// Intended for shortlived CPU-bound jobs with tight realtime requirements. These jobs are expected + /// to do a small amount of work and quickly terminate or yield. + /// + /// Unlike the other priorities, this group forces tasks to immediately schedule onto the thread + /// where the task is awoken, and will start as soon as the currently executing task terminates + /// or yields. + RunNow, +} + +impl TaskPriority { + const MAX: usize = TaskPriority::RunNow as u8 as usize + 1; + + #[inline] + fn to_index(self) -> usize { + self as u8 as usize + } + + #[inline] + fn from_index(index: usize) -> Option { + Some(match index { + 0 => Self::BlockingIO, + 1 => Self::BlockingCompute, + 2 => Self::AsyncIO, + 3 => Self::Compute, + 4 => Self::RunNow, + _ => return None, + }) + } +} + +#[derive(Debug, Default)] +pub(crate) struct Metadata { + pub priority: TaskPriority, + pub is_send: bool, +} + +/// A builder for a [`Task`] to be scheduled onto a [`TaskPool`]. +pub struct TaskBuilder<'a, T> { + pub(crate) task_pool: &'a TaskPool, + pub(crate) priority: TaskPriority, + marker_: PhantomData<*const T>, +} + +impl<'a, T> TaskBuilder<'a, T> { + pub(crate) fn new(task_pool: &'a TaskPool) -> Self { + Self { + task_pool, + priority: TaskPriority::default(), + marker_: PhantomData, + } + } + + /// Sets the priority of the spawned task. See [`TaskPriority`] for more details. + pub fn with_priority(mut self, priority: TaskPriority) -> Self { + self.priority = priority; + self + } + + pub(crate) fn build_metadata(self) -> Metadata { + Metadata { + priority: self.priority, + is_send: false, + } + } +} + +/// Configuration for which thread to schedule a [`Task`] within a [`Scope`] onto. +#[derive(Clone, Copy, Default, Debug)] +pub enum ScopeTaskTarget { + /// Spawns the future onto any thread intthe [`TaskPool`]. + #[default] + Any, + + /// Spawns a scoped future onto the thread the scope is run on. + /// + /// For more information, see [`TaskPool::scope`]. + Scope, + + /// Spawns a scoped future onto the thread of the external thread executor. + /// This is typically the main thread. + /// + /// For more information, see [`TaskPool::scope`]. + External, +} + +/// A builder for a [`Task`] within a [`Scope`]. +pub struct ScopeTaskBuilder<'a, 'scope, 'env: 'scope, T> { + scope: &'a Scope<'scope, 'env, T>, + priority: TaskPriority, + target: ScopeTaskTarget, +} + +impl<'a, 'scope, 'env, T> ScopeTaskBuilder<'a, 'scope, 'env, T> { + pub(crate) fn new(scope: &'a Scope<'scope, 'env, T>) -> Self { + Self { + scope, + priority: TaskPriority::default(), + target: ScopeTaskTarget::default(), + } + } + + /// Sets the priority of the spawned task. See [`TaskPriority`] for more details. + pub fn with_priority(mut self, priority: TaskPriority) -> Self { + self.priority = priority; + self + } + + /// Sets the target for which thread to schedule the spawned task onto. + /// See [`ScopeTaskTarget`] for more details. + pub fn with_target(mut self, target: ScopeTaskTarget) -> Self { + self.target = target; + self + } + + pub(crate) fn build_metadata(self) -> Metadata { + Metadata { + priority: self.priority, + is_send: false, + } + } +} diff --git a/crates/bevy_tasks/src/single_threaded_task_pool.rs b/crates/bevy_tasks/src/single_threaded_task_pool.rs index d81e43b4e91b9..23c1fe4fd6f26 100644 --- a/crates/bevy_tasks/src/single_threaded_task_pool.rs +++ b/crates/bevy_tasks/src/single_threaded_task_pool.rs @@ -1,28 +1,18 @@ -use alloc::{string::String, vec::Vec}; -use bevy_platform::sync::Arc; -use core::{cell::{RefCell, Cell}, future::Future, marker::PhantomData, mem}; +use alloc::{string::String, vec::Vec, fmt}; +use core::{cell::{RefCell, Cell}, future::Future, marker::PhantomData, mem, task::{Poll, Context, Waker}, pin::Pin}; -use crate::executor::LocalExecutor; use crate::{block_on, Task}; -crate::cfg::std! { - if { - use std::thread_local; - - use crate::executor::LocalExecutor as Executor; - - thread_local! { - static LOCAL_EXECUTOR: Executor<'static> = const { Executor::new() }; - } +crate::cfg::bevy_executor! { + if { + use crate::bevy_executor::Executor; } else { - - // Because we do not have thread-locals without std, we cannot use LocalExecutor here. - use crate::executor::Executor; - - static LOCAL_EXECUTOR: Executor<'static> = const { Executor::new() }; + use crate::edge_executor::Executor; } } +static EXECUTOR: Executor = const { Executor::new() }; + /// Used to create a [`TaskPool`]. #[derive(Debug, Default, Clone)] pub struct TaskPoolBuilder {} @@ -31,15 +21,9 @@ pub struct TaskPoolBuilder {} /// task pool. In the case of the multithreaded task pool this struct is used to spawn /// tasks on a specific thread. But the wasm task pool just calls /// `wasm_bindgen_futures::spawn_local` for spawning which just runs tasks on the main thread -/// and so the [`ThreadExecutor`] does nothing. -#[derive(Default)] -pub struct ThreadExecutor<'a>(PhantomData<&'a ()>); -impl<'a> ThreadExecutor<'a> { - /// Creates a new `ThreadExecutor` - pub fn new() -> Self { - Self::default() - } -} +/// and so the [`ThreadSpawner`] does nothing. +#[derive(Clone)] +pub struct ThreadSpawner; impl TaskPoolBuilder { /// Creates a new `TaskPoolBuilder` instance @@ -76,6 +60,10 @@ impl TaskPoolBuilder { pub fn build(self) -> TaskPool { TaskPool::new_internal() } + + pub(crate) fn build_static(self, _executor: &'static Executor) -> TaskPool { + TaskPool::new_internal() + } } /// A thread pool for executing tasks. Tasks are futures that are being automatically driven by @@ -85,8 +73,8 @@ pub struct TaskPool {} impl TaskPool { /// Just create a new `ThreadExecutor` for wasm - pub fn get_thread_executor() -> Arc> { - Arc::new(ThreadExecutor::new()) + pub fn current_thread_spawner(&self) -> ThreadSpawner { + ThreadSpawner } /// Create a `TaskPool` with the default configuration. @@ -113,7 +101,7 @@ impl TaskPool { F: for<'scope> FnOnce(&'scope mut Scope<'scope, 'env, T>), T: Send + 'static, { - self.scope_with_executor(false, None, f) + self.scope_with_executor(None, f) } /// Allows spawning non-`'static` futures on the thread pool. The function takes a callback, @@ -124,8 +112,7 @@ impl TaskPool { #[expect(unsafe_code, reason = "Required to transmute lifetimes.")] pub fn scope_with_executor<'env, F, T>( &self, - _tick_task_pool_executor: bool, - _thread_executor: Option<&ThreadExecutor>, + _thread_executor: Option, f: F, ) -> Vec where @@ -140,22 +127,19 @@ impl TaskPool { // Any usages of the references passed into `Scope` must be accessed through // the transmuted reference for the rest of this function. - let executor = LocalExecutor::new(); + // Kept around to ensure that, in the case of an unwinding panic, all scheduled Tasks are cancelled. + let tasks: RefCell>> = RefCell::new(Vec::new()); // SAFETY: As above, all futures must complete in this function so we can change the lifetime - let executor_ref: &'env LocalExecutor<'env> = unsafe { mem::transmute(&executor) }; - - let results: RefCell>> = RefCell::new(Vec::new()); - // SAFETY: As above, all futures must complete in this function so we can change the lifetime - let results_ref: &'env RefCell>> = unsafe { mem::transmute(&results) }; + let tasks_ref: &'env RefCell>> = unsafe { mem::transmute(&tasks) }; let pending_tasks: Cell = Cell::new(0); // SAFETY: As above, all futures must complete in this function so we can change the lifetime let pending_tasks: &'env Cell = unsafe { mem::transmute(&pending_tasks) }; let mut scope = Scope { - executor_ref, + executor_ref: &EXECUTOR, + tasks_ref, pending_tasks, - results_ref, scope: PhantomData, env: PhantomData, }; @@ -166,19 +150,41 @@ impl TaskPool { f(scope_ref); // Wait until the scope is complete - block_on(executor.run(async { + block_on(EXECUTOR.run(async { while pending_tasks.get() != 0 { futures_lite::future::yield_now().await; } })); - results + let mut context = Context::from_waker(Waker::noop()); + tasks .take() .into_iter() - .map(|result| result.unwrap()) + .map(|mut task| match Pin::new(&mut task).poll(&mut context) { + Poll::Ready(result) => result, + Poll::Pending => unreachable!(), + }) .collect() } + /// Creates a builder for a new [`Task`] to schedule onto the [`TaskPool`].k + /// + /// # Example + /// + /// ```no_run + /// # async fn my_cool_task() {} + /// # use bevy_tasks::{TaskPool, TaskPriority}; + /// let task_pool = TaskPool::get(); + /// let task = task_pool.builder() + /// .with_priority(TaskPriority::BlockingIO) + /// .spawn(async { + /// my_cool_task + /// }); + /// ``` + pub fn builder(&self) -> TaskBuilder<'_, T> { + TaskBuilder::new(self) + } + /// Spawns a static future onto the thread pool. The returned Task is a future, which can be polled /// to retrieve the output of the original future. Dropping the task will attempt to cancel it. /// It can also be "detached", allowing it to continue running without having to be polled by the @@ -189,26 +195,59 @@ impl TaskPool { &self, future: impl Future + 'static + MaybeSend + MaybeSync, ) -> Task + where + T: 'static + MaybeSend + MaybeSync, + { + self.build().spawn(future) + } + + /// Spawns a static future on the JS event loop. This is exactly the same as [`TaskPool::spawn`]. + pub fn spawn_local( + &self, + future: impl Future + 'static + MaybeSend + MaybeSync, + ) -> Task + where + T: 'static + MaybeSend + MaybeSync, + { + } + + crate::cfg::web! { + if {} else { + pub(crate) fn try_tick_local() -> bool { + crate::cfg::bevy_executor! { + if { + Executor::try_tick_local() + } else { + EXECUTOR.try_tick() + } + } + } + } + } +} + +impl<'a, T> TaskBuilder<'a, T> { + /// Spawns a static future onto the thread pool. The returned Task is a future, which can be polled + /// to retrieve the output of the original future. Dropping the task will attempt to cancel it. + /// It can also be "detached", allowing it to continue running without having to be polled by the + /// end-user. + /// + /// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should be used instead. + pub fn spawn( + &self, + future: impl Future + 'static + MaybeSend + MaybeSync, + ) -> Task where T: 'static + MaybeSend + MaybeSync, { crate::cfg::switch! {{ crate::cfg::web => { Task::wrap_future(future) - } - crate::cfg::std => { - LOCAL_EXECUTOR.with(|executor| { - let task = executor.spawn(future); - // Loop until all tasks are done - while executor.try_tick() {} - - Task::new(task) - }) - } + } _ => { - let task = LOCAL_EXECUTOR.spawn(future); + let task = EXECUTOR.spawn_local(future); // Loop until all tasks are done - while LOCAL_EXECUTOR.try_tick() {} + while Self::try_tick_local() {} Task::new(task) } @@ -225,43 +264,65 @@ impl TaskPool { { self.spawn(future) } +} - /// Runs a function with the local executor. Typically used to tick - /// the local executor on the main thread as it needs to share time with - /// other things. +impl<'a, 'scope, 'env, T: Send + 'scope> ScopeTaskBuilder<'a, 'scope, 'env, T> { + #[expect( + unsafe_code, + reason = "Executor::spawn and ThreadSpawner::spawn_scoped otherwise requires 'static Futures" + )] + /// Spawns a scoped future onto the thread pool. The scope *must* outlive + /// the provided future. The results of the future will be returned as a part of + /// [`TaskPool::scope`]'s return value. /// - /// ``` - /// use bevy_tasks::TaskPool; + /// For futures that should run on the thread `scope` is called on [`Scope::spawn_on_scope`] should be used + /// instead. /// - /// TaskPool::new().with_local_executor(|local_executor| { - /// local_executor.try_tick(); - /// }); - /// ``` - pub fn with_local_executor(&self, f: F) -> R - where - F: FnOnce(&Executor) -> R, - { - crate::cfg::switch! {{ - crate::cfg::std => { - LOCAL_EXECUTOR.with(f) - } - _ => { - f(&LOCAL_EXECUTOR) - } - }} + /// For more information, see [`TaskPool::scope`]. + pub fn spawn + 'scope + Send>(self, f: Fut) { + let task = match self.target { + // SAFETY: The scope call that generated this `Scope` ensures that the created + // Task does not outlive 'scope. + ScopeTaskTarget::Any => unsafe { + self.scope + .executor + .spawn_scoped(AssertUnwindSafe(f).catch_unwind(), Metadata::default()) + .fallible() + }, + // SAFETY: The scope call that generated this `Scope` ensures that the created + // Task does not outlive 'scope. + ScopeTaskTarget::Scope => unsafe { + self.scope + .scope_spawner + .spawn_scoped(AssertUnwindSafe(f).catch_unwind()) + .into_inner() + .fallible() + }, + // SAFETY: The scope call that generated this `Scope` ensures that the created + // Task does not outlive 'scope. + ScopeTaskTarget::External => unsafe { + self.scope + .external_spawner + .spawn_scoped(AssertUnwindSafe(f).catch_unwind()) + .into_inner() + .fallible() + }, + }; + let result = self.scope.spawned.push(task); + debug_assert!(result.is_ok()); } } + /// A `TaskPool` scope for running one or more non-`'static` futures. /// /// For more information, see [`TaskPool::scope`]. -#[derive(Debug)] pub struct Scope<'scope, 'env: 'scope, T> { - executor_ref: &'scope LocalExecutor<'scope>, + executor_ref: &'static Executor, // The number of pending tasks spawned on the scope pending_tasks: &'scope Cell, // Vector to gather results of all futures spawned during scope run - results_ref: &'env RefCell>>, + tasks_ref: &'env RefCell>>, // make `Scope` invariant over 'scope and 'env scope: PhantomData<&'scope mut &'scope ()>, @@ -301,28 +362,33 @@ impl<'scope, 'env, T: Send + 'env> Scope<'scope, 'env, T> { let pending_tasks = self.pending_tasks; pending_tasks.update(|i| i + 1); - // add a spot to keep the result, and record the index - let results_ref = self.results_ref; - let mut results = results_ref.borrow_mut(); - let task_number = results.len(); - results.push(None); - drop(results); - // create the job closure let f = async move { let result = f.await; - // store the result in the allocated slot - let mut results = results_ref.borrow_mut(); - results[task_number] = Some(result); - drop(results); - // decrement the pending tasks count pending_tasks.update(|i| i - 1); + + result }; - // spawn the job itself - self.executor_ref.spawn(f).detach(); + let mut tasks = self.tasks_ref.borrow_mut(); + + #[expect(unsafe_code, reason = "Executor::spawn_local_scoped is unsafe")] + // SAFETY: The surrounding scope will not terminate until all local tasks are done + // ensuring that the borrowed variables do not outlive the detached task. + tasks.push(unsafe { self.executor_ref.spawn_local_scoped(f) }); + } +} + +impl <'scope, 'env: 'scope, T> fmt::Debug for Scope<'scope, 'env, T> +where T: fmt::Debug +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Scope") + .field("pending_tasks", &self.pending_tasks) + .field("tasks_ref", &self.tasks_ref) + .finish() } } diff --git a/crates/bevy_tasks/src/slice.rs b/crates/bevy_tasks/src/slice.rs index a705314a34502..512bea0421b54 100644 --- a/crates/bevy_tasks/src/slice.rs +++ b/crates/bevy_tasks/src/slice.rs @@ -220,7 +220,7 @@ mod tests { #[test] fn test_par_chunks_map() { let v = vec![42; 1000]; - let task_pool = TaskPool::new(); + let task_pool = TaskPoolBuilder::new().build(); let outputs = v.par_splat_map(&task_pool, None, |_, numbers| -> i32 { numbers.iter().sum() }); @@ -236,7 +236,7 @@ mod tests { #[test] fn test_par_chunks_map_mut() { let mut v = vec![42; 1000]; - let task_pool = TaskPool::new(); + let task_pool = TaskPoolBuilder::new().build(); let outputs = v.par_splat_map_mut(&task_pool, None, |_, numbers| -> i32 { for number in numbers.iter_mut() { @@ -257,7 +257,7 @@ mod tests { #[test] fn test_par_chunks_map_index() { let v = vec![1; 1000]; - let task_pool = TaskPool::new(); + let task_pool = TaskPoolBuilder::new().build(); let outputs = v.par_chunk_map(&task_pool, 100, |index, numbers| -> i32 { numbers.iter().sum::() * index as i32 }); diff --git a/crates/bevy_tasks/src/task.rs b/crates/bevy_tasks/src/task.rs index dd649ba47dca3..c28e4170928e8 100644 --- a/crates/bevy_tasks/src/task.rs +++ b/crates/bevy_tasks/src/task.rs @@ -7,6 +7,19 @@ use core::{ use crate::cfg; +crate::cfg::switch! { + crate::cfg::web => { + type TaskInner = async_channel::Receiver>; + } + crate::cfg::bevy_executor => { + use crate::Metadata; + type TaskInner = async_task::Task; + } + _ => { + type TaskInner = async_task::Task; + } +} + /// Wraps `async_executor::Task`, a spawned future. /// /// Tasks are also futures themselves and yield the output of the spawned future. @@ -16,15 +29,7 @@ use crate::cfg; /// /// Tasks that panic get immediately canceled. Awaiting a canceled task also causes a panic. #[must_use = "Tasks are canceled when dropped, use `.detach()` to run them in the background."] -pub struct Task( - cfg::web! { - if { - async_channel::Receiver> - } else { - async_task::Task - } - }, -); +pub struct Task(TaskInner); // Custom constructors for web and non-web platforms cfg::web! { @@ -38,7 +43,9 @@ cfg::web! { spawn_local(async move { // Catch any panics that occur when polling the future so they can // be propagated back to the task handle. - let value = CatchUnwind(AssertUnwindSafe(future)).await; + let value = CatchUnwind { + inner: AssertUnwindSafe(future) + }.await; let _ = sender.send(value); }); Self(receiver) @@ -47,9 +54,15 @@ cfg::web! { } else { impl Task { /// Creates a new task from a given `async_executor::Task` - pub(crate) fn new(task: async_task::Task) -> Self { + #[inline] + pub(crate) fn new(task: TaskInner) -> Self { Self(task) } + + #[inline] + pub(crate) fn into_inner(self) -> TaskInner { + self.0 + } } } } @@ -173,13 +186,17 @@ cfg::web! { type Panic = Box; - #[pin_project::pin_project] - struct CatchUnwind(#[pin] F); + pin_project_lite::pin_project! { + struct CatchUnwind { + #[pin] + inner: F + } + } impl Future for CatchUnwind { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let f = AssertUnwindSafe(|| self.project().0.poll(cx)); + let f = AssertUnwindSafe(|| self.project().inner.poll(cx)); let result = cfg::std! { if { diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index eb6f8502b3d80..aff21ac782653 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -1,20 +1,19 @@ use alloc::{boxed::Box, format, string::String, vec::Vec}; -use core::{future::Future, marker::PhantomData, mem, panic::AssertUnwindSafe}; -use std::{ - thread::{self, JoinHandle}, - thread_local, -}; +use core::{future::Future, marker::PhantomData, mem, num::NonZeroUsize, panic::AssertUnwindSafe}; +use std::{sync::OnceLock, thread::{self, JoinHandle}}; -use crate::executor::FallibleTask; +use crate::{bevy_executor::Executor, Metadata, ScopeTaskBuilder, ScopeTaskTarget, TaskBuilder, TaskPriority}; +use async_task::FallibleTask; use bevy_platform::sync::Arc; use concurrent_queue::ConcurrentQueue; use futures_lite::FutureExt; -use crate::{ - block_on, - thread_executor::{ThreadExecutor, ThreadExecutorTicker}, - Task, -}; +use crate::{block_on, Task}; + +pub use crate::bevy_executor::ThreadSpawner; + +static EXECUTOR: Executor = Executor::new(); +static TASK_POOL: OnceLock = OnceLock::new(); struct CallOnDrop(Option>); @@ -41,6 +40,8 @@ pub struct TaskPoolBuilder { on_thread_spawn: Option>, on_thread_destroy: Option>, + + priority_limits: [Option; TaskPriority::MAX], } impl TaskPoolBuilder { @@ -62,6 +63,12 @@ impl TaskPoolBuilder { self } + /// Sets the limit of how many active threads for a given priority. + pub fn priority_limit(mut self, priority: TaskPriority, limit: Option) -> Self { + self.priority_limits[priority.to_index()] = limit.and_then(NonZeroUsize::new); + self + } + /// Override the name of the threads created for the pool. If set, threads will /// be named ` ()`, i.e. `MyThreadPool (2)` pub fn thread_name(mut self, thread_name: String) -> Self { @@ -117,7 +124,12 @@ impl TaskPoolBuilder { /// Creates a new [`TaskPool`] based on the current options. pub fn build(self) -> TaskPool { - TaskPool::new_internal(self) + #[expect( + unsafe_code, + reason = "Required for priority limit initialization to be both performant and safe." + )] + // SAFETY: The box is unique and is otherwise never going to be called from any other place. + unsafe { TaskPool::new_internal(self, Box::leak(Box::new(Executor::new()))) } } } @@ -134,7 +146,7 @@ impl TaskPoolBuilder { #[derive(Debug)] pub struct TaskPool { /// The executor for the pool. - executor: Arc>, + executor: &'static Executor, // The inner state of the pool. threads: Vec>, @@ -142,14 +154,36 @@ pub struct TaskPool { } impl TaskPool { - thread_local! { - static LOCAL_EXECUTOR: crate::executor::LocalExecutor<'static> = const { crate::executor::LocalExecutor::new() }; - static THREAD_EXECUTOR: Arc> = Arc::new(ThreadExecutor::new()); + /// Creates a [`ThreadSpawner`] for this current thread of execution. + /// Can be used to spawn new tasks to execute exclusively on this thread. + pub fn current_thread_spawner(&self) -> ThreadSpawner { + self.executor.current_thread_spawner() } - /// Each thread should only create one `ThreadExecutor`, otherwise, there are good chances they will deadlock - pub fn get_thread_executor() -> Arc> { - Self::THREAD_EXECUTOR.with(Clone::clone) + /// Attempts to get the global [`TaskPool`] instance, or returns `None` if it is not initialized. + pub fn try_get() -> Option<&'static TaskPool> { + TASK_POOL.get() + } + + /// Gets the global [`TaskPool`] instance. + /// + /// # Panics + /// + /// Panics if the global instance has not been initialized yet. + pub fn get() -> &'static TaskPool { + Self::try_get() + .expect("The TaskPool has not been initialized yet. Please call TaskPool::get_or_init beforehand.") + } + + /// Gets the global [`TaskPool`] instance, or initializes it with `f`. + pub fn get_or_init(f: impl FnOnce() -> TaskPoolBuilder) -> &'static TaskPool { + #[expect( + unsafe_code, + reason = "Required for priority limit initialization to be both performant and safe." + )] + // SAFETY: TASK_POOL is never reset and the OnceLock ensures it's only ever initialized + // once. + TASK_POOL.get_or_init(|| unsafe { Self::new_internal(f(), &EXECUTOR) }) } /// Create a `TaskPool` with the default configuration. @@ -157,10 +191,19 @@ impl TaskPool { TaskPoolBuilder::new().build() } - fn new_internal(builder: TaskPoolBuilder) -> Self { - let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>(); + #[expect( + unsafe_code, + reason = "Required for priority limit initialization to be both performant and safe." + )] + /// # Safety + /// This should only be called once over the lifetime of the application. + unsafe fn new_internal(builder: TaskPoolBuilder, executor: &'static Executor) -> Self { + // SAFETY: The caller is required to ensure that this is only called once per application + // and no threads accessing the Executor are started until later in this very function. + // Thus it's impossible for there to be any aliasing access done here. + unsafe { executor.set_priority_limits(builder.priority_limits); } - let executor = Arc::new(crate::executor::Executor::new()); + let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>(); let num_threads = builder .num_threads @@ -168,7 +211,6 @@ impl TaskPool { let threads = (0..num_threads) .map(|i| { - let ex = Arc::clone(&executor); let shutdown_rx = shutdown_rx.clone(); let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() { @@ -187,28 +229,22 @@ impl TaskPool { thread_builder .spawn(move || { - TaskPool::LOCAL_EXECUTOR.with(|local_executor| { - if let Some(on_thread_spawn) = on_thread_spawn { - on_thread_spawn(); - drop(on_thread_spawn); - } - let _destructor = CallOnDrop(on_thread_destroy); - loop { - let res = std::panic::catch_unwind(|| { - let tick_forever = async move { - loop { - local_executor.tick().await; - } - }; - block_on(ex.run(tick_forever.or(shutdown_rx.recv()))) - }); - if let Ok(value) = res { - // Use unwrap_err because we expect a Closed error - value.unwrap_err(); - break; - } + crate::bevy_executor::install_runtime_into_current_thread(executor); + + if let Some(on_thread_spawn) = on_thread_spawn { + on_thread_spawn(); + drop(on_thread_spawn); + } + let _destructor = CallOnDrop(on_thread_destroy); + loop { + let res = + std::panic::catch_unwind(|| block_on(executor.run(shutdown_rx.recv()))); + if let Ok(value) = res { + // Use unwrap_err because we expect a Closed error + value.unwrap_err(); + break; } - }); + } }) .expect("Failed to spawn thread.") }) @@ -312,57 +348,39 @@ impl TaskPool { F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>), T: Send + 'static, { - Self::THREAD_EXECUTOR.with(|scope_executor| { - self.scope_with_executor_inner(true, scope_executor, scope_executor, f) - }) + let scope_spawner = self.current_thread_spawner(); + self.scope_with_executor_inner(scope_spawner.clone(), scope_spawner, f) } - /// This allows passing an external executor to spawn tasks on. When you pass an external executor - /// [`Scope::spawn_on_scope`] spawns is then run on the thread that [`ThreadExecutor`] is being ticked on. - /// If [`None`] is passed the scope will use a [`ThreadExecutor`] that is ticked on the current thread. - /// - /// When `tick_task_pool_executor` is set to `true`, the multithreaded task stealing executor is ticked on the scope - /// thread. Disabling this can be useful when finishing the scope is latency sensitive. Pulling tasks from - /// global executor can run tasks unrelated to the scope and delay when the scope returns. + /// This allows passing an external [`ThreadSpawner`] to spawn tasks to. When you pass an external spawner + /// [`Scope::spawn_on_scope`] spawns is then run on the thread that [`ThreadSpawner`] originated from. + /// If [`None`] is passed the scope will use a [`ThreadSpawner`] that is ticked on the current thread. /// /// See [`Self::scope`] for more details in general about how scopes work. pub fn scope_with_executor<'env, F, T>( &self, - tick_task_pool_executor: bool, - external_executor: Option<&ThreadExecutor>, + external_spawner: Option, f: F, ) -> Vec where F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>), T: Send + 'static, { - Self::THREAD_EXECUTOR.with(|scope_executor| { - // If an `external_executor` is passed, use that. Otherwise, get the executor stored - // in the `THREAD_EXECUTOR` thread local. - if let Some(external_executor) = external_executor { - self.scope_with_executor_inner( - tick_task_pool_executor, - external_executor, - scope_executor, - f, - ) - } else { - self.scope_with_executor_inner( - tick_task_pool_executor, - scope_executor, - scope_executor, - f, - ) - } - }) + let scope_spawner = self.executor.current_thread_spawner(); + // If an `external_executor` is passed, use that. Otherwise, get the executor stored + // in the `THREAD_EXECUTOR` thread local. + if let Some(external_spawner) = external_spawner { + self.scope_with_executor_inner(external_spawner, scope_spawner, f) + } else { + self.scope_with_executor_inner(scope_spawner.clone(), scope_spawner, f) + } } #[expect(unsafe_code, reason = "Required to transmute lifetimes.")] fn scope_with_executor_inner<'env, F, T>( &self, - tick_task_pool_executor: bool, - external_executor: &ThreadExecutor, - scope_executor: &ThreadExecutor, + external_spawner: ThreadSpawner, + scope_spawner: ThreadSpawner, f: F, ) -> Vec where @@ -376,26 +394,16 @@ impl TaskPool { // transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety. // Any usages of the references passed into `Scope` must be accessed through // the transmuted reference for the rest of this function. - let executor: &crate::executor::Executor = &self.executor; - // SAFETY: As above, all futures must complete in this function so we can change the lifetime - let executor: &'env crate::executor::Executor = unsafe { mem::transmute(executor) }; - // SAFETY: As above, all futures must complete in this function so we can change the lifetime - let external_executor: &'env ThreadExecutor<'env> = - unsafe { mem::transmute(external_executor) }; - // SAFETY: As above, all futures must complete in this function so we can change the lifetime - let scope_executor: &'env ThreadExecutor<'env> = unsafe { mem::transmute(scope_executor) }; - let spawned: ConcurrentQueue>>> = - ConcurrentQueue::unbounded(); + let spawned: ConcurrentQueue> = ConcurrentQueue::unbounded(); // shadow the variable so that the owned value cannot be used for the rest of the function // SAFETY: As above, all futures must complete in this function so we can change the lifetime - let spawned: &'env ConcurrentQueue< - FallibleTask>>, - > = unsafe { mem::transmute(&spawned) }; + let spawned: &'env ConcurrentQueue> = + unsafe { mem::transmute(&spawned) }; let scope = Scope { - executor, - external_executor, - scope_executor, + executor: self.executor, + external_spawner, + scope_spawner, spawned, scope: PhantomData, env: PhantomData, @@ -410,142 +418,36 @@ impl TaskPool { if spawned.is_empty() { Vec::new() } else { - block_on(async move { - let get_results = async { - let mut results = Vec::with_capacity(spawned.len()); - while let Ok(task) = spawned.pop() { - if let Some(res) = task.await { - match res { - Ok(res) => results.push(res), - Err(payload) => std::panic::resume_unwind(payload), - } - } else { - panic!("Failed to catch panic!"); - } + block_on(self.executor.run(async move { + let mut results = Vec::with_capacity(spawned.len()); + while let Ok(task) = spawned.pop() { + match task.await { + Some(Ok(res)) => results.push(res), + Some(Err(payload)) => std::panic::resume_unwind(payload), + None => panic!("Failed to catch panic!"), } - results - }; - - let tick_task_pool_executor = tick_task_pool_executor || self.threads.is_empty(); - - // we get this from a thread local so we should always be on the scope executors thread. - // note: it is possible `scope_executor` and `external_executor` is the same executor, - // in that case, we should only tick one of them, otherwise, it may cause deadlock. - let scope_ticker = scope_executor.ticker().unwrap(); - let external_ticker = if !external_executor.is_same(scope_executor) { - external_executor.ticker() - } else { - None - }; - - match (external_ticker, tick_task_pool_executor) { - (Some(external_ticker), true) => { - Self::execute_global_external_scope( - executor, - external_ticker, - scope_ticker, - get_results, - ) - .await - } - (Some(external_ticker), false) => { - Self::execute_external_scope(external_ticker, scope_ticker, get_results) - .await - } - // either external_executor is none or it is same as scope_executor - (None, true) => { - Self::execute_global_scope(executor, scope_ticker, get_results).await - } - (None, false) => Self::execute_scope(scope_ticker, get_results).await, } - }) + results + })) } } - #[inline] - async fn execute_global_external_scope<'scope, 'ticker, T>( - executor: &'scope crate::executor::Executor<'scope>, - external_ticker: ThreadExecutorTicker<'scope, 'ticker>, - scope_ticker: ThreadExecutorTicker<'scope, 'ticker>, - get_results: impl Future>, - ) -> Vec { - // we restart the executors if a task errors. if a scoped - // task errors it will panic the scope on the call to get_results - let execute_forever = async move { - loop { - let tick_forever = async { - loop { - external_ticker.tick().or(scope_ticker.tick()).await; - } - }; - // we don't care if it errors. If a scoped task errors it will propagate - // to get_results - let _result = AssertUnwindSafe(executor.run(tick_forever)) - .catch_unwind() - .await - .is_ok(); - } - }; - get_results.or(execute_forever).await - } - - #[inline] - async fn execute_external_scope<'scope, 'ticker, T>( - external_ticker: ThreadExecutorTicker<'scope, 'ticker>, - scope_ticker: ThreadExecutorTicker<'scope, 'ticker>, - get_results: impl Future>, - ) -> Vec { - let execute_forever = async { - loop { - let tick_forever = async { - loop { - external_ticker.tick().or(scope_ticker.tick()).await; - } - }; - let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok(); - } - }; - get_results.or(execute_forever).await - } - - #[inline] - async fn execute_global_scope<'scope, 'ticker, T>( - executor: &'scope crate::executor::Executor<'scope>, - scope_ticker: ThreadExecutorTicker<'scope, 'ticker>, - get_results: impl Future>, - ) -> Vec { - let execute_forever = async { - loop { - let tick_forever = async { - loop { - scope_ticker.tick().await; - } - }; - let _result = AssertUnwindSafe(executor.run(tick_forever)) - .catch_unwind() - .await - .is_ok(); - } - }; - get_results.or(execute_forever).await - } - - #[inline] - async fn execute_scope<'scope, 'ticker, T>( - scope_ticker: ThreadExecutorTicker<'scope, 'ticker>, - get_results: impl Future>, - ) -> Vec { - let execute_forever = async { - loop { - let tick_forever = async { - loop { - scope_ticker.tick().await; - } - }; - let _result = AssertUnwindSafe(tick_forever).catch_unwind().await.is_ok(); - } - }; - get_results.or(execute_forever).await + /// Creates a builder for a new [`Task`] to schedule onto the [`TaskPool`].k + /// + /// # Example + /// + /// ```no_run + /// # async fn my_cool_task() {} + /// # use bevy_tasks::{TaskPool, TaskPriority}; + /// let task_pool = TaskPool::get(); + /// let task = task_pool.builder() + /// .with_priority(TaskPriority::BlockingIO) + /// .spawn(async { + /// my_cool_task + /// }); + /// ``` + pub fn builder(&self) -> TaskBuilder<'_, T> { + TaskBuilder::new(self) } /// Spawns a static future onto the thread pool. The returned [`Task`] is a @@ -556,11 +458,13 @@ impl TaskPool { /// /// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should /// be used instead. + /// + /// This is a shorthand for `self.builder().spawn(future)`. pub fn spawn(&self, future: impl Future + Send + 'static) -> Task where T: Send + 'static, { - Task::new(self.executor.spawn(future)) + self.builder().spawn(future) } /// Spawns a static future on the thread-local async executor for the @@ -574,35 +478,17 @@ impl TaskPool { /// /// Users should generally prefer to use [`TaskPool::spawn`] instead, /// unless the provided future is not `Send`. + /// + /// This is a shorthand for `self.builder().spawn(future)`. pub fn spawn_local(&self, future: impl Future + 'static) -> Task where T: 'static, { - Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future))) + self.builder().spawn_local(future) } - /// Runs a function with the local executor. Typically used to tick - /// the local executor on the main thread as it needs to share time with - /// other things. - /// - /// ``` - /// use bevy_tasks::TaskPool; - /// - /// TaskPool::new().with_local_executor(|local_executor| { - /// local_executor.try_tick(); - /// }); - /// ``` - pub fn with_local_executor(&self, f: F) -> R - where - F: FnOnce(&crate::executor::LocalExecutor) -> R, - { - Self::LOCAL_EXECUTOR.with(f) - } -} - -impl Default for TaskPool { - fn default() -> Self { - Self::new() + pub(crate) fn try_tick_local() -> bool { + Executor::try_tick_local() } } @@ -620,70 +506,42 @@ impl Drop for TaskPool { } } +type ScopeTask = FallibleTask>, Metadata>; + /// A [`TaskPool`] scope for running one or more non-`'static` futures. /// /// For more information, see [`TaskPool::scope`]. #[derive(Debug)] pub struct Scope<'scope, 'env: 'scope, T> { - executor: &'scope crate::executor::Executor<'scope>, - external_executor: &'scope ThreadExecutor<'scope>, - scope_executor: &'scope ThreadExecutor<'scope>, - spawned: &'scope ConcurrentQueue>>>, + executor: &'static Executor, + external_spawner: ThreadSpawner, + scope_spawner: ThreadSpawner, + spawned: &'scope ConcurrentQueue>, // make `Scope` invariant over 'scope and 'env scope: PhantomData<&'scope mut &'scope ()>, env: PhantomData<&'env mut &'env ()>, } impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> { + /// Creates a builder to spawn a scoped future to schedule onto the [`TaskPool`]. + /// The scope *must* outlive the provided future. The results of the future will + /// be returned as a part of [`TaskPool::scope`]'s return value. + pub fn builder(&self) -> ScopeTaskBuilder<'_, 'scope, 'env, T> { + ScopeTaskBuilder::new(self) + } + /// Spawns a scoped future onto the thread pool. The scope *must* outlive /// the provided future. The results of the future will be returned as a part of /// [`TaskPool::scope`]'s return value. /// - /// For futures that should run on the thread `scope` is called on [`Scope::spawn_on_scope`] should be used - /// instead. + /// For futures that should run on the thread `scope` is called on [`Scope::builder`] should + /// be used instead, with [`ScopeTaskBuilder::with_target`] to target specific thread. /// /// For more information, see [`TaskPool::scope`]. + /// + /// This is a shorthand for `scope.builder().spawn(f)`. pub fn spawn + 'scope + Send>(&self, f: Fut) { - let task = self - .executor - .spawn(AssertUnwindSafe(f).catch_unwind()) - .fallible(); - // ConcurrentQueue only errors when closed or full, but we never - // close and use an unbounded queue, so it is safe to unwrap - self.spawned.push(task).unwrap(); - } - - /// Spawns a scoped future onto the thread the scope is run on. The scope *must* outlive - /// the provided future. The results of the future will be returned as a part of - /// [`TaskPool::scope`]'s return value. Users should generally prefer to use - /// [`Scope::spawn`] instead, unless the provided future needs to run on the scope's thread. - /// - /// For more information, see [`TaskPool::scope`]. - pub fn spawn_on_scope + 'scope + Send>(&self, f: Fut) { - let task = self - .scope_executor - .spawn(AssertUnwindSafe(f).catch_unwind()) - .fallible(); - // ConcurrentQueue only errors when closed or full, but we never - // close and use an unbounded queue, so it is safe to unwrap - self.spawned.push(task).unwrap(); - } - - /// Spawns a scoped future onto the thread of the external thread executor. - /// This is typically the main thread. The scope *must* outlive - /// the provided future. The results of the future will be returned as a part of - /// [`TaskPool::scope`]'s return value. Users should generally prefer to use - /// [`Scope::spawn`] instead, unless the provided future needs to run on the external thread. - /// - /// For more information, see [`TaskPool::scope`]. - pub fn spawn_on_external + 'scope + Send>(&self, f: Fut) { - let task = self - .external_executor - .spawn(AssertUnwindSafe(f).catch_unwind()) - .fallible(); - // ConcurrentQueue only errors when closed or full, but we never - // close and use an unbounded queue, so it is safe to unwrap - self.spawned.push(task).unwrap(); + self.builder().spawn(f); } } @@ -700,6 +558,88 @@ where } } +impl<'a, T> TaskBuilder<'a, T> { + /// Spawns a static future onto the thread pool. The returned [`Task`] is a + /// future that can be polled for the result. It can also be canceled and + /// "detached", allowing the task to continue running even if dropped. In + /// any case, the pool will execute the task even without polling by the + /// end-user. + /// + /// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should + /// be used instead. + pub fn spawn(self, future: impl Future + Send + 'static) -> Task + where + T: Send + 'static, + { + Task::new(self.task_pool.executor.spawn(future, self.build_metadata())) + } + + /// Spawns a static future on the thread-local async executor for the + /// current thread. The task will run entirely on the thread the task was + /// spawned on. + /// + /// The returned [`Task`] is a future that can be polled for the + /// result. It can also be canceled and "detached", allowing the task to + /// continue running even if dropped. In any case, the pool will execute the + /// task even without polling by the end-user. + /// + /// Users should generally prefer to use [`TaskPool::spawn`] instead, + /// unless the provided future is not `Send`. + pub fn spawn_local(self, future: impl Future + 'static) -> Task + where + T: 'static, + { + Task::new(self.task_pool.executor.spawn_local(future, self.build_metadata())) + } +} + +impl<'a, 'scope, 'env, T: Send + 'scope> ScopeTaskBuilder<'a, 'scope, 'env, T> { + #[expect( + unsafe_code, + reason = "Executor::spawn and ThreadSpawner::spawn_scoped otherwise requires 'static Futures" + )] + /// Spawns a scoped future onto the thread pool. The scope *must* outlive + /// the provided future. The results of the future will be returned as a part of + /// [`TaskPool::scope`]'s return value. + /// + /// For futures that should run on the thread `scope` is called on [`Scope::spawn_on_scope`] should be used + /// instead. + /// + /// For more information, see [`TaskPool::scope`]. + pub fn spawn + 'scope + Send>(self, f: Fut) { + let task = match self.target { + // SAFETY: The scope call that generated this `Scope` ensures that the created + // Task does not outlive 'scope. + ScopeTaskTarget::Any => unsafe { + self.scope + .executor + .spawn_scoped(AssertUnwindSafe(f).catch_unwind(), Metadata::default()) + .fallible() + }, + // SAFETY: The scope call that generated this `Scope` ensures that the created + // Task does not outlive 'scope. + ScopeTaskTarget::Scope => unsafe { + self.scope + .scope_spawner + .spawn_scoped(AssertUnwindSafe(f).catch_unwind()) + .into_inner() + .fallible() + }, + // SAFETY: The scope call that generated this `Scope` ensures that the created + // Task does not outlive 'scope. + ScopeTaskTarget::External => unsafe { + self.scope + .external_spawner + .spawn_scoped(AssertUnwindSafe(f).catch_unwind()) + .into_inner() + .fallible() + }, + }; + let result = self.scope.spawned.push(task); + debug_assert!(result.is_ok()); + } +} + #[cfg(test)] mod tests { use super::*; @@ -708,7 +648,7 @@ mod tests { #[test] fn test_spawn() { - let pool = TaskPool::new(); + let pool = TaskPool::get_or_init(TaskPoolBuilder::default); let foo = Box::new(42); let foo = &*foo; @@ -740,6 +680,7 @@ mod tests { #[test] fn test_thread_callbacks() { let counter = Arc::new(AtomicI32::new(0)); + static EX: Executor = Executor::new(); let start_counter = counter.clone(); { let barrier = Arc::new(Barrier::new(11)); @@ -790,7 +731,7 @@ mod tests { #[test] fn test_mixed_spawn_on_scope_and_spawn() { - let pool = TaskPool::new(); + let pool = TaskPool::get_or_init(TaskPoolBuilder::default); let foo = Box::new(42); let foo = &*foo; @@ -812,7 +753,10 @@ mod tests { }); } else { let count_clone = local_count.clone(); - scope.spawn_on_scope(async move { + scope + .builder() + .with_target(ScopeTaskTarget::Scope) + .spawn(async move { if *foo != 42 { panic!("not 42!?!?") } else { @@ -835,7 +779,7 @@ mod tests { #[test] fn test_thread_locality() { - let pool = Arc::new(TaskPool::new()); + let pool = TaskPool::get_or_init(TaskPoolBuilder::default); let count = Arc::new(AtomicI32::new(0)); let barrier = Arc::new(Barrier::new(101)); let thread_check_failed = Arc::new(AtomicBool::new(false)); @@ -843,17 +787,18 @@ mod tests { for _ in 0..100 { let inner_barrier = barrier.clone(); let count_clone = count.clone(); - let inner_pool = pool.clone(); let inner_thread_check_failed = thread_check_failed.clone(); thread::spawn(move || { - inner_pool.scope(|scope| { + pool.scope(|scope| { let inner_count_clone = count_clone.clone(); scope.spawn(async move { inner_count_clone.fetch_add(1, Ordering::Release); }); let spawner = thread::current().id(); let inner_count_clone = count_clone.clone(); - scope.spawn_on_scope(async move { + scope.builder() + .with_target(ScopeTaskTarget::Scope) + .spawn(async move { inner_count_clone.fetch_add(1, Ordering::Release); if thread::current().id() != spawner { // NOTE: This check is using an atomic rather than simply panicking the @@ -872,7 +817,7 @@ mod tests { #[test] fn test_nested_spawn() { - let pool = TaskPool::new(); + let pool = TaskPool::get_or_init(TaskPoolBuilder::default); let foo = Box::new(42); let foo = &*foo; @@ -910,7 +855,7 @@ mod tests { #[test] fn test_nested_locality() { - let pool = Arc::new(TaskPool::new()); + let pool = TaskPool::get_or_init(TaskPoolBuilder::default); let count = Arc::new(AtomicI32::new(0)); let barrier = Arc::new(Barrier::new(101)); let thread_check_failed = Arc::new(AtomicBool::new(false)); @@ -918,17 +863,18 @@ mod tests { for _ in 0..100 { let inner_barrier = barrier.clone(); let count_clone = count.clone(); - let inner_pool = pool.clone(); let inner_thread_check_failed = thread_check_failed.clone(); thread::spawn(move || { - inner_pool.scope(|scope| { + pool.scope(|scope| { let spawner = thread::current().id(); let inner_count_clone = count_clone.clone(); scope.spawn(async move { inner_count_clone.fetch_add(1, Ordering::Release); // spawning on the scope from another thread runs the futures on the scope's thread - scope.spawn_on_scope(async move { + scope.builder() + .with_target(ScopeTaskTarget::Scope) + .spawn(async move { inner_count_clone.fetch_add(1, Ordering::Release); if thread::current().id() != spawner { // NOTE: This check is using an atomic rather than simply panicking the @@ -949,7 +895,7 @@ mod tests { // This test will often freeze on other executors. #[test] fn test_nested_scopes() { - let pool = TaskPool::new(); + let pool = TaskPool::get_or_init(TaskPoolBuilder::default); let count = Arc::new(AtomicI32::new(0)); pool.scope(|scope| { diff --git a/crates/bevy_tasks/src/thread_executor.rs b/crates/bevy_tasks/src/thread_executor.rs deleted file mode 100644 index 86d2ab280d87c..0000000000000 --- a/crates/bevy_tasks/src/thread_executor.rs +++ /dev/null @@ -1,133 +0,0 @@ -use core::marker::PhantomData; -use std::thread::{self, ThreadId}; - -use crate::executor::Executor; -use async_task::Task; -use futures_lite::Future; - -/// An executor that can only be ticked on the thread it was instantiated on. But -/// can spawn `Send` tasks from other threads. -/// -/// # Example -/// ``` -/// # use std::sync::{Arc, atomic::{AtomicI32, Ordering}}; -/// use bevy_tasks::ThreadExecutor; -/// -/// let thread_executor = ThreadExecutor::new(); -/// let count = Arc::new(AtomicI32::new(0)); -/// -/// // create some owned values that can be moved into another thread -/// let count_clone = count.clone(); -/// -/// std::thread::scope(|scope| { -/// scope.spawn(|| { -/// // we cannot get the ticker from another thread -/// let not_thread_ticker = thread_executor.ticker(); -/// assert!(not_thread_ticker.is_none()); -/// -/// // but we can spawn tasks from another thread -/// thread_executor.spawn(async move { -/// count_clone.fetch_add(1, Ordering::Relaxed); -/// }).detach(); -/// }); -/// }); -/// -/// // the tasks do not make progress unless the executor is manually ticked -/// assert_eq!(count.load(Ordering::Relaxed), 0); -/// -/// // tick the ticker until task finishes -/// let thread_ticker = thread_executor.ticker().unwrap(); -/// thread_ticker.try_tick(); -/// assert_eq!(count.load(Ordering::Relaxed), 1); -/// ``` -#[derive(Debug)] -pub struct ThreadExecutor<'task> { - executor: Executor<'task>, - thread_id: ThreadId, -} - -impl<'task> Default for ThreadExecutor<'task> { - fn default() -> Self { - Self { - executor: Executor::new(), - thread_id: thread::current().id(), - } - } -} - -impl<'task> ThreadExecutor<'task> { - /// create a new [`ThreadExecutor`] - pub fn new() -> Self { - Self::default() - } - - /// Spawn a task on the thread executor - pub fn spawn( - &self, - future: impl Future + Send + 'task, - ) -> Task { - self.executor.spawn(future) - } - - /// Gets the [`ThreadExecutorTicker`] for this executor. - /// Use this to tick the executor. - /// It only returns the ticker if it's on the thread the executor was created on - /// and returns `None` otherwise. - pub fn ticker<'ticker>(&'ticker self) -> Option> { - if thread::current().id() == self.thread_id { - return Some(ThreadExecutorTicker { - executor: self, - _marker: PhantomData, - }); - } - None - } - - /// Returns true if `self` and `other`'s executor is same - pub fn is_same(&self, other: &Self) -> bool { - core::ptr::eq(self, other) - } -} - -/// Used to tick the [`ThreadExecutor`]. The executor does not -/// make progress unless it is manually ticked on the thread it was -/// created on. -#[derive(Debug)] -pub struct ThreadExecutorTicker<'task, 'ticker> { - executor: &'ticker ThreadExecutor<'task>, - // make type not send or sync - _marker: PhantomData<*const ()>, -} - -impl<'task, 'ticker> ThreadExecutorTicker<'task, 'ticker> { - /// Tick the thread executor. - pub async fn tick(&self) { - self.executor.executor.tick().await; - } - - /// Synchronously try to tick a task on the executor. - /// Returns false if does not find a task to tick. - pub fn try_tick(&self) -> bool { - self.executor.executor.try_tick() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use alloc::sync::Arc; - - #[test] - fn test_ticker() { - let executor = Arc::new(ThreadExecutor::new()); - let ticker = executor.ticker(); - assert!(ticker.is_some()); - - thread::scope(|s| { - s.spawn(|| { - let ticker = executor.ticker(); - assert!(ticker.is_none()); - }); - }); - } -} diff --git a/crates/bevy_tasks/src/usages.rs b/crates/bevy_tasks/src/usages.rs index 40cabd76ac5e6..2150395b2eae0 100644 --- a/crates/bevy_tasks/src/usages.rs +++ b/crates/bevy_tasks/src/usages.rs @@ -1,79 +1,4 @@ use super::TaskPool; -use bevy_platform::sync::OnceLock; -use core::ops::Deref; - -macro_rules! taskpool { - ($(#[$attr:meta])* ($static:ident, $type:ident)) => { - static $static: OnceLock<$type> = OnceLock::new(); - - $(#[$attr])* - #[derive(Debug)] - pub struct $type(TaskPool); - - impl $type { - #[doc = concat!(" Gets the global [`", stringify!($type), "`] instance, or initializes it with `f`.")] - pub fn get_or_init(f: impl FnOnce() -> TaskPool) -> &'static Self { - $static.get_or_init(|| Self(f())) - } - - #[doc = concat!(" Attempts to get the global [`", stringify!($type), "`] instance, \ - or returns `None` if it is not initialized.")] - pub fn try_get() -> Option<&'static Self> { - $static.get() - } - - #[doc = concat!(" Gets the global [`", stringify!($type), "`] instance.")] - #[doc = ""] - #[doc = " # Panics"] - #[doc = " Panics if the global instance has not been initialized yet."] - pub fn get() -> &'static Self { - $static.get().expect( - concat!( - "The ", - stringify!($type), - " has not been initialized yet. Please call ", - stringify!($type), - "::get_or_init beforehand." - ) - ) - } - } - - impl Deref for $type { - type Target = TaskPool; - - fn deref(&self) -> &Self::Target { - &self.0 - } - } - }; -} - -taskpool! { - /// A newtype for a task pool for CPU-intensive work that must be completed to - /// deliver the next frame - /// - /// See [`TaskPool`] documentation for details on Bevy tasks. - /// [`AsyncComputeTaskPool`] should be preferred if the work does not have to be - /// completed before the next frame. - (COMPUTE_TASK_POOL, ComputeTaskPool) -} - -taskpool! { - /// A newtype for a task pool for CPU-intensive work that may span across multiple frames - /// - /// See [`TaskPool`] documentation for details on Bevy tasks. - /// Use [`ComputeTaskPool`] if the work must be complete before advancing to the next frame. - (ASYNC_COMPUTE_TASK_POOL, AsyncComputeTaskPool) -} - -taskpool! { - /// A newtype for a task pool for IO-intensive work (i.e. tasks that spend very little time in a - /// "woken" state) - /// - /// See [`TaskPool`] documentation for details on Bevy tasks. - (IO_TASK_POOL, IoTaskPool) -} crate::cfg::web! { if {} else { @@ -84,26 +9,11 @@ crate::cfg::web! { /// /// This function *must* be called on the main thread, or the task pools will not be updated appropriately. pub fn tick_global_task_pools_on_main_thread() { - COMPUTE_TASK_POOL - .get() - .unwrap() - .with_local_executor(|compute_local_executor| { - ASYNC_COMPUTE_TASK_POOL - .get() - .unwrap() - .with_local_executor(|async_local_executor| { - IO_TASK_POOL - .get() - .unwrap() - .with_local_executor(|io_local_executor| { - for _ in 0..100 { - compute_local_executor.try_tick(); - async_local_executor.try_tick(); - io_local_executor.try_tick(); - } - }); - }); - }); + for _ in 0..100 { + if !TaskPool::try_tick_local() { + break; + } + } } } } diff --git a/crates/bevy_transform/Cargo.toml b/crates/bevy_transform/Cargo.toml index 4dc23e6d881ef..284b5e9634218 100644 --- a/crates/bevy_transform/Cargo.toml +++ b/crates/bevy_transform/Cargo.toml @@ -33,7 +33,7 @@ approx = "0.5.1" [features] # Turning off default features leaves you with a barebones # definition of transform. -default = ["std", "bevy-support", "bevy_reflect", "async_executor"] +default = ["std", "bevy-support", "bevy_reflect"] # Functionality @@ -55,12 +55,6 @@ bevy_reflect = [ "bevy_app/bevy_reflect", ] -# Executor Backend - -## Uses `async-executor` as a task execution backend. -## This backend is incompatible with `no_std` targets. -async_executor = ["std", "bevy_tasks/async_executor"] - # Platform Compatibility ## Allows access to the `std` crate. Enabling this feature will prevent compilation diff --git a/crates/bevy_transform/src/systems.rs b/crates/bevy_transform/src/systems.rs index 62038b37ed90f..1781481225da1 100644 --- a/crates/bevy_transform/src/systems.rs +++ b/crates/bevy_transform/src/systems.rs @@ -252,7 +252,7 @@ mod parallel { // TODO: this implementation could be used in no_std if there are equivalents of these. use alloc::{sync::Arc, vec::Vec}; use bevy_ecs::{entity::UniqueEntityIter, prelude::*, system::lifetimeless::Read}; - use bevy_tasks::{ComputeTaskPool, TaskPool}; + use bevy_tasks::{TaskPool, TaskPoolBuilder}; use bevy_utils::Parallel; use core::sync::atomic::{AtomicI32, Ordering}; use std::sync::{ @@ -320,7 +320,7 @@ mod parallel { } // Spawn workers on the task pool to recursively propagate the hierarchy in parallel. - let task_pool = ComputeTaskPool::get_or_init(TaskPool::default); + let task_pool = TaskPool::get_or_init(TaskPoolBuilder::default); task_pool.scope(|s| { (1..task_pool.thread_num()) // First worker is run locally instead of the task pool. .for_each(|_| s.spawn(async { propagation_worker(&queue, &nodes) })); @@ -559,13 +559,13 @@ mod test { use bevy_app::prelude::*; use bevy_ecs::{prelude::*, world::CommandQueue}; use bevy_math::{vec3, Vec3}; - use bevy_tasks::{ComputeTaskPool, TaskPool}; + use bevy_tasks::{TaskPool, TaskPoolBuilder}; use crate::systems::*; #[test] fn correct_parent_removed() { - ComputeTaskPool::get_or_init(TaskPool::default); + TaskPool::get_or_init(TaskPoolBuilder::default); let mut world = World::default(); let offset_global_transform = |offset| GlobalTransform::from(Transform::from_xyz(offset, offset, offset)); @@ -626,7 +626,7 @@ mod test { #[test] fn did_propagate() { - ComputeTaskPool::get_or_init(TaskPool::default); + TaskPool::get_or_init(TaskPoolBuilder::default); let mut world = World::default(); let mut schedule = Schedule::default(); @@ -702,7 +702,7 @@ mod test { #[test] fn correct_children() { - ComputeTaskPool::get_or_init(TaskPool::default); + TaskPool::get_or_init(TaskPoolBuilder::default); let mut world = World::default(); let mut schedule = Schedule::default(); @@ -783,7 +783,7 @@ mod test { #[test] fn correct_transforms_when_no_children() { let mut app = App::new(); - ComputeTaskPool::get_or_init(TaskPool::default); + TaskPool::get_or_init(TaskPoolBuilder::default); app.add_systems( Update, @@ -834,7 +834,7 @@ mod test { #[test] #[should_panic] fn panic_when_hierarchy_cycle() { - ComputeTaskPool::get_or_init(TaskPool::default); + TaskPool::get_or_init(TaskPoolBuilder::default); // We cannot directly edit ChildOf and Children, so we use a temp world to break the // hierarchy's invariants. let mut temp = World::new(); diff --git a/docs/cargo_features.md b/docs/cargo_features.md index 967f1eaf5cd8f..b2bd202e70c56 100644 --- a/docs/cargo_features.md +++ b/docs/cargo_features.md @@ -14,7 +14,6 @@ The default feature set enables most of the expected features of a game engine, |android-game-activity|Android GameActivity support. Default, choose between this and `android-native-activity`.| |android_shared_stdcxx|Enable using a shared stdlib for cxx on Android| |animation|Enable animation support, and glTF animation loading| -|async_executor|Uses `async-executor` as a task execution backend.| |bevy_animation|Provides animation functionality| |bevy_anti_aliasing|Provides various anti aliasing solutions| |bevy_asset|Provides asset functionality| diff --git a/examples/README.md b/examples/README.md index 51f59f0a9ebe3..7da8f9af96489 100644 --- a/examples/README.md +++ b/examples/README.md @@ -265,7 +265,7 @@ Example | Description Example | Description --- | --- -[Async Compute](../examples/async_tasks/async_compute.rs) | How to use `AsyncComputeTaskPool` to complete longer running tasks +[Blocking Compute](../examples/async_tasks/blocking_compute.rs) | How to use `TaskPool` to complete longer running tasks [External Source of Data on an External Thread](../examples/async_tasks/external_source_external_thread.rs) | How to use an external thread to run an infinite task and communicate with a channel ### Audio diff --git a/examples/animation/animation_graph.rs b/examples/animation/animation_graph.rs index e511a1bb7faa4..cd7bc001cccd4 100644 --- a/examples/animation/animation_graph.rs +++ b/examples/animation/animation_graph.rs @@ -16,7 +16,10 @@ use argh::FromArgs; #[cfg(not(target_arch = "wasm32"))] use { - bevy::{asset::io::file::FileAssetReader, tasks::IoTaskPool}, + bevy::{ + asset::io::file::FileAssetReader, + tasks::{TaskPool, TaskPriority}, + }, ron::ser::PrettyConfig, std::{fs::File, path::Path}, }; @@ -176,9 +179,13 @@ fn setup_assets_programmatically( // If asked to save, do so. #[cfg(not(target_arch = "wasm32"))] if _save { + use bevy::tasks::TaskPriority; + let animation_graph = animation_graph.clone(); - IoTaskPool::get() + TaskPool::get() + .builder() + .with_priority(TaskPriority::BlockingIO) .spawn(async move { use std::io::Write; diff --git a/examples/asset/multi_asset_sync.rs b/examples/asset/multi_asset_sync.rs index 83add4ba3c016..86d1f611fed42 100644 --- a/examples/asset/multi_asset_sync.rs +++ b/examples/asset/multi_asset_sync.rs @@ -9,7 +9,7 @@ use std::{ }, }; -use bevy::{gltf::Gltf, prelude::*, tasks::AsyncComputeTaskPool}; +use bevy::{gltf::Gltf, prelude::*, tasks::TaskPool}; use event_listener::Event; use futures_lite::Future; @@ -29,7 +29,7 @@ fn main() { // This approach polls a value in a system. .add_systems(Update, wait_on_load.run_if(assets_loaded)) // This showcases how to wait for assets using async - // by spawning a `Future` in `AsyncComputeTaskPool`. + // by spawning a `Future` in `TaskPool`. .add_systems( Update, get_async_loading_state.run_if(in_state(LoadingState::Loading)), @@ -158,7 +158,7 @@ fn setup_assets(mut commands: Commands, asset_server: Res) { commands.insert_resource(AsyncLoadingState(loading_state.clone())); // await the `AssetBarrierFuture`. - AsyncComputeTaskPool::get() + TaskPool::get() .spawn(async move { future.await; // Notify via `AsyncLoadingState` diff --git a/examples/async_tasks/async_compute.rs b/examples/async_tasks/blocking_compute.rs similarity index 62% rename from examples/async_tasks/async_compute.rs rename to examples/async_tasks/blocking_compute.rs index 7e24525cb6230..bd1d5f88e3a03 100644 --- a/examples/async_tasks/async_compute.rs +++ b/examples/async_tasks/blocking_compute.rs @@ -1,10 +1,10 @@ -//! This example shows how to use the ECS and the [`AsyncComputeTaskPool`] +//! This example shows how to use the ECS and the [`TaskPool`] //! to spawn, poll, and complete tasks across systems and system ticks. use bevy::{ ecs::{system::SystemState, world::CommandQueue}, prelude::*, - tasks::{block_on, futures_lite::future, AsyncComputeTaskPool, Task}, + tasks::{block_on, futures_lite::future, Task, TaskPool}, }; use rand::Rng; use std::time::Duration; @@ -50,51 +50,56 @@ struct ComputeTransform(Task); /// system, [`handle_tasks`], will poll the spawned tasks on subsequent /// frames/ticks, and use the results to spawn cubes fn spawn_tasks(mut commands: Commands) { - let thread_pool = AsyncComputeTaskPool::get(); + let thread_pool = TaskPool::get(); for x in 0..NUM_CUBES { for y in 0..NUM_CUBES { for z in 0..NUM_CUBES { - // Spawn new task on the AsyncComputeTaskPool; the task will be + // Spawn new task on the TaskPool; the task will be // executed in the background, and the Task future returned by // spawn() can be used to poll for the result let entity = commands.spawn_empty().id(); - let task = thread_pool.spawn(async move { - let duration = Duration::from_secs_f32(rand::rng().random_range(0.05..5.0)); - - // Pretend this is a time-intensive function. :) - async_std::task::sleep(duration).await; - - // Such hard work, all done! - let transform = Transform::from_xyz(x as f32, y as f32, z as f32); - let mut command_queue = CommandQueue::default(); - - // we use a raw command queue to pass a FnOnce(&mut World) back to be - // applied in a deferred manner. - command_queue.push(move |world: &mut World| { - let (box_mesh_handle, box_material_handle) = { - let mut system_state = SystemState::<( - Res, - Res, - )>::new(world); - let (box_mesh_handle, box_material_handle) = - system_state.get_mut(world); - - (box_mesh_handle.clone(), box_material_handle.clone()) - }; - - world - .entity_mut(entity) - // Add our new `Mesh3d` and `MeshMaterial3d` to our tagged entity - .insert(( - Mesh3d(box_mesh_handle), - MeshMaterial3d(box_material_handle), - transform, - )); + let task = thread_pool + .builder() + .with_priority(TaskPriority::BlockingCompute) + .spawn(async move { + let duration = Duration::from_secs_f32(rand::rng().random_range(0.05..5.0)); + + // Pretend this is a time-intensive function. :) + async_std::task::sleep(duration).await; + + // Such hard work, all done! + let transform = Transform::from_xyz(x as f32, y as f32, z as f32); + let mut command_queue = CommandQueue::default(); + + // we use a raw command queue to pass a FnOnce(&mut World) back to be + // applied in a deferred manner. + command_queue.push(move |world: &mut World| { + let (box_mesh_handle, box_material_handle) = { + let mut system_state = SystemState::<( + Res, + Res, + )>::new( + world + ); + let (box_mesh_handle, box_material_handle) = + system_state.get_mut(world); + + (box_mesh_handle.clone(), box_material_handle.clone()) + }; + + world + .entity_mut(entity) + // Add our new `Mesh3d` and `MeshMaterial3d` to our tagged entity + .insert(( + Mesh3d(box_mesh_handle), + MeshMaterial3d(box_material_handle), + transform, + )); + }); + + command_queue }); - command_queue - }); - // Add our new task as a component commands.entity(entity).insert(ComputeTransform(task)); } diff --git a/examples/ecs/parallel_query.rs b/examples/ecs/parallel_query.rs index 6ebd28ea5065e..245380935e64f 100644 --- a/examples/ecs/parallel_query.rs +++ b/examples/ecs/parallel_query.rs @@ -27,7 +27,7 @@ fn spawn_system(mut commands: Commands, asset_server: Res) { // Move sprites according to their velocity fn move_system(mut sprites: Query<(&mut Transform, &Velocity)>) { // Compute the new location of each sprite in parallel on the - // ComputeTaskPool + // TaskPool // // This example is only for demonstrative purposes. Using a // ParallelIterator for an inexpensive operation like addition on only 128 diff --git a/examples/scene/scene.rs b/examples/scene/scene.rs index 2fba727a82f9a..1778da0956427 100644 --- a/examples/scene/scene.rs +++ b/examples/scene/scene.rs @@ -20,11 +20,11 @@ //! # Note on working with files //! //! The saving behavior uses the standard filesystem APIs, which are blocking, so it -//! utilizes a thread pool (`IoTaskPool`) to avoid stalling the main thread. This +//! utilizes a thread pool (`TaskPool`) to avoid stalling the main thread. This //! won't work on WASM because WASM typically doesn't have direct filesystem access. //! -use bevy::{asset::LoadState, prelude::*, tasks::IoTaskPool}; +use bevy::{asset::LoadState, prelude::*, tasks::TaskPool}; use core::time::Duration; use std::{fs::File, io::Write}; @@ -195,7 +195,9 @@ fn save_scene_system(world: &mut World) { // // This can't work in Wasm as there is no filesystem access. #[cfg(not(target_arch = "wasm32"))] - IoTaskPool::get() + TaskPool::get() + .builder() + .with_priority(TaskPriority::BlcokingIO) .spawn(async move { // Write the scene RON data to file File::create(format!("assets/{NEW_SCENE_FILE_PATH}"))