Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ sumcheck = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" }
serdes = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "serdes" }
gkr_engine = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" }
gkr_hashers = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" }
expander_utils = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "utils" }
expander_utils = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "utils" }
1 change: 0 additions & 1 deletion expander_compiler/src/circuit/layered/opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,6 @@ impl<C: Config, I: InputType> Circuit<C, I> {
.map(|segment| segment.all_gates())
.collect();
let mut edges = Vec::new();
//println!("segments: {}", self.segments.len());
for (i, i_gates) in all_gates.iter().enumerate() {
for (j, j_gates) in sampled_gates.iter().enumerate().take(i) {
let mut common_count = 0;
Expand Down
71 changes: 28 additions & 43 deletions expander_compiler/src/zkcuda/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use super::{
};

pub use macros::call_kernel;
const NOT_BROADCAST: usize = 1;

struct DeviceMemory<C: Config> {
values: Vec<SIMDField<C>>,
Expand All @@ -44,7 +45,7 @@ pub struct KernelCall {
num_parallel: usize,
input_handles: Vec<DeviceMemoryHandle>,
output_handles: Vec<DeviceMemoryHandle>,
is_broadcast: Vec<bool>,
is_broadcast: Vec<usize>,
}

#[derive(PartialEq, Eq, Clone, Debug, ExpSerde)]
Expand All @@ -53,7 +54,7 @@ pub struct ProofTemplate {
pub commitment_indices: Vec<usize>,
pub commitment_bit_orders: Vec<BitOrder>,
pub parallel_count: usize,
pub is_broadcast: Vec<bool>,
pub is_broadcast: Vec<usize>,
}

impl ProofTemplate {
Expand All @@ -69,7 +70,7 @@ impl ProofTemplate {
pub fn parallel_count(&self) -> usize {
self.parallel_count
}
pub fn is_broadcast(&self) -> &[bool] {
pub fn is_broadcast(&self) -> &[usize] {
&self.is_broadcast
}
}
Expand Down Expand Up @@ -156,17 +157,19 @@ fn check_shape_compat(
kernel_shape: &Shape,
io_shape: &Shape,
parallel_count: usize,
) -> Option<bool> {
) -> Option<usize> {
if kernel_shape.len() == io_shape.len() {
if *kernel_shape == *io_shape {
Some(true)
Some(parallel_count)
} else {
None
}
} else if kernel_shape.len() + 1 == io_shape.len() {
if io_shape.iter().skip(1).eq(kernel_shape.iter()) {
if io_shape[0] == parallel_count {
Some(false)
Some(1)
} else if parallel_count % io_shape[0] == 0 {
Some(parallel_count / io_shape[0])
} else {
None
}
Expand Down Expand Up @@ -299,18 +302,12 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
&self,
values: &[SIMDField<C>],
s: &mut [SIMDField<C>],
is_broadcast: bool,
parallel_index: usize,
chunk_size: Option<usize>,
) {
if is_broadcast {
s.copy_from_slice(values);
} else {
let chunk_size = chunk_size.unwrap();
s.copy_from_slice(
&values[chunk_size * parallel_index..chunk_size * (parallel_index + 1)],
);
}
let chunk_size = chunk_size.unwrap();
let start_index = chunk_size * parallel_index % values.len();
s.copy_from_slice(&values[start_index..(start_index + chunk_size)]);
}

pub fn call_kernel(
Expand All @@ -332,13 +329,9 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
.enumerate()
{
if !spec.is_input {
is_broadcast.push(false);
is_broadcast.push(NOT_BROADCAST);
continue;
}
/*println!(
"Checking shape compatibility for input/output {}: kernel_shape={:?}, io_shape={:?}, num_parallel={}",
i, kernel_shape, io, num_parallel
);*/
let io_shape = if let Some(handle) = io {
handle.shape_history.shape()
} else {
Expand All @@ -350,7 +343,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
.as_ref()
.unwrap()
.shape_history
.get_initial_split_list(!ib);
.get_initial_split_list(ib == NOT_BROADCAST);
let t = io.as_ref().unwrap().id;
self.device_memories[t].required_shape_products = merge_shape_products(
&isl,
Expand All @@ -371,7 +364,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
}
}
for (io_spec, ib) in kernel.io_specs().iter().zip(is_broadcast.iter()) {
if io_spec.is_output && *ib {
if io_spec.is_output && *ib != NOT_BROADCAST {
panic!("Output is broadcasted, but it shouldn't be");
}
}
Expand All @@ -381,11 +374,11 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
let mut outputs_tmp = vec![Vec::new(); kernel.io_specs().len()];
let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()];
let mut chunk_sizes: Vec<Option<usize>> = vec![None; kernel.io_specs().len()];
for (((input, &ib), ir_inputs), chunk_size) in ios
for (((input, ir_inputs), chunk_size), kernel_shape) in ios
.iter()
.zip(is_broadcast.iter())
.zip(ir_inputs_all.iter_mut())
.zip(chunk_sizes.iter_mut())
.zip(kernel.io_shapes().iter())
{
if input.is_none() {
continue;
Expand All @@ -394,9 +387,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
let values = handle
.shape_history
.permute_vec(&self.device_memories[handle.id].values);
if !ib {
*chunk_size = Some(values.len() / num_parallel);
}
*chunk_size = Some(kernel_shape.iter().product());
*ir_inputs = values;
}
let mut ir_inputs_per_parallel = Vec::new();
Expand All @@ -414,7 +405,6 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
self.ir_copy_from_device_memory(
&ir_inputs_all[i],
&mut ir_inputs[*input_start..*input_end],
is_broadcast[i],
parallel_i,
chunk_sizes[i],
);
Expand Down Expand Up @@ -513,7 +503,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
.zip(kernel_call.input_handles.iter())
.zip(kernel_call.is_broadcast.iter())
{
if !spec.is_input || ib {
if !spec.is_input || ib > NOT_BROADCAST {
continue;
}
let pad_shape = get_pad_shape(input_handle).unwrap();
Expand All @@ -526,7 +516,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
.zip(kernel_call.output_handles.iter())
.zip(kernel_call.is_broadcast.iter())
{
if !spec.is_output || ib {
if !spec.is_output || ib > NOT_BROADCAST {
continue;
}
let pad_shape = get_pad_shape(output_handle).unwrap();
Expand Down Expand Up @@ -576,7 +566,6 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
self.state = ContextState::ComputationGraphDone;

let dm_shapes = self.propagate_and_get_shapes();

let (mut cg_kernels, cg_proof_templates, cg_commitments_lens) = if let Some(cg) = cg {
for (i, kernel) in cg.kernels.iter().enumerate() {
assert_eq!(self.kernels.add(kernel), i);
Expand Down Expand Up @@ -622,10 +611,10 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
let mut psi = Vec::new();
for (s, &ib) in pad_shapes_input.iter().zip(kernel_call.is_broadcast.iter()) {
psi.push(s.as_ref().map(|t| {
if ib {
if ib == kernel_call.num_parallel {
t.0.clone()
} else {
keep_shape_since(&t.0, kernel_call.num_parallel)
keep_shape_since(&t.0, kernel_call.num_parallel / ib)
}
}));
}
Expand All @@ -635,7 +624,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
.zip(kernel_call.is_broadcast.iter())
{
pso.push(s.as_ref().map(|t| {
if ib {
if ib == kernel_call.num_parallel {
t.0.clone()
} else {
keep_shape_since(&t.0, kernel_call.num_parallel)
Expand All @@ -661,7 +650,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
commitment_indices.push(handle.as_ref().unwrap().id);
commitment_bit_orders.push(shape.1.clone());
is_broadcast.push(ib);
if !ib {
if ib == NOT_BROADCAST {
any_shape = Some(shape.0.clone());
}
}
Expand All @@ -678,7 +667,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
commitment_indices.push(handle.as_ref().unwrap().id);
commitment_bit_orders.push(shape.1.clone());
is_broadcast.push(ib);
if !ib {
if ib == NOT_BROADCAST {
any_shape = Some(shape.0.clone());
}
}
Expand All @@ -695,7 +684,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
dm_max += 1;
commitment_bit_orders.push((0..n.trailing_zeros() as usize).collect());
commitments_lens.push(n);
is_broadcast.push(false);
is_broadcast.push(NOT_BROADCAST);
}

let kernel_id = self.kernels.add(&kernel);
Expand Down Expand Up @@ -778,9 +767,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
let values = handle
.shape_history
.permute_vec(&self.device_memories[handle.id].values);
if !ib {
*chunk_size = Some(values.len() / kernel_call.num_parallel);
}
*chunk_size = Some(values.len() * ib / kernel_call.num_parallel);
*ir_inputs = values;
}
for (((output, &ib), ir_inputs), chunk_size) in kernel_call
Expand All @@ -800,7 +787,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
let values = handle
.shape_history
.permute_vec(&self.device_memories[handle.id].values);
assert!(!ib);
assert!(ib == NOT_BROADCAST);
*chunk_size = Some(values.len() / kernel_call.num_parallel);
*ir_inputs = values;
}
Expand All @@ -823,7 +810,6 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
self.ir_copy_from_device_memory(
ir_inputs,
&mut inputs[*input_start..*input_end],
chunk_size.is_none(),
parallel_i,
*chunk_size,
);
Expand All @@ -843,7 +829,6 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
self.ir_copy_from_device_memory(
ir_outputs,
&mut inputs[*output_start..*output_end],
chunk_size.is_none(),
parallel_i,
*chunk_size,
);
Expand Down
3 changes: 1 addition & 2 deletions expander_compiler/src/zkcuda/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,8 @@ fn reorder_ir_inputs<C: Config>(
lc_in[i].len = n;
assert!(var_max % n == 0);
let im = shape_padded_mapping(&pad_shapes[i]);
// println!("{:?}", im.mapping());
for (j, &k) in im.mapping().iter().enumerate() {
var_new_id[prev + k + 1] = var_max + j + 1;
var_new_id[prev + j + 1] = var_max + k + 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line appears to have incorrect logic for remapping variable IDs. var_new_id should map from the old variable ID to the new one. The im.mapping() provides a mapping from a padded index j to an original index k. The old variable ID is based on k, and the new ID is based on j.

The current implementation var_new_id[prev + j + 1] = var_max + k + 1; attempts to use a padded index j to index var_new_id, which is sized for original, unpadded inputs. This will likely lead to an out-of-bounds panic when padding is active (padded_len > original_len).

The logic should be var_new_id[old_id] = new_id, which translates to var_new_id[prev + k + 1] = var_max + j + 1;.

Suggested change
var_new_id[prev + j + 1] = var_max + k + 1;
var_new_id[prev + k + 1] = var_max + j + 1;

}
var_max += n;
}
Expand Down
2 changes: 1 addition & 1 deletion expander_compiler/src/zkcuda/mpi_mem_share.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl MPISharedMemory for ProofTemplate {
.map(|_| BitOrder::new_from_memory(ptr))
.collect();
let parallel_count = usize::new_from_memory(ptr);
let is_broadcast = Vec::<bool>::new_from_memory(ptr);
let is_broadcast = Vec::<usize>::new_from_memory(ptr);

ProofTemplate {
kernel_id,
Expand Down
34 changes: 14 additions & 20 deletions expander_compiler/src/zkcuda/proving_system/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub fn check_inputs<C: Config>(
kernel: &Kernel<C>,
values: &[&[SIMDField<C>]],
parallel_count: usize,
is_broadcast: &[bool],
is_broadcast: &[usize],
) {
if kernel.layered_circuit_input().len() != values.len() {
panic!("Input size mismatch");
Expand All @@ -23,11 +23,9 @@ pub fn check_inputs<C: Config>(
panic!("Input size mismatch");
}
for i in 0..kernel.layered_circuit_input().len() {
if is_broadcast[i] {
if kernel.layered_circuit_input()[i].len != values[i].len() {
panic!("Input size mismatch");
}
} else if kernel.layered_circuit_input()[i].len * parallel_count != values[i].len() {
if kernel.layered_circuit_input()[i].len
!= values[i].len() / (parallel_count / is_broadcast[i])
{
panic!("Input size mismatch");
}
}
Expand All @@ -37,24 +35,20 @@ pub fn prepare_inputs<C: Config>(
layered_circuit: &Circuit<C, NormalInputType>,
partition_info: &[LayeredCircuitInputVec],
values: &[&[SIMDField<C>]],
is_broadcast: &[bool],
is_broadcast: &[usize],
parallel_count: usize,
parallel_index: usize,
) -> Vec<SIMDField<C>> {
let mut lc_input = vec![SIMDField::<C>::zero(); layered_circuit.input_size()];
for ((input, value), ib) in partition_info.iter().zip(values.iter()).zip(is_broadcast) {
if *ib {
for (i, x) in value.iter().enumerate() {
lc_input[input.offset + i] = *x;
}
} else {
for (i, x) in value
.iter()
.skip(parallel_index * input.len)
.take(input.len)
.enumerate()
{
lc_input[input.offset + i] = *x;
}
let parallel_index = parallel_index % (parallel_count / ib);
for (i, x) in value
.iter()
.skip(parallel_index * input.len)
.take(input.len)
.enumerate()
{
lc_input[input.offset + i] = *x;
}
}
lc_input
Expand Down
6 changes: 4 additions & 2 deletions expander_compiler/src/zkcuda/proving_system/dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl<C: Config> KernelWiseProvingSystem<C> for DummyProvingSystem<C> {
_commitments_state: &[&Self::CommitmentState],
commitments_values: &[&[SIMDField<C>]],
parallel_count: usize,
is_broadcast: &[bool],
is_broadcast: &[usize],
) -> DummyProof {
check_inputs(kernel, commitments_values, parallel_count, is_broadcast);
let mut res = vec![];
Expand All @@ -84,6 +84,7 @@ impl<C: Config> KernelWiseProvingSystem<C> for DummyProvingSystem<C> {
kernel.layered_circuit_input(),
commitments_values,
is_broadcast,
parallel_count,
i,
);
let (_, cond) = kernel
Expand All @@ -104,7 +105,7 @@ impl<C: Config> KernelWiseProvingSystem<C> for DummyProvingSystem<C> {
proof: &Self::Proof,
commitments: &[&Self::Commitment],
parallel_count: usize,
is_broadcast: &[bool],
is_broadcast: &[usize],
) -> bool {
let values = commitments.iter().map(|c| &c.vals[..]).collect::<Vec<_>>();
check_inputs(kernel, &values, parallel_count, is_broadcast);
Expand All @@ -114,6 +115,7 @@ impl<C: Config> KernelWiseProvingSystem<C> for DummyProvingSystem<C> {
kernel.layered_circuit_input(),
&values,
is_broadcast,
parallel_count,
i,
);
let (_, cond) = kernel
Expand Down
Loading
Loading