Skip to content

Conversation

@chonpsk
Copy link

@chonpsk chonpsk commented Jul 18, 2025

No description provided.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @chonpsk, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant refactoring to the ZK-CUDA context, decoupling the compilation of the computation graph from the concrete input data. This allows for a more flexible workflow where memory is allocated and its shape defined upfront, while the actual input values are provided later, specifically before the witness generation phase. This change enhances the system's ability to compile circuits independently of specific data, enabling the same compiled circuit to be used with varying inputs for proof generation.

Highlights

  • Decoupled Data Handling: The Context now provides a new_device_memory function to allocate memory and define its shape without immediately copying data. This separates memory allocation from data population.
  • Deferred Data Copying: The copy_to_device function has been updated to populate pre-allocated device memory. Its usage across the codebase has been shifted to occur after the compile_computation_graph call, but before solve_witness, allowing the computation graph to be compiled independently of concrete input values.
  • Relocated Kernel Evaluation: The core logic for evaluating kernel primitives, which was previously executed within the call_kernel! macro's underlying implementation, has been moved to the solve_witness method. This ensures that kernel computations are performed when concrete witness data is available.
  • Test Case Adaptations: All relevant test cases and examples have been updated to reflect this new data handling paradigm, ensuring compatibility and correctness with the detached compilation flow.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the ZK-CUDA context to detach the compilation of the computation graph from the witness generation. This is a significant architectural improvement that allows compiling a circuit once and then running it with different inputs. The changes look solid and the new API is applied consistently across the tests. I've found a few minor issues:

  • A typo in an assertion message.
  • Some commented-out code that should be removed.
  • An opportunity to refactor duplicated logic in the solve_witness function to improve maintainability.
  • A test with a large amount of commented-out code that should be cleaned up.

Comment on lines 401 to 407
// let handle = make_device_mem(
// &mut self.device_memories,
// ov,
// shape_prepend(shape, num_parallel),
// );
// let id = handle.as_ref().unwrap().id;
let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of commented-out code should be removed to improve code clarity and maintainability.

Suggested change
// let handle = make_device_mem(
// &mut self.device_memories,
// ov,
// shape_prepend(shape, num_parallel),
// );
// let id = handle.as_ref().unwrap().id;
let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel));
let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel));

self.state = ContextState::WitnessDone;

