diff --git a/Cargo.lock b/Cargo.lock index 8baf7066..ca286138 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3853,4 +3853,4 @@ dependencies = [ "proc-macro2", "quote", "syn 2.0.104", -] +] \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index b2b59091..2c667106 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,4 +65,4 @@ sumcheck = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } serdes = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "serdes" } gkr_engine = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } gkr_hashers = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -expander_utils = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "utils" } +expander_utils = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "utils" } \ No newline at end of file diff --git a/expander_compiler/src/circuit/layered/opt.rs b/expander_compiler/src/circuit/layered/opt.rs index 21c50c1d..fcb3687f 100644 --- a/expander_compiler/src/circuit/layered/opt.rs +++ b/expander_compiler/src/circuit/layered/opt.rs @@ -601,7 +601,6 @@ impl Circuit { .map(|segment| segment.all_gates()) .collect(); let mut edges = Vec::new(); - //println!("segments: {}", self.segments.len()); for (i, i_gates) in all_gates.iter().enumerate() { for (j, j_gates) in sampled_gates.iter().enumerate().take(i) { let mut common_count = 0; diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 327b3e2a..6edda827 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -24,6 +24,7 @@ use super::{ }; pub use macros::call_kernel; +const NOT_BROADCAST: usize = 1; struct DeviceMemory { values: Vec>, @@ -44,7 +45,7 @@ pub struct KernelCall { num_parallel: usize, input_handles: Vec, output_handles: Vec, - is_broadcast: Vec, + is_broadcast: Vec, } #[derive(PartialEq, Eq, Clone, Debug, ExpSerde)] @@ -53,7 +54,7 @@ pub struct ProofTemplate { pub commitment_indices: Vec, pub commitment_bit_orders: Vec, pub parallel_count: usize, - pub is_broadcast: Vec, + pub is_broadcast: Vec, } impl ProofTemplate { @@ -69,7 +70,7 @@ impl ProofTemplate { pub fn parallel_count(&self) -> usize { self.parallel_count } - pub fn is_broadcast(&self) -> &[bool] { + pub fn is_broadcast(&self) -> &[usize] { &self.is_broadcast } } @@ -156,17 +157,19 @@ fn check_shape_compat( kernel_shape: &Shape, io_shape: &Shape, parallel_count: usize, -) -> Option { +) -> Option { if kernel_shape.len() == io_shape.len() { if *kernel_shape == *io_shape { - Some(true) + Some(parallel_count) } else { None } } else if kernel_shape.len() + 1 == io_shape.len() { if io_shape.iter().skip(1).eq(kernel_shape.iter()) { if io_shape[0] == parallel_count { - Some(false) + Some(1) + } else if parallel_count % io_shape[0] == 0 { + Some(parallel_count / io_shape[0]) } else { None } @@ -299,18 +302,12 @@ impl>> Context { &self, values: &[SIMDField], s: &mut [SIMDField], - is_broadcast: bool, parallel_index: usize, chunk_size: Option, ) { - if is_broadcast { - s.copy_from_slice(values); - } else { - let chunk_size = chunk_size.unwrap(); - s.copy_from_slice( - &values[chunk_size * parallel_index..chunk_size * (parallel_index + 1)], - ); - } + let chunk_size = chunk_size.unwrap(); + let start_index = chunk_size * parallel_index % values.len(); + s.copy_from_slice(&values[start_index..(start_index + chunk_size)]); } pub fn call_kernel( @@ -332,13 +329,9 @@ impl>> Context { .enumerate() { if !spec.is_input { - is_broadcast.push(false); + is_broadcast.push(NOT_BROADCAST); continue; } - /*println!( - "Checking shape compatibility for input/output {}: kernel_shape={:?}, io_shape={:?}, num_parallel={}", - i, kernel_shape, io, num_parallel - );*/ let io_shape = if let Some(handle) = io { handle.shape_history.shape() } else { @@ -350,7 +343,7 @@ impl>> Context { .as_ref() .unwrap() .shape_history - .get_initial_split_list(!ib); + .get_initial_split_list(ib == NOT_BROADCAST); let t = io.as_ref().unwrap().id; self.device_memories[t].required_shape_products = merge_shape_products( &isl, @@ -371,7 +364,7 @@ impl>> Context { } } for (io_spec, ib) in kernel.io_specs().iter().zip(is_broadcast.iter()) { - if io_spec.is_output && *ib { + if io_spec.is_output && *ib != NOT_BROADCAST { panic!("Output is broadcasted, but it shouldn't be"); } } @@ -381,11 +374,11 @@ impl>> Context { let mut outputs_tmp = vec![Vec::new(); kernel.io_specs().len()]; let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()]; let mut chunk_sizes: Vec> = vec![None; kernel.io_specs().len()]; - for (((input, &ib), ir_inputs), chunk_size) in ios + for (((input, ir_inputs), chunk_size), kernel_shape) in ios .iter() - .zip(is_broadcast.iter()) .zip(ir_inputs_all.iter_mut()) .zip(chunk_sizes.iter_mut()) + .zip(kernel.io_shapes().iter()) { if input.is_none() { continue; @@ -394,9 +387,7 @@ impl>> Context { let values = handle .shape_history .permute_vec(&self.device_memories[handle.id].values); - if !ib { - *chunk_size = Some(values.len() / num_parallel); - } + *chunk_size = Some(kernel_shape.iter().product()); *ir_inputs = values; } let mut ir_inputs_per_parallel = Vec::new(); @@ -414,7 +405,6 @@ impl>> Context { self.ir_copy_from_device_memory( &ir_inputs_all[i], &mut ir_inputs[*input_start..*input_end], - is_broadcast[i], parallel_i, chunk_sizes[i], ); @@ -513,7 +503,7 @@ impl>> Context { .zip(kernel_call.input_handles.iter()) .zip(kernel_call.is_broadcast.iter()) { - if !spec.is_input || ib { + if !spec.is_input || ib > NOT_BROADCAST { continue; } let pad_shape = get_pad_shape(input_handle).unwrap(); @@ -526,7 +516,7 @@ impl>> Context { .zip(kernel_call.output_handles.iter()) .zip(kernel_call.is_broadcast.iter()) { - if !spec.is_output || ib { + if !spec.is_output || ib > NOT_BROADCAST { continue; } let pad_shape = get_pad_shape(output_handle).unwrap(); @@ -576,7 +566,6 @@ impl>> Context { self.state = ContextState::ComputationGraphDone; let dm_shapes = self.propagate_and_get_shapes(); - let (mut cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg { for (i, kernel) in cg.kernels.iter().enumerate() { assert_eq!(self.kernels.add(kernel), i); @@ -622,10 +611,10 @@ impl>> Context { let mut psi = Vec::new(); for (s, &ib) in pad_shapes_input.iter().zip(kernel_call.is_broadcast.iter()) { psi.push(s.as_ref().map(|t| { - if ib { + if ib == kernel_call.num_parallel { t.0.clone() } else { - keep_shape_since(&t.0, kernel_call.num_parallel) + keep_shape_since(&t.0, kernel_call.num_parallel / ib) } })); } @@ -635,7 +624,7 @@ impl>> Context { .zip(kernel_call.is_broadcast.iter()) { pso.push(s.as_ref().map(|t| { - if ib { + if ib == kernel_call.num_parallel { t.0.clone() } else { keep_shape_since(&t.0, kernel_call.num_parallel) @@ -661,7 +650,7 @@ impl>> Context { commitment_indices.push(handle.as_ref().unwrap().id); commitment_bit_orders.push(shape.1.clone()); is_broadcast.push(ib); - if !ib { + if ib == NOT_BROADCAST { any_shape = Some(shape.0.clone()); } } @@ -678,7 +667,7 @@ impl>> Context { commitment_indices.push(handle.as_ref().unwrap().id); commitment_bit_orders.push(shape.1.clone()); is_broadcast.push(ib); - if !ib { + if ib == NOT_BROADCAST { any_shape = Some(shape.0.clone()); } } @@ -695,7 +684,7 @@ impl>> Context { dm_max += 1; commitment_bit_orders.push((0..n.trailing_zeros() as usize).collect()); commitments_lens.push(n); - is_broadcast.push(false); + is_broadcast.push(NOT_BROADCAST); } let kernel_id = self.kernels.add(&kernel); @@ -778,9 +767,7 @@ impl>> Context { let values = handle .shape_history .permute_vec(&self.device_memories[handle.id].values); - if !ib { - *chunk_size = Some(values.len() / kernel_call.num_parallel); - } + *chunk_size = Some(values.len() * ib / kernel_call.num_parallel); *ir_inputs = values; } for (((output, &ib), ir_inputs), chunk_size) in kernel_call @@ -800,7 +787,7 @@ impl>> Context { let values = handle .shape_history .permute_vec(&self.device_memories[handle.id].values); - assert!(!ib); + assert!(ib == NOT_BROADCAST); *chunk_size = Some(values.len() / kernel_call.num_parallel); *ir_inputs = values; } @@ -823,7 +810,6 @@ impl>> Context { self.ir_copy_from_device_memory( ir_inputs, &mut inputs[*input_start..*input_end], - chunk_size.is_none(), parallel_i, *chunk_size, ); @@ -843,7 +829,6 @@ impl>> Context { self.ir_copy_from_device_memory( ir_outputs, &mut inputs[*output_start..*output_end], - chunk_size.is_none(), parallel_i, *chunk_size, ); diff --git a/expander_compiler/src/zkcuda/kernel.rs b/expander_compiler/src/zkcuda/kernel.rs index b7ae0f20..fc6fdee6 100644 --- a/expander_compiler/src/zkcuda/kernel.rs +++ b/expander_compiler/src/zkcuda/kernel.rs @@ -301,9 +301,8 @@ fn reorder_ir_inputs( lc_in[i].len = n; assert!(var_max % n == 0); let im = shape_padded_mapping(&pad_shapes[i]); - // println!("{:?}", im.mapping()); for (j, &k) in im.mapping().iter().enumerate() { - var_new_id[prev + k + 1] = var_max + j + 1; + var_new_id[prev + j + 1] = var_max + k + 1; } var_max += n; } diff --git a/expander_compiler/src/zkcuda/mpi_mem_share.rs b/expander_compiler/src/zkcuda/mpi_mem_share.rs index 5c79c7db..dc2042a8 100644 --- a/expander_compiler/src/zkcuda/mpi_mem_share.rs +++ b/expander_compiler/src/zkcuda/mpi_mem_share.rs @@ -132,7 +132,7 @@ impl MPISharedMemory for ProofTemplate { .map(|_| BitOrder::new_from_memory(ptr)) .collect(); let parallel_count = usize::new_from_memory(ptr); - let is_broadcast = Vec::::new_from_memory(ptr); + let is_broadcast = Vec::::new_from_memory(ptr); ProofTemplate { kernel_id, diff --git a/expander_compiler/src/zkcuda/proving_system/common.rs b/expander_compiler/src/zkcuda/proving_system/common.rs index b8cd617a..89d648c9 100644 --- a/expander_compiler/src/zkcuda/proving_system/common.rs +++ b/expander_compiler/src/zkcuda/proving_system/common.rs @@ -14,7 +14,7 @@ pub fn check_inputs( kernel: &Kernel, values: &[&[SIMDField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) { if kernel.layered_circuit_input().len() != values.len() { panic!("Input size mismatch"); @@ -23,11 +23,9 @@ pub fn check_inputs( panic!("Input size mismatch"); } for i in 0..kernel.layered_circuit_input().len() { - if is_broadcast[i] { - if kernel.layered_circuit_input()[i].len != values[i].len() { - panic!("Input size mismatch"); - } - } else if kernel.layered_circuit_input()[i].len * parallel_count != values[i].len() { + if kernel.layered_circuit_input()[i].len + != values[i].len() / (parallel_count / is_broadcast[i]) + { panic!("Input size mismatch"); } } @@ -37,24 +35,20 @@ pub fn prepare_inputs( layered_circuit: &Circuit, partition_info: &[LayeredCircuitInputVec], values: &[&[SIMDField]], - is_broadcast: &[bool], + is_broadcast: &[usize], + parallel_count: usize, parallel_index: usize, ) -> Vec> { let mut lc_input = vec![SIMDField::::zero(); layered_circuit.input_size()]; for ((input, value), ib) in partition_info.iter().zip(values.iter()).zip(is_broadcast) { - if *ib { - for (i, x) in value.iter().enumerate() { - lc_input[input.offset + i] = *x; - } - } else { - for (i, x) in value - .iter() - .skip(parallel_index * input.len) - .take(input.len) - .enumerate() - { - lc_input[input.offset + i] = *x; - } + let parallel_index = parallel_index % (parallel_count / ib); + for (i, x) in value + .iter() + .skip(parallel_index * input.len) + .take(input.len) + .enumerate() + { + lc_input[input.offset + i] = *x; } } lc_input diff --git a/expander_compiler/src/zkcuda/proving_system/dummy.rs b/expander_compiler/src/zkcuda/proving_system/dummy.rs index b18beb6b..64b19c1c 100644 --- a/expander_compiler/src/zkcuda/proving_system/dummy.rs +++ b/expander_compiler/src/zkcuda/proving_system/dummy.rs @@ -74,7 +74,7 @@ impl KernelWiseProvingSystem for DummyProvingSystem { _commitments_state: &[&Self::CommitmentState], commitments_values: &[&[SIMDField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> DummyProof { check_inputs(kernel, commitments_values, parallel_count, is_broadcast); let mut res = vec![]; @@ -84,6 +84,7 @@ impl KernelWiseProvingSystem for DummyProvingSystem { kernel.layered_circuit_input(), commitments_values, is_broadcast, + parallel_count, i, ); let (_, cond) = kernel @@ -104,7 +105,7 @@ impl KernelWiseProvingSystem for DummyProvingSystem { proof: &Self::Proof, commitments: &[&Self::Commitment], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> bool { let values = commitments.iter().map(|c| &c.vals[..]).collect::>(); check_inputs(kernel, &values, parallel_count, is_broadcast); @@ -114,6 +115,7 @@ impl KernelWiseProvingSystem for DummyProvingSystem { kernel.layered_circuit_input(), &values, is_broadcast, + parallel_count, i, ); let (_, cond) = kernel diff --git a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs index 836201ac..aecd2a0c 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs @@ -61,7 +61,7 @@ where _commitments_state: &[&Self::CommitmentState], commitments_values: &[&[SIMDField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> Self::Proof { let timer = Timer::new("prove", true); check_inputs(kernel, commitments_values, parallel_count, is_broadcast); @@ -113,7 +113,7 @@ where proof: &Self::Proof, commitments: &[&Self::Commitment], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> bool { let timer = Timer::new("verify", true); let mut expander_circuit = kernel.layered_circuit().export_to_expander_flatten(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index ad7b3121..0eea8e61 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -48,7 +48,7 @@ where /// This function returns the local values for each parallel instance based on the global values and the broadcast information. pub fn get_local_vals<'vals_life, F: Field>( global_vals: &'vals_life [impl AsRef<[F]>], - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_index: usize, parallel_num: usize, ) -> Vec<&'vals_life [F]> { @@ -56,12 +56,11 @@ pub fn get_local_vals<'vals_life, F: Field>( .iter() .zip(is_broadcast.iter()) .map(|(vals, is_broadcast)| { - if *is_broadcast { - vals.as_ref() - } else { - let local_val_len = vals.as_ref().len() / parallel_num; - &vals.as_ref()[local_val_len * parallel_index..local_val_len * (parallel_index + 1)] - } + let is_broadcast_next_power_of_two = is_broadcast.next_power_of_two(); + let local_val_len = + vals.as_ref().len() / (parallel_num / is_broadcast_next_power_of_two); + let start_index = local_val_len * parallel_index % vals.as_ref().len(); + &vals.as_ref()[start_index..local_val_len + start_index] }) .collect::>() } @@ -99,7 +98,7 @@ pub fn prove_gkr_with_local_vals( expander_circuit.evaluate(); let (claimed_v, challenge) = gkr_prove(expander_circuit, prover_scratch, transcript, mpi_config); - assert_eq!(claimed_v, F::ChallengeField::from(0)); + assert_eq!(claimed_v, F::ChallengeField::from(0_u32)); challenge } @@ -133,22 +132,23 @@ pub fn partition_challenge_and_location_for_pcs_no_mpi( total_vals_len: usize, parallel_index: usize, parallel_count: usize, - is_broadcast: bool, + is_broadcast: usize, ) -> (ExpanderSingleVarChallenge, Vec) { assert_eq!(gkr_challenge.r_mpi.len(), 0); let mut challenge = gkr_challenge.clone(); let zero = F::ChallengeField::ZERO; - if is_broadcast { + if is_broadcast == parallel_count { let n_vals_vars = total_vals_len.ilog2() as usize; let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); challenge.rz.resize(n_vals_vars, zero); (challenge, component_idx_vars) } else { - let n_vals_vars = (total_vals_len / parallel_count).ilog2() as usize; + let real_parallel_count = parallel_count / is_broadcast; + let n_vals_vars = (total_vals_len / real_parallel_count).ilog2() as usize; let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); challenge.rz.resize(n_vals_vars, zero); - let n_index_vars = parallel_count.ilog2() as usize; + let n_index_vars = real_parallel_count.ilog2() as usize; let index_vars = (0..n_index_vars) .map(|i| F::ChallengeField::from(((parallel_index >> i) & 1) as u32)) .collect::>(); @@ -206,7 +206,7 @@ pub fn partition_gkr_claims_and_open_pcs_no_mpi_impl( gkr_claim: &ExpanderSingleVarChallenge, global_vals: &[impl AsRef<[::SimdCircuitField]>], p_keys: &ExpanderProverSetup, - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_index: usize, parallel_num: usize, transcript: &mut C::TranscriptConfig, @@ -218,7 +218,6 @@ pub fn partition_gkr_claims_and_open_pcs_no_mpi_impl( >( gkr_claim, val_len, parallel_index, parallel_num, *ib ); - pcs_local_open_impl::( commitment_val.as_ref(), &challenge_for_pcs, @@ -235,7 +234,7 @@ pub fn partition_gkr_claims_and_open_pcs_no_mpi( gkr_claim: &ExpanderDualVarChallenge, global_vals: &[impl AsRef<[::SimdCircuitField]>], p_keys: &ExpanderProverSetup, - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_index: usize, parallel_num: usize, transcript: &mut C::TranscriptConfig, diff --git a/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs index d5c40f2a..14abf848 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs @@ -76,7 +76,7 @@ pub fn verify_pcs_opening_and_aggregation_no_mpi_impl( challenge: &ExpanderSingleVarChallenge, y: &::ChallengeField, commitments: &[&ExpanderCommitment], - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_index: usize, parallel_count: usize, transcript: &mut C::TranscriptConfig, @@ -144,7 +144,7 @@ pub fn verify_pcs_opening_and_aggregation_no_mpi( claim_v0: ::ChallengeField, claim_v1: Option<::ChallengeField>, commitments: &[&ExpanderCommitment], - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_index: usize, parallel_count: usize, transcript: &mut C::TranscriptConfig, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index bc980372..d94eabb5 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -63,10 +63,8 @@ where (None, None) }; commit_timer.stop(); - let mut vals_ref = vec![]; let mut challenges = vec![]; - let prove_timer = Timer::new("Prove all kernels", global_mpi_config.is_root()); let proofs = computation_graph @@ -78,7 +76,6 @@ where .iter() .map(|&idx| values[idx].as_ref()) .collect::>(); - let single_kernel_gkr_timer = Timer::new("small gkr kernel", global_mpi_config.is_root()); let gkr_end_state = prove_kernel_gkr_no_oversubscribe::< @@ -200,7 +197,7 @@ pub fn prove_kernel_gkr_no_oversubscribe( kernel: &Kernel, commitments_values: &[&[F::SimdCircuitField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], n_bytes_profiler: &mut NBytesProfiler, ) -> Option<(T, ExpanderDualVarChallenge)> where @@ -313,6 +310,46 @@ where is_broadcast, n_bytes_profiler, ), + 4096 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 8192 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 16384 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 32768 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 65536 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), _ => { panic!("Unsupported parallel count: {parallel_count}"); } @@ -324,7 +361,7 @@ pub fn prove_kernel_gkr_internal( kernel: &Kernel, commitments_values: &[&[FBasic::SimdCircuitField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], n_bytes_profiler: &mut NBytesProfiler, ) -> Option<(T, ExpanderDualVarChallenge)> where @@ -337,7 +374,6 @@ where let world_rank = mpi_config.world_rank(); let world_size = mpi_config.world_size(); let n_copies = parallel_count / world_size; - let local_commitment_values = get_local_vals_multi_copies( commitments_values, is_broadcast, @@ -365,7 +401,7 @@ where pub fn get_local_vals_multi_copies<'vals_life, F: Field>( global_vals: &'vals_life [impl AsRef<[F]>], - is_broadcast: &[bool], + is_broadcast: &[usize], local_world_rank: usize, n_copies: usize, parallel_count: usize, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index 85c5955f..a0bc026e 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -59,7 +59,6 @@ where values: &[impl AsRef<[SIMDField]>], ) -> Option>> { let mut n_bytes_profiler = NBytesProfiler::new(); - #[cfg(feature = "zkcuda_profile")] { use arith::SimdField; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index 42315b39..25773ce0 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -147,7 +147,6 @@ where ECCConfig: Config, { let timer = Timer::new("prove", true); - SharedMemoryEngine::write_witness_to_shared_memory::(device_memories); wait_async(ClientHttpHelper::request_prove()); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index 5605daf0..a045593b 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -127,7 +127,7 @@ pub fn prove_kernel_gkr( kernel: &Kernel, commitments_values: &[&[F::SimdCircuitField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> Option<(T, ExpanderDualVarChallenge)> where F: FieldEngine, @@ -169,22 +169,25 @@ pub fn partition_challenge_and_location_for_pcs_mpi( gkr_challenge: &ExpanderSingleVarChallenge, total_vals_len: usize, parallel_count: usize, - is_broadcast: bool, + broadcast_num: usize, ) -> (ExpanderSingleVarChallenge, Vec) { let mut challenge = gkr_challenge.clone(); let zero = F::ChallengeField::ZERO; - if is_broadcast { + if broadcast_num == parallel_count { let n_vals_vars = total_vals_len.ilog2() as usize; let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); challenge.rz.resize(n_vals_vars, zero); challenge.r_mpi.clear(); (challenge, component_idx_vars) } else { - let n_vals_vars = (total_vals_len / parallel_count).ilog2() as usize; + let n_vals_vars = (total_vals_len / (parallel_count / broadcast_num.next_power_of_two())) + .ilog2() as usize; let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); challenge.rz.resize(n_vals_vars, zero); - - challenge.rz.extend_from_slice(&challenge.r_mpi); + //TODO: what is challenge.r_mpi, why need it when broadcast is false? + challenge.rz.extend_from_slice( + &challenge.r_mpi[..(parallel_count / broadcast_num).ilog2() as usize], + ); challenge.r_mpi.clear(); (challenge, component_idx_vars) } @@ -196,7 +199,7 @@ pub fn partition_single_gkr_claim_and_open_pcs_mpi( commitments_values: &[impl AsRef<[SIMDField]>], commitments_state: &[&ExpanderCommitmentState], gkr_challenge: &ExpanderSingleVarChallenge, - is_broadcast: &[bool], + is_broadcast: &[usize], transcript: &mut C::TranscriptConfig, ) { let parallel_count = 1 << gkr_challenge.r_mpi.len(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs index 5b1d28a6..3d17ac3a 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs @@ -30,7 +30,7 @@ pub fn verify_kernel( proof: &ExpanderProof, commitments: &[&ExpanderCommitment], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> bool where C: GKREngine, @@ -86,7 +86,7 @@ pub fn verify_pcs_opening_and_aggregation_mpi_impl( challenge: &ExpanderSingleVarChallenge, y: &::ChallengeField, commitments: &[&ExpanderCommitment], - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_count: usize, transcript: &mut C::TranscriptConfig, ) -> bool @@ -145,7 +145,7 @@ pub fn verify_pcs_opening_and_aggregation_mpi( claim_v0: ::ChallengeField, claim_v1: Option<::ChallengeField>, commitments: &[&ExpanderCommitment], - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_count: usize, transcript: &mut C::TranscriptConfig, ) -> bool diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs index 72545956..bdcb1123 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs @@ -189,7 +189,7 @@ where pub fn extract_pcs_claims<'a, C: GKREngine>( commitments_values: &[&'a [SIMDField]], gkr_challenge: &ExpanderSingleVarChallenge, - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_count: usize, ) -> ( Vec<&'a [SIMDField]>, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs index b06be3df..b08e297c 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs @@ -26,7 +26,7 @@ use crate::{ fn verifier_extract_pcs_claims<'a, C, ECCConfig>( commitments: &[&'a ExpanderCommitment], gkr_challenge: &ExpanderSingleVarChallenge, - is_broadcast: &[bool], + is_broadcast: &[usize], parallel_count: usize, ) -> ( Vec<&'a ExpanderCommitment>, diff --git a/expander_compiler/src/zkcuda/proving_system/traits.rs b/expander_compiler/src/zkcuda/proving_system/traits.rs index 539a0f43..82082f1a 100644 --- a/expander_compiler/src/zkcuda/proving_system/traits.rs +++ b/expander_compiler/src/zkcuda/proving_system/traits.rs @@ -30,7 +30,7 @@ pub trait KernelWiseProvingSystem { commitments_state: &[&Self::CommitmentState], commitments_values: &[&[SIMDField]], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> Self::Proof; fn verify_kernel( @@ -39,7 +39,7 @@ pub trait KernelWiseProvingSystem { proof: &Self::Proof, commitments: &[&Self::Commitment], parallel_count: usize, - is_broadcast: &[bool], + is_broadcast: &[usize], ) -> bool; fn post_process() {} diff --git a/expander_compiler/src/zkcuda/shape.rs b/expander_compiler/src/zkcuda/shape.rs index 443651d2..21c6ef72 100644 --- a/expander_compiler/src/zkcuda/shape.rs +++ b/expander_compiler/src/zkcuda/shape.rs @@ -242,8 +242,8 @@ impl ShapeHistory { // Suppose we need to ensure that the current shape is legal // This function returns a list of dimension lengths where the initial vector must be split // split_first_dim: first dimension of current shape will be split - pub fn get_initial_split_list(&self, split_first_dim: bool) -> Vec { - let last_entry = self.entries.last().unwrap().minimize(split_first_dim); + pub fn get_initial_split_list(&self, keep_first_dim: bool) -> Vec { + let last_entry = self.entries.last().unwrap().minimize(keep_first_dim); let mut split_list = prefix_products(&last_entry.shape); for e in self.entries.iter().rev().skip(1) { let e = e.minimize(false);