Skip to content

Conversation

@hczphn
Copy link
Collaborator

@hczphn hczphn commented Sep 3, 2025

Support partial broadcast
Previously, only boolean broadcasting was supported: if broadcast = true, the data would be broadcasted parallel_count times; otherwise, no broadcasting was applied.

However, in some cases we need to broadcast the data a fixed number of times (e.g., 2, 3, …).
This PR adds support for such partial broadcasts.

Concretely, when the data has shape [k, kernel_shape] and the parallel count is N, if N / k = d, then the data is broadcasted along the first dimension by stacking it d times, resulting in a final shape [N, kernel_shape].
Example:
data: [[a,a], [b,b]] --> shape (2,2)
kernel_shape = [2], parallel_count = 4
broadcast_data: [[a,a], [b,b],[a,a], [b,b]] --> shape(4,2)

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @hczphn, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the system's data broadcasting capabilities by moving from a simple binary (on/off) broadcasting model to a more flexible partial broadcasting mechanism. This allows for fine-grained control over data replication during parallel computations, improving the adaptability of the framework for various computational patterns. The changes are integrated across the core data structures and processing logic, ensuring consistent behavior with the new broadcasting paradigm.

Highlights

  • Enhanced Broadcasting Mechanism: The core change introduces support for 'partial broadcasting' by changing the is_broadcast flag from a boolean to a usize integer. This allows data to be broadcasted a fixed number of times (e.g., 2, 3, etc.) instead of just an all-or-nothing approach.
  • Type System and Logic Updates: The change from bool to usize for broadcasting required extensive updates across the codebase, including struct definitions (KernelCall, ProofTemplate), function signatures, and the underlying logic for data handling, shape compatibility checks, and input preparation within the zkcuda context and proving systems.
  • Dependency and Test Infrastructure Changes: New dependencies (stacker, psm) were added to Cargo.lock and expander_compiler/Cargo.toml. The test suite was updated to reflect the new broadcasting logic, including changes to kernel definitions and the use of BN254Config and ExpanderNoOverSubscribe in zkcuda_matmul.rs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for partial broadcasting, a significant enhancement. The core logic change from a boolean is_broadcast flag to a usize factor is implemented across many files. While the changes are generally correct, there are several areas for improvement regarding code clarity, consistency, and robustness. I've identified some potentially incorrect logic involving next_power_of_two, an unused function parameter, and several opportunities to remove dead or obscure code to improve maintainability.

values: &[SIMDField<C>],
s: &mut [SIMDField<C>],
is_broadcast: bool,
is_broadcast: usize,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The is_broadcast parameter is unused within ir_copy_from_device_memory and should be removed to improve code clarity. The function signature and all call sites (e.g., lines 418, 825, and 845) should be updated accordingly.

let local_val_len = vals.as_ref().len() / parallel_num;
&vals.as_ref()[local_val_len * parallel_index..local_val_len * (parallel_index + 1)]
}
let is_broadcast_next_power_of_two = is_broadcast.next_power_of_two();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of is_broadcast.next_power_of_two() seems incorrect. It implies that the broadcast factor is rounded up to the next power of two, which could lead to an incorrect local_val_len if is_broadcast is not a power of two. This is also inconsistent with other parts of the code that use is_broadcast directly. Please use *is_broadcast directly or clarify the intent with a comment.

Suggested change
let is_broadcast_next_power_of_two = is_broadcast.next_power_of_two();
let is_broadcast_next_power_of_two = *is_broadcast;