for kernel_call in self.kernel_calls.iter() {
let kernel = self.kernel_primitives.get(kernel_call.kernel_id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a minor formatting issue here with an extra space before self. Please remove it to maintain consistent code style.

Suggested change
let kernel = self.kernel_primitives.get(kernel_call.kernel_id);
let kernel = self.kernel_primitives.get(kernel_call.kernel_id);

Comment on lines 688 to 772
for kernel_call in self.kernel_calls.iter() {
let kernel = self.kernel_primitives.get(kernel_call.kernel_id);
let num_parallel = kernel_call.num_parallel;
let is_broadcast = &kernel_call.is_broadcast;

let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()];
let mut chunk_sizes: Vec<Option<usize>> = vec![None; kernel.io_specs().len()];
for (((input, &ib), ir_inputs), chunk_size) in kernel_call.input_handles
.iter()
.zip(is_broadcast.iter())
.zip(ir_inputs_all.iter_mut())
.zip(chunk_sizes.iter_mut())
{
if input.is_none() {
continue;
}
let handle = ensure_handle(input.clone());
let values = handle
.shape_history
.permute_vec(&self.device_memories[handle.id].values);
if !ib {
*chunk_size = Some(values.len() / num_parallel);
}
*ir_inputs = values;
}
let mut ir_inputs_per_parallel = Vec::new();
for parallel_i in 0..num_parallel {
let mut ir_inputs = vec![SIMDField::<C>::zero(); kernel.ir_for_calling().input_size()];
for (i, ((input, input_start), input_end)) in kernel_call.input_handles
.iter()
.zip(kernel.ir_input_offsets().iter())
.zip(kernel.ir_input_offsets().iter().skip(1))
.enumerate()
{
if input.is_none() {
continue;
}
self.ir_copy_from_device_memory(
&ir_inputs_all[i],
&mut ir_inputs[*input_start..*input_end],
is_broadcast[i],
parallel_i,
chunk_sizes[i],
);
}
ir_inputs_per_parallel.push(ir_inputs);
}
let ir_outputs_per_parallel: Vec<Result<Vec<SIMDField<C>>, Error>> = ir_inputs_per_parallel
.into_par_iter()
.map(|ir_inputs| {
kernel
.ir_for_calling()
.eval_safe_simd(ir_inputs, &[], &self.hint_caller)
})
.collect();

let mut outputs_tmp: Vec<Vec<SIMDField::<C>>> = vec![Vec::new(); kernel.io_specs().len()];
for ir_outputs in ir_outputs_per_parallel {
let ir_outputs = ir_outputs?;
for (((spec, output_start), output_end), out) in kernel
.io_specs()
.iter()
.zip(kernel.ir_output_offsets().iter())
.zip(kernel.ir_output_offsets().iter().skip(1))
.zip(outputs_tmp.iter_mut())
{
if !spec.is_output {
continue;
}
out.extend_from_slice(&ir_outputs[*output_start..*output_end]);
}
}

for ((output, spec), ov) in kernel_call.output_handles
.iter()
.zip(kernel.io_specs().iter())
.zip(outputs_tmp.into_iter())
{
if !spec.is_output {
continue;
}
let output_id = output.as_ref().unwrap().id;
self.device_memories[output_id].values = ov;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code for kernel execution, which was moved into solve_witness, shares a lot of similar logic with the hint-solving part that follows (starting from line 693).

Specifically, the preparation of inputs for parallel execution is duplicated. To improve maintainability and reduce code duplication, consider refactoring the common logic for preparing inputs into a separate helper function. This would make the solve_witness function cleaner and easier to understand.

}
let a = ctx.copy_to_device(&a);
let a_shape = vec![16, 2];
let (a , a_id) = ctx.new_device_memory(a_shape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a minor formatting issue here with an extra space in let (a , a_id).

Suggested change
let (a , a_id) = ctx.new_device_memory(a_shape);
let (a, a_id) = ctx.new_device_memory(a_shape);

Comment on lines 257 to 286
// let result: Vec<M31> = ctx.copy_to_host(b);
// assert_eq!(
// result,
// vec![
// M31::from(1),
// M31::from(0),
// M31::from(1),
// M31::from(0),
// M31::from(1),
// M31::from(0),
// M31::from(1),
// M31::from(0)
// ]
// );

type P = Expander<M31Config>;
let computation_graph = ctx.compile_computation_graph().unwrap();
ctx.solve_witness().unwrap();
let a_value = M31::from(0x55);
ctx.copy_to_device(&a_value, a_id);
// ctx.solve_witness().unwrap();
println!("{:?}", computation_graph);
println!("{:?}", ctx.export_device_memories());
let (prover_setup, verifier_setup) = P::setup(&computation_graph);
let proof = P::prove(
&prover_setup,
&computation_graph,
&ctx.export_device_memories(),
);
assert!(P::verify(&verifier_setup, &computation_graph, &proof));
// println!("{:?}", ctx.export_device_memories());
// let (prover_setup, verifier_setup) = P::setup(&computation_graph);
// let proof = P::prove(
// &prover_setup,
// &computation_graph,
// &ctx.export_device_memories(),
// );
// assert!(P::verify(&verifier_setup, &computation_graph, &proof));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

A significant portion of the zkcuda_to_binary test function is commented out. Commented-out code can reduce readability and lead to confusion. If this test is temporarily disabled or a work-in-progress, it would be better to add a #[ignore] attribute with a TODO comment explaining why. If the code is no longer needed, it should be removed.

@chonpsk
Copy link
Author

chonpsk commented Jul 23, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully detaches the compilation of the computation graph from the witness solving phase by introducing a new API for device memory management. The core logic changes in expander_compiler/src/zkcuda/context.rs are sound.

However, this refactoring has broken several tests (cg_mpi_share.rs, zkcuda_keccak.rs, zkcuda_matmul.rs) because their result assertions are no longer valid with the deferred execution model. These tests need to be updated to check results after solve_witness() is called.

I've also pointed out some minor issues like typos, commented-out code, and a potential redundant computation in one of the tests.

let out: Vec<Vec<M31>> = ctx.copy_to_host(out);
println!("copy to host ok");
assert_eq!(out, expected_res);
assert_eq!(out[0][0], expected_res[0][0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The refactoring to deferred execution has broken the logic of this test. The assert_eq! calls will now fail because out is checked before solve_witness is called, meaning it contains uninitialized data.

These assertions should be moved to after the ctx.solve_witness().unwrap() call on line 342.

) {
assert!(device_memory_id < self.device_memories.len(), "The device memory doesn't exist.");
let (flat, shape) = flatten_shaped(host_memory);
// assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion is commented out. If it's no longer necessary, it should be removed to improve code clarity. If it is still needed, it should be fixed and re-enabled.

The current logic shape_vec_len(&self.device_memories[device_memory_id].required_shape_products) seems incorrect, as required_shape_products is already a vector of products, not a shape. A correct check might involve comparing against the last element of required_shape_products, which represents the total size.

) -> DeviceMemoryHandle {
device_memory_id: usize,
) {
assert!(device_memory_id < self.device_memories.len(), "The device memory dosen't exist.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in the assertion message. It should be "doesn't" instead of "dosen't". This also applies to the assertion on line 260.

Suggested change
assert!(device_memory_id < self.device_memories.len(), "The device memory dosen't exist.");
assert!(device_memory_id < self.device_memories.len(), "The device memory doesn't exist.");

let kernel = self.kernel_primitives.get(kernel_call.kernel_id);
let num_parallel = kernel_call.num_parallel;
let is_broadcast = &kernel_call.is_broadcast;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This empty line can be removed to improve code conciseness.

}
}
self.state = ContextState::WitnessDone;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This empty line can be removed.

let a = ctx.copy_to_device(&a);
let a_shape = vec![16, 2];
let (a, a_id) = ctx.new_device_memory(a_shape);
// let a = ctx.copy_to_device(&a);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line of code is commented out and can be removed to improve readability.

let mut c = None;
call_kernel!(ctx, kernel_add_16, 1, b, mut c).unwrap();
ctx.copy_simd_to_device(&a_value, a_id);
ctx.solve_result().unwrap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

solve_result() is called here, and then it's called again within solve_witness() in the zkcuda_test_simd function that calls zkcuda_test_simd_prepare_ctx. This results in redundant computation.

Consider restructuring the test to avoid this double execution. For example, the assertions could be moved to the zkcuda_test_simd function after solve_witness is called, and this call to solve_result could be removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants