diff --git a/circuit-std-rs/tests/logup.rs b/circuit-std-rs/tests/logup.rs index 0b983f12..e3460bd8 100644 --- a/circuit-std-rs/tests/logup.rs +++ b/circuit-std-rs/tests/logup.rs @@ -7,8 +7,14 @@ use circuit_std_rs::{ use expander_compiler::{ field::{BN254Fr, Goldilocks}, frontend::*, - zkcuda::{context::*, kernel::*, proving_system::*, shape::Reshape}, + zkcuda::{ + context::*, + kernel::*, + proving_system::{expander::config::ZKCudaBN254Hyrax, *}, + shape::Reshape, + }, }; +use serdes::ExpSerde; #[test] fn logup_test() { @@ -196,3 +202,31 @@ fn rangeproof_zkcuda_test_fail() { ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } + +#[test] +fn rangeproof_zkcuda_no_oversubscribe_test() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.querycounthint", query_count_hint); + hint_registry.register("myhint.rangeproofhint", rangeproof_hint); + //compile and test + let kernel: KernelPrimitive = compile_rangeproof_test_kernel().unwrap(); + let mut ctx: Context = Context::new(hint_registry); + + let a = BN254Fr::from((1 << 9) as u32); + let a = ctx.copy_to_device(&a); + let a = a.reshape(&[1]); + call_kernel!(ctx, kernel, 1, a).unwrap(); + + let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.solve_witness().unwrap(); + let (prover_setup, _) = ExpanderNoOverSubscribe::::setup(&computation_graph); + let proof = ExpanderNoOverSubscribe::::prove( + &prover_setup, + &computation_graph, + ctx.export_device_memories(), + ); + let file = std::fs::File::create("proof.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + proof.serialize_into(writer).expect("serialize failed"); + as ProvingSystem>::post_process(); +} diff --git a/expander_compiler/src/zkcuda/mpi_mem_share.rs b/expander_compiler/src/zkcuda/mpi_mem_share.rs index 5c79c7db..12c3d6a0 100644 --- a/expander_compiler/src/zkcuda/mpi_mem_share.rs +++ b/expander_compiler/src/zkcuda/mpi_mem_share.rs @@ -65,18 +65,16 @@ impl MPISharedMemory for ComputationGraph { impl MPISharedMemory for Kernel { fn bytes_size(&self) -> usize { - assert!( - self.hint_solver.is_none(), - "Hint solver is not supported in MPISharedMemory for Kernel" - ); + if self.hint_solver.is_some() { + eprintln!("Warning: Shared Memory will ignore the hint solver in Kernel"); + } self.layered_circuit.bytes_size() + self.layered_circuit_input.bytes_size() } fn to_memory(&self, ptr: &mut *mut u8) { - assert!( - self.hint_solver.is_none(), - "Hint solver is not supported in MPISharedMemory for Kernel" - ); + if self.hint_solver.is_some() { + eprintln!("Warning: Shared Memory will ignore the hint solver in Kernel"); + } self.layered_circuit.to_memory(ptr); self.layered_circuit_input.to_memory(ptr); }