(challenge, component_idx_vars)
} else {
let n_vals_vars = (total_vals_len / parallel_count).ilog2() as usize;
let n_vals_vars = (total_vals_len / (parallel_count / broadcast_num.next_power_of_two())).ilog2() as usize;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of broadcast_num.next_power_of_two() is inconsistent with other parts of the code (e.g., partition_challenge_and_location_for_pcs_no_mpi) and may be incorrect. It should likely be broadcast_num directly.

Suggested change
let n_vals_vars = (total_vals_len / (parallel_count / broadcast_num.next_power_of_two())).ilog2() as usize;
let n_vals_vars = (total_vals_len / (parallel_count / broadcast_num)).ilog2() as usize;

Comment on lines 186 to 187
//TODO: what is challenge.r_mpi, why need it when broadcast is false?
challenge.rz.extend_from_slice(&challenge.r_mpi[..(parallel_count / broadcast_num).ilog2() as usize]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

A TODO comment is present and should be addressed. Additionally, the expression (parallel_count / broadcast_num).ilog2() will panic if the result of the division is not a power of two. Consider adding an assertion to check this condition.

.unwrap()
.shape_history
.get_initial_split_list(!ib);
.get_initial_split_list(ib/num_parallel+1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The expression ib/num_parallel+1 is functionally correct but obscure. It distinguishes between full broadcast (ib == num_parallel) and other cases. Using a more explicit if expression would improve readability.

Suggested change
.get_initial_split_list(ib/num_parallel+1);
.get_initial_split_list(if ib == num_parallel { 0 } else { 1 });

.map(|dm| {
let shape = prefix_products_to_shape(&dm.required_shape_products);
let im = shape_padded_mapping(&shape);
let tmp = im.map_inputs(&dm.values);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable tmp is assigned but never used. It should be removed.

Comment on lines 245 to 246
pub fn get_initial_split_list(&self, split_first_dim: usize) -> Vec<usize> {
let last_entry = self.entries.last().unwrap().minimize(split_first_dim==1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function get_initial_split_list takes a usize for split_first_dim and then checks split_first_dim==1. This is not idiomatic Rust. It would be clearer to use a bool parameter and adjust call sites to pass a boolean, which would make their logic more explicit and readable.

Suggested change
pub fn get_initial_split_list(&self, split_first_dim: usize) -> Vec<usize> {
let last_entry = self.entries.last().unwrap().minimize(split_first_dim==1);
pub fn get_initial_split_list(&self, split_first_dim: bool) -> Vec<usize> {
let last_entry = self.entries.last().unwrap().minimize(split_first_dim);

Comment on lines 36 to 101
// #[test]
// fn zkcuda_matmul_sum() {
// let kernel_mul_line: KernelPrimitive<M31Config> = compile_mul_line().unwrap();
// let kernel_sum_8_elements: KernelPrimitive<M31Config> = compile_sum_8_elements().unwrap();

// let mut ctx: Context<M31Config> = Context::default();

// let mut mat_a: Vec<Vec<M31>> = vec![];
// for i in 0..64 {
// mat_a.push(vec![]);
// for j in 0..32 {
// mat_a[i].push(M31::from((i * 233 + j + 1) as u32));
// }
// }
// let mut mat_b: Vec<Vec<M31>> = vec![];
// for i in 0..32 {
// mat_b.push(vec![]);
// for j in 0..64 {
// mat_b[i].push(M31::from((i * 2333 + j + 1) as u32));
// }
// }
// let mut expected_result = M31::zero();
// for i in 0..64 {
// for j in 0..64 {
// for k in 0..32 {
// expected_result += mat_a[i][k] * mat_b[k][j];
// }
// }
// }

// let a = ctx.copy_to_device(&mat_a);
// let b = ctx.copy_to_device(&mat_b);
// let mut c = None;
// call_kernel!(ctx, kernel_mul_line, 64, a, b, mut c).unwrap();

// let c = c.reshape(&[512, 8]);
// let mut d = None;
// call_kernel!(ctx, kernel_sum_8_elements, 512, c, mut d).unwrap();

// let d = d.reshape(&[64, 8]);
// let mut e = None;
// call_kernel!(ctx, kernel_sum_8_elements, 64, d, mut e).unwrap();

// let e = e.reshape(&[8, 8]);
// let mut f = None;
// call_kernel!(ctx, kernel_sum_8_elements, 8, e, mut f).unwrap();

// let f = f.reshape(&[1, 8]);
// let mut g = None;
// call_kernel!(ctx, kernel_sum_8_elements, 1, f, mut g).unwrap();

// let g = g.reshape(&[]);
// let result: M31 = ctx.copy_to_host(g);
// assert_eq!(result, expected_result);

// type P = Expander<M31Config>;
// let computation_graph = ctx.compile_computation_graph().unwrap();
// ctx.solve_witness().unwrap();
// let (prover_setup, verifier_setup) = P::setup(&computation_graph);
// let proof = P::prove(
// &prover_setup,
// &computation_graph,
// ctx.export_device_memories(),
// );
// assert!(P::verify(&verifier_setup, &computation_graph, &proof));
// }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This large block of commented-out test code should be removed to improve readability.

@hczphn
Copy link
Collaborator Author

hczphn commented Sep 4, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for partial broadcasting by changing is_broadcast from a boolean to a usize factor. This is a valuable enhancement for flexibility. The implementation is largely consistent across the codebase. However, I've identified a critical bug in the variable remapping logic within reorder_ir_inputs that could cause panics or incorrect circuit compilation. Additionally, there's a TODO comment that points to unclear logic, which should be addressed to improve maintainability. Other changes look good and align with the PR's objective.

// println!("{:?}", im.mapping());
for (j, &k) in im.mapping().iter().enumerate() {
var_new_id[prev + k + 1] = var_max + j + 1;
var_new_id[prev + j + 1] = var_max + k + 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line appears to have incorrect logic for remapping variable IDs. var_new_id should map from the old variable ID to the new one. The im.mapping() provides a mapping from a padded index j to an original index k. The old variable ID is based on k, and the new ID is based on j.

The current implementation var_new_id[prev + j + 1] = var_max + k + 1; attempts to use a padded index j to index var_new_id, which is sized for original, unpadded inputs. This will likely lead to an out-of-bounds panic when padding is active (padded_len > original_len).

The logic should be var_new_id[old_id] = new_id, which translates to var_new_id[prev + k + 1] = var_max + j + 1;.

Suggested change
var_new_id[prev + j + 1] = var_max + k + 1;
var_new_id[prev + k + 1] = var_max + j + 1;

Comment on lines +187 to +190
//TODO: what is challenge.r_mpi, why need it when broadcast is false?
challenge.rz.extend_from_slice(
&challenge.r_mpi[..(parallel_count / broadcast_num).ilog2() as usize],
);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO comment here indicates uncertainty about the logic, which can be a maintainability issue. It's important to clarify this. If the logic is correct, please replace the TODO with a comment explaining why challenge.r_mpi is used here and how the slicing logic works for the non-broadcast case. If the logic is incorrect or not fully understood, it should be revisited.

&vals.as_ref()[local_val_len * parallel_index..local_val_len * (parallel_index + 1)]
}
let is_broadcast_next_power_of_two = is_broadcast.next_power_of_two();
let local_val_len =
Copy link
Collaborator

@zhiyong1997 zhiyong1997 Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The partition does not seem to be correct if the original is_broadcast is not a power of two. We may need to enforce that parallel_count is a multiple of is_broadcast.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants