-
Notifications
You must be signed in to change notification settings - Fork 25
Zq/detach compile #179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Zq/detach compile #179
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @chonpsk, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a significant refactoring to the ZK-CUDA context, decoupling the compilation of the computation graph from the concrete input data. This allows for a more flexible workflow where memory is allocated and its shape defined upfront, while the actual input values are provided later, specifically before the witness generation phase. This change enhances the system's ability to compile circuits independently of specific data, enabling the same compiled circuit to be used with varying inputs for proof generation.
Highlights
- Decoupled Data Handling: The
Contextnow provides anew_device_memoryfunction to allocate memory and define its shape without immediately copying data. This separates memory allocation from data population. - Deferred Data Copying: The
copy_to_devicefunction has been updated to populate pre-allocated device memory. Its usage across the codebase has been shifted to occur after thecompile_computation_graphcall, but beforesolve_witness, allowing the computation graph to be compiled independently of concrete input values. - Relocated Kernel Evaluation: The core logic for evaluating kernel primitives, which was previously executed within the
call_kernel!macro's underlying implementation, has been moved to thesolve_witnessmethod. This ensures that kernel computations are performed when concrete witness data is available. - Test Case Adaptations: All relevant test cases and examples have been updated to reflect this new data handling paradigm, ensuring compatibility and correctness with the detached compilation flow.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the ZK-CUDA context to detach the compilation of the computation graph from the witness generation. This is a significant architectural improvement that allows compiling a circuit once and then running it with different inputs. The changes look solid and the new API is applied consistently across the tests. I've found a few minor issues:
- A typo in an assertion message.
- Some commented-out code that should be removed.
- An opportunity to refactor duplicated logic in the
solve_witnessfunction to improve maintainability. - A test with a large amount of commented-out code that should be cleaned up.
| // let handle = make_device_mem( | ||
| // &mut self.device_memories, | ||
| // ov, | ||
| // shape_prepend(shape, num_parallel), | ||
| // ); | ||
| // let id = handle.as_ref().unwrap().id; | ||
| let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of commented-out code should be removed to improve code clarity and maintainability.
| // let handle = make_device_mem( | |
| // &mut self.device_memories, | |
| // ov, | |
| // shape_prepend(shape, num_parallel), | |
| // ); | |
| // let id = handle.as_ref().unwrap().id; | |
| let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel)); | |
| let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel)); |
| self.state = ContextState::WitnessDone; | ||
|
|
||
| for kernel_call in self.kernel_calls.iter() { | ||
| let kernel = self.kernel_primitives.get(kernel_call.kernel_id); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for kernel_call in self.kernel_calls.iter() { | ||
| let kernel = self.kernel_primitives.get(kernel_call.kernel_id); | ||
| let num_parallel = kernel_call.num_parallel; | ||
| let is_broadcast = &kernel_call.is_broadcast; | ||
|
|
||
| let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()]; | ||
| let mut chunk_sizes: Vec<Option<usize>> = vec![None; kernel.io_specs().len()]; | ||
| for (((input, &ib), ir_inputs), chunk_size) in kernel_call.input_handles | ||
| .iter() | ||
| .zip(is_broadcast.iter()) | ||
| .zip(ir_inputs_all.iter_mut()) | ||
| .zip(chunk_sizes.iter_mut()) | ||
| { | ||
| if input.is_none() { | ||
| continue; | ||
| } | ||
| let handle = ensure_handle(input.clone()); | ||
| let values = handle | ||
| .shape_history | ||
| .permute_vec(&self.device_memories[handle.id].values); | ||
| if !ib { | ||
| *chunk_size = Some(values.len() / num_parallel); | ||
| } | ||
| *ir_inputs = values; | ||
| } | ||
| let mut ir_inputs_per_parallel = Vec::new(); | ||
| for parallel_i in 0..num_parallel { | ||
| let mut ir_inputs = vec![SIMDField::<C>::zero(); kernel.ir_for_calling().input_size()]; | ||
| for (i, ((input, input_start), input_end)) in kernel_call.input_handles | ||
| .iter() | ||
| .zip(kernel.ir_input_offsets().iter()) | ||
| .zip(kernel.ir_input_offsets().iter().skip(1)) | ||
| .enumerate() | ||
| { | ||
| if input.is_none() { | ||
| continue; | ||
| } | ||
| self.ir_copy_from_device_memory( | ||
| &ir_inputs_all[i], | ||
| &mut ir_inputs[*input_start..*input_end], | ||
| is_broadcast[i], | ||
| parallel_i, | ||
| chunk_sizes[i], | ||
| ); | ||
| } | ||
| ir_inputs_per_parallel.push(ir_inputs); | ||
| } | ||
| let ir_outputs_per_parallel: Vec<Result<Vec<SIMDField<C>>, Error>> = ir_inputs_per_parallel | ||
| .into_par_iter() | ||
| .map(|ir_inputs| { | ||
| kernel | ||
| .ir_for_calling() | ||
| .eval_safe_simd(ir_inputs, &[], &self.hint_caller) | ||
| }) | ||
| .collect(); | ||
|
|
||
| let mut outputs_tmp: Vec<Vec<SIMDField::<C>>> = vec![Vec::new(); kernel.io_specs().len()]; | ||
| for ir_outputs in ir_outputs_per_parallel { | ||
| let ir_outputs = ir_outputs?; | ||
| for (((spec, output_start), output_end), out) in kernel | ||
| .io_specs() | ||
| .iter() | ||
| .zip(kernel.ir_output_offsets().iter()) | ||
| .zip(kernel.ir_output_offsets().iter().skip(1)) | ||
| .zip(outputs_tmp.iter_mut()) | ||
| { | ||
| if !spec.is_output { | ||
| continue; | ||
| } | ||
| out.extend_from_slice(&ir_outputs[*output_start..*output_end]); | ||
| } | ||
| } | ||
|
|
||
| for ((output, spec), ov) in kernel_call.output_handles | ||
| .iter() | ||
| .zip(kernel.io_specs().iter()) | ||
| .zip(outputs_tmp.into_iter()) | ||
| { | ||
| if !spec.is_output { | ||
| continue; | ||
| } | ||
| let output_id = output.as_ref().unwrap().id; | ||
| self.device_memories[output_id].values = ov; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code for kernel execution, which was moved into solve_witness, shares a lot of similar logic with the hint-solving part that follows (starting from line 693).
Specifically, the preparation of inputs for parallel execution is duplicated. To improve maintainability and reduce code duplication, consider refactoring the common logic for preparing inputs into a separate helper function. This would make the solve_witness function cleaner and easier to understand.
| } | ||
| let a = ctx.copy_to_device(&a); | ||
| let a_shape = vec![16, 2]; | ||
| let (a , a_id) = ctx.new_device_memory(a_shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // let result: Vec<M31> = ctx.copy_to_host(b); | ||
| // assert_eq!( | ||
| // result, | ||
| // vec![ | ||
| // M31::from(1), | ||
| // M31::from(0), | ||
| // M31::from(1), | ||
| // M31::from(0), | ||
| // M31::from(1), | ||
| // M31::from(0), | ||
| // M31::from(1), | ||
| // M31::from(0) | ||
| // ] | ||
| // ); | ||
|
|
||
| type P = Expander<M31Config>; | ||
| let computation_graph = ctx.compile_computation_graph().unwrap(); | ||
| ctx.solve_witness().unwrap(); | ||
| let a_value = M31::from(0x55); | ||
| ctx.copy_to_device(&a_value, a_id); | ||
| // ctx.solve_witness().unwrap(); | ||
| println!("{:?}", computation_graph); | ||
| println!("{:?}", ctx.export_device_memories()); | ||
| let (prover_setup, verifier_setup) = P::setup(&computation_graph); | ||
| let proof = P::prove( | ||
| &prover_setup, | ||
| &computation_graph, | ||
| &ctx.export_device_memories(), | ||
| ); | ||
| assert!(P::verify(&verifier_setup, &computation_graph, &proof)); | ||
| // println!("{:?}", ctx.export_device_memories()); | ||
| // let (prover_setup, verifier_setup) = P::setup(&computation_graph); | ||
| // let proof = P::prove( | ||
| // &prover_setup, | ||
| // &computation_graph, | ||
| // &ctx.export_device_memories(), | ||
| // ); | ||
| // assert!(P::verify(&verifier_setup, &computation_graph, &proof)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A significant portion of the zkcuda_to_binary test function is commented out. Commented-out code can reduce readability and lead to confusion. If this test is temporarily disabled or a work-in-progress, it would be better to add a #[ignore] attribute with a TODO comment explaining why. If the code is no longer needed, it should be removed.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully detaches the compilation of the computation graph from the witness solving phase by introducing a new API for device memory management. The core logic changes in expander_compiler/src/zkcuda/context.rs are sound.
However, this refactoring has broken several tests (cg_mpi_share.rs, zkcuda_keccak.rs, zkcuda_matmul.rs) because their result assertions are no longer valid with the deferred execution model. These tests need to be updated to check results after solve_witness() is called.
I've also pointed out some minor issues like typos, commented-out code, and a potential redundant computation in one of the tests.
| let out: Vec<Vec<M31>> = ctx.copy_to_host(out); | ||
| println!("copy to host ok"); | ||
| assert_eq!(out, expected_res); | ||
| assert_eq!(out[0][0], expected_res[0][0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ) { | ||
| assert!(device_memory_id < self.device_memories.len(), "The device memory doesn't exist."); | ||
| let (flat, shape) = flatten_shaped(host_memory); | ||
| // assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion is commented out. If it's no longer necessary, it should be removed to improve code clarity. If it is still needed, it should be fixed and re-enabled.
The current logic shape_vec_len(&self.device_memories[device_memory_id].required_shape_products) seems incorrect, as required_shape_products is already a vector of products, not a shape. A correct check might involve comparing against the last element of required_shape_products, which represents the total size.
| ) -> DeviceMemoryHandle { | ||
| device_memory_id: usize, | ||
| ) { | ||
| assert!(device_memory_id < self.device_memories.len(), "The device memory dosen't exist."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a typo in the assertion message. It should be "doesn't" instead of "dosen't". This also applies to the assertion on line 260.
| assert!(device_memory_id < self.device_memories.len(), "The device memory dosen't exist."); | |
| assert!(device_memory_id < self.device_memories.len(), "The device memory doesn't exist."); |
| let kernel = self.kernel_primitives.get(kernel_call.kernel_id); | ||
| let num_parallel = kernel_call.num_parallel; | ||
| let is_broadcast = &kernel_call.is_broadcast; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| } | ||
| } | ||
| self.state = ContextState::WitnessDone; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| let a = ctx.copy_to_device(&a); | ||
| let a_shape = vec![16, 2]; | ||
| let (a, a_id) = ctx.new_device_memory(a_shape); | ||
| // let a = ctx.copy_to_device(&a); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| let mut c = None; | ||
| call_kernel!(ctx, kernel_add_16, 1, b, mut c).unwrap(); | ||
| ctx.copy_simd_to_device(&a_value, a_id); | ||
| ctx.solve_result().unwrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
solve_result() is called here, and then it's called again within solve_witness() in the zkcuda_test_simd function that calls zkcuda_test_simd_prepare_ctx. This results in redundant computation.
Consider restructuring the test to avoid this double execution. For example, the assertions could be moved to the zkcuda_test_simd function after solve_witness is called, and this call to solve_result could be removed.
No description provided.