Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions circuit-std-rs/tests/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,14 @@ fn rangeproof_zkcuda_test() {
let kernel: KernelPrimitive<M31Config> = compile_rangeproof_test_kernel().unwrap();
let mut ctx: Context<M31Config, _> = Context::new(hint_registry);

let a = M31::from(1 << 9);
let a = ctx.copy_to_device(&a);
let a_value = M31::from(1 << 9);
let (a, a_id) = ctx.new_device_memory(vec![]);
let a = a.reshape(&[1]);
call_kernel!(ctx, kernel, 1, a).unwrap();

type P = Expander<M31Config>;
let computation_graph = ctx.compile_computation_graph().unwrap();
ctx.copy_to_device(&a_value, a_id);
ctx.solve_witness().unwrap();
let (prover_setup, verifier_setup) = <P as ProvingSystem<M31Config>>::setup(&computation_graph);
let proof = P::prove(
Expand All @@ -180,13 +181,14 @@ fn rangeproof_zkcuda_test_fail() {
let kernel: KernelPrimitive<M31Config> = compile_rangeproof_test_kernel().unwrap();
let mut ctx: Context<M31Config, _> = Context::new(hint_registry);

let a = M31::from(1 << 11);
let a = ctx.copy_to_device(&a);
let a_value = M31::from(1 << 11);
let (a, a_id) = ctx.new_device_memory(vec![]);
let a = a.reshape(&[1]);
call_kernel!(ctx, kernel, 1, a).unwrap();

type P = Expander<M31Config>;
let computation_graph = ctx.compile_computation_graph().unwrap();
ctx.copy_to_device(&a_value, a_id);
ctx.solve_witness().unwrap();
let (prover_setup, verifier_setup) = <P as ProvingSystem<M31Config>>::setup(&computation_graph);
let proof = P::prove(
Expand Down
6 changes: 4 additions & 2 deletions expander_compiler/bin/zkcuda_matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ pub fn zkcuda_matmul<C: Config, P: ProvingSystem<C>, const N: usize>() {
}
}

let a = ctx.copy_to_device(&mat_a);
let b = ctx.copy_to_device(&mat_b);
let (a, a_id) = ctx.new_device_memory(vec![N, M]);
let (b, b_id) = ctx.new_device_memory(vec![M, K]);
let mut c = None;
call_kernel!(ctx, kernel_mul_line, N, a, b, mut c).unwrap();

Expand All @@ -72,6 +72,8 @@ pub fn zkcuda_matmul<C: Config, P: ProvingSystem<C>, const N: usize>() {
assert_eq!(result, expected_result);

let computation_graph = ctx.compile_computation_graph().unwrap();
ctx.copy_to_device(&mat_a, a_id);
ctx.copy_to_device(&mat_b, b_id);
ctx.solve_witness().unwrap();

let (prover_setup, verifier_setup) = P::setup(&computation_graph);
Expand Down
203 changes: 123 additions & 80 deletions expander_compiler/src/zkcuda/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ use super::{
pub use macros::call_kernel;

struct DeviceMemory<C: Config> {
values: Vec<SIMDField<C>>,
pub values: Vec<SIMDField<C>>,
required_shape_products: Vec<usize>,
}

#[derive(Clone, Debug, ExpSerde)]
pub struct DeviceMemoryHandleRaw {
id: usize,
pub id: usize,
shape_history: ShapeHistory,
}

Expand Down Expand Up @@ -217,29 +217,50 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
}
}

pub fn new_device_memory(&mut self, shape: Shape) -> (DeviceMemoryHandle, usize) {
let t = shape_vec_len(&shape);
let required_shape_products = if t == 1 { vec![1] } else { vec![1, t] };
self.device_memories.push(DeviceMemory {
values: vec![],
required_shape_products,
});
(Some(DeviceMemoryHandleRaw {
id: self.device_memories.len() - 1,
shape_history: ShapeHistory::new(shape),
}), self.device_memories.len() - 1)
}

pub fn copy_to_device<T: VecShaped<CircuitField<C>>>(
&mut self,
host_memory: &T,
) -> DeviceMemoryHandle {
device_memory_id: usize,
) {
assert!(device_memory_id < self.device_memories.len(), "The device memory doesn't exist.");
let (flat, shape) = flatten_shaped(host_memory);
// assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion is commented out. If it's no longer necessary, it should be removed to improve code clarity. If it is still needed, it should be fixed and re-enabled.

The current logic shape_vec_len(&self.device_memories[device_memory_id].required_shape_products) seems incorrect, as required_shape_products is already a vector of products, not a shape. A correct check might involve comparing against the last element of required_shape_products, which represents the total size.

let simd_flat = pack_vec::<C>(&flat);
make_device_mem(&mut self.device_memories, simd_flat, shape)
self.device_memories[device_memory_id].values = simd_flat;
}

pub fn copy_to_device_and_pack_simd<T: VecShaped<CircuitField<C>>>(
&mut self,
host_memory: &T,
) -> DeviceMemoryHandle {
device_memory_id: usize,
) {
assert!(device_memory_id < self.device_memories.len(), "The device memory doesn't exist.");
let (flat, shape) = flatten_shaped_pack_simd(host_memory);
make_device_mem(&mut self.device_memories, flat, shape)
self.device_memories[device_memory_id].values = flat;
}

pub fn copy_simd_to_device<T: VecShaped<SIMDField<C>>>(
&mut self,
host_memory: &T,
) -> DeviceMemoryHandle {
device_memory_id: usize,
) {
assert!(device_memory_id < self.device_memories.len(), "The device memory doesn't exist.");
let (flat, shape) = flatten_shaped(host_memory);
make_device_mem(&mut self.device_memories, flat, shape)
// assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match.");
self.device_memories[device_memory_id].values = flat;
}

pub fn copy_to_host<T: VecShaped<CircuitField<C>> + Default>(
Expand Down Expand Up @@ -367,72 +388,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {

let kernel_id = self.kernel_primitives.add(kernel);

let mut outputs_tmp = vec![Vec::new(); kernel.io_specs().len()];
let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()];
let mut chunk_sizes: Vec<Option<usize>> = vec![None; kernel.io_specs().len()];
for (((input, &ib), ir_inputs), chunk_size) in ios
.iter()
.zip(is_broadcast.iter())
.zip(ir_inputs_all.iter_mut())
.zip(chunk_sizes.iter_mut())
{
if input.is_none() {
continue;
}
let handle = ensure_handle(input.clone());
let values = handle
.shape_history
.permute_vec(&self.device_memories[handle.id].values);
if !ib {
*chunk_size = Some(values.len() / num_parallel);
}
*ir_inputs = values;
}
let mut ir_inputs_per_parallel = Vec::new();
for parallel_i in 0..num_parallel {
let mut ir_inputs = vec![SIMDField::<C>::zero(); kernel.ir_for_calling().input_size()];
for (i, ((input, input_start), input_end)) in ios
.iter()
.zip(kernel.ir_input_offsets().iter())
.zip(kernel.ir_input_offsets().iter().skip(1))
.enumerate()
{
if input.is_none() {
continue;
}
self.ir_copy_from_device_memory(
&ir_inputs_all[i],
&mut ir_inputs[*input_start..*input_end],
is_broadcast[i],
parallel_i,
chunk_sizes[i],
);
}
ir_inputs_per_parallel.push(ir_inputs);
}
let ir_outputs_per_parallel: Vec<Result<Vec<SIMDField<C>>, Error>> = ir_inputs_per_parallel
.into_par_iter()
.map(|ir_inputs| {
kernel
.ir_for_calling()
.eval_safe_simd(ir_inputs, &[], &self.hint_caller)
})
.collect();
for ir_outputs in ir_outputs_per_parallel {
let ir_outputs = ir_outputs?;
for (((spec, output_start), output_end), out) in kernel
.io_specs()
.iter()
.zip(kernel.ir_output_offsets().iter())
.zip(kernel.ir_output_offsets().iter().skip(1))
.zip(outputs_tmp.iter_mut())
{
if !spec.is_output {
continue;
}
out.extend_from_slice(&ir_outputs[*output_start..*output_end]);
}
}
let outputs_tmp: Vec<Vec<SIMDField::<C>>> = vec![Vec::new(); kernel.io_specs().len()];
let input_handles = ios.to_vec();
let mut output_handles = vec![None; kernel.io_specs().len()];

Expand All @@ -447,12 +403,7 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
*output = None;
continue;
}
let handle = make_device_mem(
&mut self.device_memories,
ov,
shape_prepend(shape, num_parallel),
);
let id = handle.as_ref().unwrap().id;
let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel));
self.device_memories[id].required_shape_products = merge_shape_products(
&handle
.as_ref()
Expand Down Expand Up @@ -720,6 +671,96 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
Ok(())
}

pub fn solve_result(&mut self) -> Result<(), Error> {
for kernel_call in self.kernel_calls.iter() {
let kernel = self.kernel_primitives.get(kernel_call.kernel_id);
let num_parallel = kernel_call.num_parallel;
let is_broadcast = &kernel_call.is_broadcast;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This empty line can be removed to improve code conciseness.

let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()];
let mut chunk_sizes: Vec<Option<usize>> = vec![None; kernel.io_specs().len()];
for (((input, &ib), ir_inputs), chunk_size) in kernel_call.input_handles
.iter()
.zip(is_broadcast.iter())
.zip(ir_inputs_all.iter_mut())
.zip(chunk_sizes.iter_mut())
{
if input.is_none() {
continue;
}
let handle = ensure_handle(input.clone());
let values = handle
.shape_history
.permute_vec(&self.device_memories[handle.id].values);
if !ib {
*chunk_size = Some(values.len() / num_parallel);
}
*ir_inputs = values;
}
let mut ir_inputs_per_parallel = Vec::new();
for parallel_i in 0..num_parallel {
let mut ir_inputs = vec![SIMDField::<C>::zero(); kernel.ir_for_calling().input_size()];
for (i, ((input, input_start), input_end)) in kernel_call.input_handles
.iter()
.zip(kernel.ir_input_offsets().iter())
.zip(kernel.ir_input_offsets().iter().skip(1))
.enumerate()
{
if input.is_none() {
continue;
}
self.ir_copy_from_device_memory(
&ir_inputs_all[i],
&mut ir_inputs[*input_start..*input_end],
is_broadcast[i],
parallel_i,
chunk_sizes[i],
);
}
ir_inputs_per_parallel.push(ir_inputs);
}
let ir_outputs_per_parallel: Vec<Result<Vec<SIMDField<C>>, Error>> = ir_inputs_per_parallel
.into_par_iter()
.map(|ir_inputs| {
kernel
.ir_for_calling()
.eval_safe_simd(ir_inputs, &[], &self.hint_caller)
})
.collect();

let mut outputs_tmp: Vec<Vec<SIMDField::<C>>> = vec![Vec::new(); kernel.io_specs().len()];
for ir_outputs in ir_outputs_per_parallel {
let ir_outputs = ir_outputs?;
for (((spec, output_start), output_end), out) in kernel
.io_specs()
.iter()
.zip(kernel.ir_output_offsets().iter())
.zip(kernel.ir_output_offsets().iter().skip(1))
.zip(outputs_tmp.iter_mut())
{
if !spec.is_output {
continue;
}
out.extend_from_slice(&ir_outputs[*output_start..*output_end]);
}
}

for ((output, spec), ov) in kernel_call.output_handles
.iter()
.zip(kernel.io_specs().iter())
.zip(outputs_tmp.into_iter())
{
if !spec.is_output {
continue;
}
let output_id = output.as_ref().unwrap().id;
self.device_memories[output_id].values = ov;
}
}

Ok(())
}

