diff --git a/crates/bevy_input/src/keyboard.rs b/crates/bevy_input/src/keyboard.rs index 83b35fe47e0a8..35e3465534a0e 100644 --- a/crates/bevy_input/src/keyboard.rs +++ b/crates/bevy_input/src/keyboard.rs @@ -4,6 +4,7 @@ use bevy_reflect::{FromReflect, Reflect}; #[cfg(feature = "serialize")] use bevy_reflect::{ReflectDeserialize, ReflectSerialize}; +use bevy_utils::IterableEnum; /// A keyboard input event. /// @@ -71,7 +72,9 @@ pub fn keyboard_input_system( /// ## Updating /// /// The resource is updated inside of the [`keyboard_input_system`](crate::keyboard::keyboard_input_system). -#[derive(Debug, Hash, Ord, PartialOrd, PartialEq, Eq, Clone, Copy, Reflect, FromReflect)] +#[derive( + Debug, Hash, Ord, PartialOrd, PartialEq, Eq, Clone, Copy, Reflect, FromReflect, IterableEnum, +)] #[reflect(Debug, Hash, PartialEq)] #[cfg_attr( feature = "serialize", diff --git a/crates/bevy_utils/Cargo.toml b/crates/bevy_utils/Cargo.toml index 5a019dbc38455..25c26b2e88533 100644 --- a/crates/bevy_utils/Cargo.toml +++ b/crates/bevy_utils/Cargo.toml @@ -16,6 +16,7 @@ uuid = { version = "1.1", features = ["v4", "serde"] } hashbrown = { version = "0.12", features = ["serde"] } petgraph = "0.6" thiserror = "1.0" +bevy_utils_macros = { version = "0.9.0", path = "./macros" } [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = {version = "0.2.0", features = ["js"]} diff --git a/crates/bevy_utils/macros/Cargo.toml b/crates/bevy_utils/macros/Cargo.toml new file mode 100644 index 0000000000000..fed53367d0cbb --- /dev/null +++ b/crates/bevy_utils/macros/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "bevy_utils_macros" +version = "0.9.0" +edition = "2021" +description = "Derive implementations for bevy_utils" +homepage = "https://bevyengine.org" +repository = "https://github.com/bevyengine/bevy" +license = "MIT OR Apache-2.0" +keywords = ["bevy"] + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "1.0", features = ["full", "parsing", "extra-traits"] } +quote = "1.0" +bevy_macro_utils = { version = "0.9.0", path = "../../bevy_macro_utils" } diff --git a/crates/bevy_utils/macros/src/iterable_enum.rs b/crates/bevy_utils/macros/src/iterable_enum.rs new file mode 100644 index 0000000000000..4c353c5c02a0e --- /dev/null +++ b/crates/bevy_utils/macros/src/iterable_enum.rs @@ -0,0 +1,63 @@ +use proc_macro::TokenStream; +use quote::{quote, quote_spanned, ToTokens}; +use syn::{__private::Span, spanned::Spanned, DataEnum}; + +use crate::paths; + +pub fn parse_iterable_enum_derive(input: TokenStream) -> TokenStream { + let ast = syn::parse_macro_input!(input as syn::DeriveInput); + + let span = ast.span(); + + let name = &ast.ident; + let generics = &ast.generics; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let get_at = match ast.data { + syn::Data::Enum(d) => get_at_impl(name, span, d), + _ => quote_spanned! { + span => compile_error!("`IterableEnum` can only be applied to `enum`") + }, + }; + + let iterable_enum = paths::iterable_enum_path(); + + quote! { + impl #impl_generics #iterable_enum for #name #ty_generics #where_clause { + #get_at + } + } + .into() +} + +fn get_at_impl(name: impl ToTokens, span: Span, d: DataEnum) -> quote::__private::TokenStream { + let mut arms = quote!(); + let mut index: usize = 0; + + for variant in d.variants { + match variant.fields { + syn::Fields::Unit => { + let ident = variant.ident; + arms = quote! { #arms #index => Some(#name::#ident), }; + index += 1; + } + _ => { + return quote_spanned! { + span => compile_error!("All Fields should be Units!"); + } + .into(); + } + }; + } + + quote! { + #[inline] + fn get_at(index: usize) -> Option { + match index { + #arms + _ => None, + } + } + } + .into() +} diff --git a/crates/bevy_utils/macros/src/lib.rs b/crates/bevy_utils/macros/src/lib.rs new file mode 100644 index 0000000000000..6ec2be74629a3 --- /dev/null +++ b/crates/bevy_utils/macros/src/lib.rs @@ -0,0 +1,12 @@ +#![forbid(unsafe_code)] +#![warn(missing_docs)] + +use proc_macro::TokenStream; + +mod iterable_enum; +mod paths; + +#[proc_macro_derive(IterableEnum)] +pub fn iterable_enum_derive(input: TokenStream) -> TokenStream { + iterable_enum::parse_iterable_enum_derive(input) +} diff --git a/crates/bevy_utils/macros/src/paths.rs b/crates/bevy_utils/macros/src/paths.rs new file mode 100644 index 0000000000000..b07bcf0db071c --- /dev/null +++ b/crates/bevy_utils/macros/src/paths.rs @@ -0,0 +1,16 @@ +use bevy_macro_utils::BevyManifest; +use quote::format_ident; + +#[inline] +pub(crate) fn bevy_utils_path() -> syn::Path { + BevyManifest::default().get_path("bevy_utils") +} + +#[inline] +pub(crate) fn iterable_enum_path() -> syn::Path { + let mut utils_path = bevy_utils_path(); + utils_path + .segments + .push(format_ident!("IterableEnum").into()); + utils_path +} diff --git a/crates/bevy_utils/src/iterable_enum.rs b/crates/bevy_utils/src/iterable_enum.rs new file mode 100644 index 0000000000000..c6061d99caa7c --- /dev/null +++ b/crates/bevy_utils/src/iterable_enum.rs @@ -0,0 +1,38 @@ +use std::marker::PhantomData; + +/// A trait for enums to get a `Unit`-enum-field by a `usize` +pub trait IterableEnum: Sized { + /// Gets an `Unit`-enum-field by the given `usize` index + fn get_at(index: usize) -> Option; + + /// Creates a new [`EnumIterator`] which will numerically return every `Unit` of this enum + #[inline] + fn enum_iter() -> EnumIterator { + EnumIterator { + accelerator: 0, + phantom: PhantomData, + } + } +} + +/// An iterator over `IterableEnum`s +/// +/// Iterates all `Unit` fields in numeric order +pub struct EnumIterator { + accelerator: usize, + phantom: PhantomData, +} + +impl Iterator for EnumIterator { + type Item = E; + + #[inline] + fn next(&mut self) -> Option { + if let Some(unit) = E::get_at(self.accelerator) { + self.accelerator += 1; + Some(unit) + } else { + None + } + } +} diff --git a/crates/bevy_utils/src/lib.rs b/crates/bevy_utils/src/lib.rs index a6a4aebcb254c..22972547f1253 100644 --- a/crates/bevy_utils/src/lib.rs +++ b/crates/bevy_utils/src/lib.rs @@ -10,6 +10,8 @@ pub mod prelude { pub use crate::default; } +pub use bevy_utils_macros::*; + pub mod futures; pub mod label; mod short_names; @@ -19,12 +21,14 @@ pub mod syncunsafecell; mod default; mod float_ord; +mod iterable_enum; pub use ahash::AHasher; pub use default::default; pub use float_ord::*; pub use hashbrown; pub use instant::{Duration, Instant}; +pub use iterable_enum::IterableEnum; pub use petgraph; pub use thiserror; pub use tracing;