-
Notifications
You must be signed in to change notification settings - Fork 25
Description
Background
Currently, ZkCuda generates a layered circuit directly when calling the compile_xx_kernel function. This immediate determination of the input layout for the layered circuit leads to the following issue:
Consider a kernel A with an input of [Field;5] and no output. This kernel's input will occupy 5 out of 8 slots in the layered circuit, padded to 2^n. It might look something like this: [xxxxx...].
Now, consider another kernel B with an input of [[Field;5];3] and no output. This kernel's input will occupy 15 out of 16 slots in the layered circuit, padded similarly: [xxxxxxxxxxxxxxx.].
When calling kernel A with num_parallel=3, the committed input is [xxxxx...xxxxx...xxxxx...........]. This input layout is incompatible with kernel B's input layout. Consequently, the same commitment cannot be used for both kernels.
This is why the reshape function requires each dimension to be 2^n.
This breaks user expectations because the input shapes for num_parallel=3 kernel A and num_parallel=1 kernel B are [3,5] and [1,3,5] respectively. They differ only by an initial 1 and should be convertible. Achieving this is one of ZkCuda's goals.
Solution
I propose modifying the Kernel structure (the output of compile_xx_kernel function) to use dest IR or the context during the layering process instead of the layered circuit. This retains more information, allowing for input layer rearrangement.
Users can call kernels with these modified Kernel structures, but the corresponding layered circuit is not yet determined. When users call to_computation_graph, the layered circuits are compiled, and all input connections are recorded. The compiled computation graph can be dumped to a file for future use.
Overall, most APIs remain unchanged, with the following modifications:
- Arbitrary reshaping is allowed.
- Kernel serialization is still possible, but reading a saved Kernel will only reduce part of the compilation time.
- A new function will be added to generate a complete computation graph, which can automatically run and prove each kernel given the inputs.
This will only modify the user interface, without changing the Prover-side interface.
New API Proposal
The following is my vision for the function that generates the complete computation graph:
impl Context {
// Compile the current call relationships into a computation graph
fn generate_computation_graph(inputs: &[DeviceMemoryHandle], outputs: &[DeviceMemoryHandle]) -> (ProvingComputationGraph, VerifyingComputationGraph);
// Run the specified computation graph. If no kernels have been called and the device memory only contains the input variables, the generated proof will match the corresponding VerifyingComputationGraph.
fn run_proving_computation_graph(graph: &ProvingComputationGraph, inputs: &[DeviceMemoryHandle]) -> [DeviceMemoryHandle];
}Example:
fn main() {
// ...
// Previous kernel calls
let a = ctx.copy_to_device(&a_src);
let mut b: DeviceMemoryHandle = None;
call_kernel!(ctx, kernel_add_2, a, mut b);
let b = b.reshape(&[1, 16]);
let mut c: DeviceMemoryHandle = None;
call_kernel!(ctx, kernel_add_16, b, mut c);
let c = c.reshape(&[]);
let result: M31 = ctx.copy_to_host(c);
assert_eq!(result, M31::from(32 * 33 / 2));
// Export the computation graph for proving
let (pg, vg) = ctx.generate_computation_graph(&[a], &[c]);
// Apply the computation graph
let ctx2 = Context::default();
let a2 = ctx2.copy_to_device(&a_src);
let c2 = ctx2.run_proving_computation_graph(&pg, &[a2])[0];
// The result should be consistent with the previous run
let c2 = c2.reshape(&[]);
let result2: M31 = ctx2.copy_to_host(c2);
assert_eq!(result2, M31::from(32 * 33 / 2));
// The generated proof can be verified using the previous verification computation graph
let proof = ctx2.proof();
assert!(vg.verify(&proof, &verifier_setup));
// Both computation graphs can be serialized
pg.serialize_into(...);
vg.serialize_into(...);
// ...
}Feedback Request
I would appreciate any feedback on the overall design and the specific API proposal.
@niconiconi @zhenfeizhang @hczphn @Tao-Lu-X @stirlingx001 @DreamWuGit