// actually, this function computes hints
pub fn solve_witness(&mut self) -> Result<(), Error> {
match self.state {
Expand All @@ -732,6 +773,8 @@ impl<C: Config, H: HintCaller<CircuitField<C>>> Context<C, H> {
}
}
self.state = ContextState::WitnessDone;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This empty line can be removed.

self.solve_result();

for (kernel_call, proof_template) in
self.kernel_calls.iter().zip(self.proof_templates.iter())
Expand Down
13 changes: 8 additions & 5 deletions expander_compiler/src/zkcuda/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,16 @@ fn context_shape_test_1_impl<P: ProvingSystem<M31Config>>() {

// Part 1
// Since we only use the shape [15, 1], the representation of the vector is "xxxxxxxxxxxxxxx.".
let mut a = ctx.copy_to_device(&vec![one; 15]);
let a_value_1 = vec![one; 15];
let (mut a, a_id_1) = ctx.new_device_memory(vec![15]);
call_kernel!(ctx, identity_1, 15, mut a).unwrap();
assert_eq!(ctx.copy_to_host::<Vec<F>>(a), vec![one; 15]);

// Part 2
// Since we use [15, 1] and [3, 5], the context will find a representation that is compatible with both.
// The representation of the vector is "xxxxx...xxxxx...xxxxx...........".
let mut a = ctx.copy_to_device(&vec![one; 15]);
let a_value_2 = vec![one; 15];
let (mut a, a_id_2) = ctx.new_device_memory(vec![15]);
let mut b = a.reshape(&[5, 3]);
call_kernel!(ctx, identity_1, 15, mut a).unwrap();
call_kernel!(ctx, identity_3, 5, mut b).unwrap();
Expand All @@ -84,6 +86,8 @@ fn context_shape_test_1_impl<P: ProvingSystem<M31Config>>() {
assert_eq!(ctx.copy_to_host::<Vec<F>>(b), vec![one; 15]);

let computation_graph = ctx.compile_computation_graph().unwrap();
ctx.copy_to_device(&a_value_1, a_id_1);
ctx.copy_to_device(&a_value_2, a_id_2);
ctx.solve_witness().unwrap();

// Debugging output and assertions
Expand Down Expand Up @@ -143,12 +147,11 @@ fn context_shape_test_1() {
fn context_shape_test_2() {
type C = M31Config;
type F = CircuitField<C>;
let one = F::one();
let identity_3 = compile_identity_3::<C>().unwrap();
let identity_5 = compile_identity_5::<C>().unwrap();

let mut ctx: Context<C> = Context::default();
let a = ctx.copy_to_device(&vec![one; 15]);
let (a, _) = ctx.new_device_memory(vec![15]);
let mut b = a.reshape(&[5, 3]);
let mut a = a.reshape(&[3, 5]);
call_kernel!(ctx, identity_5, 3, mut a).unwrap();
Expand All @@ -164,7 +167,7 @@ fn context_shape_test_2_success() {
let identity_5 = compile_identity_5::<C>().unwrap();

let mut ctx: Context<C> = Context::default();
let a = ctx.copy_to_device(&vec![one; 15]);
let (a, _) = ctx.new_device_memory(vec![15]);
let b = a.reshape(&[5, 3]);
let mut a = a.reshape(&[3, 5]);
call_kernel!(ctx, identity_5, 3, mut a).unwrap();
Expand Down
4 changes: 3 additions & 1 deletion expander_compiler/tests/cg_mpi_share.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ fn get_computation_graph() -> ComputationGraph<M31Config> {
}

println!("prepare data ok");
let p = ctx.copy_to_device(&p);
let p_value = p;
let (p, p_id) = ctx.new_device_memory(vec![N_PARALLEL, 64 * 8]);
println!("copy to device ok");
let mut out = None;
call_kernel!(ctx, kernel, N_PARALLEL, p, mut out).unwrap();
Expand All @@ -338,6 +339,7 @@ fn get_computation_graph() -> ComputationGraph<M31Config> {
assert_eq!(out[0][0], expected_res[0][0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The refactoring to deferred execution has broken the logic of this test. The assert_eq! calls will now fail because out is checked before solve_witness is called, meaning it contains uninitialized data.

These assertions should be moved to after the ctx.solve_witness().unwrap() call on line 342.


let computation_graph = ctx.compile_computation_graph().unwrap();
ctx.copy_to_device(&p_value, p_id);

computation_graph
}
Expand Down
Loading
Loading