From 8da60afcb4e85e9a54cc3c8f096c4c9c2aed8fe1 Mon Sep 17 00:00:00 2001 From: hczphn Date: Wed, 13 Aug 2025 02:27:08 +0000 Subject: [PATCH 01/13] support flexible broadcast --- expander_compiler/src/zkcuda/context.rs | 77 ++++---- expander_compiler/src/zkcuda/mpi_mem_share.rs | 2 +- .../src/zkcuda/proving_system/common.rs | 32 ++-- .../src/zkcuda/proving_system/dummy.rs | 6 +- .../expander/api_single_thread.rs | 4 +- .../proving_system/expander/prove_impl.rs | 27 +-- .../proving_system/expander/verify_impl.rs | 4 +- .../expander_no_oversubscribe/prove_impl.rs | 21 ++- .../expander_no_oversubscribe/server_fn.rs | 5 +- .../expander_parallelized/prove_impl.rs | 17 +- .../expander_parallelized/verify_impl.rs | 6 +- .../expander_pcs_defered/prove_impl.rs | 2 +- .../expander_pcs_defered/verify_impl.rs | 2 +- .../src/zkcuda/proving_system/traits.rs | 4 +- expander_compiler/src/zkcuda/shape.rs | 32 ++-- .../tests/zkcuda/zkcuda_matmul.rs | 174 +++++++++++------- 16 files changed, 240 insertions(+), 175 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 327b3e2a..ae3579f2 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -1,3 +1,5 @@ +use core::num; + use arith::SimdField; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use serdes::ExpSerde; @@ -44,7 +46,7 @@ pub struct KernelCall { num_parallel: usize, input_handles: Vec, output_handles: Vec, - is_broadcast: Vec, + is_broadcast: Vec, } #[derive(PartialEq, Eq, Clone, Debug, ExpSerde)] @@ -53,7 +55,7 @@ pub struct ProofTemplate { pub commitment_indices: Vec, pub commitment_bit_orders: Vec, pub parallel_count: usize, - pub is_broadcast: Vec, + pub is_broadcast: Vec, } impl ProofTemplate { @@ -69,7 +71,7 @@ impl ProofTemplate { pub fn parallel_count(&self) -> usize { self.parallel_count } - pub fn is_broadcast(&self) -> &[bool] { + pub fn is_broadcast(&self) -> &[usize] { &self.is_broadcast } } @@ -156,17 +158,20 @@ fn check_shape_compat( kernel_shape: &Shape, io_shape: &Shape, parallel_count: usize, -) -> Option { +) -> Option { + println!("kernel_shape: {:?}, io_shape: {:?}, parallel_count: {}", kernel_shape, io_shape, parallel_count); if kernel_shape.len() == io_shape.len() { if *kernel_shape == *io_shape { - Some(true) + Some(parallel_count) } else { None } } else if kernel_shape.len() + 1 == io_shape.len() { if io_shape.iter().skip(1).eq(kernel_shape.iter()) { if io_shape[0] == parallel_count { - Some(false) + Some(1) + } else if (parallel_count / io_shape[0]).is_power_of_two() { + Some(parallel_count / io_shape[0]) } else { None } @@ -299,18 +304,15 @@ impl>> Context { &self, values: &[SIMDField], s: &mut [SIMDField], - is_broadcast: bool, + is_broadcast: usize, parallel_index: usize, chunk_size: Option, ) { - if is_broadcast { - s.copy_from_slice(values); - } else { - let chunk_size = chunk_size.unwrap(); - s.copy_from_slice( - &values[chunk_size * parallel_index..chunk_size * (parallel_index + 1)], - ); - } + let chunk_size = chunk_size.unwrap(); + let start_index = chunk_size * parallel_index % values.len(); + s.copy_from_slice( + &values[start_index..(start_index + chunk_size)], + ); } pub fn call_kernel( @@ -332,7 +334,7 @@ impl>> Context { .enumerate() { if !spec.is_input { - is_broadcast.push(false); + is_broadcast.push(1); continue; } /*println!( @@ -350,7 +352,8 @@ impl>> Context { .as_ref() .unwrap() .shape_history - .get_initial_split_list(!ib); + .get_initial_split_list(ib/num_parallel+1); + // let isl = vec![1,64,4096]; let t = io.as_ref().unwrap().id; self.device_memories[t].required_shape_products = merge_shape_products( &isl, @@ -370,8 +373,9 @@ impl>> Context { } } } + println!("is_broadcast: {:?}", is_broadcast); for (io_spec, ib) in kernel.io_specs().iter().zip(is_broadcast.iter()) { - if io_spec.is_output && *ib { + if io_spec.is_output && *ib!=1 { panic!("Output is broadcasted, but it shouldn't be"); } } @@ -381,11 +385,12 @@ impl>> Context { let mut outputs_tmp = vec![Vec::new(); kernel.io_specs().len()]; let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()]; let mut chunk_sizes: Vec> = vec![None; kernel.io_specs().len()]; - for (((input, &ib), ir_inputs), chunk_size) in ios + for ((((input, &ib), ir_inputs), chunk_size), kernel_shape) in ios .iter() .zip(is_broadcast.iter()) .zip(ir_inputs_all.iter_mut()) .zip(chunk_sizes.iter_mut()) + .zip(kernel.io_shapes().iter()) { if input.is_none() { continue; @@ -394,11 +399,10 @@ impl>> Context { let values = handle .shape_history .permute_vec(&self.device_memories[handle.id].values); - if !ib { - *chunk_size = Some(values.len() / num_parallel); - } + *chunk_size = Some(kernel_shape.iter().product()); *ir_inputs = values; } + println!("chunk_sizes: {:?}", chunk_sizes); let mut ir_inputs_per_parallel = Vec::new(); for parallel_i in 0..num_parallel { let mut ir_inputs = vec![SIMDField::::zero(); kernel.ir_for_calling().input_size()]; @@ -469,7 +473,7 @@ impl>> Context { .as_ref() .unwrap() .shape_history - .get_initial_split_list(true), + .get_initial_split_list(1), &self.device_memories[id].required_shape_products, ); *output = handle.clone(); @@ -513,7 +517,7 @@ impl>> Context { .zip(kernel_call.input_handles.iter()) .zip(kernel_call.is_broadcast.iter()) { - if !spec.is_input || ib { + if !spec.is_input || ib > 1 { continue; } let pad_shape = get_pad_shape(input_handle).unwrap(); @@ -526,7 +530,7 @@ impl>> Context { .zip(kernel_call.output_handles.iter()) .zip(kernel_call.is_broadcast.iter()) { - if !spec.is_output || ib { + if !spec.is_output || ib > 1 { continue; } let pad_shape = get_pad_shape(output_handle).unwrap(); @@ -549,7 +553,7 @@ impl>> Context { if x != 1 && x != kernel_call.num_parallel { let sh_tmp = handle.shape_history.reshape(&[x, total / x]); dm.required_shape_products = merge_shape_products( - &sh_tmp.get_initial_split_list(true), + &sh_tmp.get_initial_split_list(1), &dm.required_shape_products, ); } @@ -622,7 +626,7 @@ impl>> Context { let mut psi = Vec::new(); for (s, &ib) in pad_shapes_input.iter().zip(kernel_call.is_broadcast.iter()) { psi.push(s.as_ref().map(|t| { - if ib { + if ib == kernel_call.num_parallel { t.0.clone() } else { keep_shape_since(&t.0, kernel_call.num_parallel) @@ -635,7 +639,7 @@ impl>> Context { .zip(kernel_call.is_broadcast.iter()) { pso.push(s.as_ref().map(|t| { - if ib { + if ib == kernel_call.num_parallel { t.0.clone() } else { keep_shape_since(&t.0, kernel_call.num_parallel) @@ -661,7 +665,7 @@ impl>> Context { commitment_indices.push(handle.as_ref().unwrap().id); commitment_bit_orders.push(shape.1.clone()); is_broadcast.push(ib); - if !ib { + if ib == 1 { any_shape = Some(shape.0.clone()); } } @@ -678,7 +682,7 @@ impl>> Context { commitment_indices.push(handle.as_ref().unwrap().id); commitment_bit_orders.push(shape.1.clone()); is_broadcast.push(ib); - if !ib { + if ib == 1 { any_shape = Some(shape.0.clone()); } } @@ -695,7 +699,7 @@ impl>> Context { dm_max += 1; commitment_bit_orders.push((0..n.trailing_zeros() as usize).collect()); commitments_lens.push(n); - is_broadcast.push(false); + is_broadcast.push(1); } let kernel_id = self.kernels.add(&kernel); @@ -778,9 +782,8 @@ impl>> Context { let values = handle .shape_history .permute_vec(&self.device_memories[handle.id].values); - if !ib { - *chunk_size = Some(values.len() / kernel_call.num_parallel); - } + let kernel_shape = handle.shape_history.shape(); + *chunk_size = Some(kernel_shape.iter().product()); *ir_inputs = values; } for (((output, &ib), ir_inputs), chunk_size) in kernel_call @@ -800,7 +803,7 @@ impl>> Context { let values = handle .shape_history .permute_vec(&self.device_memories[handle.id].values); - assert!(!ib); + assert!(ib == 1); *chunk_size = Some(values.len() / kernel_call.num_parallel); *ir_inputs = values; } @@ -823,7 +826,7 @@ impl>> Context { self.ir_copy_from_device_memory( ir_inputs, &mut inputs[*input_start..*input_end], - chunk_size.is_none(), + chunk_size.unwrap_or(2), parallel_i, *chunk_size, ); @@ -843,7 +846,7 @@ impl>> Context { self.ir_copy_from_device_memory( ir_outputs, &mut inputs[*output_start..*output_end], - chunk_size.is_none(), + chunk_size.unwrap_or(1), parallel_i, *chunk_size, ); diff --git a/expander_compiler/src/zkcuda/mpi_mem_share.rs b/expander_compiler/src/zkcuda/mpi_mem_share.rs index 5c79c7db..dc2042a8 100644 --- a/expander_compiler/src/zkcuda/mpi_mem_share.rs +++ b/expander_compiler/src/zkcuda/mpi_mem_share.rs @@ -132,7 +132,7 @@ impl MPISharedMemory for ProofTemplate { .map(|_| BitOrder::new_from_memory(ptr)) .collect(); let parallel_count = usize::new_from_memory(ptr); - let is_broadcast = Vec::::new_from_memory(ptr); + let is_broadcast = Vec::::new_from_memory(ptr); ProofTemplate { kernel_id, diff --git a/expander_compiler/src/zkcuda/proving_system/common.rs b/expander_compiler/src/zkcuda/proving_system/common.rs index b8cd617a..431ae429 100644 --- a/expander_compiler/src/zkcuda/proving_system/common.rs +++ b/expander_compiler/src/zkcuda/proving_system/common.rs @@ -14,7 +14,7 @@ pub fn check_inputs( kernel: &Kernel, values: &[&[SIMDField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) { if kernel.layered_circuit_input().len() != values.len() { panic!("Input size mismatch"); @@ -23,11 +23,7 @@ pub fn check_inputs( panic!("Input size mismatch"); } for i in 0..kernel.layered_circuit_input().len() { - if is_broadcast[i] { - if kernel.layered_circuit_input()[i].len != values[i].len() { - panic!("Input size mismatch"); - } - } else if kernel.layered_circuit_input()[i].len * parallel_count != values[i].len() { + if kernel.layered_circuit_input()[i].len != values[i].len() / (parallel_count / is_broadcast[i]) { panic!("Input size mismatch"); } } @@ -37,24 +33,20 @@ pub fn prepare_inputs( layered_circuit: &Circuit, partition_info: &[LayeredCircuitInputVec], values: &[&[SIMDField]], - is_broadcast: &[bool], + is_broadcast: &[usize], + parallel_count: usize, parallel_index: usize, ) -> Vec> { let mut lc_input = vec![SIMDField::::zero(); layered_circuit.input_size()]; for ((input, value), ib) in partition_info.iter().zip(values.iter()).zip(is_broadcast) { - if *ib { - for (i, x) in value.iter().enumerate() { - lc_input[input.offset + i] = *x; - } - } else { - for (i, x) in value - .iter() - .skip(parallel_index * input.len) - .take(input.len) - .enumerate() - { - lc_input[input.offset + i] = *x; - } + let parallel_index = parallel_index % (parallel_count / ib); + for (i, x) in value + .iter() + .skip(parallel_index * input.len) + .take(input.len) + .enumerate() + { + lc_input[input.offset + i] = *x; } } lc_input diff --git a/expander_compiler/src/zkcuda/proving_system/dummy.rs b/expander_compiler/src/zkcuda/proving_system/dummy.rs index b18beb6b..64b19c1c 100644 --- a/expander_compiler/src/zkcuda/proving_system/dummy.rs +++ b/expander_compiler/src/zkcuda/proving_system/dummy.rs @@ -74,7 +74,7 @@ impl KernelWiseProvingSystem for DummyProvingSystem { _commitments_state: &[&Self::CommitmentState], commitments_values: &[&[SIMDField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> DummyProof { check_inputs(kernel, commitments_values, parallel_count, is_broadcast); let mut res = vec![]; @@ -84,6 +84,7 @@ impl KernelWiseProvingSystem for DummyProvingSystem { kernel.layered_circuit_input(), commitments_values, is_broadcast, + parallel_count, i, ); let (_, cond) = kernel @@ -104,7 +105,7 @@ impl KernelWiseProvingSystem for DummyProvingSystem { proof: &Self::Proof, commitments: &[&Self::Commitment], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> bool { let values = commitments.iter().map(|c| &c.vals[..]).collect::>(); check_inputs(kernel, &values, parallel_count, is_broadcast); @@ -114,6 +115,7 @@ impl KernelWiseProvingSystem for DummyProvingSystem { kernel.layered_circuit_input(), &values, is_broadcast, + parallel_count, i, ); let (_, cond) = kernel diff --git a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs index 836201ac..aecd2a0c 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs @@ -61,7 +61,7 @@ where _commitments_state: &[&Self::CommitmentState], commitments_values: &[&[SIMDField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> Self::Proof { let timer = Timer::new("prove", true); check_inputs(kernel, commitments_values, parallel_count, is_broadcast); @@ -113,7 +113,7 @@ where proof: &Self::Proof, commitments: &[&Self::Commitment], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> bool { let timer = Timer::new("verify", true); let mut expander_circuit = kernel.layered_circuit().export_to_expander_flatten(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index ad7b3121..c1a35888 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -48,7 +48,7 @@ where /// This function returns the local values for each parallel instance based on the global values and the broadcast information. pub fn get_local_vals<'vals_life, F: Field>( global_vals: &'vals_life [impl AsRef<[F]>], - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_index: usize, parallel_num: usize, ) -> Vec<&'vals_life [F]> { @@ -56,12 +56,9 @@ pub fn get_local_vals<'vals_life, F: Field>( .iter() .zip(is_broadcast.iter()) .map(|(vals, is_broadcast)| { - if *is_broadcast { - vals.as_ref() - } else { - let local_val_len = vals.as_ref().len() / parallel_num; - &vals.as_ref()[local_val_len * parallel_index..local_val_len * (parallel_index + 1)] - } + let local_val_len = vals.as_ref().len() / (parallel_num / is_broadcast); + let start_index = local_val_len * parallel_index % vals.as_ref().len(); + &vals.as_ref()[start_index..local_val_len + start_index] }) .collect::>() } @@ -75,10 +72,13 @@ pub fn prepare_inputs_with_local_vals( ) -> Vec { let mut input_vals = vec![F::ZERO; input_len]; for (partition, val) in partition_info.iter().zip(local_commitment_values.iter()) { + // println!("partiion.len: {}, val.len: {}", partition.len, val.as_ref().len()); + // panic!("partiion.len: {}, val.len: {}", partition.len, val.as_ref().len()); assert!(partition.len == val.as_ref().len()); input_vals[partition.offset..partition.offset + partition.len] .copy_from_slice(val.as_ref()); } + // panic!("1"); input_vals } @@ -133,22 +133,23 @@ pub fn partition_challenge_and_location_for_pcs_no_mpi( total_vals_len: usize, parallel_index: usize, parallel_count: usize, - is_broadcast: bool, + is_broadcast: usize, ) -> (ExpanderSingleVarChallenge, Vec) { assert_eq!(gkr_challenge.r_mpi.len(), 0); let mut challenge = gkr_challenge.clone(); let zero = F::ChallengeField::ZERO; - if is_broadcast { + if is_broadcast == parallel_count { let n_vals_vars = total_vals_len.ilog2() as usize; let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); challenge.rz.resize(n_vals_vars, zero); (challenge, component_idx_vars) } else { - let n_vals_vars = (total_vals_len / parallel_count).ilog2() as usize; + let real_parallel_count = parallel_count / is_broadcast; + let n_vals_vars = (total_vals_len / real_parallel_count).ilog2() as usize; let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); challenge.rz.resize(n_vals_vars, zero); - let n_index_vars = parallel_count.ilog2() as usize; + let n_index_vars = real_parallel_count.ilog2() as usize; let index_vars = (0..n_index_vars) .map(|i| F::ChallengeField::from(((parallel_index >> i) & 1) as u32)) .collect::>(); @@ -206,7 +207,7 @@ pub fn partition_gkr_claims_and_open_pcs_no_mpi_impl( gkr_claim: &ExpanderSingleVarChallenge, global_vals: &[impl AsRef<[::SimdCircuitField]>], p_keys: &ExpanderProverSetup, - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_index: usize, parallel_num: usize, transcript: &mut C::TranscriptConfig, @@ -235,7 +236,7 @@ pub fn partition_gkr_claims_and_open_pcs_no_mpi( gkr_claim: &ExpanderDualVarChallenge, global_vals: &[impl AsRef<[::SimdCircuitField]>], p_keys: &ExpanderProverSetup, - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_index: usize, parallel_num: usize, transcript: &mut C::TranscriptConfig, diff --git a/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs index d5c40f2a..14abf848 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs @@ -76,7 +76,7 @@ pub fn verify_pcs_opening_and_aggregation_no_mpi_impl( challenge: &ExpanderSingleVarChallenge, y: &::ChallengeField, commitments: &[&ExpanderCommitment], - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_index: usize, parallel_count: usize, transcript: &mut C::TranscriptConfig, @@ -144,7 +144,7 @@ pub fn verify_pcs_opening_and_aggregation_no_mpi( claim_v0: ::ChallengeField, claim_v1: Option<::ChallengeField>, commitments: &[&ExpanderCommitment], - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_index: usize, parallel_count: usize, transcript: &mut C::TranscriptConfig, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index bc980372..44898127 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -63,7 +63,10 @@ where (None, None) }; commit_timer.stop(); - + println!("enter Here!$$$$:{:}", values.len()); + for item in values { + println!("{}", item.as_ref().len()); + } let mut vals_ref = vec![]; let mut challenges = vec![]; @@ -78,7 +81,9 @@ where .iter() .map(|&idx| values[idx].as_ref()) .collect::>(); - + for commitment_value in &commitment_values { + println!("commitment_value: {}", commitment_value.len()); + } let single_kernel_gkr_timer = Timer::new("small gkr kernel", global_mpi_config.is_root()); let gkr_end_state = prove_kernel_gkr_no_oversubscribe::< @@ -200,7 +205,7 @@ pub fn prove_kernel_gkr_no_oversubscribe( kernel: &Kernel, commitments_values: &[&[F::SimdCircuitField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], n_bytes_profiler: &mut NBytesProfiler, ) -> Option<(T, ExpanderDualVarChallenge)> where @@ -324,7 +329,7 @@ pub fn prove_kernel_gkr_internal( kernel: &Kernel, commitments_values: &[&[FBasic::SimdCircuitField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], n_bytes_profiler: &mut NBytesProfiler, ) -> Option<(T, ExpanderDualVarChallenge)> where @@ -338,6 +343,9 @@ where let world_size = mpi_config.world_size(); let n_copies = parallel_count / world_size; + for &commitment_value in commitments_values { + println!("commitment_value: {}", commitment_value.len()); + } let local_commitment_values = get_local_vals_multi_copies( commitments_values, is_broadcast, @@ -345,6 +353,9 @@ where n_copies, parallel_count, ); + for commitment_value in local_commitment_values[0].clone() { + println!("local_commitment_values: {}", commitment_value.len()); + } let (mut expander_circuit, mut prover_scratch) = prepare_expander_circuit::(kernel, world_size); @@ -365,7 +376,7 @@ where pub fn get_local_vals_multi_copies<'vals_life, F: Field>( global_vals: &'vals_life [impl AsRef<[F]>], - is_broadcast: &[bool], + is_broadcast: &[usize], local_world_rank: usize, n_copies: usize, parallel_count: usize, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index 85c5955f..0379999e 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -59,7 +59,10 @@ where values: &[impl AsRef<[SIMDField]>], ) -> Option>> { let mut n_bytes_profiler = NBytesProfiler::new(); - + println!("enter Here!####:{:}", values.len()); + for item in values { + println!("{}", item.as_ref().len()); + } #[cfg(feature = "zkcuda_profile")] { use arith::SimdField; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index 5605daf0..617a2f90 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -127,7 +127,7 @@ pub fn prove_kernel_gkr( kernel: &Kernel, commitments_values: &[&[F::SimdCircuitField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> Option<(T, ExpanderDualVarChallenge)> where F: FieldEngine, @@ -169,22 +169,25 @@ pub fn partition_challenge_and_location_for_pcs_mpi( gkr_challenge: &ExpanderSingleVarChallenge, total_vals_len: usize, parallel_count: usize, - is_broadcast: bool, + broadcast_num: usize, ) -> (ExpanderSingleVarChallenge, Vec) { let mut challenge = gkr_challenge.clone(); let zero = F::ChallengeField::ZERO; - if is_broadcast { + if broadcast_num == parallel_count { let n_vals_vars = total_vals_len.ilog2() as usize; let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); challenge.rz.resize(n_vals_vars, zero); + println!("broadcast challenge.rz.len() = {}", challenge.rz.len()); challenge.r_mpi.clear(); (challenge, component_idx_vars) } else { - let n_vals_vars = (total_vals_len / parallel_count).ilog2() as usize; + let n_vals_vars = (total_vals_len / (parallel_count / broadcast_num)).ilog2() as usize; let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); challenge.rz.resize(n_vals_vars, zero); - - challenge.rz.extend_from_slice(&challenge.r_mpi); + //TODO: what is challenge.r_mpi, why need it when broadcast is false? + println!("challenge.rz.len() = {}", challenge.rz.len()); + challenge.rz.extend_from_slice(&challenge.r_mpi[..(parallel_count / broadcast_num).ilog2() as usize]); + println!("after challenge.rz.len() = {}", challenge.rz.len()); challenge.r_mpi.clear(); (challenge, component_idx_vars) } @@ -196,7 +199,7 @@ pub fn partition_single_gkr_claim_and_open_pcs_mpi( commitments_values: &[impl AsRef<[SIMDField]>], commitments_state: &[&ExpanderCommitmentState], gkr_challenge: &ExpanderSingleVarChallenge, - is_broadcast: &[bool], + is_broadcast: &[usize], transcript: &mut C::TranscriptConfig, ) { let parallel_count = 1 << gkr_challenge.r_mpi.len(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs index 5b1d28a6..3d17ac3a 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs @@ -30,7 +30,7 @@ pub fn verify_kernel( proof: &ExpanderProof, commitments: &[&ExpanderCommitment], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> bool where C: GKREngine, @@ -86,7 +86,7 @@ pub fn verify_pcs_opening_and_aggregation_mpi_impl( challenge: &ExpanderSingleVarChallenge, y: &::ChallengeField, commitments: &[&ExpanderCommitment], - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_count: usize, transcript: &mut C::TranscriptConfig, ) -> bool @@ -145,7 +145,7 @@ pub fn verify_pcs_opening_and_aggregation_mpi( claim_v0: ::ChallengeField, claim_v1: Option<::ChallengeField>, commitments: &[&ExpanderCommitment], - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_count: usize, transcript: &mut C::TranscriptConfig, ) -> bool diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs index 72545956..bdcb1123 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs @@ -189,7 +189,7 @@ where pub fn extract_pcs_claims<'a, C: GKREngine>( commitments_values: &[&'a [SIMDField]], gkr_challenge: &ExpanderSingleVarChallenge, - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_count: usize, ) -> ( Vec<&'a [SIMDField]>, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs index b06be3df..b08e297c 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs @@ -26,7 +26,7 @@ use crate::{ fn verifier_extract_pcs_claims<'a, C, ECCConfig>( commitments: &[&'a ExpanderCommitment], gkr_challenge: &ExpanderSingleVarChallenge, - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_count: usize, ) -> ( Vec<&'a ExpanderCommitment>, diff --git a/expander_compiler/src/zkcuda/proving_system/traits.rs b/expander_compiler/src/zkcuda/proving_system/traits.rs index 539a0f43..82082f1a 100644 --- a/expander_compiler/src/zkcuda/proving_system/traits.rs +++ b/expander_compiler/src/zkcuda/proving_system/traits.rs @@ -30,7 +30,7 @@ pub trait KernelWiseProvingSystem { commitments_state: &[&Self::CommitmentState], commitments_values: &[&[SIMDField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> Self::Proof; fn verify_kernel( @@ -39,7 +39,7 @@ pub trait KernelWiseProvingSystem { proof: &Self::Proof, commitments: &[&Self::Commitment], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> bool; fn post_process() {} diff --git a/expander_compiler/src/zkcuda/shape.rs b/expander_compiler/src/zkcuda/shape.rs index 443651d2..c4a60152 100644 --- a/expander_compiler/src/zkcuda/shape.rs +++ b/expander_compiler/src/zkcuda/shape.rs @@ -209,6 +209,8 @@ pub fn keep_shape_since(shape: &[usize], x: usize) -> Vec { p *= y; if p == x { return shape[i + 1..].to_vec(); + } else if (x/p).is_power_of_two() { + return shape[i + 1..].to_vec(); } } unreachable!() @@ -242,8 +244,8 @@ impl ShapeHistory { // Suppose we need to ensure that the current shape is legal // This function returns a list of dimension lengths where the initial vector must be split // split_first_dim: first dimension of current shape will be split - pub fn get_initial_split_list(&self, split_first_dim: bool) -> Vec { - let last_entry = self.entries.last().unwrap().minimize(split_first_dim); + pub fn get_initial_split_list(&self, split_first_dim: usize) -> Vec { + let last_entry = self.entries.last().unwrap().minimize(split_first_dim==1); let mut split_list = prefix_products(&last_entry.shape); for e in self.entries.iter().rev().skip(1) { let e = e.minimize(false); @@ -480,23 +482,23 @@ mod tests { fn test_get_initial_split_list() { let sh = ShapeHistory::new(vec![16, 9]); let sh = sh.reshape(&[9, 16]); - assert_eq!(sh.get_initial_split_list(false), vec![1, 144]); - assert_eq!(sh.get_initial_split_list(true), vec![1, 9, 144]); + assert_eq!(sh.get_initial_split_list(9), vec![1, 144]); + assert_eq!(sh.get_initial_split_list(1), vec![1, 9, 144]); let sh = sh.reshape(&[3, 16, 3]); - assert_eq!(sh.get_initial_split_list(true), vec![1, 3, 144]); + assert_eq!(sh.get_initial_split_list(1), vec![1, 3, 144]); let sh = sh.reshape(&[2, 2, 2, 2, 3, 3]); let sh = sh.transpose(&[1, 0, 2, 3, 4, 5]); - assert_eq!(sh.get_initial_split_list(false), vec![1, 2, 4, 144]); - assert_eq!(sh.get_initial_split_list(true), vec![1, 2, 4, 144]); + assert_eq!(sh.get_initial_split_list(2), vec![1, 2, 4, 144]); + assert_eq!(sh.get_initial_split_list(1), vec![1, 2, 4, 144]); let sh = sh.reshape(&[16, 9]); - assert_eq!(sh.get_initial_split_list(false), vec![1, 2, 4, 144]); - assert_eq!(sh.get_initial_split_list(true), vec![1, 2, 4, 16, 144]); + assert_eq!(sh.get_initial_split_list(2), vec![1, 2, 4, 144]); + assert_eq!(sh.get_initial_split_list(1), vec![1, 2, 4, 16, 144]); let sh = sh.transpose(&[1, 0]); - assert_eq!(sh.get_initial_split_list(false), vec![1, 2, 4, 16, 144]); - assert_eq!(sh.get_initial_split_list(true), vec![1, 2, 4, 16, 144]); + assert_eq!(sh.get_initial_split_list(2), vec![1, 2, 4, 16, 144]); + assert_eq!(sh.get_initial_split_list(1), vec![1, 2, 4, 16, 144]); let sh = sh.reshape(&[3, 3, 16]); - assert_eq!(sh.get_initial_split_list(false), vec![1, 2, 4, 16, 144]); - assert_eq!(sh.get_initial_split_list(true), vec![1, 2, 4, 16, 48, 144]); + assert_eq!(sh.get_initial_split_list(2), vec![1, 2, 4, 16, 144]); + assert_eq!(sh.get_initial_split_list(1), vec![1, 2, 4, 16, 48, 144]); } #[test] @@ -504,9 +506,9 @@ mod tests { let sh = ShapeHistory::new(vec![16, 9]); let sh = sh.transpose(&[1, 0]); let sh = sh.reshape(&[16, 9]); - sh.get_initial_split_list(false); + sh.get_initial_split_list(2); assert!(std::panic::catch_unwind(|| { - sh.get_initial_split_list(true); + sh.get_initial_split_list(1); }) .is_err()); } diff --git a/expander_compiler/tests/zkcuda/zkcuda_matmul.rs b/expander_compiler/tests/zkcuda/zkcuda_matmul.rs index 7b20e63d..7fc90b4c 100644 --- a/expander_compiler/tests/zkcuda/zkcuda_matmul.rs +++ b/expander_compiler/tests/zkcuda/zkcuda_matmul.rs @@ -3,97 +3,145 @@ use expander_compiler::zkcuda::proving_system::Expander; use expander_compiler::zkcuda::proving_system::ProvingSystem; use expander_compiler::zkcuda::shape::Reshape; use expander_compiler::zkcuda::{context::*, kernel::*}; - +use serdes::ExpSerde; +use serde::{Deserialize, Serialize}; +use expander_compiler::zkcuda::proving_system::expander::config::{ZKCudaBN254Hyrax, ZKCudaBN254HyraxBatchPCS, ZKCudaBN254KZG, ZKCudaBN254KZGBatchPCS,}; +use expander_compiler::zkcuda::proving_system::{ExpanderNoOverSubscribe, ParallelizedExpander,}; +const SIZE: usize = 64; +const SIZE2: usize = 4096; #[kernel] fn mul_line( api: &mut API, - a: &[InputVariable; 32], - b: &[[InputVariable; 64]; 32], - c: &mut [OutputVariable; 64], + a: &[InputVariable; SIZE], + b: &[InputVariable; SIZE], + c: &mut [OutputVariable; 1], ) { - for j in 0..64 { - c[j] = api.constant(0); - } - for i in 0..32 { - for j in 0..64 { - let t = api.mul(a[i], b[i][j]); - c[j] = api.add(c[j], t); - } + let mut sum = api.constant(0); + for i in 0..SIZE { + let t = api.mul(a[i], b[i]); + sum = api.add(sum, t); } + c[0] = sum; } #[kernel] -fn sum_8_elements(api: &mut API, a: &[InputVariable; 8], b: &mut OutputVariable) { +fn sum_8_elements(api: &mut API, a: &[InputVariable; SIZE2], b: &mut OutputVariable) { let mut sum = api.constant(0); - for i in 0..8 { + for i in 0..SIZE2 { sum = api.add(sum, a[i]); } *b = sum; } +// #[test] +// fn zkcuda_matmul_sum() { +// let kernel_mul_line: KernelPrimitive = compile_mul_line().unwrap(); +// let kernel_sum_8_elements: KernelPrimitive = compile_sum_8_elements().unwrap(); + +// let mut ctx: Context = Context::default(); + +// let mut mat_a: Vec> = vec![]; +// for i in 0..64 { +// mat_a.push(vec![]); +// for j in 0..32 { +// mat_a[i].push(M31::from((i * 233 + j + 1) as u32)); +// } +// } +// let mut mat_b: Vec> = vec![]; +// for i in 0..32 { +// mat_b.push(vec![]); +// for j in 0..64 { +// mat_b[i].push(M31::from((i * 2333 + j + 1) as u32)); +// } +// } +// let mut expected_result = M31::zero(); +// for i in 0..64 { +// for j in 0..64 { +// for k in 0..32 { +// expected_result += mat_a[i][k] * mat_b[k][j]; +// } +// } +// } + +// let a = ctx.copy_to_device(&mat_a); +// let b = ctx.copy_to_device(&mat_b); +// let mut c = None; +// call_kernel!(ctx, kernel_mul_line, 64, a, b, mut c).unwrap(); + +// let c = c.reshape(&[512, 8]); +// let mut d = None; +// call_kernel!(ctx, kernel_sum_8_elements, 512, c, mut d).unwrap(); + +// let d = d.reshape(&[64, 8]); +// let mut e = None; +// call_kernel!(ctx, kernel_sum_8_elements, 64, d, mut e).unwrap(); + +// let e = e.reshape(&[8, 8]); +// let mut f = None; +// call_kernel!(ctx, kernel_sum_8_elements, 8, e, mut f).unwrap(); + +// let f = f.reshape(&[1, 8]); +// let mut g = None; +// call_kernel!(ctx, kernel_sum_8_elements, 1, f, mut g).unwrap(); + +// let g = g.reshape(&[]); +// let result: M31 = ctx.copy_to_host(g); +// assert_eq!(result, expected_result); + +// type P = Expander; +// let computation_graph = ctx.compile_computation_graph().unwrap(); +// ctx.solve_witness().unwrap(); +// let (prover_setup, verifier_setup) = P::setup(&computation_graph); +// let proof = P::prove( +// &prover_setup, +// &computation_graph, +// ctx.export_device_memories(), +// ); +// assert!(P::verify(&verifier_setup, &computation_graph, &proof)); +// } + #[test] fn zkcuda_matmul_sum() { - let kernel_mul_line: KernelPrimitive = compile_mul_line().unwrap(); - let kernel_sum_8_elements: KernelPrimitive = compile_sum_8_elements().unwrap(); - - let mut ctx: Context = Context::default(); + let kernel_mul_line: KernelPrimitive = compile_mul_line().unwrap(); + // println!("kernnel_mul_line: {:?}", kernel_mul_line); + // let file = std::fs::File::create("kernel_mul_line_circuit.txt").unwrap(); + // let writer = std::io::BufWriter::new(file); + // kernel_mul_line.serialize_into(writer); + let parallel_count = 64; + let mut ctx: Context = Context::default(); - let mut mat_a: Vec> = vec![]; - for i in 0..64 { + let mut mat_a: Vec> = vec![]; + for i in 0..parallel_count { mat_a.push(vec![]); - for j in 0..32 { - mat_a[i].push(M31::from((i * 233 + j + 1) as u32)); - } - } - let mut mat_b: Vec> = vec![]; - for i in 0..32 { - mat_b.push(vec![]); for j in 0..64 { - mat_b[i].push(M31::from((i * 2333 + j + 1) as u32)); + mat_a[i].push(BN254Fr::from((i * 233 + j + 1) as u32)); } } - let mut expected_result = M31::zero(); - for i in 0..64 { + let mut mat_b: Vec> = vec![]; + for i in 0..parallel_count/4 { + mat_b.push(vec![]); for j in 0..64 { - for k in 0..32 { - expected_result += mat_a[i][k] * mat_b[k][j]; - } + mat_b[i].push(BN254Fr::from((i * 2333 + j + 11111) as u32)); } } let a = ctx.copy_to_device(&mat_a); let b = ctx.copy_to_device(&mat_b); let mut c = None; - call_kernel!(ctx, kernel_mul_line, 64, a, b, mut c).unwrap(); - - let c = c.reshape(&[512, 8]); - let mut d = None; - call_kernel!(ctx, kernel_sum_8_elements, 512, c, mut d).unwrap(); - - let d = d.reshape(&[64, 8]); - let mut e = None; - call_kernel!(ctx, kernel_sum_8_elements, 64, d, mut e).unwrap(); - - let e = e.reshape(&[8, 8]); - let mut f = None; - call_kernel!(ctx, kernel_sum_8_elements, 8, e, mut f).unwrap(); - - let f = f.reshape(&[1, 8]); - let mut g = None; - call_kernel!(ctx, kernel_sum_8_elements, 1, f, mut g).unwrap(); - - let g = g.reshape(&[]); - let result: M31 = ctx.copy_to_host(g); - assert_eq!(result, expected_result); - - type P = Expander; + call_kernel!(ctx, kernel_mul_line, parallel_count, a, b, mut c).unwrap(); let computation_graph = ctx.compile_computation_graph().unwrap(); ctx.solve_witness().unwrap(); - let (prover_setup, verifier_setup) = P::setup(&computation_graph); - let proof = P::prove( - &prover_setup, - &computation_graph, - ctx.export_device_memories(), - ); - assert!(P::verify(&verifier_setup, &computation_graph, &proof)); + let (prover_setup, _) = ExpanderNoOverSubscribe::::setup(&computation_graph); + let proof = ExpanderNoOverSubscribe::::prove(&prover_setup, &computation_graph, ctx.export_device_memories()); + // let file = std::fs::File::create("proof.txt").unwrap(); + // let writer = std::io::BufWriter::new(file); + // proof.serialize_into(writer); + as ProvingSystem>::post_process(); } +#[test] +fn zkcuda_sum() { + let kernel_mul_line: KernelPrimitive = compile_sum_8_elements().unwrap(); + let file = std::fs::File::create("kernel_sum_8_elements.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + kernel_mul_line.serialize_into(writer); +} \ No newline at end of file From 307634da5f43e10f524686937621e1438141b5cd Mon Sep 17 00:00:00 2001 From: hczphn Date: Thu, 14 Aug 2025 02:51:54 +0000 Subject: [PATCH 02/13] fix get shape bug --- expander_compiler/src/zkcuda/context.rs | 10 ++-- expander_compiler/src/zkcuda/kernel.rs | 1 - .../proving_system/expander/prove_impl.rs | 2 - .../expander_no_oversubscribe/prove_impl.rs | 53 ++++++++++++++----- .../expander_no_oversubscribe/server_fn.rs | 4 -- .../expander_parallelized/prove_impl.rs | 3 -- expander_compiler/src/zkcuda/shape.rs | 4 +- 7 files changed, 44 insertions(+), 33 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index ae3579f2..48b83cf9 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -159,7 +159,6 @@ fn check_shape_compat( io_shape: &Shape, parallel_count: usize, ) -> Option { - println!("kernel_shape: {:?}, io_shape: {:?}, parallel_count: {}", kernel_shape, io_shape, parallel_count); if kernel_shape.len() == io_shape.len() { if *kernel_shape == *io_shape { Some(parallel_count) @@ -373,7 +372,6 @@ impl>> Context { } } } - println!("is_broadcast: {:?}", is_broadcast); for (io_spec, ib) in kernel.io_specs().iter().zip(is_broadcast.iter()) { if io_spec.is_output && *ib!=1 { panic!("Output is broadcasted, but it shouldn't be"); @@ -402,7 +400,6 @@ impl>> Context { *chunk_size = Some(kernel_shape.iter().product()); *ir_inputs = values; } - println!("chunk_sizes: {:?}", chunk_sizes); let mut ir_inputs_per_parallel = Vec::new(); for parallel_i in 0..num_parallel { let mut ir_inputs = vec![SIMDField::::zero(); kernel.ir_for_calling().input_size()]; @@ -580,7 +577,6 @@ impl>> Context { self.state = ContextState::ComputationGraphDone; let dm_shapes = self.propagate_and_get_shapes(); - let (mut cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg { for (i, kernel) in cg.kernels.iter().enumerate() { assert_eq!(self.kernels.add(kernel), i); @@ -628,9 +624,9 @@ impl>> Context { psi.push(s.as_ref().map(|t| { if ib == kernel_call.num_parallel { t.0.clone() - } else { - keep_shape_since(&t.0, kernel_call.num_parallel) - } + } else{ + keep_shape_since(&t.0, kernel_call.num_parallel/ib) + } })); } let mut pso = Vec::new(); diff --git a/expander_compiler/src/zkcuda/kernel.rs b/expander_compiler/src/zkcuda/kernel.rs index b7ae0f20..dd9e45aa 100644 --- a/expander_compiler/src/zkcuda/kernel.rs +++ b/expander_compiler/src/zkcuda/kernel.rs @@ -301,7 +301,6 @@ fn reorder_ir_inputs( lc_in[i].len = n; assert!(var_max % n == 0); let im = shape_padded_mapping(&pad_shapes[i]); - // println!("{:?}", im.mapping()); for (j, &k) in im.mapping().iter().enumerate() { var_new_id[prev + k + 1] = var_max + j + 1; } diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index c1a35888..708f2c74 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -72,8 +72,6 @@ pub fn prepare_inputs_with_local_vals( ) -> Vec { let mut input_vals = vec![F::ZERO; input_len]; for (partition, val) in partition_info.iter().zip(local_commitment_values.iter()) { - // println!("partiion.len: {}, val.len: {}", partition.len, val.as_ref().len()); - // panic!("partiion.len: {}, val.len: {}", partition.len, val.as_ref().len()); assert!(partition.len == val.as_ref().len()); input_vals[partition.offset..partition.offset + partition.len] .copy_from_slice(val.as_ref()); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index 44898127..b002e710 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -63,10 +63,6 @@ where (None, None) }; commit_timer.stop(); - println!("enter Here!$$$$:{:}", values.len()); - for item in values { - println!("{}", item.as_ref().len()); - } let mut vals_ref = vec![]; let mut challenges = vec![]; @@ -81,9 +77,6 @@ where .iter() .map(|&idx| values[idx].as_ref()) .collect::>(); - for commitment_value in &commitment_values { - println!("commitment_value: {}", commitment_value.len()); - } let single_kernel_gkr_timer = Timer::new("small gkr kernel", global_mpi_config.is_root()); let gkr_end_state = prove_kernel_gkr_no_oversubscribe::< @@ -318,6 +311,46 @@ where is_broadcast, n_bytes_profiler, ), + 4096 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 8192 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 16384 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 32768 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 65536 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), _ => { panic!("Unsupported parallel count: {parallel_count}"); } @@ -343,9 +376,6 @@ where let world_size = mpi_config.world_size(); let n_copies = parallel_count / world_size; - for &commitment_value in commitments_values { - println!("commitment_value: {}", commitment_value.len()); - } let local_commitment_values = get_local_vals_multi_copies( commitments_values, is_broadcast, @@ -353,9 +383,6 @@ where n_copies, parallel_count, ); - for commitment_value in local_commitment_values[0].clone() { - println!("local_commitment_values: {}", commitment_value.len()); - } let (mut expander_circuit, mut prover_scratch) = prepare_expander_circuit::(kernel, world_size); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index 0379999e..a0bc026e 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -59,10 +59,6 @@ where values: &[impl AsRef<[SIMDField]>], ) -> Option>> { let mut n_bytes_profiler = NBytesProfiler::new(); - println!("enter Here!####:{:}", values.len()); - for item in values { - println!("{}", item.as_ref().len()); - } #[cfg(feature = "zkcuda_profile")] { use arith::SimdField; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index 617a2f90..d3f961f3 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -177,7 +177,6 @@ pub fn partition_challenge_and_location_for_pcs_mpi( let n_vals_vars = total_vals_len.ilog2() as usize; let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); challenge.rz.resize(n_vals_vars, zero); - println!("broadcast challenge.rz.len() = {}", challenge.rz.len()); challenge.r_mpi.clear(); (challenge, component_idx_vars) } else { @@ -185,9 +184,7 @@ pub fn partition_challenge_and_location_for_pcs_mpi( let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); challenge.rz.resize(n_vals_vars, zero); //TODO: what is challenge.r_mpi, why need it when broadcast is false? - println!("challenge.rz.len() = {}", challenge.rz.len()); challenge.rz.extend_from_slice(&challenge.r_mpi[..(parallel_count / broadcast_num).ilog2() as usize]); - println!("after challenge.rz.len() = {}", challenge.rz.len()); challenge.r_mpi.clear(); (challenge, component_idx_vars) } diff --git a/expander_compiler/src/zkcuda/shape.rs b/expander_compiler/src/zkcuda/shape.rs index c4a60152..51c11e6e 100644 --- a/expander_compiler/src/zkcuda/shape.rs +++ b/expander_compiler/src/zkcuda/shape.rs @@ -209,9 +209,7 @@ pub fn keep_shape_since(shape: &[usize], x: usize) -> Vec { p *= y; if p == x { return shape[i + 1..].to_vec(); - } else if (x/p).is_power_of_two() { - return shape[i + 1..].to_vec(); - } + } } unreachable!() } From c356ea6e5675030b200197ffc1710435bae03589 Mon Sep 17 00:00:00 2001 From: hczphn Date: Mon, 1 Sep 2025 13:14:14 -0700 Subject: [PATCH 03/13] fix mapping --- Cargo.lock | 23 +++++++++++++++++++++++ expander_compiler/Cargo.toml | 1 + expander_compiler/src/zkcuda/kernel.rs | 2 +- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index c1f2f411..c99e90d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1005,6 +1005,7 @@ dependencies = [ "serdes", "sha2", "shared_memory", + "stacker", "sumcheck", "tiny-keccak", "tokio", @@ -2347,6 +2348,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "psm" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" +dependencies = [ + "cc", +] + [[package]] name = "quote" version = "1.0.40" @@ -2823,6 +2833,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stacker" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys 0.59.0", +] + [[package]] name = "static_assertions" version = "1.1.0" diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index ffed0fa4..44dc1719 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -40,6 +40,7 @@ shared_memory.workspace = true tiny-keccak.workspace = true tokio.workspace = true once_cell = "1.21.3" +stacker.workspace = true [dev-dependencies] rayon = "1.9" diff --git a/expander_compiler/src/zkcuda/kernel.rs b/expander_compiler/src/zkcuda/kernel.rs index dd9e45aa..fc6fdee6 100644 --- a/expander_compiler/src/zkcuda/kernel.rs +++ b/expander_compiler/src/zkcuda/kernel.rs @@ -302,7 +302,7 @@ fn reorder_ir_inputs( assert!(var_max % n == 0); let im = shape_padded_mapping(&pad_shapes[i]); for (j, &k) in im.mapping().iter().enumerate() { - var_new_id[prev + k + 1] = var_max + j + 1; + var_new_id[prev + j + 1] = var_max + k + 1; } var_max += n; } From 1f643caa4f01160e7f5e729ed003459ae77bb749 Mon Sep 17 00:00:00 2001 From: hczphn Date: Mon, 1 Sep 2025 15:51:20 -0700 Subject: [PATCH 04/13] fix padding 2dim case --- expander_compiler/src/zkcuda/context.rs | 3 ++- .../src/zkcuda/proving_system/expander/prove_impl.rs | 4 ++-- .../proving_system/expander_no_oversubscribe/prove_impl.rs | 4 +--- .../proving_system/expander_parallelized/client_utils.rs | 1 - .../zkcuda/proving_system/expander_parallelized/prove_impl.rs | 2 +- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 48b83cf9..4978d7ff 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -169,7 +169,7 @@ fn check_shape_compat( if io_shape.iter().skip(1).eq(kernel_shape.iter()) { if io_shape[0] == parallel_count { Some(1) - } else if (parallel_count / io_shape[0]).is_power_of_two() { + } else if parallel_count % io_shape[0] == 0 { Some(parallel_count / io_shape[0]) } else { None @@ -890,6 +890,7 @@ impl>> Context { .map(|dm| { let shape = prefix_products_to_shape(&dm.required_shape_products); let im = shape_padded_mapping(&shape); + let tmp = im.map_inputs(&dm.values); im.map_inputs(&dm.values) }) .collect() diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index 708f2c74..be0b404d 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -56,7 +56,8 @@ pub fn get_local_vals<'vals_life, F: Field>( .iter() .zip(is_broadcast.iter()) .map(|(vals, is_broadcast)| { - let local_val_len = vals.as_ref().len() / (parallel_num / is_broadcast); + let is_broadcast_next_power_of_two = is_broadcast.next_power_of_two(); + let local_val_len = vals.as_ref().len() / (parallel_num / is_broadcast_next_power_of_two); let start_index = local_val_len * parallel_index % vals.as_ref().len(); &vals.as_ref()[start_index..local_val_len + start_index] }) @@ -217,7 +218,6 @@ pub fn partition_gkr_claims_and_open_pcs_no_mpi_impl( >( gkr_claim, val_len, parallel_index, parallel_num, *ib ); - pcs_local_open_impl::( commitment_val.as_ref(), &challenge_for_pcs, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index b002e710..71fa48e6 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -65,7 +65,6 @@ where commit_timer.stop(); let mut vals_ref = vec![]; let mut challenges = vec![]; - let prove_timer = Timer::new("Prove all kernels", global_mpi_config.is_root()); let proofs = computation_graph @@ -375,7 +374,6 @@ where let world_rank = mpi_config.world_rank(); let world_size = mpi_config.world_size(); let n_copies = parallel_count / world_size; - let local_commitment_values = get_local_vals_multi_copies( commitments_values, is_broadcast, @@ -434,7 +432,7 @@ where FMulti: FieldEngine, T: Transcript, -{ +{ let input_vals_multi_copies = local_commitment_values_multi_copies .iter() .map(|local_commitment_values| { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index 42315b39..25773ce0 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -147,7 +147,6 @@ where ECCConfig: Config, { let timer = Timer::new("prove", true); - SharedMemoryEngine::write_witness_to_shared_memory::(device_memories); wait_async(ClientHttpHelper::request_prove()); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index d3f961f3..d5e5e29d 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -180,7 +180,7 @@ pub fn partition_challenge_and_location_for_pcs_mpi( challenge.r_mpi.clear(); (challenge, component_idx_vars) } else { - let n_vals_vars = (total_vals_len / (parallel_count / broadcast_num)).ilog2() as usize; + let n_vals_vars = (total_vals_len / (parallel_count / broadcast_num.next_power_of_two())).ilog2() as usize; let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); challenge.rz.resize(n_vals_vars, zero); //TODO: what is challenge.r_mpi, why need it when broadcast is false? From 93b9c3fd418b9cc51e58b1821c3d3390435287ed Mon Sep 17 00:00:00 2001 From: hczphn Date: Wed, 3 Sep 2025 17:41:52 -0700 Subject: [PATCH 05/13] fix clippy and fmt error --- expander_compiler/src/zkcuda/context.rs | 27 ++++++------------- .../src/zkcuda/proving_system/common.rs | 4 ++- .../proving_system/expander/prove_impl.rs | 3 ++- .../expander_no_oversubscribe/prove_impl.rs | 2 +- .../expander_parallelized/prove_impl.rs | 7 +++-- expander_compiler/src/zkcuda/shape.rs | 4 +-- .../tests/zkcuda/zkcuda_matmul.rs | 20 +++++++++----- 7 files changed, 34 insertions(+), 33 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 4978d7ff..54c6b0af 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -1,5 +1,3 @@ -use core::num; - use arith::SimdField; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use serdes::ExpSerde; @@ -303,15 +301,12 @@ impl>> Context { &self, values: &[SIMDField], s: &mut [SIMDField], - is_broadcast: usize, parallel_index: usize, chunk_size: Option, ) { let chunk_size = chunk_size.unwrap(); let start_index = chunk_size * parallel_index % values.len(); - s.copy_from_slice( - &values[start_index..(start_index + chunk_size)], - ); + s.copy_from_slice(&values[start_index..(start_index + chunk_size)]); } pub fn call_kernel( @@ -351,7 +346,7 @@ impl>> Context { .as_ref() .unwrap() .shape_history - .get_initial_split_list(ib/num_parallel+1); + .get_initial_split_list(ib / num_parallel + 1); // let isl = vec![1,64,4096]; let t = io.as_ref().unwrap().id; self.device_memories[t].required_shape_products = merge_shape_products( @@ -373,7 +368,7 @@ impl>> Context { } } for (io_spec, ib) in kernel.io_specs().iter().zip(is_broadcast.iter()) { - if io_spec.is_output && *ib!=1 { + if io_spec.is_output && *ib != 1 { panic!("Output is broadcasted, but it shouldn't be"); } } @@ -383,9 +378,8 @@ impl>> Context { let mut outputs_tmp = vec![Vec::new(); kernel.io_specs().len()]; let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()]; let mut chunk_sizes: Vec> = vec![None; kernel.io_specs().len()]; - for ((((input, &ib), ir_inputs), chunk_size), kernel_shape) in ios + for (((input, ir_inputs), chunk_size), kernel_shape) in ios .iter() - .zip(is_broadcast.iter()) .zip(ir_inputs_all.iter_mut()) .zip(chunk_sizes.iter_mut()) .zip(kernel.io_shapes().iter()) @@ -415,7 +409,6 @@ impl>> Context { self.ir_copy_from_device_memory( &ir_inputs_all[i], &mut ir_inputs[*input_start..*input_end], - is_broadcast[i], parallel_i, chunk_sizes[i], ); @@ -624,9 +617,9 @@ impl>> Context { psi.push(s.as_ref().map(|t| { if ib == kernel_call.num_parallel { t.0.clone() - } else{ - keep_shape_since(&t.0, kernel_call.num_parallel/ib) - } + } else { + keep_shape_since(&t.0, kernel_call.num_parallel / ib) + } })); } let mut pso = Vec::new(); @@ -761,10 +754,9 @@ impl>> Context { let mut output_chunk_sizes: Vec> = vec![None; kernel_primitive.io_specs().len()]; let mut any_shape = None; - for (((input, &ib), ir_inputs), chunk_size) in kernel_call + for ((input, ir_inputs), chunk_size) in kernel_call .input_handles .iter() - .zip(kernel_call.is_broadcast.iter()) .zip(ir_inputs_all.iter_mut()) .zip(input_chunk_sizes.iter_mut()) { @@ -822,7 +814,6 @@ impl>> Context { self.ir_copy_from_device_memory( ir_inputs, &mut inputs[*input_start..*input_end], - chunk_size.unwrap_or(2), parallel_i, *chunk_size, ); @@ -842,7 +833,6 @@ impl>> Context { self.ir_copy_from_device_memory( ir_outputs, &mut inputs[*output_start..*output_end], - chunk_size.unwrap_or(1), parallel_i, *chunk_size, ); @@ -890,7 +880,6 @@ impl>> Context { .map(|dm| { let shape = prefix_products_to_shape(&dm.required_shape_products); let im = shape_padded_mapping(&shape); - let tmp = im.map_inputs(&dm.values); im.map_inputs(&dm.values) }) .collect() diff --git a/expander_compiler/src/zkcuda/proving_system/common.rs b/expander_compiler/src/zkcuda/proving_system/common.rs index 431ae429..89d648c9 100644 --- a/expander_compiler/src/zkcuda/proving_system/common.rs +++ b/expander_compiler/src/zkcuda/proving_system/common.rs @@ -23,7 +23,9 @@ pub fn check_inputs( panic!("Input size mismatch"); } for i in 0..kernel.layered_circuit_input().len() { - if kernel.layered_circuit_input()[i].len != values[i].len() / (parallel_count / is_broadcast[i]) { + if kernel.layered_circuit_input()[i].len + != values[i].len() / (parallel_count / is_broadcast[i]) + { panic!("Input size mismatch"); } } diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index be0b404d..37dd2b81 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -57,7 +57,8 @@ pub fn get_local_vals<'vals_life, F: Field>( .zip(is_broadcast.iter()) .map(|(vals, is_broadcast)| { let is_broadcast_next_power_of_two = is_broadcast.next_power_of_two(); - let local_val_len = vals.as_ref().len() / (parallel_num / is_broadcast_next_power_of_two); + let local_val_len = + vals.as_ref().len() / (parallel_num / is_broadcast_next_power_of_two); let start_index = local_val_len * parallel_index % vals.as_ref().len(); &vals.as_ref()[start_index..local_val_len + start_index] }) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index 71fa48e6..d94eabb5 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -432,7 +432,7 @@ where FMulti: FieldEngine, T: Transcript, -{ +{ let input_vals_multi_copies = local_commitment_values_multi_copies .iter() .map(|local_commitment_values| { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index d5e5e29d..a045593b 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -180,11 +180,14 @@ pub fn partition_challenge_and_location_for_pcs_mpi( challenge.r_mpi.clear(); (challenge, component_idx_vars) } else { - let n_vals_vars = (total_vals_len / (parallel_count / broadcast_num.next_power_of_two())).ilog2() as usize; + let n_vals_vars = (total_vals_len / (parallel_count / broadcast_num.next_power_of_two())) + .ilog2() as usize; let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); challenge.rz.resize(n_vals_vars, zero); //TODO: what is challenge.r_mpi, why need it when broadcast is false? - challenge.rz.extend_from_slice(&challenge.r_mpi[..(parallel_count / broadcast_num).ilog2() as usize]); + challenge.rz.extend_from_slice( + &challenge.r_mpi[..(parallel_count / broadcast_num).ilog2() as usize], + ); challenge.r_mpi.clear(); (challenge, component_idx_vars) } diff --git a/expander_compiler/src/zkcuda/shape.rs b/expander_compiler/src/zkcuda/shape.rs index 51c11e6e..de135861 100644 --- a/expander_compiler/src/zkcuda/shape.rs +++ b/expander_compiler/src/zkcuda/shape.rs @@ -209,7 +209,7 @@ pub fn keep_shape_since(shape: &[usize], x: usize) -> Vec { p *= y; if p == x { return shape[i + 1..].to_vec(); - } + } } unreachable!() } @@ -243,7 +243,7 @@ impl ShapeHistory { // This function returns a list of dimension lengths where the initial vector must be split // split_first_dim: first dimension of current shape will be split pub fn get_initial_split_list(&self, split_first_dim: usize) -> Vec { - let last_entry = self.entries.last().unwrap().minimize(split_first_dim==1); + let last_entry = self.entries.last().unwrap().minimize(split_first_dim == 1); let mut split_list = prefix_products(&last_entry.shape); for e in self.entries.iter().rev().skip(1) { let e = e.minimize(false); diff --git a/expander_compiler/tests/zkcuda/zkcuda_matmul.rs b/expander_compiler/tests/zkcuda/zkcuda_matmul.rs index 7fc90b4c..16425498 100644 --- a/expander_compiler/tests/zkcuda/zkcuda_matmul.rs +++ b/expander_compiler/tests/zkcuda/zkcuda_matmul.rs @@ -1,12 +1,14 @@ use expander_compiler::frontend::*; +use expander_compiler::zkcuda::proving_system::expander::config::{ + ZKCudaBN254Hyrax, ZKCudaBN254HyraxBatchPCS, ZKCudaBN254KZG, ZKCudaBN254KZGBatchPCS, +}; use expander_compiler::zkcuda::proving_system::Expander; use expander_compiler::zkcuda::proving_system::ProvingSystem; +use expander_compiler::zkcuda::proving_system::{ExpanderNoOverSubscribe, ParallelizedExpander}; use expander_compiler::zkcuda::shape::Reshape; use expander_compiler::zkcuda::{context::*, kernel::*}; -use serdes::ExpSerde; use serde::{Deserialize, Serialize}; -use expander_compiler::zkcuda::proving_system::expander::config::{ZKCudaBN254Hyrax, ZKCudaBN254HyraxBatchPCS, ZKCudaBN254KZG, ZKCudaBN254KZGBatchPCS,}; -use expander_compiler::zkcuda::proving_system::{ExpanderNoOverSubscribe, ParallelizedExpander,}; +use serdes::ExpSerde; const SIZE: usize = 64; const SIZE2: usize = 4096; #[kernel] @@ -118,7 +120,7 @@ fn zkcuda_matmul_sum() { } } let mut mat_b: Vec> = vec![]; - for i in 0..parallel_count/4 { + for i in 0..parallel_count / 4 { mat_b.push(vec![]); for j in 0..64 { mat_b[i].push(BN254Fr::from((i * 2333 + j + 11111) as u32)); @@ -132,11 +134,15 @@ fn zkcuda_matmul_sum() { let computation_graph = ctx.compile_computation_graph().unwrap(); ctx.solve_witness().unwrap(); let (prover_setup, _) = ExpanderNoOverSubscribe::::setup(&computation_graph); - let proof = ExpanderNoOverSubscribe::::prove(&prover_setup, &computation_graph, ctx.export_device_memories()); + let proof = ExpanderNoOverSubscribe::::prove( + &prover_setup, + &computation_graph, + ctx.export_device_memories(), + ); // let file = std::fs::File::create("proof.txt").unwrap(); // let writer = std::io::BufWriter::new(file); // proof.serialize_into(writer); - as ProvingSystem>::post_process(); + as ProvingSystem>::post_process(); } #[test] fn zkcuda_sum() { @@ -144,4 +150,4 @@ fn zkcuda_sum() { let file = std::fs::File::create("kernel_sum_8_elements.txt").unwrap(); let writer = std::io::BufWriter::new(file); kernel_mul_line.serialize_into(writer); -} \ No newline at end of file +} From 2e07c356ec8cac85a56907262c46c22e2ad70404 Mon Sep 17 00:00:00 2001 From: hczphn Date: Wed, 3 Sep 2025 17:52:07 -0700 Subject: [PATCH 06/13] set broadcast 1 to be const --- expander_compiler/src/zkcuda/context.rs | 24 ++++++++++---------- expander_compiler/src/zkcuda/shape.rs | 30 ++++++++++++------------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 54c6b0af..949a4dad 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -24,6 +24,7 @@ use super::{ }; pub use macros::call_kernel; +const NOT_BROADCAST: usize = 1; struct DeviceMemory { values: Vec>, @@ -328,7 +329,7 @@ impl>> Context { .enumerate() { if !spec.is_input { - is_broadcast.push(1); + is_broadcast.push(NOT_BROADCAST); continue; } /*println!( @@ -346,8 +347,7 @@ impl>> Context { .as_ref() .unwrap() .shape_history - .get_initial_split_list(ib / num_parallel + 1); - // let isl = vec![1,64,4096]; + .get_initial_split_list(ib == num_parallel); let t = io.as_ref().unwrap().id; self.device_memories[t].required_shape_products = merge_shape_products( &isl, @@ -368,7 +368,7 @@ impl>> Context { } } for (io_spec, ib) in kernel.io_specs().iter().zip(is_broadcast.iter()) { - if io_spec.is_output && *ib != 1 { + if io_spec.is_output && *ib != NOT_BROADCAST { panic!("Output is broadcasted, but it shouldn't be"); } } @@ -463,7 +463,7 @@ impl>> Context { .as_ref() .unwrap() .shape_history - .get_initial_split_list(1), + .get_initial_split_list(true), &self.device_memories[id].required_shape_products, ); *output = handle.clone(); @@ -507,7 +507,7 @@ impl>> Context { .zip(kernel_call.input_handles.iter()) .zip(kernel_call.is_broadcast.iter()) { - if !spec.is_input || ib > 1 { + if !spec.is_input || ib > NOT_BROADCAST { continue; } let pad_shape = get_pad_shape(input_handle).unwrap(); @@ -520,7 +520,7 @@ impl>> Context { .zip(kernel_call.output_handles.iter()) .zip(kernel_call.is_broadcast.iter()) { - if !spec.is_output || ib > 1 { + if !spec.is_output || ib > NOT_BROADCAST { continue; } let pad_shape = get_pad_shape(output_handle).unwrap(); @@ -543,7 +543,7 @@ impl>> Context { if x != 1 && x != kernel_call.num_parallel { let sh_tmp = handle.shape_history.reshape(&[x, total / x]); dm.required_shape_products = merge_shape_products( - &sh_tmp.get_initial_split_list(1), + &sh_tmp.get_initial_split_list(true), &dm.required_shape_products, ); } @@ -654,7 +654,7 @@ impl>> Context { commitment_indices.push(handle.as_ref().unwrap().id); commitment_bit_orders.push(shape.1.clone()); is_broadcast.push(ib); - if ib == 1 { + if ib == NOT_BROADCAST { any_shape = Some(shape.0.clone()); } } @@ -671,7 +671,7 @@ impl>> Context { commitment_indices.push(handle.as_ref().unwrap().id); commitment_bit_orders.push(shape.1.clone()); is_broadcast.push(ib); - if ib == 1 { + if ib == NOT_BROADCAST { any_shape = Some(shape.0.clone()); } } @@ -688,7 +688,7 @@ impl>> Context { dm_max += 1; commitment_bit_orders.push((0..n.trailing_zeros() as usize).collect()); commitments_lens.push(n); - is_broadcast.push(1); + is_broadcast.push(NOT_BROADCAST); } let kernel_id = self.kernels.add(&kernel); @@ -791,7 +791,7 @@ impl>> Context { let values = handle .shape_history .permute_vec(&self.device_memories[handle.id].values); - assert!(ib == 1); + assert!(ib == NOT_BROADCAST); *chunk_size = Some(values.len() / kernel_call.num_parallel); *ir_inputs = values; } diff --git a/expander_compiler/src/zkcuda/shape.rs b/expander_compiler/src/zkcuda/shape.rs index de135861..21c6ef72 100644 --- a/expander_compiler/src/zkcuda/shape.rs +++ b/expander_compiler/src/zkcuda/shape.rs @@ -242,8 +242,8 @@ impl ShapeHistory { // Suppose we need to ensure that the current shape is legal // This function returns a list of dimension lengths where the initial vector must be split // split_first_dim: first dimension of current shape will be split - pub fn get_initial_split_list(&self, split_first_dim: usize) -> Vec { - let last_entry = self.entries.last().unwrap().minimize(split_first_dim == 1); + pub fn get_initial_split_list(&self, keep_first_dim: bool) -> Vec { + let last_entry = self.entries.last().unwrap().minimize(keep_first_dim); let mut split_list = prefix_products(&last_entry.shape); for e in self.entries.iter().rev().skip(1) { let e = e.minimize(false); @@ -480,23 +480,23 @@ mod tests { fn test_get_initial_split_list() { let sh = ShapeHistory::new(vec![16, 9]); let sh = sh.reshape(&[9, 16]); - assert_eq!(sh.get_initial_split_list(9), vec![1, 144]); - assert_eq!(sh.get_initial_split_list(1), vec![1, 9, 144]); + assert_eq!(sh.get_initial_split_list(false), vec![1, 144]); + assert_eq!(sh.get_initial_split_list(true), vec![1, 9, 144]); let sh = sh.reshape(&[3, 16, 3]); - assert_eq!(sh.get_initial_split_list(1), vec![1, 3, 144]); + assert_eq!(sh.get_initial_split_list(true), vec![1, 3, 144]); let sh = sh.reshape(&[2, 2, 2, 2, 3, 3]); let sh = sh.transpose(&[1, 0, 2, 3, 4, 5]); - assert_eq!(sh.get_initial_split_list(2), vec![1, 2, 4, 144]); - assert_eq!(sh.get_initial_split_list(1), vec![1, 2, 4, 144]); + assert_eq!(sh.get_initial_split_list(false), vec![1, 2, 4, 144]); + assert_eq!(sh.get_initial_split_list(true), vec![1, 2, 4, 144]); let sh = sh.reshape(&[16, 9]); - assert_eq!(sh.get_initial_split_list(2), vec![1, 2, 4, 144]); - assert_eq!(sh.get_initial_split_list(1), vec![1, 2, 4, 16, 144]); + assert_eq!(sh.get_initial_split_list(false), vec![1, 2, 4, 144]); + assert_eq!(sh.get_initial_split_list(true), vec![1, 2, 4, 16, 144]); let sh = sh.transpose(&[1, 0]); - assert_eq!(sh.get_initial_split_list(2), vec![1, 2, 4, 16, 144]); - assert_eq!(sh.get_initial_split_list(1), vec![1, 2, 4, 16, 144]); + assert_eq!(sh.get_initial_split_list(false), vec![1, 2, 4, 16, 144]); + assert_eq!(sh.get_initial_split_list(true), vec![1, 2, 4, 16, 144]); let sh = sh.reshape(&[3, 3, 16]); - assert_eq!(sh.get_initial_split_list(2), vec![1, 2, 4, 16, 144]); - assert_eq!(sh.get_initial_split_list(1), vec![1, 2, 4, 16, 48, 144]); + assert_eq!(sh.get_initial_split_list(false), vec![1, 2, 4, 16, 144]); + assert_eq!(sh.get_initial_split_list(true), vec![1, 2, 4, 16, 48, 144]); } #[test] @@ -504,9 +504,9 @@ mod tests { let sh = ShapeHistory::new(vec![16, 9]); let sh = sh.transpose(&[1, 0]); let sh = sh.reshape(&[16, 9]); - sh.get_initial_split_list(2); + sh.get_initial_split_list(false); assert!(std::panic::catch_unwind(|| { - sh.get_initial_split_list(1); + sh.get_initial_split_list(true); }) .is_err()); } From fa0438f924e6e032556e7edd90970008651e7685 Mon Sep 17 00:00:00 2001 From: hczphn Date: Wed, 3 Sep 2025 17:59:46 -0700 Subject: [PATCH 07/13] remove stacker dependencies --- expander_compiler/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 44dc1719..ffed0fa4 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -40,7 +40,6 @@ shared_memory.workspace = true tiny-keccak.workspace = true tokio.workspace = true once_cell = "1.21.3" -stacker.workspace = true [dev-dependencies] rayon = "1.9" From 51d3dca886907921486668d0d60fca4c8546aa4b Mon Sep 17 00:00:00 2001 From: hczphn Date: Wed, 3 Sep 2025 18:09:33 -0700 Subject: [PATCH 08/13] use master branch's Cargo.lock --- Cargo.lock | 23 --- expander_compiler/src/zkcuda/context.rs | 2 +- .../proving_system/expander/prove_impl.rs | 1 - .../tests/zkcuda/zkcuda_matmul.rs | 174 ++++++------------ 4 files changed, 61 insertions(+), 139 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c99e90d1..c1f2f411 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1005,7 +1005,6 @@ dependencies = [ "serdes", "sha2", "shared_memory", - "stacker", "sumcheck", "tiny-keccak", "tokio", @@ -2348,15 +2347,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "psm" -version = "0.1.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" -dependencies = [ - "cc", -] - [[package]] name = "quote" version = "1.0.40" @@ -2833,19 +2823,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" -[[package]] -name = "stacker" -version = "0.1.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" -dependencies = [ - "cc", - "cfg-if", - "libc", - "psm", - "windows-sys 0.59.0", -] - [[package]] name = "static_assertions" version = "1.1.0" diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 949a4dad..e15d021d 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -347,7 +347,7 @@ impl>> Context { .as_ref() .unwrap() .shape_history - .get_initial_split_list(ib == num_parallel); + .get_initial_split_list(ib == num_parallel); let t = io.as_ref().unwrap().id; self.device_memories[t].required_shape_products = merge_shape_products( &isl, diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index 37dd2b81..a42c018c 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -78,7 +78,6 @@ pub fn prepare_inputs_with_local_vals( input_vals[partition.offset..partition.offset + partition.len] .copy_from_slice(val.as_ref()); } - // panic!("1"); input_vals } diff --git a/expander_compiler/tests/zkcuda/zkcuda_matmul.rs b/expander_compiler/tests/zkcuda/zkcuda_matmul.rs index 16425498..5182e146 100644 --- a/expander_compiler/tests/zkcuda/zkcuda_matmul.rs +++ b/expander_compiler/tests/zkcuda/zkcuda_matmul.rs @@ -1,153 +1,99 @@ use expander_compiler::frontend::*; -use expander_compiler::zkcuda::proving_system::expander::config::{ - ZKCudaBN254Hyrax, ZKCudaBN254HyraxBatchPCS, ZKCudaBN254KZG, ZKCudaBN254KZGBatchPCS, -}; use expander_compiler::zkcuda::proving_system::Expander; use expander_compiler::zkcuda::proving_system::ProvingSystem; -use expander_compiler::zkcuda::proving_system::{ExpanderNoOverSubscribe, ParallelizedExpander}; use expander_compiler::zkcuda::shape::Reshape; use expander_compiler::zkcuda::{context::*, kernel::*}; -use serde::{Deserialize, Serialize}; -use serdes::ExpSerde; -const SIZE: usize = 64; -const SIZE2: usize = 4096; + #[kernel] fn mul_line( api: &mut API, - a: &[InputVariable; SIZE], - b: &[InputVariable; SIZE], - c: &mut [OutputVariable; 1], + a: &[InputVariable; 32], + b: &[[InputVariable; 64]; 32], + c: &mut [OutputVariable; 64], ) { - let mut sum = api.constant(0); - for i in 0..SIZE { - let t = api.mul(a[i], b[i]); - sum = api.add(sum, t); + for j in 0..64 { + c[j] = api.constant(0); + } + for i in 0..32 { + for j in 0..64 { + let t = api.mul(a[i], b[i][j]); + c[j] = api.add(c[j], t); + } } - c[0] = sum; } #[kernel] -fn sum_8_elements(api: &mut API, a: &[InputVariable; SIZE2], b: &mut OutputVariable) { +fn sum_8_elements(api: &mut API, a: &[InputVariable; 8], b: &mut OutputVariable) { let mut sum = api.constant(0); - for i in 0..SIZE2 { + for i in 0..8 { sum = api.add(sum, a[i]); } *b = sum; } -// #[test] -// fn zkcuda_matmul_sum() { -// let kernel_mul_line: KernelPrimitive = compile_mul_line().unwrap(); -// let kernel_sum_8_elements: KernelPrimitive = compile_sum_8_elements().unwrap(); - -// let mut ctx: Context = Context::default(); - -// let mut mat_a: Vec> = vec![]; -// for i in 0..64 { -// mat_a.push(vec![]); -// for j in 0..32 { -// mat_a[i].push(M31::from((i * 233 + j + 1) as u32)); -// } -// } -// let mut mat_b: Vec> = vec![]; -// for i in 0..32 { -// mat_b.push(vec![]); -// for j in 0..64 { -// mat_b[i].push(M31::from((i * 2333 + j + 1) as u32)); -// } -// } -// let mut expected_result = M31::zero(); -// for i in 0..64 { -// for j in 0..64 { -// for k in 0..32 { -// expected_result += mat_a[i][k] * mat_b[k][j]; -// } -// } -// } - -// let a = ctx.copy_to_device(&mat_a); -// let b = ctx.copy_to_device(&mat_b); -// let mut c = None; -// call_kernel!(ctx, kernel_mul_line, 64, a, b, mut c).unwrap(); - -// let c = c.reshape(&[512, 8]); -// let mut d = None; -// call_kernel!(ctx, kernel_sum_8_elements, 512, c, mut d).unwrap(); - -// let d = d.reshape(&[64, 8]); -// let mut e = None; -// call_kernel!(ctx, kernel_sum_8_elements, 64, d, mut e).unwrap(); - -// let e = e.reshape(&[8, 8]); -// let mut f = None; -// call_kernel!(ctx, kernel_sum_8_elements, 8, e, mut f).unwrap(); - -// let f = f.reshape(&[1, 8]); -// let mut g = None; -// call_kernel!(ctx, kernel_sum_8_elements, 1, f, mut g).unwrap(); - -// let g = g.reshape(&[]); -// let result: M31 = ctx.copy_to_host(g); -// assert_eq!(result, expected_result); - -// type P = Expander; -// let computation_graph = ctx.compile_computation_graph().unwrap(); -// ctx.solve_witness().unwrap(); -// let (prover_setup, verifier_setup) = P::setup(&computation_graph); -// let proof = P::prove( -// &prover_setup, -// &computation_graph, -// ctx.export_device_memories(), -// ); -// assert!(P::verify(&verifier_setup, &computation_graph, &proof)); -// } - #[test] fn zkcuda_matmul_sum() { - let kernel_mul_line: KernelPrimitive = compile_mul_line().unwrap(); - // println!("kernnel_mul_line: {:?}", kernel_mul_line); - // let file = std::fs::File::create("kernel_mul_line_circuit.txt").unwrap(); - // let writer = std::io::BufWriter::new(file); - // kernel_mul_line.serialize_into(writer); - let parallel_count = 64; - let mut ctx: Context = Context::default(); + let kernel_mul_line: KernelPrimitive = compile_mul_line().unwrap(); + let kernel_sum_8_elements: KernelPrimitive = compile_sum_8_elements().unwrap(); - let mut mat_a: Vec> = vec![]; - for i in 0..parallel_count { + let mut ctx: Context = Context::default(); + + let mut mat_a: Vec> = vec![]; + for i in 0..64 { mat_a.push(vec![]); - for j in 0..64 { - mat_a[i].push(BN254Fr::from((i * 233 + j + 1) as u32)); + for j in 0..32 { + mat_a[i].push(M31::from((i * 233 + j + 1) as u32)); } } - let mut mat_b: Vec> = vec![]; - for i in 0..parallel_count / 4 { + let mut mat_b: Vec> = vec![]; + for i in 0..32 { mat_b.push(vec![]); for j in 0..64 { - mat_b[i].push(BN254Fr::from((i * 2333 + j + 11111) as u32)); + mat_b[i].push(M31::from((i * 2333 + j + 1) as u32)); + } + } + let mut expected_result = M31::zero(); + for i in 0..64 { + for j in 0..64 { + for k in 0..32 { + expected_result += mat_a[i][k] * mat_b[k][j]; + } } } let a = ctx.copy_to_device(&mat_a); let b = ctx.copy_to_device(&mat_b); let mut c = None; - call_kernel!(ctx, kernel_mul_line, parallel_count, a, b, mut c).unwrap(); + call_kernel!(ctx, kernel_mul_line, 64, a, b, mut c).unwrap(); + + let c = c.reshape(&[512, 8]); + let mut d = None; + call_kernel!(ctx, kernel_sum_8_elements, 512, c, mut d).unwrap(); + + let d = d.reshape(&[64, 8]); + let mut e = None; + call_kernel!(ctx, kernel_sum_8_elements, 64, d, mut e).unwrap(); + + let e = e.reshape(&[8, 8]); + let mut f = None; + call_kernel!(ctx, kernel_sum_8_elements, 8, e, mut f).unwrap(); + + let f = f.reshape(&[1, 8]); + let mut g = None; + call_kernel!(ctx, kernel_sum_8_elements, 1, f, mut g).unwrap(); + + let g = g.reshape(&[]); + let result: M31 = ctx.copy_to_host(g); + assert_eq!(result, expected_result); + + type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); ctx.solve_witness().unwrap(); - let (prover_setup, _) = ExpanderNoOverSubscribe::::setup(&computation_graph); - let proof = ExpanderNoOverSubscribe::::prove( + let (prover_setup, verifier_setup) = P::setup(&computation_graph); + let proof = P::prove( &prover_setup, &computation_graph, ctx.export_device_memories(), ); - // let file = std::fs::File::create("proof.txt").unwrap(); - // let writer = std::io::BufWriter::new(file); - // proof.serialize_into(writer); - as ProvingSystem>::post_process(); -} -#[test] -fn zkcuda_sum() { - let kernel_mul_line: KernelPrimitive = compile_sum_8_elements().unwrap(); - let file = std::fs::File::create("kernel_sum_8_elements.txt").unwrap(); - let writer = std::io::BufWriter::new(file); - kernel_mul_line.serialize_into(writer); -} + assert!(P::verify(&verifier_setup, &computation_graph, &proof)); +} \ No newline at end of file From 1efff252b8a1b3151b186f0c1af606ceea10d77a Mon Sep 17 00:00:00 2001 From: hczphn Date: Wed, 3 Sep 2025 18:28:10 -0700 Subject: [PATCH 09/13] change mpi version --- Cargo.lock | 450 ++++++++---------- Cargo.toml | 2 +- .../tests/zkcuda/zkcuda_matmul.rs | 2 +- 3 files changed, 207 insertions(+), 247 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c1f2f411..d0d5e117 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,9 +61,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" -version = "0.6.19" +version = "0.6.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933" +checksum = "3ae563653d1938f79b1ab1b5e668c87c76a9930414574a6583a7b7e11a8e6192" dependencies = [ "anstyle", "anstyle-parse", @@ -91,28 +91,28 @@ dependencies = [ [[package]] name = "anstyle-query" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9" +checksum = "9e231f6134f61b71076a3eab506c379d4f36122f2af15a9ff04415ea4c3339e2" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] name = "anstyle-wincon" -version = "3.0.9" +version = "3.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882" +checksum = "3e0633414522a32ffaac8ac6cc8f748e090c5717661fddeea04219e2344f5f2a" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "ark-std", "criterion", @@ -257,15 +257,21 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "async-trait" -version = "0.1.88" +version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.5.0" @@ -285,7 +291,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.6.0", + "hyper 1.7.0", "hyper-util", "itoa", "matchit", @@ -330,7 +336,7 @@ dependencies = [ [[package]] name = "babybear" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "ark-std", @@ -367,7 +373,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "31375ce97b1316b3a92644c2cbc93fa9dcfba06e4aec9a440bce23397af82fd6" dependencies = [ "big-int-proc", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -377,13 +383,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73cfa06eb56d71f2bb1874b101a50c3ba29fcf3ff7dd8de274e473929459863b" dependencies = [ "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] name = "bin" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "babybear", @@ -413,25 +419,20 @@ dependencies = [ [[package]] name = "bindgen" -version = "0.69.5" +version = "0.72.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" dependencies = [ - "bitflags 2.9.1", + "bitflags 2.9.4", "cexpr", "clang-sys", - "itertools 0.12.1", - "lazy_static", - "lazycell", - "log", - "prettyplease", + "itertools 0.13.0", "proc-macro2", "quote", "regex", "rustc-hash", "shlex", - "syn 2.0.104", - "which", + "syn 2.0.106", ] [[package]] @@ -442,9 +443,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.9.1" +version = "2.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" +checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" [[package]] name = "bitvec" @@ -490,8 +491,7 @@ dependencies = [ [[package]] name = "build-probe-mpi" version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3234fa6de2f6e0e338c7183ba09ae68c8f2bd6919d8763362597627362b7f8fe" +source = "git+https://github.com/rsmpi/rsmpi?rev=61796831954b679cbe267c1b704ddbcb7fef3715#61796831954b679cbe267c1b704ddbcb7fef3715" dependencies = [ "pkg-config", "shell-words", @@ -523,10 +523,11 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.30" +version = "1.2.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7" +checksum = "590f9024a68a8c40351881787f1934dc11afd69090f5edb6831464694d836ea3" dependencies = [ + "find-msvc-tools", "shlex", ] @@ -541,9 +542,9 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.1" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" [[package]] name = "chrono" @@ -589,7 +590,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "ark-std", @@ -601,7 +602,7 @@ dependencies = [ "mpi", "rand", "serdes", - "thiserror", + "thiserror 1.0.69", "transcript", ] @@ -645,9 +646,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.41" +version = "4.5.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be92d32e80243a54711e5d7ce823c35c41c9d929dc4ab58e1276f625841aadf9" +checksum = "7eac00902d9d136acd712710d71823fb8ac8004ca445a89e73a41d45aa712931" dependencies = [ "clap_builder", "clap_derive", @@ -655,9 +656,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.41" +version = "4.5.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "707eab41e9622f9139419d573eca0900137718000c517d47da73045f54331c3d" +checksum = "2ad9bbf750e73b5884fb8a211a9424a1906c1e156724260fdae972f31d70e1d6" dependencies = [ "anstream", "anstyle", @@ -667,14 +668,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.41" +version = "4.5.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491" +checksum = "bbfd7eae0b0f1a6e63d4b13c9c478de77c2eb546fba158ad50b4203dc24b9f9c" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] @@ -702,14 +703,14 @@ dependencies = [ [[package]] name = "config_macros" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "gkr_engine", "gkr_hashers", "poly_commit", "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", "transcript", ] @@ -817,7 +818,7 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crosslayer_prototype" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "env_logger", @@ -829,7 +830,7 @@ dependencies = [ "rand", "serdes", "sumcheck", - "thiserror", + "thiserror 1.0.69", "transcript", ] @@ -891,7 +892,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] @@ -1029,6 +1030,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e178e4fba8a2726903f6ba98a6d221e76f9c12c650d5dc0e6afdc50677b49650" + [[package]] name = "fnv" version = "1.0.7" @@ -1052,9 +1059,9 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" dependencies = [ "percent-encoding", ] @@ -1137,13 +1144,13 @@ dependencies = [ "cfg-if", "libc", "r-efi", - "wasi 0.14.2+wasi-0.2.4", + "wasi 0.14.3+wasi-0.2.4", ] [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "ark-std", @@ -1154,13 +1161,13 @@ dependencies = [ "rand", "raw-cpuid", "serdes", - "thiserror", + "thiserror 1.0.69", ] [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "ark-std", @@ -1179,7 +1186,7 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "ark-std", @@ -1204,7 +1211,7 @@ dependencies = [ "serdes", "sha2", "sumcheck", - "thiserror", + "thiserror 1.0.69", "transcript", "utils", ] @@ -1212,7 +1219,7 @@ dependencies = [ [[package]] name = "gkr_engine" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "babybear", @@ -1225,13 +1232,13 @@ dependencies = [ "polynomials", "rand", "serdes", - "thiserror", + "thiserror 1.0.69", ] [[package]] name = "gkr_hashers" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "halo2curves", @@ -1242,14 +1249,14 @@ dependencies = [ [[package]] name = "glob" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "goldilocks" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "ark-std", @@ -1333,9 +1340,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.4" +version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" [[package]] name = "headers" @@ -1373,15 +1380,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" -[[package]] -name = "home" -version = "0.5.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" -dependencies = [ - "windows-sys 0.59.0", -] - [[package]] name = "http" version = "0.2.12" @@ -1476,19 +1474,21 @@ dependencies = [ [[package]] name = "hyper" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" dependencies = [ + "atomic-waker", "bytes", "futures-channel", - "futures-util", + "futures-core", "http 1.3.1", "http-body 1.0.1", "httparse", "httpdate", "itoa", "pin-project-lite", + "pin-utils", "smallvec", "tokio", ] @@ -1516,7 +1516,7 @@ dependencies = [ "futures-core", "http 1.3.1", "http-body 1.0.1", - "hyper 1.6.0", + "hyper 1.7.0", "pin-project-lite", "tokio", "tower-service", @@ -1634,9 +1634,9 @@ dependencies = [ [[package]] name = "idna" -version = "1.0.3" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" dependencies = [ "idna_adapter", "smallvec", @@ -1655,21 +1655,21 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.10.0" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" +checksum = "f2481980430f9f78649238835720ddccc57e52df14ffce1c6f37391d61b563e9" dependencies = [ "equivalent", - "hashbrown 0.15.4", + "hashbrown 0.15.5", ] [[package]] name = "io-uring" -version = "0.7.9" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" +checksum = "046fa2d4d00aea763528b4950358d0ead425372445dc8ff86312b3c69ff7727b" dependencies = [ - "bitflags 2.9.1", + "bitflags 2.9.4", "cfg-if", "libc", ] @@ -1706,15 +1706,6 @@ dependencies = [ "either", ] -[[package]] -name = "itertools" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -1751,7 +1742,7 @@ checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] @@ -1773,23 +1764,17 @@ dependencies = [ "spin", ] -[[package]] -name = "lazycell" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" - [[package]] name = "libc" -version = "0.2.174" +version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" +checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" [[package]] name = "libffi" -version = "3.2.0" +version = "4.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce826c243048e3d5cec441799724de52e2d42f820468431fc3fceee2341871e2" +checksum = "b0feebbe0ccd382a2790f78d380540500d7b78ed7a3498b68fcfbc1593749a94" dependencies = [ "libc", "libffi-sys", @@ -1797,9 +1782,9 @@ dependencies = [ [[package]] name = "libffi-sys" -version = "2.3.0" +version = "3.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f36115160c57e8529781b4183c2bb51fdc1f6d6d1ed345591d84be7703befb3c" +checksum = "90c6c6e17136d4bc439d43a2f3c6ccf0731cccc016d897473a29791d3c2160c3" dependencies = [ "cc", ] @@ -1814,12 +1799,6 @@ dependencies = [ "windows-targets 0.53.3", ] -[[package]] -name = "linux-raw-sys" -version = "0.4.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" - [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -1844,9 +1823,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.27" +version = "0.4.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" [[package]] name = "macros" @@ -1854,7 +1833,7 @@ version = "0.1.0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] @@ -1881,7 +1860,7 @@ dependencies = [ [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "ark-std", @@ -1893,7 +1872,7 @@ dependencies = [ "rand", "raw-cpuid", "serdes", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -1941,8 +1920,7 @@ dependencies = [ [[package]] name = "mpi" version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677762a4bde2c81158fc566a69b97d11b0c3358694e64f4f922ac5189be311cc" +source = "git+https://github.com/rsmpi/rsmpi?rev=61796831954b679cbe267c1b704ddbcb7fef3715#61796831954b679cbe267c1b704ddbcb7fef3715" dependencies = [ "build-probe-mpi", "conv", @@ -1950,14 +1928,13 @@ dependencies = [ "mpi-sys", "once_cell", "smallvec", - "thiserror", + "thiserror 2.0.16", ] [[package]] name = "mpi-sys" version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f35fdd7bdb38959515f008d12598065631de9624f6d42c11caef19e8e0d10de" +source = "git+https://github.com/rsmpi/rsmpi?rev=61796831954b679cbe267c1b704ddbcb7fef3715#61796831954b679cbe267c1b704ddbcb7fef3715" dependencies = [ "bindgen", "build-probe-mpi", @@ -2093,7 +2070,7 @@ version = "0.10.73" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8" dependencies = [ - "bitflags 2.9.1", + "bitflags 2.9.4", "cfg-if", "foreign-types", "libc", @@ -2110,7 +2087,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] @@ -2186,9 +2163,9 @@ checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "percent-encoding" -version = "2.3.1" +version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" [[package]] name = "pin-project" @@ -2207,7 +2184,7 @@ checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] @@ -2259,7 +2236,7 @@ dependencies = [ [[package]] name = "poly_commit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "ark-std", @@ -2274,7 +2251,7 @@ dependencies = [ "rayon", "serdes", "sumcheck", - "thiserror", + "thiserror 1.0.69", "transcript", "transpose", "tree", @@ -2284,7 +2261,7 @@ dependencies = [ [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "ark-std", @@ -2312,9 +2289,9 @@ dependencies = [ [[package]] name = "potential_utf" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585" +checksum = "84df19adbe5b5a0782edcab45899906947ab039ccf4573713735ee7de1e6b08a" dependencies = [ "zerovec", ] @@ -2328,21 +2305,11 @@ dependencies = [ "zerocopy", ] -[[package]] -name = "prettyplease" -version = "0.2.36" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff24dfcda44452b9816fff4cd4227e1bb73ff5a2f1bc1105aa92fb8565ce44d2" -dependencies = [ - "proc-macro2", - "syn 2.0.104", -] - [[package]] name = "proc-macro2" -version = "1.0.95" +version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" dependencies = [ "unicode-ident", ] @@ -2404,14 +2371,14 @@ version = "11.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146" dependencies = [ - "bitflags 2.9.1", + "bitflags 2.9.4", ] [[package]] name = "rayon" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" dependencies = [ "either", "rayon-core", @@ -2419,9 +2386,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.1" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -2433,14 +2400,14 @@ version = "0.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5407465600fb0548f1442edf71dd20683c6ed326200ace4b1ef0763521bb3b77" dependencies = [ - "bitflags 2.9.1", + "bitflags 2.9.4", ] [[package]] name = "regex" -version = "1.11.1" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +checksum = "23d7fd106d8c02486a8d64e778353d1cffe08ce79ac2e82f540c86d0facf6912" dependencies = [ "aho-corasick", "memchr", @@ -2450,9 +2417,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6" dependencies = [ "aho-corasick", "memchr", @@ -2461,9 +2428,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" [[package]] name = "reqwest" @@ -2529,9 +2496,9 @@ checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" [[package]] name = "rustc-hash" -version = "1.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustc_version" @@ -2542,29 +2509,16 @@ dependencies = [ "semver", ] -[[package]] -name = "rustix" -version = "0.38.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" -dependencies = [ - "bitflags 2.9.1", - "errno", - "libc", - "linux-raw-sys 0.4.15", - "windows-sys 0.59.0", -] - [[package]] name = "rustix" version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8" dependencies = [ - "bitflags 2.9.1", + "bitflags 2.9.4", "errno", "libc", - "linux-raw-sys 0.9.4", + "linux-raw-sys", "windows-sys 0.60.2", ] @@ -2579,9 +2533,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.21" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" @@ -2625,7 +2579,7 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.9.1", + "bitflags 2.9.4", "core-foundation", "core-foundation-sys", "libc", @@ -2665,14 +2619,14 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] name = "serde_json" -version = "1.0.141" +version = "1.0.143" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3" +checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" dependencies = [ "itoa", "memchr", @@ -2705,22 +2659,22 @@ dependencies = [ [[package]] name = "serdes" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "ethnum", "halo2curves", "serdes_derive", - "thiserror", + "thiserror 1.0.69", ] [[package]] name = "serdes_derive" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] @@ -2772,18 +2726,18 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.5" +version = "1.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" +checksum = "b2a4719bff48cee6b39d12c020eeb490953ad2443b7055bd0b21fca26bd8c28b" dependencies = [ "libc", ] [[package]] name = "slab" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04dc19736151f35336d325007ac991178d504a119863a2fcb3758cdb5e52c50d" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" [[package]] name = "smallvec" @@ -2850,7 +2804,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "circuit", @@ -2878,9 +2832,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.104" +version = "2.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" dependencies = [ "proc-macro2", "quote", @@ -2907,7 +2861,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] @@ -2939,15 +2893,15 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.20.0" +version = "3.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" +checksum = "15b61f8f20e3a6f7e0649d825294eaf317edce30f82cf6026e7e4cb9222a7d1e" dependencies = [ "fastrand", "getrandom 0.3.3", "once_cell", - "rustix 1.0.8", - "windows-sys 0.59.0", + "rustix", + "windows-sys 0.60.2", ] [[package]] @@ -2956,7 +2910,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0" +dependencies = [ + "thiserror-impl 2.0.16", ] [[package]] @@ -2967,7 +2930,18 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] @@ -3001,9 +2975,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.47.0" +version = "1.47.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43864ed400b6043a4757a25c7a64a8efde741aed79a056a2fb348a406701bb35" +checksum = "89e49afdadebb872d3145a5638b59eb0691ea23e46ca484037cfab3b76b95038" dependencies = [ "backtrace", "bytes", @@ -3027,7 +3001,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] @@ -3054,9 +3028,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.15" +version = "0.7.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" +checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" dependencies = [ "bytes", "futures-core", @@ -3116,7 +3090,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "gkr_engine", @@ -3139,7 +3113,7 @@ dependencies = [ [[package]] name = "tree" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "arith", "ark-std", @@ -3167,7 +3141,7 @@ dependencies = [ "log", "rand", "sha1", - "thiserror", + "thiserror 1.0.69", "url", "utf-8", ] @@ -3211,13 +3185,14 @@ dependencies = [ [[package]] name = "url" -version = "2.5.4" +version = "2.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] @@ -3241,7 +3216,7 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "utils" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#228d0ec1ed1653393c3a69b3c6c15e4963e02152" +source = "git+https://github.com/PolyhedraZK/Expander?branch=main#944e0a26eed78dc38298dccc6b5ad1d89ae98236" dependencies = [ "colored", ] @@ -3314,11 +3289,11 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.14.2+wasi-0.2.4" +version = "0.14.3+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +checksum = "6a51ae83037bdd272a9e28ce236db8c07016dd0d50c27038b3f407533c030c95" dependencies = [ - "wit-bindgen-rt", + "wit-bindgen", ] [[package]] @@ -3343,7 +3318,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", "wasm-bindgen-shared", ] @@ -3378,7 +3353,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3402,18 +3377,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "which" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" -dependencies = [ - "either", - "home", - "once_cell", - "rustix 0.38.44", -] - [[package]] name = "win-sys" version = "0.3.1" @@ -3425,11 +3388,11 @@ dependencies = [ [[package]] name = "winapi-util" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +checksum = "0978bf7171b3d90bac376700cb56d606feb40f251a475a5d6634613564460b22" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -3466,7 +3429,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] @@ -3477,7 +3440,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] @@ -3767,13 +3730,10 @@ dependencies = [ ] [[package]] -name = "wit-bindgen-rt" -version = "0.39.0" +name = "wit-bindgen" +version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" -dependencies = [ - "bitflags 2.9.1", -] +checksum = "052283831dbae3d879dc7f51f3d92703a316ca49f91540417d38591826127814" [[package]] name = "writeable" @@ -3810,7 +3770,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", "synstructure", ] @@ -3831,7 +3791,7 @@ checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] @@ -3851,7 +3811,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", "synstructure", ] @@ -3872,7 +3832,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] [[package]] @@ -3888,9 +3848,9 @@ dependencies = [ [[package]] name = "zerovec" -version = "0.11.2" +version = "0.11.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428" +checksum = "e7aa2bd55086f1ab526693ecbe444205da57e25f4489879da80635a46d90e73b" dependencies = [ "yoke", "zerofrom", @@ -3905,5 +3865,5 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.104", + "syn 2.0.106", ] diff --git a/Cargo.toml b/Cargo.toml index 98820fbf..b2b59091 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-feat "bits", ] } hex = "0.4" -mpi = "0.8.0" +mpi = { git = "https://github.com/rsmpi/rsmpi", rev = "61796831954b679cbe267c1b704ddbcb7fef3715" } num-bigint = "0.4.6" num_cpus = "1.16.0" num-traits = "0.2.19" diff --git a/expander_compiler/tests/zkcuda/zkcuda_matmul.rs b/expander_compiler/tests/zkcuda/zkcuda_matmul.rs index 5182e146..7b20e63d 100644 --- a/expander_compiler/tests/zkcuda/zkcuda_matmul.rs +++ b/expander_compiler/tests/zkcuda/zkcuda_matmul.rs @@ -96,4 +96,4 @@ fn zkcuda_matmul_sum() { ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); -} \ No newline at end of file +} From 06f8620f5e649ad81f547b9038880a84efb916da Mon Sep 17 00:00:00 2001 From: hczphn Date: Wed, 3 Sep 2025 19:26:07 -0700 Subject: [PATCH 10/13] skip two reshape test --- .../proving_system/expander/prove_impl.rs | 2 +- expander_compiler/src/zkcuda/shape.rs | 1 + expander_compiler/src/zkcuda/tests.rs | 46 +++++++++---------- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index a42c018c..68d3feb1 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -98,7 +98,7 @@ pub fn prove_gkr_with_local_vals( expander_circuit.evaluate(); let (claimed_v, challenge) = gkr_prove(expander_circuit, prover_scratch, transcript, mpi_config); - assert_eq!(claimed_v, F::ChallengeField::from(0)); + assert_eq!(claimed_v, F::ChallengeField::from(0 as u32)); challenge } diff --git a/expander_compiler/src/zkcuda/shape.rs b/expander_compiler/src/zkcuda/shape.rs index 21c6ef72..69b41438 100644 --- a/expander_compiler/src/zkcuda/shape.rs +++ b/expander_compiler/src/zkcuda/shape.rs @@ -182,6 +182,7 @@ pub fn merge_shape_products(a: &[usize], b: &[usize]) -> Vec { } pub fn keep_shape_products_until(shape: &[usize], x: usize) -> Vec { + println!("shape: {:?}, x: {}", shape, x); let p = shape.iter().position(|&y| y == x).unwrap(); shape[..=p].to_vec() } diff --git a/expander_compiler/src/zkcuda/tests.rs b/expander_compiler/src/zkcuda/tests.rs index 715dd77d..e22f6fd7 100644 --- a/expander_compiler/src/zkcuda/tests.rs +++ b/expander_compiler/src/zkcuda/tests.rs @@ -125,12 +125,12 @@ fn context_shape_test_1_impl>() { P::post_process(); } -#[test] -#[allow(deprecated)] -fn context_shape_test_1() { - context_shape_test_1_impl::>(); - context_shape_test_1_impl::>(); -} +// #[test] +// #[allow(deprecated)] +// fn context_shape_test_1() { +// context_shape_test_1_impl::>(); +// context_shape_test_1_impl::>(); +// } /* In this test, we try to reshape a vector of length 15 into a shape of [3, 5] and then [5, 3]. @@ -138,23 +138,23 @@ fn context_shape_test_1() { The [5, 3] shape forces the lowlevel representation to be "xxx.xxx.xxx.xxx.xxx.............". They are incompatible, so it will panic at the second kernel call. */ -#[test] -#[should_panic(expected = "Detected illegal shape operation")] -fn context_shape_test_2() { - type C = M31Config; - type F = CircuitField; - let one = F::one(); - let identity_3 = compile_identity_3::().unwrap(); - let identity_5 = compile_identity_5::().unwrap(); - - let mut ctx: Context = Context::default(); - let a = ctx.copy_to_device(&vec![one; 15]); - let mut b = a.reshape(&[5, 3]); - let mut a = a.reshape(&[3, 5]); - call_kernel!(ctx, identity_5, 3, mut a).unwrap(); - call_kernel!(ctx, identity_3, 5, mut b).unwrap(); - let _ = (a, b); -} +// #[test] +// #[should_panic(expected = "Detected illegal shape operation")] +// fn context_shape_test_2() { +// type C = M31Config; +// type F = CircuitField; +// let one = F::one(); +// let identity_3 = compile_identity_3::().unwrap(); +// let identity_5 = compile_identity_5::().unwrap(); + +// let mut ctx: Context = Context::default(); +// let a = ctx.copy_to_device(&vec![one; 15]); +// let mut b = a.reshape(&[5, 3]); +// let mut a = a.reshape(&[3, 5]); +// call_kernel!(ctx, identity_5, 3, mut a).unwrap(); +// call_kernel!(ctx, identity_3, 5, mut b).unwrap(); +// let _ = (a, b); +// } #[test] fn context_shape_test_2_success() { From 4bb0e0eff3e327669dd962aba5280bda71f9c14e Mon Sep 17 00:00:00 2001 From: hczphn Date: Wed, 3 Sep 2025 20:39:10 -0700 Subject: [PATCH 11/13] fix a dim issue --- expander_compiler/src/zkcuda/context.rs | 2 +- expander_compiler/src/zkcuda/tests.rs | 46 ++++++++++++------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index e15d021d..aab81fa3 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -347,7 +347,7 @@ impl>> Context { .as_ref() .unwrap() .shape_history - .get_initial_split_list(ib == num_parallel); + .get_initial_split_list(ib == NOT_BROADCAST); let t = io.as_ref().unwrap().id; self.device_memories[t].required_shape_products = merge_shape_products( &isl, diff --git a/expander_compiler/src/zkcuda/tests.rs b/expander_compiler/src/zkcuda/tests.rs index e22f6fd7..715dd77d 100644 --- a/expander_compiler/src/zkcuda/tests.rs +++ b/expander_compiler/src/zkcuda/tests.rs @@ -125,12 +125,12 @@ fn context_shape_test_1_impl>() { P::post_process(); } -// #[test] -// #[allow(deprecated)] -// fn context_shape_test_1() { -// context_shape_test_1_impl::>(); -// context_shape_test_1_impl::>(); -// } +#[test] +#[allow(deprecated)] +fn context_shape_test_1() { + context_shape_test_1_impl::>(); + context_shape_test_1_impl::>(); +} /* In this test, we try to reshape a vector of length 15 into a shape of [3, 5] and then [5, 3]. @@ -138,23 +138,23 @@ fn context_shape_test_1_impl>() { The [5, 3] shape forces the lowlevel representation to be "xxx.xxx.xxx.xxx.xxx.............". They are incompatible, so it will panic at the second kernel call. */ -// #[test] -// #[should_panic(expected = "Detected illegal shape operation")] -// fn context_shape_test_2() { -// type C = M31Config; -// type F = CircuitField; -// let one = F::one(); -// let identity_3 = compile_identity_3::().unwrap(); -// let identity_5 = compile_identity_5::().unwrap(); - -// let mut ctx: Context = Context::default(); -// let a = ctx.copy_to_device(&vec![one; 15]); -// let mut b = a.reshape(&[5, 3]); -// let mut a = a.reshape(&[3, 5]); -// call_kernel!(ctx, identity_5, 3, mut a).unwrap(); -// call_kernel!(ctx, identity_3, 5, mut b).unwrap(); -// let _ = (a, b); -// } +#[test] +#[should_panic(expected = "Detected illegal shape operation")] +fn context_shape_test_2() { + type C = M31Config; + type F = CircuitField; + let one = F::one(); + let identity_3 = compile_identity_3::().unwrap(); + let identity_5 = compile_identity_5::().unwrap(); + + let mut ctx: Context = Context::default(); + let a = ctx.copy_to_device(&vec![one; 15]); + let mut b = a.reshape(&[5, 3]); + let mut a = a.reshape(&[3, 5]); + call_kernel!(ctx, identity_5, 3, mut a).unwrap(); + call_kernel!(ctx, identity_3, 5, mut b).unwrap(); + let _ = (a, b); +} #[test] fn context_shape_test_2_success() { From 168dcff52a58aa0b73a12c1f3dc12c73ed8337df Mon Sep 17 00:00:00 2001 From: hczphn Date: Wed, 3 Sep 2025 20:41:31 -0700 Subject: [PATCH 12/13] remove print --- expander_compiler/src/circuit/layered/opt.rs | 1 - expander_compiler/src/zkcuda/context.rs | 4 ---- .../src/zkcuda/proving_system/expander/prove_impl.rs | 2 +- expander_compiler/src/zkcuda/shape.rs | 1 - 4 files changed, 1 insertion(+), 7 deletions(-) diff --git a/expander_compiler/src/circuit/layered/opt.rs b/expander_compiler/src/circuit/layered/opt.rs index 21c50c1d..fcb3687f 100644 --- a/expander_compiler/src/circuit/layered/opt.rs +++ b/expander_compiler/src/circuit/layered/opt.rs @@ -601,7 +601,6 @@ impl Circuit { .map(|segment| segment.all_gates()) .collect(); let mut edges = Vec::new(); - //println!("segments: {}", self.segments.len()); for (i, i_gates) in all_gates.iter().enumerate() { for (j, j_gates) in sampled_gates.iter().enumerate().take(i) { let mut common_count = 0; diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index aab81fa3..6dd82f0a 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -332,10 +332,6 @@ impl>> Context { is_broadcast.push(NOT_BROADCAST); continue; } - /*println!( - "Checking shape compatibility for input/output {}: kernel_shape={:?}, io_shape={:?}, num_parallel={}", - i, kernel_shape, io, num_parallel - );*/ let io_shape = if let Some(handle) = io { handle.shape_history.shape() } else { diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index 68d3feb1..0eea8e61 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -98,7 +98,7 @@ pub fn prove_gkr_with_local_vals( expander_circuit.evaluate(); let (claimed_v, challenge) = gkr_prove(expander_circuit, prover_scratch, transcript, mpi_config); - assert_eq!(claimed_v, F::ChallengeField::from(0 as u32)); + assert_eq!(claimed_v, F::ChallengeField::from(0_u32)); challenge } diff --git a/expander_compiler/src/zkcuda/shape.rs b/expander_compiler/src/zkcuda/shape.rs index 69b41438..21c6ef72 100644 --- a/expander_compiler/src/zkcuda/shape.rs +++ b/expander_compiler/src/zkcuda/shape.rs @@ -182,7 +182,6 @@ pub fn merge_shape_products(a: &[usize], b: &[usize]) -> Vec { } pub fn keep_shape_products_until(shape: &[usize], x: usize) -> Vec { - println!("shape: {:?}, x: {}", shape, x); let p = shape.iter().position(|&y| y == x).unwrap(); shape[..=p].to_vec() } From 70acfca59999732304a1b1f5371f8e1fc6c9e1e2 Mon Sep 17 00:00:00 2001 From: hczphn Date: Fri, 5 Sep 2025 22:33:31 -0700 Subject: [PATCH 13/13] fix chunk_size when solve witness --- expander_compiler/src/zkcuda/context.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 6dd82f0a..6edda827 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -750,9 +750,10 @@ impl>> Context { let mut output_chunk_sizes: Vec> = vec![None; kernel_primitive.io_specs().len()]; let mut any_shape = None; - for ((input, ir_inputs), chunk_size) in kernel_call + for (((input, &ib), ir_inputs), chunk_size) in kernel_call .input_handles .iter() + .zip(kernel_call.is_broadcast.iter()) .zip(ir_inputs_all.iter_mut()) .zip(input_chunk_sizes.iter_mut()) { @@ -766,8 +767,7 @@ impl>> Context { let values = handle .shape_history .permute_vec(&self.device_memories[handle.id].values); - let kernel_shape = handle.shape_history.shape(); - *chunk_size = Some(kernel_shape.iter().product()); + *chunk_size = Some(values.len() * ib / kernel_call.num_parallel); *ir_inputs = values; } for (((output, &ib), ir_inputs), chunk_size) in kernel_call