-
Notifications
You must be signed in to change notification settings - Fork 25
Hc/broadcast #190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Hc/broadcast #190
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @hczphn, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the system's data broadcasting capabilities by moving from a simple binary (on/off) broadcasting model to a more flexible partial broadcasting mechanism. This allows for fine-grained control over data replication during parallel computations, improving the adaptability of the framework for various computational patterns. The changes are integrated across the core data structures and processing logic, ensuring consistent behavior with the new broadcasting paradigm.
Highlights
- Enhanced Broadcasting Mechanism: The core change introduces support for 'partial broadcasting' by changing the
is_broadcastflag from a boolean to ausizeinteger. This allows data to be broadcasted a fixed number of times (e.g., 2, 3, etc.) instead of just an all-or-nothing approach. - Type System and Logic Updates: The change from
booltousizefor broadcasting required extensive updates across the codebase, including struct definitions (KernelCall,ProofTemplate), function signatures, and the underlying logic for data handling, shape compatibility checks, and input preparation within thezkcudacontext and proving systems. - Dependency and Test Infrastructure Changes: New dependencies (
stacker,psm) were added toCargo.lockandexpander_compiler/Cargo.toml. The test suite was updated to reflect the new broadcasting logic, including changes to kernel definitions and the use ofBN254ConfigandExpanderNoOverSubscribeinzkcuda_matmul.rs.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for partial broadcasting, a significant enhancement. The core logic change from a boolean is_broadcast flag to a usize factor is implemented across many files. While the changes are generally correct, there are several areas for improvement regarding code clarity, consistency, and robustness. I've identified some potentially incorrect logic involving next_power_of_two, an unused function parameter, and several opportunities to remove dead or obscure code to improve maintainability.
| values: &[SIMDField<C>], | ||
| s: &mut [SIMDField<C>], | ||
| is_broadcast: bool, | ||
| is_broadcast: usize, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| let local_val_len = vals.as_ref().len() / parallel_num; | ||
| &vals.as_ref()[local_val_len * parallel_index..local_val_len * (parallel_index + 1)] | ||
| } | ||
| let is_broadcast_next_power_of_two = is_broadcast.next_power_of_two(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of is_broadcast.next_power_of_two() seems incorrect. It implies that the broadcast factor is rounded up to the next power of two, which could lead to an incorrect local_val_len if is_broadcast is not a power of two. This is also inconsistent with other parts of the code that use is_broadcast directly. Please use *is_broadcast directly or clarify the intent with a comment.
| let is_broadcast_next_power_of_two = is_broadcast.next_power_of_two(); | |
| let is_broadcast_next_power_of_two = *is_broadcast; |
| (challenge, component_idx_vars) | ||
| } else { | ||
| let n_vals_vars = (total_vals_len / parallel_count).ilog2() as usize; | ||
| let n_vals_vars = (total_vals_len / (parallel_count / broadcast_num.next_power_of_two())).ilog2() as usize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of broadcast_num.next_power_of_two() is inconsistent with other parts of the code (e.g., partition_challenge_and_location_for_pcs_no_mpi) and may be incorrect. It should likely be broadcast_num directly.
| let n_vals_vars = (total_vals_len / (parallel_count / broadcast_num.next_power_of_two())).ilog2() as usize; | |
| let n_vals_vars = (total_vals_len / (parallel_count / broadcast_num)).ilog2() as usize; |
| //TODO: what is challenge.r_mpi, why need it when broadcast is false? | ||
| challenge.rz.extend_from_slice(&challenge.r_mpi[..(parallel_count / broadcast_num).ilog2() as usize]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| .unwrap() | ||
| .shape_history | ||
| .get_initial_split_list(!ib); | ||
| .get_initial_split_list(ib/num_parallel+1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expression ib/num_parallel+1 is functionally correct but obscure. It distinguishes between full broadcast (ib == num_parallel) and other cases. Using a more explicit if expression would improve readability.
| .get_initial_split_list(ib/num_parallel+1); | |
| .get_initial_split_list(if ib == num_parallel { 0 } else { 1 }); |
| .map(|dm| { | ||
| let shape = prefix_products_to_shape(&dm.required_shape_products); | ||
| let im = shape_padded_mapping(&shape); | ||
| let tmp = im.map_inputs(&dm.values); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs
Outdated
Show resolved
Hide resolved
expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs
Show resolved
Hide resolved
| pub fn get_initial_split_list(&self, split_first_dim: usize) -> Vec<usize> { | ||
| let last_entry = self.entries.last().unwrap().minimize(split_first_dim==1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function get_initial_split_list takes a usize for split_first_dim and then checks split_first_dim==1. This is not idiomatic Rust. It would be clearer to use a bool parameter and adjust call sites to pass a boolean, which would make their logic more explicit and readable.
| pub fn get_initial_split_list(&self, split_first_dim: usize) -> Vec<usize> { | |
| let last_entry = self.entries.last().unwrap().minimize(split_first_dim==1); | |
| pub fn get_initial_split_list(&self, split_first_dim: bool) -> Vec<usize> { | |
| let last_entry = self.entries.last().unwrap().minimize(split_first_dim); |
| // #[test] | ||
| // fn zkcuda_matmul_sum() { | ||
| // let kernel_mul_line: KernelPrimitive<M31Config> = compile_mul_line().unwrap(); | ||
| // let kernel_sum_8_elements: KernelPrimitive<M31Config> = compile_sum_8_elements().unwrap(); | ||
|
|
||
| // let mut ctx: Context<M31Config> = Context::default(); | ||
|
|
||
| // let mut mat_a: Vec<Vec<M31>> = vec![]; | ||
| // for i in 0..64 { | ||
| // mat_a.push(vec![]); | ||
| // for j in 0..32 { | ||
| // mat_a[i].push(M31::from((i * 233 + j + 1) as u32)); | ||
| // } | ||
| // } | ||
| // let mut mat_b: Vec<Vec<M31>> = vec![]; | ||
| // for i in 0..32 { | ||
| // mat_b.push(vec![]); | ||
| // for j in 0..64 { | ||
| // mat_b[i].push(M31::from((i * 2333 + j + 1) as u32)); | ||
| // } | ||
| // } | ||
| // let mut expected_result = M31::zero(); | ||
| // for i in 0..64 { | ||
| // for j in 0..64 { | ||
| // for k in 0..32 { | ||
| // expected_result += mat_a[i][k] * mat_b[k][j]; | ||
| // } | ||
| // } | ||
| // } | ||
|
|
||
| // let a = ctx.copy_to_device(&mat_a); | ||
| // let b = ctx.copy_to_device(&mat_b); | ||
| // let mut c = None; | ||
| // call_kernel!(ctx, kernel_mul_line, 64, a, b, mut c).unwrap(); | ||
|
|
||
| // let c = c.reshape(&[512, 8]); | ||
| // let mut d = None; | ||
| // call_kernel!(ctx, kernel_sum_8_elements, 512, c, mut d).unwrap(); | ||
|
|
||
| // let d = d.reshape(&[64, 8]); | ||
| // let mut e = None; | ||
| // call_kernel!(ctx, kernel_sum_8_elements, 64, d, mut e).unwrap(); | ||
|
|
||
| // let e = e.reshape(&[8, 8]); | ||
| // let mut f = None; | ||
| // call_kernel!(ctx, kernel_sum_8_elements, 8, e, mut f).unwrap(); | ||
|
|
||
| // let f = f.reshape(&[1, 8]); | ||
| // let mut g = None; | ||
| // call_kernel!(ctx, kernel_sum_8_elements, 1, f, mut g).unwrap(); | ||
|
|
||
| // let g = g.reshape(&[]); | ||
| // let result: M31 = ctx.copy_to_host(g); | ||
| // assert_eq!(result, expected_result); | ||
|
|
||
| // type P = Expander<M31Config>; | ||
| // let computation_graph = ctx.compile_computation_graph().unwrap(); | ||
| // ctx.solve_witness().unwrap(); | ||
| // let (prover_setup, verifier_setup) = P::setup(&computation_graph); | ||
| // let proof = P::prove( | ||
| // &prover_setup, | ||
| // &computation_graph, | ||
| // ctx.export_device_memories(), | ||
| // ); | ||
| // assert!(P::verify(&verifier_setup, &computation_graph, &proof)); | ||
| // } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for partial broadcasting by changing is_broadcast from a boolean to a usize factor. This is a valuable enhancement for flexibility. The implementation is largely consistent across the codebase. However, I've identified a critical bug in the variable remapping logic within reorder_ir_inputs that could cause panics or incorrect circuit compilation. Additionally, there's a TODO comment that points to unclear logic, which should be addressed to improve maintainability. Other changes look good and align with the PR's objective.
| // println!("{:?}", im.mapping()); | ||
| for (j, &k) in im.mapping().iter().enumerate() { | ||
| var_new_id[prev + k + 1] = var_max + j + 1; | ||
| var_new_id[prev + j + 1] = var_max + k + 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line appears to have incorrect logic for remapping variable IDs. var_new_id should map from the old variable ID to the new one. The im.mapping() provides a mapping from a padded index j to an original index k. The old variable ID is based on k, and the new ID is based on j.
The current implementation var_new_id[prev + j + 1] = var_max + k + 1; attempts to use a padded index j to index var_new_id, which is sized for original, unpadded inputs. This will likely lead to an out-of-bounds panic when padding is active (padded_len > original_len).
The logic should be var_new_id[old_id] = new_id, which translates to var_new_id[prev + k + 1] = var_max + j + 1;.
| var_new_id[prev + j + 1] = var_max + k + 1; | |
| var_new_id[prev + k + 1] = var_max + j + 1; |
| //TODO: what is challenge.r_mpi, why need it when broadcast is false? | ||
| challenge.rz.extend_from_slice( | ||
| &challenge.r_mpi[..(parallel_count / broadcast_num).ilog2() as usize], | ||
| ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO comment here indicates uncertainty about the logic, which can be a maintainability issue. It's important to clarify this. If the logic is correct, please replace the TODO with a comment explaining why challenge.r_mpi is used here and how the slicing logic works for the non-broadcast case. If the logic is incorrect or not fully understood, it should be revisited.
| &vals.as_ref()[local_val_len * parallel_index..local_val_len * (parallel_index + 1)] | ||
| } | ||
| let is_broadcast_next_power_of_two = is_broadcast.next_power_of_two(); | ||
| let local_val_len = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The partition does not seem to be correct if the original is_broadcast is not a power of two. We may need to enforce that parallel_count is a multiple of is_broadcast.
a0b2df2 to
70acfca
Compare
Support partial broadcast
Previously, only boolean broadcasting was supported: if broadcast = true, the data would be broadcasted parallel_count times; otherwise, no broadcasting was applied.
However, in some cases we need to broadcast the data a fixed number of times (e.g., 2, 3, …).
This PR adds support for such partial broadcasts.
Concretely, when the data has shape [k, kernel_shape] and the parallel count is N, if N / k = d, then the data is broadcasted along the first dimension by stacking it d times, resulting in a final shape [N, kernel_shape].
Example:
data: [[a,a], [b,b]] --> shape (2,2)
kernel_shape = [2], parallel_count = 4
broadcast_data: [[a,a], [b,b],[a,a], [b,b]] --> shape(4,2)