From 9327c8a2ab65660ca368915cd444b06320f0c919 Mon Sep 17 00:00:00 2001 From: Sheldon Frith Date: Fri, 7 Feb 2025 22:10:00 -0500 Subject: [PATCH] Accomplished branch objective, also slightly cleaned up recently added feature related to "on_cpu" version of shader functions. --- .../src/wgsl/shader_module/derived_portion.rs | 8 +- bevy_gpu_compute_macro/src/lib.rs | 5 +- bevy_gpu_compute_macro/src/transformer/mod.rs | 1 - .../src/transformer/module_parser/lib.rs | 4 +- .../module_parser/main_function.rs | 2 +- .../erroneous_usage_finder.rs | 39 ++++ .../helper_method.rs | 8 +- .../transform_wgsl_helper_methods/mod.rs | 4 + .../transform_wgsl_helper_methods/parse.rs | 73 ++++++++ .../transform_wgsl_helper_methods/run.rs | 171 ++++++++---------- .../transform_wgsl_helper_methods/test.rs | 124 +++++++++++-- .../test_for_cpu.rs} | 16 +- .../to_expanded_format.rs | 24 ++- .../to_expanded_format_for_cpu.rs} | 26 +-- .../category.rs | 19 -- .../helper_method.rs | 17 -- .../matcher.rs | 96 ---------- .../method_name.rs | 26 --- .../mod.rs | 7 - .../run.rs | 123 ------------- bevy_gpu_compute_macro/tests/tt.rs | 141 --------------- 21 files changed, 349 insertions(+), 585 deletions(-) create mode 100644 bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/erroneous_usage_finder.rs create mode 100644 bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/parse.rs rename bevy_gpu_compute_macro/src/transformer/{transform_wgsl_helper_methods_for_cpu/test.rs => transform_wgsl_helper_methods/test_for_cpu.rs} (89%) rename bevy_gpu_compute_macro/src/transformer/{transform_wgsl_helper_methods_for_cpu/to_expanded_format.rs => transform_wgsl_helper_methods/to_expanded_format_for_cpu.rs} (84%) delete mode 100644 bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/category.rs delete mode 100644 bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/helper_method.rs delete mode 100644 bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/matcher.rs delete mode 100644 bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/method_name.rs delete mode 100644 bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/mod.rs delete mode 100644 bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/run.rs delete mode 100644 bevy_gpu_compute_macro/tests/tt.rs diff --git a/bevy_gpu_compute_core/src/wgsl/shader_module/derived_portion.rs b/bevy_gpu_compute_core/src/wgsl/shader_module/derived_portion.rs index f2e3f08..8de42dc 100644 --- a/bevy_gpu_compute_core/src/wgsl/shader_module/derived_portion.rs +++ b/bevy_gpu_compute_core/src/wgsl/shader_module/derived_portion.rs @@ -82,9 +82,11 @@ mod tests { use crate::{ IterSpaceDimmension, - wgsl::shader_custom_type_name::ShaderCustomTypeName, - wgsl::shader_module::complete_shader_module::WgslShaderModule, - wgsl::shader_sections::{WgslInputArray, WgslOutputArray, WgslShaderModuleSectionCode}, + wgsl::{ + shader_custom_type_name::ShaderCustomTypeName, + shader_module::complete_shader_module::WgslShaderModule, + shader_sections::{WgslInputArray, WgslOutputArray, WgslShaderModuleSectionCode}, + }, }; use super::*; diff --git a/bevy_gpu_compute_macro/src/lib.rs b/bevy_gpu_compute_macro/src/lib.rs index 02eb219..c01a4d5 100644 --- a/bevy_gpu_compute_macro/src/lib.rs +++ b/bevy_gpu_compute_macro/src/lib.rs @@ -8,7 +8,6 @@ use transformer::{ custom_types::get_all_custom_types::get_custom_types, module_parser::lib::parse_shader_module, output::produce_expanded_output, remove_doc_comments::DocCommentRemover, transform_wgsl_helper_methods::run::transform_wgsl_helper_methods, - transform_wgsl_helper_methods_for_cpu::run::transform_wgsl_helper_methods_for_cpu, }; mod state; mod transformer; @@ -71,8 +70,8 @@ pub fn wgsl_shader_module(_attr: TokenStream, item: TokenStream) -> TokenStream DocCommentRemover {}.visit_item_mod(&module); let mut state = ModuleTransformState::empty(module, content); get_custom_types(&mut state); - transform_wgsl_helper_methods(&mut state); - transform_wgsl_helper_methods_for_cpu(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module, false); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module_for_cpu, true); parse_shader_module(&mut state); let output = produce_expanded_output(&mut state); output.into() diff --git a/bevy_gpu_compute_macro/src/transformer/mod.rs b/bevy_gpu_compute_macro/src/transformer/mod.rs index a212ce5..ccc2f12 100644 --- a/bevy_gpu_compute_macro/src/transformer/mod.rs +++ b/bevy_gpu_compute_macro/src/transformer/mod.rs @@ -5,4 +5,3 @@ pub mod output; pub mod remove_doc_comments; pub mod to_wgsl_syntax; pub mod transform_wgsl_helper_methods; -pub mod transform_wgsl_helper_methods_for_cpu; diff --git a/bevy_gpu_compute_macro/src/transformer/module_parser/lib.rs b/bevy_gpu_compute_macro/src/transformer/module_parser/lib.rs index 836e2b1..d65fd8a 100644 --- a/bevy_gpu_compute_macro/src/transformer/module_parser/lib.rs +++ b/bevy_gpu_compute_macro/src/transformer/module_parser/lib.rs @@ -6,7 +6,7 @@ use crate::state::ModuleTransformState; use super::constants::find_constants; use super::divide_custom_types::divide_custom_types_by_category; use super::helper_functions::find_helper_functions; -use super::main_function::find_main_function; +use super::main_function::parse_main_function; use super::use_statements::handle_use_statements; use super::validate_no_global_id_assignments::check_module_for_global_id_assignment; @@ -17,7 +17,7 @@ pub fn parse_shader_module(state: &mut ModuleTransformState) { "Shader module must have a body" ); } - find_main_function(state); + parse_main_function(state); handle_use_statements(state); state.module_ident = Some(state.rust_module.ident.to_string()); state.module_visibility = Some(state.rust_module.vis.to_token_stream().to_string()); diff --git a/bevy_gpu_compute_macro/src/transformer/module_parser/main_function.rs b/bevy_gpu_compute_macro/src/transformer/module_parser/main_function.rs index f9a7c11..457297e 100644 --- a/bevy_gpu_compute_macro/src/transformer/module_parser/main_function.rs +++ b/bevy_gpu_compute_macro/src/transformer/module_parser/main_function.rs @@ -5,7 +5,7 @@ use proc_macro_error::abort; use quote::ToTokens; use syn::{ItemFn, spanned::Spanned, visit::Visit}; -pub fn find_main_function(state: &mut ModuleTransformState) { +pub fn parse_main_function(state: &mut ModuleTransformState) { let module = state.rust_module.clone(); let mut extractor = MainFunctionsExtractor::new(state, false); extractor.visit_item_mod(&module); diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/erroneous_usage_finder.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/erroneous_usage_finder.rs new file mode 100644 index 0000000..18ae21c --- /dev/null +++ b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/erroneous_usage_finder.rs @@ -0,0 +1,39 @@ +use syn::{ + Expr, + visit::{self, Visit}, +}; + +use crate::transformer::custom_types::custom_type::CustomType; + +use super::parse::parse_possible_wgsl_helper; + +pub struct ErroneousUsageFinder { + custom_types: Vec, +} +impl ErroneousUsageFinder { + pub fn new(custom_types: &[CustomType]) -> Self { + Self { + custom_types: custom_types.to_vec(), + } + } +} +impl Visit<'_> for ErroneousUsageFinder { + /// This error message relies on `WgslVecInput` being in `bevy_gpu_compute_core::wgsl_helpers` + /** + ```rust + // ensure that the crate structure is what we expect, otherwise the error message will be incorrect + use bevy_gpu_compute_core::wgsl_helpers::WgslVecInput; + ``` + */ + fn visit_expr(&mut self, expr: &Expr) { + visit::visit_expr(self, expr); + if let Expr::Call(call) = expr { + let helper_method = parse_possible_wgsl_helper(call, &self.custom_types); + if helper_method.is_some() { + panic!( + "WGSL Helpers (`bevy_gpu_compute_core::wgsl_helpers`) not allowed outside of functions." + ); + } + } + } +} diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/helper_method.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/helper_method.rs index bbc7617..47114b6 100644 --- a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/helper_method.rs +++ b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/helper_method.rs @@ -7,11 +7,11 @@ use super::{ to_expanded_format::ToExpandedFormatMethodKind, }; -pub struct WgslHelperMethod<'a> { +pub struct WgslHelperMethod { pub category: WgslHelperCategory, pub method: WgslHelperMethodName, - pub t_def: &'a CustomType, - pub arg1: Option<&'a Expr>, - pub arg2: Option<&'a Expr>, + pub t_def: CustomType, + pub arg1: Option, + pub arg2: Option, pub method_expander_kind: Option, } diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/mod.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/mod.rs index 82eff9c..14894a6 100644 --- a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/mod.rs +++ b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/mod.rs @@ -1,7 +1,11 @@ pub mod category; +mod erroneous_usage_finder; pub mod helper_method; pub mod matcher; pub mod method_name; +mod parse; pub mod run; pub mod test; +mod test_for_cpu; pub mod to_expanded_format; +mod to_expanded_format_for_cpu; diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/parse.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/parse.rs new file mode 100644 index 0000000..d600c63 --- /dev/null +++ b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/parse.rs @@ -0,0 +1,73 @@ +use syn::{Expr, ExprCall, GenericArgument, PathArguments, Type}; + +use crate::transformer::{ + custom_types::custom_type::CustomType, + transform_wgsl_helper_methods::helper_method::WgslHelperMethod, +}; + +use super::{ + category::WgslHelperCategory, matcher::WgslHelperMethodMatcher, + method_name::WgslHelperMethodName, +}; +fn get_special_function_category(call: &ExprCall) -> Option { + if let Expr::Path(path) = &*call.func { + if let Some(first_seg) = path.path.segments.first() { + return WgslHelperCategory::from_ident(first_seg.ident.clone()); + } + } + None +} +fn get_special_function_method(call: &ExprCall) -> Option { + if let Expr::Path(path) = &*call.func { + if let Some(last_seg) = path.path.segments.last() { + return WgslHelperMethodName::from_ident(last_seg.ident.clone()); + } + } + None +} +fn get_special_function_generic_type<'a>( + call: &'a ExprCall, + custom_types: &'a [CustomType], +) -> Option<&'a CustomType> { + if let Expr::Path(path) = &*call.func { + if let Some(last_seg) = path.path.segments.last() { + if let PathArguments::AngleBracketed(args) = &last_seg.arguments { + if let Some(GenericArgument::Type(Type::Path(type_path))) = args.args.first() { + if let Some(last_seg) = type_path.path.segments.last() { + return custom_types.iter().find(|t| t.name.eq(&last_seg.ident)); + } + } + } + } + } + None +} + +pub fn parse_possible_wgsl_helper<'a>( + call: &'a ExprCall, + custom_types: &'a [CustomType], +) -> Option { + let category = get_special_function_category(call); + let method = get_special_function_method(call); + let type_name = get_special_function_generic_type(call, custom_types); + if let Some(cat) = category { + if let Some(met) = method { + if let Some(ty) = type_name { + let args = call.args.clone(); + let mut method = WgslHelperMethod { + category: cat, + method: met, + t_def: ty.clone(), + arg1: args.first().cloned(), + arg2: args.get(1).cloned(), + method_expander_kind: None, + }; + WgslHelperMethodMatcher::choose_expand_format(&mut method); + if method.method_expander_kind.is_some() { + return Some(method); + } + } + } + } + None +} diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/run.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/run.rs index a4094ab..d16eb16 100644 --- a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/run.rs +++ b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/run.rs @@ -1,123 +1,110 @@ -use proc_macro_error::abort; use proc_macro2::TokenStream; use syn::{ - Expr, ExprCall, GenericArgument, PathArguments, Type, parse_quote, + Expr, ItemFn, ItemMod, parse_quote, + visit::Visit, visit_mut::{self, VisitMut}, }; -use crate::{ - state::ModuleTransformState, - transformer::{ - custom_types::custom_type::CustomType, - transform_wgsl_helper_methods::{ - helper_method::WgslHelperMethod, to_expanded_format::ToExpandedFormat, - }, +use crate::transformer::{ + custom_types::custom_type::CustomType, + transform_wgsl_helper_methods::{ + helper_method::WgslHelperMethod, to_expanded_format::ToExpandedFormat, }, }; use super::{ - category::WgslHelperCategory, matcher::WgslHelperMethodMatcher, - method_name::WgslHelperMethodName, + erroneous_usage_finder::ErroneousUsageFinder, parse::parse_possible_wgsl_helper, + to_expanded_format_for_cpu::ToExpandedFormatForCpu, }; -fn get_special_function_category(call: &ExprCall) -> Option { - if let Expr::Path(path) = &*call.func { - if let Some(first_seg) = path.path.segments.first() { - return WgslHelperCategory::from_ident(first_seg.ident.clone()); - } - } - None -} -fn get_special_function_method(call: &ExprCall) -> Option { - if let Expr::Path(path) = &*call.func { - if let Some(last_seg) = path.path.segments.last() { - return WgslHelperMethodName::from_ident(last_seg.ident.clone()); - } - } - None -} -fn get_special_function_generic_type<'a>( - call: &'a ExprCall, - custom_types: &'a [CustomType], -) -> Option<&'a CustomType> { - if let Expr::Path(path) = &*call.func { - if let Some(last_seg) = path.path.segments.last() { - if let PathArguments::AngleBracketed(args) = &last_seg.arguments { - if let Some(GenericArgument::Type(Type::Path(type_path))) = args.args.first() { - if let Some(last_seg) = type_path.path.segments.last() { - return custom_types.iter().find(|t| t.name.eq(&last_seg.ident)); - } - } - } - } - } - None -} - -fn replace(call: ExprCall, custom_types: &[CustomType]) -> Option { - let category = get_special_function_category(&call); - let method = get_special_function_method(&call); - let type_name = get_special_function_generic_type(&call, custom_types); - if let Some(cat) = category { - if let Some(met) = method { - if let Some(ty) = type_name { - let mut method = WgslHelperMethod { - category: cat, - method: met, - t_def: ty, - arg1: call.args.first(), - arg2: call.args.get(1), - method_expander_kind: None, - }; - WgslHelperMethodMatcher::choose_expand_format(&mut method); - if method.method_expander_kind.is_some() { - let t = ToExpandedFormat::run(&method); - return Some(t); - } - } - } +/// Rust's normal type checking will ensure that these helper functions are using correctly defined types +pub fn transform_wgsl_helper_methods( + custom_types: &Option>, + rust_module: &mut ItemMod, + for_cpu: bool, +) { + assert!(custom_types.is_some(), "Allowed types must be defined"); + let custom_types = if let Some(ct) = &custom_types { + ct + } else { + panic!("Allowed types must be set before transforming helper functions"); + }; + let mut converter = WgslHelperExpressionConverter::new(custom_types, for_cpu); + converter.visit_item_mod_mut(rust_module); + if !for_cpu { + let mut error_finder = ErroneousUsageFinder::new(custom_types); + error_finder.visit_item_mod(rust_module); } - None } -struct HelperFunctionConverter { +struct WgslHelperExpressionConverter { custom_types: Vec, + in_main_func: bool, + nesting_level: u32, + for_cpu: bool, } -impl VisitMut for HelperFunctionConverter { +impl VisitMut for WgslHelperExpressionConverter { + fn visit_item_fn_mut(&mut self, node: &mut ItemFn) { + if node.sig.ident == "main" && self.nesting_level == 0 { + self.in_main_func = true; + self.nesting_level += 1; + visit_mut::visit_item_fn_mut(self, node); + self.nesting_level -= 1; + self.in_main_func = false; + } else { + // For any other function, just increment nesting level and continue + self.nesting_level += 1; + visit_mut::visit_item_fn_mut(self, node); + self.nesting_level -= 1; + } + } fn visit_expr_mut(&mut self, expr: &mut Expr) { - visit_mut::visit_expr_mut(self, expr); - if let Expr::Call(call) = expr { - let replacement = replace(call.clone(), &self.custom_types); - if let Some(r) = replacement { - *expr = parse_quote!(#r); + if self.nesting_level > 0 { + let in_main = self.in_main_func && self.nesting_level == 1; + if let Expr::Call(call) = expr { + let helper_method = parse_possible_wgsl_helper(call, &self.custom_types); + if let Some(method) = helper_method { + if self.for_cpu { + let replacement = process_wgsl_helper_for_cpu(method); + *expr = parse_quote!(#replacement); + } else { + let replacement = process_wgsl_helper(method, in_main); + *expr = parse_quote!(#replacement); + } + } } } + // Continue visiting child nodes + visit_mut::visit_expr_mut(self, expr); } } -impl HelperFunctionConverter { - pub fn new(custom_types: &[CustomType]) -> Self { +impl WgslHelperExpressionConverter { + pub fn new(custom_types: &[CustomType], for_cpu: bool) -> Self { Self { custom_types: custom_types.to_vec(), + in_main_func: false, + nesting_level: 0, + for_cpu, } } } -/// Rust's normal type checking will ensure that these helper functions are using correctly defined types -pub fn transform_wgsl_helper_methods(state: &mut ModuleTransformState) { - assert!( - state.custom_types.is_some(), - "Allowed types must be defined" - ); - let custom_types = if let Some(ct) = &state.custom_types { - ct - } else { - abort!( - state.rust_module.ident.span(), - "Allowed types must be set before transforming helper functions" +fn process_wgsl_helper(helper_method: WgslHelperMethod, in_main_func: bool) -> TokenStream { + if !helper_method + .method_expander_kind + .as_ref() + .unwrap() + .valid_outside_main() + && !in_main_func + { + panic!( + "WGSL helpers that read from inputs or write to outputs (`bevy_gpu_compute_core::wgsl_helpers`) can only be used inside the main function. It is technically possible to pass in entire input arrays, configs, or output arrays to helper functions, but considering the performance implications, it is not recommended. Instead interact with your inputs and outputs in the main function and pass in only the necessary data to the helper functions." ); - }; - let mut converter = HelperFunctionConverter::new(custom_types); - converter.visit_item_mod_mut(&mut state.rust_module); + } + ToExpandedFormat::run(&helper_method) +} +fn process_wgsl_helper_for_cpu(helper_method: WgslHelperMethod) -> TokenStream { + ToExpandedFormatForCpu::run(&helper_method) } diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/test.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/test.rs index 3725a7a..e94d7b0 100644 --- a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/test.rs +++ b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/test.rs @@ -16,13 +16,12 @@ mod tests { fn test_vec_len() { let input: ItemMod = parse_quote! { mod test { - fn example() { + fn main() { let x = WgslVecInput::vec_len::(); } } }; - let expected_output = - "mod test { fn example () { let x = POSITION_INPUT_ARRAY_LENGTH ; } }"; + let expected_output = "mod test { fn main () { let x = POSITION_INPUT_ARRAY_LENGTH ; } }"; let mut state = ModuleTransformState::empty(input, "".to_string()); let custom_types = vec![CustomType::new( &format_ident!("Position"), @@ -30,7 +29,7 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module, false); let result = state.rust_module.to_token_stream().to_string(); println!("{}", result); assert_eq!( @@ -41,7 +40,10 @@ mod tests { } #[test] - fn test_vec_val() { + #[should_panic( + expected = "WGSL helpers that read from inputs or write to outputs (`bevy_gpu_compute_core::wgsl_helpers`) can only be used inside the main function. It is technically possible to pass in entire input arrays, configs, or output arrays to helper functions, but considering the performance implications, it is not recommended. Instead interact with your inputs and outputs in the main function and pass in only the necessary data to the helper functions." + )] + fn test_vec_val_only_in_main() { let input: ItemMod = parse_quote! { mod test { @@ -50,7 +52,27 @@ mod tests { } } }; - let expected_output = "mod test { fn example () { let x = radius_input_array [5] ; } }"; + let mut state = ModuleTransformState::empty(input, "".to_string()); + let custom_types = vec![CustomType::new( + &format_ident!("Radius"), + CustomTypeKind::InputArray, + TokenStream::new(), + )]; + state.custom_types = Some(custom_types); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module, false); + } + + #[test] + fn test_vec_val() { + let input: ItemMod = parse_quote! { + mod test { + + fn main() { + let x = WgslVecInput::vec_val::(5); + } + } + }; + let expected_output = "mod test { fn main () { let x = radius_input_array [5] ; } }"; let mut state = ModuleTransformState::empty(input, "".to_string()); let custom_types = vec![CustomType::new( @@ -59,7 +81,7 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module, false); let result = state.rust_module.to_token_stream().to_string(); println!("{}", result); assert_eq!( @@ -70,7 +92,10 @@ mod tests { } #[test] - fn test_push() { + #[should_panic( + expected = "WGSL helpers that read from inputs or write to outputs (`bevy_gpu_compute_core::wgsl_helpers`) can only be used inside the main function. It is technically possible to pass in entire input arrays, configs, or output arrays to helper functions, but considering the performance implications, it is not recommended. Instead interact with your inputs and outputs in the main function and pass in only the necessary data to the helper functions." + )] + fn test_push_only_in_main() { let input: ItemMod = parse_quote! { mod test { fn example() { @@ -79,7 +104,6 @@ mod tests { } }; - let expected_output = "mod test { fn example () { { let collisionresult_output_array_index = atomicAdd (& collisionresult_counter , 1u) ; if collisionresult_output_array_index < COLLISIONRESULT_OUTPUT_ARRAY_LENGTH { collisionresult_output_array [collisionresult_output_array_index] = value ; } } ; } }"; let mut state = ModuleTransformState::empty(input, "".to_string()); let custom_types = vec![CustomType::new( &format_ident!("CollisionResult"), @@ -87,7 +111,27 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module, false); + } + #[test] + fn test_push() { + let input: ItemMod = parse_quote! { + mod test { + fn main() { + WgslOutput::push::(value); + } + } + }; + + let expected_output = "mod test { fn main () { { let collisionresult_output_array_index = atomicAdd (& collisionresult_counter , 1u) ; if collisionresult_output_array_index < COLLISIONRESULT_OUTPUT_ARRAY_LENGTH { collisionresult_output_array [collisionresult_output_array_index] = value ; } } ; } }"; + let mut state = ModuleTransformState::empty(input, "".to_string()); + let custom_types = vec![CustomType::new( + &format_ident!("CollisionResult"), + CustomTypeKind::OutputVec, + TokenStream::new(), + )]; + state.custom_types = Some(custom_types); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module, false); let result = state.rust_module.to_token_stream().to_string(); println!("{}", result); @@ -117,7 +161,7 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module, false); let result = state.rust_module.to_token_stream().to_string(); println!("{}", result); @@ -146,7 +190,7 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module, false); let result = state.rust_module.to_token_stream().to_string(); println!("{}", result); @@ -158,7 +202,10 @@ mod tests { } #[test] - fn test_output_set() { + #[should_panic( + expected = "WGSL helpers that read from inputs or write to outputs (`bevy_gpu_compute_core::wgsl_helpers`) can only be used inside the main function. It is technically possible to pass in entire input arrays, configs, or output arrays to helper functions, but considering the performance implications, it is not recommended. Instead interact with your inputs and outputs in the main function and pass in only the necessary data to the helper functions." + )] + fn test_output_set_not_in_main() { let input: ItemMod = parse_quote! { mod test { fn example() { @@ -166,8 +213,27 @@ mod tests { } } }; + let mut state = ModuleTransformState::empty(input, "".to_string()); + let custom_types = vec![CustomType::new( + &format_ident!("CollisionResult"), + CustomTypeKind::OutputArray, + TokenStream::new(), + )]; + state.custom_types = Some(custom_types); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module, false); + } + #[test] + + fn test_output_set() { + let input: ItemMod = parse_quote! { + mod test { + fn main() { + WgslOutput::set::(idx, val); + } + } + }; let expected_output = - "mod test { fn example () { collisionresult_output_array [idx] = val ; } }"; + "mod test { fn main () { collisionresult_output_array [idx] = val ; } }"; let mut state = ModuleTransformState::empty(input, "".to_string()); let custom_types = vec![CustomType::new( @@ -176,7 +242,7 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module, false); let result = state.rust_module.to_token_stream().to_string(); println!("{}", result); @@ -187,7 +253,10 @@ mod tests { ); } #[test] - fn test_config_get() { + #[should_panic( + expected = "WGSL helpers that read from inputs or write to outputs (`bevy_gpu_compute_core::wgsl_helpers`) can only be used inside the main function. It is technically possible to pass in entire input arrays, configs, or output arrays to helper functions, but considering the performance implications, it is not recommended. Instead interact with your inputs and outputs in the main function and pass in only the necessary data to the helper functions." + )] + fn test_config_get_outside_main() { let input: ItemMod = parse_quote! { mod test { fn example() { @@ -195,7 +264,6 @@ mod tests { } } }; - let expected_output = "mod test { fn example () { let t = position ; } }"; let mut state = ModuleTransformState::empty(input, "".to_string()); let custom_types = vec![CustomType::new( @@ -204,7 +272,27 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module, false); + } + #[test] + fn test_config_get() { + let input: ItemMod = parse_quote! { + mod test { + fn main() { + let t = WgslConfigInput::get::(); + } + } + }; + let expected_output = "mod test { fn main () { let t = position ; } }"; + + let mut state = ModuleTransformState::empty(input, "".to_string()); + let custom_types = vec![CustomType::new( + &format_ident!("Position"), + CustomTypeKind::Uniform, + TokenStream::new(), + )]; + state.custom_types = Some(custom_types); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module, false); let result = state.rust_module.to_token_stream().to_string(); println!("{}", result); diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/test.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/test_for_cpu.rs similarity index 89% rename from bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/test.rs rename to bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/test_for_cpu.rs index 356b681..be1a491 100644 --- a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/test.rs +++ b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/test_for_cpu.rs @@ -4,7 +4,7 @@ mod tests { state::ModuleTransformState, transformer::{ custom_types::custom_type::{CustomType, CustomTypeKind}, - transform_wgsl_helper_methods_for_cpu::run::transform_wgsl_helper_methods_for_cpu, + transform_wgsl_helper_methods::run::transform_wgsl_helper_methods, }, }; @@ -30,7 +30,7 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods_for_cpu(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module_for_cpu, true); let result = state.rust_module_for_cpu.to_token_stream().to_string(); println!("{}", result); assert_eq!( @@ -60,7 +60,7 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods_for_cpu(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module_for_cpu, true); let result = state.rust_module_for_cpu.to_token_stream().to_string(); println!("{}", result); assert_eq!( @@ -89,7 +89,7 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods_for_cpu(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module_for_cpu, true); let result = state.rust_module_for_cpu.to_token_stream().to_string(); println!("{}", result); @@ -119,7 +119,7 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods_for_cpu(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module_for_cpu, true); let result = state.rust_module_for_cpu.to_token_stream().to_string(); println!("{}", result); @@ -149,7 +149,7 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods_for_cpu(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module_for_cpu, true); let result = state.rust_module_for_cpu.to_token_stream().to_string(); println!("{}", result); @@ -179,7 +179,7 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods_for_cpu(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module_for_cpu, true); let result = state.rust_module_for_cpu.to_token_stream().to_string(); println!("{}", result); @@ -207,7 +207,7 @@ mod tests { TokenStream::new(), )]; state.custom_types = Some(custom_types); - transform_wgsl_helper_methods_for_cpu(&mut state); + transform_wgsl_helper_methods(&state.custom_types, &mut state.rust_module_for_cpu, true); let result = state.rust_module_for_cpu.to_token_stream().to_string(); println!("{}", result); diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/to_expanded_format.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/to_expanded_format.rs index b3c9224..b540d3d 100644 --- a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/to_expanded_format.rs +++ b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/to_expanded_format.rs @@ -12,10 +12,22 @@ pub enum ToExpandedFormatMethodKind { InputVal, OutputPush, OutputLen, - OutputMaxLen, OutputSet, } +impl ToExpandedFormatMethodKind { + pub fn valid_outside_main(&self) -> bool { + match self { + ToExpandedFormatMethodKind::ConfigGet + | ToExpandedFormatMethodKind::InputVal + | ToExpandedFormatMethodKind::OutputPush + | ToExpandedFormatMethodKind::OutputSet => false, + ToExpandedFormatMethodKind::OutputLen + | ToExpandedFormatMethodKind::OutputMaxLen + | ToExpandedFormatMethodKind::InputLen => true, + } + } +} pub struct ToExpandedFormat {} impl ToExpandedFormat { @@ -32,7 +44,7 @@ impl ToExpandedFormat { } Some(ToExpandedFormatMethodKind::InputVal) => { let name = method.t_def.name.input_array(); - let index = if let Some(a1) = method.arg1 { + let index = if let Some(a1) = &method.arg1 { a1 } else { abort!(Span::call_site(), "arg1 is None for input value method") @@ -42,12 +54,12 @@ impl ToExpandedFormat { } } Some(ToExpandedFormatMethodKind::OutputPush) => { - let t_def = method.t_def; + let t_def = &method.t_def; let counter = t_def.name.counter(); let arr = t_def.name.output_array(); let len = t_def.name.output_array_length(); let index = t_def.name.index(); - let value = if let Some(a1) = method.arg1 { + let value = if let Some(a1) = &method.arg1 { a1 } else { abort!(Span::call_site(), "arg1 is None for output push method") @@ -73,12 +85,12 @@ impl ToExpandedFormat { } Some(ToExpandedFormatMethodKind::OutputSet) => { let arr = method.t_def.name.output_array().to_token_stream(); - let index = if let Some(a1) = method.arg1 { + let index = if let Some(a1) = &method.arg1 { a1 } else { abort!(Span::call_site(), "arg1 is None for output set method") }; - let value = if let Some(a2) = method.arg2 { + let value = if let Some(a2) = &method.arg2 { a2 } else { abort!(Span::call_site(), "arg2 is None for output set method") diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/to_expanded_format.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/to_expanded_format_for_cpu.rs similarity index 84% rename from bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/to_expanded_format.rs rename to bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/to_expanded_format_for_cpu.rs index f16d8d0..c38167c 100644 --- a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/to_expanded_format.rs +++ b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods/to_expanded_format_for_cpu.rs @@ -5,20 +5,10 @@ use quote::ToTokens; use quote::quote; use super::helper_method::WgslHelperMethod; +use super::to_expanded_format::ToExpandedFormatMethodKind; -pub enum ToExpandedFormatMethodKind { - ConfigGet, - InputLen, - InputVal, - OutputPush, - OutputLen, - - OutputMaxLen, - OutputSet, -} - -pub struct ToExpandedFormat {} -impl ToExpandedFormat { +pub struct ToExpandedFormatForCpu {} +impl ToExpandedFormatForCpu { pub fn run(method: &WgslHelperMethod) -> TokenStream { match method.method_expander_kind { Some(ToExpandedFormatMethodKind::ConfigGet) => { @@ -35,7 +25,7 @@ impl ToExpandedFormat { } Some(ToExpandedFormatMethodKind::InputVal) => { let name = method.t_def.name.input_array(); - let index = if let Some(a1) = method.arg1 { + let index = if let Some(a1) = &method.arg1 { a1 } else { abort!(Span::call_site(), "arg1 is None for input value method") @@ -45,9 +35,9 @@ impl ToExpandedFormat { } } Some(ToExpandedFormatMethodKind::OutputPush) => { - let t_def = method.t_def; + let t_def = &method.t_def; let arr = t_def.name.output_array(); - let value = if let Some(a1) = method.arg1 { + let value = if let Some(a1) = &method.arg1 { a1 } else { abort!(Span::call_site(), "arg1 is None for output push method") @@ -69,12 +59,12 @@ impl ToExpandedFormat { } Some(ToExpandedFormatMethodKind::OutputSet) => { let arr = method.t_def.name.output_array().to_token_stream(); - let index = if let Some(a1) = method.arg1 { + let index = if let Some(a1) = &method.arg1 { a1 } else { abort!(Span::call_site(), "arg1 is None for output set method") }; - let value = if let Some(a2) = method.arg2 { + let value = if let Some(a2) = &method.arg2 { a2 } else { abort!(Span::call_site(), "arg2 is None for output set method") diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/category.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/category.rs deleted file mode 100644 index caa24f2..0000000 --- a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/category.rs +++ /dev/null @@ -1,19 +0,0 @@ -use syn::Ident; - -pub enum WgslHelperCategory { - VecInput, - Output, - ConfigInput, - _Invalid, -} -// from ident -impl WgslHelperCategory { - pub fn from_ident(ident: Ident) -> Option { - match ident.to_string().as_str() { - "WgslVecInput" => Some(WgslHelperCategory::VecInput), - "WgslOutput" => Some(WgslHelperCategory::Output), - "WgslConfigInput" => Some(WgslHelperCategory::ConfigInput), - _ => None, - } - } -} diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/helper_method.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/helper_method.rs deleted file mode 100644 index bbc7617..0000000 --- a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/helper_method.rs +++ /dev/null @@ -1,17 +0,0 @@ -use syn::Expr; - -use crate::transformer::custom_types::custom_type::CustomType; - -use super::{ - category::WgslHelperCategory, method_name::WgslHelperMethodName, - to_expanded_format::ToExpandedFormatMethodKind, -}; - -pub struct WgslHelperMethod<'a> { - pub category: WgslHelperCategory, - pub method: WgslHelperMethodName, - pub t_def: &'a CustomType, - pub arg1: Option<&'a Expr>, - pub arg2: Option<&'a Expr>, - pub method_expander_kind: Option, -} diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/matcher.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/matcher.rs deleted file mode 100644 index 612accc..0000000 --- a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/matcher.rs +++ /dev/null @@ -1,96 +0,0 @@ -use crate::transformer::{ - custom_types::custom_type::CustomTypeKind, - transform_wgsl_helper_methods_for_cpu::to_expanded_format::ToExpandedFormatMethodKind, -}; - -use super::{ - category::WgslHelperCategory, helper_method::WgslHelperMethod, - method_name::WgslHelperMethodName, -}; - -pub struct WgslHelperMethodMatcher {} -impl WgslHelperMethodMatcher { - pub fn choose_expand_format(method: &mut WgslHelperMethod) { - match (&method.category, &method.method) { - (WgslHelperCategory::ConfigInput, WgslHelperMethodName::Get) => { - assert!( - method.t_def.kind == CustomTypeKind::Uniform, - "Expected {} to be an input config type, since WgslConfigInput::get is called, instead found it was of type {:?}. Put #[wgsl_config] above your type declaration to fix this. A given type cannot be used for multiple purposes, for example a type T cannot be both a config and a input array.", - method.t_def.name.name, - method.t_def.kind - ); - method.method_expander_kind = Some(ToExpandedFormatMethodKind::ConfigGet); - } - (WgslHelperCategory::VecInput, WgslHelperMethodName::VecLen) => { - assert!( - method.t_def.kind == CustomTypeKind::InputArray, - "Expected {} to be an input array type, since WgslVecInput::vec_len is called, instead found it was of type {:?}. Put #[wgsl_input_array] above your type declaration to fix this. A given type cannot be used for multiple purposes, for example a type T cannot be both an input array and an output array.", - method.t_def.name.name, - method.t_def.kind - ); - method.method_expander_kind = Some(ToExpandedFormatMethodKind::InputLen); - } - (WgslHelperCategory::VecInput, WgslHelperMethodName::VecVal) => { - assert!( - method.arg1.is_some(), - "Expected an argument for input vec value getter" - ); - assert!( - method.t_def.kind == CustomTypeKind::InputArray, - "Expected {} to be an input array type, since WgslVecInput::vec_val is called, instead found it was of type {:?}. Put #[wgsl_input_array] above your type declaration to fix this. A given type cannot be used for multiple purposes, for example a type T cannot be both an input array and an output array.", - method.t_def.name.name, - method.t_def.kind - ); - method.method_expander_kind = Some(ToExpandedFormatMethodKind::InputVal); - } - (WgslHelperCategory::Output, WgslHelperMethodName::Push) => { - assert!( - method.arg1.is_some(), - "Expected an argument for output push" - ); - assert!( - method.t_def.kind == CustomTypeKind::OutputVec, - "Expected {} to be an output vec type, since WgslOutput::push is called, instead found it was of type {:?}. Put #[wgsl_output_vec] above your type declaration to fix this. A given type cannot be used for multiple purposes, for example a type T cannot be both an input array and an output array, an output cannot be both a vec and an array.", - method.t_def.name.name, - method.t_def.kind - ); - method.method_expander_kind = Some(ToExpandedFormatMethodKind::OutputPush); - } - (WgslHelperCategory::Output, WgslHelperMethodName::MaxLen) => { - assert!( - method.t_def.kind == CustomTypeKind::OutputArray - || method.t_def.kind == CustomTypeKind::OutputVec, - "Expected {} to be an output array or vec type, since WgslOutput::max_len is called, instead found it was of type {:?}. Put #[wgsl_output_array] or #[wgsl_output_vec] above your type declaration to fix this. A given type cannot be used for multiple purposes, for example a type T cannot be both an input array and an output array.", - method.t_def.name.name, - method.t_def.kind - ); - method.method_expander_kind = Some(ToExpandedFormatMethodKind::OutputMaxLen); - } - (WgslHelperCategory::Output, WgslHelperMethodName::Len) => { - assert!( - method.t_def.kind == CustomTypeKind::OutputVec, - "Expected {} to be an output vec type, since WgslOutput::len is called, instead found it was of type {:?}. Put #[wgsl_output_vec] above your type declaration to fix this. A given type cannot be used for multiple purposes, for example a type T cannot be both an input array and an output array.", - method.t_def.name.name, - method.t_def.kind - ); - method.method_expander_kind = Some(ToExpandedFormatMethodKind::OutputLen); - } - (WgslHelperCategory::Output, WgslHelperMethodName::Set) => { - assert!( - method.t_def.kind == CustomTypeKind::OutputArray, - "Expected {} to be an output array type, since WgslOutput::set is called, instead found it was of type {:?}. Put #[wgsl_output_array] above your type declaration to fix this. A given type cannot be used for multiple purposes, for example a type T cannot be both an input array and an output array.", - method.t_def.name.name, - method.t_def.kind - ); - assert!( - method.arg1.is_some() && method.arg2.is_some(), - "Expected two arguments for output set" - ); - method.method_expander_kind = Some(ToExpandedFormatMethodKind::OutputSet); - } - _ => { - method.method_expander_kind = None; - } - } - } -} diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/method_name.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/method_name.rs deleted file mode 100644 index 2875406..0000000 --- a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/method_name.rs +++ /dev/null @@ -1,26 +0,0 @@ -use syn::Ident; - -pub enum WgslHelperMethodName { - VecLen, - VecVal, - Push, - Len, - MaxLen, - Set, - Get, - _Invalid, -} -impl WgslHelperMethodName { - pub fn from_ident(ident: Ident) -> Option { - match ident.to_string().as_str() { - "vec_len" => Some(WgslHelperMethodName::VecLen), - "vec_val" => Some(WgslHelperMethodName::VecVal), - "push" => Some(WgslHelperMethodName::Push), - "len" => Some(WgslHelperMethodName::Len), - "max_len" => Some(WgslHelperMethodName::MaxLen), - "set" => Some(WgslHelperMethodName::Set), - "get" => Some(WgslHelperMethodName::Get), - _ => None, - } - } -} diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/mod.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/mod.rs deleted file mode 100644 index 82eff9c..0000000 --- a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod category; -pub mod helper_method; -pub mod matcher; -pub mod method_name; -pub mod run; -pub mod test; -pub mod to_expanded_format; diff --git a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/run.rs b/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/run.rs deleted file mode 100644 index b6105ef..0000000 --- a/bevy_gpu_compute_macro/src/transformer/transform_wgsl_helper_methods_for_cpu/run.rs +++ /dev/null @@ -1,123 +0,0 @@ -use proc_macro_error::abort; -use proc_macro2::TokenStream; - -use syn::{ - Expr, ExprCall, GenericArgument, PathArguments, Type, parse_quote, - visit_mut::{self, VisitMut}, -}; - -use crate::{ - state::ModuleTransformState, - transformer::{ - custom_types::custom_type::CustomType, - transform_wgsl_helper_methods_for_cpu::{ - helper_method::WgslHelperMethod, to_expanded_format::ToExpandedFormat, - }, - }, -}; - -use super::{ - category::WgslHelperCategory, matcher::WgslHelperMethodMatcher, - method_name::WgslHelperMethodName, -}; - -fn get_special_function_category(call: &ExprCall) -> Option { - if let Expr::Path(path) = &*call.func { - if let Some(first_seg) = path.path.segments.first() { - return WgslHelperCategory::from_ident(first_seg.ident.clone()); - } - } - None -} -fn get_special_function_method(call: &ExprCall) -> Option { - if let Expr::Path(path) = &*call.func { - if let Some(last_seg) = path.path.segments.last() { - return WgslHelperMethodName::from_ident(last_seg.ident.clone()); - } - } - None -} -fn get_special_function_generic_type<'a>( - call: &'a ExprCall, - custom_types: &'a [CustomType], -) -> Option<&'a CustomType> { - if let Expr::Path(path) = &*call.func { - if let Some(last_seg) = path.path.segments.last() { - if let PathArguments::AngleBracketed(args) = &last_seg.arguments { - if let Some(GenericArgument::Type(Type::Path(type_path))) = args.args.first() { - if let Some(last_seg) = type_path.path.segments.last() { - return custom_types.iter().find(|t| t.name.eq(&last_seg.ident)); - } - } - } - } - } - None -} - -fn replace(call: ExprCall, custom_types: &[CustomType]) -> Option { - let category = get_special_function_category(&call); - let method = get_special_function_method(&call); - let type_name = get_special_function_generic_type(&call, custom_types); - if let Some(cat) = category { - if let Some(met) = method { - if let Some(ty) = type_name { - let mut method = WgslHelperMethod { - category: cat, - method: met, - t_def: ty, - arg1: call.args.first(), - arg2: call.args.get(1), - method_expander_kind: None, - }; - WgslHelperMethodMatcher::choose_expand_format(&mut method); - if method.method_expander_kind.is_some() { - let t = ToExpandedFormat::run(&method); - return Some(t); - } - } - } - } - None -} - -struct HelperFunctionConverterForCpu { - custom_types: Vec, -} - -impl VisitMut for HelperFunctionConverterForCpu { - fn visit_expr_mut(&mut self, expr: &mut Expr) { - visit_mut::visit_expr_mut(self, expr); - if let Expr::Call(call) = expr { - let replacement = replace(call.clone(), &self.custom_types); - if let Some(r) = replacement { - *expr = parse_quote!(#r); - } - } - } -} -impl HelperFunctionConverterForCpu { - pub fn new(custom_types: &[CustomType]) -> Self { - Self { - custom_types: custom_types.to_vec(), - } - } -} - -/// Rust's normal type checking will ensure that these helper functions are using correctly defined types -pub fn transform_wgsl_helper_methods_for_cpu(state: &mut ModuleTransformState) { - assert!( - state.custom_types.is_some(), - "Allowed types must be defined" - ); - let custom_types = if let Some(ct) = &state.custom_types { - ct - } else { - abort!( - state.rust_module_for_cpu.ident.span(), - "Allowed types must be set before transforming helper functions" - ); - }; - let mut converter = HelperFunctionConverterForCpu::new(custom_types); - converter.visit_item_mod_mut(&mut state.rust_module_for_cpu); -} diff --git a/bevy_gpu_compute_macro/tests/tt.rs b/bevy_gpu_compute_macro/tests/tt.rs deleted file mode 100644 index 77b1664..0000000 --- a/bevy_gpu_compute_macro/tests/tt.rs +++ /dev/null @@ -1,141 +0,0 @@ -// pub fn parsed() -> WgslShaderModuleUserPortion { -// WgslShaderModuleUserPortion { -// static_consts: [ -// WgslConstAssignment { -// code: WgslShaderModuleSectionCode { -// rust_code: ("const example_module_const : u32 = 42;") -// .to_string(), -// wgsl_code: ("const example_module_const : u32 = 42;") -// .to_string(), -// }, -// }, -// ] -// .into(), -// helper_types: [].into(), -// uniforms: Vec::from([ -// WgslType { -// name: ShaderCustomTypeName::new("Config"), -// code: WgslShaderModuleSectionCode { -// rust_code: ("#[wgsl_config] struct Config { pub example_value : f32, }") -// .to_string(), -// wgsl_code: ("struct Config { example_value : f32, }") -// .to_string(), -// }, -// }, -// ]), -// input_arrays: [ -// WgslInputArray { -// item_type: WgslType { -// name: ShaderCustomTypeName::new("Position"), -// code: WgslShaderModuleSectionCode { -// rust_code: ("#[wgsl_input_array] type Position = [f32; 2];") -// .to_string(), -// wgsl_code: ("alias Position = array < f32, 2 > ;") -// .to_string(), -// }, -// }, -// }, -// WgslInputArray { -// item_type: WgslType { -// name: ShaderCustomTypeName::new("Radius"), -// code: WgslShaderModuleSectionCode { -// rust_code: ("#[wgsl_input_array] type Radius = f32;") -// .to_string(), -// wgsl_code: ("alias Radius = f32;").to_string(), -// }, -// }, -// }, -// ] -// .into(), -// output_arrays: [ -// WgslOutputArray { -// item_type: WgslType { -// name: ShaderCustomTypeName::new("CollisionResult"), -// code: WgslShaderModuleSectionCode { -// rust_code: ("#[wgsl_output_vec] struct CollisionResult { entity1 : u32, entity2 : u32, }") -// .to_string(), -// wgsl_code: ("struct CollisionResult { entity1 : u32, entity2 : u32, }") -// .to_string(), -// }, -// }, -// atomic_counter_name: Some("collisionresult_counter".to_string()), -// }, -// ] -// .into(), -// helper_functions: [ -// WgslFunction { -// name: ("calculate_distance_squared").to_string(), -// code: WgslShaderModuleSectionCode { -// rust_code: ("pub fn calculate_distance_squared(p1 : [f32; 2], p2 : [f32; 2]) -> f32\n{\n let dx = p1 [0] - p2 [0]; let dy = p1 [1] - p2 [1]; return dx * dx + dy *\n dy;\n}") -// .to_string(), -// wgsl_code: ("pub fn\ncalculate_distance_squared(p1 : array < f32, 2 > , p2 : array < f32, 2 >) ->\nf32\n{\n let dx = p1 [0] - p2 [0]; let dy = p1 [1] - p2 [1]; return dx * dx + dy *\n dy;\n}") -// .to_string(), -// }, -// }, -// ] -// .into(), -// main_function: Some(WgslFunction { -// name: ("main").to_string(), -// code: WgslShaderModuleSectionCode { -// rust_code: ("pub fn main(iter_pos : WgslIterationPosition)\n{\n let current_entity = iter_pos.x; let other_entity = iter_pos.y; if\n current_entity >= POSITION_INPUT_ARRAY_LENGTH || other_entity >=\n POSITION_INPUT_ARRAY_LENGTH || current_entity == other_entity ||\n current_entity >= other_entity { return; } let current_radius =\n radius_input_array [current_entity] + example_module_const as f32; let\n other_radius = radius_input_array [other_entity] + config.example_value as\n f32; if current_radius <= 0.0 || other_radius <= 0.0 { return; } let\n current_pos = position_input_array [current_entity]; let other_pos =\n position_input_array [other_entity]; let dist_squared =\n calculate_distance_squared(current_pos, other_pos); let radius_sum =\n current_radius + other_radius; if dist_squared < radius_sum * radius_sum\n {\n {\n let collisionresult_output_array_index =\n atomicAdd(& collisionresult_counter, 1u); if\n collisionresult_output_array_index <\n COLLISIONRESULT_OUTPUT_ARRAY_LENGTH\n {\n collisionresult_output_array\n [collisionresult_output_array_index] = CollisionResult\n { entity1 : current_entity, entity2 : other_entity, };\n }\n };\n }\n}") -// .to_string(), -// wgsl_code: ("pub fn main(@builtin(global_invocation_id) iter_pos: vec3)\n{\n let current_entity = iter_pos.x; let other_entity = iter_pos.y; if\n current_entity >= POSITION_INPUT_ARRAY_LENGTH || other_entity >=\n POSITION_INPUT_ARRAY_LENGTH || current_entity == other_entity ||\n current_entity >= other_entity { return; } let current_radius =\n radius_input_array [current_entity] + f32(example_module_const); let\n other_radius = radius_input_array [other_entity] +\n f32(config.example_value); if current_radius <= 0.0 || other_radius <= 0.0\n { return; } let current_pos = position_input_array [current_entity]; let\n other_pos = position_input_array [other_entity]; let dist_squared =\n calculate_distance_squared(current_pos, other_pos); let radius_sum =\n current_radius + other_radius; if dist_squared < radius_sum * radius_sum\n {\n {\n let collisionresult_output_array_index =\n atomicAdd(& collisionresult_counter, 1u); if\n collisionresult_output_array_index <\n COLLISIONRESULT_OUTPUT_ARRAY_LENGTH\n {\n collisionresult_output_array\n [collisionresult_output_array_index] =\n CollisionResult(current_entity, other_entity);\n }\n };\n }\n}") -// .to_string(), -// }, -// }), -// binding_numbers_by_variable_name: Some( -// HashMap::from([ -// ("radius_input_array".to_string(), 3u32), -// ("collisionresult_counter".to_string(), 5u32), -// ("config".to_string(), 1u32), -// ("position_input_array".to_string(), 2u32), -// ("collisionresult_output_array".to_string(), 4u32), -// ]), -// ), -// } -// } -// pub mod on_cpu { -// use super::*; -// use bevy_gpu_compute_core::wgsl_helpers::*; -// const example_module_const: u32 = 42; -// pub fn calculate_distance_squared(p1: [f32; 2], p2: [f32; 2]) -> f32 { -// let dx = p1[0] - p2[0]; -// let dy = p1[1] - p2[1]; -// return dx * dx + dy * dy; -// } -// pub fn main( -// iter_pos: WgslIterationPosition, -// config: Config, -// position_input_array: Vec, -// radius_input_array: Vec, -// mut collisionresult_output_array: &mut Vec, -// ) { -// let COLLISIONRESULT_OUTPUT_ARRAY_LENGTH = collisionresult_output_array.len(); -// let RADIUS_INPUT_ARRAY_LENGTH = radius_input_array.len(); -// let POSITION_INPUT_ARRAY_LENGTH = position_input_array.len(); -// let current_entity = iter_pos.x; -// let other_entity = iter_pos.y; -// if current_entity >= position_input_array.len() -// || other_entity >= position_input_array.len() -// || current_entity == other_entity -// || current_entity >= other_entity -// { -// return; -// } -// let current_radius = radius_input_array[current_entity] + example_module_const as f32; -// let other_radius = radius_input_array[other_entity] + config.example_value as f32; -// if current_radius <= 0.0 || other_radius <= 0.0 { -// return; -// } -// let current_pos = position_input_array[current_entity]; -// let other_pos = position_input_array[other_entity]; -// let dist_squared = calculate_distance_squared(current_pos, other_pos); -// let radius_sum = current_radius + other_radius; -// if dist_squared < radius_sum * radius_sum { -// collisionresult_output_array.push(CollisionResult { -// entity1: current_entity, -// entity2: other_entity, -// }); -// } -// } -// }