From ae040c1277232ebe4b3e14bb5187d2ae6711b0bc Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Wed, 25 Jun 2025 20:06:35 -0500 Subject: [PATCH 01/60] add some timer --- .../proving_system/expander_pcs_defered/prove_impl.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs index 03195d4c..d3e9e194 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs @@ -1,4 +1,5 @@ use arith::Field; +use expander_utils::timer::Timer; use gkr_engine::{ ExpanderPCS, ExpanderSingleVarChallenge, FieldEngine, GKREngine, MPIConfig, MPIEngine, Proof as BytesProof, Transcript, @@ -115,6 +116,7 @@ where ECCConfig: Config, C::FieldConfig: FieldEngine, { + let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); let (commitments, _states) = if global_mpi_config.is_root() { let (commitments, states) = values .iter() @@ -124,10 +126,15 @@ where } else { (None, None) }; + commit_timer.stop(); let mut vals_ref = vec![]; let mut challenges = vec![]; + let prove_timer = Timer::new( + "Prove all kernels (NO PCS Opening)", + global_mpi_config.is_root(), + ); let proofs = computation_graph .proof_templates() .iter() @@ -169,12 +176,16 @@ where } }) .collect::>(); + prove_timer.stop(); if global_mpi_config.is_root() { let mut proofs = proofs.into_iter().map(|p| p.unwrap()).collect::>(); + let pcs_opening_timer = Timer::new("Batch PCS Opening for all kernels", true); let pcs_batch_opening = open_defered_pcs::(prover_setup, &vals_ref, &challenges); + pcs_opening_timer.stop(); + proofs.push(pcs_batch_opening); Some(CombinedProof { commitments: commitments.unwrap(), From aa92f5ee492391b9e9d55cae6a0b66297b846da8 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Wed, 25 Jun 2025 20:20:04 -0500 Subject: [PATCH 02/60] some more tests --- expander_compiler/bin/zkcuda_matmul_pcs_defered.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/expander_compiler/bin/zkcuda_matmul_pcs_defered.rs b/expander_compiler/bin/zkcuda_matmul_pcs_defered.rs index fa8e17be..dea7e024 100644 --- a/expander_compiler/bin/zkcuda_matmul_pcs_defered.rs +++ b/expander_compiler/bin/zkcuda_matmul_pcs_defered.rs @@ -1,6 +1,9 @@ #![allow(unused)] mod zkcuda_matmul; -use expander_compiler::{frontend::BN254Config, zkcuda::proving_system::ExpanderPCSDefered}; +use expander_compiler::{ + frontend::BN254Config, + zkcuda::proving_system::{expander_pcs_defered::BN254ConfigSha2UniKZG, ExpanderPCSDefered}, +}; use gkr::BN254ConfigSha2Hyrax; use zkcuda_matmul::zkcuda_matmul; @@ -8,4 +11,8 @@ fn main() { zkcuda_matmul::, 4>(); zkcuda_matmul::, 8>(); zkcuda_matmul::, 16>(); + + zkcuda_matmul::, 4>(); + zkcuda_matmul::, 8>(); + zkcuda_matmul::, 16>(); } From ed47ac32f20587a3b16be89ed0e33ef7ad548870 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Wed, 25 Jun 2025 20:36:01 -0500 Subject: [PATCH 03/60] timer in verifier --- .../proving_system/expander_parallelized/api_parallel.rs | 3 +++ .../proving_system/expander_pcs_defered/verify_impl.rs | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs index cfd77e38..50b7b6eb 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs @@ -14,6 +14,7 @@ use crate::zkcuda::proving_system::{CombinedProof, ProvingSystem}; use super::super::Expander; +use expander_utils::timer::Timer; use gkr_engine::{FieldEngine, GKREngine}; use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; @@ -51,6 +52,7 @@ where computation_graph: &ComputationGraph, proof: &Self::Proof, ) -> bool { + let verification_timer = Timer::new("Verify all kernels", true); let verified = proof .proofs .par_iter() @@ -72,6 +74,7 @@ where ) }) .collect::>(); + verification_timer.stop(); verified.iter().all(|x| *x) } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs index 86deb072..9a429033 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs @@ -1,6 +1,7 @@ use std::io::Cursor; use arith::Field; +use expander_utils::timer::Timer; use gkr::gkr_verify; use gkr_engine::{ ExpanderDualVarChallenge, ExpanderPCS, ExpanderSingleVarChallenge, FieldEngine, GKREngine, @@ -155,8 +156,10 @@ where >::Commitment: AsRef<>::Commitment>, { + let verification_timer = Timer::new("Total Verification", true); let pcs_batch_opening = proof.proofs.pop().unwrap(); + let gkr_verification_timer = Timer::new("GKR Verification", true); let verified_with_pcs_claims = proof .proofs .par_iter() @@ -193,7 +196,9 @@ where println!("Failed to verify GKR proofs"); return false; } + gkr_verification_timer.stop(); + let pcs_verification_timer = Timer::new("PCS Verification", true); let commitments_ref = verified_with_pcs_claims .iter() .flat_map(|(_, c, _)| c) @@ -211,6 +216,8 @@ where &commitments_ref, &challenges, ); + pcs_verification_timer.stop(); + verification_timer.stop(); gkr_verified && pcs_verified } From fc2e657384538a5eafe8262d8b74d4794857ab7c Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Wed, 25 Jun 2025 20:45:36 -0500 Subject: [PATCH 04/60] switch to uni-kzg in plain parallelized expander --- .../proving_system/expander_parallelized/server_bin.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs index 53e76bbe..20fc4f8a 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs @@ -3,12 +3,12 @@ use std::str::FromStr; use clap::Parser; use expander_compiler::{ frontend::{BN254Config, BabyBearConfig, GF2Config, GoldilocksConfig, M31Config}, - zkcuda::proving_system::expander_parallelized::{ + zkcuda::proving_system::{expander_parallelized::{ server_ctrl::{serve, ExpanderExecArgs}, ParallelizedExpander, - }, + }, expander_pcs_defered::BN254ConfigSha2UniKZG}, }; -use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; +use gkr::BN254ConfigSha2Hyrax; use gkr_engine::PolynomialCommitmentType; #[tokio::main] @@ -55,7 +55,7 @@ pub async fn main() { .await; } ("BN254", PolynomialCommitmentType::KZG) => { - serve::>( + serve::>( expander_exec_args.port_number, ) .await; From 46e8f54ba384877a6c0af9e7cce19d81c9e9f408 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Wed, 25 Jun 2025 20:45:57 -0500 Subject: [PATCH 05/60] minor --- .../expander_parallelized/server_bin.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs index 20fc4f8a..6250428d 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs @@ -3,10 +3,13 @@ use std::str::FromStr; use clap::Parser; use expander_compiler::{ frontend::{BN254Config, BabyBearConfig, GF2Config, GoldilocksConfig, M31Config}, - zkcuda::proving_system::{expander_parallelized::{ - server_ctrl::{serve, ExpanderExecArgs}, - ParallelizedExpander, - }, expander_pcs_defered::BN254ConfigSha2UniKZG}, + zkcuda::proving_system::{ + expander_parallelized::{ + server_ctrl::{serve, ExpanderExecArgs}, + ParallelizedExpander, + }, + expander_pcs_defered::BN254ConfigSha2UniKZG, + }, }; use gkr::BN254ConfigSha2Hyrax; use gkr_engine::PolynomialCommitmentType; From 7ef3472b549e6b137836f396b46c12feab58ebd4 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Wed, 25 Jun 2025 20:50:46 -0500 Subject: [PATCH 06/60] switch to uni-kzg by default --- expander_compiler/bin/zkcuda_matmul.rs | 13 +++++++------ expander_compiler/bin/zkcuda_matmul_pcs_defered.rs | 6 +++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/expander_compiler/bin/zkcuda_matmul.rs b/expander_compiler/bin/zkcuda_matmul.rs index 75bd7b4a..4eea1357 100644 --- a/expander_compiler/bin/zkcuda_matmul.rs +++ b/expander_compiler/bin/zkcuda_matmul.rs @@ -4,6 +4,7 @@ use expander_compiler::frontend::{ BN254Config, BasicAPI, CircuitField, Config, Error, FieldArith, Variable, API, }; +use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; use expander_compiler::zkcuda::proving_system::{Expander, ParallelizedExpander, ProvingSystem}; use expander_compiler::zkcuda::shape::Reshape; use expander_compiler::zkcuda::{ @@ -93,10 +94,10 @@ pub fn zkcuda_matmul, const N: usize>() { } fn main() { - zkcuda_matmul::, 4>(); - zkcuda_matmul::, 8>(); - zkcuda_matmul::, 16>(); - zkcuda_matmul::, 4>(); - zkcuda_matmul::, 8>(); - zkcuda_matmul::, 16>(); + // zkcuda_matmul::, 4>(); + // zkcuda_matmul::, 8>(); + // zkcuda_matmul::, 16>(); + zkcuda_matmul::, 4>(); + zkcuda_matmul::, 8>(); + zkcuda_matmul::, 16>(); } diff --git a/expander_compiler/bin/zkcuda_matmul_pcs_defered.rs b/expander_compiler/bin/zkcuda_matmul_pcs_defered.rs index dea7e024..0970798a 100644 --- a/expander_compiler/bin/zkcuda_matmul_pcs_defered.rs +++ b/expander_compiler/bin/zkcuda_matmul_pcs_defered.rs @@ -8,9 +8,9 @@ use gkr::BN254ConfigSha2Hyrax; use zkcuda_matmul::zkcuda_matmul; fn main() { - zkcuda_matmul::, 4>(); - zkcuda_matmul::, 8>(); - zkcuda_matmul::, 16>(); + // zkcuda_matmul::, 4>(); + // zkcuda_matmul::, 8>(); + // zkcuda_matmul::, 16>(); zkcuda_matmul::, 4>(); zkcuda_matmul::, 8>(); From 710c6cf6b7fb9748aa86bfea2d2b27ec0cfc6bf0 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Wed, 25 Jun 2025 20:51:12 -0500 Subject: [PATCH 07/60] clippy auto fix --- expander_compiler/bin/zkcuda_matmul.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/expander_compiler/bin/zkcuda_matmul.rs b/expander_compiler/bin/zkcuda_matmul.rs index 4eea1357..1d2c80f5 100644 --- a/expander_compiler/bin/zkcuda_matmul.rs +++ b/expander_compiler/bin/zkcuda_matmul.rs @@ -5,13 +5,12 @@ use expander_compiler::frontend::{ BN254Config, BasicAPI, CircuitField, Config, Error, FieldArith, Variable, API, }; use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; -use expander_compiler::zkcuda::proving_system::{Expander, ParallelizedExpander, ProvingSystem}; +use expander_compiler::zkcuda::proving_system::{ParallelizedExpander, ProvingSystem}; use expander_compiler::zkcuda::shape::Reshape; use expander_compiler::zkcuda::{ context::{call_kernel, Context}, kernel::{compile_with_spec_and_shapes, kernel, IOVecSpec, KernelPrimitive}, }; -use gkr::BN254ConfigSha2Hyrax; const M: usize = 512; const K: usize = 512; From 5c23c6925890f0c95f85726c4734ef62420db1f8 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Wed, 25 Jun 2025 21:24:36 -0500 Subject: [PATCH 08/60] fix test --- expander_compiler/tests/zkcuda_examples.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/expander_compiler/tests/zkcuda_examples.rs b/expander_compiler/tests/zkcuda_examples.rs index c70ef94e..29cbeff3 100644 --- a/expander_compiler/tests/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda_examples.rs @@ -1,4 +1,5 @@ use expander_compiler::frontend::*; +use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; use expander_compiler::zkcuda::proving_system::{Expander, ParallelizedExpander, ProvingSystem}; use expander_compiler::zkcuda::shape::Reshape; use expander_compiler::zkcuda::{context::*, kernel::*}; @@ -85,7 +86,7 @@ fn zkcuda_test_multi_core() { zkcuda_test::>(); zkcuda_test::>(); zkcuda_test::>(); - zkcuda_test::>(); + zkcuda_test::>(); } fn zkcuda_test_simd_prepare_ctx() -> Context { From b56a743ff1998f6df2de93b2ed6163883afd17cd Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Thu, 26 Jun 2025 18:11:09 -0500 Subject: [PATCH 09/60] switch to original kzg in non-pcs-batching --- .../proving_system/expander_parallelized/server_bin.rs | 5 ++--- expander_compiler/tests/zkcuda_examples.rs | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs index 6250428d..d61907a5 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs @@ -8,10 +8,9 @@ use expander_compiler::{ server_ctrl::{serve, ExpanderExecArgs}, ParallelizedExpander, }, - expander_pcs_defered::BN254ConfigSha2UniKZG, }, }; -use gkr::BN254ConfigSha2Hyrax; +use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; use gkr_engine::PolynomialCommitmentType; #[tokio::main] @@ -58,7 +57,7 @@ pub async fn main() { .await; } ("BN254", PolynomialCommitmentType::KZG) => { - serve::>( + serve::>( expander_exec_args.port_number, ) .await; diff --git a/expander_compiler/tests/zkcuda_examples.rs b/expander_compiler/tests/zkcuda_examples.rs index 29cbeff3..c70ef94e 100644 --- a/expander_compiler/tests/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda_examples.rs @@ -1,5 +1,4 @@ use expander_compiler::frontend::*; -use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; use expander_compiler::zkcuda::proving_system::{Expander, ParallelizedExpander, ProvingSystem}; use expander_compiler::zkcuda::shape::Reshape; use expander_compiler::zkcuda::{context::*, kernel::*}; @@ -86,7 +85,7 @@ fn zkcuda_test_multi_core() { zkcuda_test::>(); zkcuda_test::>(); zkcuda_test::>(); - zkcuda_test::>(); + zkcuda_test::>(); } fn zkcuda_test_simd_prepare_ctx() -> Context { From 0bfc04912cbdba5061e70c7964c713d5406c7bdc Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 6 Jul 2025 21:10:04 -0500 Subject: [PATCH 10/60] wip --- Cargo.lock | 77 +++--- Cargo.toml | 38 +-- expander_compiler/Cargo.toml | 4 + expander_compiler/src/utils/misc.rs | 11 + .../src/zkcuda/proving_system.rs | 3 + .../proving_system/expander/prove_impl.rs | 16 +- .../expander_no_oversubscribe.rs | 3 + .../api_no_oversubscribe.rs | 58 +++++ .../expander_no_oversubscribe/cmd_utils.rs | 0 .../expander_no_oversubscribe/prove_impl.rs | 224 ++++++++++++++++++ .../expander_no_oversubscribe/server_bin.rs | 40 ++++ .../expander_no_oversubscribe/server_fn.rs | 82 +++++++ .../expander_parallelized/api_parallel.rs | 2 +- .../expander_parallelized/client_utils.rs | 17 +- .../expander_pcs_defered/api_pcs_defered.rs | 2 +- 15 files changed, 513 insertions(+), 64 deletions(-) create mode 100644 expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs create mode 100644 expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs create mode 100644 expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/cmd_utils.rs create mode 100644 expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs create mode 100644 expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs create mode 100644 expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs diff --git a/Cargo.lock b/Cargo.lock index 04297c86..b6b5278f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,7 +112,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "ark-std", "criterion", @@ -330,7 +330,7 @@ dependencies = [ [[package]] name = "babybear" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "ark-std", @@ -383,7 +383,7 @@ dependencies = [ [[package]] name = "bin" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "babybear", @@ -499,9 +499,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.18.1" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "byteorder" @@ -523,9 +523,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.27" +version = "1.2.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d487aa071b5f64da6f19a3e848e3578944b726ee5a4854b82172f02aa876bfdc" +checksum = "5c1599538de2394445747c8cf7935946e3cc27e9625f889d979bfb2aaf569362" dependencies = [ "shlex", ] @@ -589,7 +589,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "ark-std", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "config_macros" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "gkr_engine", "gkr_hashers", @@ -817,7 +817,7 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crosslayer_prototype" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "env_logger", @@ -835,9 +835,9 @@ dependencies = [ [[package]] name = "crunchy" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" [[package]] name = "crypto-common" @@ -1143,7 +1143,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "ark-std", @@ -1160,7 +1160,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "ark-std", @@ -1179,7 +1179,7 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "ark-std", @@ -1212,7 +1212,7 @@ dependencies = [ [[package]] name = "gkr_engine" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "babybear", @@ -1231,7 +1231,7 @@ dependencies = [ [[package]] name = "gkr_hashers" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "halo2curves", @@ -1249,7 +1249,7 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "goldilocks" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "ark-std", @@ -1301,7 +1301,7 @@ dependencies = [ [[package]] name = "halo2curves" version = "0.6.1" -source = "git+https://github.com/PolyhedraZK/halo2curves#33ebdf4a2845e96055438cc66305a34699fa3a19" +source = "git+https://github.com/PolyhedraZK/halo2curves#abb020f388b519c1f00033e267faa0709b1249e2" dependencies = [ "blake2", "digest", @@ -1655,14 +1655,25 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.9.0" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" +checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ "equivalent", "hashbrown 0.15.4", ] +[[package]] +name = "io-uring" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b86e202f00093dcba4275d4636b93ef9dd75d025ae560d2521b45ea28ab49013" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "libc", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -1870,7 +1881,7 @@ dependencies = [ [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "ark-std", @@ -2248,7 +2259,7 @@ dependencies = [ [[package]] name = "poly_commit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "ark-std", @@ -2273,7 +2284,7 @@ dependencies = [ [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "ark-std", @@ -2694,7 +2705,7 @@ dependencies = [ [[package]] name = "serdes" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "ethnum", "halo2curves", @@ -2705,7 +2716,7 @@ dependencies = [ [[package]] name = "serdes_derive" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "proc-macro2", "quote", @@ -2829,7 +2840,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "circuit", @@ -2980,17 +2991,19 @@ dependencies = [ [[package]] name = "tokio" -version = "1.45.1" +version = "1.46.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" +checksum = "0cc3a2344dafbe23a245241fe8b09735b521110d30fcefbbd5feb1797ca35d17" dependencies = [ "backtrace", "bytes", + "io-uring", "libc", "mio", "parking_lot", "pin-project-lite", "signal-hook-registry", + "slab", "socket2", "tokio-macros", "windows-sys 0.52.0", @@ -3093,7 +3106,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "gkr_engine", @@ -3116,7 +3129,7 @@ dependencies = [ [[package]] name = "tree" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "arith", "ark-std", @@ -3218,7 +3231,7 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "utils" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#f4f91e2959e22898d6a9ea49876d5329e70b31ad" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" dependencies = [ "colored", ] diff --git a/Cargo.toml b/Cargo.toml index 98820fbf..d391828a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,22 +47,22 @@ stacker = "0.1.17" tiny-keccak = { version = "2.0", features = ["keccak"] } tokio = { version = "1", features = ["full"] } -arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -mpi_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -gkr_field_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -babybear = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -crosslayer_prototype = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "circuit" } -expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "transcript" } -expander_binary = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "bin" } -gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -goldilocks = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -poly_commit = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "poly_commit" } -polynomials = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -sumcheck = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -serdes = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "serdes" } -gkr_engine = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -gkr_hashers = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -expander_utils = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "utils" } +arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } +mpi_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } +gkr_field_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } +babybear = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } +crosslayer_prototype = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } +expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine", package = "circuit" } +expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine", package = "transcript" } +expander_binary = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine", package = "bin" } +gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } +gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } +mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } +goldilocks = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } +poly_commit = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine", package = "poly_commit" } +polynomials = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } +sumcheck = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } +serdes = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine", package = "serdes" } +gkr_engine = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } +gkr_hashers = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } +expander_utils = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine", package = "utils" } diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 36906b39..7b54210a 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -61,6 +61,10 @@ path = "src/zkcuda/proving_system/expander_parallelized/server_bin.rs" name = "expander_server_pcs_defered" path = "src/zkcuda/proving_system/expander_pcs_defered/server_bin.rs" +[[bin]] +name = "expander_server_no_oversubscribe" +path = "src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs" + [[bin]] name = "zkcuda_matmul" path = "bin/zkcuda_matmul.rs" diff --git a/expander_compiler/src/utils/misc.rs b/expander_compiler/src/utils/misc.rs index af06365c..58cb1031 100644 --- a/expander_compiler/src/utils/misc.rs +++ b/expander_compiler/src/utils/misc.rs @@ -1,5 +1,16 @@ use std::collections::{HashMap, HashSet}; +pub fn prev_power_of_two(x: usize) -> usize { + if x == 0 { + return 0; + } + let mut padk: usize = 0; + while (1 << padk) <= x { + padk += 1; + } + 1 << (padk - 1) +} + pub fn next_power_of_two(x: usize) -> usize { let mut padk: usize = 0; while (1 << padk) < x { diff --git a/expander_compiler/src/zkcuda/proving_system.rs b/expander_compiler/src/zkcuda/proving_system.rs index a5f3fce5..20e627f0 100644 --- a/expander_compiler/src/zkcuda/proving_system.rs +++ b/expander_compiler/src/zkcuda/proving_system.rs @@ -18,3 +18,6 @@ pub use expander_parallelized::api_parallel::*; pub mod expander_pcs_defered; pub use expander_pcs_defered::api_pcs_defered::*; + +pub mod expander_no_oversubscribe; +pub use expander_no_oversubscribe::api_no_oversubscribe::*; diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index 3530abae..4743505a 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -84,16 +84,14 @@ pub fn prepare_inputs_with_local_vals( input_vals } -pub fn prove_gkr_with_local_vals( - expander_circuit: &mut Circuit, - prover_scratch: &mut ProverScratchPad, - local_commitment_values: &[impl AsRef<[::SimdCircuitField]>], +pub fn prove_gkr_with_local_vals( + expander_circuit: &mut Circuit, + prover_scratch: &mut ProverScratchPad, + local_commitment_values: &[impl AsRef<[F::SimdCircuitField]>], partition_info: &[LayeredCircuitInputVec], - transcript: &mut C::TranscriptConfig, + transcript: &mut T, mpi_config: &MPIConfig, -) -> ExpanderDualVarChallenge -where - C::FieldConfig: FieldEngine, +) -> ExpanderDualVarChallenge { expander_circuit.layers[0].input_vals = prepare_inputs_with_local_vals( 1 << expander_circuit.log_input_size(), @@ -106,7 +104,7 @@ where gkr_prove(expander_circuit, prover_scratch, transcript, mpi_config); assert_eq!( claimed_v, - ::ChallengeField::from(0) + F::ChallengeField::from(0) ); challenge } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs new file mode 100644 index 00000000..54c80bbd --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs @@ -0,0 +1,3 @@ +pub mod api_no_oversubscribe; +pub mod server_fn; +pub mod prove_impl; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs new file mode 100644 index 00000000..516df91d --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -0,0 +1,58 @@ +use crate::circuit::config::Config; +use crate::frontend::SIMDField; +use crate::zkcuda::context::ComputationGraph; +use crate::zkcuda::proving_system::expander::structs::{ + ExpanderProverSetup, ExpanderVerifierSetup, +}; +use crate::zkcuda::proving_system::expander_parallelized::client_utils::{ + client_launch_server_and_setup, client_parse_args, client_send_witness_and_prove, wait_async, + ClientHttpHelper, +}; +use crate::zkcuda::proving_system::{CombinedProof, ParallelizedExpander, ProvingSystem}; + +use super::super::Expander; + +use gkr_engine::{FieldEngine, GKREngine}; + +pub struct ExpanderNoOverSubscribe { + _config: std::marker::PhantomData, +} + +impl> ProvingSystem + for ExpanderNoOverSubscribe +where + C::FieldConfig: FieldEngine, +{ + type ProverSetup = ExpanderProverSetup; + type VerifierSetup = ExpanderVerifierSetup; + type Proof = CombinedProof>; + + fn setup( + computation_graph: &crate::zkcuda::context::ComputationGraph, + ) -> (Self::ProverSetup, Self::VerifierSetup) { + let server_binary = + client_parse_args().unwrap_or("../target/release/expander_server_no_oversubscribe".to_owned()); + client_launch_server_and_setup::(&server_binary, computation_graph, false) + } + + fn prove( + _prover_setup: &Self::ProverSetup, + _computation_graph: &crate::zkcuda::context::ComputationGraph, + device_memories: &[Vec>], + ) -> Self::Proof { + client_send_witness_and_prove(device_memories) + } + + fn verify( + verifier_setup: &Self::VerifierSetup, + computation_graph: &ComputationGraph, + proof: &Self::Proof, + ) -> bool { + // The proof should be the same as the one returned by ParallelizedExpander::prove + ParallelizedExpander::verify(verifier_setup, computation_graph, proof) + } + + fn post_process() { + wait_async(ClientHttpHelper::request_exit()) + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/cmd_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/cmd_utils.rs new file mode 100644 index 00000000..e69de29b diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs new file mode 100644 index 00000000..83c68583 --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -0,0 +1,224 @@ +use arith::Field; +use expander_utils::timer::Timer; +use gkr_engine::{ + ExpanderDualVarChallenge, ExpanderSingleVarChallenge, FieldEngine, GKREngine, MPIConfig, + MPIEngine, Transcript, +}; + +use crate::{ + frontend::{Config, SIMDField}, + utils::misc::next_power_of_two, + zkcuda::{ + context::ComputationGraph, + kernel::Kernel, + proving_system::{ + expander::{ + commit_impl::local_commit_impl, + prove_impl::{ + get_local_vals, pcs_local_open_impl, prepare_expander_circuit, + prove_gkr_with_local_vals, + }, + structs::{ExpanderCommitmentState, ExpanderProof, ExpanderProverSetup}, + }, + expander_parallelized::server_ctrl::generate_local_mpi_config, + CombinedProof, Expander, + }, + }, +}; + +pub fn mpi_prove_impl( + global_mpi_config: &MPIConfig<'static>, + prover_setup: &ExpanderProverSetup, + computation_graph: &ComputationGraph, + values: &[impl AsRef<[SIMDField]>], +) -> Option>> +where + C: GKREngine, + ECCConfig: Config, + C::FieldConfig: FieldEngine, +{ + let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); + let (commitments, states) = if global_mpi_config.is_root() { + let (commitments, states) = values + .iter() + .map(|value| local_commit_impl::(prover_setup, value.as_ref())) + .unzip::<_, _, Vec<_>, Vec<_>>(); + (Some(commitments), Some(states)) + } else { + (None, None) + }; + commit_timer.stop(); + + let prove_timer = Timer::new("Prove all kernels", global_mpi_config.is_root()); + let proofs = computation_graph + .proof_templates() + .iter() + .map(|template| { + let commitment_values = template + .commitment_indices() + .iter() + .map(|&idx| values[idx].as_ref()) + .collect::>(); + + let single_kernel_gkr_timer = + Timer::new("small gkr kernel", global_mpi_config.is_root()); + let gkr_end_state = prove_kernel_gkr::( + global_mpi_config, + &computation_graph.kernels()[template.kernel_id()], + &commitment_values, + next_power_of_two(template.parallel_count()), + template.is_broadcast(), + ); + single_kernel_gkr_timer.stop(); + + if global_mpi_config.is_root() { + let pcs_open_timer = Timer::new("pcs open", true); + let (mut transcript, challenge) = gkr_end_state.unwrap(); + let challenges = if let Some(challenge_y) = challenge.challenge_y() { + vec![challenge.challenge_x(), challenge_y] + } else { + vec![challenge.challenge_x()] + }; + + challenges.iter().for_each(|c| { + partition_single_gkr_claim_and_open_pcs_mpi::( + prover_setup, + &commitment_values, + &template + .commitment_indices() + .iter() + .map(|&idx| &states.as_ref().unwrap()[idx]) + .collect::>(), + c, + template.is_broadcast(), + &mut transcript, + ); + }); + + pcs_open_timer.stop(); + Some(ExpanderProof { + data: vec![transcript.finalize_and_get_proof()], + }) + } else { + None + } + }) + .collect::>(); + prove_timer.stop(); + + if global_mpi_config.is_root() { + let proofs = proofs.into_iter().map(|p| p.unwrap()).collect::>(); + Some(CombinedProof { + commitments: commitments.unwrap(), + proofs, + }) + } else { + None + } +} + +#[allow(clippy::too_many_arguments)] +pub fn prove_kernel_gkr( + mpi_config: &MPIConfig<'static>, + kernel: &Kernel, + commitments_values: &[&[F::SimdCircuitField]], + parallel_count: usize, + is_broadcast: &[bool], +) -> Option<( + T, + ExpanderDualVarChallenge, +)> +where + F: FieldEngine, + T: Transcript, + ECCConfig: Config, +{ + let local_mpi_config = generate_local_mpi_config(mpi_config, parallel_count); + + local_mpi_config.as_ref()?; + + let local_mpi_config = local_mpi_config.unwrap(); + let local_world_size = local_mpi_config.world_size(); + let local_world_rank = local_mpi_config.world_rank(); + + let local_commitment_values = get_local_vals( + commitments_values, + is_broadcast, + local_world_rank, + local_world_size, + ); + + let (mut expander_circuit, mut prover_scratch) = + prepare_expander_circuit::(kernel, local_world_size); + + let mut transcript = T::new(); + let challenge = prove_gkr_with_local_vals::( + &mut expander_circuit, + &mut prover_scratch, + &local_commitment_values, + kernel.layered_circuit_input(), + &mut transcript, + &local_mpi_config, + ); + + Some((transcript, challenge)) +} + +pub fn partition_challenge_and_location_for_pcs_mpi( + gkr_challenge: &ExpanderSingleVarChallenge, + total_vals_len: usize, + parallel_count: usize, + is_broadcast: bool, +) -> (ExpanderSingleVarChallenge, Vec) { + let mut challenge = gkr_challenge.clone(); + let zero = F::ChallengeField::ZERO; + if is_broadcast { + let n_vals_vars = total_vals_len.ilog2() as usize; + let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); + challenge.rz.resize(n_vals_vars, zero); + challenge.r_mpi.clear(); + (challenge, component_idx_vars) + } else { + let n_vals_vars = (total_vals_len / parallel_count).ilog2() as usize; + let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); + challenge.rz.resize(n_vals_vars, zero); + + challenge.rz.extend_from_slice(&challenge.r_mpi); + challenge.r_mpi.clear(); + (challenge, component_idx_vars) + } +} + +#[allow(clippy::too_many_arguments)] +fn partition_single_gkr_claim_and_open_pcs_mpi( + p_keys: &ExpanderProverSetup, + commitments_values: &[impl AsRef<[SIMDField]>], + commitments_state: &[&ExpanderCommitmentState], + gkr_challenge: &ExpanderSingleVarChallenge, + is_broadcast: &[bool], + transcript: &mut C::TranscriptConfig, +) where + C::FieldConfig: FieldEngine, +{ + let parallel_count = 1 << gkr_challenge.r_mpi.len(); + for ((commitment_val, _state), ib) in commitments_values + .iter() + .zip(commitments_state) + .zip(is_broadcast) + { + let val_len = commitment_val.as_ref().len(); + let (challenge_for_pcs, _) = partition_challenge_and_location_for_pcs_mpi( + gkr_challenge, + val_len, + parallel_count, + *ib, + ); + + pcs_local_open_impl::( + commitment_val.as_ref(), + &challenge_for_pcs, + p_keys, + transcript, + ); + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs new file mode 100644 index 00000000..6359e2b8 --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs @@ -0,0 +1,40 @@ +use std::str::FromStr; + +use clap::Parser; +use expander_compiler::{ + frontend::BN254Config, + zkcuda::proving_system::{ + expander_parallelized::server_ctrl::{serve, ExpanderExecArgs}, expander_pcs_defered::BN254ConfigSha2UniKZG, ExpanderNoOverSubscribe, + }, +}; +use gkr::BN254ConfigSha2Hyrax; +use gkr_engine::PolynomialCommitmentType; + +#[tokio::main] +pub async fn main() { + let expander_exec_args = ExpanderExecArgs::parse(); + assert_eq!( + expander_exec_args.fiat_shamir_hash, "SHA256", + "Only SHA256 is supported for now" + ); + + let pcs_type = PolynomialCommitmentType::from_str(&expander_exec_args.poly_commit).unwrap(); + + match (expander_exec_args.field_type.as_str(), pcs_type) { + ("BN254", PolynomialCommitmentType::Hyrax) => { + serve::>( + expander_exec_args.port_number, + ) + .await; + } + ("BN254", PolynomialCommitmentType::KZG) => { + serve::>( + expander_exec_args.port_number, + ) + .await; + } + (field_type, pcs_type) => { + panic!("Combination of {field_type:?} and {pcs_type:?} not supported for no oversubscribe expander proving system."); + } + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs new file mode 100644 index 00000000..a40e735c --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -0,0 +1,82 @@ +use gkr_engine::{FieldEngine, GKREngine, MPIConfig, MPIEngine}; +use serdes::ExpSerde; + +use crate::{ + frontend::{Config, SIMDField}, + zkcuda::{ + context::ComputationGraph, + proving_system::{ + expander::{ + structs::{ExpanderProverSetup, ExpanderVerifierSetup}, + }, expander_parallelized::{prove_impl::mpi_prove_impl, server_fns::ServerFns}, CombinedProof, Expander, ExpanderNoOverSubscribe, ParallelizedExpander + }, + }, +}; + +impl ServerFns for ExpanderNoOverSubscribe +where + C: GKREngine, + ECCConfig: Config, + C::FieldConfig: FieldEngine, +{ + fn setup_request_handler( + global_mpi_config: &MPIConfig<'static>, + setup_file: Option, + computation_graph: &mut ComputationGraph, + prover_setup: &mut ExpanderProverSetup, + verifier_setup: &mut ExpanderVerifierSetup, + ) where + C::FieldConfig: FieldEngine, + { + ParallelizedExpander::::setup_request_handler( + global_mpi_config, + setup_file, + computation_graph, + prover_setup, + verifier_setup, + ); + } + + fn prove_request_handler( + global_mpi_config: &MPIConfig<'static>, + prover_setup: &ExpanderProverSetup, + computation_graph: &ComputationGraph, + values: &[impl AsRef<[SIMDField]>], + ) -> Option>> + where + C: GKREngine, + ECCConfig: Config, + C::FieldConfig: FieldEngine, + { + mpi_prove_impl(global_mpi_config, prover_setup, computation_graph, values) + } +} + +pub fn broadcast_string(global_mpi_config: &MPIConfig<'static>, string: Option) -> String { + // Broadcast the setup file path to all workers + if global_mpi_config.is_root() && string.is_none() { + panic!("String must be provided on the root process in broadcast_string"); + } + let mut string_length = string.as_ref().map_or(0, |s| s.len()); + global_mpi_config.root_broadcast_f(&mut string_length); + let mut bytes = string.map_or(vec![0u8; string_length], |s| s.into_bytes()); + global_mpi_config.root_broadcast_bytes(&mut bytes); + String::from_utf8(bytes).expect("Failed to convert broadcasted bytes to String") +} + +pub fn read_circuit( + _global_mpi_config: &MPIConfig<'static>, + setup_file: String, + computation_graph: &mut ComputationGraph, +) where + C: GKREngine, + ECCConfig: Config, + C::FieldConfig: FieldEngine, +{ + let computation_graph_bytes = + std::fs::read(setup_file).expect("Failed to read computation graph from file"); + *computation_graph = ComputationGraph::::deserialize_from(std::io::Cursor::new( + computation_graph_bytes, + )) + .expect("Failed to deserialize computation graph"); +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs index 50b7b6eb..1f99669c 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs @@ -36,7 +36,7 @@ where ) -> (Self::ProverSetup, Self::VerifierSetup) { let server_binary = client_parse_args().unwrap_or("../target/release/expander_server".to_owned()); - client_launch_server_and_setup::(&server_binary, computation_graph) + client_launch_server_and_setup::(&server_binary, computation_graph, true) } fn prove( diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index e14aba1c..2ccf0e61 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -2,7 +2,7 @@ use std::fs; use crate::{ frontend::{Config, SIMDField}, - utils::misc::next_power_of_two, + utils::misc::{next_power_of_two, prev_power_of_two}, zkcuda::{ context::ComputationGraph, proving_system::{ @@ -77,6 +77,7 @@ pub fn client_parse_args() -> Option { pub fn client_launch_server_and_setup( server_binary: &str, computation_graph: &ComputationGraph, + allow_oversubscribe: bool, ) -> ( ExpanderProverSetup, ExpanderVerifierSetup, @@ -104,10 +105,22 @@ where .map(|t| t.parallel_count()) .max() .unwrap_or(1); + let max_parallel_count = next_power_of_two(max_parallel_count); + + let mpi_size = if allow_oversubscribe { + max_parallel_count + } else { + let num_cpus = prev_power_of_two(num_cpus::get_physical()); + if max_parallel_count > num_cpus { + num_cpus + } else { + max_parallel_count + } + }; let port = parse_port_number(); let server_url = format!("{SERVER_IP}:{port}"); - start_server::(server_binary, next_power_of_two(max_parallel_count), port); + start_server::(server_binary, mpi_size, port); // Keep trying until the server is ready loop { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs index 5a8fa237..da0d30fc 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs @@ -35,7 +35,7 @@ where ) -> (Self::ProverSetup, Self::VerifierSetup) { let server_binary = client_parse_args() .unwrap_or("../target/release/expander_server_pcs_defered".to_owned()); - client_launch_server_and_setup::(&server_binary, computation_graph) + client_launch_server_and_setup::(&server_binary, computation_graph, true) } fn prove( From b3aea1a4b55c393344c82798db62df5c263ca483 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 6 Jul 2025 21:13:59 -0500 Subject: [PATCH 11/60] remove unnecessay config in circuit preprocess --- expander_compiler/src/circuit/layered/export.rs | 2 +- .../src/zkcuda/proving_system/expander/api_single_thread.rs | 2 +- .../src/zkcuda/proving_system/expander/prove_impl.rs | 2 +- .../zkcuda/proving_system/expander_parallelized/verify_impl.rs | 2 +- .../zkcuda/proving_system/expander_pcs_defered/verify_impl.rs | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/expander_compiler/src/circuit/layered/export.rs b/expander_compiler/src/circuit/layered/export.rs index f59b7311..a7011513 100644 --- a/expander_compiler/src/circuit/layered/export.rs +++ b/expander_compiler/src/circuit/layered/export.rs @@ -73,7 +73,7 @@ impl Circuit { pub fn export_to_expander_flatten(&self) -> expander_circuit::Circuit { let circuit = self.export_to_expander::(); let mut flattened = circuit.flatten::(); - flattened.pre_process_gkr::(); + flattened.pre_process_gkr(); flattened } } diff --git a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs index 0bb500de..a9c7b146 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs @@ -119,7 +119,7 @@ where ) -> bool { let timer = Timer::new("verify", true); let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten::(); - expander_circuit.pre_process_gkr::(); + expander_circuit.pre_process_gkr(); for i in 0..parallel_count { let mut transcript = C::TranscriptConfig::new(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index 4743505a..517df952 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -29,7 +29,7 @@ where C::FieldConfig: FieldEngine, { let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten::(); - expander_circuit.pre_process_gkr::(); + expander_circuit.pre_process_gkr(); let (max_num_input_var, max_num_output_var) = super::utils::max_n_vars(&expander_circuit); let prover_scratch = ProverScratchPad::::new( max_num_input_var, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs index f05f9898..11e70d93 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs @@ -39,7 +39,7 @@ where { let timer = Timer::new("verify", true); let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten::(); - expander_circuit.pre_process_gkr::(); + expander_circuit.pre_process_gkr(); let mut transcript = C::TranscriptConfig::new(); expander_circuit.fill_rnd_coefs(&mut transcript); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs index 9a429033..147dfa6f 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs @@ -71,7 +71,7 @@ where C::FieldConfig: FieldEngine, { let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten::(); - expander_circuit.pre_process_gkr::(); + expander_circuit.pre_process_gkr(); let mut transcript = C::TranscriptConfig::new(); expander_circuit.fill_rnd_coefs(&mut transcript); From 2734bb9ade331ad476092ee332760c665c56baa3 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 6 Jul 2025 21:28:19 -0500 Subject: [PATCH 12/60] fine-grained config to enable flexibility --- Cargo.lock | 42 +++++++-------- .../src/circuit/layered/export.rs | 2 +- .../expander/api_single_thread.rs | 7 ++- .../proving_system/expander/prove_impl.rs | 14 +++-- .../expander_no_oversubscribe/prove_impl.rs | 51 +------------------ .../expander_parallelized/prove_impl.rs | 22 ++++---- .../expander_parallelized/verify_impl.rs | 3 +- .../expander_pcs_defered/verify_impl.rs | 3 +- 8 files changed, 46 insertions(+), 98 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b6b5278f..dd30779b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,7 +112,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "ark-std", "criterion", @@ -330,7 +330,7 @@ dependencies = [ [[package]] name = "babybear" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "ark-std", @@ -383,7 +383,7 @@ dependencies = [ [[package]] name = "bin" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "babybear", @@ -589,7 +589,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "ark-std", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "config_macros" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "gkr_engine", "gkr_hashers", @@ -817,7 +817,7 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crosslayer_prototype" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "env_logger", @@ -1143,7 +1143,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "ark-std", @@ -1160,7 +1160,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "ark-std", @@ -1179,7 +1179,7 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "ark-std", @@ -1212,7 +1212,7 @@ dependencies = [ [[package]] name = "gkr_engine" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "babybear", @@ -1231,7 +1231,7 @@ dependencies = [ [[package]] name = "gkr_hashers" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "halo2curves", @@ -1249,7 +1249,7 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "goldilocks" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "ark-std", @@ -1881,7 +1881,7 @@ dependencies = [ [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "ark-std", @@ -2259,7 +2259,7 @@ dependencies = [ [[package]] name = "poly_commit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "ark-std", @@ -2284,7 +2284,7 @@ dependencies = [ [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "ark-std", @@ -2705,7 +2705,7 @@ dependencies = [ [[package]] name = "serdes" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "ethnum", "halo2curves", @@ -2716,7 +2716,7 @@ dependencies = [ [[package]] name = "serdes_derive" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "proc-macro2", "quote", @@ -2840,7 +2840,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "circuit", @@ -3106,7 +3106,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "gkr_engine", @@ -3129,7 +3129,7 @@ dependencies = [ [[package]] name = "tree" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "arith", "ark-std", @@ -3231,7 +3231,7 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "utils" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#0e13c8ee2880ff042a09ef4bde687b1c9d9250b6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" dependencies = [ "colored", ] diff --git a/expander_compiler/src/circuit/layered/export.rs b/expander_compiler/src/circuit/layered/export.rs index a7011513..558d5ab8 100644 --- a/expander_compiler/src/circuit/layered/export.rs +++ b/expander_compiler/src/circuit/layered/export.rs @@ -72,7 +72,7 @@ impl Circuit { pub fn export_to_expander_flatten(&self) -> expander_circuit::Circuit { let circuit = self.export_to_expander::(); - let mut flattened = circuit.flatten::(); + let mut flattened = circuit.flatten(); flattened.pre_process_gkr(); flattened } diff --git a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs index a9c7b146..25b96076 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs @@ -69,7 +69,7 @@ where check_inputs(kernel, commitments_values, parallel_count, is_broadcast); let (mut expander_circuit, mut prover_scratch) = - prepare_expander_circuit::(kernel, 1); + prepare_expander_circuit::(kernel, 1); let mut proof = ExpanderProof { data: vec![] }; @@ -83,7 +83,7 @@ where parallel_index, parallel_count, ); - let challenge = prove_gkr_with_local_vals::( + let challenge = prove_gkr_with_local_vals::( &mut expander_circuit, &mut prover_scratch, &local_vals, @@ -118,8 +118,7 @@ where is_broadcast: &[bool], ) -> bool { let timer = Timer::new("verify", true); - let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten::(); - expander_circuit.pre_process_gkr(); + let mut expander_circuit = kernel.layered_circuit().export_to_expander_flatten(); for i in 0..parallel_count { let mut transcript = C::TranscriptConfig::new(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index 517df952..34e39395 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -19,19 +19,17 @@ use crate::{ /// ECCCircuit -> ExpanderCircuit /// Returns an additional prover scratch pad for later use in GKR. -pub fn prepare_expander_circuit( +pub fn prepare_expander_circuit( kernel: &Kernel, mpi_world_size: usize, -) -> (Circuit, ProverScratchPad) +) -> (Circuit, ProverScratchPad) where - C: GKREngine, - ECCConfig: Config, - C::FieldConfig: FieldEngine, + F: FieldEngine, + ECCConfig: Config, { - let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten::(); - expander_circuit.pre_process_gkr(); + let expander_circuit = kernel.layered_circuit().export_to_expander_flatten(); let (max_num_input_var, max_num_output_var) = super::utils::max_n_vars(&expander_circuit); - let prover_scratch = ProverScratchPad::::new( + let prover_scratch = ProverScratchPad::::new( max_num_input_var, max_num_output_var, mpi_world_size, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index 83c68583..4932679b 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -20,7 +20,7 @@ use crate::{ }, structs::{ExpanderCommitmentState, ExpanderProof, ExpanderProverSetup}, }, - expander_parallelized::server_ctrl::generate_local_mpi_config, + expander_parallelized::{prove_impl::prove_kernel_gkr, server_ctrl::generate_local_mpi_config}, CombinedProof, Expander, }, }, @@ -62,7 +62,7 @@ where let single_kernel_gkr_timer = Timer::new("small gkr kernel", global_mpi_config.is_root()); - let gkr_end_state = prove_kernel_gkr::( + let gkr_end_state = prove_kernel_gkr::( global_mpi_config, &computation_graph.kernels()[template.kernel_id()], &commitment_values, @@ -117,53 +117,6 @@ where } } -#[allow(clippy::too_many_arguments)] -pub fn prove_kernel_gkr( - mpi_config: &MPIConfig<'static>, - kernel: &Kernel, - commitments_values: &[&[F::SimdCircuitField]], - parallel_count: usize, - is_broadcast: &[bool], -) -> Option<( - T, - ExpanderDualVarChallenge, -)> -where - F: FieldEngine, - T: Transcript, - ECCConfig: Config, -{ - let local_mpi_config = generate_local_mpi_config(mpi_config, parallel_count); - - local_mpi_config.as_ref()?; - - let local_mpi_config = local_mpi_config.unwrap(); - let local_world_size = local_mpi_config.world_size(); - let local_world_rank = local_mpi_config.world_rank(); - - let local_commitment_values = get_local_vals( - commitments_values, - is_broadcast, - local_world_rank, - local_world_size, - ); - - let (mut expander_circuit, mut prover_scratch) = - prepare_expander_circuit::(kernel, local_world_size); - - let mut transcript = T::new(); - let challenge = prove_gkr_with_local_vals::( - &mut expander_circuit, - &mut prover_scratch, - &local_commitment_values, - kernel.layered_circuit_input(), - &mut transcript, - &local_mpi_config, - ); - - Some((transcript, challenge)) -} - pub fn partition_challenge_and_location_for_pcs_mpi( gkr_challenge: &ExpanderSingleVarChallenge, total_vals_len: usize, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index a7d02721..0ae2fbbc 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -62,7 +62,7 @@ where let single_kernel_gkr_timer = Timer::new("small gkr kernel", global_mpi_config.is_root()); - let gkr_end_state = prove_kernel_gkr::( + let gkr_end_state = prove_kernel_gkr::( global_mpi_config, &computation_graph.kernels()[template.kernel_id()], &commitment_values, @@ -118,20 +118,20 @@ where } #[allow(clippy::too_many_arguments)] -pub fn prove_kernel_gkr( +pub fn prove_kernel_gkr( mpi_config: &MPIConfig<'static>, kernel: &Kernel, - commitments_values: &[&[SIMDField]], + commitments_values: &[&[F::SimdCircuitField]], parallel_count: usize, is_broadcast: &[bool], ) -> Option<( - C::TranscriptConfig, - ExpanderDualVarChallenge, + T, + ExpanderDualVarChallenge, )> where - C: GKREngine, - ECCConfig: Config, - C::FieldConfig: FieldEngine, + F: FieldEngine, + T: Transcript, + ECCConfig: Config, { let local_mpi_config = generate_local_mpi_config(mpi_config, parallel_count); @@ -149,10 +149,10 @@ where ); let (mut expander_circuit, mut prover_scratch) = - prepare_expander_circuit::(kernel, local_world_size); + prepare_expander_circuit::(kernel, local_world_size); - let mut transcript = C::TranscriptConfig::new(); - let challenge = prove_gkr_with_local_vals::( + let mut transcript = T::new(); + let challenge = prove_gkr_with_local_vals::( &mut expander_circuit, &mut prover_scratch, &local_commitment_values, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs index 11e70d93..763c3aff 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs @@ -38,8 +38,7 @@ where C::FieldConfig: FieldEngine, { let timer = Timer::new("verify", true); - let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten::(); - expander_circuit.pre_process_gkr(); + let mut expander_circuit = kernel.layered_circuit().export_to_expander_flatten(); let mut transcript = C::TranscriptConfig::new(); expander_circuit.fill_rnd_coefs(&mut transcript); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs index 147dfa6f..95cfbcd0 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs @@ -70,8 +70,7 @@ where ECCConfig: Config, C::FieldConfig: FieldEngine, { - let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten::(); - expander_circuit.pre_process_gkr(); + let mut expander_circuit = kernel.layered_circuit().export_to_expander_flatten(); let mut transcript = C::TranscriptConfig::new(); expander_circuit.fill_rnd_coefs(&mut transcript); From bb8fdd58b92c2575ed10b21b8010522f31587411 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 6 Jul 2025 21:29:49 -0500 Subject: [PATCH 13/60] minor bug fix --- .../zkcuda/proving_system/expander/prove_impl.rs | 15 ++++----------- .../proving_system/expander_no_oversubscribe.rs | 2 +- .../api_no_oversubscribe.rs | 4 ++-- .../expander_no_oversubscribe/prove_impl.rs | 4 +++- .../expander_no_oversubscribe/server_bin.rs | 4 +++- .../expander_no_oversubscribe/server_fn.rs | 6 +++--- .../expander_parallelized/prove_impl.rs | 5 +---- .../expander_parallelized/server_bin.rs | 8 +++----- .../expander_pcs_defered/prove_impl.rs | 2 +- 9 files changed, 21 insertions(+), 29 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index 34e39395..caf619b6 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -29,11 +29,8 @@ where { let expander_circuit = kernel.layered_circuit().export_to_expander_flatten(); let (max_num_input_var, max_num_output_var) = super::utils::max_n_vars(&expander_circuit); - let prover_scratch = ProverScratchPad::::new( - max_num_input_var, - max_num_output_var, - mpi_world_size, - ); + let prover_scratch = + ProverScratchPad::::new(max_num_input_var, max_num_output_var, mpi_world_size); (expander_circuit, prover_scratch) } @@ -89,8 +86,7 @@ pub fn prove_gkr_with_local_vals( partition_info: &[LayeredCircuitInputVec], transcript: &mut T, mpi_config: &MPIConfig, -) -> ExpanderDualVarChallenge -{ +) -> ExpanderDualVarChallenge { expander_circuit.layers[0].input_vals = prepare_inputs_with_local_vals( 1 << expander_circuit.log_input_size(), partition_info, @@ -100,10 +96,7 @@ pub fn prove_gkr_with_local_vals( expander_circuit.evaluate(); let (claimed_v, challenge) = gkr_prove(expander_circuit, prover_scratch, transcript, mpi_config); - assert_eq!( - claimed_v, - F::ChallengeField::from(0) - ); + assert_eq!(claimed_v, F::ChallengeField::from(0)); challenge } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs index 54c80bbd..0ce09b0b 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs @@ -1,3 +1,3 @@ pub mod api_no_oversubscribe; -pub mod server_fn; pub mod prove_impl; +pub mod server_fn; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index 516df91d..39d51cf4 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -30,8 +30,8 @@ where fn setup( computation_graph: &crate::zkcuda::context::ComputationGraph, ) -> (Self::ProverSetup, Self::VerifierSetup) { - let server_binary = - client_parse_args().unwrap_or("../target/release/expander_server_no_oversubscribe".to_owned()); + let server_binary = client_parse_args() + .unwrap_or("../target/release/expander_server_no_oversubscribe".to_owned()); client_launch_server_and_setup::(&server_binary, computation_graph, false) } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index 4932679b..bdcb1518 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -20,7 +20,9 @@ use crate::{ }, structs::{ExpanderCommitmentState, ExpanderProof, ExpanderProverSetup}, }, - expander_parallelized::{prove_impl::prove_kernel_gkr, server_ctrl::generate_local_mpi_config}, + expander_parallelized::{ + prove_impl::prove_kernel_gkr, server_ctrl::generate_local_mpi_config, + }, CombinedProof, Expander, }, }, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs index 6359e2b8..fd7cb9c7 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs @@ -4,7 +4,9 @@ use clap::Parser; use expander_compiler::{ frontend::BN254Config, zkcuda::proving_system::{ - expander_parallelized::server_ctrl::{serve, ExpanderExecArgs}, expander_pcs_defered::BN254ConfigSha2UniKZG, ExpanderNoOverSubscribe, + expander_parallelized::server_ctrl::{serve, ExpanderExecArgs}, + expander_pcs_defered::BN254ConfigSha2UniKZG, + ExpanderNoOverSubscribe, }, }; use gkr::BN254ConfigSha2Hyrax; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index a40e735c..8a4cbb04 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -6,9 +6,9 @@ use crate::{ zkcuda::{ context::ComputationGraph, proving_system::{ - expander::{ - structs::{ExpanderProverSetup, ExpanderVerifierSetup}, - }, expander_parallelized::{prove_impl::mpi_prove_impl, server_fns::ServerFns}, CombinedProof, Expander, ExpanderNoOverSubscribe, ParallelizedExpander + expander::structs::{ExpanderProverSetup, ExpanderVerifierSetup}, + expander_parallelized::{prove_impl::mpi_prove_impl, server_fns::ServerFns}, + CombinedProof, Expander, ExpanderNoOverSubscribe, ParallelizedExpander, }, }, }; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index 0ae2fbbc..8ec1c5c7 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -124,10 +124,7 @@ pub fn prove_kernel_gkr( commitments_values: &[&[F::SimdCircuitField]], parallel_count: usize, is_broadcast: &[bool], -) -> Option<( - T, - ExpanderDualVarChallenge, -)> +) -> Option<(T, ExpanderDualVarChallenge)> where F: FieldEngine, T: Transcript, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs index d61907a5..53e76bbe 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs @@ -3,11 +3,9 @@ use std::str::FromStr; use clap::Parser; use expander_compiler::{ frontend::{BN254Config, BabyBearConfig, GF2Config, GoldilocksConfig, M31Config}, - zkcuda::proving_system::{ - expander_parallelized::{ - server_ctrl::{serve, ExpanderExecArgs}, - ParallelizedExpander, - }, + zkcuda::proving_system::expander_parallelized::{ + server_ctrl::{serve, ExpanderExecArgs}, + ParallelizedExpander, }, }; use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs index d3e9e194..cd27e252 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs @@ -145,7 +145,7 @@ where .map(|&idx| values[idx].as_ref()) .collect::>(); - let gkr_end_state = prove_kernel_gkr::( + let gkr_end_state = prove_kernel_gkr::( global_mpi_config, &computation_graph.kernels()[template.kernel_id()], &commitment_values, From cddcf1f4985aee44a311edf983d1c39048e883f4 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 6 Jul 2025 21:45:13 -0500 Subject: [PATCH 14/60] clippy --- .../expander_no_oversubscribe/prove_impl.rs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index bdcb1518..539657b1 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -1,7 +1,7 @@ use arith::Field; use expander_utils::timer::Timer; use gkr_engine::{ - ExpanderDualVarChallenge, ExpanderSingleVarChallenge, FieldEngine, GKREngine, MPIConfig, + ExpanderSingleVarChallenge, FieldEngine, GKREngine, MPIConfig, MPIEngine, Transcript, }; @@ -10,19 +10,13 @@ use crate::{ utils::misc::next_power_of_two, zkcuda::{ context::ComputationGraph, - kernel::Kernel, proving_system::{ expander::{ commit_impl::local_commit_impl, - prove_impl::{ - get_local_vals, pcs_local_open_impl, prepare_expander_circuit, - prove_gkr_with_local_vals, - }, + prove_impl::pcs_local_open_impl, structs::{ExpanderCommitmentState, ExpanderProof, ExpanderProverSetup}, }, - expander_parallelized::{ - prove_impl::prove_kernel_gkr, server_ctrl::generate_local_mpi_config, - }, + expander_parallelized::prove_impl::prove_kernel_gkr, CombinedProof, Expander, }, }, From 6cad241ef6eb1a7c552c051694384317a85842f7 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 7 Jul 2025 19:03:43 -0500 Subject: [PATCH 15/60] remove pcs field --- Cargo.lock | 46 ++--- expander_compiler/ec_go_lib/src/proving.rs | 6 +- .../expander/api_single_thread.rs | 11 +- .../proving_system/expander/commit_impl.rs | 13 +- .../proving_system/expander/prove_impl.rs | 28 +-- .../proving_system/expander/setup_impl.rs | 5 +- .../zkcuda/proving_system/expander/structs.rs | 32 +--- .../proving_system/expander/verify_impl.rs | 42 ++--- .../api_no_oversubscribe.rs | 6 +- .../expander_no_oversubscribe/prove_impl.rs | 178 ++++++++++++++++-- .../expander_no_oversubscribe/server_fn.rs | 13 +- .../expander_parallelized/api_parallel.rs | 6 +- .../expander_parallelized/client_utils.rs | 6 +- .../expander_parallelized/cmd_utils.rs | 8 +- .../expander_parallelized/prove_impl.rs | 11 +- .../expander_parallelized/server_ctrl.rs | 22 +-- .../expander_parallelized/server_fns.rs | 20 +- .../shared_memory_utils.rs | 29 +-- .../expander_parallelized/verify_impl.rs | 21 +-- .../expander_pcs_defered/api_pcs_defered.rs | 10 +- .../expander_pcs_defered/prove_impl.rs | 41 ++-- .../expander_pcs_defered/server_fns.rs | 4 - .../expander_pcs_defered/setup_impl.rs | 5 +- .../expander_pcs_defered/verify_impl.rs | 58 +++--- .../tests/example_call_expander.rs | 6 +- 25 files changed, 334 insertions(+), 293 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dd30779b..5a2ba4da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,7 +112,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "ark-std", "criterion", @@ -330,7 +330,7 @@ dependencies = [ [[package]] name = "babybear" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "ark-std", @@ -383,7 +383,7 @@ dependencies = [ [[package]] name = "bin" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "babybear", @@ -589,7 +589,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "ark-std", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "config_macros" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "gkr_engine", "gkr_hashers", @@ -817,7 +817,7 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crosslayer_prototype" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "env_logger", @@ -1143,7 +1143,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "ark-std", @@ -1160,7 +1160,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "ark-std", @@ -1179,7 +1179,7 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "ark-std", @@ -1212,7 +1212,7 @@ dependencies = [ [[package]] name = "gkr_engine" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "babybear", @@ -1231,7 +1231,7 @@ dependencies = [ [[package]] name = "gkr_hashers" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "halo2curves", @@ -1249,7 +1249,7 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "goldilocks" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "ark-std", @@ -1508,9 +1508,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.14" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc2fdfdbff08affe55bb779f33b053aa1fe5dd5b54c257343c17edfa55711bdb" +checksum = "7f66d5bd4c6f02bf0542fad85d626775bab9258cf795a4256dcaf3161114d1df" dependencies = [ "bytes", "futures-core", @@ -1881,7 +1881,7 @@ dependencies = [ [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "ark-std", @@ -2259,7 +2259,7 @@ dependencies = [ [[package]] name = "poly_commit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "ark-std", @@ -2284,7 +2284,7 @@ dependencies = [ [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "ark-std", @@ -2705,7 +2705,7 @@ dependencies = [ [[package]] name = "serdes" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "ethnum", "halo2curves", @@ -2716,7 +2716,7 @@ dependencies = [ [[package]] name = "serdes_derive" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "proc-macro2", "quote", @@ -2840,7 +2840,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "circuit", @@ -3106,7 +3106,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "gkr_engine", @@ -3129,7 +3129,7 @@ dependencies = [ [[package]] name = "tree" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "arith", "ark-std", @@ -3231,7 +3231,7 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "utils" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#fbd973c1cbe6371e87ecf09ef0d39b9d47f14fb7" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" dependencies = [ "colored", ] diff --git a/expander_compiler/ec_go_lib/src/proving.rs b/expander_compiler/ec_go_lib/src/proving.rs index a0a0b459..0fca5d90 100644 --- a/expander_compiler/ec_go_lib/src/proving.rs +++ b/expander_compiler/ec_go_lib/src/proving.rs @@ -20,11 +20,7 @@ use super::{match_config_id, ByteArray, Config}; fn prove_circuit_file_inner( circuit_filename: &str, witness: &[u8], -) -> Result, String> -where - C::FieldConfig: FieldEngine, - C::PCSField: SimdField::CircuitField>, -{ +) -> Result, String> { // (None, None) means single core execution let mpi_config = MPIConfig::prover_new(None, None); diff --git a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs index 25b96076..f301e672 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs @@ -35,13 +35,12 @@ impl KernelWiseProvingSystem for Expander where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { - type ProverSetup = ExpanderProverSetup; - type VerifierSetup = ExpanderVerifierSetup; + type ProverSetup = ExpanderProverSetup; + type VerifierSetup = ExpanderVerifierSetup; type Proof = ExpanderProof; - type Commitment = ExpanderCommitment; - type CommitmentState = ExpanderCommitmentState; + type Commitment = ExpanderCommitment; + type CommitmentState = ExpanderCommitmentState; fn setup( computation_graph: &crate::zkcuda::context::ComputationGraph, @@ -172,8 +171,6 @@ where // In this case, generate the implementation with a procedural macro seems to be the best solution. impl> ProvingSystem for Expander -where - C::FieldConfig: FieldEngine, { type ProverSetup = >::ProverSetup; type VerifierSetup = >::VerifierSetup; diff --git a/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs index 2b827083..0b015baa 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs @@ -9,29 +9,28 @@ use crate::{ }; pub fn local_commit_impl( - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, vals: &[SIMDField], ) -> ( - ExpanderCommitment, - ExpanderCommitmentState, + ExpanderCommitment, + ExpanderCommitmentState, ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let timer = Timer::new("commit", true); let n_vars = vals.len().ilog2() as usize; - let params = >::gen_params(n_vars, 1); + let params = >::gen_params(n_vars, 1); let p_key = prover_setup.p_keys.get(&vals.len()).unwrap(); - let mut scratch = >::init_scratch_pad( + let mut scratch = >::init_scratch_pad( ¶ms, &MPIConfig::prover_new(None, None), ); - let commitment = >::commit( + let commitment = >::commit( ¶ms, &MPIConfig::prover_new(None, None), p_key, diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index caf619b6..984b3366 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -158,18 +158,14 @@ pub fn partition_challenge_and_location_for_pcs_no_mpi( pub fn pcs_local_open_impl( vals: &[::SimdCircuitField], challenge: &ExpanderSingleVarChallenge, - p_keys: &ExpanderProverSetup, + p_keys: &ExpanderProverSetup, transcript: &mut C::TranscriptConfig, -) where - C::FieldConfig: FieldEngine, -{ +) { assert_eq!(challenge.r_mpi.len(), 0); let val_len = vals.len(); - let params = >::gen_params( - val_len.ilog2() as usize, - 1, - ); + let params = + >::gen_params(val_len.ilog2() as usize, 1); let p_key = p_keys.p_keys.get(&val_len).unwrap(); let poly = RefMultiLinearPoly::from_ref(vals); @@ -180,14 +176,14 @@ pub fn pcs_local_open_impl( transcript.append_field_element(&v); transcript.lock_proof(); - let opening = >::open( + let opening = >::open( ¶ms, &MPIConfig::prover_new(None, None), p_key, &poly, challenge, transcript, - &>::init_scratch_pad( + &>::init_scratch_pad( ¶ms, &MPIConfig::prover_new(None, None), ), @@ -206,14 +202,12 @@ pub fn pcs_local_open_impl( pub fn partition_gkr_claims_and_open_pcs_no_mpi_impl( gkr_claim: &ExpanderSingleVarChallenge, global_vals: &[impl AsRef<[::SimdCircuitField]>], - p_keys: &ExpanderProverSetup, + p_keys: &ExpanderProverSetup, is_broadcast: &[bool], parallel_index: usize, parallel_num: usize, transcript: &mut C::TranscriptConfig, -) where - C::FieldConfig: FieldEngine, -{ +) { for (commitment_val, ib) in global_vals.iter().zip(is_broadcast) { let val_len = commitment_val.as_ref().len(); let (challenge_for_pcs, _) = partition_challenge_and_location_for_pcs_no_mpi::< @@ -237,14 +231,12 @@ pub fn partition_gkr_claims_and_open_pcs_no_mpi_impl( pub fn partition_gkr_claims_and_open_pcs_no_mpi( gkr_claim: &ExpanderDualVarChallenge, global_vals: &[impl AsRef<[::SimdCircuitField]>], - p_keys: &ExpanderProverSetup, + p_keys: &ExpanderProverSetup, is_broadcast: &[bool], parallel_index: usize, parallel_num: usize, transcript: &mut C::TranscriptConfig, -) where - C::FieldConfig: FieldEngine, -{ +) { let challenges = if let Some(challenge_y) = gkr_claim.challenge_y() { vec![gkr_claim.challenge_x(), challenge_y] } else { diff --git a/expander_compiler/src/zkcuda/proving_system/expander/setup_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/setup_impl.rs index 7b087004..19a5329a 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/setup_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/setup_impl.rs @@ -16,13 +16,12 @@ use crate::{ pub fn local_setup_impl( computation_graph: &ComputationGraph, ) -> ( - ExpanderProverSetup, - ExpanderVerifierSetup, + ExpanderProverSetup, + ExpanderVerifierSetup, ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let mut p_keys = HashMap::new(); let mut v_keys = HashMap::new(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander/structs.rs b/expander_compiler/src/zkcuda/proving_system/expander/structs.rs index 07c47480..76738bab 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/structs.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/structs.rs @@ -9,14 +9,12 @@ use crate::{frontend::Config, zkcuda::proving_system::Commitment}; /// A wrapper for the PCS Commitment that includes the length of the values committed to. #[allow(clippy::type_complexity)] #[derive(ExpSerde)] -pub struct ExpanderCommitment> { +pub struct ExpanderCommitment> { pub vals_len: usize, pub commitment: PCS::Commitment, } -impl> Clone - for ExpanderCommitment -{ +impl> Clone for ExpanderCommitment { fn clone(&self) -> Self { Self { vals_len: self.vals_len, @@ -40,13 +38,11 @@ impl< /// For Raw, KZG, and Hyrax, this is not needed, so the scratchpad can be empty. #[allow(clippy::type_complexity)] #[derive(ExpSerde)] -pub struct ExpanderCommitmentState> { +pub struct ExpanderCommitmentState> { pub scratch: PCS::ScratchPad, } -impl> Clone - for ExpanderCommitmentState -{ +impl> Clone for ExpanderCommitmentState { fn clone(&self) -> Self { Self { scratch: self.scratch.clone(), @@ -58,13 +54,11 @@ impl> Clone /// The keys are indexed by the length of values committed to, allowing for different setups based on the length of the values. #[allow(clippy::type_complexity)] #[derive(ExpSerde)] -pub struct ExpanderProverSetup> { +pub struct ExpanderProverSetup> { pub p_keys: HashMap::PKey>, } -impl> Default - for ExpanderProverSetup -{ +impl> Default for ExpanderProverSetup { fn default() -> Self { Self { p_keys: HashMap::new(), @@ -72,9 +66,7 @@ impl> Default } } -impl> Clone - for ExpanderProverSetup -{ +impl> Clone for ExpanderProverSetup { fn clone(&self) -> Self { Self { p_keys: self.p_keys.clone(), @@ -86,14 +78,12 @@ impl> Clone /// The keys are indexed by the length of values committed to, allowing for different setups based on the length of the values. #[allow(clippy::type_complexity)] #[derive(ExpSerde)] -pub struct ExpanderVerifierSetup> { +pub struct ExpanderVerifierSetup> { pub v_keys: HashMap::VKey>, } // implement default -impl> Default - for ExpanderVerifierSetup -{ +impl> Default for ExpanderVerifierSetup { fn default() -> Self { Self { v_keys: HashMap::new(), @@ -101,9 +91,7 @@ impl> Default } } -impl> Clone - for ExpanderVerifierSetup -{ +impl> Clone for ExpanderVerifierSetup { fn clone(&self) -> Self { Self { v_keys: self.v_keys.clone(), diff --git a/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs index 8aed5cbf..d5c40f2a 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs @@ -24,35 +24,31 @@ use crate::{ pub fn verify_pcs( mut proof_reader: impl Read, - commitment: &ExpanderCommitment, + commitment: &ExpanderCommitment, challenge: &ExpanderSingleVarChallenge, claim: &::ChallengeField, - v_keys: &ExpanderVerifierSetup, + v_keys: &ExpanderVerifierSetup, transcript: &mut C::TranscriptConfig, ) -> bool where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { - let val_len = as Commitment< - ECCConfig, - >>::vals_len(commitment); + let val_len = + as Commitment>::vals_len( + commitment, + ); - let params = >::gen_params( - val_len.ilog2() as usize, - 1, - ); + let params = + >::gen_params(val_len.ilog2() as usize, 1); let v_key = v_keys.v_keys.get(&val_len).unwrap(); let opening = - >::Opening::deserialize_from( - &mut proof_reader, - ) - .unwrap(); + >::Opening::deserialize_from(&mut proof_reader) + .unwrap(); transcript.lock_proof(); - let verified = >::verify( + let verified = >::verify( ¶ms, v_key, &commitment.commitment, @@ -76,10 +72,10 @@ where pub fn verify_pcs_opening_and_aggregation_no_mpi_impl( mut proof_reader: impl Read, kernel: &Kernel, - v_keys: &ExpanderVerifierSetup, + v_keys: &ExpanderVerifierSetup, challenge: &ExpanderSingleVarChallenge, y: &::ChallengeField, - commitments: &[&ExpanderCommitment], + commitments: &[&ExpanderCommitment], is_broadcast: &[bool], parallel_index: usize, parallel_count: usize, @@ -88,7 +84,6 @@ pub fn verify_pcs_opening_and_aggregation_no_mpi_impl( where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let mut target_y = ::ChallengeField::ZERO; for ((input, commitment), ib) in kernel @@ -98,9 +93,9 @@ where .zip(is_broadcast) { let val_len = - as Commitment< - ECCConfig, - >>::vals_len(commitment); + as Commitment>::vals_len( + commitment, + ); let (challenge_for_pcs, component_idx_vars) = partition_challenge_and_location_for_pcs_no_mpi( challenge, @@ -144,11 +139,11 @@ where pub fn verify_pcs_opening_and_aggregation_no_mpi( mut proof_reader: impl Read, kernel: &Kernel, - v_keys: &ExpanderVerifierSetup, + v_keys: &ExpanderVerifierSetup, challenge: &ExpanderDualVarChallenge, claim_v0: ::ChallengeField, claim_v1: Option<::ChallengeField>, - commitments: &[&ExpanderCommitment], + commitments: &[&ExpanderCommitment], is_broadcast: &[bool], parallel_index: usize, parallel_count: usize, @@ -157,7 +152,6 @@ pub fn verify_pcs_opening_and_aggregation_no_mpi( where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let challenges = if let Some(challenge_y) = challenge.challenge_y() { vec![challenge.challenge_x(), challenge_y] diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index 39d51cf4..da66ca9f 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -20,11 +20,9 @@ pub struct ExpanderNoOverSubscribe { impl> ProvingSystem for ExpanderNoOverSubscribe -where - C::FieldConfig: FieldEngine, { - type ProverSetup = ExpanderProverSetup; - type VerifierSetup = ExpanderVerifierSetup; + type ProverSetup = ExpanderProverSetup; + type VerifierSetup = ExpanderVerifierSetup; type Proof = CombinedProof>; fn setup( diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index 539657b1..229a725d 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -1,8 +1,8 @@ -use arith::Field; +use arith::{Field, Fr}; use expander_utils::timer::Timer; use gkr_engine::{ - ExpanderSingleVarChallenge, FieldEngine, GKREngine, MPIConfig, - MPIEngine, Transcript, + BN254ConfigXN, ExpanderDualVarChallenge, ExpanderSingleVarChallenge, FieldEngine, FieldType, + GKREngine, MPIConfig, MPIEngine, Transcript, }; use crate::{ @@ -10,28 +10,28 @@ use crate::{ utils::misc::next_power_of_two, zkcuda::{ context::ComputationGraph, + kernel::Kernel, proving_system::{ expander::{ commit_impl::local_commit_impl, prove_impl::pcs_local_open_impl, structs::{ExpanderCommitmentState, ExpanderProof, ExpanderProverSetup}, }, - expander_parallelized::prove_impl::prove_kernel_gkr, + expander_parallelized::server_ctrl::generate_local_mpi_config, CombinedProof, Expander, }, }, }; -pub fn mpi_prove_impl( +pub fn mpi_prove_no_oversubscribe_impl( global_mpi_config: &MPIConfig<'static>, - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, computation_graph: &ComputationGraph, values: &[impl AsRef<[SIMDField]>], ) -> Option>> where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); let (commitments, states) = if global_mpi_config.is_root() { @@ -58,7 +58,11 @@ where let single_kernel_gkr_timer = Timer::new("small gkr kernel", global_mpi_config.is_root()); - let gkr_end_state = prove_kernel_gkr::( + let gkr_end_state = prove_kernel_gkr_no_oversubscribe::< + C::FieldConfig, + C::TranscriptConfig, + ECCConfig, + >( global_mpi_config, &computation_graph.kernels()[template.kernel_id()], &commitment_values, @@ -113,6 +117,156 @@ where } } +#[allow(clippy::too_many_arguments)] +pub fn prove_kernel_gkr_no_oversubscribe( + mpi_config: &MPIConfig<'static>, + kernel: &Kernel, + commitments_values: &[&[F::SimdCircuitField]], + parallel_count: usize, + is_broadcast: &[bool], +) -> Option<(T, ExpanderDualVarChallenge)> +where + F: FieldEngine, + T: Transcript, + ECCConfig: Config, +{ + let local_mpi_config = generate_local_mpi_config(mpi_config, parallel_count); + + local_mpi_config.as_ref()?; + + let local_mpi_config = local_mpi_config.unwrap(); + let local_world_size = local_mpi_config.world_size(); + let local_world_rank = local_mpi_config.world_rank(); + + let n_local_copies = parallel_count / local_world_size; + match n_local_copies { + 1 => prove_kernel_gkr_internal::( + mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + ), + 2 => prove_kernel_gkr_internal::, T, ECCConfig>( + mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + ), + 4 => prove_kernel_gkr_internal::, T, ECCConfig>( + mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + ), + 8 => prove_kernel_gkr_internal::, T, ECCConfig>( + mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + ), + 16 => prove_kernel_gkr_internal::, T, ECCConfig>( + mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + ), + 32 => prove_kernel_gkr_internal::, T, ECCConfig>( + mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + ), + 64 => prove_kernel_gkr_internal::, T, ECCConfig>( + mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + ), + 128 => prove_kernel_gkr_internal::, T, ECCConfig>( + mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + ), + 256 => prove_kernel_gkr_internal::, T, ECCConfig>( + mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + ), + 512 => prove_kernel_gkr_internal::, T, ECCConfig>( + mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + ), + 1024 => prove_kernel_gkr_internal::, T, ECCConfig>( + mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + ), + 2048 => prove_kernel_gkr_internal::, T, ECCConfig>( + mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + ), + _ => { + panic!("Unsupported parallel count: {}", parallel_count); + } + } +} + +pub fn prove_kernel_gkr_internal( + mpi_config: &MPIConfig<'static>, + kernel: &Kernel, + commitments_values: &[&[FBasic::SimdCircuitField]], + parallel_count: usize, + is_broadcast: &[bool], +) -> Option<(T, ExpanderDualVarChallenge)> +where + FBasic: FieldEngine, + FMulti: + FieldEngine, + T: Transcript, + ECCConfig: Config, +{ + let local_commitment_values = get_local_vals( + commitments_values, + is_broadcast, + local_world_rank, + local_world_size, + ); + + let (mut expander_circuit, mut prover_scratch) = + prepare_expander_circuit::(kernel, local_world_size); + + let mut transcript = T::new(); + let challenge = prove_gkr_with_local_vals::( + &mut expander_circuit, + &mut prover_scratch, + &local_commitment_values, + kernel.layered_circuit_input(), + &mut transcript, + &local_mpi_config, + ); + + Some((transcript, challenge)) +} + pub fn partition_challenge_and_location_for_pcs_mpi( gkr_challenge: &ExpanderSingleVarChallenge, total_vals_len: usize, @@ -140,15 +294,13 @@ pub fn partition_challenge_and_location_for_pcs_mpi( #[allow(clippy::too_many_arguments)] fn partition_single_gkr_claim_and_open_pcs_mpi( - p_keys: &ExpanderProverSetup, + p_keys: &ExpanderProverSetup, commitments_values: &[impl AsRef<[SIMDField]>], - commitments_state: &[&ExpanderCommitmentState], + commitments_state: &[&ExpanderCommitmentState], gkr_challenge: &ExpanderSingleVarChallenge, is_broadcast: &[bool], transcript: &mut C::TranscriptConfig, -) where - C::FieldConfig: FieldEngine, -{ +) { let parallel_count = 1 << gkr_challenge.r_mpi.len(); for ((commitment_val, _state), ib) in commitments_values .iter() diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index 8a4cbb04..08215258 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -17,17 +17,14 @@ impl ServerFns for ExpanderNoOverSubscribe where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { fn setup_request_handler( global_mpi_config: &MPIConfig<'static>, setup_file: Option, computation_graph: &mut ComputationGraph, - prover_setup: &mut ExpanderProverSetup, - verifier_setup: &mut ExpanderVerifierSetup, - ) where - C::FieldConfig: FieldEngine, - { + prover_setup: &mut ExpanderProverSetup, + verifier_setup: &mut ExpanderVerifierSetup, + ) { ParallelizedExpander::::setup_request_handler( global_mpi_config, setup_file, @@ -39,14 +36,13 @@ where fn prove_request_handler( global_mpi_config: &MPIConfig<'static>, - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, computation_graph: &ComputationGraph, values: &[impl AsRef<[SIMDField]>], ) -> Option>> where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { mpi_prove_impl(global_mpi_config, prover_setup, computation_graph, values) } @@ -71,7 +67,6 @@ pub fn read_circuit( ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let computation_graph_bytes = std::fs::read(setup_file).expect("Failed to read computation graph from file"); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs index 1f99669c..5a934627 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs @@ -24,11 +24,9 @@ pub struct ParallelizedExpander { impl> ProvingSystem for ParallelizedExpander -where - C::FieldConfig: FieldEngine, { - type ProverSetup = ExpanderProverSetup; - type VerifierSetup = ExpanderVerifierSetup; + type ProverSetup = ExpanderProverSetup; + type VerifierSetup = ExpanderVerifierSetup; type Proof = CombinedProof>; fn setup( diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index 2ccf0e61..379819bd 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -79,13 +79,12 @@ pub fn client_launch_server_and_setup( computation_graph: &ComputationGraph, allow_oversubscribe: bool, ) -> ( - ExpanderProverSetup, - ExpanderVerifierSetup, + ExpanderProverSetup, + ExpanderVerifierSetup, ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let setup_timer = Timer::new("setup", true); println!("Starting server with binary: {server_binary}"); @@ -143,7 +142,6 @@ pub fn client_send_witness_and_prove( where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let timer = Timer::new("prove", true); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs index 48dd79de..11db0122 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs @@ -2,10 +2,7 @@ use gkr_engine::{ExpanderPCS, FieldEngine, FieldType, GKREngine, PolynomialCommi use std::process::Command; #[allow(clippy::zombie_processes)] -pub fn start_server(binary: &str, max_parallel_count: usize, port_number: u16) -where - C::FieldConfig: FieldEngine, -{ +pub fn start_server(binary: &str, max_parallel_count: usize, port_number: u16) { let (overscribe, field_name, pcs_name) = parse_config::(max_parallel_count); let cmd_str = format!( @@ -16,7 +13,6 @@ where fn parse_config(mpi_size: usize) -> (String, String, String) where - C::FieldConfig: FieldEngine, { let oversubscription = if mpi_size > num_cpus::get_physical() { println!("Warning: Not enough cores available for the requested number of processes. Using oversubscription."); @@ -34,7 +30,7 @@ where _ => panic!("Unsupported field type"), }; - let pcs_name = match >::PCS_TYPE { + let pcs_name = match >::PCS_TYPE { PolynomialCommitmentType::Raw => "Raw", PolynomialCommitmentType::Hyrax => "Hyrax", PolynomialCommitmentType::KZG => "KZG", diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index 8ec1c5c7..e6a2b674 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -28,14 +28,13 @@ use crate::{ pub fn mpi_prove_impl( global_mpi_config: &MPIConfig<'static>, - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, computation_graph: &ComputationGraph, values: &[impl AsRef<[SIMDField]>], ) -> Option>> where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); let (commitments, states) = if global_mpi_config.is_root() { @@ -188,15 +187,13 @@ pub fn partition_challenge_and_location_for_pcs_mpi( #[allow(clippy::too_many_arguments)] fn partition_single_gkr_claim_and_open_pcs_mpi( - p_keys: &ExpanderProverSetup, + p_keys: &ExpanderProverSetup, commitments_values: &[impl AsRef<[SIMDField]>], - commitments_state: &[&ExpanderCommitmentState], + commitments_state: &[&ExpanderCommitmentState], gkr_challenge: &ExpanderSingleVarChallenge, is_broadcast: &[bool], transcript: &mut C::TranscriptConfig, -) where - C::FieldConfig: FieldEngine, -{ +) { let parallel_count = 1 << gkr_challenge.r_mpi.len(); for ((commitment_val, _state), ib) in commitments_values .iter() diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs index 04116049..80c98334 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs @@ -50,38 +50,28 @@ pub static mut UNIVERSE: Option = None; pub static mut GLOBAL_COMMUNICATOR: Option = None; pub static mut LOCAL_COMMUNICATOR: Option = None; -pub struct ServerState> -where - C::FieldConfig: FieldEngine, -{ +pub struct ServerState> { pub lock: Arc>, // For now we want to ensure that only one request is processed at a time pub global_mpi_config: MPIConfig<'static>, pub local_mpi_config: Option>, - pub prover_setup: Arc>>, - pub verifier_setup: - Arc>>, + pub prover_setup: Arc>>, + pub verifier_setup: Arc>>, pub computation_graph: Arc>>, pub shutdown_tx: Arc>>>, } unsafe impl> Send for ServerState -where - C::FieldConfig: FieldEngine, { } unsafe impl> Sync for ServerState -where - C::FieldConfig: FieldEngine, { } impl> Clone for ServerState -where - C::FieldConfig: FieldEngine, { fn clone(&self) -> Self { ServerState { @@ -103,7 +93,7 @@ pub async fn root_main( where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, + S: ServerFns, { let _lock = state.lock.lock().await; // Ensure only one request is processed at a time @@ -170,7 +160,7 @@ pub async fn worker_main(global_mpi_config: MPIConfig<'static>) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, + S: ServerFns, { let state = ServerState:: { @@ -266,7 +256,7 @@ pub async fn serve(port_number: String) where C: GKREngine + 'static, ECCConfig: Config + 'static, - C::FieldConfig: FieldEngine, + S: ServerFns + 'static, { let global_mpi_config = unsafe { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs index 31aedafe..ccc7bd8f 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs @@ -20,19 +20,18 @@ pub trait ServerFns where C: gkr_engine::GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { fn setup_request_handler( global_mpi_config: &MPIConfig<'static>, setup_file: Option, computation_graph: &mut ComputationGraph, - prover_setup: &mut ExpanderProverSetup, - verifier_setup: &mut ExpanderVerifierSetup, + prover_setup: &mut ExpanderProverSetup, + verifier_setup: &mut ExpanderVerifierSetup, ); fn prove_request_handler( global_mpi_config: &MPIConfig<'static>, - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, computation_graph: &ComputationGraph, values: &[impl AsRef<[SIMDField]>], ) -> Option>>; @@ -42,17 +41,14 @@ impl ServerFns for ParallelizedExpander where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { fn setup_request_handler( global_mpi_config: &MPIConfig<'static>, setup_file: Option, computation_graph: &mut ComputationGraph, - prover_setup: &mut ExpanderProverSetup, - verifier_setup: &mut ExpanderVerifierSetup, - ) where - C::FieldConfig: FieldEngine, - { + prover_setup: &mut ExpanderProverSetup, + verifier_setup: &mut ExpanderVerifierSetup, + ) { let setup_file = if global_mpi_config.is_root() { let setup_file = setup_file.expect("Setup file path must be provided"); broadcast_string(global_mpi_config, Some(setup_file)) @@ -69,14 +65,13 @@ where fn prove_request_handler( global_mpi_config: &MPIConfig<'static>, - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, computation_graph: &ComputationGraph, values: &[impl AsRef<[SIMDField]>], ) -> Option>> where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { mpi_prove_impl(global_mpi_config, prover_setup, computation_graph, values) } @@ -101,7 +96,6 @@ pub fn read_circuit( ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let computation_graph_bytes = std::fs::read(setup_file).expect("Failed to read computation graph from file"); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs index 720d2a95..a98288d2 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs @@ -87,15 +87,8 @@ impl SharedMemoryEngine { /// This impl block contains functions for reading/writing specific objects to shared memory. impl SharedMemoryEngine { - pub fn write_pcs_setup_to_shared_memory< - PCSField: Field, - F: FieldEngine, - PCS: ExpanderPCS, - >( - pcs_setup: &( - ExpanderProverSetup, - ExpanderVerifierSetup, - ), + pub fn write_pcs_setup_to_shared_memory>( + pcs_setup: &(ExpanderProverSetup, ExpanderVerifierSetup), ) { Self::write_object_to_shared_memory( pcs_setup, @@ -104,14 +97,8 @@ impl SharedMemoryEngine { ); } - pub fn read_pcs_setup_from_shared_memory< - PCSField: Field, - F: FieldEngine, - PCS: ExpanderPCS, - >() -> ( - ExpanderProverSetup, - ExpanderVerifierSetup, - ) { + pub fn read_pcs_setup_from_shared_memory>( + ) -> (ExpanderProverSetup, ExpanderVerifierSetup) { Self::read_object_from_shared_memory("pcs_setup", 0) } @@ -189,9 +176,7 @@ impl SharedMemoryEngine { ECCConfig: Config, >( proof: &CombinedProof>, - ) where - C::FieldConfig: FieldEngine, - { + ) { Self::write_object_to_shared_memory(proof, unsafe { &mut SHARED_MEMORY.proof }, "proof"); } @@ -199,9 +184,7 @@ impl SharedMemoryEngine { C: GKREngine, ECCConfig: Config, >() -> CombinedProof> - where - C::FieldConfig: FieldEngine, - { +where { Self::read_object_from_shared_memory("proof", 0) } } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs index 763c3aff..5b1d28a6 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs @@ -25,17 +25,16 @@ use crate::{ }; pub fn verify_kernel( - verifier_setup: &ExpanderVerifierSetup, + verifier_setup: &ExpanderVerifierSetup, kernel: &Kernel, proof: &ExpanderProof, - commitments: &[&ExpanderCommitment], + commitments: &[&ExpanderCommitment], parallel_count: usize, is_broadcast: &[bool], ) -> bool where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let timer = Timer::new("verify", true); let mut expander_circuit = kernel.layered_circuit().export_to_expander_flatten(); @@ -83,10 +82,10 @@ where pub fn verify_pcs_opening_and_aggregation_mpi_impl( mut proof_reader: impl Read, kernel: &Kernel, - v_keys: &ExpanderVerifierSetup, + v_keys: &ExpanderVerifierSetup, challenge: &ExpanderSingleVarChallenge, y: &::ChallengeField, - commitments: &[&ExpanderCommitment], + commitments: &[&ExpanderCommitment], is_broadcast: &[bool], parallel_count: usize, transcript: &mut C::TranscriptConfig, @@ -94,7 +93,6 @@ pub fn verify_pcs_opening_and_aggregation_mpi_impl( where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let mut target_y = ::ChallengeField::ZERO; for ((input, commitment), ib) in kernel @@ -104,9 +102,9 @@ where .zip(is_broadcast) { let val_len = - as Commitment< - ECCConfig, - >>::vals_len(commitment); + as Commitment>::vals_len( + commitment, + ); let (challenge_for_pcs, component_idx_vars) = partition_challenge_and_location_for_pcs_mpi(challenge, val_len, parallel_count, *ib); @@ -142,11 +140,11 @@ where pub fn verify_pcs_opening_and_aggregation_mpi( mut proof_reader: impl Read, kernel: &Kernel, - v_keys: &ExpanderVerifierSetup, + v_keys: &ExpanderVerifierSetup, challenge: &ExpanderDualVarChallenge, claim_v0: ::ChallengeField, claim_v1: Option<::ChallengeField>, - commitments: &[&ExpanderCommitment], + commitments: &[&ExpanderCommitment], is_broadcast: &[bool], parallel_count: usize, transcript: &mut C::TranscriptConfig, @@ -154,7 +152,6 @@ pub fn verify_pcs_opening_and_aggregation_mpi( where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let challenges = if let Some(challenge_y) = challenge.challenge_y() { vec![challenge.challenge_x(), challenge_y] diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs index da0d30fc..0459826a 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs @@ -20,13 +20,13 @@ impl ProvingSystem for ExpanderPCSDefered where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, - >::Commitment: - AsRef<>::Commitment>, + + >::Commitment: + AsRef<>::Commitment>, { - type ProverSetup = ExpanderProverSetup; + type ProverSetup = ExpanderProverSetup; - type VerifierSetup = ExpanderVerifierSetup; + type VerifierSetup = ExpanderVerifierSetup; type Proof = CombinedProof>; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs index cd27e252..a8d6f284 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs @@ -28,16 +28,15 @@ use crate::{ }; pub fn pad_vals_and_commit( - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, vals: &[SIMDField], ) -> ( - ExpanderCommitment, - ExpanderCommitmentState, + ExpanderCommitment, + ExpanderCommitmentState, ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { assert_eq!(prover_setup.p_keys.len(), 1); let len_to_commit = prover_setup.p_keys.keys().next().cloned().unwrap(); @@ -56,14 +55,13 @@ where } pub fn open_defered_pcs( - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, vals: &[&[SIMDField]], challenges: &[ExpanderSingleVarChallenge], ) -> ExpanderProof where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { // TODO: Efficiency let polys: Vec<_> = vals @@ -74,26 +72,23 @@ where // TODO: Soundness let mut transcript = C::TranscriptConfig::new(); let max_length = prover_setup.p_keys.keys().max().cloned().unwrap_or(0); - let params = >::gen_params( - max_length.ilog2() as usize, - 1, - ); - let scratch_pad = >::init_scratch_pad( + let params = + >::gen_params(max_length.ilog2() as usize, 1); + let scratch_pad = >::init_scratch_pad( ¶ms, &MPIConfig::prover_new(None, None), ); transcript.lock_proof(); - let (vals, opening) = - >::multi_points_batch_open( - ¶ms, - &MPIConfig::prover_new(None, None), - prover_setup.p_keys.get(&max_length).unwrap(), - &polys, - challenges, - &scratch_pad, - &mut transcript, - ); + let (vals, opening) = >::multi_points_batch_open( + ¶ms, + &MPIConfig::prover_new(None, None), + prover_setup.p_keys.get(&max_length).unwrap(), + &polys, + challenges, + &scratch_pad, + &mut transcript, + ); transcript.unlock_proof(); let mut bytes = vec![]; @@ -107,14 +102,13 @@ where pub fn mpi_prove_with_pcs_defered( global_mpi_config: &MPIConfig<'static>, - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, computation_graph: &ComputationGraph, values: &[impl AsRef<[SIMDField]>], ) -> Option>> where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); let (commitments, _states) = if global_mpi_config.is_root() { @@ -206,7 +200,6 @@ pub fn extract_pcs_claims<'a, C: GKREngine>( Vec>, ) where - C::FieldConfig: FieldEngine, { let mut commitment_values_rt = vec![]; let mut challenges = vec![]; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs index 496a1ad3..aa281763 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs @@ -19,19 +19,16 @@ impl ServerFns for ExpanderPCSDefered where C: gkr_engine::GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { fn setup_request_handler( global_mpi_config: &gkr_engine::MPIConfig<'static>, setup_file: Option, computation_graph: &mut ComputationGraph, prover_setup: &mut ExpanderProverSetup< - ::PCSField, ::FieldConfig, ::PCSConfig, >, verifier_setup: &mut ExpanderVerifierSetup< - ::PCSField, ::FieldConfig, ::PCSConfig, >, @@ -54,7 +51,6 @@ where fn prove_request_handler( global_mpi_config: &gkr_engine::MPIConfig<'static>, prover_setup: &ExpanderProverSetup< - ::PCSField, ::FieldConfig, ::PCSConfig, >, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/setup_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/setup_impl.rs index 6aaf73fd..ffc7e419 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/setup_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/setup_impl.rs @@ -16,13 +16,12 @@ use crate::{ pub fn pcs_setup_max_length_only( computation_graph: &ComputationGraph, ) -> ( - ExpanderProverSetup, - ExpanderVerifierSetup, + ExpanderProverSetup, + ExpanderVerifierSetup, ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let mut p_keys = HashMap::new(); let mut v_keys = HashMap::new(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs index 95cfbcd0..08d4f45e 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs @@ -25,27 +25,26 @@ use crate::{ }; fn verifier_extract_pcs_claims<'a, C, ECCConfig>( - commitments: &[&'a ExpanderCommitment], + commitments: &[&'a ExpanderCommitment], gkr_challenge: &ExpanderSingleVarChallenge, is_broadcast: &[bool], parallel_count: usize, ) -> ( - Vec<&'a ExpanderCommitment>, + Vec<&'a ExpanderCommitment>, Vec>, ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let mut commitments_rt = vec![]; let mut challenges = vec![]; for (&commitment, ib) in commitments.iter().zip(is_broadcast) { let val_len = - as Commitment< - ECCConfig, - >>::vals_len(commitment); + as Commitment>::vals_len( + commitment, + ); let (challenge_for_pcs, _) = partition_challenge_and_location_for_pcs_mpi( gkr_challenge, val_len, @@ -68,7 +67,6 @@ pub fn verify_gkr( where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let mut expander_circuit = kernel.layered_circuit().export_to_expander_flatten(); @@ -95,21 +93,20 @@ where pub fn verify_defered_pcs_opening( proof: &BytesProof, - verifier_setup: &ExpanderVerifierSetup, - commitments: &[&ExpanderCommitment], + verifier_setup: &ExpanderVerifierSetup, + commitments: &[&ExpanderCommitment], challenges: &[ExpanderSingleVarChallenge], ) -> bool where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, - >::Commitment: - AsRef<>::Commitment>, + + >::Commitment: + AsRef<>::Commitment>, { let mut transcript = C::TranscriptConfig::new(); let max_num_vars = verifier_setup.v_keys.keys().max().cloned().unwrap_or(0); - let params = - >::gen_params(max_num_vars, 1); + let params = >::gen_params(max_num_vars, 1); let mut defered_proof_bytes = proof.bytes.clone(); let mut cursor = Cursor::new(&mut defered_proof_bytes); @@ -122,38 +119,35 @@ where Vec::<::ChallengeField>::deserialize_from(&mut cursor) .unwrap(); let opening = - >::BatchOpening::deserialize_from( - &mut cursor, - ) - .unwrap(); + >::BatchOpening::deserialize_from(&mut cursor) + .unwrap(); transcript.lock_proof(); - let pcs_verified = - >::multi_points_batch_verify( - ¶ms, - verifier_setup.v_keys.get(&max_num_vars).unwrap(), - &commitments, - challenges, - &vals, - &opening, - &mut transcript, - ); + let pcs_verified = >::multi_points_batch_verify( + ¶ms, + verifier_setup.v_keys.get(&max_num_vars).unwrap(), + &commitments, + challenges, + &vals, + &opening, + &mut transcript, + ); transcript.unlock_proof(); pcs_verified } pub fn verify( - verifier_setup: &ExpanderVerifierSetup, + verifier_setup: &ExpanderVerifierSetup, computation_graph: &ComputationGraph, mut proof: CombinedProof>, ) -> bool where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, - >::Commitment: - AsRef<>::Commitment>, + + >::Commitment: + AsRef<>::Commitment>, { let verification_timer = Timer::new("Total Verification", true); let pcs_batch_opening = proof.proofs.pop().unwrap(); diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs index e38f88e0..da5de25a 100644 --- a/expander_compiler/tests/example_call_expander.rs +++ b/expander_compiler/tests/example_call_expander.rs @@ -21,11 +21,7 @@ impl Define for Circuit { } } -fn example() -where - C::PCSField: SimdField::CircuitField>, - C::FieldConfig: FieldEngine, -{ +fn example() { let n_witnesses = SIMDField::::PACK_SIZE; println!("n_witnesses: {}", n_witnesses); let compile_result: CompileResult = From a9b755a5530aeb5a22b622b0671ae95e027c8628 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 7 Jul 2025 21:24:22 -0500 Subject: [PATCH 16/60] no overscribe version done --- .../proving_system/expander/prove_impl.rs | 9 +- .../zkcuda/proving_system/expander/structs.rs | 7 +- .../zkcuda/proving_system/expander/utils.rs | 12 +- .../api_no_oversubscribe.rs | 1 + .../expander_no_oversubscribe/prove_impl.rs | 143 +++++++++++------- .../expander_no_oversubscribe/server_fn.rs | 11 +- .../expander_parallelized/prove_impl.rs | 2 +- 7 files changed, 106 insertions(+), 79 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index 984b3366..459f1ae4 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -13,7 +13,7 @@ use crate::{ frontend::Config, zkcuda::{ kernel::{Kernel, LayeredCircuitInputVec}, - proving_system::expander::structs::ExpanderProverSetup, + proving_system::expander::{self, structs::ExpanderProverSetup}, }, }; @@ -25,9 +25,12 @@ pub fn prepare_expander_circuit( ) -> (Circuit, ProverScratchPad) where F: FieldEngine, - ECCConfig: Config, + ECCConfig: Config, + ECCConfig::FieldConfig: FieldEngine, { - let expander_circuit = kernel.layered_circuit().export_to_expander_flatten(); + let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten(); + expander_circuit.pre_process_gkr(); + let (max_num_input_var, max_num_output_var) = super::utils::max_n_vars(&expander_circuit); let prover_scratch = ProverScratchPad::::new(max_num_input_var, max_num_output_var, mpi_world_size); diff --git a/expander_compiler/src/zkcuda/proving_system/expander/structs.rs b/expander_compiler/src/zkcuda/proving_system/expander/structs.rs index 76738bab..fc6e1940 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/structs.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/structs.rs @@ -23,11 +23,8 @@ impl> Clone for ExpanderCommitment { } } -impl< - F: FieldEngine, - PCS: ExpanderPCS, - ECCConfig: Config, - > Commitment for ExpanderCommitment +impl, ECCConfig: Config> Commitment + for ExpanderCommitment { fn vals_len(&self) -> usize { self.vals_len diff --git a/expander_compiler/src/zkcuda/proving_system/expander/utils.rs b/expander_compiler/src/zkcuda/proving_system/expander/utils.rs index a525eab8..0273a55d 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/utils.rs @@ -3,12 +3,7 @@ use gkr_engine::{ExpanderPCS, FieldEngine, MPIConfig, StructuredReferenceString, use poly_commit::expander_pcs_init_testing_only; #[allow(clippy::type_complexity)] -pub fn pcs_testing_setup_fixed_seed< - 'a, - F: FieldEngine, - T: Transcript, - PCS: ExpanderPCS, ->( +pub fn pcs_testing_setup_fixed_seed<'a, F: FieldEngine, T: Transcript, PCS: ExpanderPCS>( vals_len: usize, mpi_config: &MPIConfig<'a>, ) -> ( @@ -17,10 +12,7 @@ pub fn pcs_testing_setup_fixed_seed< ::VKey, PCS::ScratchPad, ) { - expander_pcs_init_testing_only::( - vals_len.ilog2() as usize, - mpi_config, - ) + expander_pcs_init_testing_only::(vals_len.ilog2() as usize, mpi_config) } pub fn max_n_vars(circuit: &ExpanderCircuit) -> (usize, usize) { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index da66ca9f..5541c3bd 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -12,6 +12,7 @@ use crate::zkcuda::proving_system::{CombinedProof, ParallelizedExpander, Proving use super::super::Expander; +use arith::Fr; use gkr_engine::{FieldEngine, GKREngine}; pub struct ExpanderNoOverSubscribe { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index 229a725d..542d3f9d 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -1,4 +1,4 @@ -use arith::{Field, Fr}; +use arith::{Field, Fr, SimdField}; use expander_utils::timer::Timer; use gkr_engine::{ BN254ConfigXN, ExpanderDualVarChallenge, ExpanderSingleVarChallenge, FieldEngine, FieldType, @@ -10,13 +10,17 @@ use crate::{ utils::misc::next_power_of_two, zkcuda::{ context::ComputationGraph, - kernel::Kernel, + kernel::{Kernel, LayeredCircuitInputVec}, proving_system::{ expander::{ commit_impl::local_commit_impl, - prove_impl::pcs_local_open_impl, + prove_impl::{ + get_local_vals, pcs_local_open_impl, prepare_expander_circuit, + prepare_inputs_with_local_vals, + }, structs::{ExpanderCommitmentState, ExpanderProof, ExpanderProverSetup}, }, + expander_parallelized::prove_impl::partition_single_gkr_claim_and_open_pcs_mpi, expander_parallelized::server_ctrl::generate_local_mpi_config, CombinedProof, Expander, }, @@ -31,6 +35,7 @@ pub fn mpi_prove_no_oversubscribe_impl( ) -> Option>> where C: GKREngine, + C::FieldConfig: FieldEngine, ECCConfig: Config, { let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); @@ -136,7 +141,6 @@ where let local_mpi_config = local_mpi_config.unwrap(); let local_world_size = local_mpi_config.world_size(); - let local_world_rank = local_mpi_config.world_rank(); let n_local_copies = parallel_count / local_world_size; match n_local_copies { @@ -244,82 +248,105 @@ where T: Transcript, ECCConfig: Config, { - let local_commitment_values = get_local_vals( + let world_rank = mpi_config.world_rank(); + let world_size = mpi_config.world_size(); + let n_copies = parallel_count / world_size; + + let local_commitment_values = get_local_vals_multi_copies( commitments_values, is_broadcast, - local_world_rank, - local_world_size, + world_rank, + n_copies, + parallel_count, ); let (mut expander_circuit, mut prover_scratch) = - prepare_expander_circuit::(kernel, local_world_size); + prepare_expander_circuit::(kernel, world_size); let mut transcript = T::new(); - let challenge = prove_gkr_with_local_vals::( + let challenge = prove_gkr_with_local_vals_multi_copies::( &mut expander_circuit, &mut prover_scratch, &local_commitment_values, kernel.layered_circuit_input(), &mut transcript, - &local_mpi_config, + &mpi_config, ); Some((transcript, challenge)) } -pub fn partition_challenge_and_location_for_pcs_mpi( - gkr_challenge: &ExpanderSingleVarChallenge, - total_vals_len: usize, +pub fn get_local_vals_multi_copies<'vals_life, F: Field>( + global_vals: &'vals_life [impl AsRef<[F]>], + is_broadcast: &[bool], + local_world_rank: usize, + n_copies: usize, parallel_count: usize, - is_broadcast: bool, -) -> (ExpanderSingleVarChallenge, Vec) { - let mut challenge = gkr_challenge.clone(); - let zero = F::ChallengeField::ZERO; - if is_broadcast { - let n_vals_vars = total_vals_len.ilog2() as usize; - let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); - challenge.rz.resize(n_vals_vars, zero); - challenge.r_mpi.clear(); - (challenge, component_idx_vars) - } else { - let n_vals_vars = (total_vals_len / parallel_count).ilog2() as usize; - let component_idx_vars = challenge.rz[n_vals_vars..].to_vec(); - challenge.rz.resize(n_vals_vars, zero); +) -> Vec> { + let parallel_indices = (0..n_copies) + .map(|i| local_world_rank * n_copies + i) + .collect::>(); - challenge.rz.extend_from_slice(&challenge.r_mpi); - challenge.r_mpi.clear(); - (challenge, component_idx_vars) - } + parallel_indices + .iter() + .map(|¶llel_index| { + get_local_vals(global_vals, is_broadcast, parallel_index, parallel_count) + }) + .collect::>() } -#[allow(clippy::too_many_arguments)] -fn partition_single_gkr_claim_and_open_pcs_mpi( - p_keys: &ExpanderProverSetup, - commitments_values: &[impl AsRef<[SIMDField]>], - commitments_state: &[&ExpanderCommitmentState], - gkr_challenge: &ExpanderSingleVarChallenge, - is_broadcast: &[bool], - transcript: &mut C::TranscriptConfig, -) { - let parallel_count = 1 << gkr_challenge.r_mpi.len(); - for ((commitment_val, _state), ib) in commitments_values +pub fn prove_gkr_with_local_vals_multi_copies( + expander_circuit: &mut expander_circuit::Circuit, + prover_scratch: &mut sumcheck::ProverScratchPad, + local_commitment_values_multi_copies: &[Vec>], + partition_info: &[LayeredCircuitInputVec], + transcript: &mut T, + mpi_config: &MPIConfig, +) -> ExpanderDualVarChallenge +where + FBasic: FieldEngine, + FMulti: + FieldEngine, + T: Transcript, +{ + let input_vals_multi_copies = local_commitment_values_multi_copies .iter() - .zip(commitments_state) - .zip(is_broadcast) - { - let val_len = commitment_val.as_ref().len(); - let (challenge_for_pcs, _) = partition_challenge_and_location_for_pcs_mpi( - gkr_challenge, - val_len, - parallel_count, - *ib, - ); + .map(|local_commitment_values| { + prepare_inputs_with_local_vals( + 1 << expander_circuit.log_input_size(), + partition_info, + local_commitment_values, + ) + }) + .collect::>(); - pcs_local_open_impl::( - commitment_val.as_ref(), - &challenge_for_pcs, - p_keys, - transcript, - ); + let mut input_vals = + vec![FMulti::SimdCircuitField::ZERO; 1 << expander_circuit.log_input_size()]; + for (i, vals) in input_vals.iter_mut().enumerate() { + let vals_unpacked = input_vals_multi_copies + .iter() + .flat_map(|v| v[i].unpack()) + .collect::>(); + *vals = FMulti::SimdCircuitField::pack(&vals_unpacked); + } + expander_circuit.layers[0].input_vals = input_vals; + + expander_circuit.fill_rnd_coefs(transcript); + expander_circuit.evaluate(); + let (claimed_v, challenge) = + gkr::gkr_prove(expander_circuit, prover_scratch, transcript, mpi_config); + assert_eq!(claimed_v, FBasic::ChallengeField::from(0)); + + let n_simd_vars_basic = FBasic::SimdCircuitField::PACK_SIZE.ilog2() as usize; + + ExpanderDualVarChallenge { + rz_0: challenge.rz_0, + rz_1: challenge.rz_1, + r_simd: challenge.r_simd[..n_simd_vars_basic].to_vec(), + r_mpi: { + let mut v = challenge.r_simd[n_simd_vars_basic..].to_vec(); + v.extend(&challenge.r_mpi); + v + }, } } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index 08215258..5f1ecd84 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -1,4 +1,9 @@ -use gkr_engine::{FieldEngine, GKREngine, MPIConfig, MPIEngine}; +use arith::Fr; +use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; +use gkr_engine::{ + BN254Config, ExpanderPCS, FieldEngine, FieldType, GKREngine, MPIConfig, MPIEngine, + PolynomialCommitmentType, +}; use serdes::ExpSerde; use crate::{ @@ -7,6 +12,7 @@ use crate::{ context::ComputationGraph, proving_system::{ expander::structs::{ExpanderProverSetup, ExpanderVerifierSetup}, + expander_no_oversubscribe::prove_impl::mpi_prove_no_oversubscribe_impl, expander_parallelized::{prove_impl::mpi_prove_impl, server_fns::ServerFns}, CombinedProof, Expander, ExpanderNoOverSubscribe, ParallelizedExpander, }, @@ -16,6 +22,7 @@ use crate::{ impl ServerFns for ExpanderNoOverSubscribe where C: GKREngine, + C::FieldConfig: FieldEngine, ECCConfig: Config, { fn setup_request_handler( @@ -44,7 +51,7 @@ where C: GKREngine, ECCConfig: Config, { - mpi_prove_impl(global_mpi_config, prover_setup, computation_graph, values) + mpi_prove_no_oversubscribe_impl(global_mpi_config, prover_setup, computation_graph, values) } } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index e6a2b674..c6a84a93 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -186,7 +186,7 @@ pub fn partition_challenge_and_location_for_pcs_mpi( } #[allow(clippy::too_many_arguments)] -fn partition_single_gkr_claim_and_open_pcs_mpi( +pub fn partition_single_gkr_claim_and_open_pcs_mpi( p_keys: &ExpanderProverSetup, commitments_values: &[impl AsRef<[SIMDField]>], commitments_state: &[&ExpanderCommitmentState], From da3c1e1f71d956da1e814638a8ba1185432a3c09 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 7 Jul 2025 21:25:07 -0500 Subject: [PATCH 17/60] clippy auto fix --- expander_compiler/ec_go_lib/src/proving.rs | 2 -- .../src/zkcuda/proving_system/expander/commit_impl.rs | 2 +- .../src/zkcuda/proving_system/expander/prove_impl.rs | 2 +- .../src/zkcuda/proving_system/expander/setup_impl.rs | 2 +- .../src/zkcuda/proving_system/expander/structs.rs | 1 - .../expander_no_oversubscribe/api_no_oversubscribe.rs | 3 +-- .../expander_no_oversubscribe/prove_impl.rs | 10 +++++----- .../expander_no_oversubscribe/server_fn.rs | 6 ++---- .../expander_parallelized/api_parallel.rs | 2 +- .../expander_parallelized/client_utils.rs | 2 +- .../expander_parallelized/server_ctrl.rs | 2 +- .../proving_system/expander_parallelized/server_fns.rs | 2 +- .../expander_pcs_defered/api_pcs_defered.rs | 2 +- .../proving_system/expander_pcs_defered/prove_impl.rs | 2 +- .../proving_system/expander_pcs_defered/server_fns.rs | 2 +- .../proving_system/expander_pcs_defered/setup_impl.rs | 2 +- 16 files changed, 19 insertions(+), 25 deletions(-) diff --git a/expander_compiler/ec_go_lib/src/proving.rs b/expander_compiler/ec_go_lib/src/proving.rs index 0fca5d90..809528c3 100644 --- a/expander_compiler/ec_go_lib/src/proving.rs +++ b/expander_compiler/ec_go_lib/src/proving.rs @@ -1,12 +1,10 @@ use std::ptr; use std::slice; -use arith::SimdField; use expander_binary::executor; use expander_compiler::frontend::ChallengeField; use expander_compiler::frontend::SIMDField; -use gkr_engine::FieldEngine; use libc::{c_uchar, c_ulong, malloc}; use expander_compiler::circuit::config; diff --git a/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs index 0b015baa..723d8f1f 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs @@ -1,5 +1,5 @@ use expander_utils::timer::Timer; -use gkr_engine::{ExpanderPCS, FieldEngine, GKREngine, MPIConfig}; +use gkr_engine::{ExpanderPCS, GKREngine, MPIConfig}; use polynomials::RefMultiLinearPoly; use super::structs::ExpanderProverSetup; diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index 459f1ae4..ad7b3121 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -13,7 +13,7 @@ use crate::{ frontend::Config, zkcuda::{ kernel::{Kernel, LayeredCircuitInputVec}, - proving_system::expander::{self, structs::ExpanderProverSetup}, + proving_system::expander::structs::ExpanderProverSetup, }, }; diff --git a/expander_compiler/src/zkcuda/proving_system/expander/setup_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/setup_impl.rs index 19a5329a..e761b6cb 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/setup_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/setup_impl.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use gkr_engine::{FieldEngine, GKREngine, MPIConfig}; +use gkr_engine::{GKREngine, MPIConfig}; use crate::{ frontend::Config, diff --git a/expander_compiler/src/zkcuda/proving_system/expander/structs.rs b/expander_compiler/src/zkcuda/proving_system/expander/structs.rs index fc6e1940..1a81a5e1 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/structs.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/structs.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; -use arith::Field; use gkr_engine::{ExpanderPCS, FieldEngine, Proof as BytesProof, StructuredReferenceString}; use serdes::ExpSerde; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index 5541c3bd..1559d3eb 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -12,8 +12,7 @@ use crate::zkcuda::proving_system::{CombinedProof, ParallelizedExpander, Proving use super::super::Expander; -use arith::Fr; -use gkr_engine::{FieldEngine, GKREngine}; +use gkr_engine::GKREngine; pub struct ExpanderNoOverSubscribe { _config: std::marker::PhantomData, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index 542d3f9d..b9de4df0 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -1,7 +1,7 @@ use arith::{Field, Fr, SimdField}; use expander_utils::timer::Timer; use gkr_engine::{ - BN254ConfigXN, ExpanderDualVarChallenge, ExpanderSingleVarChallenge, FieldEngine, FieldType, + BN254ConfigXN, ExpanderDualVarChallenge, FieldEngine, GKREngine, MPIConfig, MPIEngine, Transcript, }; @@ -15,10 +15,10 @@ use crate::{ expander::{ commit_impl::local_commit_impl, prove_impl::{ - get_local_vals, pcs_local_open_impl, prepare_expander_circuit, + get_local_vals, prepare_expander_circuit, prepare_inputs_with_local_vals, }, - structs::{ExpanderCommitmentState, ExpanderProof, ExpanderProverSetup}, + structs::{ExpanderProof, ExpanderProverSetup}, }, expander_parallelized::prove_impl::partition_single_gkr_claim_and_open_pcs_mpi, expander_parallelized::server_ctrl::generate_local_mpi_config, @@ -229,7 +229,7 @@ where is_broadcast, ), _ => { - panic!("Unsupported parallel count: {}", parallel_count); + panic!("Unsupported parallel count: {parallel_count}"); } } } @@ -270,7 +270,7 @@ where &local_commitment_values, kernel.layered_circuit_input(), &mut transcript, - &mpi_config, + mpi_config, ); Some((transcript, challenge)) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index 5f1ecd84..021bd4c7 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -1,8 +1,6 @@ use arith::Fr; -use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; use gkr_engine::{ - BN254Config, ExpanderPCS, FieldEngine, FieldType, GKREngine, MPIConfig, MPIEngine, - PolynomialCommitmentType, + FieldEngine, GKREngine, MPIConfig, MPIEngine, }; use serdes::ExpSerde; @@ -13,7 +11,7 @@ use crate::{ proving_system::{ expander::structs::{ExpanderProverSetup, ExpanderVerifierSetup}, expander_no_oversubscribe::prove_impl::mpi_prove_no_oversubscribe_impl, - expander_parallelized::{prove_impl::mpi_prove_impl, server_fns::ServerFns}, + expander_parallelized::server_fns::ServerFns, CombinedProof, Expander, ExpanderNoOverSubscribe, ParallelizedExpander, }, }, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs index 5a934627..aad322f6 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs @@ -15,7 +15,7 @@ use crate::zkcuda::proving_system::{CombinedProof, ProvingSystem}; use super::super::Expander; use expander_utils::timer::Timer; -use gkr_engine::{FieldEngine, GKREngine}; +use gkr_engine::GKREngine; use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; pub struct ParallelizedExpander { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index 379819bd..ebec0cd7 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -19,7 +19,7 @@ use crate::{ use super::server_ctrl::{RequestType, SERVER_IP, SERVER_PORT}; use expander_utils::timer::Timer; -use gkr_engine::{FieldEngine, GKREngine}; +use gkr_engine::GKREngine; use reqwest::Client; use serdes::ExpSerde; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs index 80c98334..b5619db1 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs @@ -18,7 +18,7 @@ use mpi::traits::Communicator; use crate::frontend::Config; use axum::{extract::State, Json}; -use gkr_engine::{FieldEngine, GKREngine, MPIConfig, MPIEngine}; +use gkr_engine::{GKREngine, MPIConfig, MPIEngine}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use std::net::{IpAddr, SocketAddr}; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs index ccc7bd8f..75100f08 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs @@ -1,4 +1,4 @@ -use gkr_engine::{FieldEngine, GKREngine, MPIConfig, MPIEngine}; +use gkr_engine::{GKREngine, MPIConfig, MPIEngine}; use serdes::ExpSerde; use crate::{ diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs index 0459826a..235262d6 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs @@ -1,4 +1,4 @@ -use gkr_engine::{ExpanderPCS, FieldEngine, GKREngine}; +use gkr_engine::{ExpanderPCS, GKREngine}; use crate::{ frontend::{Config, SIMDField}, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs index a8d6f284..aaabb96f 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs @@ -1,7 +1,7 @@ use arith::Field; use expander_utils::timer::Timer; use gkr_engine::{ - ExpanderPCS, ExpanderSingleVarChallenge, FieldEngine, GKREngine, MPIConfig, MPIEngine, + ExpanderPCS, ExpanderSingleVarChallenge, GKREngine, MPIConfig, MPIEngine, Proof as BytesProof, Transcript, }; use polynomials::RefMultiLinearPoly; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs index aa281763..0a34c5fc 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs @@ -1,4 +1,4 @@ -use gkr_engine::{FieldEngine, MPIEngine}; +use gkr_engine::MPIEngine; use crate::{ frontend::Config, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/setup_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/setup_impl.rs index ffc7e419..dba77cf4 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/setup_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/setup_impl.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use gkr_engine::{FieldEngine, GKREngine, MPIConfig}; +use gkr_engine::{GKREngine, MPIConfig}; use crate::{ frontend::Config, From 4ed35e13b2af124bc758c9492c3ce7edd94c46e9 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Tue, 8 Jul 2025 18:04:05 -0500 Subject: [PATCH 18/60] add testing code --- .../proving_system/expander_no_oversubscribe/prove_impl.rs | 7 +++---- .../proving_system/expander_no_oversubscribe/server_fn.rs | 4 +--- .../proving_system/expander_pcs_defered/prove_impl.rs | 4 ++-- expander_compiler/tests/example_call_expander.rs | 1 - expander_compiler/tests/zkcuda_examples.rs | 7 ++++++- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index b9de4df0..8580bba4 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -1,8 +1,8 @@ use arith::{Field, Fr, SimdField}; use expander_utils::timer::Timer; use gkr_engine::{ - BN254ConfigXN, ExpanderDualVarChallenge, FieldEngine, - GKREngine, MPIConfig, MPIEngine, Transcript, + BN254ConfigXN, ExpanderDualVarChallenge, FieldEngine, GKREngine, MPIConfig, MPIEngine, + Transcript, }; use crate::{ @@ -15,8 +15,7 @@ use crate::{ expander::{ commit_impl::local_commit_impl, prove_impl::{ - get_local_vals, prepare_expander_circuit, - prepare_inputs_with_local_vals, + get_local_vals, prepare_expander_circuit, prepare_inputs_with_local_vals, }, structs::{ExpanderProof, ExpanderProverSetup}, }, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index 021bd4c7..4a8975cf 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -1,7 +1,5 @@ use arith::Fr; -use gkr_engine::{ - FieldEngine, GKREngine, MPIConfig, MPIEngine, -}; +use gkr_engine::{FieldEngine, GKREngine, MPIConfig, MPIEngine}; use serdes::ExpSerde; use crate::{ diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs index aaabb96f..071a2887 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs @@ -1,8 +1,8 @@ use arith::Field; use expander_utils::timer::Timer; use gkr_engine::{ - ExpanderPCS, ExpanderSingleVarChallenge, GKREngine, MPIConfig, MPIEngine, - Proof as BytesProof, Transcript, + ExpanderPCS, ExpanderSingleVarChallenge, GKREngine, MPIConfig, MPIEngine, Proof as BytesProof, + Transcript, }; use polynomials::RefMultiLinearPoly; use serdes::ExpSerde; diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs index da5de25a..1eaf8e05 100644 --- a/expander_compiler/tests/example_call_expander.rs +++ b/expander_compiler/tests/example_call_expander.rs @@ -2,7 +2,6 @@ use arith::Field; use arith::SimdField; use expander_binary::executor; use expander_compiler::frontend::*; -use gkr_engine::FieldEngine; use gkr_engine::MPIConfig; use rand::SeedableRng; diff --git a/expander_compiler/tests/zkcuda_examples.rs b/expander_compiler/tests/zkcuda_examples.rs index c70ef94e..e1575b2a 100644 --- a/expander_compiler/tests/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda_examples.rs @@ -1,5 +1,7 @@ use expander_compiler::frontend::*; -use expander_compiler::zkcuda::proving_system::{Expander, ParallelizedExpander, ProvingSystem}; +use expander_compiler::zkcuda::proving_system::{ + Expander, ExpanderNoOverSubscribe, ParallelizedExpander, ProvingSystem, +}; use expander_compiler::zkcuda::shape::Reshape; use expander_compiler::zkcuda::{context::*, kernel::*}; @@ -86,6 +88,9 @@ fn zkcuda_test_multi_core() { zkcuda_test::>(); zkcuda_test::>(); zkcuda_test::>(); + + zkcuda_test::>(); + zkcuda_test::>(); } fn zkcuda_test_simd_prepare_ctx() -> Context { From 5e0a7ea79a53934adda492361930c1dccd3fdfbc Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Tue, 8 Jul 2025 18:30:24 -0500 Subject: [PATCH 19/60] bug fix mpi_config -> local_mpi_config --- .../expander_no_oversubscribe/prove_impl.rs | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index 8580bba4..f3726bc7 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -144,84 +144,84 @@ where let n_local_copies = parallel_count / local_world_size; match n_local_copies { 1 => prove_kernel_gkr_internal::( - mpi_config, + &local_mpi_config, kernel, commitments_values, parallel_count, is_broadcast, ), 2 => prove_kernel_gkr_internal::, T, ECCConfig>( - mpi_config, + &local_mpi_config, kernel, commitments_values, parallel_count, is_broadcast, ), 4 => prove_kernel_gkr_internal::, T, ECCConfig>( - mpi_config, + &local_mpi_config, kernel, commitments_values, parallel_count, is_broadcast, ), 8 => prove_kernel_gkr_internal::, T, ECCConfig>( - mpi_config, + &local_mpi_config, kernel, commitments_values, parallel_count, is_broadcast, ), 16 => prove_kernel_gkr_internal::, T, ECCConfig>( - mpi_config, + &local_mpi_config, kernel, commitments_values, parallel_count, is_broadcast, ), 32 => prove_kernel_gkr_internal::, T, ECCConfig>( - mpi_config, + &local_mpi_config, kernel, commitments_values, parallel_count, is_broadcast, ), 64 => prove_kernel_gkr_internal::, T, ECCConfig>( - mpi_config, + &local_mpi_config, kernel, commitments_values, parallel_count, is_broadcast, ), 128 => prove_kernel_gkr_internal::, T, ECCConfig>( - mpi_config, + &local_mpi_config, kernel, commitments_values, parallel_count, is_broadcast, ), 256 => prove_kernel_gkr_internal::, T, ECCConfig>( - mpi_config, + &local_mpi_config, kernel, commitments_values, parallel_count, is_broadcast, ), 512 => prove_kernel_gkr_internal::, T, ECCConfig>( - mpi_config, + &local_mpi_config, kernel, commitments_values, parallel_count, is_broadcast, ), 1024 => prove_kernel_gkr_internal::, T, ECCConfig>( - mpi_config, + &local_mpi_config, kernel, commitments_values, parallel_count, is_broadcast, ), 2048 => prove_kernel_gkr_internal::, T, ECCConfig>( - mpi_config, + &local_mpi_config, kernel, commitments_values, parallel_count, From 4cc1d4514d8134cbe647eae9511f55d198ba32a6 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Tue, 8 Jul 2025 19:25:45 -0500 Subject: [PATCH 20/60] bug fix in server binary of no oversubscribe --- .../proving_system/expander_no_oversubscribe/server_bin.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs index fd7cb9c7..278bdb67 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs @@ -5,11 +5,10 @@ use expander_compiler::{ frontend::BN254Config, zkcuda::proving_system::{ expander_parallelized::server_ctrl::{serve, ExpanderExecArgs}, - expander_pcs_defered::BN254ConfigSha2UniKZG, ExpanderNoOverSubscribe, }, }; -use gkr::BN254ConfigSha2Hyrax; +use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; use gkr_engine::PolynomialCommitmentType; #[tokio::main] @@ -30,7 +29,7 @@ pub async fn main() { .await; } ("BN254", PolynomialCommitmentType::KZG) => { - serve::>( + serve::>( expander_exec_args.port_number, ) .await; From 36442cbbd26d23e338c70d4e25c687475858bd38 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Tue, 8 Jul 2025 19:44:36 -0500 Subject: [PATCH 21/60] fix inconsistency after merging --- .../proving_system/expander_no_oversubscribe/server_fn.rs | 4 +++- .../proving_system/expander_parallelized/server_ctrl.rs | 7 +++---- .../proving_system/expander_parallelized/server_fns.rs | 7 +++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index 4a8975cf..fd9fab87 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -9,7 +9,7 @@ use crate::{ proving_system::{ expander::structs::{ExpanderProverSetup, ExpanderVerifierSetup}, expander_no_oversubscribe::prove_impl::mpi_prove_no_oversubscribe_impl, - expander_parallelized::server_fns::ServerFns, + expander_parallelized::{server_ctrl::SharedMemoryWINWrapper, server_fns::ServerFns}, CombinedProof, Expander, ExpanderNoOverSubscribe, ParallelizedExpander, }, }, @@ -27,6 +27,7 @@ where computation_graph: &mut ComputationGraph, prover_setup: &mut ExpanderProverSetup, verifier_setup: &mut ExpanderVerifierSetup, + mpi_win: &mut Option, ) { ParallelizedExpander::::setup_request_handler( global_mpi_config, @@ -34,6 +35,7 @@ where computation_graph, prover_setup, verifier_setup, + mpi_win, ); } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs index 806c03d6..604eec66 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs @@ -16,7 +16,7 @@ use mpi::ffi::MPI_Win; use mpi::topology::SimpleCommunicator; use mpi::traits::Communicator; -use crate::frontend::Config; +use crate::frontend::{Config, SIMDField}; use axum::{extract::State, Json}; use gkr_engine::{GKREngine, MPIConfig, MPIEngine}; @@ -62,11 +62,10 @@ pub struct ServerState>, pub prover_setup: Arc>>, - pub verifier_setup: - Arc>>, + pub verifier_setup: Arc>>, pub computation_graph: Arc>>, - pub witness: Arc>>>, + pub witness: Arc>>>>, pub cg_shared_memory_win: Arc>>, // Shared memory for computation graph pub wt_shared_memory_win: Arc>>, // Shared memory for witness diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs index 8f1b5f14..e80cff71 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs @@ -42,7 +42,7 @@ where fn setup_shared_witness( global_mpi_config: &MPIConfig<'static>, - witness_target: &mut Vec>, + witness_target: &mut Vec>>, mpi_shared_memory_win: &mut Option, ) { // dispose of the previous shared memory if it exists @@ -67,7 +67,7 @@ where fn shared_memory_clean_up( global_mpi_config: &MPIConfig<'static>, computation_graph: ComputationGraph, - witness: Vec>, + witness: Vec>>, cg_mpi_win: &mut Option, wt_mpi_win: &mut Option, ) { @@ -98,8 +98,7 @@ where prover_setup: &mut ExpanderProverSetup, verifier_setup: &mut ExpanderVerifierSetup, mpi_win: &mut Option, - ) where - C::FieldConfig: FieldEngine, + ) { let setup_file = if global_mpi_config.is_root() { let setup_file = setup_file.expect("Setup file path must be provided"); From 3a8daa50ce3f9d22e66fc82fdc6fba6826a2dc3e Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Tue, 8 Jul 2025 19:45:15 -0500 Subject: [PATCH 22/60] clippy auto fix --- .../proving_system/expander_parallelized/server_fns.rs | 5 ++--- .../zkcuda/proving_system/expander_pcs_defered/server_fns.rs | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs index e80cff71..37d5f8a3 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs @@ -1,4 +1,4 @@ -use gkr_engine::{FieldEngine, GKREngine, MPIConfig, MPIEngine, MPISharedMemory}; +use gkr_engine::{GKREngine, MPIConfig, MPIEngine, MPISharedMemory}; use serdes::ExpSerde; use crate::{ @@ -98,8 +98,7 @@ where prover_setup: &mut ExpanderProverSetup, verifier_setup: &mut ExpanderVerifierSetup, mpi_win: &mut Option, - ) - { + ) { let setup_file = if global_mpi_config.is_root() { let setup_file = setup_file.expect("Setup file path must be provided"); broadcast_string(global_mpi_config, Some(setup_file)) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs index e9ce5715..a34894e0 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs @@ -1,4 +1,4 @@ -use gkr_engine::{FieldEngine, GKREngine, MPIEngine}; +use gkr_engine::{GKREngine, MPIEngine}; use crate::{ frontend::Config, From 1233863375241c236bc57726eed036fac662dba9 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Tue, 8 Jul 2025 20:09:08 -0500 Subject: [PATCH 23/60] benchmark code --- expander_compiler/Cargo.toml | 4 ++++ .../bin/zkcuda_matmul_no_oversubscribe.rs | 16 ++++++++++++++++ 2 files changed, 20 insertions(+) create mode 100644 expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 7b54210a..09e8e996 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -72,3 +72,7 @@ path = "bin/zkcuda_matmul.rs" [[bin]] name = "zkcuda_matmul_pcs_defered" path = "bin/zkcuda_matmul_pcs_defered.rs" + +[[bin]] +name = "zkcuda_matmul_no_oversubscribe" +path = "bin/zkcuda_matmul_no_oversubscribe.rs" diff --git a/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs b/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs new file mode 100644 index 00000000..89a16258 --- /dev/null +++ b/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs @@ -0,0 +1,16 @@ +#![allow(unused)] +mod zkcuda_matmul; +use expander_compiler::{ + frontend::BN254Config, + zkcuda::proving_system::ExpanderNoOverSubscribe, +}; +use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; +use zkcuda_matmul::zkcuda_matmul; + +fn main() { + zkcuda_matmul::, 4>(); + zkcuda_matmul::, 8>(); + zkcuda_matmul::, 16>(); + + zkcuda_matmul::, 1024>(); +} From e6d6c7436f416f69c66aa65efe1fb4115d96bd8a Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Tue, 8 Jul 2025 20:14:30 -0500 Subject: [PATCH 24/60] change config in benchmark code --- expander_compiler/bin/zkcuda_matmul.rs | 9 +++++---- expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs | 5 +---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/expander_compiler/bin/zkcuda_matmul.rs b/expander_compiler/bin/zkcuda_matmul.rs index 1d2c80f5..ac86ddad 100644 --- a/expander_compiler/bin/zkcuda_matmul.rs +++ b/expander_compiler/bin/zkcuda_matmul.rs @@ -4,13 +4,14 @@ use expander_compiler::frontend::{ BN254Config, BasicAPI, CircuitField, Config, Error, FieldArith, Variable, API, }; -use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; +// use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; use expander_compiler::zkcuda::proving_system::{ParallelizedExpander, ProvingSystem}; use expander_compiler::zkcuda::shape::Reshape; use expander_compiler::zkcuda::{ context::{call_kernel, Context}, kernel::{compile_with_spec_and_shapes, kernel, IOVecSpec, KernelPrimitive}, }; +use gkr::BN254ConfigSha2KZG; const M: usize = 512; const K: usize = 512; @@ -96,7 +97,7 @@ fn main() { // zkcuda_matmul::, 4>(); // zkcuda_matmul::, 8>(); // zkcuda_matmul::, 16>(); - zkcuda_matmul::, 4>(); - zkcuda_matmul::, 8>(); - zkcuda_matmul::, 16>(); + zkcuda_matmul::, 4>(); + zkcuda_matmul::, 8>(); + zkcuda_matmul::, 16>(); } diff --git a/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs b/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs index 89a16258..56955f39 100644 --- a/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs +++ b/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs @@ -1,9 +1,6 @@ #![allow(unused)] mod zkcuda_matmul; -use expander_compiler::{ - frontend::BN254Config, - zkcuda::proving_system::ExpanderNoOverSubscribe, -}; +use expander_compiler::{frontend::BN254Config, zkcuda::proving_system::ExpanderNoOverSubscribe}; use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; use zkcuda_matmul::zkcuda_matmul; From 6c595c716e10089c06050af6efc446cc1d3140d5 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Thu, 10 Jul 2025 21:29:45 -0500 Subject: [PATCH 25/60] update: KZG -> UniKZG --- Cargo.lock | 54 +++++++++---------- expander_compiler/bin/zkcuda_matmul.rs | 8 +-- .../bin/zkcuda_matmul_no_oversubscribe.rs | 16 +++--- .../expander_no_oversubscribe/server_bin.rs | 5 +- .../expander_parallelized/server_bin.rs | 13 +++-- expander_compiler/tests/zkcuda_examples.rs | 9 ++-- 6 files changed, 57 insertions(+), 48 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5a2ba4da..03b1b268 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,7 +112,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "ark-std", "criterion", @@ -330,7 +330,7 @@ dependencies = [ [[package]] name = "babybear" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "ark-std", @@ -383,7 +383,7 @@ dependencies = [ [[package]] name = "bin" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "babybear", @@ -589,7 +589,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "ark-std", @@ -645,9 +645,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" +checksum = "be92d32e80243a54711e5d7ce823c35c41c9d929dc4ab58e1276f625841aadf9" dependencies = [ "clap_builder", "clap_derive", @@ -655,9 +655,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" +checksum = "707eab41e9622f9139419d573eca0900137718000c517d47da73045f54331c3d" dependencies = [ "anstream", "anstyle", @@ -667,9 +667,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce" +checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491" dependencies = [ "heck", "proc-macro2", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "config_macros" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "gkr_engine", "gkr_hashers", @@ -817,7 +817,7 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crosslayer_prototype" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "env_logger", @@ -1143,7 +1143,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "ark-std", @@ -1160,7 +1160,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "ark-std", @@ -1179,7 +1179,7 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "ark-std", @@ -1212,7 +1212,7 @@ dependencies = [ [[package]] name = "gkr_engine" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "babybear", @@ -1231,7 +1231,7 @@ dependencies = [ [[package]] name = "gkr_hashers" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "halo2curves", @@ -1249,7 +1249,7 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "goldilocks" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "ark-std", @@ -1881,7 +1881,7 @@ dependencies = [ [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "ark-std", @@ -2259,7 +2259,7 @@ dependencies = [ [[package]] name = "poly_commit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "ark-std", @@ -2284,7 +2284,7 @@ dependencies = [ [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "ark-std", @@ -2705,7 +2705,7 @@ dependencies = [ [[package]] name = "serdes" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "ethnum", "halo2curves", @@ -2716,7 +2716,7 @@ dependencies = [ [[package]] name = "serdes_derive" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "proc-macro2", "quote", @@ -2840,7 +2840,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "circuit", @@ -3106,7 +3106,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "gkr_engine", @@ -3129,7 +3129,7 @@ dependencies = [ [[package]] name = "tree" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "arith", "ark-std", @@ -3231,7 +3231,7 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "utils" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#7146ecfff167da9331d0b6f38ce599d86ca1c01b" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" dependencies = [ "colored", ] diff --git a/expander_compiler/bin/zkcuda_matmul.rs b/expander_compiler/bin/zkcuda_matmul.rs index ac86ddad..a0ac65c1 100644 --- a/expander_compiler/bin/zkcuda_matmul.rs +++ b/expander_compiler/bin/zkcuda_matmul.rs @@ -4,6 +4,7 @@ use expander_compiler::frontend::{ BN254Config, BasicAPI, CircuitField, Config, Error, FieldArith, Variable, API, }; +use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; // use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; use expander_compiler::zkcuda::proving_system::{ParallelizedExpander, ProvingSystem}; use expander_compiler::zkcuda::shape::Reshape; @@ -11,7 +12,6 @@ use expander_compiler::zkcuda::{ context::{call_kernel, Context}, kernel::{compile_with_spec_and_shapes, kernel, IOVecSpec, KernelPrimitive}, }; -use gkr::BN254ConfigSha2KZG; const M: usize = 512; const K: usize = 512; @@ -97,7 +97,7 @@ fn main() { // zkcuda_matmul::, 4>(); // zkcuda_matmul::, 8>(); // zkcuda_matmul::, 16>(); - zkcuda_matmul::, 4>(); - zkcuda_matmul::, 8>(); - zkcuda_matmul::, 16>(); + zkcuda_matmul::, 4>(); + zkcuda_matmul::, 8>(); + zkcuda_matmul::, 16>(); } diff --git a/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs b/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs index 56955f39..9782e841 100644 --- a/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs +++ b/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs @@ -1,13 +1,17 @@ #![allow(unused)] mod zkcuda_matmul; -use expander_compiler::{frontend::BN254Config, zkcuda::proving_system::ExpanderNoOverSubscribe}; -use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; +use expander_compiler::{ + frontend::BN254Config, + zkcuda::proving_system::{ + expander_pcs_defered::BN254ConfigSha2UniKZG, ExpanderNoOverSubscribe, + }, +}; use zkcuda_matmul::zkcuda_matmul; fn main() { - zkcuda_matmul::, 4>(); - zkcuda_matmul::, 8>(); - zkcuda_matmul::, 16>(); + zkcuda_matmul::, 4>(); + zkcuda_matmul::, 8>(); + zkcuda_matmul::, 16>(); - zkcuda_matmul::, 1024>(); + zkcuda_matmul::, 1024>(); } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs index 278bdb67..fd7cb9c7 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs @@ -5,10 +5,11 @@ use expander_compiler::{ frontend::BN254Config, zkcuda::proving_system::{ expander_parallelized::server_ctrl::{serve, ExpanderExecArgs}, + expander_pcs_defered::BN254ConfigSha2UniKZG, ExpanderNoOverSubscribe, }, }; -use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; +use gkr::BN254ConfigSha2Hyrax; use gkr_engine::PolynomialCommitmentType; #[tokio::main] @@ -29,7 +30,7 @@ pub async fn main() { .await; } ("BN254", PolynomialCommitmentType::KZG) => { - serve::>( + serve::>( expander_exec_args.port_number, ) .await; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs index 53e76bbe..6250428d 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs @@ -3,12 +3,15 @@ use std::str::FromStr; use clap::Parser; use expander_compiler::{ frontend::{BN254Config, BabyBearConfig, GF2Config, GoldilocksConfig, M31Config}, - zkcuda::proving_system::expander_parallelized::{ - server_ctrl::{serve, ExpanderExecArgs}, - ParallelizedExpander, + zkcuda::proving_system::{ + expander_parallelized::{ + server_ctrl::{serve, ExpanderExecArgs}, + ParallelizedExpander, + }, + expander_pcs_defered::BN254ConfigSha2UniKZG, }, }; -use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; +use gkr::BN254ConfigSha2Hyrax; use gkr_engine::PolynomialCommitmentType; #[tokio::main] @@ -55,7 +58,7 @@ pub async fn main() { .await; } ("BN254", PolynomialCommitmentType::KZG) => { - serve::>( + serve::>( expander_exec_args.port_number, ) .await; diff --git a/expander_compiler/tests/zkcuda_examples.rs b/expander_compiler/tests/zkcuda_examples.rs index e1575b2a..8a085e75 100644 --- a/expander_compiler/tests/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda_examples.rs @@ -1,11 +1,12 @@ use expander_compiler::frontend::*; +use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; use expander_compiler::zkcuda::proving_system::{ Expander, ExpanderNoOverSubscribe, ParallelizedExpander, ProvingSystem, }; use expander_compiler::zkcuda::shape::Reshape; use expander_compiler::zkcuda::{context::*, kernel::*}; -use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; +use gkr::BN254ConfigSha2Hyrax; use serdes::ExpSerde; #[kernel] @@ -76,7 +77,7 @@ fn zkcuda_test_single_core() { zkcuda_test::>(); zkcuda_test::>(); zkcuda_test::>(); - zkcuda_test::>(); + zkcuda_test::>(); } #[test] @@ -87,10 +88,10 @@ fn zkcuda_test_multi_core() { zkcuda_test::>(); zkcuda_test::>(); zkcuda_test::>(); - zkcuda_test::>(); + zkcuda_test::>(); zkcuda_test::>(); - zkcuda_test::>(); + zkcuda_test::>(); } fn zkcuda_test_simd_prepare_ctx() -> Context { From 4056954cf1a481a236a99ec38187d2aee9fa683f Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 13 Jul 2025 19:26:09 -0500 Subject: [PATCH 26/60] add binary in ci --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4efdf4de..6bd51090 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -76,7 +76,7 @@ jobs: run: brew install openmpi - if: matrix.os == 'ubuntu-latest' run: sudo apt-get update && sudo apt-get install libopenmpi-dev -y - - run: cargo build --release --bin expander_server --bin expander_server_pcs_defered + - run: cargo build --release --bin expander_server --bin expander_server_pcs_defered --bin expander_server_no_oversubscribe - run: cargo test test-rust-avx512: From 9bb5270d909e2f26ea835f6b93a2d426c20603b0 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 13 Jul 2025 21:24:10 -0500 Subject: [PATCH 27/60] add zkcuda config --- .../src/zkcuda/proving_system/expander.rs | 1 + .../zkcuda/proving_system/expander/config.rs | 52 +++++++++++++++++++ .../proving_system/expander_pcs_defered.rs | 8 +-- 3 files changed, 58 insertions(+), 3 deletions(-) create mode 100644 expander_compiler/src/zkcuda/proving_system/expander/config.rs diff --git a/expander_compiler/src/zkcuda/proving_system/expander.rs b/expander_compiler/src/zkcuda/proving_system/expander.rs index d75cdb6f..ec444713 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander.rs @@ -1,6 +1,7 @@ pub mod api_single_thread; pub mod commit_impl; +pub mod config; pub mod prove_impl; pub mod setup_impl; pub mod structs; diff --git a/expander_compiler/src/zkcuda/proving_system/expander/config.rs b/expander_compiler/src/zkcuda/proving_system/expander/config.rs new file mode 100644 index 00000000..e06267b0 --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander/config.rs @@ -0,0 +1,52 @@ +use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2Raw, M31x16ConfigSha2RawVanilla}; +use gkr_engine::GKREngine; + +use crate::{frontend::{BN254Config, Config, M31Config}, zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG}; + +pub trait ZKCudaConfig { + type ECCConfig: Config; + type GKRConfig: GKREngine::FieldConfig>; + + const BATCH_PCS: bool = false; +} + +pub type GetPCS = <::GKRConfig as GKREngine>::PCSConfig; +pub type GetTranscript = <::GKRConfig as GKREngine>::TranscriptConfig; +pub type GetFieldConfig = <::GKRConfig as GKREngine>::FieldConfig; + +pub struct ZKCudaConfigImpl +where + ECC: Config, + GKR: GKREngine::FieldConfig>, +{ + _phantom: std::marker::PhantomData<(ECC, GKR, bool)>, +} + +impl ZKCudaConfig for ZKCudaConfigImpl +where + ECC: Config, + GKR: GKREngine::FieldConfig>, +{ + type ECCConfig = ECC; + type GKRConfig = GKR; + + const BATCH_PCS: bool = BATCH_PCS; +} + +// Concrete ZKCudaConfig types for various configurations +pub type ZKCudaBN254Hyrax<'a> = ZKCudaConfigImpl, false>; +pub type ZKCudaBN254KZG<'a> = ZKCudaConfigImpl, false>; + +pub type ZKCudaM31<'a> = ZKCudaConfigImpl, false>; +pub type ZKCudaGF2<'a> = ZKCudaConfigImpl, false>; +pub type ZKCudaGoldilocks<'a> = ZKCudaConfigImpl, false>; +pub type ZKCudaBabyBear<'a> = ZKCudaConfigImpl, false>; + +// Batch PCS types +pub type ZKCudaBN254HyraxBatchPCS<'a> = ZKCudaConfigImpl, true>; +pub type ZKCudaBN254KZGBatchPCS<'a> = ZKCudaConfigImpl, true>; + +pub type ZKCudaM31BatchPCS<'a> = ZKCudaConfigImpl, true>; +pub type ZKCudaGF2BatchPCS<'a> = ZKCudaConfigImpl, true>; +pub type ZKCudaGoldilocksBatchPCS<'a> = ZKCudaConfigImpl, true>; +pub type ZKCudaBabyBearBatchPCS<'a> = ZKCudaConfigImpl, true>; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered.rs index 6ecc6058..9be9c805 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered.rs @@ -13,11 +13,13 @@ use gkr_hashers::SHA256hasher; use halo2curves::bn256::Bn256; use poly_commit::HyperUniKZGPCS; -pub struct BN254ConfigSha2UniKZG; +pub struct BN254ConfigSha2UniKZG<'a> { + _phantom: std::marker::PhantomData<&'a ()>, +} -impl GKREngine for BN254ConfigSha2UniKZG { +impl<'a> GKREngine for BN254ConfigSha2UniKZG<'a> { type FieldConfig = ::FieldConfig; - type MPIConfig = MPIConfig<'static>; + type MPIConfig = MPIConfig<'a>; type TranscriptConfig = BytesHashTranscript; type PCSConfig = HyperUniKZGPCS; const SCHEME: GKRScheme = GKRScheme::Vanilla; From e4c78a674f258ae847b25bc5754731f063d17e0c Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 14 Jul 2025 21:15:05 -0500 Subject: [PATCH 28/60] switch to defered PCS --- .../zkcuda/proving_system/expander/config.rs | 17 ++++-- .../api_no_oversubscribe.rs | 12 ++-- .../expander_no_oversubscribe/prove_impl.rs | 55 +++++++++--------- .../expander_no_oversubscribe/server_fn.rs | 56 ++++++------------- 4 files changed, 68 insertions(+), 72 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander/config.rs b/expander_compiler/src/zkcuda/proving_system/expander/config.rs index e06267b0..75c78557 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/config.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/config.rs @@ -1,7 +1,10 @@ use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2Raw, M31x16ConfigSha2RawVanilla}; use gkr_engine::GKREngine; -use crate::{frontend::{BN254Config, Config, M31Config}, zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG}; +use crate::{ + frontend::{BN254Config, Config, M31Config}, + zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG, +}; pub trait ZKCudaConfig { type ECCConfig: Config; @@ -11,8 +14,10 @@ pub trait ZKCudaConfig { } pub type GetPCS = <::GKRConfig as GKREngine>::PCSConfig; -pub type GetTranscript = <::GKRConfig as GKREngine>::TranscriptConfig; -pub type GetFieldConfig = <::GKRConfig as GKREngine>::FieldConfig; +pub type GetTranscript = + <::GKRConfig as GKREngine>::TranscriptConfig; +pub type GetFieldConfig = + <::GKRConfig as GKREngine>::FieldConfig; pub struct ZKCudaConfigImpl where @@ -43,8 +48,10 @@ pub type ZKCudaGoldilocks<'a> = ZKCudaConfigImpl = ZKCudaConfigImpl, false>; // Batch PCS types -pub type ZKCudaBN254HyraxBatchPCS<'a> = ZKCudaConfigImpl, true>; -pub type ZKCudaBN254KZGBatchPCS<'a> = ZKCudaConfigImpl, true>; +pub type ZKCudaBN254HyraxBatchPCS<'a> = + ZKCudaConfigImpl, true>; +pub type ZKCudaBN254KZGBatchPCS<'a> = + ZKCudaConfigImpl, true>; pub type ZKCudaM31BatchPCS<'a> = ZKCudaConfigImpl, true>; pub type ZKCudaGF2BatchPCS<'a> = ZKCudaConfigImpl, true>; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index 1559d3eb..4d5558e4 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -8,11 +8,13 @@ use crate::zkcuda::proving_system::expander_parallelized::client_utils::{ client_launch_server_and_setup, client_parse_args, client_send_witness_and_prove, wait_async, ClientHttpHelper, }; -use crate::zkcuda::proving_system::{CombinedProof, ParallelizedExpander, ProvingSystem}; +use crate::zkcuda::proving_system::{ + CombinedProof, ExpanderPCSDefered, ParallelizedExpander, ProvingSystem, +}; use super::super::Expander; -use gkr_engine::GKREngine; +use gkr_engine::{ExpanderPCS, GKREngine}; pub struct ExpanderNoOverSubscribe { _config: std::marker::PhantomData, @@ -20,6 +22,9 @@ pub struct ExpanderNoOverSubscribe { impl> ProvingSystem for ExpanderNoOverSubscribe +where + >::Commitment: + AsRef<>::Commitment>, { type ProverSetup = ExpanderProverSetup; type VerifierSetup = ExpanderVerifierSetup; @@ -46,8 +51,7 @@ impl> ProvingSyste computation_graph: &ComputationGraph, proof: &Self::Proof, ) -> bool { - // The proof should be the same as the one returned by ParallelizedExpander::prove - ParallelizedExpander::verify(verifier_setup, computation_graph, proof) + ExpanderPCSDefered::::verify(verifier_setup, computation_graph, proof) } fn post_process() { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index f3726bc7..d7f65c23 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -19,8 +19,13 @@ use crate::{ }, structs::{ExpanderProof, ExpanderProverSetup}, }, - expander_parallelized::prove_impl::partition_single_gkr_claim_and_open_pcs_mpi, - expander_parallelized::server_ctrl::generate_local_mpi_config, + expander_parallelized::{ + prove_impl::partition_single_gkr_claim_and_open_pcs_mpi, + server_ctrl::generate_local_mpi_config, + }, + expander_pcs_defered::prove_impl::{ + extract_pcs_claims, open_defered_pcs, pad_vals_and_commit, + }, CombinedProof, Expander, }, }, @@ -41,7 +46,7 @@ where let (commitments, states) = if global_mpi_config.is_root() { let (commitments, states) = values .iter() - .map(|value| local_commit_impl::(prover_setup, value.as_ref())) + .map(|value| pad_vals_and_commit::(prover_setup, value.as_ref())) .unzip::<_, _, Vec<_>, Vec<_>>(); (Some(commitments), Some(states)) } else { @@ -49,6 +54,9 @@ where }; commit_timer.stop(); + let mut vals_ref = vec![]; + let mut challenges = vec![]; + let prove_timer = Timer::new("Prove all kernels", global_mpi_config.is_root()); let proofs = computation_graph .proof_templates() @@ -76,30 +84,20 @@ where single_kernel_gkr_timer.stop(); if global_mpi_config.is_root() { - let pcs_open_timer = Timer::new("pcs open", true); let (mut transcript, challenge) = gkr_end_state.unwrap(); - let challenges = if let Some(challenge_y) = challenge.challenge_y() { - vec![challenge.challenge_x(), challenge_y] - } else { - vec![challenge.challenge_x()] - }; + assert!(challenge.challenge_y().is_none()); + let challenge = challenge.challenge_x(); - challenges.iter().for_each(|c| { - partition_single_gkr_claim_and_open_pcs_mpi::( - prover_setup, - &commitment_values, - &template - .commitment_indices() - .iter() - .map(|&idx| &states.as_ref().unwrap()[idx]) - .collect::>(), - c, - template.is_broadcast(), - &mut transcript, - ); - }); + let (local_vals_ref, local_challenges) = extract_pcs_claims::( + &commitment_values, + &challenge, + template.is_broadcast(), + next_power_of_two(template.parallel_count()), + ); + + vals_ref.extend(local_vals_ref); + challenges.extend(local_challenges); - pcs_open_timer.stop(); Some(ExpanderProof { data: vec![transcript.finalize_and_get_proof()], }) @@ -111,7 +109,14 @@ where prove_timer.stop(); if global_mpi_config.is_root() { - let proofs = proofs.into_iter().map(|p| p.unwrap()).collect::>(); + let mut proofs = proofs.into_iter().map(|p| p.unwrap()).collect::>(); + + let pcs_opening_timer = Timer::new("Batch PCS Opening for all kernels", true); + let pcs_batch_opening = + open_defered_pcs::(prover_setup, &vals_ref, &challenges); + pcs_opening_timer.stop(); + + proofs.push(pcs_batch_opening); Some(CombinedProof { commitments: commitments.unwrap(), proofs, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index fd9fab87..ddcf7bde 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -1,6 +1,5 @@ use arith::Fr; use gkr_engine::{FieldEngine, GKREngine, MPIConfig, MPIEngine}; -use serdes::ExpSerde; use crate::{ frontend::{Config, SIMDField}, @@ -9,7 +8,11 @@ use crate::{ proving_system::{ expander::structs::{ExpanderProverSetup, ExpanderVerifierSetup}, expander_no_oversubscribe::prove_impl::mpi_prove_no_oversubscribe_impl, - expander_parallelized::{server_ctrl::SharedMemoryWINWrapper, server_fns::ServerFns}, + expander_parallelized::{ + server_ctrl::SharedMemoryWINWrapper, + server_fns::{broadcast_string, read_circuit, ServerFns}, + }, + expander_pcs_defered::setup_impl::pcs_setup_max_length_only, CombinedProof, Expander, ExpanderNoOverSubscribe, ParallelizedExpander, }, }, @@ -29,14 +32,19 @@ where verifier_setup: &mut ExpanderVerifierSetup, mpi_win: &mut Option, ) { - ParallelizedExpander::::setup_request_handler( - global_mpi_config, - setup_file, - computation_graph, - prover_setup, - verifier_setup, - mpi_win, - ); + let setup_file = if global_mpi_config.is_root() { + let setup_file = setup_file.expect("Setup file path must be provided"); + broadcast_string(global_mpi_config, Some(setup_file)) + } else { + // Workers will wait for the setup file to be broadcasted + broadcast_string(global_mpi_config, None) + }; + + read_circuit::(global_mpi_config, setup_file, computation_graph, mpi_win); + if global_mpi_config.is_root() { + (*prover_setup, *verifier_setup) = + pcs_setup_max_length_only::(computation_graph); + } } fn prove_request_handler( @@ -52,31 +60,3 @@ where mpi_prove_no_oversubscribe_impl(global_mpi_config, prover_setup, computation_graph, values) } } - -pub fn broadcast_string(global_mpi_config: &MPIConfig<'static>, string: Option) -> String { - // Broadcast the setup file path to all workers - if global_mpi_config.is_root() && string.is_none() { - panic!("String must be provided on the root process in broadcast_string"); - } - let mut string_length = string.as_ref().map_or(0, |s| s.len()); - global_mpi_config.root_broadcast_f(&mut string_length); - let mut bytes = string.map_or(vec![0u8; string_length], |s| s.into_bytes()); - global_mpi_config.root_broadcast_bytes(&mut bytes); - String::from_utf8(bytes).expect("Failed to convert broadcasted bytes to String") -} - -pub fn read_circuit( - _global_mpi_config: &MPIConfig<'static>, - setup_file: String, - computation_graph: &mut ComputationGraph, -) where - C: GKREngine, - ECCConfig: Config, -{ - let computation_graph_bytes = - std::fs::read(setup_file).expect("Failed to read computation graph from file"); - *computation_graph = ComputationGraph::::deserialize_from(std::io::Cursor::new( - computation_graph_bytes, - )) - .expect("Failed to deserialize computation graph"); -} From 714e57337a5b1c0eacca241bc5e761cd48ca6eaa Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 14 Jul 2025 21:17:27 -0500 Subject: [PATCH 29/60] clippy fix --- .../expander_no_oversubscribe/api_no_oversubscribe.rs | 2 +- .../expander_no_oversubscribe/prove_impl.rs | 8 ++------ .../proving_system/expander_no_oversubscribe/server_fn.rs | 2 +- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index 4d5558e4..57842b13 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -9,7 +9,7 @@ use crate::zkcuda::proving_system::expander_parallelized::client_utils::{ ClientHttpHelper, }; use crate::zkcuda::proving_system::{ - CombinedProof, ExpanderPCSDefered, ParallelizedExpander, ProvingSystem, + CombinedProof, ExpanderPCSDefered, ProvingSystem, }; use super::super::Expander; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index d7f65c23..fde8de54 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -13,16 +13,12 @@ use crate::{ kernel::{Kernel, LayeredCircuitInputVec}, proving_system::{ expander::{ - commit_impl::local_commit_impl, prove_impl::{ get_local_vals, prepare_expander_circuit, prepare_inputs_with_local_vals, }, structs::{ExpanderProof, ExpanderProverSetup}, }, - expander_parallelized::{ - prove_impl::partition_single_gkr_claim_and_open_pcs_mpi, - server_ctrl::generate_local_mpi_config, - }, + expander_parallelized::server_ctrl::generate_local_mpi_config, expander_pcs_defered::prove_impl::{ extract_pcs_claims, open_defered_pcs, pad_vals_and_commit, }, @@ -43,7 +39,7 @@ where ECCConfig: Config, { let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); - let (commitments, states) = if global_mpi_config.is_root() { + let (commitments, _states) = if global_mpi_config.is_root() { let (commitments, states) = values .iter() .map(|value| pad_vals_and_commit::(prover_setup, value.as_ref())) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index ddcf7bde..00459eee 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -13,7 +13,7 @@ use crate::{ server_fns::{broadcast_string, read_circuit, ServerFns}, }, expander_pcs_defered::setup_impl::pcs_setup_max_length_only, - CombinedProof, Expander, ExpanderNoOverSubscribe, ParallelizedExpander, + CombinedProof, Expander, ExpanderNoOverSubscribe, }, }, }; From 421b876e1d8fd0d0ef0f5da60ee0e22e80169905 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 14 Jul 2025 21:29:49 -0500 Subject: [PATCH 30/60] fmt --- .../expander_no_oversubscribe/api_no_oversubscribe.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index 57842b13..4de71faf 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -8,9 +8,7 @@ use crate::zkcuda::proving_system::expander_parallelized::client_utils::{ client_launch_server_and_setup, client_parse_args, client_send_witness_and_prove, wait_async, ClientHttpHelper, }; -use crate::zkcuda::proving_system::{ - CombinedProof, ExpanderPCSDefered, ProvingSystem, -}; +use crate::zkcuda::proving_system::{CombinedProof, ExpanderPCSDefered, ProvingSystem}; use super::super::Expander; From ec2e67f78542f8d4203e441f1d423747c01f2939 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 14 Jul 2025 21:47:29 -0500 Subject: [PATCH 31/60] mpi version --- .github/workflows/ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6bd51090..ee2c8673 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: with: workspaces: "expander_compiler -> expander_compiler/target" # The prefix cache key, this can be changed to start a new cache manually. - prefix-key: "mpi-v5.0.5" # update me if brew formula changes to a new version + prefix-key: "mpi-v5.0.7" # update me if brew formula changes to a new version - if: matrix.os == 'macos-latest' run: brew install openmpi - if: matrix.os == 'ubuntu-latest' @@ -71,7 +71,7 @@ jobs: with: workspaces: "expander_compiler -> expander_compiler/target" # The prefix cache key, this can be changed to start a new cache manually. - prefix-key: "mpi-v5.0.5" # update me if brew formula changes to a new version + prefix-key: "mpi-v5.0.7" # update me if brew formula changes to a new version - if: matrix.os == 'macos-latest' run: brew install openmpi - if: matrix.os == 'ubuntu-latest' @@ -90,7 +90,7 @@ jobs: with: workspaces: "expander_compiler -> expander_compiler/target" # The prefix cache key, this can be changed to start a new cache manually. - prefix-key: "mpi-v5.0.5" # update me if brew formula changes to a new version + prefix-key: "mpi-v5.0.7" # update me if brew formula changes to a new version - run: RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo build --release --bin expander_commit --bin expander_prove - run: RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo test From b625fd813d7e548618bfe06d647ced010d214397 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 14 Jul 2025 21:49:13 -0500 Subject: [PATCH 32/60] mpi version --- .github/workflows/ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ee2c8673..3a5d27ab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: with: workspaces: "expander_compiler -> expander_compiler/target" # The prefix cache key, this can be changed to start a new cache manually. - prefix-key: "mpi-v5.0.7" # update me if brew formula changes to a new version + prefix-key: "mpi-v5.0.8" # update me if brew formula changes to a new version - if: matrix.os == 'macos-latest' run: brew install openmpi - if: matrix.os == 'ubuntu-latest' @@ -71,7 +71,7 @@ jobs: with: workspaces: "expander_compiler -> expander_compiler/target" # The prefix cache key, this can be changed to start a new cache manually. - prefix-key: "mpi-v5.0.7" # update me if brew formula changes to a new version + prefix-key: "mpi-v5.0.8" # update me if brew formula changes to a new version - if: matrix.os == 'macos-latest' run: brew install openmpi - if: matrix.os == 'ubuntu-latest' @@ -90,7 +90,7 @@ jobs: with: workspaces: "expander_compiler -> expander_compiler/target" # The prefix cache key, this can be changed to start a new cache manually. - prefix-key: "mpi-v5.0.7" # update me if brew formula changes to a new version + prefix-key: "mpi-v5.0.8" # update me if brew formula changes to a new version - run: RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo build --release --bin expander_commit --bin expander_prove - run: RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo test From 60efc090d3b2b3b04e6320dc1d6720e7a59c9fa8 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Tue, 15 Jul 2025 21:19:23 -0500 Subject: [PATCH 33/60] control prover behavior with zkcuda config, enabling batching/non-batching pcs --- .../bin/zkcuda_matmul_no_oversubscribe.rs | 11 +- .../api_no_oversubscribe.rs | 53 +++-- .../expander_no_oversubscribe/cmd_utils.rs | 0 .../expander_no_oversubscribe/prove_impl.rs | 200 ++++++++++++------ .../expander_no_oversubscribe/server_bin.rs | 37 +++- .../expander_no_oversubscribe/server_fn.rs | 78 +++---- .../expander_parallelized/api_parallel.rs | 7 +- .../expander_parallelized/client_utils.rs | 3 +- .../expander_parallelized/cmd_utils.rs | 9 +- .../expander_parallelized/server_ctrl.rs | 4 + .../expander_pcs_defered/api_pcs_defered.rs | 7 +- expander_compiler/tests/zkcuda_examples.rs | 9 +- 12 files changed, 276 insertions(+), 142 deletions(-) delete mode 100644 expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/cmd_utils.rs diff --git a/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs b/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs index 9782e841..df3d19aa 100644 --- a/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs +++ b/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs @@ -3,15 +3,16 @@ mod zkcuda_matmul; use expander_compiler::{ frontend::BN254Config, zkcuda::proving_system::{ - expander_pcs_defered::BN254ConfigSha2UniKZG, ExpanderNoOverSubscribe, + expander::config::ZKCudaBN254Hyrax, expander_pcs_defered::BN254ConfigSha2UniKZG, + ExpanderNoOverSubscribe, }, }; use zkcuda_matmul::zkcuda_matmul; fn main() { - zkcuda_matmul::, 4>(); - zkcuda_matmul::, 8>(); - zkcuda_matmul::, 16>(); + zkcuda_matmul::<_, ExpanderNoOverSubscribe, 4>(); + zkcuda_matmul::<_, ExpanderNoOverSubscribe, 8>(); + zkcuda_matmul::<_, ExpanderNoOverSubscribe, 16>(); - zkcuda_matmul::, 1024>(); + zkcuda_matmul::<_, ExpanderNoOverSubscribe, 1024>(); } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index 4de71faf..5a215097 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -1,6 +1,6 @@ -use crate::circuit::config::Config; use crate::frontend::SIMDField; use crate::zkcuda::context::ComputationGraph; +use crate::zkcuda::proving_system::expander::config::{GetFieldConfig, GetPCS, ZKCudaConfig}; use crate::zkcuda::proving_system::expander::structs::{ ExpanderProverSetup, ExpanderVerifierSetup, }; @@ -8,48 +8,65 @@ use crate::zkcuda::proving_system::expander_parallelized::client_utils::{ client_launch_server_and_setup, client_parse_args, client_send_witness_and_prove, wait_async, ClientHttpHelper, }; -use crate::zkcuda::proving_system::{CombinedProof, ExpanderPCSDefered, ProvingSystem}; +use crate::zkcuda::proving_system::{ + CombinedProof, ExpanderPCSDefered, ParallelizedExpander, ProvingSystem, +}; use super::super::Expander; -use gkr_engine::{ExpanderPCS, GKREngine}; +use gkr_engine::ExpanderPCS; -pub struct ExpanderNoOverSubscribe { - _config: std::marker::PhantomData, +pub struct ExpanderNoOverSubscribe { + _config: std::marker::PhantomData, } -impl> ProvingSystem - for ExpanderNoOverSubscribe +impl ProvingSystem for ExpanderNoOverSubscribe where - >::Commitment: - AsRef<>::Commitment>, + as ExpanderPCS>>::Commitment: + AsRef< as ExpanderPCS>>::Commitment>, { - type ProverSetup = ExpanderProverSetup; - type VerifierSetup = ExpanderVerifierSetup; - type Proof = CombinedProof>; + type ProverSetup = ExpanderProverSetup, GetPCS>; + type VerifierSetup = ExpanderVerifierSetup, GetPCS>; + type Proof = CombinedProof>; fn setup( - computation_graph: &crate::zkcuda::context::ComputationGraph, + computation_graph: &ComputationGraph, ) -> (Self::ProverSetup, Self::VerifierSetup) { let server_binary = client_parse_args() .unwrap_or("../target/release/expander_server_no_oversubscribe".to_owned()); - client_launch_server_and_setup::(&server_binary, computation_graph, false) + client_launch_server_and_setup::( + &server_binary, + computation_graph, + false, + ZC::BATCH_PCS, + ) } fn prove( _prover_setup: &Self::ProverSetup, - _computation_graph: &crate::zkcuda::context::ComputationGraph, - device_memories: &[Vec>], + _computation_graph: &ComputationGraph, + device_memories: &[Vec>], ) -> Self::Proof { client_send_witness_and_prove(device_memories) } fn verify( verifier_setup: &Self::VerifierSetup, - computation_graph: &ComputationGraph, + computation_graph: &ComputationGraph, proof: &Self::Proof, ) -> bool { - ExpanderPCSDefered::::verify(verifier_setup, computation_graph, proof) + match ZC::BATCH_PCS { + true => ExpanderPCSDefered::::verify( + verifier_setup, + computation_graph, + proof, + ), + false => ParallelizedExpander::::verify( + verifier_setup, + computation_graph, + proof, + ), + } } fn post_process() { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/cmd_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/cmd_utils.rs deleted file mode 100644 index e69de29b..00000000 diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index fde8de54..7f8ea82b 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -13,12 +13,17 @@ use crate::{ kernel::{Kernel, LayeredCircuitInputVec}, proving_system::{ expander::{ + commit_impl::local_commit_impl, + config::{GetFieldConfig, GetPCS, GetTranscript, ZKCudaConfig}, prove_impl::{ get_local_vals, prepare_expander_circuit, prepare_inputs_with_local_vals, }, structs::{ExpanderProof, ExpanderProverSetup}, }, - expander_parallelized::server_ctrl::generate_local_mpi_config, + expander_parallelized::{ + prove_impl::partition_single_gkr_claim_and_open_pcs_mpi, + server_ctrl::generate_local_mpi_config, + }, expander_pcs_defered::prove_impl::{ extract_pcs_claims, open_defered_pcs, pad_vals_and_commit, }, @@ -27,22 +32,28 @@ use crate::{ }, }; -pub fn mpi_prove_no_oversubscribe_impl( +pub fn mpi_prove_no_oversubscribe_impl( global_mpi_config: &MPIConfig<'static>, - prover_setup: &ExpanderProverSetup, - computation_graph: &ComputationGraph, - values: &[impl AsRef<[SIMDField]>], -) -> Option>> + prover_setup: &ExpanderProverSetup, GetPCS>, + computation_graph: &ComputationGraph, + values: &[impl AsRef<[SIMDField]>], +) -> Option>> where - C: GKREngine, - C::FieldConfig: FieldEngine, - ECCConfig: Config, + ::FieldConfig: FieldEngine, { let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); - let (commitments, _states) = if global_mpi_config.is_root() { + let (commitments, states) = if global_mpi_config.is_root() { let (commitments, states) = values .iter() - .map(|value| pad_vals_and_commit::(prover_setup, value.as_ref())) + .map(|value| match ZC::BATCH_PCS { + true => pad_vals_and_commit::( + prover_setup, + value.as_ref(), + ), + false => { + local_commit_impl::(prover_setup, value.as_ref()) + } + }) .unzip::<_, _, Vec<_>, Vec<_>>(); (Some(commitments), Some(states)) } else { @@ -54,71 +65,128 @@ where let mut challenges = vec![]; let prove_timer = Timer::new("Prove all kernels", global_mpi_config.is_root()); - let proofs = computation_graph - .proof_templates() - .iter() - .map(|template| { - let commitment_values = template - .commitment_indices() - .iter() - .map(|&idx| values[idx].as_ref()) - .collect::>(); - - let single_kernel_gkr_timer = - Timer::new("small gkr kernel", global_mpi_config.is_root()); - let gkr_end_state = prove_kernel_gkr_no_oversubscribe::< - C::FieldConfig, - C::TranscriptConfig, - ECCConfig, - >( - global_mpi_config, - &computation_graph.kernels()[template.kernel_id()], - &commitment_values, - next_power_of_two(template.parallel_count()), - template.is_broadcast(), - ); - single_kernel_gkr_timer.stop(); - - if global_mpi_config.is_root() { - let (mut transcript, challenge) = gkr_end_state.unwrap(); - assert!(challenge.challenge_y().is_none()); - let challenge = challenge.challenge_x(); + let proofs = + computation_graph + .proof_templates() + .iter() + .map(|template| { + let commitment_values = template + .commitment_indices() + .iter() + .map(|&idx| values[idx].as_ref()) + .collect::>(); - let (local_vals_ref, local_challenges) = extract_pcs_claims::( + let single_kernel_gkr_timer = + Timer::new("small gkr kernel", global_mpi_config.is_root()); + let gkr_end_state = prove_kernel_gkr_no_oversubscribe::< + GetFieldConfig, + GetTranscript, + ZC::ECCConfig, + >( + global_mpi_config, + &computation_graph.kernels()[template.kernel_id()], &commitment_values, - &challenge, - template.is_broadcast(), next_power_of_two(template.parallel_count()), + template.is_broadcast(), ); + single_kernel_gkr_timer.stop(); - vals_ref.extend(local_vals_ref); - challenges.extend(local_challenges); + match ZC::BATCH_PCS { + true => { + if global_mpi_config.is_root() { + let (mut transcript, challenge) = gkr_end_state.unwrap(); + assert!(challenge.challenge_y().is_none()); + let challenge = challenge.challenge_x(); - Some(ExpanderProof { - data: vec![transcript.finalize_and_get_proof()], - }) - } else { - None - } - }) - .collect::>(); + let (local_vals_ref, local_challenges) = + extract_pcs_claims::( + &commitment_values, + &challenge, + template.is_broadcast(), + next_power_of_two(template.parallel_count()), + ); + + vals_ref.extend(local_vals_ref); + challenges.extend(local_challenges); + + Some(ExpanderProof { + data: vec![transcript.finalize_and_get_proof()], + }) + } else { + None + } + } + false => { + if global_mpi_config.is_root() { + let pcs_open_timer = Timer::new("pcs open", true); + let (mut transcript, challenge) = gkr_end_state.unwrap(); + let challenges = if let Some(challenge_y) = challenge.challenge_y() { + vec![challenge.challenge_x(), challenge_y] + } else { + vec![challenge.challenge_x()] + }; + + challenges.iter().for_each(|c| { + partition_single_gkr_claim_and_open_pcs_mpi::( + prover_setup, + &commitment_values, + &template + .commitment_indices() + .iter() + .map(|&idx| &states.as_ref().unwrap()[idx]) + .collect::>(), + c, + template.is_broadcast(), + &mut transcript, + ); + }); + + pcs_open_timer.stop(); + Some(ExpanderProof { + data: vec![transcript.finalize_and_get_proof()], + }) + } else { + None + } + } + } + }) + .collect::>(); prove_timer.stop(); - if global_mpi_config.is_root() { - let mut proofs = proofs.into_iter().map(|p| p.unwrap()).collect::>(); + match ZC::BATCH_PCS { + true => { + if global_mpi_config.is_root() { + let mut proofs = proofs.into_iter().map(|p| p.unwrap()).collect::>(); - let pcs_opening_timer = Timer::new("Batch PCS Opening for all kernels", true); - let pcs_batch_opening = - open_defered_pcs::(prover_setup, &vals_ref, &challenges); - pcs_opening_timer.stop(); + let pcs_opening_timer = Timer::new("Batch PCS Opening for all kernels", true); + let pcs_batch_opening = open_defered_pcs::( + prover_setup, + &vals_ref, + &challenges, + ); + pcs_opening_timer.stop(); - proofs.push(pcs_batch_opening); - Some(CombinedProof { - commitments: commitments.unwrap(), - proofs, - }) - } else { - None + proofs.push(pcs_batch_opening); + Some(CombinedProof { + commitments: commitments.unwrap(), + proofs, + }) + } else { + None + } + } + false => { + if global_mpi_config.is_root() { + let proofs = proofs.into_iter().map(|p| p.unwrap()).collect::>(); + Some(CombinedProof { + commitments: commitments.unwrap(), + proofs, + }) + } else { + None + } + } } } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs index fd7cb9c7..2da46f9c 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs @@ -4,12 +4,17 @@ use clap::Parser; use expander_compiler::{ frontend::BN254Config, zkcuda::proving_system::{ + expander::{ + self, + config::{ + ZKCudaBN254Hyrax, ZKCudaBN254HyraxBatchPCS, ZKCudaBN254KZG, ZKCudaBN254KZGBatchPCS, + }, + }, expander_parallelized::server_ctrl::{serve, ExpanderExecArgs}, expander_pcs_defered::BN254ConfigSha2UniKZG, ExpanderNoOverSubscribe, }, }; -use gkr::BN254ConfigSha2Hyrax; use gkr_engine::PolynomialCommitmentType; #[tokio::main] @@ -24,16 +29,30 @@ pub async fn main() { match (expander_exec_args.field_type.as_str(), pcs_type) { ("BN254", PolynomialCommitmentType::Hyrax) => { - serve::>( - expander_exec_args.port_number, - ) - .await; + if expander_exec_args.batch_pcs { + serve::<_, _, ExpanderNoOverSubscribe>( + expander_exec_args.port_number, + ) + .await; + } else { + serve::<_, _, ExpanderNoOverSubscribe>( + expander_exec_args.port_number, + ) + .await; + } } ("BN254", PolynomialCommitmentType::KZG) => { - serve::>( - expander_exec_args.port_number, - ) - .await; + if expander_exec_args.batch_pcs { + serve::<_, _, ExpanderNoOverSubscribe>( + expander_exec_args.port_number, + ) + .await; + } else { + serve::<_, _, ExpanderNoOverSubscribe>( + expander_exec_args.port_number, + ) + .await; + } } (field_type, pcs_type) => { panic!("Combination of {field_type:?} and {pcs_type:?} not supported for no oversubscribe expander proving system."); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index 00459eee..9e8f7a40 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -1,62 +1,66 @@ use arith::Fr; -use gkr_engine::{FieldEngine, GKREngine, MPIConfig, MPIEngine}; +use gkr_engine::{FieldEngine, GKREngine, MPIConfig}; use crate::{ - frontend::{Config, SIMDField}, + frontend::SIMDField, zkcuda::{ context::ComputationGraph, proving_system::{ - expander::structs::{ExpanderProverSetup, ExpanderVerifierSetup}, - expander_no_oversubscribe::prove_impl::mpi_prove_no_oversubscribe_impl, - expander_parallelized::{ - server_ctrl::SharedMemoryWINWrapper, - server_fns::{broadcast_string, read_circuit, ServerFns}, + expander::{ + config::{GetFieldConfig, GetPCS, ZKCudaConfig}, + structs::{ExpanderProverSetup, ExpanderVerifierSetup}, }, - expander_pcs_defered::setup_impl::pcs_setup_max_length_only, - CombinedProof, Expander, ExpanderNoOverSubscribe, + expander_no_oversubscribe::prove_impl::mpi_prove_no_oversubscribe_impl, + expander_parallelized::{server_ctrl::SharedMemoryWINWrapper, server_fns::ServerFns}, + CombinedProof, Expander, ExpanderNoOverSubscribe, ExpanderPCSDefered, + ParallelizedExpander, }, }, }; -impl ServerFns for ExpanderNoOverSubscribe +impl ServerFns for ExpanderNoOverSubscribe where - C: GKREngine, - C::FieldConfig: FieldEngine, - ECCConfig: Config, + ::FieldConfig: FieldEngine, { fn setup_request_handler( global_mpi_config: &MPIConfig<'static>, setup_file: Option, - computation_graph: &mut ComputationGraph, - prover_setup: &mut ExpanderProverSetup, - verifier_setup: &mut ExpanderVerifierSetup, + computation_graph: &mut ComputationGraph, + prover_setup: &mut ExpanderProverSetup, GetPCS>, + verifier_setup: &mut ExpanderVerifierSetup, GetPCS>, mpi_win: &mut Option, ) { - let setup_file = if global_mpi_config.is_root() { - let setup_file = setup_file.expect("Setup file path must be provided"); - broadcast_string(global_mpi_config, Some(setup_file)) - } else { - // Workers will wait for the setup file to be broadcasted - broadcast_string(global_mpi_config, None) - }; - - read_circuit::(global_mpi_config, setup_file, computation_graph, mpi_win); - if global_mpi_config.is_root() { - (*prover_setup, *verifier_setup) = - pcs_setup_max_length_only::(computation_graph); + match ZC::BATCH_PCS { + true => ExpanderPCSDefered::::setup_request_handler( + global_mpi_config, + setup_file, + computation_graph, + prover_setup, + verifier_setup, + mpi_win, + ), + false => ParallelizedExpander::::setup_request_handler( + global_mpi_config, + setup_file, + computation_graph, + prover_setup, + verifier_setup, + mpi_win, + ), } } fn prove_request_handler( global_mpi_config: &MPIConfig<'static>, - prover_setup: &ExpanderProverSetup, - computation_graph: &ComputationGraph, - values: &[impl AsRef<[SIMDField]>], - ) -> Option>> - where - C: GKREngine, - ECCConfig: Config, - { - mpi_prove_no_oversubscribe_impl(global_mpi_config, prover_setup, computation_graph, values) + prover_setup: &ExpanderProverSetup, GetPCS>, + computation_graph: &ComputationGraph, + values: &[impl AsRef<[SIMDField]>], + ) -> Option>> { + mpi_prove_no_oversubscribe_impl::( + global_mpi_config, + prover_setup, + computation_graph, + values, + ) } } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs index aad322f6..1862af98 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs @@ -34,7 +34,12 @@ impl> ProvingSyste ) -> (Self::ProverSetup, Self::VerifierSetup) { let server_binary = client_parse_args().unwrap_or("../target/release/expander_server".to_owned()); - client_launch_server_and_setup::(&server_binary, computation_graph, true) + client_launch_server_and_setup::( + &server_binary, + computation_graph, + true, + false, + ) } fn prove( diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index ebec0cd7..855a6701 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -78,6 +78,7 @@ pub fn client_launch_server_and_setup( server_binary: &str, computation_graph: &ComputationGraph, allow_oversubscribe: bool, + batch_pcs: bool, ) -> ( ExpanderProverSetup, ExpanderVerifierSetup, @@ -119,7 +120,7 @@ where let port = parse_port_number(); let server_url = format!("{SERVER_IP}:{port}"); - start_server::(server_binary, mpi_size, port); + start_server::(server_binary, mpi_size, port, batch_pcs); // Keep trying until the server is ready loop { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs index 11db0122..187211cc 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs @@ -2,11 +2,16 @@ use gkr_engine::{ExpanderPCS, FieldEngine, FieldType, GKREngine, PolynomialCommi use std::process::Command; #[allow(clippy::zombie_processes)] -pub fn start_server(binary: &str, max_parallel_count: usize, port_number: u16) { +pub fn start_server( + binary: &str, + max_parallel_count: usize, + port_number: u16, + batch_pcs: bool, +) { let (overscribe, field_name, pcs_name) = parse_config::(max_parallel_count); let cmd_str = format!( - "mpiexec -n {max_parallel_count} {overscribe} {binary} --field-type {field_name} --poly-commit {pcs_name} --port-number {port_number}" + "mpiexec -n {max_parallel_count} {overscribe} {binary} --field-type {field_name} --poly-commit {pcs_name} --port-number {port_number} --batch-pcs {batch_pcs}" ); exec_command(&cmd_str, false); } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs index 604eec66..260bbc4e 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs @@ -381,4 +381,8 @@ pub struct ExpanderExecArgs { /// The port number for the server to listen on. #[arg(short, long, default_value = "Port")] pub port_number: String, + + /// The port number for the server to listen on. + #[arg(short, long, default_value = "false")] + pub batch_pcs: bool, } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs index 235262d6..daa93c08 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs @@ -35,7 +35,12 @@ where ) -> (Self::ProverSetup, Self::VerifierSetup) { let server_binary = client_parse_args() .unwrap_or("../target/release/expander_server_pcs_defered".to_owned()); - client_launch_server_and_setup::(&server_binary, computation_graph, true) + client_launch_server_and_setup::( + &server_binary, + computation_graph, + true, + true, + ) } fn prove( diff --git a/expander_compiler/tests/zkcuda_examples.rs b/expander_compiler/tests/zkcuda_examples.rs index 8a085e75..727f08d0 100644 --- a/expander_compiler/tests/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda_examples.rs @@ -1,4 +1,7 @@ use expander_compiler::frontend::*; +use expander_compiler::zkcuda::proving_system::expander::config::{ + ZKCudaBN254Hyrax, ZKCudaBN254HyraxBatchPCS, ZKCudaBN254KZG, ZKCudaBN254KZGBatchPCS, +}; use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; use expander_compiler::zkcuda::proving_system::{ Expander, ExpanderNoOverSubscribe, ParallelizedExpander, ProvingSystem, @@ -90,8 +93,10 @@ fn zkcuda_test_multi_core() { zkcuda_test::>(); zkcuda_test::>(); - zkcuda_test::>(); - zkcuda_test::>(); + zkcuda_test::<_, ExpanderNoOverSubscribe>(); + zkcuda_test::<_, ExpanderNoOverSubscribe>(); + zkcuda_test::<_, ExpanderNoOverSubscribe>(); + zkcuda_test::<_, ExpanderNoOverSubscribe>(); } fn zkcuda_test_simd_prepare_ctx() -> Context { From 9aab4d3dc90d72033888b5efcd8764e534cf45c5 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Tue, 15 Jul 2025 21:19:47 -0500 Subject: [PATCH 34/60] clippy auto fix --- .../expander_no_oversubscribe/server_bin.rs | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs index 2da46f9c..a389e894 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs @@ -1,20 +1,13 @@ use std::str::FromStr; use clap::Parser; -use expander_compiler::{ - frontend::BN254Config, - zkcuda::proving_system::{ - expander::{ - self, - config::{ +use expander_compiler::zkcuda::proving_system::{ + expander::config::{ ZKCudaBN254Hyrax, ZKCudaBN254HyraxBatchPCS, ZKCudaBN254KZG, ZKCudaBN254KZGBatchPCS, }, - }, expander_parallelized::server_ctrl::{serve, ExpanderExecArgs}, - expander_pcs_defered::BN254ConfigSha2UniKZG, ExpanderNoOverSubscribe, - }, -}; + }; use gkr_engine::PolynomialCommitmentType; #[tokio::main] From 3e72945aad0604347208e6344461de5ac20df637 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Tue, 15 Jul 2025 21:24:26 -0500 Subject: [PATCH 35/60] printing the size of computation graph/witness/proof --- .../expander_no_oversubscribe/server_bin.rs | 12 ++++++------ .../expander_parallelized/client_utils.rs | 2 ++ .../expander_parallelized/shared_memory_utils.rs | 5 +++++ 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs index a389e894..5c402ac7 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs @@ -2,12 +2,12 @@ use std::str::FromStr; use clap::Parser; use expander_compiler::zkcuda::proving_system::{ - expander::config::{ - ZKCudaBN254Hyrax, ZKCudaBN254HyraxBatchPCS, ZKCudaBN254KZG, ZKCudaBN254KZGBatchPCS, - }, - expander_parallelized::server_ctrl::{serve, ExpanderExecArgs}, - ExpanderNoOverSubscribe, - }; + expander::config::{ + ZKCudaBN254Hyrax, ZKCudaBN254HyraxBatchPCS, ZKCudaBN254KZG, ZKCudaBN254KZGBatchPCS, + }, + expander_parallelized::server_ctrl::{serve, ExpanderExecArgs}, + ExpanderNoOverSubscribe, +}; use gkr_engine::PolynomialCommitmentType; #[tokio::main] diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index 855a6701..fd1a48dc 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -92,6 +92,8 @@ where let mut bytes = vec![]; computation_graph.serialize_into(&mut bytes).unwrap(); + println!("Serialized computation graph, size: {}", bytes.len()); + // append current timestamp to the file name to avoid conflicts let setup_filename = format!( "/tmp/computation_graph_{}.bin", diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs index b68a3ec8..5f5393a3 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs @@ -62,6 +62,8 @@ impl SharedMemoryEngine { .serialize_into(&mut buffer) .expect("Failed to serialize object"); + println!("Object size: {}", buffer.len()); + unsafe { Self::allocate_shared_memory_if_necessary(shared_memory_ref, name, buffer.len()); let object_ptr = shared_memory_ref.as_mut().unwrap().as_ptr(); @@ -91,6 +93,7 @@ impl SharedMemoryEngine { pub fn write_pcs_setup_to_shared_memory>( pcs_setup: &(ExpanderProverSetup, ExpanderVerifierSetup), ) { + println!("Writing PCS setup to shared memory..."); Self::write_object_to_shared_memory( pcs_setup, unsafe { &mut SHARED_MEMORY.pcs_setup }, @@ -112,6 +115,7 @@ impl SharedMemoryEngine { .map(|v| std::mem::size_of::() + std::mem::size_of_val(v.as_ref())) .sum::(); + println!("Writing witness to shared memory, total size: {total_size}"); unsafe { Self::allocate_shared_memory_if_necessary( &mut SHARED_MEMORY.witness, @@ -208,6 +212,7 @@ impl SharedMemoryEngine { >( proof: &CombinedProof>, ) { + println!("Writing proof to shared memory..."); Self::write_object_to_shared_memory(proof, unsafe { &mut SHARED_MEMORY.proof }, "proof"); } From f838643711dbeb8a2b049bf8e715e4b440a361fb Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Tue, 15 Jul 2025 21:30:43 -0500 Subject: [PATCH 36/60] update dependency --- Cargo.lock | 54 +++++++++++++++++++++++++++--------------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 03b1b268..5bfa24f3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,7 +112,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "ark-std", "criterion", @@ -330,7 +330,7 @@ dependencies = [ [[package]] name = "babybear" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "ark-std", @@ -383,7 +383,7 @@ dependencies = [ [[package]] name = "bin" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "babybear", @@ -589,7 +589,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "ark-std", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "config_macros" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "gkr_engine", "gkr_hashers", @@ -817,7 +817,7 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crosslayer_prototype" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "env_logger", @@ -1143,7 +1143,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "ark-std", @@ -1160,7 +1160,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "ark-std", @@ -1179,7 +1179,7 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "ark-std", @@ -1212,7 +1212,7 @@ dependencies = [ [[package]] name = "gkr_engine" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "babybear", @@ -1231,7 +1231,7 @@ dependencies = [ [[package]] name = "gkr_hashers" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "halo2curves", @@ -1249,7 +1249,7 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "goldilocks" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "ark-std", @@ -1271,9 +1271,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" +checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d" dependencies = [ "bytes", "fnv", @@ -1881,7 +1881,7 @@ dependencies = [ [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "ark-std", @@ -2259,7 +2259,7 @@ dependencies = [ [[package]] name = "poly_commit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "ark-std", @@ -2284,7 +2284,7 @@ dependencies = [ [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "ark-std", @@ -2557,15 +2557,15 @@ dependencies = [ [[package]] name = "rustix" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8" dependencies = [ "bitflags 2.9.1", "errno", "libc", "linux-raw-sys 0.9.4", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -2705,7 +2705,7 @@ dependencies = [ [[package]] name = "serdes" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "ethnum", "halo2curves", @@ -2716,7 +2716,7 @@ dependencies = [ [[package]] name = "serdes_derive" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "proc-macro2", "quote", @@ -2840,7 +2840,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "circuit", @@ -2936,7 +2936,7 @@ dependencies = [ "fastrand", "getrandom 0.3.3", "once_cell", - "rustix 1.0.7", + "rustix 1.0.8", "windows-sys 0.59.0", ] @@ -3106,7 +3106,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "gkr_engine", @@ -3129,7 +3129,7 @@ dependencies = [ [[package]] name = "tree" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "arith", "ark-std", @@ -3231,7 +3231,7 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "utils" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#f9aa713c6802e8d872f8607cc0062cb9669f02cd" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" dependencies = [ "colored", ] From 3cd0ba8abc13b139acfb6d861438beb0d35777b2 Mon Sep 17 00:00:00 2001 From: zhiyong Date: Thu, 17 Jul 2025 00:22:39 +0000 Subject: [PATCH 37/60] fix a bug in server cli --- .../proving_system/expander_parallelized/cmd_utils.rs | 4 +++- .../proving_system/expander_parallelized/server_ctrl.rs | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs index 187211cc..48b895a7 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs @@ -10,8 +10,9 @@ pub fn start_server( ) { let (overscribe, field_name, pcs_name) = parse_config::(max_parallel_count); + let batch_pcs_option = if batch_pcs { "--batch-pcs" } else { "" }; let cmd_str = format!( - "mpiexec -n {max_parallel_count} {overscribe} {binary} --field-type {field_name} --poly-commit {pcs_name} --port-number {port_number} --batch-pcs {batch_pcs}" + "mpiexec -n {max_parallel_count} {overscribe} {binary} --field-type {field_name} --poly-commit {pcs_name} --port-number {port_number} {batch_pcs_option}" ); exec_command(&cmd_str, false); } @@ -51,6 +52,7 @@ where #[allow(clippy::zombie_processes)] fn exec_command(cmd: &str, wait_for_completion: bool) { + println!("Executing command: {cmd}"); let mut parts = cmd.split_whitespace(); let command = parts.next().unwrap(); let args: Vec<&str> = parts.collect(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs index 260bbc4e..a95202ec 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs @@ -379,10 +379,10 @@ pub struct ExpanderExecArgs { pub poly_commit: String, /// The port number for the server to listen on. - #[arg(short, long, default_value = "Port")] + #[arg(short, long, default_value = "3000")] pub port_number: String, - /// The port number for the server to listen on. - #[arg(short, long, default_value = "false")] + /// Whether to batch PCS opening in proving. + #[arg(short, long, default_value_t = false)] pub batch_pcs: bool, } From 39b112e361010198566c12f256bdcf4ca77f5897 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Thu, 17 Jul 2025 18:50:02 -0500 Subject: [PATCH 38/60] reorganize tests --- expander_compiler/tests/{ => circuit}/example.rs | 0 expander_compiler/tests/{ => circuit}/example_call_expander.rs | 0 expander_compiler/tests/{ => circuit}/keccak_gf2.rs | 0 expander_compiler/tests/{ => circuit}/keccak_gf2_full.rs | 0 .../tests/{ => circuit}/keccak_gf2_full_crosslayer.rs | 0 expander_compiler/tests/{ => circuit}/keccak_gf2_vec.rs | 0 expander_compiler/tests/{ => circuit}/keccak_non_gf2.rs | 0 expander_compiler/tests/{ => circuit}/mul_fanout_limit.rs | 0 expander_compiler/tests/{ => circuit}/multithreading_witness.rs | 0 expander_compiler/tests/{ => circuit}/rsa_mul.py | 0 expander_compiler/tests/{ => circuit}/simple_add_m31.rs | 0 expander_compiler/tests/{ => circuit}/sub_circuit_macro.rs | 0 expander_compiler/tests/{ => circuit}/to_binary_builtin.rs | 0 expander_compiler/tests/{ => circuit}/to_binary_hint.rs | 0 .../tests/{ => circuit}/to_binary_unconstrained_api.rs | 0 expander_compiler/tests/{ => zkcuda}/cg_mpi_share.rs | 0 expander_compiler/tests/{ => zkcuda}/zkcuda_examples.rs | 0 expander_compiler/tests/{ => zkcuda}/zkcuda_keccak.rs | 0 expander_compiler/tests/{ => zkcuda}/zkcuda_matmul.rs | 0 19 files changed, 0 insertions(+), 0 deletions(-) rename expander_compiler/tests/{ => circuit}/example.rs (100%) rename expander_compiler/tests/{ => circuit}/example_call_expander.rs (100%) rename expander_compiler/tests/{ => circuit}/keccak_gf2.rs (100%) rename expander_compiler/tests/{ => circuit}/keccak_gf2_full.rs (100%) rename expander_compiler/tests/{ => circuit}/keccak_gf2_full_crosslayer.rs (100%) rename expander_compiler/tests/{ => circuit}/keccak_gf2_vec.rs (100%) rename expander_compiler/tests/{ => circuit}/keccak_non_gf2.rs (100%) rename expander_compiler/tests/{ => circuit}/mul_fanout_limit.rs (100%) rename expander_compiler/tests/{ => circuit}/multithreading_witness.rs (100%) rename expander_compiler/tests/{ => circuit}/rsa_mul.py (100%) rename expander_compiler/tests/{ => circuit}/simple_add_m31.rs (100%) rename expander_compiler/tests/{ => circuit}/sub_circuit_macro.rs (100%) rename expander_compiler/tests/{ => circuit}/to_binary_builtin.rs (100%) rename expander_compiler/tests/{ => circuit}/to_binary_hint.rs (100%) rename expander_compiler/tests/{ => circuit}/to_binary_unconstrained_api.rs (100%) rename expander_compiler/tests/{ => zkcuda}/cg_mpi_share.rs (100%) rename expander_compiler/tests/{ => zkcuda}/zkcuda_examples.rs (100%) rename expander_compiler/tests/{ => zkcuda}/zkcuda_keccak.rs (100%) rename expander_compiler/tests/{ => zkcuda}/zkcuda_matmul.rs (100%) diff --git a/expander_compiler/tests/example.rs b/expander_compiler/tests/circuit/example.rs similarity index 100% rename from expander_compiler/tests/example.rs rename to expander_compiler/tests/circuit/example.rs diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/circuit/example_call_expander.rs similarity index 100% rename from expander_compiler/tests/example_call_expander.rs rename to expander_compiler/tests/circuit/example_call_expander.rs diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/circuit/keccak_gf2.rs similarity index 100% rename from expander_compiler/tests/keccak_gf2.rs rename to expander_compiler/tests/circuit/keccak_gf2.rs diff --git a/expander_compiler/tests/keccak_gf2_full.rs b/expander_compiler/tests/circuit/keccak_gf2_full.rs similarity index 100% rename from expander_compiler/tests/keccak_gf2_full.rs rename to expander_compiler/tests/circuit/keccak_gf2_full.rs diff --git a/expander_compiler/tests/keccak_gf2_full_crosslayer.rs b/expander_compiler/tests/circuit/keccak_gf2_full_crosslayer.rs similarity index 100% rename from expander_compiler/tests/keccak_gf2_full_crosslayer.rs rename to expander_compiler/tests/circuit/keccak_gf2_full_crosslayer.rs diff --git a/expander_compiler/tests/keccak_gf2_vec.rs b/expander_compiler/tests/circuit/keccak_gf2_vec.rs similarity index 100% rename from expander_compiler/tests/keccak_gf2_vec.rs rename to expander_compiler/tests/circuit/keccak_gf2_vec.rs diff --git a/expander_compiler/tests/keccak_non_gf2.rs b/expander_compiler/tests/circuit/keccak_non_gf2.rs similarity index 100% rename from expander_compiler/tests/keccak_non_gf2.rs rename to expander_compiler/tests/circuit/keccak_non_gf2.rs diff --git a/expander_compiler/tests/mul_fanout_limit.rs b/expander_compiler/tests/circuit/mul_fanout_limit.rs similarity index 100% rename from expander_compiler/tests/mul_fanout_limit.rs rename to expander_compiler/tests/circuit/mul_fanout_limit.rs diff --git a/expander_compiler/tests/multithreading_witness.rs b/expander_compiler/tests/circuit/multithreading_witness.rs similarity index 100% rename from expander_compiler/tests/multithreading_witness.rs rename to expander_compiler/tests/circuit/multithreading_witness.rs diff --git a/expander_compiler/tests/rsa_mul.py b/expander_compiler/tests/circuit/rsa_mul.py similarity index 100% rename from expander_compiler/tests/rsa_mul.py rename to expander_compiler/tests/circuit/rsa_mul.py diff --git a/expander_compiler/tests/simple_add_m31.rs b/expander_compiler/tests/circuit/simple_add_m31.rs similarity index 100% rename from expander_compiler/tests/simple_add_m31.rs rename to expander_compiler/tests/circuit/simple_add_m31.rs diff --git a/expander_compiler/tests/sub_circuit_macro.rs b/expander_compiler/tests/circuit/sub_circuit_macro.rs similarity index 100% rename from expander_compiler/tests/sub_circuit_macro.rs rename to expander_compiler/tests/circuit/sub_circuit_macro.rs diff --git a/expander_compiler/tests/to_binary_builtin.rs b/expander_compiler/tests/circuit/to_binary_builtin.rs similarity index 100% rename from expander_compiler/tests/to_binary_builtin.rs rename to expander_compiler/tests/circuit/to_binary_builtin.rs diff --git a/expander_compiler/tests/to_binary_hint.rs b/expander_compiler/tests/circuit/to_binary_hint.rs similarity index 100% rename from expander_compiler/tests/to_binary_hint.rs rename to expander_compiler/tests/circuit/to_binary_hint.rs diff --git a/expander_compiler/tests/to_binary_unconstrained_api.rs b/expander_compiler/tests/circuit/to_binary_unconstrained_api.rs similarity index 100% rename from expander_compiler/tests/to_binary_unconstrained_api.rs rename to expander_compiler/tests/circuit/to_binary_unconstrained_api.rs diff --git a/expander_compiler/tests/cg_mpi_share.rs b/expander_compiler/tests/zkcuda/cg_mpi_share.rs similarity index 100% rename from expander_compiler/tests/cg_mpi_share.rs rename to expander_compiler/tests/zkcuda/cg_mpi_share.rs diff --git a/expander_compiler/tests/zkcuda_examples.rs b/expander_compiler/tests/zkcuda/zkcuda_examples.rs similarity index 100% rename from expander_compiler/tests/zkcuda_examples.rs rename to expander_compiler/tests/zkcuda/zkcuda_examples.rs diff --git a/expander_compiler/tests/zkcuda_keccak.rs b/expander_compiler/tests/zkcuda/zkcuda_keccak.rs similarity index 100% rename from expander_compiler/tests/zkcuda_keccak.rs rename to expander_compiler/tests/zkcuda/zkcuda_keccak.rs diff --git a/expander_compiler/tests/zkcuda_matmul.rs b/expander_compiler/tests/zkcuda/zkcuda_matmul.rs similarity index 100% rename from expander_compiler/tests/zkcuda_matmul.rs rename to expander_compiler/tests/zkcuda/zkcuda_matmul.rs From 9c529fe5a584f2ae9cd5590c12cdecccc94a177c Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Thu, 17 Jul 2025 20:57:09 -0500 Subject: [PATCH 39/60] zkcuda integration --- expander_compiler/Cargo.toml | 22 +++++- .../bin/integration/circuit_def.rs | 72 +++++++++++++++++++ expander_compiler/bin/integration/cleanup.rs | 10 +++ expander_compiler/bin/integration/prove.rs | 39 ++++++++++ expander_compiler/bin/integration/run.sh | 0 expander_compiler/bin/integration/setup.rs | 23 ++++++ expander_compiler/bin/integration/verify.rs | 32 +++++++++ .../bin/{ => zkcuda_servers}/zkcuda_matmul.rs | 0 .../zkcuda_matmul_no_oversubscribe.rs | 0 .../zkcuda_matmul_pcs_defered.rs | 0 expander_compiler/tests/circuit/mod.rs | 16 +++++ expander_compiler/tests/mod.rs | 2 + expander_compiler/tests/zkcuda/mod.rs | 3 + 13 files changed, 216 insertions(+), 3 deletions(-) create mode 100644 expander_compiler/bin/integration/circuit_def.rs create mode 100644 expander_compiler/bin/integration/cleanup.rs create mode 100644 expander_compiler/bin/integration/prove.rs create mode 100644 expander_compiler/bin/integration/run.sh create mode 100644 expander_compiler/bin/integration/setup.rs create mode 100644 expander_compiler/bin/integration/verify.rs rename expander_compiler/bin/{ => zkcuda_servers}/zkcuda_matmul.rs (100%) rename expander_compiler/bin/{ => zkcuda_servers}/zkcuda_matmul_no_oversubscribe.rs (100%) rename expander_compiler/bin/{ => zkcuda_servers}/zkcuda_matmul_pcs_defered.rs (100%) create mode 100644 expander_compiler/tests/circuit/mod.rs create mode 100644 expander_compiler/tests/mod.rs create mode 100644 expander_compiler/tests/zkcuda/mod.rs diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 09e8e996..2a6fa7b5 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -67,12 +67,28 @@ path = "src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs" [[bin]] name = "zkcuda_matmul" -path = "bin/zkcuda_matmul.rs" +path = "bin/zkcuda_servers/zkcuda_matmul.rs" [[bin]] name = "zkcuda_matmul_pcs_defered" -path = "bin/zkcuda_matmul_pcs_defered.rs" +path = "bin/zkcuda_servers/zkcuda_matmul_pcs_defered.rs" [[bin]] name = "zkcuda_matmul_no_oversubscribe" -path = "bin/zkcuda_matmul_no_oversubscribe.rs" +path = "bin/zkcuda_servers/zkcuda_matmul_no_oversubscribe.rs" + +[[bin]] +name = "zkcuda_setup" +path = "bin/integration/setup.rs" + +[[bin]] +name = "zkcuda_prove" +path = "bin/integration/prove.rs" + +[[bin]] +name = "zkcuda_verify" +path = "bin/integration/verify.rs" + +[[bin]] +name = "zkcuda_cleanup" +path = "bin/integration/cleanup.rs" diff --git a/expander_compiler/bin/integration/circuit_def.rs b/expander_compiler/bin/integration/circuit_def.rs new file mode 100644 index 00000000..dedb02ad --- /dev/null +++ b/expander_compiler/bin/integration/circuit_def.rs @@ -0,0 +1,72 @@ +use expander_compiler::frontend::*; +use expander_compiler::zkcuda::proving_system::expander::config::{ + ZKCudaBN254Hyrax, ZKCudaBN254HyraxBatchPCS, ZKCudaBN254KZG, ZKCudaBN254KZGBatchPCS, +}; +use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; +use expander_compiler::zkcuda::proving_system::{ + Expander, ExpanderNoOverSubscribe, ParallelizedExpander, ProvingSystem, +}; +use expander_compiler::zkcuda::shape::Reshape; +use expander_compiler::zkcuda::{context::*, kernel::*}; + +use gkr::BN254ConfigSha2Hyrax; +use serdes::ExpSerde; + +#[kernel] +fn add_2_macro(api: &mut API, a: &[InputVariable; 2], b: &mut OutputVariable) { + *b = api.add(a[0], a[1]); +} + +#[kernel] +fn add_16_macro(api: &mut API, a: &[InputVariable; 16], b: &mut OutputVariable) { + let mut sum = api.constant(0); + for i in 0..16 { + sum = api.add(sum, a[i]); + } + *b = sum; +} + +pub fn gen_computation_graph_and_witness( + input: Option>>>, +) -> (ComputationGraph, Option>>>) { + let kernel_add_2: KernelPrimitive = compile_add_2_macro().unwrap(); + let kernel_add_16: KernelPrimitive = compile_add_16_macro().unwrap(); + + let mut ctx: Context = Context::default(); + let a = if let Some(input) = input.as_ref() { + assert_eq!(input.len(), 16); + assert!(input.iter().all(|v| v.len() == 2)); + input.clone() + } else { + let mut tmp = vec![vec![]; 16]; + for i in 0..16 { + for j in 0..2 { + tmp[i].push(CircuitField::::from((i * 2 + j + 1) as u32)); + } + } + tmp + }; + + let expected_result = a.iter().flat_map(|v| v).sum::>(); + + let a = ctx.copy_to_device(&a); + let mut b: DeviceMemoryHandle = None; + call_kernel!(ctx, kernel_add_2, 16, a, mut b).unwrap(); + let b = b.reshape(&[1, 16]); + let mut c: DeviceMemoryHandle = None; + call_kernel!(ctx, kernel_add_16, 1, b, mut c).unwrap(); + let c = c.reshape(&[]); + let result: CircuitField = ctx.copy_to_host(c); + assert_eq!(result, expected_result); + + let computation_graph = ctx.compile_computation_graph().unwrap(); + + let extended_witness = if let Some(_) = input { + ctx.solve_witness().unwrap(); + Some(ctx.export_device_memories()) + } else { + None + }; + + (computation_graph, extended_witness) +} diff --git a/expander_compiler/bin/integration/cleanup.rs b/expander_compiler/bin/integration/cleanup.rs new file mode 100644 index 00000000..c989c36c --- /dev/null +++ b/expander_compiler/bin/integration/cleanup.rs @@ -0,0 +1,10 @@ +use expander_compiler::{ + frontend::BN254Config, + zkcuda::proving_system::{ + expander::config::ZKCudaBN254Hyrax, ExpanderNoOverSubscribe, ProvingSystem, + }, +}; + +fn main() { + as ProvingSystem>::post_process(); +} diff --git a/expander_compiler/bin/integration/prove.rs b/expander_compiler/bin/integration/prove.rs new file mode 100644 index 00000000..7d677125 --- /dev/null +++ b/expander_compiler/bin/integration/prove.rs @@ -0,0 +1,39 @@ +mod circuit_def; +use circuit_def::gen_computation_graph_and_witness; +use expander_compiler::{ + frontend::{BN254Config, CircuitField}, + zkcuda::{ + context::ComputationGraph, + proving_system::{ + expander::config::ZKCudaBN254Hyrax, ExpanderNoOverSubscribe, ProvingSystem, + }, + }, +}; +use serdes::ExpSerde; + +fn main() { + let mut input = vec![vec![]; 16]; + for i in 0..16 { + for j in 0..2 { + input[i].push(CircuitField::::from((i * 2 + j + 1) as u32)); + } + } + + let (_, extended_witness) = gen_computation_graph_and_witness::(Some(input)); + + // Note: we've saved the computation graph and setup in the server. In order to generate a proof, we only need to submit the witness. + let dummy_prover_setup = as ProvingSystem< + BN254Config, + >>::ProverSetup::default(); + let dummy_computation_graph = ComputationGraph::::default(); + + let proof = ExpanderNoOverSubscribe::::prove( + &dummy_prover_setup, + &dummy_computation_graph, + &extended_witness.unwrap(), + ); + + let mut bytes = vec![]; + proof.serialize_into(&mut bytes).unwrap(); + std::fs::write("/tmp/proof.bin", &bytes).unwrap(); +} diff --git a/expander_compiler/bin/integration/run.sh b/expander_compiler/bin/integration/run.sh new file mode 100644 index 00000000..e69de29b diff --git a/expander_compiler/bin/integration/setup.rs b/expander_compiler/bin/integration/setup.rs new file mode 100644 index 00000000..bc94515b --- /dev/null +++ b/expander_compiler/bin/integration/setup.rs @@ -0,0 +1,23 @@ +mod circuit_def; +use circuit_def::gen_computation_graph_and_witness; +use expander_compiler::{ + frontend::BN254Config, + zkcuda::proving_system::{ + expander::config::ZKCudaBN254Hyrax, ExpanderNoOverSubscribe, ProvingSystem, + }, +}; +use serdes::ExpSerde; + +fn main() { + let (computation_graph, _) = gen_computation_graph_and_witness::(None); + let (prover_setup, verifier_setup) = + ExpanderNoOverSubscribe::::setup(&computation_graph); + + let mut bytes = vec![]; + prover_setup.serialize_into(&mut bytes).unwrap(); + std::fs::write("/tmp/prover_setup.bin", &bytes).unwrap(); + + bytes.clear(); + verifier_setup.serialize_into(&mut bytes).unwrap(); + std::fs::write("/tmp/verifier_setup.bin", &bytes).unwrap(); +} diff --git a/expander_compiler/bin/integration/verify.rs b/expander_compiler/bin/integration/verify.rs new file mode 100644 index 00000000..73c93656 --- /dev/null +++ b/expander_compiler/bin/integration/verify.rs @@ -0,0 +1,32 @@ +mod circuit_def; +use std::io::Cursor; + +use circuit_def::gen_computation_graph_and_witness; +use expander_compiler::{ + frontend::BN254Config, + zkcuda::proving_system::{ + expander::config::ZKCudaBN254Hyrax, ExpanderNoOverSubscribe, ProvingSystem, + }, +}; +use serdes::ExpSerde; + +fn main() { + let (computation_graph, _) = gen_computation_graph_and_witness::(None); + + let verifier_setup_bytes = std::fs::read("/tmp/verifier_setup.bin").unwrap(); + let verifier_setup = as ProvingSystem< + BN254Config, + >>::VerifierSetup::deserialize_from(Cursor::new(verifier_setup_bytes)) + .unwrap(); + + let proof_bytes = std::fs::read("/tmp/proof.bin").unwrap(); + let proof = as ProvingSystem>::Proof::deserialize_from(Cursor::new(proof_bytes)).unwrap(); + + let verified = + as ProvingSystem>::verify( + &verifier_setup, + &computation_graph, + &proof, + ); + assert!(verified, "Proof verification failed"); +} diff --git a/expander_compiler/bin/zkcuda_matmul.rs b/expander_compiler/bin/zkcuda_servers/zkcuda_matmul.rs similarity index 100% rename from expander_compiler/bin/zkcuda_matmul.rs rename to expander_compiler/bin/zkcuda_servers/zkcuda_matmul.rs diff --git a/expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs b/expander_compiler/bin/zkcuda_servers/zkcuda_matmul_no_oversubscribe.rs similarity index 100% rename from expander_compiler/bin/zkcuda_matmul_no_oversubscribe.rs rename to expander_compiler/bin/zkcuda_servers/zkcuda_matmul_no_oversubscribe.rs diff --git a/expander_compiler/bin/zkcuda_matmul_pcs_defered.rs b/expander_compiler/bin/zkcuda_servers/zkcuda_matmul_pcs_defered.rs similarity index 100% rename from expander_compiler/bin/zkcuda_matmul_pcs_defered.rs rename to expander_compiler/bin/zkcuda_servers/zkcuda_matmul_pcs_defered.rs diff --git a/expander_compiler/tests/circuit/mod.rs b/expander_compiler/tests/circuit/mod.rs new file mode 100644 index 00000000..c4216d8a --- /dev/null +++ b/expander_compiler/tests/circuit/mod.rs @@ -0,0 +1,16 @@ +mod example; +mod example_call_expander; +mod keccak_gf2; +mod keccak_gf2_full; +mod keccak_gf2_full_crosslayer; +mod keccak_gf2_vec; +mod keccak_non_gf2; + +mod mul_fanout_limit; +mod multithreading_witness; + +mod simple_add_m31; +mod sub_circuit_macro; +mod to_binary_builtin; +mod to_binary_hint; +mod to_binary_unconstrained_api; diff --git a/expander_compiler/tests/mod.rs b/expander_compiler/tests/mod.rs new file mode 100644 index 00000000..e9f75e05 --- /dev/null +++ b/expander_compiler/tests/mod.rs @@ -0,0 +1,2 @@ +mod circuit; +mod zkcuda; diff --git a/expander_compiler/tests/zkcuda/mod.rs b/expander_compiler/tests/zkcuda/mod.rs new file mode 100644 index 00000000..07bcfda6 --- /dev/null +++ b/expander_compiler/tests/zkcuda/mod.rs @@ -0,0 +1,3 @@ +mod zkcuda_examples; +mod zkcuda_keccak; +mod zkcuda_matmul; From 1323f8841c05dedbca0b2b017e8d18c1e75ac43d Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Thu, 17 Jul 2025 21:05:25 -0500 Subject: [PATCH 40/60] testing script --- .../bin/integration/circuit_def.rs | 25 +++++++++---------- expander_compiler/bin/integration/prove.rs | 1 + expander_compiler/bin/integration/run.sh | 15 +++++++++++ 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/expander_compiler/bin/integration/circuit_def.rs b/expander_compiler/bin/integration/circuit_def.rs index dedb02ad..6c54e382 100644 --- a/expander_compiler/bin/integration/circuit_def.rs +++ b/expander_compiler/bin/integration/circuit_def.rs @@ -1,16 +1,14 @@ -use expander_compiler::frontend::*; -use expander_compiler::zkcuda::proving_system::expander::config::{ - ZKCudaBN254Hyrax, ZKCudaBN254HyraxBatchPCS, ZKCudaBN254KZG, ZKCudaBN254KZGBatchPCS, -}; -use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; -use expander_compiler::zkcuda::proving_system::{ - Expander, ExpanderNoOverSubscribe, ParallelizedExpander, ProvingSystem, +#![allow(clippy::ptr_arg)] +#![allow(clippy::needless_range_loop)] + +use expander_compiler::frontend::{ + BasicAPI, CircuitField, Config, Error, SIMDField, Variable, API, }; use expander_compiler::zkcuda::shape::Reshape; -use expander_compiler::zkcuda::{context::*, kernel::*}; - -use gkr::BN254ConfigSha2Hyrax; -use serdes::ExpSerde; +use expander_compiler::zkcuda::{ + context::{call_kernel, ComputationGraph, Context, DeviceMemoryHandle}, + kernel::{compile_with_spec_and_shapes, kernel, IOVecSpec, KernelPrimitive}, +}; #[kernel] fn add_2_macro(api: &mut API, a: &[InputVariable; 2], b: &mut OutputVariable) { @@ -26,6 +24,7 @@ fn add_16_macro(api: &mut API, a: &[InputVariable; 16], b: &mut Ou *b = sum; } +#[allow(clippy::type_complexity)] pub fn gen_computation_graph_and_witness( input: Option>>>, ) -> (ComputationGraph, Option>>>) { @@ -47,7 +46,7 @@ pub fn gen_computation_graph_and_witness( tmp }; - let expected_result = a.iter().flat_map(|v| v).sum::>(); + let expected_result = a.iter().flatten().sum::>(); let a = ctx.copy_to_device(&a); let mut b: DeviceMemoryHandle = None; @@ -61,7 +60,7 @@ pub fn gen_computation_graph_and_witness( let computation_graph = ctx.compile_computation_graph().unwrap(); - let extended_witness = if let Some(_) = input { + let extended_witness = if input.is_some() { ctx.solve_witness().unwrap(); Some(ctx.export_device_memories()) } else { diff --git a/expander_compiler/bin/integration/prove.rs b/expander_compiler/bin/integration/prove.rs index 7d677125..4595a793 100644 --- a/expander_compiler/bin/integration/prove.rs +++ b/expander_compiler/bin/integration/prove.rs @@ -11,6 +11,7 @@ use expander_compiler::{ }; use serdes::ExpSerde; +#[allow(clippy::needless_range_loop)] fn main() { let mut input = vec![vec![]; 16]; for i in 0..16 { diff --git a/expander_compiler/bin/integration/run.sh b/expander_compiler/bin/integration/run.sh index e69de29b..59654858 100644 --- a/expander_compiler/bin/integration/run.sh +++ b/expander_compiler/bin/integration/run.sh @@ -0,0 +1,15 @@ + +cargo build --release --bin zkcuda_setup --bin zkcuda_prove --bin zkcuda_verify --bin zkcuda_cleanup + +cargo run --release --bin zkcuda_setup + +# prove a first instance +cargo run --release --bin zkcuda_prove +cargo run --release --bin zkcuda_verify + +# prove a second instance +cargo run --release --bin zkcuda_prove +cargo run --release --bin zkcuda_verify + +# shutdown the server +cargo run --release --bin zkcuda_cleanup From 61eaeced15c355ba5ee79d5e4cafe9db7986d7de Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Thu, 17 Jul 2025 21:11:54 -0500 Subject: [PATCH 41/60] simple renaming --- expander_compiler/Cargo.toml | 8 ++++---- .../{integration => zkcuda_integration}/circuit_def.rs | 0 .../bin/{integration => zkcuda_integration}/cleanup.rs | 0 .../bin/{integration => zkcuda_integration}/prove.rs | 1 + .../bin/{integration => zkcuda_integration}/run.sh | 2 ++ .../bin/{integration => zkcuda_integration}/setup.rs | 0 .../bin/{integration => zkcuda_integration}/verify.rs | 0 7 files changed, 7 insertions(+), 4 deletions(-) rename expander_compiler/bin/{integration => zkcuda_integration}/circuit_def.rs (100%) rename expander_compiler/bin/{integration => zkcuda_integration}/cleanup.rs (100%) rename expander_compiler/bin/{integration => zkcuda_integration}/prove.rs (96%) rename expander_compiler/bin/{integration => zkcuda_integration}/run.sh (93%) mode change 100644 => 100755 rename expander_compiler/bin/{integration => zkcuda_integration}/setup.rs (100%) rename expander_compiler/bin/{integration => zkcuda_integration}/verify.rs (100%) diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 2a6fa7b5..6058a03c 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -79,16 +79,16 @@ path = "bin/zkcuda_servers/zkcuda_matmul_no_oversubscribe.rs" [[bin]] name = "zkcuda_setup" -path = "bin/integration/setup.rs" +path = "bin/zkcuda_integration/setup.rs" [[bin]] name = "zkcuda_prove" -path = "bin/integration/prove.rs" +path = "bin/zkcuda_integration/prove.rs" [[bin]] name = "zkcuda_verify" -path = "bin/integration/verify.rs" +path = "bin/zkcuda_integration/verify.rs" [[bin]] name = "zkcuda_cleanup" -path = "bin/integration/cleanup.rs" +path = "bin/zkcuda_integration/cleanup.rs" diff --git a/expander_compiler/bin/integration/circuit_def.rs b/expander_compiler/bin/zkcuda_integration/circuit_def.rs similarity index 100% rename from expander_compiler/bin/integration/circuit_def.rs rename to expander_compiler/bin/zkcuda_integration/circuit_def.rs diff --git a/expander_compiler/bin/integration/cleanup.rs b/expander_compiler/bin/zkcuda_integration/cleanup.rs similarity index 100% rename from expander_compiler/bin/integration/cleanup.rs rename to expander_compiler/bin/zkcuda_integration/cleanup.rs diff --git a/expander_compiler/bin/integration/prove.rs b/expander_compiler/bin/zkcuda_integration/prove.rs similarity index 96% rename from expander_compiler/bin/integration/prove.rs rename to expander_compiler/bin/zkcuda_integration/prove.rs index 4595a793..1a393eff 100644 --- a/expander_compiler/bin/integration/prove.rs +++ b/expander_compiler/bin/zkcuda_integration/prove.rs @@ -13,6 +13,7 @@ use serdes::ExpSerde; #[allow(clippy::needless_range_loop)] fn main() { + // Replace this with your actual input data. let mut input = vec![vec![]; 16]; for i in 0..16 { for j in 0..2 { diff --git a/expander_compiler/bin/integration/run.sh b/expander_compiler/bin/zkcuda_integration/run.sh old mode 100644 new mode 100755 similarity index 93% rename from expander_compiler/bin/integration/run.sh rename to expander_compiler/bin/zkcuda_integration/run.sh index 59654858..206119cd --- a/expander_compiler/bin/integration/run.sh +++ b/expander_compiler/bin/zkcuda_integration/run.sh @@ -1,6 +1,8 @@ +#!/bin/bash cargo build --release --bin zkcuda_setup --bin zkcuda_prove --bin zkcuda_verify --bin zkcuda_cleanup +# setup the server cargo run --release --bin zkcuda_setup # prove a first instance diff --git a/expander_compiler/bin/integration/setup.rs b/expander_compiler/bin/zkcuda_integration/setup.rs similarity index 100% rename from expander_compiler/bin/integration/setup.rs rename to expander_compiler/bin/zkcuda_integration/setup.rs diff --git a/expander_compiler/bin/integration/verify.rs b/expander_compiler/bin/zkcuda_integration/verify.rs similarity index 100% rename from expander_compiler/bin/integration/verify.rs rename to expander_compiler/bin/zkcuda_integration/verify.rs From 1a4e7a95cc9f0863442b96d47cf5c85650635120 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Thu, 17 Jul 2025 22:01:37 -0500 Subject: [PATCH 42/60] client drop witness after serialization --- expander_compiler/Cargo.toml | 6 +++--- .../zkcuda_matmul.rs | 2 +- .../zkcuda_matmul_no_oversubscribe.rs | 0 .../zkcuda_matmul_pcs_defered.rs | 0 expander_compiler/bin/zkcuda_integration/prove.rs | 2 +- .../src/zkcuda/proving_system/dummy.rs | 2 +- .../proving_system/expander/api_single_thread.rs | 2 +- .../api_no_oversubscribe.rs | 2 +- .../expander_parallelized/api_parallel.rs | 2 +- .../expander_parallelized/client_utils.rs | 6 ++---- .../expander_parallelized/shared_memory_utils.rs | 12 +++++------- .../expander_pcs_defered/api_pcs_defered.rs | 2 +- .../src/zkcuda/proving_system/traits.rs | 2 +- expander_compiler/src/zkcuda/tests.rs | 2 +- expander_compiler/tests/zkcuda/zkcuda_examples.rs | 14 +++++++------- expander_compiler/tests/zkcuda/zkcuda_keccak.rs | 4 ++-- expander_compiler/tests/zkcuda/zkcuda_matmul.rs | 2 +- 17 files changed, 29 insertions(+), 33 deletions(-) rename expander_compiler/bin/{zkcuda_servers => zkcuda_bench}/zkcuda_matmul.rs (98%) rename expander_compiler/bin/{zkcuda_servers => zkcuda_bench}/zkcuda_matmul_no_oversubscribe.rs (100%) rename expander_compiler/bin/{zkcuda_servers => zkcuda_bench}/zkcuda_matmul_pcs_defered.rs (100%) diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 6058a03c..00fa500a 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -67,15 +67,15 @@ path = "src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs" [[bin]] name = "zkcuda_matmul" -path = "bin/zkcuda_servers/zkcuda_matmul.rs" +path = "bin/zkcuda_bench/zkcuda_matmul.rs" [[bin]] name = "zkcuda_matmul_pcs_defered" -path = "bin/zkcuda_servers/zkcuda_matmul_pcs_defered.rs" +path = "bin/zkcuda_bench/zkcuda_matmul_pcs_defered.rs" [[bin]] name = "zkcuda_matmul_no_oversubscribe" -path = "bin/zkcuda_servers/zkcuda_matmul_no_oversubscribe.rs" +path = "bin/zkcuda_bench/zkcuda_matmul_no_oversubscribe.rs" [[bin]] name = "zkcuda_setup" diff --git a/expander_compiler/bin/zkcuda_servers/zkcuda_matmul.rs b/expander_compiler/bin/zkcuda_bench/zkcuda_matmul.rs similarity index 98% rename from expander_compiler/bin/zkcuda_servers/zkcuda_matmul.rs rename to expander_compiler/bin/zkcuda_bench/zkcuda_matmul.rs index a0ac65c1..17425403 100644 --- a/expander_compiler/bin/zkcuda_servers/zkcuda_matmul.rs +++ b/expander_compiler/bin/zkcuda_bench/zkcuda_matmul.rs @@ -81,7 +81,7 @@ pub fn zkcuda_matmul, const N: usize>() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); let elapsed = timer.elapsed(); println!("Parallel Count {N}, Proving time: {elapsed:?}"); diff --git a/expander_compiler/bin/zkcuda_servers/zkcuda_matmul_no_oversubscribe.rs b/expander_compiler/bin/zkcuda_bench/zkcuda_matmul_no_oversubscribe.rs similarity index 100% rename from expander_compiler/bin/zkcuda_servers/zkcuda_matmul_no_oversubscribe.rs rename to expander_compiler/bin/zkcuda_bench/zkcuda_matmul_no_oversubscribe.rs diff --git a/expander_compiler/bin/zkcuda_servers/zkcuda_matmul_pcs_defered.rs b/expander_compiler/bin/zkcuda_bench/zkcuda_matmul_pcs_defered.rs similarity index 100% rename from expander_compiler/bin/zkcuda_servers/zkcuda_matmul_pcs_defered.rs rename to expander_compiler/bin/zkcuda_bench/zkcuda_matmul_pcs_defered.rs diff --git a/expander_compiler/bin/zkcuda_integration/prove.rs b/expander_compiler/bin/zkcuda_integration/prove.rs index 1a393eff..0e77c7a2 100644 --- a/expander_compiler/bin/zkcuda_integration/prove.rs +++ b/expander_compiler/bin/zkcuda_integration/prove.rs @@ -32,7 +32,7 @@ fn main() { let proof = ExpanderNoOverSubscribe::::prove( &dummy_prover_setup, &dummy_computation_graph, - &extended_witness.unwrap(), + extended_witness.unwrap(), ); let mut bytes = vec![]; diff --git a/expander_compiler/src/zkcuda/proving_system/dummy.rs b/expander_compiler/src/zkcuda/proving_system/dummy.rs index 56f2717a..b18beb6b 100644 --- a/expander_compiler/src/zkcuda/proving_system/dummy.rs +++ b/expander_compiler/src/zkcuda/proving_system/dummy.rs @@ -146,7 +146,7 @@ impl ProvingSystem for DummyProvingSystem { fn prove( prover_setup: &Self::ProverSetup, computation_graph: &ComputationGraph, - device_memories: &[Vec>], + device_memories: Vec>>, ) -> Self::Proof { let (commitments, states) = device_memories .iter() diff --git a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs index f301e672..9091558a 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs @@ -185,7 +185,7 @@ impl> ProvingSyste fn prove( prover_setup: &Self::ProverSetup, computation_graph: &ComputationGraph, - device_memories: &[Vec>], + device_memories: Vec>>, ) -> Self::Proof { let (commitments, states) = device_memories .iter() diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs index 5a215097..7d7fed98 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -45,7 +45,7 @@ where fn prove( _prover_setup: &Self::ProverSetup, _computation_graph: &ComputationGraph, - device_memories: &[Vec>], + device_memories: Vec>>, ) -> Self::Proof { client_send_witness_and_prove(device_memories) } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs index 1862af98..071044e6 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs @@ -45,7 +45,7 @@ impl> ProvingSyste fn prove( _prover_setup: &Self::ProverSetup, _computation_graph: &crate::zkcuda::context::ComputationGraph, - device_memories: &[Vec>], + device_memories: Vec>>, ) -> Self::Proof { client_send_witness_and_prove(device_memories) } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index fd1a48dc..42315b39 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -140,7 +140,7 @@ where } pub fn client_send_witness_and_prove( - device_memories: &[Vec>], + device_memories: Vec>>, ) -> CombinedProof> where C: GKREngine, @@ -148,9 +148,7 @@ where { let timer = Timer::new("prove", true); - SharedMemoryEngine::write_witness_to_shared_memory::( - &device_memories.iter().map(|m| &m[..]).collect::>(), - ); + SharedMemoryEngine::write_witness_to_shared_memory::(device_memories); wait_async(ClientHttpHelper::request_prove()); let proof = SharedMemoryEngine::read_proof_from_shared_memory(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs index 5f5393a3..56ed79d8 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs @@ -106,13 +106,11 @@ impl SharedMemoryEngine { Self::read_object_from_shared_memory("pcs_setup", 0) } - pub fn write_witness_to_shared_memory( - values: &[impl AsRef<[F::SimdCircuitField]>], - ) { + pub fn write_witness_to_shared_memory(values: Vec>) { let total_size = std::mem::size_of::() + values .iter() - .map(|v| std::mem::size_of::() + std::mem::size_of_val(v.as_ref())) + .map(|v| std::mem::size_of::() + std::mem::size_of_val(v)) .sum::(); println!("Writing witness to shared memory, total size: {total_size}"); @@ -132,13 +130,13 @@ impl SharedMemoryEngine { ptr = ptr.add(std::mem::size_of::()); for vals in values { - let vals_len = vals.as_ref().len(); + let vals_len = vals.len(); let len_ptr = &vals_len as *const usize as *const u8; std::ptr::copy_nonoverlapping(len_ptr, ptr, std::mem::size_of::()); ptr = ptr.add(std::mem::size_of::()); - let vals_size = std::mem::size_of_val(vals.as_ref()); - std::ptr::copy_nonoverlapping(vals.as_ref().as_ptr() as *const u8, ptr, vals_size); + let vals_size = std::mem::size_of_val(&vals); + std::ptr::copy_nonoverlapping(vals.as_ptr() as *const u8, ptr, vals_size); ptr = ptr.add(vals_size); } } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs index daa93c08..0d4d9d39 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs @@ -46,7 +46,7 @@ where fn prove( _prover_setup: &Self::ProverSetup, _computation_graph: &crate::zkcuda::context::ComputationGraph, - device_memories: &[Vec>], + device_memories: Vec>>, ) -> Self::Proof { client_send_witness_and_prove(device_memories) } diff --git a/expander_compiler/src/zkcuda/proving_system/traits.rs b/expander_compiler/src/zkcuda/proving_system/traits.rs index cc791c50..539a0f43 100644 --- a/expander_compiler/src/zkcuda/proving_system/traits.rs +++ b/expander_compiler/src/zkcuda/proving_system/traits.rs @@ -70,7 +70,7 @@ pub trait ProvingSystem { fn prove( prover_setup: &Self::ProverSetup, computation_graph: &ComputationGraph, - device_memories: &[Vec>], + device_memories: Vec>>, ) -> Self::Proof; fn verify( diff --git a/expander_compiler/src/zkcuda/tests.rs b/expander_compiler/src/zkcuda/tests.rs index ad7c609a..715dd77d 100644 --- a/expander_compiler/src/zkcuda/tests.rs +++ b/expander_compiler/src/zkcuda/tests.rs @@ -119,7 +119,7 @@ fn context_shape_test_1_impl>() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); P::post_process(); diff --git a/expander_compiler/tests/zkcuda/zkcuda_examples.rs b/expander_compiler/tests/zkcuda/zkcuda_examples.rs index 727f08d0..328f9926 100644 --- a/expander_compiler/tests/zkcuda/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda/zkcuda_examples.rs @@ -56,7 +56,7 @@ fn zkcuda_test>() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); P::post_process(); @@ -149,7 +149,7 @@ fn zkcuda_test_simd() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); @@ -177,7 +177,7 @@ fn zkcuda_test_simd() { let proof3 = P::prove( &prover_setup3, &computation_graph, - &ctx3.export_device_memories(), + ctx3.export_device_memories(), ); assert!(P::verify(&verifier_setup2, &computation_graph, &proof3)); } @@ -222,7 +222,7 @@ fn zkcuda_test_simd_autopack() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } @@ -285,7 +285,7 @@ fn zkcuda_to_binary() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } @@ -311,7 +311,7 @@ fn zkcuda_assertion() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } @@ -333,7 +333,7 @@ fn zkcuda_assertion_fail() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } diff --git a/expander_compiler/tests/zkcuda/zkcuda_keccak.rs b/expander_compiler/tests/zkcuda/zkcuda_keccak.rs index f8bfeb7d..995637a0 100644 --- a/expander_compiler/tests/zkcuda/zkcuda_keccak.rs +++ b/expander_compiler/tests/zkcuda/zkcuda_keccak.rs @@ -353,7 +353,7 @@ fn zkcuda_keccak_1_helper>() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); println!("proof generation ok"); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); @@ -416,7 +416,7 @@ fn zkcuda_keccak_2_helper>() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); println!("proof generation ok"); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); diff --git a/expander_compiler/tests/zkcuda/zkcuda_matmul.rs b/expander_compiler/tests/zkcuda/zkcuda_matmul.rs index 605d449b..7b20e63d 100644 --- a/expander_compiler/tests/zkcuda/zkcuda_matmul.rs +++ b/expander_compiler/tests/zkcuda/zkcuda_matmul.rs @@ -93,7 +93,7 @@ fn zkcuda_matmul_sum() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } From ccb39885aeee1d06f1b1c282f011f122cf69bc4d Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Thu, 17 Jul 2025 22:11:36 -0500 Subject: [PATCH 43/60] bug fix --- circuit-std-rs/tests/logup.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/circuit-std-rs/tests/logup.rs b/circuit-std-rs/tests/logup.rs index 3181d3d3..d22a7177 100644 --- a/circuit-std-rs/tests/logup.rs +++ b/circuit-std-rs/tests/logup.rs @@ -165,7 +165,7 @@ fn rangeproof_zkcuda_test() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } @@ -192,7 +192,7 @@ fn rangeproof_zkcuda_test_fail() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } From 59370d53922641eb78259f58038ccd636fb7c0da Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Thu, 17 Jul 2025 23:42:19 -0500 Subject: [PATCH 44/60] bug fix --- .../expander_parallelized/shared_memory_utils.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs index 56ed79d8..648f33a8 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs @@ -110,7 +110,7 @@ impl SharedMemoryEngine { let total_size = std::mem::size_of::() + values .iter() - .map(|v| std::mem::size_of::() + std::mem::size_of_val(v)) + .map(|v| std::mem::size_of::() + std::mem::size_of_val(v.as_slice())) .sum::(); println!("Writing witness to shared memory, total size: {total_size}"); @@ -135,7 +135,7 @@ impl SharedMemoryEngine { std::ptr::copy_nonoverlapping(len_ptr, ptr, std::mem::size_of::()); ptr = ptr.add(std::mem::size_of::()); - let vals_size = std::mem::size_of_val(&vals); + let vals_size = std::mem::size_of_val(vals.as_slice()); std::ptr::copy_nonoverlapping(vals.as_ptr() as *const u8, ptr, vals_size); ptr = ptr.add(vals_size); } From fe9b62f37bbef170fdc10bc80b6bcfeb240535cb Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 20 Jul 2025 20:20:36 -0500 Subject: [PATCH 45/60] a simple profiler to detect the number of bytes used to store bn254 fields --- .../expander_no_oversubscribe.rs | 1 + .../expander_no_oversubscribe/profiler.rs | 78 +++++++++++++++++++ .../expander_no_oversubscribe/prove_impl.rs | 7 +- .../expander_no_oversubscribe/server_fn.rs | 15 ++++ 4 files changed, 98 insertions(+), 3 deletions(-) create mode 100644 expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs index 0ce09b0b..65954def 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs @@ -1,3 +1,4 @@ pub mod api_no_oversubscribe; +pub mod profiler; pub mod prove_impl; pub mod server_fn; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs new file mode 100644 index 00000000..7c239d00 --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs @@ -0,0 +1,78 @@ +#[cfg(feature = "profile")] +mod profiler_enabled { + use std::collections::HashMap; + + use arith::Fr; + use halo2curves::ff::PrimeField; + + #[derive(Clone, Debug, Default)] + pub struct NBytesProfiler { + pub bytes_stats: HashMap, + } + + impl NBytesProfiler { + pub fn new() -> Self { + NBytesProfiler { + bytes_stats: HashMap::new(), + } + } + + pub fn add_bytes(&mut self, n_bytes: usize) { + *self.bytes_stats.entry(n_bytes).or_insert(0) += 1; + } + + pub fn add_fr(&mut self, fr: Fr) { + let le_bytes = fr.to_repr(); + let be_leading_zeros_bytes = le_bytes.into_iter().rev().take_while(|&b| b == 0).count(); + let n_bytes = le_bytes.len() - be_leading_zeros_bytes; + self.add_bytes(n_bytes); + } + + pub fn print_stats(&self) { + for (bytes, count) in &self.bytes_stats { + println!("{} bytes: {}", bytes, count); + } + } + } +} + +#[cfg(not(feature = "profile"))] +mod profiler_disabled { + use arith::Fr; + + #[derive(Clone, Debug, Default)] + pub struct NBytesProfiler; + + impl NBytesProfiler { + pub fn new() -> Self { + NBytesProfiler + } + + pub fn add_bytes(&mut self, _n_bytes: usize) {} + + pub fn add_fr(&mut self, _fr: Fr) {} + + pub fn print_stats(&self) {} + } +} + +#[cfg(not(feature = "profile"))] +pub use profiler_disabled::NBytesProfiler; +#[cfg(feature = "profile")] +pub use profiler_enabled::NBytesProfiler; + +#[cfg(feature = "profile")] +mod test { + use arith::Fr; + + use super::profiler_enabled::NBytesProfiler; + + #[test] + fn test_n_bytes_profiler() { + let mut profiler = NBytesProfiler::new(); + profiler.add_bytes(32); + profiler.add_bytes(64); + profiler.add_fr(Fr::from(256u64)); + profiler.print_stats(); + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index 7f8ea82b..ec2289f5 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -310,7 +310,7 @@ pub fn prove_kernel_gkr_internal( is_broadcast: &[bool], ) -> Option<(T, ExpanderDualVarChallenge)> where - FBasic: FieldEngine, + FBasic: FieldEngine, FMulti: FieldEngine, T: Transcript, @@ -372,7 +372,7 @@ pub fn prove_gkr_with_local_vals_multi_copies( mpi_config: &MPIConfig, ) -> ExpanderDualVarChallenge where - FBasic: FieldEngine, + FBasic: FieldEngine, FMulti: FieldEngine, T: Transcript, @@ -390,6 +390,7 @@ where let mut input_vals = vec![FMulti::SimdCircuitField::ZERO; 1 << expander_circuit.log_input_size()]; + for (i, vals) in input_vals.iter_mut().enumerate() { let vals_unpacked = input_vals_multi_copies .iter() @@ -403,7 +404,7 @@ where expander_circuit.evaluate(); let (claimed_v, challenge) = gkr::gkr_prove(expander_circuit, prover_scratch, transcript, mpi_config); - assert_eq!(claimed_v, FBasic::ChallengeField::from(0)); + assert_eq!(claimed_v, FBasic::ChallengeField::from(0u32)); let n_simd_vars_basic = FBasic::SimdCircuitField::PACK_SIZE.ilog2() as usize; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index 9e8f7a40..cc84e672 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -56,6 +56,21 @@ where computation_graph: &ComputationGraph, values: &[impl AsRef<[SIMDField]>], ) -> Option>> { + #[cfg(feature = "profile")] + { + use arith::SimdField; + use expander_no_oversubscribe::prove_impl::NBytesProfiler; + + let n_bytes_profiler = NBytesProfiler::new(); + values.iter().for_each(|vals| { + vals.as_ref().iter().for_each(|fr| { + let fr_unpacked = fr.unpack(); + assert!(fr_unpacked.len() == 1); + n_bytes_profiler.add_fr(fr_unpacked[0]); + }); + }); + } + mpi_prove_no_oversubscribe_impl::( global_mpi_config, prover_setup, From e28e367efa798a83bfc42cc35abeffd39526e37d Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 20 Jul 2025 20:25:18 -0500 Subject: [PATCH 46/60] bug fix --- .../proving_system/expander_no_oversubscribe/profiler.rs | 6 +----- .../proving_system/expander_no_oversubscribe/server_fn.rs | 4 ++-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs index 7c239d00..1c993df7 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs @@ -30,7 +30,7 @@ mod profiler_enabled { pub fn print_stats(&self) { for (bytes, count) in &self.bytes_stats { - println!("{} bytes: {}", bytes, count); + println!("{bytes} bytes: {count}"); } } } @@ -63,10 +63,6 @@ pub use profiler_enabled::NBytesProfiler; #[cfg(feature = "profile")] mod test { - use arith::Fr; - - use super::profiler_enabled::NBytesProfiler; - #[test] fn test_n_bytes_profiler() { let mut profiler = NBytesProfiler::new(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index cc84e672..e5ee4d69 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -59,9 +59,9 @@ where #[cfg(feature = "profile")] { use arith::SimdField; - use expander_no_oversubscribe::prove_impl::NBytesProfiler; + use crate::zkcuda::proving_system::expander_no_oversubscribe::profiler::NBytesProfiler; - let n_bytes_profiler = NBytesProfiler::new(); + let mut n_bytes_profiler = NBytesProfiler::new(); values.iter().for_each(|vals| { vals.as_ref().iter().for_each(|fr| { let fr_unpacked = fr.unpack(); From fdda64bee5177fd1d8d90dabd793d9c872219409 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 20 Jul 2025 20:40:14 -0500 Subject: [PATCH 47/60] make compiler happy about unused import --- .../proving_system/expander_no_oversubscribe/profiler.rs | 4 ++++ .../proving_system/expander_no_oversubscribe/server_fn.rs | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs index 1c993df7..5b54d7c1 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs @@ -63,6 +63,10 @@ pub use profiler_enabled::NBytesProfiler; #[cfg(feature = "profile")] mod test { + #![allow(unused_imports)] + use super::NBytesProfiler; + use arith::Fr; + #[test] fn test_n_bytes_profiler() { let mut profiler = NBytesProfiler::new(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index e5ee4d69..cb9b6630 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -58,8 +58,8 @@ where ) -> Option>> { #[cfg(feature = "profile")] { - use arith::SimdField; use crate::zkcuda::proving_system::expander_no_oversubscribe::profiler::NBytesProfiler; + use arith::SimdField; let mut n_bytes_profiler = NBytesProfiler::new(); values.iter().for_each(|vals| { From 7a1abc7579a9c5073da93503eeb461f24e9acd41 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 20 Jul 2025 20:42:23 -0500 Subject: [PATCH 48/60] minor --- .../zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index cb9b6630..57f9c237 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -69,6 +69,7 @@ where n_bytes_profiler.add_fr(fr_unpacked[0]); }); }); + n_bytes_profiler.print_stats(); } mpi_prove_no_oversubscribe_impl::( From 58770c322f10bac4d6a408e5f7f63f64875247db Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 20 Jul 2025 21:02:31 -0500 Subject: [PATCH 49/60] including intermediate evaluations --- expander_compiler/Cargo.toml | 1 + .../expander_no_oversubscribe/profiler.rs | 10 +++--- .../expander_no_oversubscribe/prove_impl.rs | 31 +++++++++++++++++++ .../expander_no_oversubscribe/server_fn.rs | 23 +++++++++----- 4 files changed, 53 insertions(+), 12 deletions(-) diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 00fa500a..ffed0fa4 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -48,6 +48,7 @@ sha2 = "0.10.8" [features] default = [] profile = ["expander_utils/profile"] +zkcuda_profile = [] [[bin]] name = "trivial_circuit" diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs index 5b54d7c1..ed9421cc 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs @@ -1,4 +1,4 @@ -#[cfg(feature = "profile")] +#[cfg(feature = "zkcuda_profile")] mod profiler_enabled { use std::collections::HashMap; @@ -36,7 +36,7 @@ mod profiler_enabled { } } -#[cfg(not(feature = "profile"))] +#[cfg(not(feature = "zkcuda_profile"))] mod profiler_disabled { use arith::Fr; @@ -56,12 +56,12 @@ mod profiler_disabled { } } -#[cfg(not(feature = "profile"))] +#[cfg(not(feature = "zkcuda_profile"))] pub use profiler_disabled::NBytesProfiler; -#[cfg(feature = "profile")] +#[cfg(feature = "zkcuda_profile")] pub use profiler_enabled::NBytesProfiler; -#[cfg(feature = "profile")] +#[cfg(feature = "zkcuda_profile")] mod test { #![allow(unused_imports)] use super::NBytesProfiler; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index ec2289f5..a636a117 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -20,6 +20,7 @@ use crate::{ }, structs::{ExpanderProof, ExpanderProverSetup}, }, + expander_no_oversubscribe::profiler::NBytesProfiler, expander_parallelized::{ prove_impl::partition_single_gkr_claim_and_open_pcs_mpi, server_ctrl::generate_local_mpi_config, @@ -37,6 +38,7 @@ pub fn mpi_prove_no_oversubscribe_impl( prover_setup: &ExpanderProverSetup, GetPCS>, computation_graph: &ComputationGraph, values: &[impl AsRef<[SIMDField]>], + n_bytes_profiler: &mut NBytesProfiler, ) -> Option>> where ::FieldConfig: FieldEngine, @@ -88,6 +90,7 @@ where &commitment_values, next_power_of_two(template.parallel_count()), template.is_broadcast(), + n_bytes_profiler, ); single_kernel_gkr_timer.stop(); @@ -197,6 +200,7 @@ pub fn prove_kernel_gkr_no_oversubscribe( commitments_values: &[&[F::SimdCircuitField]], parallel_count: usize, is_broadcast: &[bool], + n_bytes_profiler: &mut NBytesProfiler, ) -> Option<(T, ExpanderDualVarChallenge)> where F: FieldEngine, @@ -218,6 +222,7 @@ where commitments_values, parallel_count, is_broadcast, + n_bytes_profiler, ), 2 => prove_kernel_gkr_internal::, T, ECCConfig>( &local_mpi_config, @@ -225,6 +230,7 @@ where commitments_values, parallel_count, is_broadcast, + n_bytes_profiler, ), 4 => prove_kernel_gkr_internal::, T, ECCConfig>( &local_mpi_config, @@ -232,6 +238,7 @@ where commitments_values, parallel_count, is_broadcast, + n_bytes_profiler, ), 8 => prove_kernel_gkr_internal::, T, ECCConfig>( &local_mpi_config, @@ -239,6 +246,7 @@ where commitments_values, parallel_count, is_broadcast, + n_bytes_profiler, ), 16 => prove_kernel_gkr_internal::, T, ECCConfig>( &local_mpi_config, @@ -246,6 +254,7 @@ where commitments_values, parallel_count, is_broadcast, + n_bytes_profiler, ), 32 => prove_kernel_gkr_internal::, T, ECCConfig>( &local_mpi_config, @@ -253,6 +262,7 @@ where commitments_values, parallel_count, is_broadcast, + n_bytes_profiler, ), 64 => prove_kernel_gkr_internal::, T, ECCConfig>( &local_mpi_config, @@ -260,6 +270,7 @@ where commitments_values, parallel_count, is_broadcast, + n_bytes_profiler, ), 128 => prove_kernel_gkr_internal::, T, ECCConfig>( &local_mpi_config, @@ -267,6 +278,7 @@ where commitments_values, parallel_count, is_broadcast, + n_bytes_profiler, ), 256 => prove_kernel_gkr_internal::, T, ECCConfig>( &local_mpi_config, @@ -274,6 +286,7 @@ where commitments_values, parallel_count, is_broadcast, + n_bytes_profiler, ), 512 => prove_kernel_gkr_internal::, T, ECCConfig>( &local_mpi_config, @@ -281,6 +294,7 @@ where commitments_values, parallel_count, is_broadcast, + n_bytes_profiler, ), 1024 => prove_kernel_gkr_internal::, T, ECCConfig>( &local_mpi_config, @@ -288,6 +302,7 @@ where commitments_values, parallel_count, is_broadcast, + n_bytes_profiler, ), 2048 => prove_kernel_gkr_internal::, T, ECCConfig>( &local_mpi_config, @@ -295,6 +310,7 @@ where commitments_values, parallel_count, is_broadcast, + n_bytes_profiler, ), _ => { panic!("Unsupported parallel count: {parallel_count}"); @@ -308,6 +324,7 @@ pub fn prove_kernel_gkr_internal( commitments_values: &[&[FBasic::SimdCircuitField]], parallel_count: usize, is_broadcast: &[bool], + n_bytes_profiler: &mut NBytesProfiler, ) -> Option<(T, ExpanderDualVarChallenge)> where FBasic: FieldEngine, @@ -339,6 +356,7 @@ where kernel.layered_circuit_input(), &mut transcript, mpi_config, + n_bytes_profiler, ); Some((transcript, challenge)) @@ -370,6 +388,7 @@ pub fn prove_gkr_with_local_vals_multi_copies( partition_info: &[LayeredCircuitInputVec], transcript: &mut T, mpi_config: &MPIConfig, + _n_bytes_profiler: &mut NBytesProfiler, ) -> ExpanderDualVarChallenge where FBasic: FieldEngine, @@ -402,6 +421,18 @@ where expander_circuit.fill_rnd_coefs(transcript); expander_circuit.evaluate(); + + #[cfg(feature = "zkcuda_profile")] + { + expander_circuit.layers.iter().for_each(|layer| { + layer.input_vals.iter().for_each(|val| { + val.unpack().iter().for_each(|fr| { + _n_bytes_profiler.add_fr(*fr); + }) + }); + }); + } + let (claimed_v, challenge) = gkr::gkr_prove(expander_circuit, prover_scratch, transcript, mpi_config); assert_eq!(claimed_v, FBasic::ChallengeField::from(0u32)); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index 57f9c237..0c31b11b 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -1,5 +1,5 @@ use arith::Fr; -use gkr_engine::{FieldEngine, GKREngine, MPIConfig}; +use gkr_engine::{FieldEngine, GKREngine, MPIConfig, MPIEngine}; use crate::{ frontend::SIMDField, @@ -10,7 +10,9 @@ use crate::{ config::{GetFieldConfig, GetPCS, ZKCudaConfig}, structs::{ExpanderProverSetup, ExpanderVerifierSetup}, }, - expander_no_oversubscribe::prove_impl::mpi_prove_no_oversubscribe_impl, + expander_no_oversubscribe::{ + profiler::NBytesProfiler, prove_impl::mpi_prove_no_oversubscribe_impl, + }, expander_parallelized::{server_ctrl::SharedMemoryWINWrapper, server_fns::ServerFns}, CombinedProof, Expander, ExpanderNoOverSubscribe, ExpanderPCSDefered, ParallelizedExpander, @@ -56,12 +58,12 @@ where computation_graph: &ComputationGraph, values: &[impl AsRef<[SIMDField]>], ) -> Option>> { - #[cfg(feature = "profile")] + let mut n_bytes_profiler = NBytesProfiler::new(); + + #[cfg(feature = "zkcuda_profile")] { - use crate::zkcuda::proving_system::expander_no_oversubscribe::profiler::NBytesProfiler; use arith::SimdField; - let mut n_bytes_profiler = NBytesProfiler::new(); values.iter().for_each(|vals| { vals.as_ref().iter().for_each(|fr| { let fr_unpacked = fr.unpack(); @@ -72,11 +74,18 @@ where n_bytes_profiler.print_stats(); } - mpi_prove_no_oversubscribe_impl::( + let proof = mpi_prove_no_oversubscribe_impl::( global_mpi_config, prover_setup, computation_graph, values, - ) + &mut n_bytes_profiler, + ); + + if global_mpi_config.is_root() { + n_bytes_profiler.print_stats(); + } + + proof } } From ea30303ccc0aa465abbfca26ac5ef6fd92b5e2e0 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 20 Jul 2025 21:22:15 -0500 Subject: [PATCH 50/60] clearer profiler --- .../expander_no_oversubscribe/server_fn.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index 0c31b11b..1f2b5ac5 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -71,7 +71,10 @@ where n_bytes_profiler.add_fr(fr_unpacked[0]); }); }); - n_bytes_profiler.print_stats(); + if global_mpi_config.is_root() { + println!("NBytesProfiler stats before proving:"); + n_bytes_profiler.print_stats(); + } } let proof = mpi_prove_no_oversubscribe_impl::( @@ -82,8 +85,12 @@ where &mut n_bytes_profiler, ); - if global_mpi_config.is_root() { - n_bytes_profiler.print_stats(); + #[cfg(feature = "zkcuda_profile")] + { + if global_mpi_config.is_root() { + println!("NBytesProfiler stats after proving:"); + n_bytes_profiler.print_stats(); + } } proof From f07ab90d34c74b7f99bd5709dab0a444b5d0eab4 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 20 Jul 2025 21:22:49 -0500 Subject: [PATCH 51/60] clippy --- .../proving_system/expander_no_oversubscribe/server_fn.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index 1f2b5ac5..b530d187 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -1,5 +1,5 @@ use arith::Fr; -use gkr_engine::{FieldEngine, GKREngine, MPIConfig, MPIEngine}; +use gkr_engine::{FieldEngine, GKREngine, MPIConfig}; use crate::{ frontend::SIMDField, From f8df18bd6e639ff404625a8c0185ba3dbfb2f6fd Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 20 Jul 2025 21:24:25 -0500 Subject: [PATCH 52/60] bug fix --- .../proving_system/expander_no_oversubscribe/server_fn.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs index b530d187..85c5955f 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -63,6 +63,7 @@ where #[cfg(feature = "zkcuda_profile")] { use arith::SimdField; + use gkr_engine::MPIEngine; values.iter().for_each(|vals| { vals.as_ref().iter().for_each(|fr| { @@ -87,6 +88,8 @@ where #[cfg(feature = "zkcuda_profile")] { + use gkr_engine::MPIEngine; + if global_mpi_config.is_root() { println!("NBytesProfiler stats after proving:"); n_bytes_profiler.print_stats(); From fcd1edeb4c1f228099f302253c639b4b067e3f3c Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 21 Jul 2025 19:06:37 -0500 Subject: [PATCH 53/60] remove par verifier --- .../src/zkcuda/proving_system/expander/api_single_thread.rs | 4 ++-- .../proving_system/expander_parallelized/api_parallel.rs | 4 ++-- .../zkcuda/proving_system/expander_pcs_defered/verify_impl.rs | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs index 9091558a..af45f796 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs @@ -234,8 +234,8 @@ impl> ProvingSyste ) -> bool { let verified = proof .proofs - .par_iter() - .zip(computation_graph.proof_templates().par_iter()) + .iter() + .zip(computation_graph.proof_templates().iter()) .map(|(local_proof, template)| { let local_commitments = template .commitment_indices() diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs index 071044e6..bc2a25b3 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs @@ -58,8 +58,8 @@ impl> ProvingSyste let verification_timer = Timer::new("Verify all kernels", true); let verified = proof .proofs - .par_iter() - .zip(computation_graph.proof_templates().par_iter()) + .iter() + .zip(computation_graph.proof_templates().iter()) .map(|(local_proof, template)| { let local_commitments = template .commitment_indices() diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs index 08d4f45e..84c9f3f3 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs @@ -155,8 +155,8 @@ where let gkr_verification_timer = Timer::new("GKR Verification", true); let verified_with_pcs_claims = proof .proofs - .par_iter() - .zip(computation_graph.proof_templates().par_iter()) + .iter() + .zip(computation_graph.proof_templates().iter()) .map(|(local_proof, template)| { let local_commitments = template .commitment_indices() From 3a4469397f341e52640b12debf770ff1b89f3277 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 21 Jul 2025 19:07:05 -0500 Subject: [PATCH 54/60] clippy auto fix --- .../src/zkcuda/proving_system/expander/api_single_thread.rs | 2 +- .../zkcuda/proving_system/expander_parallelized/api_parallel.rs | 2 +- .../zkcuda/proving_system/expander_pcs_defered/verify_impl.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs index af45f796..ee39b3ae 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs @@ -1,4 +1,4 @@ -use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{IndexedParallelIterator, ParallelIterator}; use std::io::Cursor; use crate::circuit::config::Config; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs index bc2a25b3..b767ff2e 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs @@ -16,7 +16,7 @@ use super::super::Expander; use expander_utils::timer::Timer; use gkr_engine::GKREngine; -use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{IndexedParallelIterator, ParallelIterator}; pub struct ParallelizedExpander { _config: std::marker::PhantomData, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs index 84c9f3f3..9a313dc6 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs @@ -7,7 +7,7 @@ use gkr_engine::{ ExpanderDualVarChallenge, ExpanderPCS, ExpanderSingleVarChallenge, FieldEngine, GKREngine, Proof as BytesProof, Transcript, }; -use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{IndexedParallelIterator, ParallelIterator}; use serdes::ExpSerde; use crate::{ From 372d39bd77a3130a12c772917dadc6ad526f9f10 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 21 Jul 2025 21:24:30 -0500 Subject: [PATCH 55/60] no need to padding for unikzg --- .../proving_system/expander/api_single_thread.rs | 2 +- .../src/zkcuda/proving_system/expander/commit_impl.rs | 5 ++--- .../expander_no_oversubscribe/prove_impl.rs | 11 ++++++----- .../expander_parallelized/prove_impl.rs | 7 ++++++- .../proving_system/expander_pcs_defered/prove_impl.rs | 11 ++++------- expander_compiler/tests/zkcuda/zkcuda_examples.rs | 4 ++-- 6 files changed, 21 insertions(+), 19 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs index ee39b3ae..5588599c 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs @@ -52,7 +52,7 @@ where prover_setup: &Self::ProverSetup, vals: &[SIMDField], ) -> (Self::Commitment, Self::CommitmentState) { - local_commit_impl::(prover_setup, vals) + local_commit_impl::(prover_setup.p_keys.get(&vals.len()).unwrap(), vals) } fn prove_kernel( diff --git a/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs index 723d8f1f..ceeb7592 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs @@ -1,5 +1,5 @@ use expander_utils::timer::Timer; -use gkr_engine::{ExpanderPCS, GKREngine, MPIConfig}; +use gkr_engine::{ExpanderPCS, GKREngine, MPIConfig, StructuredReferenceString}; use polynomials::RefMultiLinearPoly; use super::structs::ExpanderProverSetup; @@ -9,7 +9,7 @@ use crate::{ }; pub fn local_commit_impl( - prover_setup: &ExpanderProverSetup, + p_key: &<>::SRS as StructuredReferenceString>::PKey, vals: &[SIMDField], ) -> ( ExpanderCommitment, @@ -23,7 +23,6 @@ where let n_vars = vals.len().ilog2() as usize; let params = >::gen_params(n_vars, 1); - let p_key = prover_setup.p_keys.get(&vals.len()).unwrap(); let mut scratch = >::init_scratch_pad( ¶ms, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs index a636a117..bc980372 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -26,7 +26,7 @@ use crate::{ server_ctrl::generate_local_mpi_config, }, expander_pcs_defered::prove_impl::{ - extract_pcs_claims, open_defered_pcs, pad_vals_and_commit, + extract_pcs_claims, max_len_setup_commit_impl, open_defered_pcs, }, CombinedProof, Expander, }, @@ -48,13 +48,14 @@ where let (commitments, states) = values .iter() .map(|value| match ZC::BATCH_PCS { - true => pad_vals_and_commit::( + true => max_len_setup_commit_impl::( prover_setup, value.as_ref(), ), - false => { - local_commit_impl::(prover_setup, value.as_ref()) - } + false => local_commit_impl::( + prover_setup.p_keys.get(&value.as_ref().len()).unwrap(), + value.as_ref(), + ), }) .unzip::<_, _, Vec<_>, Vec<_>>(); (Some(commitments), Some(states)) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index c6a84a93..5605daf0 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -40,7 +40,12 @@ where let (commitments, states) = if global_mpi_config.is_root() { let (commitments, states) = values .iter() - .map(|value| local_commit_impl::(prover_setup, value.as_ref())) + .map(|value| { + local_commit_impl::( + prover_setup.p_keys.get(&value.as_ref().len()).unwrap(), + value.as_ref(), + ) + }) .unzip::<_, _, Vec<_>, Vec<_>>(); (Some(commitments), Some(states)) } else { diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs index 071a2887..6629e6bb 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs @@ -27,7 +27,7 @@ use crate::{ }, }; -pub fn pad_vals_and_commit( +pub fn max_len_setup_commit_impl( prover_setup: &ExpanderProverSetup, vals: &[SIMDField], ) -> ( @@ -44,11 +44,8 @@ where let actual_len = vals.len(); assert!(len_to_commit >= actual_len); - // padding to max length and commit, this may be very inefficient - // TODO: optimize this - let mut vals = vals.to_vec(); - vals.resize(len_to_commit, SIMDField::::ZERO); - let (mut commitment, state) = local_commit_impl::(prover_setup, &vals); + let (mut commitment, state) = + local_commit_impl::(prover_setup.p_keys.get(&len_to_commit).unwrap(), &vals); commitment.vals_len = actual_len; // Store the actual length in the commitment (commitment, state) @@ -114,7 +111,7 @@ where let (commitments, _states) = if global_mpi_config.is_root() { let (commitments, states) = values .iter() - .map(|value| pad_vals_and_commit::(prover_setup, value.as_ref())) + .map(|value| max_len_setup_commit_impl::(prover_setup, value.as_ref())) .unzip::<_, _, Vec<_>, Vec<_>>(); (Some(commitments), Some(states)) } else { diff --git a/expander_compiler/tests/zkcuda/zkcuda_examples.rs b/expander_compiler/tests/zkcuda/zkcuda_examples.rs index 328f9926..a9bb04fe 100644 --- a/expander_compiler/tests/zkcuda/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda/zkcuda_examples.rs @@ -93,8 +93,8 @@ fn zkcuda_test_multi_core() { zkcuda_test::>(); zkcuda_test::>(); - zkcuda_test::<_, ExpanderNoOverSubscribe>(); - zkcuda_test::<_, ExpanderNoOverSubscribe>(); + // zkcuda_test::<_, ExpanderNoOverSubscribe>(); + // zkcuda_test::<_, ExpanderNoOverSubscribe>(); zkcuda_test::<_, ExpanderNoOverSubscribe>(); zkcuda_test::<_, ExpanderNoOverSubscribe>(); } From 9d3eb948b324768673b3115d60f6bc013d0d287c Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Mon, 21 Jul 2025 21:26:43 -0500 Subject: [PATCH 56/60] clippy --- .../src/zkcuda/proving_system/expander/api_single_thread.rs | 1 - .../src/zkcuda/proving_system/expander/commit_impl.rs | 1 - .../proving_system/expander_parallelized/api_parallel.rs | 1 - .../zkcuda/proving_system/expander_pcs_defered/prove_impl.rs | 3 +-- .../zkcuda/proving_system/expander_pcs_defered/verify_impl.rs | 1 - 5 files changed, 1 insertion(+), 6 deletions(-) diff --git a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs index 5588599c..836201ac 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs @@ -1,4 +1,3 @@ -use rayon::iter::{IndexedParallelIterator, ParallelIterator}; use std::io::Cursor; use crate::circuit::config::Config; diff --git a/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs index ceeb7592..400296ba 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs @@ -2,7 +2,6 @@ use expander_utils::timer::Timer; use gkr_engine::{ExpanderPCS, GKREngine, MPIConfig, StructuredReferenceString}; use polynomials::RefMultiLinearPoly; -use super::structs::ExpanderProverSetup; use crate::{ frontend::{Config, SIMDField}, zkcuda::proving_system::expander::structs::{ExpanderCommitment, ExpanderCommitmentState}, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs index b767ff2e..81005317 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs @@ -16,7 +16,6 @@ use super::super::Expander; use expander_utils::timer::Timer; use gkr_engine::GKREngine; -use rayon::iter::{IndexedParallelIterator, ParallelIterator}; pub struct ParallelizedExpander { _config: std::marker::PhantomData, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs index 6629e6bb..72545956 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs @@ -1,4 +1,3 @@ -use arith::Field; use expander_utils::timer::Timer; use gkr_engine::{ ExpanderPCS, ExpanderSingleVarChallenge, GKREngine, MPIConfig, MPIEngine, Proof as BytesProof, @@ -45,7 +44,7 @@ where assert!(len_to_commit >= actual_len); let (mut commitment, state) = - local_commit_impl::(prover_setup.p_keys.get(&len_to_commit).unwrap(), &vals); + local_commit_impl::(prover_setup.p_keys.get(&len_to_commit).unwrap(), vals); commitment.vals_len = actual_len; // Store the actual length in the commitment (commitment, state) diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs index 9a313dc6..b06be3df 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs @@ -7,7 +7,6 @@ use gkr_engine::{ ExpanderDualVarChallenge, ExpanderPCS, ExpanderSingleVarChallenge, FieldEngine, GKREngine, Proof as BytesProof, Transcript, }; -use rayon::iter::{IndexedParallelIterator, ParallelIterator}; use serdes::ExpSerde; use crate::{ From 5e2b6f3dee9e15c7534d91626dcad70b479e1635 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Thu, 24 Jul 2025 18:44:50 -0500 Subject: [PATCH 57/60] update --- Cargo.lock | 66 +++++++++++++++++++++++++++--------------------------- Cargo.toml | 38 +++++++++++++++---------------- 2 files changed, 52 insertions(+), 52 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5bfa24f3..1705ae5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,7 +112,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "ark-std", "criterion", @@ -330,7 +330,7 @@ dependencies = [ [[package]] name = "babybear" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -383,7 +383,7 @@ dependencies = [ [[package]] name = "bin" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "babybear", @@ -523,9 +523,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.29" +version = "1.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c1599538de2394445747c8cf7935946e3cc27e9625f889d979bfb2aaf569362" +checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7" dependencies = [ "shlex", ] @@ -589,7 +589,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "config_macros" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "gkr_engine", "gkr_hashers", @@ -817,7 +817,7 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crosslayer_prototype" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "env_logger", @@ -1143,7 +1143,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -1160,7 +1160,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -1179,7 +1179,7 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -1212,7 +1212,7 @@ dependencies = [ [[package]] name = "gkr_engine" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "babybear", @@ -1231,7 +1231,7 @@ dependencies = [ [[package]] name = "gkr_hashers" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "halo2curves", @@ -1249,7 +1249,7 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "goldilocks" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -1508,9 +1508,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.15" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f66d5bd4c6f02bf0542fad85d626775bab9258cf795a4256dcaf3161114d1df" +checksum = "8d9b05277c7e8da2c93a568989bb6207bef0112e8d17df7a6eda4a3cf143bc5e" dependencies = [ "bytes", "futures-core", @@ -1665,9 +1665,9 @@ dependencies = [ [[package]] name = "io-uring" -version = "0.7.8" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b86e202f00093dcba4275d4636b93ef9dd75d025ae560d2521b45ea28ab49013" +checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" dependencies = [ "bitflags 2.9.1", "cfg-if", @@ -1881,7 +1881,7 @@ dependencies = [ [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -2259,7 +2259,7 @@ dependencies = [ [[package]] name = "poly_commit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -2284,7 +2284,7 @@ dependencies = [ [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -2330,9 +2330,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.35" +version = "0.2.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "061c1221631e079b26479d25bbf2275bfe5917ae8419cd7e34f13bfc2aa7539a" +checksum = "ff24dfcda44452b9816fff4cd4227e1bb73ff5a2f1bc1105aa92fb8565ce44d2" dependencies = [ "proc-macro2", "syn 2.0.104", @@ -2429,9 +2429,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.13" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d04b7d0ee6b4a0207a0a7adb104d23ecb0b47d6beae7152d0fa34b692b29fd6" +checksum = "7e8af0dde094006011e6a740d4879319439489813bd0bcdc7d821beaeeff48ec" dependencies = [ "bitflags 2.9.1", ] @@ -2670,9 +2670,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3" dependencies = [ "itoa", "memchr", @@ -2705,7 +2705,7 @@ dependencies = [ [[package]] name = "serdes" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "ethnum", "halo2curves", @@ -2716,7 +2716,7 @@ dependencies = [ [[package]] name = "serdes_derive" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "proc-macro2", "quote", @@ -2840,7 +2840,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "circuit", @@ -3106,7 +3106,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "gkr_engine", @@ -3129,7 +3129,7 @@ dependencies = [ [[package]] name = "tree" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -3231,7 +3231,7 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "utils" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Frep_field_engine#a4fa31b9ea7a580da60c640cf0e38035b249831c" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "colored", ] diff --git a/Cargo.toml b/Cargo.toml index d391828a..bdf159a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,22 +47,22 @@ stacker = "0.1.17" tiny-keccak = { version = "2.0", features = ["keccak"] } tokio = { version = "1", features = ["full"] } -arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } -mpi_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } -gkr_field_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } -babybear = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } -crosslayer_prototype = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } -expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine", package = "circuit" } -expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine", package = "transcript" } -expander_binary = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine", package = "bin" } -gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } -gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } -mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } -goldilocks = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } -poly_commit = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine", package = "poly_commit" } -polynomials = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } -sumcheck = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } -serdes = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine", package = "serdes" } -gkr_engine = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } -gkr_hashers = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine" } -expander_utils = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/rep_field_engine", package = "utils" } +arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +mpi_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +gkr_field_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +babybear = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +crosslayer_prototype = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging", package = "circuit" } +expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging", package = "transcript" } +expander_binary = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging", package = "bin" } +gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +goldilocks = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +poly_commit = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging", package = "poly_commit" } +polynomials = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +sumcheck = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +serdes = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging", package = "serdes" } +gkr_engine = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +gkr_hashers = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +expander_utils = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging", package = "utils" } From d37fd538033d866f7be8c0609ddea76b2e83412f Mon Sep 17 00:00:00 2001 From: chonps Date: Sun, 27 Jul 2025 22:42:44 -0700 Subject: [PATCH 58/60] try to fix undo_transpose_shape_products --- Cargo.lock | 23 ++ expander_compiler/Cargo.toml | 1 + expander_compiler/src/zkcuda/shape.rs | 6 + expander_compiler/tests/circuit1.rs | 565 ++++++++++++++++++++++++++ 4 files changed, 595 insertions(+) create mode 100755 expander_compiler/tests/circuit1.rs diff --git a/Cargo.lock b/Cargo.lock index 1705ae5d..3385ba66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1005,6 +1005,7 @@ dependencies = [ "serdes", "sha2", "shared_memory", + "stacker", "sumcheck", "tiny-keccak", "tokio", @@ -2347,6 +2348,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "psm" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" +dependencies = [ + "cc", +] + [[package]] name = "quote" version = "1.0.40" @@ -2813,6 +2823,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stacker" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys 0.59.0", +] + [[package]] name = "static_assertions" version = "1.1.0" diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index ffed0fa4..98cfb413 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -44,6 +44,7 @@ once_cell = "1.21.3" [dev-dependencies] rayon = "1.9" sha2 = "0.10.8" +stacker = "0.1.15" [features] default = [] diff --git a/expander_compiler/src/zkcuda/shape.rs b/expander_compiler/src/zkcuda/shape.rs index 443651d2..0a5a581a 100644 --- a/expander_compiler/src/zkcuda/shape.rs +++ b/expander_compiler/src/zkcuda/shape.rs @@ -114,6 +114,12 @@ impl Entry { let mut cur_ts_prod = 1; let mut cur_ts_idx = 0; for &x in products.iter().skip(1) { + while ts[cur_ts_idx] == 1 && cur_ts_idx < ts.len() { + cur_ts_idx += 1; + } + if cur_ts_idx >= ts.len() { + break; + } segments_in_ts[self.axes.as_ref().unwrap()[cur_ts_idx]].push(x / cur_ts_prod); if x == cur_ts_prod * ts[cur_ts_idx] { cur_ts_prod = x; diff --git a/expander_compiler/tests/circuit1.rs b/expander_compiler/tests/circuit1.rs new file mode 100755 index 00000000..acc79f36 --- /dev/null +++ b/expander_compiler/tests/circuit1.rs @@ -0,0 +1,565 @@ +use expander_compiler::frontend::*; +use expander_compiler::zkcuda::{context::*, kernel::*}; +use gkr::BN254ConfigSha2Hyrax; +use gkr_engine::FieldEngine; +use expander_compiler::zkcuda::shape::{Reshape, Transpose}; +use serdes::ExpSerde; +use serde::{Deserialize, Serialize}; +use std::fs; +struct Circuit { + output: Vec>, + input: Vec>>>, + _features_features_0_Conv_output_0_conv: Vec>>>, + _features_features_0_Conv_output_0_div: Vec>>>, + _features_features_0_Conv_output_0_rem: Vec>>>, + _features_features_0_Conv_output_0_floor: Vec>>>, + _features_features_2_Relu_output_0: Vec>>>, + _features_features_3_Conv_output_0_conv: Vec>>>, + _features_features_3_Conv_output_0_div: Vec>>>, + _features_features_3_Conv_output_0_rem: Vec>>>, + _features_features_3_Conv_output_0_floor: Vec>>>, + _features_features_5_Relu_output_0: Vec>>>, + _features_features_6_MaxPool_output_0: Vec>>>, + _features_features_7_Conv_output_0_conv: Vec>>>, + _features_features_7_Conv_output_0_div: Vec>>>, + _features_features_7_Conv_output_0_rem: Vec>>>, + _features_features_7_Conv_output_0_floor: Vec>>>, + _features_features_9_Relu_output_0: Vec>>>, + _features_features_10_Conv_output_0_conv: Vec>>>, + _features_features_10_Conv_output_0_div: Vec>>>, + _features_features_10_Conv_output_0_rem: Vec>>>, + _features_features_10_Conv_output_0_floor: Vec>>>, + _features_features_12_Relu_output_0: Vec>>>, + _features_features_13_MaxPool_output_0: Vec>>>, + _features_features_14_Conv_output_0_conv: Vec>>>, + _features_features_14_Conv_output_0_div: Vec>>>, + _features_features_14_Conv_output_0_rem: Vec>>>, + _features_features_14_Conv_output_0_floor: Vec>>>, + _features_features_16_Relu_output_0: Vec>>>, + _features_features_17_Conv_output_0_conv: Vec>>>, + _features_features_17_Conv_output_0_div: Vec>>>, + _features_features_17_Conv_output_0_rem: Vec>>>, + _features_features_17_Conv_output_0_floor: Vec>>>, + _features_features_19_Relu_output_0: Vec>>>, + _features_features_20_Conv_output_0_conv: Vec>>>, + _features_features_20_Conv_output_0_div: Vec>>>, + _features_features_20_Conv_output_0_rem: Vec>>>, + _features_features_20_Conv_output_0_floor: Vec>>>, + _features_features_22_Relu_output_0: Vec>>>, + _features_features_23_MaxPool_output_0: Vec>>>, + _features_features_24_Conv_output_0_conv: Vec>>>, + _features_features_24_Conv_output_0_div: Vec>>>, + _features_features_24_Conv_output_0_rem: Vec>>>, + _features_features_24_Conv_output_0_floor: Vec>>>, + _features_features_26_Relu_output_0: Vec>>>, + _features_features_27_Conv_output_0_conv: Vec>>>, + _features_features_27_Conv_output_0_div: Vec>>>, + _features_features_27_Conv_output_0_rem: Vec>>>, + _features_features_27_Conv_output_0_floor: Vec>>>, + _features_features_29_Relu_output_0: Vec>>>, + _features_features_30_Conv_output_0_conv: Vec>>>, + _features_features_30_Conv_output_0_div: Vec>>>, + _features_features_30_Conv_output_0_rem: Vec>>>, + _features_features_30_Conv_output_0_floor: Vec>>>, + _features_features_32_Relu_output_0: Vec>>>, + _features_features_33_MaxPool_output_0: Vec>>>, + _features_features_34_Conv_output_0_conv: Vec>>>, + _features_features_34_Conv_output_0_div: Vec>>>, + _features_features_34_Conv_output_0_rem: Vec>>>, + _features_features_34_Conv_output_0_floor: Vec>>>, + _features_features_36_Relu_output_0: Vec>>>, + _features_features_37_Conv_output_0_conv: Vec>>>, + _features_features_37_Conv_output_0_div: Vec>>>, + _features_features_37_Conv_output_0_rem: Vec>>>, + _features_features_37_Conv_output_0_floor: Vec>>>, + _features_features_39_Relu_output_0: Vec>>>, + _features_features_40_Conv_output_0_conv: Vec>>>, + _features_features_40_Conv_output_0_div: Vec>>>, + _features_features_40_Conv_output_0_rem: Vec>>>, + _features_features_40_Conv_output_0_floor: Vec>>>, + _features_features_42_Relu_output_0: Vec>>>, + _features_features_43_MaxPool_output_0: Vec>>>, + _avgpool_GlobalAveragePool_output_0: Vec>>>, + _classifier_classifier_0_Gemm_output_0_matmul: Vec>, + _classifier_classifier_0_Gemm_output_0_div: Vec>, + _classifier_classifier_0_Gemm_output_0_rem: Vec>, + _classifier_classifier_0_Gemm_output_0_floor: Vec>, + _classifier_classifier_1_Relu_output_0: Vec>, + _classifier_classifier_3_Gemm_output_0_matmul: Vec>, + _classifier_classifier_3_Gemm_output_0_div: Vec>, + _classifier_classifier_3_Gemm_output_0_rem: Vec>, + _classifier_classifier_3_Gemm_output_0_floor: Vec>, + _classifier_classifier_4_Relu_output_0: Vec>, + output_matmul: Vec>, + output_div: Vec>, + output_rem: Vec>, + output_floor: Vec>, + onnx__Conv_150: Vec>>>, + onnx__Conv_151: Vec, + onnx__Conv_151_q: Vec>>, + onnx__Conv_150_nscale: BN254Fr, + onnx__Conv_150_dscale: BN254Fr, + onnx__Conv_153: Vec>>>, + onnx__Conv_154: Vec, + onnx__Conv_154_q: Vec>>, + onnx__Conv_153_nscale: BN254Fr, + onnx__Conv_153_dscale: BN254Fr, + onnx__Conv_156: Vec>>>, + onnx__Conv_157: Vec, + onnx__Conv_157_q: Vec>>, + onnx__Conv_156_nscale: BN254Fr, + onnx__Conv_156_dscale: BN254Fr, + onnx__Conv_159: Vec>>>, + onnx__Conv_160: Vec, + onnx__Conv_160_q: Vec>>, + onnx__Conv_159_nscale: BN254Fr, + onnx__Conv_159_dscale: BN254Fr, + onnx__Conv_162: Vec>>>, + onnx__Conv_163: Vec, + onnx__Conv_163_q: Vec>>, + onnx__Conv_162_nscale: BN254Fr, + onnx__Conv_162_dscale: BN254Fr, + onnx__Conv_165: Vec>>>, + onnx__Conv_166: Vec, + onnx__Conv_166_q: Vec>>, + onnx__Conv_165_nscale: BN254Fr, + onnx__Conv_165_dscale: BN254Fr, + onnx__Conv_168: Vec>>>, + onnx__Conv_169: Vec, + onnx__Conv_169_q: Vec>>, + onnx__Conv_168_nscale: BN254Fr, + onnx__Conv_168_dscale: BN254Fr, + onnx__Conv_171: Vec>>>, + onnx__Conv_172: Vec, + onnx__Conv_172_q: Vec>>, + onnx__Conv_171_nscale: BN254Fr, + onnx__Conv_171_dscale: BN254Fr, + onnx__Conv_174: Vec>>>, + onnx__Conv_175: Vec, + onnx__Conv_175_q: Vec>>, + onnx__Conv_174_nscale: BN254Fr, + onnx__Conv_174_dscale: BN254Fr, + onnx__Conv_177: Vec>>>, + onnx__Conv_178: Vec, + onnx__Conv_178_q: Vec>>, + onnx__Conv_177_nscale: BN254Fr, + onnx__Conv_177_dscale: BN254Fr, + onnx__Conv_180: Vec>>>, + onnx__Conv_181: Vec, + onnx__Conv_181_q: Vec>>, + onnx__Conv_180_nscale: BN254Fr, + onnx__Conv_180_dscale: BN254Fr, + onnx__Conv_183: Vec>>>, + onnx__Conv_184: Vec, + onnx__Conv_184_q: Vec>>, + onnx__Conv_183_nscale: BN254Fr, + onnx__Conv_183_dscale: BN254Fr, + onnx__Conv_186: Vec>>>, + onnx__Conv_187: Vec, + onnx__Conv_187_q: Vec>>, + onnx__Conv_186_nscale: BN254Fr, + onnx__Conv_186_dscale: BN254Fr, + classifier_0_weight: Vec>, + classifier_0_bias_q: Vec, + classifier_0_weight_nscale: BN254Fr, + classifier_0_weight_dscale: BN254Fr, + classifier_3_weight: Vec>, + classifier_3_bias_q: Vec, + classifier_3_weight_nscale: BN254Fr, + classifier_3_weight_dscale: BN254Fr, + classifier_6_weight: Vec>, + classifier_6_bias_q: Vec, + classifier_6_weight_nscale: BN254Fr, + classifier_6_weight_dscale: BN254Fr, + input_mat_ru: Vec, + onnx__Conv_150_mat_rv: Vec, + _features_features_2_Relu_output_0_mat_ru: Vec, + onnx__Conv_153_mat_rv: Vec, + _features_features_6_MaxPool_output_0_mat_ru: Vec, + onnx__Conv_156_mat_rv: Vec, + _features_features_9_Relu_output_0_mat_ru: Vec, + onnx__Conv_159_mat_rv: Vec, + _features_features_13_MaxPool_output_0_mat_ru: Vec, + onnx__Conv_162_mat_rv: Vec, + _features_features_16_Relu_output_0_mat_ru: Vec, + onnx__Conv_165_mat_rv: Vec, + _features_features_19_Relu_output_0_mat_ru: Vec, + onnx__Conv_168_mat_rv: Vec, + _features_features_23_MaxPool_output_0_mat_ru: Vec, + onnx__Conv_171_mat_rv: Vec, + _features_features_26_Relu_output_0_mat_ru: Vec, + onnx__Conv_174_mat_rv: Vec, + _features_features_29_Relu_output_0_mat_ru: Vec, + onnx__Conv_177_mat_rv: Vec, + _features_features_33_MaxPool_output_0_mat_ru: Vec, + onnx__Conv_180_mat_rv: Vec, + _features_features_36_Relu_output_0_mat_ru: Vec, + onnx__Conv_183_mat_rv: Vec, + _features_features_39_Relu_output_0_mat_ru: Vec, + onnx__Conv_186_mat_rv: Vec, + _Flatten_output_0_mat_ru: Vec, + classifier_0_weight_mat_rv: Vec, + _classifier_classifier_1_Relu_output_0_mat_ru: Vec, + classifier_3_weight_mat_rv: Vec, + _classifier_classifier_4_Relu_output_0_mat_ru: Vec, + classifier_6_weight_mat_rv: Vec, +} + +fn default_variable() -> Circuit{ + let output = vec![vec![BN254Fr::default();10];1]; + let input = vec![vec![vec![vec![BN254Fr::default();32];32];3];1]; + let _features_features_0_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_0_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_0_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_0_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_2_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_3_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_3_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_3_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_3_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_5_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_6_MaxPool_output_0 = vec![vec![vec![vec![BN254Fr::default();16];16];64];1]; + let _features_features_7_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_7_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_7_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_7_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_9_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_10_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_10_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_10_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_10_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_12_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_13_MaxPool_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];128];1]; + let _features_features_14_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_14_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_14_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_14_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_16_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_17_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_17_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_17_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_17_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_19_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_20_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_20_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_20_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_20_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_22_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_23_MaxPool_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];256];1]; + let _features_features_24_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_24_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_24_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_24_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_26_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_27_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_27_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_27_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_27_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_29_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_30_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_30_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_30_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_30_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_32_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_33_MaxPool_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_34_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_34_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_34_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_34_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_36_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_37_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_37_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_37_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_37_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_39_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_40_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_40_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_40_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_40_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_42_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_43_MaxPool_output_0 = vec![vec![vec![vec![BN254Fr::default();1];1];512];1]; + let _avgpool_GlobalAveragePool_output_0 = vec![vec![vec![vec![BN254Fr::default();1];1];512];1]; + let _classifier_classifier_0_Gemm_output_0_matmul = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_0_Gemm_output_0_div = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_0_Gemm_output_0_rem = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_0_Gemm_output_0_floor = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_1_Relu_output_0 = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_3_Gemm_output_0_matmul = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_3_Gemm_output_0_div = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_3_Gemm_output_0_rem = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_3_Gemm_output_0_floor = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_4_Relu_output_0 = vec![vec![BN254Fr::default();512];1]; + let output_matmul = vec![vec![BN254Fr::default();10];1]; + let output_div = vec![vec![BN254Fr::default();10];1]; + let output_rem = vec![vec![BN254Fr::default();10];1]; + let output_floor = vec![vec![BN254Fr::default();10];1]; + let onnx__Conv_150 = vec![vec![vec![vec![BN254Fr::default();3];3];3];64]; + let onnx__Conv_151 = vec![BN254Fr::default();64]; + let onnx__Conv_151_q = vec![vec![vec![BN254Fr::default();1];1];64]; + let onnx__Conv_150_nscale = BN254Fr::default(); + let onnx__Conv_150_dscale = BN254Fr::default(); + let onnx__Conv_153 = vec![vec![vec![vec![BN254Fr::default();3];3];64];64]; + let onnx__Conv_154 = vec![BN254Fr::default();64]; + let onnx__Conv_154_q = vec![vec![vec![BN254Fr::default();1];1];64]; + let onnx__Conv_153_nscale = BN254Fr::default(); + let onnx__Conv_153_dscale = BN254Fr::default(); + let onnx__Conv_156 = vec![vec![vec![vec![BN254Fr::default();3];3];64];128]; + let onnx__Conv_157 = vec![BN254Fr::default();128]; + let onnx__Conv_157_q = vec![vec![vec![BN254Fr::default();1];1];128]; + let onnx__Conv_156_nscale = BN254Fr::default(); + let onnx__Conv_156_dscale = BN254Fr::default(); + let onnx__Conv_159 = vec![vec![vec![vec![BN254Fr::default();3];3];128];128]; + let onnx__Conv_160 = vec![BN254Fr::default();128]; + let onnx__Conv_160_q = vec![vec![vec![BN254Fr::default();1];1];128]; + let onnx__Conv_159_nscale = BN254Fr::default(); + let onnx__Conv_159_dscale = BN254Fr::default(); + let onnx__Conv_162 = vec![vec![vec![vec![BN254Fr::default();3];3];128];256]; + let onnx__Conv_163 = vec![BN254Fr::default();256]; + let onnx__Conv_163_q = vec![vec![vec![BN254Fr::default();1];1];256]; + let onnx__Conv_162_nscale = BN254Fr::default(); + let onnx__Conv_162_dscale = BN254Fr::default(); + let onnx__Conv_165 = vec![vec![vec![vec![BN254Fr::default();3];3];256];256]; + let onnx__Conv_166 = vec![BN254Fr::default();256]; + let onnx__Conv_166_q = vec![vec![vec![BN254Fr::default();1];1];256]; + let onnx__Conv_165_nscale = BN254Fr::default(); + let onnx__Conv_165_dscale = BN254Fr::default(); + let onnx__Conv_168 = vec![vec![vec![vec![BN254Fr::default();3];3];256];256]; + let onnx__Conv_169 = vec![BN254Fr::default();256]; + let onnx__Conv_169_q = vec![vec![vec![BN254Fr::default();1];1];256]; + let onnx__Conv_168_nscale = BN254Fr::default(); + let onnx__Conv_168_dscale = BN254Fr::default(); + let onnx__Conv_171 = vec![vec![vec![vec![BN254Fr::default();3];3];256];512]; + let onnx__Conv_172 = vec![BN254Fr::default();512]; + let onnx__Conv_172_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx__Conv_171_nscale = BN254Fr::default(); + let onnx__Conv_171_dscale = BN254Fr::default(); + let onnx__Conv_174 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx__Conv_175 = vec![BN254Fr::default();512]; + let onnx__Conv_175_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx__Conv_174_nscale = BN254Fr::default(); + let onnx__Conv_174_dscale = BN254Fr::default(); + let onnx__Conv_177 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx__Conv_178 = vec![BN254Fr::default();512]; + let onnx__Conv_178_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx__Conv_177_nscale = BN254Fr::default(); + let onnx__Conv_177_dscale = BN254Fr::default(); + let onnx__Conv_180 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx__Conv_181 = vec![BN254Fr::default();512]; + let onnx__Conv_181_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx__Conv_180_nscale = BN254Fr::default(); + let onnx__Conv_180_dscale = BN254Fr::default(); + let onnx__Conv_183 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx__Conv_184 = vec![BN254Fr::default();512]; + let onnx__Conv_184_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx__Conv_183_nscale = BN254Fr::default(); + let onnx__Conv_183_dscale = BN254Fr::default(); + let onnx__Conv_186 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx__Conv_187 = vec![BN254Fr::default();512]; + let onnx__Conv_187_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx__Conv_186_nscale = BN254Fr::default(); + let onnx__Conv_186_dscale = BN254Fr::default(); + let classifier_0_weight = vec![vec![BN254Fr::default();512];512]; + let classifier_0_bias_q = vec![BN254Fr::default();512]; + let classifier_0_weight_nscale = BN254Fr::default(); + let classifier_0_weight_dscale = BN254Fr::default(); + let classifier_3_weight = vec![vec![BN254Fr::default();512];512]; + let classifier_3_bias_q = vec![BN254Fr::default();512]; + let classifier_3_weight_nscale = BN254Fr::default(); + let classifier_3_weight_dscale = BN254Fr::default(); + let classifier_6_weight = vec![vec![BN254Fr::default();10];512]; + let classifier_6_bias_q = vec![BN254Fr::default();10]; + let classifier_6_weight_nscale = BN254Fr::default(); + let classifier_6_weight_dscale = BN254Fr::default(); + let input_mat_ru = vec![BN254Fr::default();1024]; + let onnx__Conv_150_mat_rv = vec![BN254Fr::default();64]; + let _features_features_2_Relu_output_0_mat_ru = vec![BN254Fr::default();1024]; + let onnx__Conv_153_mat_rv = vec![BN254Fr::default();64]; + let _features_features_6_MaxPool_output_0_mat_ru = vec![BN254Fr::default();256]; + let onnx__Conv_156_mat_rv = vec![BN254Fr::default();128]; + let _features_features_9_Relu_output_0_mat_ru = vec![BN254Fr::default();256]; + let onnx__Conv_159_mat_rv = vec![BN254Fr::default();128]; + let _features_features_13_MaxPool_output_0_mat_ru = vec![BN254Fr::default();64]; + let onnx__Conv_162_mat_rv = vec![BN254Fr::default();256]; + let _features_features_16_Relu_output_0_mat_ru = vec![BN254Fr::default();64]; + let onnx__Conv_165_mat_rv = vec![BN254Fr::default();256]; + let _features_features_19_Relu_output_0_mat_ru = vec![BN254Fr::default();64]; + let onnx__Conv_168_mat_rv = vec![BN254Fr::default();256]; + let _features_features_23_MaxPool_output_0_mat_ru = vec![BN254Fr::default();16]; + let onnx__Conv_171_mat_rv = vec![BN254Fr::default();512]; + let _features_features_26_Relu_output_0_mat_ru = vec![BN254Fr::default();16]; + let onnx__Conv_174_mat_rv = vec![BN254Fr::default();512]; + let _features_features_29_Relu_output_0_mat_ru = vec![BN254Fr::default();16]; + let onnx__Conv_177_mat_rv = vec![BN254Fr::default();512]; + let _features_features_33_MaxPool_output_0_mat_ru = vec![BN254Fr::default();4]; + let onnx__Conv_180_mat_rv = vec![BN254Fr::default();512]; + let _features_features_36_Relu_output_0_mat_ru = vec![BN254Fr::default();4]; + let onnx__Conv_183_mat_rv = vec![BN254Fr::default();512]; + let _features_features_39_Relu_output_0_mat_ru = vec![BN254Fr::default();4]; + let onnx__Conv_186_mat_rv = vec![BN254Fr::default();512]; + let _Flatten_output_0_mat_ru = vec![BN254Fr::default();1]; + let classifier_0_weight_mat_rv = vec![BN254Fr::default();512]; + let _classifier_classifier_1_Relu_output_0_mat_ru = vec![BN254Fr::default();1]; + let classifier_3_weight_mat_rv = vec![BN254Fr::default();512]; + let _classifier_classifier_4_Relu_output_0_mat_ru = vec![BN254Fr::default();1]; + let classifier_6_weight_mat_rv = vec![BN254Fr::default();10]; + let ass = Circuit{output,input,_features_features_0_Conv_output_0_conv,_features_features_0_Conv_output_0_div,_features_features_0_Conv_output_0_rem,_features_features_0_Conv_output_0_floor,_features_features_2_Relu_output_0,_features_features_3_Conv_output_0_conv,_features_features_3_Conv_output_0_div,_features_features_3_Conv_output_0_rem,_features_features_3_Conv_output_0_floor,_features_features_5_Relu_output_0,_features_features_6_MaxPool_output_0,_features_features_7_Conv_output_0_conv,_features_features_7_Conv_output_0_div,_features_features_7_Conv_output_0_rem,_features_features_7_Conv_output_0_floor,_features_features_9_Relu_output_0,_features_features_10_Conv_output_0_conv,_features_features_10_Conv_output_0_div,_features_features_10_Conv_output_0_rem,_features_features_10_Conv_output_0_floor,_features_features_12_Relu_output_0,_features_features_13_MaxPool_output_0,_features_features_14_Conv_output_0_conv,_features_features_14_Conv_output_0_div,_features_features_14_Conv_output_0_rem,_features_features_14_Conv_output_0_floor,_features_features_16_Relu_output_0,_features_features_17_Conv_output_0_conv,_features_features_17_Conv_output_0_div,_features_features_17_Conv_output_0_rem,_features_features_17_Conv_output_0_floor,_features_features_19_Relu_output_0,_features_features_20_Conv_output_0_conv,_features_features_20_Conv_output_0_div,_features_features_20_Conv_output_0_rem,_features_features_20_Conv_output_0_floor,_features_features_22_Relu_output_0,_features_features_23_MaxPool_output_0,_features_features_24_Conv_output_0_conv,_features_features_24_Conv_output_0_div,_features_features_24_Conv_output_0_rem,_features_features_24_Conv_output_0_floor,_features_features_26_Relu_output_0,_features_features_27_Conv_output_0_conv,_features_features_27_Conv_output_0_div,_features_features_27_Conv_output_0_rem,_features_features_27_Conv_output_0_floor,_features_features_29_Relu_output_0,_features_features_30_Conv_output_0_conv,_features_features_30_Conv_output_0_div,_features_features_30_Conv_output_0_rem,_features_features_30_Conv_output_0_floor,_features_features_32_Relu_output_0,_features_features_33_MaxPool_output_0,_features_features_34_Conv_output_0_conv,_features_features_34_Conv_output_0_div,_features_features_34_Conv_output_0_rem,_features_features_34_Conv_output_0_floor,_features_features_36_Relu_output_0,_features_features_37_Conv_output_0_conv,_features_features_37_Conv_output_0_div,_features_features_37_Conv_output_0_rem,_features_features_37_Conv_output_0_floor,_features_features_39_Relu_output_0,_features_features_40_Conv_output_0_conv,_features_features_40_Conv_output_0_div,_features_features_40_Conv_output_0_rem,_features_features_40_Conv_output_0_floor,_features_features_42_Relu_output_0,_features_features_43_MaxPool_output_0,_avgpool_GlobalAveragePool_output_0,_classifier_classifier_0_Gemm_output_0_matmul,_classifier_classifier_0_Gemm_output_0_div,_classifier_classifier_0_Gemm_output_0_rem,_classifier_classifier_0_Gemm_output_0_floor,_classifier_classifier_1_Relu_output_0,_classifier_classifier_3_Gemm_output_0_matmul,_classifier_classifier_3_Gemm_output_0_div,_classifier_classifier_3_Gemm_output_0_rem,_classifier_classifier_3_Gemm_output_0_floor,_classifier_classifier_4_Relu_output_0,output_matmul,output_div,output_rem,output_floor,onnx__Conv_150,onnx__Conv_151,onnx__Conv_151_q,onnx__Conv_150_nscale,onnx__Conv_150_dscale,onnx__Conv_153,onnx__Conv_154,onnx__Conv_154_q,onnx__Conv_153_nscale,onnx__Conv_153_dscale,onnx__Conv_156,onnx__Conv_157,onnx__Conv_157_q,onnx__Conv_156_nscale,onnx__Conv_156_dscale,onnx__Conv_159,onnx__Conv_160,onnx__Conv_160_q,onnx__Conv_159_nscale,onnx__Conv_159_dscale,onnx__Conv_162,onnx__Conv_163,onnx__Conv_163_q,onnx__Conv_162_nscale,onnx__Conv_162_dscale,onnx__Conv_165,onnx__Conv_166,onnx__Conv_166_q,onnx__Conv_165_nscale,onnx__Conv_165_dscale,onnx__Conv_168,onnx__Conv_169,onnx__Conv_169_q,onnx__Conv_168_nscale,onnx__Conv_168_dscale,onnx__Conv_171,onnx__Conv_172,onnx__Conv_172_q,onnx__Conv_171_nscale,onnx__Conv_171_dscale,onnx__Conv_174,onnx__Conv_175,onnx__Conv_175_q,onnx__Conv_174_nscale,onnx__Conv_174_dscale,onnx__Conv_177,onnx__Conv_178,onnx__Conv_178_q,onnx__Conv_177_nscale,onnx__Conv_177_dscale,onnx__Conv_180,onnx__Conv_181,onnx__Conv_181_q,onnx__Conv_180_nscale,onnx__Conv_180_dscale,onnx__Conv_183,onnx__Conv_184,onnx__Conv_184_q,onnx__Conv_183_nscale,onnx__Conv_183_dscale,onnx__Conv_186,onnx__Conv_187,onnx__Conv_187_q,onnx__Conv_186_nscale,onnx__Conv_186_dscale,classifier_0_weight,classifier_0_bias_q,classifier_0_weight_nscale,classifier_0_weight_dscale,classifier_3_weight,classifier_3_bias_q,classifier_3_weight_nscale,classifier_3_weight_dscale,classifier_6_weight,classifier_6_bias_q,classifier_6_weight_nscale,classifier_6_weight_dscale,input_mat_ru,onnx__Conv_150_mat_rv,_features_features_2_Relu_output_0_mat_ru,onnx__Conv_153_mat_rv,_features_features_6_MaxPool_output_0_mat_ru,onnx__Conv_156_mat_rv,_features_features_9_Relu_output_0_mat_ru,onnx__Conv_159_mat_rv,_features_features_13_MaxPool_output_0_mat_ru,onnx__Conv_162_mat_rv,_features_features_16_Relu_output_0_mat_ru,onnx__Conv_165_mat_rv,_features_features_19_Relu_output_0_mat_ru,onnx__Conv_168_mat_rv,_features_features_23_MaxPool_output_0_mat_ru,onnx__Conv_171_mat_rv,_features_features_26_Relu_output_0_mat_ru,onnx__Conv_174_mat_rv,_features_features_29_Relu_output_0_mat_ru,onnx__Conv_177_mat_rv,_features_features_33_MaxPool_output_0_mat_ru,onnx__Conv_180_mat_rv,_features_features_36_Relu_output_0_mat_ru,onnx__Conv_183_mat_rv,_features_features_39_Relu_output_0_mat_ru,onnx__Conv_186_mat_rv,_Flatten_output_0_mat_ru,classifier_0_weight_mat_rv,_classifier_classifier_1_Relu_output_0_mat_ru,classifier_3_weight_mat_rv,_classifier_classifier_4_Relu_output_0_mat_ru,classifier_6_weight_mat_rv}; + ass +} + +#[kernel] +fn _features_features_0_Conv_conv_copy_macro( + api: &mut API, + onnx__Conv_150: &[[[[InputVariable;3];3];3];64], + _features_features_0_Conv_output_0_conv: &[[[[InputVariable;32];32];64];1], + input: &[[[[InputVariable;32];32];3];1], + + onnx__Conv_150_mat: &mut [[OutputVariable;64];27], + _features_features_0_Conv_output_0_conv_mat: &mut [[OutputVariable;1024];64], + input_mat: &mut [[OutputVariable;1024];27], +) { + // for i in 0..64 { + // for j in 0..3 { + // for k in 0..3 { + // for l in 0..3 { + // onnx__Conv_150_mat[((j)*3 + k)*3 + l][i] = onnx__Conv_150[i][j][k][l]; + // } + // } + // } + // } + // for i in 0..1 { + // for j in 0..64 { + // for k in 0..32 { + // for l in 0..32 { + // _features_features_0_Conv_output_0_conv_mat[j][((i)*32 + k)*32 + l] = _features_features_0_Conv_output_0_conv[i][j][k][l]; + // } + // } + // } + // } + for i in (0..(1 + 0 + 0 - 1 + 1)).step_by(1) { + for j in (0..(3 + 0 + 0 - 3 + 1)).step_by(3) { + for k in (0..(32 + 1 + 1 - 3 + 1)).step_by(1) { + for l in (0..(32 + 1 + 1 - 3 + 1)).step_by(1) { + for m in 0..1 { + for n in 0..3 { + for o in 0..3 { + for p in 0..3 { + if true && (i+m-0) >= 0 && (i+m-0) < 1 && (j+n-0) >= 0 && (j+n-0) < 3 && (k+o-1) >= 0 && (k+o-1) < 32 && (l+p-1) >= 0 && (l+p-1) < 32 { input_mat[((n)*3 + o)*3 + p][((i)*32 + k)*32 + l] = input[i+m-0][j+n-0][k+o-1][l+p-1]} + else { input_mat[((n)*3 + o)*3 + p][((i)*32 + k)*32 + l] = api.constant(0)}; + } + } + } + } + } + } + } + } +} + +#[kernel] +fn _features_features_0_Conv_conv_ab_matrix_macro( + api: &mut API, + input_mat: & [InputVariable;1024], + onnx__Conv_150_mat: & [InputVariable;64], + input_mat_ru: & [InputVariable;1024], + onnx__Conv_150_mat_rv: & [InputVariable;64], + _features_features_0_Conv_conv_ab_matrix_rx: &mut OutputVariable, + _features_features_0_Conv_conv_ab_matrix_ry: &mut OutputVariable, +) { + *_features_features_0_Conv_conv_ab_matrix_rx = api.constant(0); + for i in 0..1024 { + let tmp = api.mul(input_mat_ru[i], input_mat[i]); + *_features_features_0_Conv_conv_ab_matrix_rx = api.add(tmp, *_features_features_0_Conv_conv_ab_matrix_rx); + } + *_features_features_0_Conv_conv_ab_matrix_ry = api.constant(0); + for i in 0..64 { + let tmp = api.mul(onnx__Conv_150_mat_rv[i], onnx__Conv_150_mat[i]); + *_features_features_0_Conv_conv_ab_matrix_ry = api.add(tmp, *_features_features_0_Conv_conv_ab_matrix_ry); + } +} +#[kernel] +fn _features_features_0_Conv_conv_c_matrix_macro( + api: &mut API, + _features_features_0_Conv_output_0_conv_mat: & [InputVariable;1024], + input_mat_ru: & [InputVariable;1024], + _features_features_0_Conv_conv_c_matrix_rz: &mut OutputVariable, +) { + *_features_features_0_Conv_conv_c_matrix_rz = api.constant(0); + for i in 0..1024 { + let tmp = api.mul(input_mat_ru[i], _features_features_0_Conv_output_0_conv_mat[i]); + *_features_features_0_Conv_conv_c_matrix_rz = api.add(tmp, *_features_features_0_Conv_conv_c_matrix_rz); + } +} + +#[kernel] // multiply operation +fn _features_features_0_Conv_mul_macro( + api: &mut API, + _features_features_0_Conv_output_0_conv: &[[InputVariable;32];32], + onnx__Conv_150_nscale: &InputVariable, + _features_features_0_Conv_output_0_mul: &mut [[OutputVariable;32];32], +) { + for i in 0..32 { + for j in 0..32 { + _features_features_0_Conv_output_0_mul[i][j] = api.mul(_features_features_0_Conv_output_0_conv[i][j], onnx__Conv_150_nscale); + } + } +} + +#[kernel] // divide operation +fn _features_features_0_Conv_div_macro( + api: &mut API, + _features_features_0_Conv_output_0_mul: &[[InputVariable;32];32], + onnx__Conv_150_dscale: &InputVariable, + _features_features_0_Conv_output_0_floor: &[[InputVariable;32];32], + _features_features_0_Conv_output_0_rem: &[[InputVariable;32];32], +) { + for i in 0..32 { + for j in 0..32 { + let tmp1 = api.mul(_features_features_0_Conv_output_0_floor[i][j], onnx__Conv_150_dscale); + let tmp2 = api.sub(_features_features_0_Conv_output_0_mul[i][j], _features_features_0_Conv_output_0_rem[i][j]); + api.assert_is_equal(tmp1, tmp2); + } + } +} + +#[test] +fn expander_circuit() -> std::io::Result<()>{ + let compile_result = stacker::grow(32 * 1024 * 1024 * 1024, || + { + let mut ctx = Context::::default(); + let mut assignment = default_variable(); + + let onnx__Conv_150_mat = ctx.copy_to_device(&assignment.onnx__Conv_150); // [64, 3, 3, 3] + let onnx__Conv_150_mat = onnx__Conv_150_mat.reshape(&[64, 27]); // [64, 27] + let onnx__Conv_150_mat = onnx__Conv_150_mat.transpose(&[1, 0]); // [27, 64] + + let kernel__features_features_0_Conv_conv_ab_matrix: KernelPrimitive = compile__features_features_0_Conv_conv_ab_matrix_macro().unwrap(); + let input_mat = ctx.copy_to_device(&vec![vec![BN254Fr::default();1024];27]); + let input_mat_ru = ctx.copy_to_device(&assignment.input_mat_ru); + let onnx__Conv_150_mat_rv = ctx.copy_to_device(&assignment.onnx__Conv_150_mat_rv); + let mut _features_features_0_Conv_conv_rx = None; + let mut _features_features_0_Conv_conv_ry = None; + let mut input_mat_clone = input_mat.clone(); + let mut onnx__Conv_150_mat_clone = onnx__Conv_150_mat.clone(); + let mut input_mat_ru_clone = input_mat_ru.clone(); + let mut onnx__Conv_150_mat_rv_clone = onnx__Conv_150_mat_rv.clone(); + call_kernel!(ctx, kernel__features_features_0_Conv_conv_ab_matrix, 27, input_mat_clone, onnx__Conv_150_mat_clone, input_mat_ru_clone, onnx__Conv_150_mat_rv_clone, mut _features_features_0_Conv_conv_rx, mut _features_features_0_Conv_conv_ry).unwrap(); + + let _features_features_0_Conv_output_0_conv = ctx.copy_to_device(&assignment._features_features_0_Conv_output_0_conv); // [1, 64, 32, 32] + let _features_features_0_Conv_output_0_conv_mat = _features_features_0_Conv_output_0_conv.transpose(&[1, 0, 2, 3]); // [64, 1, 32, 32] + let _features_features_0_Conv_output_0_conv_mat = _features_features_0_Conv_output_0_conv_mat.reshape(&[64, 1024]); // [64, 1024] + + let kernel__features_features_0_Conv_conv_c_matrix: KernelPrimitive = compile__features_features_0_Conv_conv_c_matrix_macro().unwrap(); + // let _features_features_0_Conv_output_0_conv_mat = ctx.copy_to_device(&vec![vec![BN254Fr::default();1024];64]); + let mut _features_features_0_Conv_conv_rz = None; + let _features_features_0_Conv_output_0_conv_mat_clone = _features_features_0_Conv_output_0_conv_mat.clone(); + let input_mat_ru_clone = input_mat_ru.clone(); + call_kernel!(ctx, kernel__features_features_0_Conv_conv_c_matrix, 64, _features_features_0_Conv_output_0_conv_mat_clone, input_mat_ru_clone, mut _features_features_0_Conv_conv_rz).unwrap(); + + let computation_graph = ctx.compile_computation_graph().unwrap(); + let file = std::fs::File::create("graph.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + computation_graph.serialize_into(writer); + } + ); + Ok(()) +} From 531042662d0d6da711de6ab95269fe0d9bfa26cb Mon Sep 17 00:00:00 2001 From: chonps Date: Mon, 28 Jul 2025 23:06:52 -0700 Subject: [PATCH 59/60] try to fix --- expander_compiler/src/zkcuda/context.rs | 1 + expander_compiler/src/zkcuda/shape.rs | 63 +- expander_compiler/tests/circuit1.rs | 802 ++++++++++++------------ 3 files changed, 447 insertions(+), 419 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index aaeda969..65cd5231 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -486,6 +486,7 @@ impl>> Context { loop { let get_pad_shape = |x: &DeviceMemoryHandle| { x.as_ref().map(|handle| { +println!("get handle {:?}", handle.id); handle .shape_history .get_transposed_shape_and_bit_order(&dm_shapes[handle.id]) diff --git a/expander_compiler/src/zkcuda/shape.rs b/expander_compiler/src/zkcuda/shape.rs index 0a5a581a..2318b7eb 100644 --- a/expander_compiler/src/zkcuda/shape.rs +++ b/expander_compiler/src/zkcuda/shape.rs @@ -83,22 +83,41 @@ impl Entry { if self.axes.is_none() { return shape.to_vec(); } +println!("self axes {:?}", self.axes); +println!("self shape {:?}", self.shape); let mut segments = vec![]; - let mut cur_prod = 1; - let mut target = 1; - let mut self_shape_iter = self.shape.iter(); - for &x in shape.iter() { - if cur_prod == target { - cur_prod = x.0; - target = *self_shape_iter.next().unwrap(); - segments.push(vec![x]); - } else { - cur_prod *= x.0; - segments.last_mut().unwrap().push(x); + // let mut cur_prod = 1; + // let mut target = 1; + // let mut self_shape_iter = self.shape.iter(); + // for &x in shape.iter() { + // if cur_prod == target { + // cur_prod = x.0; + // target = *self_shape_iter.next().unwrap(); + // segments.push(vec![x]); + // } else { + // cur_prod *= x.0; + // segments.last_mut().unwrap().push(x); + // } + // } + // assert_eq!(cur_prod, target); + // assert_eq!(self_shape_iter.next(), None); + let mut shape_iter = shape.iter(); + for &x in self.shape.iter() { + // if x == 1 { + // continue; + // } + let mut cur_prod = 1; + segments.push(vec![]); + while let Some(y) = shape_iter.next() { + cur_prod *= y.0; + segments.last_mut().unwrap().push(*y); + if cur_prod == x { + break; + } } + assert_eq!(cur_prod, x); } - assert_eq!(cur_prod, target); - assert_eq!(self_shape_iter.next(), None); +println!("segments {:?}", segments); let mut res = Vec::with_capacity(shape.len()); for i in self.axes.as_ref().unwrap() { res.extend(segments[*i].iter()); @@ -110,6 +129,9 @@ impl Entry { return products.to_vec(); } let ts = self.transposed_shape(); +println!("undo transpose shape products {:?}", products); +println!("self axes {:?}", self.axes); +println!("transposed shape {:?}", ts); let mut segments_in_ts = vec![Vec::new(); ts.len()]; let mut cur_ts_prod = 1; let mut cur_ts_idx = 0; @@ -167,9 +189,13 @@ pub fn prefix_products(shape: &[usize]) -> Vec { } pub fn prefix_products_to_shape(products: &[usize]) -> Vec { - let mut shape = Vec::with_capacity(products.len() - 1); + // let mut shape = Vec::with_capacity(products.len() - 1); + // for i in 1..products.len() { + // shape.push(products[i] / products[i - 1]); + // } + let mut shape = products.to_vec(); for i in 1..products.len() { - shape.push(products[i] / products[i - 1]); + shape[i] /= products[i - 1]; } shape } @@ -274,9 +300,12 @@ impl ShapeHistory { cur = if e.axes.as_ref().is_none() { cur } else if cur.is_none() { - Some(e.transpose_shape(&initial_shape())) +println!("intiial shape {:?}", initial_shape()); + // Some(e.transpose_shape(&initial_shape())) + Some(e.minimize(false).transpose_shape(&initial_shape())) } else { - Some(e.transpose_shape(&cur.unwrap())) + // Some(e.transpose_shape(&cur.unwrap())) + Some(e.minimize(false).transpose_shape(&cur.unwrap())) }; } let new_shape_and_id = match cur { diff --git a/expander_compiler/tests/circuit1.rs b/expander_compiler/tests/circuit1.rs index acc79f36..77ff29ab 100755 --- a/expander_compiler/tests/circuit1.rs +++ b/expander_compiler/tests/circuit1.rs @@ -1,164 +1,162 @@ use expander_compiler::frontend::*; use expander_compiler::zkcuda::{context::*, kernel::*}; -use gkr::BN254ConfigSha2Hyrax; -use gkr_engine::FieldEngine; use expander_compiler::zkcuda::shape::{Reshape, Transpose}; use serdes::ExpSerde; -use serde::{Deserialize, Serialize}; -use std::fs; + +#[allow(dead_code)] struct Circuit { output: Vec>, input: Vec>>>, - _features_features_0_Conv_output_0_conv: Vec>>>, - _features_features_0_Conv_output_0_div: Vec>>>, - _features_features_0_Conv_output_0_rem: Vec>>>, - _features_features_0_Conv_output_0_floor: Vec>>>, - _features_features_2_Relu_output_0: Vec>>>, - _features_features_3_Conv_output_0_conv: Vec>>>, - _features_features_3_Conv_output_0_div: Vec>>>, - _features_features_3_Conv_output_0_rem: Vec>>>, - _features_features_3_Conv_output_0_floor: Vec>>>, - _features_features_5_Relu_output_0: Vec>>>, - _features_features_6_MaxPool_output_0: Vec>>>, - _features_features_7_Conv_output_0_conv: Vec>>>, - _features_features_7_Conv_output_0_div: Vec>>>, - _features_features_7_Conv_output_0_rem: Vec>>>, - _features_features_7_Conv_output_0_floor: Vec>>>, - _features_features_9_Relu_output_0: Vec>>>, - _features_features_10_Conv_output_0_conv: Vec>>>, - _features_features_10_Conv_output_0_div: Vec>>>, - _features_features_10_Conv_output_0_rem: Vec>>>, - _features_features_10_Conv_output_0_floor: Vec>>>, - _features_features_12_Relu_output_0: Vec>>>, - _features_features_13_MaxPool_output_0: Vec>>>, - _features_features_14_Conv_output_0_conv: Vec>>>, - _features_features_14_Conv_output_0_div: Vec>>>, - _features_features_14_Conv_output_0_rem: Vec>>>, - _features_features_14_Conv_output_0_floor: Vec>>>, - _features_features_16_Relu_output_0: Vec>>>, - _features_features_17_Conv_output_0_conv: Vec>>>, - _features_features_17_Conv_output_0_div: Vec>>>, - _features_features_17_Conv_output_0_rem: Vec>>>, - _features_features_17_Conv_output_0_floor: Vec>>>, - _features_features_19_Relu_output_0: Vec>>>, - _features_features_20_Conv_output_0_conv: Vec>>>, - _features_features_20_Conv_output_0_div: Vec>>>, - _features_features_20_Conv_output_0_rem: Vec>>>, - _features_features_20_Conv_output_0_floor: Vec>>>, - _features_features_22_Relu_output_0: Vec>>>, - _features_features_23_MaxPool_output_0: Vec>>>, - _features_features_24_Conv_output_0_conv: Vec>>>, - _features_features_24_Conv_output_0_div: Vec>>>, - _features_features_24_Conv_output_0_rem: Vec>>>, - _features_features_24_Conv_output_0_floor: Vec>>>, - _features_features_26_Relu_output_0: Vec>>>, - _features_features_27_Conv_output_0_conv: Vec>>>, - _features_features_27_Conv_output_0_div: Vec>>>, - _features_features_27_Conv_output_0_rem: Vec>>>, - _features_features_27_Conv_output_0_floor: Vec>>>, - _features_features_29_Relu_output_0: Vec>>>, - _features_features_30_Conv_output_0_conv: Vec>>>, - _features_features_30_Conv_output_0_div: Vec>>>, - _features_features_30_Conv_output_0_rem: Vec>>>, - _features_features_30_Conv_output_0_floor: Vec>>>, - _features_features_32_Relu_output_0: Vec>>>, - _features_features_33_MaxPool_output_0: Vec>>>, - _features_features_34_Conv_output_0_conv: Vec>>>, - _features_features_34_Conv_output_0_div: Vec>>>, - _features_features_34_Conv_output_0_rem: Vec>>>, - _features_features_34_Conv_output_0_floor: Vec>>>, - _features_features_36_Relu_output_0: Vec>>>, - _features_features_37_Conv_output_0_conv: Vec>>>, - _features_features_37_Conv_output_0_div: Vec>>>, - _features_features_37_Conv_output_0_rem: Vec>>>, - _features_features_37_Conv_output_0_floor: Vec>>>, - _features_features_39_Relu_output_0: Vec>>>, - _features_features_40_Conv_output_0_conv: Vec>>>, - _features_features_40_Conv_output_0_div: Vec>>>, - _features_features_40_Conv_output_0_rem: Vec>>>, - _features_features_40_Conv_output_0_floor: Vec>>>, - _features_features_42_Relu_output_0: Vec>>>, - _features_features_43_MaxPool_output_0: Vec>>>, + _features_features_0_conv_output_0_conv: Vec>>>, + _features_features_0_conv_output_0_div: Vec>>>, + _features_features_0_conv_output_0_rem: Vec>>>, + _features_features_0_conv_output_0_floor: Vec>>>, + _features_features_2_relu_output_0: Vec>>>, + _features_features_3_conv_output_0_conv: Vec>>>, + _features_features_3_conv_output_0_div: Vec>>>, + _features_features_3_conv_output_0_rem: Vec>>>, + _features_features_3_conv_output_0_floor: Vec>>>, + _features_features_5_relu_output_0: Vec>>>, + _features_features_6_maxpool_output_0: Vec>>>, + _features_features_7_conv_output_0_conv: Vec>>>, + _features_features_7_conv_output_0_div: Vec>>>, + _features_features_7_conv_output_0_rem: Vec>>>, + _features_features_7_conv_output_0_floor: Vec>>>, + _features_features_9_relu_output_0: Vec>>>, + _features_features_10_conv_output_0_conv: Vec>>>, + _features_features_10_conv_output_0_div: Vec>>>, + _features_features_10_conv_output_0_rem: Vec>>>, + _features_features_10_conv_output_0_floor: Vec>>>, + _features_features_12_relu_output_0: Vec>>>, + _features_features_13_maxpool_output_0: Vec>>>, + _features_features_14_conv_output_0_conv: Vec>>>, + _features_features_14_conv_output_0_div: Vec>>>, + _features_features_14_conv_output_0_rem: Vec>>>, + _features_features_14_conv_output_0_floor: Vec>>>, + _features_features_16_relu_output_0: Vec>>>, + _features_features_17_conv_output_0_conv: Vec>>>, + _features_features_17_conv_output_0_div: Vec>>>, + _features_features_17_conv_output_0_rem: Vec>>>, + _features_features_17_conv_output_0_floor: Vec>>>, + _features_features_19_relu_output_0: Vec>>>, + _features_features_20_conv_output_0_conv: Vec>>>, + _features_features_20_conv_output_0_div: Vec>>>, + _features_features_20_conv_output_0_rem: Vec>>>, + _features_features_20_conv_output_0_floor: Vec>>>, + _features_features_22_relu_output_0: Vec>>>, + _features_features_23_maxpool_output_0: Vec>>>, + _features_features_24_conv_output_0_conv: Vec>>>, + _features_features_24_conv_output_0_div: Vec>>>, + _features_features_24_conv_output_0_rem: Vec>>>, + _features_features_24_conv_output_0_floor: Vec>>>, + _features_features_26_relu_output_0: Vec>>>, + _features_features_27_conv_output_0_conv: Vec>>>, + _features_features_27_conv_output_0_div: Vec>>>, + _features_features_27_conv_output_0_rem: Vec>>>, + _features_features_27_conv_output_0_floor: Vec>>>, + _features_features_29_relu_output_0: Vec>>>, + _features_features_30_conv_output_0_conv: Vec>>>, + _features_features_30_conv_output_0_div: Vec>>>, + _features_features_30_conv_output_0_rem: Vec>>>, + _features_features_30_conv_output_0_floor: Vec>>>, + _features_features_32_relu_output_0: Vec>>>, + _features_features_33_maxpool_output_0: Vec>>>, + _features_features_34_conv_output_0_conv: Vec>>>, + _features_features_34_conv_output_0_div: Vec>>>, + _features_features_34_conv_output_0_rem: Vec>>>, + _features_features_34_conv_output_0_floor: Vec>>>, + _features_features_36_relu_output_0: Vec>>>, + _features_features_37_conv_output_0_conv: Vec>>>, + _features_features_37_conv_output_0_div: Vec>>>, + _features_features_37_conv_output_0_rem: Vec>>>, + _features_features_37_conv_output_0_floor: Vec>>>, + _features_features_39_relu_output_0: Vec>>>, + _features_features_40_conv_output_0_conv: Vec>>>, + _features_features_40_conv_output_0_div: Vec>>>, + _features_features_40_conv_output_0_rem: Vec>>>, + _features_features_40_conv_output_0_floor: Vec>>>, + _features_features_42_relu_output_0: Vec>>>, + _features_features_43_maxpool_output_0: Vec>>>, _avgpool_GlobalAveragePool_output_0: Vec>>>, - _classifier_classifier_0_Gemm_output_0_matmul: Vec>, - _classifier_classifier_0_Gemm_output_0_div: Vec>, - _classifier_classifier_0_Gemm_output_0_rem: Vec>, - _classifier_classifier_0_Gemm_output_0_floor: Vec>, - _classifier_classifier_1_Relu_output_0: Vec>, - _classifier_classifier_3_Gemm_output_0_matmul: Vec>, - _classifier_classifier_3_Gemm_output_0_div: Vec>, - _classifier_classifier_3_Gemm_output_0_rem: Vec>, - _classifier_classifier_3_Gemm_output_0_floor: Vec>, - _classifier_classifier_4_Relu_output_0: Vec>, + _classifier_classifier_0_gemm_output_0_matmul: Vec>, + _classifier_classifier_0_gemm_output_0_div: Vec>, + _classifier_classifier_0_gemm_output_0_rem: Vec>, + _classifier_classifier_0_gemm_output_0_floor: Vec>, + _classifier_classifier_1_relu_output_0: Vec>, + _classifier_classifier_3_gemm_output_0_matmul: Vec>, + _classifier_classifier_3_gemm_output_0_div: Vec>, + _classifier_classifier_3_gemm_output_0_rem: Vec>, + _classifier_classifier_3_gemm_output_0_floor: Vec>, + _classifier_classifier_4_relu_output_0: Vec>, output_matmul: Vec>, output_div: Vec>, output_rem: Vec>, output_floor: Vec>, - onnx__Conv_150: Vec>>>, - onnx__Conv_151: Vec, - onnx__Conv_151_q: Vec>>, - onnx__Conv_150_nscale: BN254Fr, - onnx__Conv_150_dscale: BN254Fr, - onnx__Conv_153: Vec>>>, - onnx__Conv_154: Vec, - onnx__Conv_154_q: Vec>>, - onnx__Conv_153_nscale: BN254Fr, - onnx__Conv_153_dscale: BN254Fr, - onnx__Conv_156: Vec>>>, - onnx__Conv_157: Vec, - onnx__Conv_157_q: Vec>>, - onnx__Conv_156_nscale: BN254Fr, - onnx__Conv_156_dscale: BN254Fr, - onnx__Conv_159: Vec>>>, - onnx__Conv_160: Vec, - onnx__Conv_160_q: Vec>>, - onnx__Conv_159_nscale: BN254Fr, - onnx__Conv_159_dscale: BN254Fr, - onnx__Conv_162: Vec>>>, - onnx__Conv_163: Vec, - onnx__Conv_163_q: Vec>>, - onnx__Conv_162_nscale: BN254Fr, - onnx__Conv_162_dscale: BN254Fr, - onnx__Conv_165: Vec>>>, - onnx__Conv_166: Vec, - onnx__Conv_166_q: Vec>>, - onnx__Conv_165_nscale: BN254Fr, - onnx__Conv_165_dscale: BN254Fr, - onnx__Conv_168: Vec>>>, - onnx__Conv_169: Vec, - onnx__Conv_169_q: Vec>>, - onnx__Conv_168_nscale: BN254Fr, - onnx__Conv_168_dscale: BN254Fr, - onnx__Conv_171: Vec>>>, - onnx__Conv_172: Vec, - onnx__Conv_172_q: Vec>>, - onnx__Conv_171_nscale: BN254Fr, - onnx__Conv_171_dscale: BN254Fr, - onnx__Conv_174: Vec>>>, - onnx__Conv_175: Vec, - onnx__Conv_175_q: Vec>>, - onnx__Conv_174_nscale: BN254Fr, - onnx__Conv_174_dscale: BN254Fr, - onnx__Conv_177: Vec>>>, - onnx__Conv_178: Vec, - onnx__Conv_178_q: Vec>>, - onnx__Conv_177_nscale: BN254Fr, - onnx__Conv_177_dscale: BN254Fr, - onnx__Conv_180: Vec>>>, - onnx__Conv_181: Vec, - onnx__Conv_181_q: Vec>>, - onnx__Conv_180_nscale: BN254Fr, - onnx__Conv_180_dscale: BN254Fr, - onnx__Conv_183: Vec>>>, - onnx__Conv_184: Vec, - onnx__Conv_184_q: Vec>>, - onnx__Conv_183_nscale: BN254Fr, - onnx__Conv_183_dscale: BN254Fr, - onnx__Conv_186: Vec>>>, - onnx__Conv_187: Vec, - onnx__Conv_187_q: Vec>>, - onnx__Conv_186_nscale: BN254Fr, - onnx__Conv_186_dscale: BN254Fr, + onnx_conv_150: Vec>>>, + onnx_conv_151: Vec, + onnx_conv_151_q: Vec>>, + onnx_conv_150_nscale: BN254Fr, + onnx_conv_150_dscale: BN254Fr, + onnx_conv_153: Vec>>>, + onnx_conv_154: Vec, + onnx_conv_154_q: Vec>>, + onnx_conv_153_nscale: BN254Fr, + onnx_conv_153_dscale: BN254Fr, + onnx_conv_156: Vec>>>, + onnx_conv_157: Vec, + onnx_conv_157_q: Vec>>, + onnx_conv_156_nscale: BN254Fr, + onnx_conv_156_dscale: BN254Fr, + onnx_conv_159: Vec>>>, + onnx_conv_160: Vec, + onnx_conv_160_q: Vec>>, + onnx_conv_159_nscale: BN254Fr, + onnx_conv_159_dscale: BN254Fr, + onnx_conv_162: Vec>>>, + onnx_conv_163: Vec, + onnx_conv_163_q: Vec>>, + onnx_conv_162_nscale: BN254Fr, + onnx_conv_162_dscale: BN254Fr, + onnx_conv_165: Vec>>>, + onnx_conv_166: Vec, + onnx_conv_166_q: Vec>>, + onnx_conv_165_nscale: BN254Fr, + onnx_conv_165_dscale: BN254Fr, + onnx_conv_168: Vec>>>, + onnx_conv_169: Vec, + onnx_conv_169_q: Vec>>, + onnx_conv_168_nscale: BN254Fr, + onnx_conv_168_dscale: BN254Fr, + onnx_conv_171: Vec>>>, + onnx_conv_172: Vec, + onnx_conv_172_q: Vec>>, + onnx_conv_171_nscale: BN254Fr, + onnx_conv_171_dscale: BN254Fr, + onnx_conv_174: Vec>>>, + onnx_conv_175: Vec, + onnx_conv_175_q: Vec>>, + onnx_conv_174_nscale: BN254Fr, + onnx_conv_174_dscale: BN254Fr, + onnx_conv_177: Vec>>>, + onnx_conv_178: Vec, + onnx_conv_178_q: Vec>>, + onnx_conv_177_nscale: BN254Fr, + onnx_conv_177_dscale: BN254Fr, + onnx_conv_180: Vec>>>, + onnx_conv_181: Vec, + onnx_conv_181_q: Vec>>, + onnx_conv_180_nscale: BN254Fr, + onnx_conv_180_dscale: BN254Fr, + onnx_conv_183: Vec>>>, + onnx_conv_184: Vec, + onnx_conv_184_q: Vec>>, + onnx_conv_183_nscale: BN254Fr, + onnx_conv_183_dscale: BN254Fr, + onnx_conv_186: Vec>>>, + onnx_conv_187: Vec, + onnx_conv_187_q: Vec>>, + onnx_conv_186_nscale: BN254Fr, + onnx_conv_186_dscale: BN254Fr, classifier_0_weight: Vec>, classifier_0_bias_q: Vec, classifier_0_weight_nscale: BN254Fr, @@ -172,192 +170,192 @@ struct Circuit { classifier_6_weight_nscale: BN254Fr, classifier_6_weight_dscale: BN254Fr, input_mat_ru: Vec, - onnx__Conv_150_mat_rv: Vec, - _features_features_2_Relu_output_0_mat_ru: Vec, - onnx__Conv_153_mat_rv: Vec, - _features_features_6_MaxPool_output_0_mat_ru: Vec, - onnx__Conv_156_mat_rv: Vec, - _features_features_9_Relu_output_0_mat_ru: Vec, - onnx__Conv_159_mat_rv: Vec, - _features_features_13_MaxPool_output_0_mat_ru: Vec, - onnx__Conv_162_mat_rv: Vec, - _features_features_16_Relu_output_0_mat_ru: Vec, - onnx__Conv_165_mat_rv: Vec, - _features_features_19_Relu_output_0_mat_ru: Vec, - onnx__Conv_168_mat_rv: Vec, - _features_features_23_MaxPool_output_0_mat_ru: Vec, - onnx__Conv_171_mat_rv: Vec, - _features_features_26_Relu_output_0_mat_ru: Vec, - onnx__Conv_174_mat_rv: Vec, - _features_features_29_Relu_output_0_mat_ru: Vec, - onnx__Conv_177_mat_rv: Vec, - _features_features_33_MaxPool_output_0_mat_ru: Vec, - onnx__Conv_180_mat_rv: Vec, - _features_features_36_Relu_output_0_mat_ru: Vec, - onnx__Conv_183_mat_rv: Vec, - _features_features_39_Relu_output_0_mat_ru: Vec, - onnx__Conv_186_mat_rv: Vec, + onnx_conv_150_mat_rv: Vec, + _features_features_2_relu_output_0_mat_ru: Vec, + onnx_conv_153_mat_rv: Vec, + _features_features_6_maxpool_output_0_mat_ru: Vec, + onnx_conv_156_mat_rv: Vec, + _features_features_9_relu_output_0_mat_ru: Vec, + onnx_conv_159_mat_rv: Vec, + _features_features_13_maxpool_output_0_mat_ru: Vec, + onnx_conv_162_mat_rv: Vec, + _features_features_16_relu_output_0_mat_ru: Vec, + onnx_conv_165_mat_rv: Vec, + _features_features_19_relu_output_0_mat_ru: Vec, + onnx_conv_168_mat_rv: Vec, + _features_features_23_maxpool_output_0_mat_ru: Vec, + onnx_conv_171_mat_rv: Vec, + _features_features_26_relu_output_0_mat_ru: Vec, + onnx_conv_174_mat_rv: Vec, + _features_features_29_relu_output_0_mat_ru: Vec, + onnx_conv_177_mat_rv: Vec, + _features_features_33_maxpool_output_0_mat_ru: Vec, + onnx_conv_180_mat_rv: Vec, + _features_features_36_relu_output_0_mat_ru: Vec, + onnx_conv_183_mat_rv: Vec, + _features_features_39_relu_output_0_mat_ru: Vec, + onnx_conv_186_mat_rv: Vec, _Flatten_output_0_mat_ru: Vec, classifier_0_weight_mat_rv: Vec, - _classifier_classifier_1_Relu_output_0_mat_ru: Vec, + _classifier_classifier_1_relu_output_0_mat_ru: Vec, classifier_3_weight_mat_rv: Vec, - _classifier_classifier_4_Relu_output_0_mat_ru: Vec, + _classifier_classifier_4_relu_output_0_mat_ru: Vec, classifier_6_weight_mat_rv: Vec, } fn default_variable() -> Circuit{ let output = vec![vec![BN254Fr::default();10];1]; let input = vec![vec![vec![vec![BN254Fr::default();32];32];3];1]; - let _features_features_0_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; - let _features_features_0_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; - let _features_features_0_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; - let _features_features_0_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; - let _features_features_2_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; - let _features_features_3_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; - let _features_features_3_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; - let _features_features_3_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; - let _features_features_3_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; - let _features_features_5_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; - let _features_features_6_MaxPool_output_0 = vec![vec![vec![vec![BN254Fr::default();16];16];64];1]; - let _features_features_7_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; - let _features_features_7_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; - let _features_features_7_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; - let _features_features_7_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; - let _features_features_9_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; - let _features_features_10_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; - let _features_features_10_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; - let _features_features_10_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; - let _features_features_10_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; - let _features_features_12_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; - let _features_features_13_MaxPool_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];128];1]; - let _features_features_14_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_14_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_14_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_14_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_16_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_17_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_17_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_17_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_17_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_19_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_20_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_20_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_20_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_20_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_22_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; - let _features_features_23_MaxPool_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];256];1]; - let _features_features_24_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_24_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_24_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_24_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_26_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_27_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_27_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_27_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_27_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_29_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_30_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_30_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_30_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_30_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_32_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; - let _features_features_33_MaxPool_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_34_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_34_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_34_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_34_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_36_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_37_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_37_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_37_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_37_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_39_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_40_Conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_40_Conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_40_Conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_40_Conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_42_Relu_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; - let _features_features_43_MaxPool_output_0 = vec![vec![vec![vec![BN254Fr::default();1];1];512];1]; + let _features_features_0_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_0_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_0_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_0_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_2_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_3_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_3_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_3_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_3_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_5_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_6_maxpool_output_0 = vec![vec![vec![vec![BN254Fr::default();16];16];64];1]; + let _features_features_7_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_7_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_7_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_7_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_9_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_10_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_10_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_10_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_10_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_12_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_13_maxpool_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];128];1]; + let _features_features_14_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_14_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_14_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_14_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_16_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_17_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_17_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_17_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_17_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_19_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_20_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_20_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_20_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_20_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_22_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_23_maxpool_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];256];1]; + let _features_features_24_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_24_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_24_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_24_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_26_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_27_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_27_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_27_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_27_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_29_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_30_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_30_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_30_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_30_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_32_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_33_maxpool_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_34_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_34_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_34_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_34_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_36_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_37_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_37_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_37_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_37_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_39_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_40_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_40_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_40_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_40_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_42_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_43_maxpool_output_0 = vec![vec![vec![vec![BN254Fr::default();1];1];512];1]; let _avgpool_GlobalAveragePool_output_0 = vec![vec![vec![vec![BN254Fr::default();1];1];512];1]; - let _classifier_classifier_0_Gemm_output_0_matmul = vec![vec![BN254Fr::default();512];1]; - let _classifier_classifier_0_Gemm_output_0_div = vec![vec![BN254Fr::default();512];1]; - let _classifier_classifier_0_Gemm_output_0_rem = vec![vec![BN254Fr::default();512];1]; - let _classifier_classifier_0_Gemm_output_0_floor = vec![vec![BN254Fr::default();512];1]; - let _classifier_classifier_1_Relu_output_0 = vec![vec![BN254Fr::default();512];1]; - let _classifier_classifier_3_Gemm_output_0_matmul = vec![vec![BN254Fr::default();512];1]; - let _classifier_classifier_3_Gemm_output_0_div = vec![vec![BN254Fr::default();512];1]; - let _classifier_classifier_3_Gemm_output_0_rem = vec![vec![BN254Fr::default();512];1]; - let _classifier_classifier_3_Gemm_output_0_floor = vec![vec![BN254Fr::default();512];1]; - let _classifier_classifier_4_Relu_output_0 = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_0_gemm_output_0_matmul = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_0_gemm_output_0_div = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_0_gemm_output_0_rem = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_0_gemm_output_0_floor = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_1_relu_output_0 = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_3_gemm_output_0_matmul = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_3_gemm_output_0_div = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_3_gemm_output_0_rem = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_3_gemm_output_0_floor = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_4_relu_output_0 = vec![vec![BN254Fr::default();512];1]; let output_matmul = vec![vec![BN254Fr::default();10];1]; let output_div = vec![vec![BN254Fr::default();10];1]; let output_rem = vec![vec![BN254Fr::default();10];1]; let output_floor = vec![vec![BN254Fr::default();10];1]; - let onnx__Conv_150 = vec![vec![vec![vec![BN254Fr::default();3];3];3];64]; - let onnx__Conv_151 = vec![BN254Fr::default();64]; - let onnx__Conv_151_q = vec![vec![vec![BN254Fr::default();1];1];64]; - let onnx__Conv_150_nscale = BN254Fr::default(); - let onnx__Conv_150_dscale = BN254Fr::default(); - let onnx__Conv_153 = vec![vec![vec![vec![BN254Fr::default();3];3];64];64]; - let onnx__Conv_154 = vec![BN254Fr::default();64]; - let onnx__Conv_154_q = vec![vec![vec![BN254Fr::default();1];1];64]; - let onnx__Conv_153_nscale = BN254Fr::default(); - let onnx__Conv_153_dscale = BN254Fr::default(); - let onnx__Conv_156 = vec![vec![vec![vec![BN254Fr::default();3];3];64];128]; - let onnx__Conv_157 = vec![BN254Fr::default();128]; - let onnx__Conv_157_q = vec![vec![vec![BN254Fr::default();1];1];128]; - let onnx__Conv_156_nscale = BN254Fr::default(); - let onnx__Conv_156_dscale = BN254Fr::default(); - let onnx__Conv_159 = vec![vec![vec![vec![BN254Fr::default();3];3];128];128]; - let onnx__Conv_160 = vec![BN254Fr::default();128]; - let onnx__Conv_160_q = vec![vec![vec![BN254Fr::default();1];1];128]; - let onnx__Conv_159_nscale = BN254Fr::default(); - let onnx__Conv_159_dscale = BN254Fr::default(); - let onnx__Conv_162 = vec![vec![vec![vec![BN254Fr::default();3];3];128];256]; - let onnx__Conv_163 = vec![BN254Fr::default();256]; - let onnx__Conv_163_q = vec![vec![vec![BN254Fr::default();1];1];256]; - let onnx__Conv_162_nscale = BN254Fr::default(); - let onnx__Conv_162_dscale = BN254Fr::default(); - let onnx__Conv_165 = vec![vec![vec![vec![BN254Fr::default();3];3];256];256]; - let onnx__Conv_166 = vec![BN254Fr::default();256]; - let onnx__Conv_166_q = vec![vec![vec![BN254Fr::default();1];1];256]; - let onnx__Conv_165_nscale = BN254Fr::default(); - let onnx__Conv_165_dscale = BN254Fr::default(); - let onnx__Conv_168 = vec![vec![vec![vec![BN254Fr::default();3];3];256];256]; - let onnx__Conv_169 = vec![BN254Fr::default();256]; - let onnx__Conv_169_q = vec![vec![vec![BN254Fr::default();1];1];256]; - let onnx__Conv_168_nscale = BN254Fr::default(); - let onnx__Conv_168_dscale = BN254Fr::default(); - let onnx__Conv_171 = vec![vec![vec![vec![BN254Fr::default();3];3];256];512]; - let onnx__Conv_172 = vec![BN254Fr::default();512]; - let onnx__Conv_172_q = vec![vec![vec![BN254Fr::default();1];1];512]; - let onnx__Conv_171_nscale = BN254Fr::default(); - let onnx__Conv_171_dscale = BN254Fr::default(); - let onnx__Conv_174 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; - let onnx__Conv_175 = vec![BN254Fr::default();512]; - let onnx__Conv_175_q = vec![vec![vec![BN254Fr::default();1];1];512]; - let onnx__Conv_174_nscale = BN254Fr::default(); - let onnx__Conv_174_dscale = BN254Fr::default(); - let onnx__Conv_177 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; - let onnx__Conv_178 = vec![BN254Fr::default();512]; - let onnx__Conv_178_q = vec![vec![vec![BN254Fr::default();1];1];512]; - let onnx__Conv_177_nscale = BN254Fr::default(); - let onnx__Conv_177_dscale = BN254Fr::default(); - let onnx__Conv_180 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; - let onnx__Conv_181 = vec![BN254Fr::default();512]; - let onnx__Conv_181_q = vec![vec![vec![BN254Fr::default();1];1];512]; - let onnx__Conv_180_nscale = BN254Fr::default(); - let onnx__Conv_180_dscale = BN254Fr::default(); - let onnx__Conv_183 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; - let onnx__Conv_184 = vec![BN254Fr::default();512]; - let onnx__Conv_184_q = vec![vec![vec![BN254Fr::default();1];1];512]; - let onnx__Conv_183_nscale = BN254Fr::default(); - let onnx__Conv_183_dscale = BN254Fr::default(); - let onnx__Conv_186 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; - let onnx__Conv_187 = vec![BN254Fr::default();512]; - let onnx__Conv_187_q = vec![vec![vec![BN254Fr::default();1];1];512]; - let onnx__Conv_186_nscale = BN254Fr::default(); - let onnx__Conv_186_dscale = BN254Fr::default(); + let onnx_conv_150 = vec![vec![vec![vec![BN254Fr::default();3];3];3];64]; + let onnx_conv_151 = vec![BN254Fr::default();64]; + let onnx_conv_151_q = vec![vec![vec![BN254Fr::default();1];1];64]; + let onnx_conv_150_nscale = BN254Fr::default(); + let onnx_conv_150_dscale = BN254Fr::default(); + let onnx_conv_153 = vec![vec![vec![vec![BN254Fr::default();3];3];64];64]; + let onnx_conv_154 = vec![BN254Fr::default();64]; + let onnx_conv_154_q = vec![vec![vec![BN254Fr::default();1];1];64]; + let onnx_conv_153_nscale = BN254Fr::default(); + let onnx_conv_153_dscale = BN254Fr::default(); + let onnx_conv_156 = vec![vec![vec![vec![BN254Fr::default();3];3];64];128]; + let onnx_conv_157 = vec![BN254Fr::default();128]; + let onnx_conv_157_q = vec![vec![vec![BN254Fr::default();1];1];128]; + let onnx_conv_156_nscale = BN254Fr::default(); + let onnx_conv_156_dscale = BN254Fr::default(); + let onnx_conv_159 = vec![vec![vec![vec![BN254Fr::default();3];3];128];128]; + let onnx_conv_160 = vec![BN254Fr::default();128]; + let onnx_conv_160_q = vec![vec![vec![BN254Fr::default();1];1];128]; + let onnx_conv_159_nscale = BN254Fr::default(); + let onnx_conv_159_dscale = BN254Fr::default(); + let onnx_conv_162 = vec![vec![vec![vec![BN254Fr::default();3];3];128];256]; + let onnx_conv_163 = vec![BN254Fr::default();256]; + let onnx_conv_163_q = vec![vec![vec![BN254Fr::default();1];1];256]; + let onnx_conv_162_nscale = BN254Fr::default(); + let onnx_conv_162_dscale = BN254Fr::default(); + let onnx_conv_165 = vec![vec![vec![vec![BN254Fr::default();3];3];256];256]; + let onnx_conv_166 = vec![BN254Fr::default();256]; + let onnx_conv_166_q = vec![vec![vec![BN254Fr::default();1];1];256]; + let onnx_conv_165_nscale = BN254Fr::default(); + let onnx_conv_165_dscale = BN254Fr::default(); + let onnx_conv_168 = vec![vec![vec![vec![BN254Fr::default();3];3];256];256]; + let onnx_conv_169 = vec![BN254Fr::default();256]; + let onnx_conv_169_q = vec![vec![vec![BN254Fr::default();1];1];256]; + let onnx_conv_168_nscale = BN254Fr::default(); + let onnx_conv_168_dscale = BN254Fr::default(); + let onnx_conv_171 = vec![vec![vec![vec![BN254Fr::default();3];3];256];512]; + let onnx_conv_172 = vec![BN254Fr::default();512]; + let onnx_conv_172_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx_conv_171_nscale = BN254Fr::default(); + let onnx_conv_171_dscale = BN254Fr::default(); + let onnx_conv_174 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx_conv_175 = vec![BN254Fr::default();512]; + let onnx_conv_175_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx_conv_174_nscale = BN254Fr::default(); + let onnx_conv_174_dscale = BN254Fr::default(); + let onnx_conv_177 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx_conv_178 = vec![BN254Fr::default();512]; + let onnx_conv_178_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx_conv_177_nscale = BN254Fr::default(); + let onnx_conv_177_dscale = BN254Fr::default(); + let onnx_conv_180 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx_conv_181 = vec![BN254Fr::default();512]; + let onnx_conv_181_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx_conv_180_nscale = BN254Fr::default(); + let onnx_conv_180_dscale = BN254Fr::default(); + let onnx_conv_183 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx_conv_184 = vec![BN254Fr::default();512]; + let onnx_conv_184_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx_conv_183_nscale = BN254Fr::default(); + let onnx_conv_183_dscale = BN254Fr::default(); + let onnx_conv_186 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx_conv_187 = vec![BN254Fr::default();512]; + let onnx_conv_187_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx_conv_186_nscale = BN254Fr::default(); + let onnx_conv_186_dscale = BN254Fr::default(); let classifier_0_weight = vec![vec![BN254Fr::default();512];512]; let classifier_0_bias_q = vec![BN254Fr::default();512]; let classifier_0_weight_nscale = BN254Fr::default(); @@ -371,57 +369,57 @@ fn default_variable() -> Circuit{ let classifier_6_weight_nscale = BN254Fr::default(); let classifier_6_weight_dscale = BN254Fr::default(); let input_mat_ru = vec![BN254Fr::default();1024]; - let onnx__Conv_150_mat_rv = vec![BN254Fr::default();64]; - let _features_features_2_Relu_output_0_mat_ru = vec![BN254Fr::default();1024]; - let onnx__Conv_153_mat_rv = vec![BN254Fr::default();64]; - let _features_features_6_MaxPool_output_0_mat_ru = vec![BN254Fr::default();256]; - let onnx__Conv_156_mat_rv = vec![BN254Fr::default();128]; - let _features_features_9_Relu_output_0_mat_ru = vec![BN254Fr::default();256]; - let onnx__Conv_159_mat_rv = vec![BN254Fr::default();128]; - let _features_features_13_MaxPool_output_0_mat_ru = vec![BN254Fr::default();64]; - let onnx__Conv_162_mat_rv = vec![BN254Fr::default();256]; - let _features_features_16_Relu_output_0_mat_ru = vec![BN254Fr::default();64]; - let onnx__Conv_165_mat_rv = vec![BN254Fr::default();256]; - let _features_features_19_Relu_output_0_mat_ru = vec![BN254Fr::default();64]; - let onnx__Conv_168_mat_rv = vec![BN254Fr::default();256]; - let _features_features_23_MaxPool_output_0_mat_ru = vec![BN254Fr::default();16]; - let onnx__Conv_171_mat_rv = vec![BN254Fr::default();512]; - let _features_features_26_Relu_output_0_mat_ru = vec![BN254Fr::default();16]; - let onnx__Conv_174_mat_rv = vec![BN254Fr::default();512]; - let _features_features_29_Relu_output_0_mat_ru = vec![BN254Fr::default();16]; - let onnx__Conv_177_mat_rv = vec![BN254Fr::default();512]; - let _features_features_33_MaxPool_output_0_mat_ru = vec![BN254Fr::default();4]; - let onnx__Conv_180_mat_rv = vec![BN254Fr::default();512]; - let _features_features_36_Relu_output_0_mat_ru = vec![BN254Fr::default();4]; - let onnx__Conv_183_mat_rv = vec![BN254Fr::default();512]; - let _features_features_39_Relu_output_0_mat_ru = vec![BN254Fr::default();4]; - let onnx__Conv_186_mat_rv = vec![BN254Fr::default();512]; + let onnx_conv_150_mat_rv = vec![BN254Fr::default();64]; + let _features_features_2_relu_output_0_mat_ru = vec![BN254Fr::default();1024]; + let onnx_conv_153_mat_rv = vec![BN254Fr::default();64]; + let _features_features_6_maxpool_output_0_mat_ru = vec![BN254Fr::default();256]; + let onnx_conv_156_mat_rv = vec![BN254Fr::default();128]; + let _features_features_9_relu_output_0_mat_ru = vec![BN254Fr::default();256]; + let onnx_conv_159_mat_rv = vec![BN254Fr::default();128]; + let _features_features_13_maxpool_output_0_mat_ru = vec![BN254Fr::default();64]; + let onnx_conv_162_mat_rv = vec![BN254Fr::default();256]; + let _features_features_16_relu_output_0_mat_ru = vec![BN254Fr::default();64]; + let onnx_conv_165_mat_rv = vec![BN254Fr::default();256]; + let _features_features_19_relu_output_0_mat_ru = vec![BN254Fr::default();64]; + let onnx_conv_168_mat_rv = vec![BN254Fr::default();256]; + let _features_features_23_maxpool_output_0_mat_ru = vec![BN254Fr::default();16]; + let onnx_conv_171_mat_rv = vec![BN254Fr::default();512]; + let _features_features_26_relu_output_0_mat_ru = vec![BN254Fr::default();16]; + let onnx_conv_174_mat_rv = vec![BN254Fr::default();512]; + let _features_features_29_relu_output_0_mat_ru = vec![BN254Fr::default();16]; + let onnx_conv_177_mat_rv = vec![BN254Fr::default();512]; + let _features_features_33_maxpool_output_0_mat_ru = vec![BN254Fr::default();4]; + let onnx_conv_180_mat_rv = vec![BN254Fr::default();512]; + let _features_features_36_relu_output_0_mat_ru = vec![BN254Fr::default();4]; + let onnx_conv_183_mat_rv = vec![BN254Fr::default();512]; + let _features_features_39_relu_output_0_mat_ru = vec![BN254Fr::default();4]; + let onnx_conv_186_mat_rv = vec![BN254Fr::default();512]; let _Flatten_output_0_mat_ru = vec![BN254Fr::default();1]; let classifier_0_weight_mat_rv = vec![BN254Fr::default();512]; - let _classifier_classifier_1_Relu_output_0_mat_ru = vec![BN254Fr::default();1]; + let _classifier_classifier_1_relu_output_0_mat_ru = vec![BN254Fr::default();1]; let classifier_3_weight_mat_rv = vec![BN254Fr::default();512]; - let _classifier_classifier_4_Relu_output_0_mat_ru = vec![BN254Fr::default();1]; + let _classifier_classifier_4_relu_output_0_mat_ru = vec![BN254Fr::default();1]; let classifier_6_weight_mat_rv = vec![BN254Fr::default();10]; - let ass = Circuit{output,input,_features_features_0_Conv_output_0_conv,_features_features_0_Conv_output_0_div,_features_features_0_Conv_output_0_rem,_features_features_0_Conv_output_0_floor,_features_features_2_Relu_output_0,_features_features_3_Conv_output_0_conv,_features_features_3_Conv_output_0_div,_features_features_3_Conv_output_0_rem,_features_features_3_Conv_output_0_floor,_features_features_5_Relu_output_0,_features_features_6_MaxPool_output_0,_features_features_7_Conv_output_0_conv,_features_features_7_Conv_output_0_div,_features_features_7_Conv_output_0_rem,_features_features_7_Conv_output_0_floor,_features_features_9_Relu_output_0,_features_features_10_Conv_output_0_conv,_features_features_10_Conv_output_0_div,_features_features_10_Conv_output_0_rem,_features_features_10_Conv_output_0_floor,_features_features_12_Relu_output_0,_features_features_13_MaxPool_output_0,_features_features_14_Conv_output_0_conv,_features_features_14_Conv_output_0_div,_features_features_14_Conv_output_0_rem,_features_features_14_Conv_output_0_floor,_features_features_16_Relu_output_0,_features_features_17_Conv_output_0_conv,_features_features_17_Conv_output_0_div,_features_features_17_Conv_output_0_rem,_features_features_17_Conv_output_0_floor,_features_features_19_Relu_output_0,_features_features_20_Conv_output_0_conv,_features_features_20_Conv_output_0_div,_features_features_20_Conv_output_0_rem,_features_features_20_Conv_output_0_floor,_features_features_22_Relu_output_0,_features_features_23_MaxPool_output_0,_features_features_24_Conv_output_0_conv,_features_features_24_Conv_output_0_div,_features_features_24_Conv_output_0_rem,_features_features_24_Conv_output_0_floor,_features_features_26_Relu_output_0,_features_features_27_Conv_output_0_conv,_features_features_27_Conv_output_0_div,_features_features_27_Conv_output_0_rem,_features_features_27_Conv_output_0_floor,_features_features_29_Relu_output_0,_features_features_30_Conv_output_0_conv,_features_features_30_Conv_output_0_div,_features_features_30_Conv_output_0_rem,_features_features_30_Conv_output_0_floor,_features_features_32_Relu_output_0,_features_features_33_MaxPool_output_0,_features_features_34_Conv_output_0_conv,_features_features_34_Conv_output_0_div,_features_features_34_Conv_output_0_rem,_features_features_34_Conv_output_0_floor,_features_features_36_Relu_output_0,_features_features_37_Conv_output_0_conv,_features_features_37_Conv_output_0_div,_features_features_37_Conv_output_0_rem,_features_features_37_Conv_output_0_floor,_features_features_39_Relu_output_0,_features_features_40_Conv_output_0_conv,_features_features_40_Conv_output_0_div,_features_features_40_Conv_output_0_rem,_features_features_40_Conv_output_0_floor,_features_features_42_Relu_output_0,_features_features_43_MaxPool_output_0,_avgpool_GlobalAveragePool_output_0,_classifier_classifier_0_Gemm_output_0_matmul,_classifier_classifier_0_Gemm_output_0_div,_classifier_classifier_0_Gemm_output_0_rem,_classifier_classifier_0_Gemm_output_0_floor,_classifier_classifier_1_Relu_output_0,_classifier_classifier_3_Gemm_output_0_matmul,_classifier_classifier_3_Gemm_output_0_div,_classifier_classifier_3_Gemm_output_0_rem,_classifier_classifier_3_Gemm_output_0_floor,_classifier_classifier_4_Relu_output_0,output_matmul,output_div,output_rem,output_floor,onnx__Conv_150,onnx__Conv_151,onnx__Conv_151_q,onnx__Conv_150_nscale,onnx__Conv_150_dscale,onnx__Conv_153,onnx__Conv_154,onnx__Conv_154_q,onnx__Conv_153_nscale,onnx__Conv_153_dscale,onnx__Conv_156,onnx__Conv_157,onnx__Conv_157_q,onnx__Conv_156_nscale,onnx__Conv_156_dscale,onnx__Conv_159,onnx__Conv_160,onnx__Conv_160_q,onnx__Conv_159_nscale,onnx__Conv_159_dscale,onnx__Conv_162,onnx__Conv_163,onnx__Conv_163_q,onnx__Conv_162_nscale,onnx__Conv_162_dscale,onnx__Conv_165,onnx__Conv_166,onnx__Conv_166_q,onnx__Conv_165_nscale,onnx__Conv_165_dscale,onnx__Conv_168,onnx__Conv_169,onnx__Conv_169_q,onnx__Conv_168_nscale,onnx__Conv_168_dscale,onnx__Conv_171,onnx__Conv_172,onnx__Conv_172_q,onnx__Conv_171_nscale,onnx__Conv_171_dscale,onnx__Conv_174,onnx__Conv_175,onnx__Conv_175_q,onnx__Conv_174_nscale,onnx__Conv_174_dscale,onnx__Conv_177,onnx__Conv_178,onnx__Conv_178_q,onnx__Conv_177_nscale,onnx__Conv_177_dscale,onnx__Conv_180,onnx__Conv_181,onnx__Conv_181_q,onnx__Conv_180_nscale,onnx__Conv_180_dscale,onnx__Conv_183,onnx__Conv_184,onnx__Conv_184_q,onnx__Conv_183_nscale,onnx__Conv_183_dscale,onnx__Conv_186,onnx__Conv_187,onnx__Conv_187_q,onnx__Conv_186_nscale,onnx__Conv_186_dscale,classifier_0_weight,classifier_0_bias_q,classifier_0_weight_nscale,classifier_0_weight_dscale,classifier_3_weight,classifier_3_bias_q,classifier_3_weight_nscale,classifier_3_weight_dscale,classifier_6_weight,classifier_6_bias_q,classifier_6_weight_nscale,classifier_6_weight_dscale,input_mat_ru,onnx__Conv_150_mat_rv,_features_features_2_Relu_output_0_mat_ru,onnx__Conv_153_mat_rv,_features_features_6_MaxPool_output_0_mat_ru,onnx__Conv_156_mat_rv,_features_features_9_Relu_output_0_mat_ru,onnx__Conv_159_mat_rv,_features_features_13_MaxPool_output_0_mat_ru,onnx__Conv_162_mat_rv,_features_features_16_Relu_output_0_mat_ru,onnx__Conv_165_mat_rv,_features_features_19_Relu_output_0_mat_ru,onnx__Conv_168_mat_rv,_features_features_23_MaxPool_output_0_mat_ru,onnx__Conv_171_mat_rv,_features_features_26_Relu_output_0_mat_ru,onnx__Conv_174_mat_rv,_features_features_29_Relu_output_0_mat_ru,onnx__Conv_177_mat_rv,_features_features_33_MaxPool_output_0_mat_ru,onnx__Conv_180_mat_rv,_features_features_36_Relu_output_0_mat_ru,onnx__Conv_183_mat_rv,_features_features_39_Relu_output_0_mat_ru,onnx__Conv_186_mat_rv,_Flatten_output_0_mat_ru,classifier_0_weight_mat_rv,_classifier_classifier_1_Relu_output_0_mat_ru,classifier_3_weight_mat_rv,_classifier_classifier_4_Relu_output_0_mat_ru,classifier_6_weight_mat_rv}; + let ass = Circuit{output,input,_features_features_0_conv_output_0_conv,_features_features_0_conv_output_0_div,_features_features_0_conv_output_0_rem,_features_features_0_conv_output_0_floor,_features_features_2_relu_output_0,_features_features_3_conv_output_0_conv,_features_features_3_conv_output_0_div,_features_features_3_conv_output_0_rem,_features_features_3_conv_output_0_floor,_features_features_5_relu_output_0,_features_features_6_maxpool_output_0,_features_features_7_conv_output_0_conv,_features_features_7_conv_output_0_div,_features_features_7_conv_output_0_rem,_features_features_7_conv_output_0_floor,_features_features_9_relu_output_0,_features_features_10_conv_output_0_conv,_features_features_10_conv_output_0_div,_features_features_10_conv_output_0_rem,_features_features_10_conv_output_0_floor,_features_features_12_relu_output_0,_features_features_13_maxpool_output_0,_features_features_14_conv_output_0_conv,_features_features_14_conv_output_0_div,_features_features_14_conv_output_0_rem,_features_features_14_conv_output_0_floor,_features_features_16_relu_output_0,_features_features_17_conv_output_0_conv,_features_features_17_conv_output_0_div,_features_features_17_conv_output_0_rem,_features_features_17_conv_output_0_floor,_features_features_19_relu_output_0,_features_features_20_conv_output_0_conv,_features_features_20_conv_output_0_div,_features_features_20_conv_output_0_rem,_features_features_20_conv_output_0_floor,_features_features_22_relu_output_0,_features_features_23_maxpool_output_0,_features_features_24_conv_output_0_conv,_features_features_24_conv_output_0_div,_features_features_24_conv_output_0_rem,_features_features_24_conv_output_0_floor,_features_features_26_relu_output_0,_features_features_27_conv_output_0_conv,_features_features_27_conv_output_0_div,_features_features_27_conv_output_0_rem,_features_features_27_conv_output_0_floor,_features_features_29_relu_output_0,_features_features_30_conv_output_0_conv,_features_features_30_conv_output_0_div,_features_features_30_conv_output_0_rem,_features_features_30_conv_output_0_floor,_features_features_32_relu_output_0,_features_features_33_maxpool_output_0,_features_features_34_conv_output_0_conv,_features_features_34_conv_output_0_div,_features_features_34_conv_output_0_rem,_features_features_34_conv_output_0_floor,_features_features_36_relu_output_0,_features_features_37_conv_output_0_conv,_features_features_37_conv_output_0_div,_features_features_37_conv_output_0_rem,_features_features_37_conv_output_0_floor,_features_features_39_relu_output_0,_features_features_40_conv_output_0_conv,_features_features_40_conv_output_0_div,_features_features_40_conv_output_0_rem,_features_features_40_conv_output_0_floor,_features_features_42_relu_output_0,_features_features_43_maxpool_output_0,_avgpool_GlobalAveragePool_output_0,_classifier_classifier_0_gemm_output_0_matmul,_classifier_classifier_0_gemm_output_0_div,_classifier_classifier_0_gemm_output_0_rem,_classifier_classifier_0_gemm_output_0_floor,_classifier_classifier_1_relu_output_0,_classifier_classifier_3_gemm_output_0_matmul,_classifier_classifier_3_gemm_output_0_div,_classifier_classifier_3_gemm_output_0_rem,_classifier_classifier_3_gemm_output_0_floor,_classifier_classifier_4_relu_output_0,output_matmul,output_div,output_rem,output_floor,onnx_conv_150,onnx_conv_151,onnx_conv_151_q,onnx_conv_150_nscale,onnx_conv_150_dscale,onnx_conv_153,onnx_conv_154,onnx_conv_154_q,onnx_conv_153_nscale,onnx_conv_153_dscale,onnx_conv_156,onnx_conv_157,onnx_conv_157_q,onnx_conv_156_nscale,onnx_conv_156_dscale,onnx_conv_159,onnx_conv_160,onnx_conv_160_q,onnx_conv_159_nscale,onnx_conv_159_dscale,onnx_conv_162,onnx_conv_163,onnx_conv_163_q,onnx_conv_162_nscale,onnx_conv_162_dscale,onnx_conv_165,onnx_conv_166,onnx_conv_166_q,onnx_conv_165_nscale,onnx_conv_165_dscale,onnx_conv_168,onnx_conv_169,onnx_conv_169_q,onnx_conv_168_nscale,onnx_conv_168_dscale,onnx_conv_171,onnx_conv_172,onnx_conv_172_q,onnx_conv_171_nscale,onnx_conv_171_dscale,onnx_conv_174,onnx_conv_175,onnx_conv_175_q,onnx_conv_174_nscale,onnx_conv_174_dscale,onnx_conv_177,onnx_conv_178,onnx_conv_178_q,onnx_conv_177_nscale,onnx_conv_177_dscale,onnx_conv_180,onnx_conv_181,onnx_conv_181_q,onnx_conv_180_nscale,onnx_conv_180_dscale,onnx_conv_183,onnx_conv_184,onnx_conv_184_q,onnx_conv_183_nscale,onnx_conv_183_dscale,onnx_conv_186,onnx_conv_187,onnx_conv_187_q,onnx_conv_186_nscale,onnx_conv_186_dscale,classifier_0_weight,classifier_0_bias_q,classifier_0_weight_nscale,classifier_0_weight_dscale,classifier_3_weight,classifier_3_bias_q,classifier_3_weight_nscale,classifier_3_weight_dscale,classifier_6_weight,classifier_6_bias_q,classifier_6_weight_nscale,classifier_6_weight_dscale,input_mat_ru,onnx_conv_150_mat_rv,_features_features_2_relu_output_0_mat_ru,onnx_conv_153_mat_rv,_features_features_6_maxpool_output_0_mat_ru,onnx_conv_156_mat_rv,_features_features_9_relu_output_0_mat_ru,onnx_conv_159_mat_rv,_features_features_13_maxpool_output_0_mat_ru,onnx_conv_162_mat_rv,_features_features_16_relu_output_0_mat_ru,onnx_conv_165_mat_rv,_features_features_19_relu_output_0_mat_ru,onnx_conv_168_mat_rv,_features_features_23_maxpool_output_0_mat_ru,onnx_conv_171_mat_rv,_features_features_26_relu_output_0_mat_ru,onnx_conv_174_mat_rv,_features_features_29_relu_output_0_mat_ru,onnx_conv_177_mat_rv,_features_features_33_maxpool_output_0_mat_ru,onnx_conv_180_mat_rv,_features_features_36_relu_output_0_mat_ru,onnx_conv_183_mat_rv,_features_features_39_relu_output_0_mat_ru,onnx_conv_186_mat_rv,_Flatten_output_0_mat_ru,classifier_0_weight_mat_rv,_classifier_classifier_1_relu_output_0_mat_ru,classifier_3_weight_mat_rv,_classifier_classifier_4_relu_output_0_mat_ru,classifier_6_weight_mat_rv}; ass } #[kernel] -fn _features_features_0_Conv_conv_copy_macro( +fn _features_features_0_conv_conv_copy_macro( api: &mut API, - onnx__Conv_150: &[[[[InputVariable;3];3];3];64], - _features_features_0_Conv_output_0_conv: &[[[[InputVariable;32];32];64];1], + onnx_conv_150: &[[[[InputVariable;3];3];3];64], + _features_features_0_conv_output_0_conv: &[[[[InputVariable;32];32];64];1], input: &[[[[InputVariable;32];32];3];1], - onnx__Conv_150_mat: &mut [[OutputVariable;64];27], - _features_features_0_Conv_output_0_conv_mat: &mut [[OutputVariable;1024];64], + onnx_conv_150_mat: &mut [[OutputVariable;64];27], + _features_features_0_conv_output_0_conv_mat: &mut [[OutputVariable;1024];64], input_mat: &mut [[OutputVariable;1024];27], ) { // for i in 0..64 { // for j in 0..3 { // for k in 0..3 { // for l in 0..3 { - // onnx__Conv_150_mat[((j)*3 + k)*3 + l][i] = onnx__Conv_150[i][j][k][l]; + // onnx_conv_150_mat[((j)*3 + k)*3 + l][i] = onnx_conv_150[i][j][k][l]; // } // } // } @@ -430,7 +428,7 @@ fn _features_features_0_Conv_conv_copy_macro( // for j in 0..64 { // for k in 0..32 { // for l in 0..32 { - // _features_features_0_Conv_output_0_conv_mat[j][((i)*32 + k)*32 + l] = _features_features_0_Conv_output_0_conv[i][j][k][l]; + // _features_features_0_conv_output_0_conv_mat[j][((i)*32 + k)*32 + l] = _features_features_0_conv_output_0_conv[i][j][k][l]; // } // } // } @@ -456,66 +454,66 @@ fn _features_features_0_Conv_conv_copy_macro( } #[kernel] -fn _features_features_0_Conv_conv_ab_matrix_macro( +fn _features_features_0_conv_conv_ab_matrix_macro( api: &mut API, input_mat: & [InputVariable;1024], - onnx__Conv_150_mat: & [InputVariable;64], + onnx_conv_150_mat: & [InputVariable;64], input_mat_ru: & [InputVariable;1024], - onnx__Conv_150_mat_rv: & [InputVariable;64], - _features_features_0_Conv_conv_ab_matrix_rx: &mut OutputVariable, - _features_features_0_Conv_conv_ab_matrix_ry: &mut OutputVariable, + onnx_conv_150_mat_rv: & [InputVariable;64], + _features_features_0_conv_conv_ab_matrix_rx: &mut OutputVariable, + _features_features_0_conv_conv_ab_matrix_ry: &mut OutputVariable, ) { - *_features_features_0_Conv_conv_ab_matrix_rx = api.constant(0); + *_features_features_0_conv_conv_ab_matrix_rx = api.constant(0); for i in 0..1024 { let tmp = api.mul(input_mat_ru[i], input_mat[i]); - *_features_features_0_Conv_conv_ab_matrix_rx = api.add(tmp, *_features_features_0_Conv_conv_ab_matrix_rx); + *_features_features_0_conv_conv_ab_matrix_rx = api.add(tmp, *_features_features_0_conv_conv_ab_matrix_rx); } - *_features_features_0_Conv_conv_ab_matrix_ry = api.constant(0); + *_features_features_0_conv_conv_ab_matrix_ry = api.constant(0); for i in 0..64 { - let tmp = api.mul(onnx__Conv_150_mat_rv[i], onnx__Conv_150_mat[i]); - *_features_features_0_Conv_conv_ab_matrix_ry = api.add(tmp, *_features_features_0_Conv_conv_ab_matrix_ry); + let tmp = api.mul(onnx_conv_150_mat_rv[i], onnx_conv_150_mat[i]); + *_features_features_0_conv_conv_ab_matrix_ry = api.add(tmp, *_features_features_0_conv_conv_ab_matrix_ry); } } #[kernel] -fn _features_features_0_Conv_conv_c_matrix_macro( +fn _features_features_0_conv_conv_c_matrix_macro( api: &mut API, - _features_features_0_Conv_output_0_conv_mat: & [InputVariable;1024], + _features_features_0_conv_output_0_conv_mat: & [InputVariable;1024], input_mat_ru: & [InputVariable;1024], - _features_features_0_Conv_conv_c_matrix_rz: &mut OutputVariable, + _features_features_0_conv_conv_c_matrix_rz: &mut OutputVariable, ) { - *_features_features_0_Conv_conv_c_matrix_rz = api.constant(0); + *_features_features_0_conv_conv_c_matrix_rz = api.constant(0); for i in 0..1024 { - let tmp = api.mul(input_mat_ru[i], _features_features_0_Conv_output_0_conv_mat[i]); - *_features_features_0_Conv_conv_c_matrix_rz = api.add(tmp, *_features_features_0_Conv_conv_c_matrix_rz); + let tmp = api.mul(input_mat_ru[i], _features_features_0_conv_output_0_conv_mat[i]); + *_features_features_0_conv_conv_c_matrix_rz = api.add(tmp, *_features_features_0_conv_conv_c_matrix_rz); } } #[kernel] // multiply operation -fn _features_features_0_Conv_mul_macro( +fn _features_features_0_conv_mul_macro( api: &mut API, - _features_features_0_Conv_output_0_conv: &[[InputVariable;32];32], - onnx__Conv_150_nscale: &InputVariable, - _features_features_0_Conv_output_0_mul: &mut [[OutputVariable;32];32], + _features_features_0_conv_output_0_conv: &[[InputVariable;32];32], + onnx_conv_150_nscale: &InputVariable, + _features_features_0_conv_output_0_mul: &mut [[OutputVariable;32];32], ) { for i in 0..32 { for j in 0..32 { - _features_features_0_Conv_output_0_mul[i][j] = api.mul(_features_features_0_Conv_output_0_conv[i][j], onnx__Conv_150_nscale); + _features_features_0_conv_output_0_mul[i][j] = api.mul(_features_features_0_conv_output_0_conv[i][j], onnx_conv_150_nscale); } } } #[kernel] // divide operation -fn _features_features_0_Conv_div_macro( +fn _features_features_0_conv_div_macro( api: &mut API, - _features_features_0_Conv_output_0_mul: &[[InputVariable;32];32], - onnx__Conv_150_dscale: &InputVariable, - _features_features_0_Conv_output_0_floor: &[[InputVariable;32];32], - _features_features_0_Conv_output_0_rem: &[[InputVariable;32];32], + _features_features_0_conv_output_0_mul: &[[InputVariable;32];32], + onnx_conv_150_dscale: &InputVariable, + _features_features_0_conv_output_0_floor: &[[InputVariable;32];32], + _features_features_0_conv_output_0_rem: &[[InputVariable;32];32], ) { for i in 0..32 { for j in 0..32 { - let tmp1 = api.mul(_features_features_0_Conv_output_0_floor[i][j], onnx__Conv_150_dscale); - let tmp2 = api.sub(_features_features_0_Conv_output_0_mul[i][j], _features_features_0_Conv_output_0_rem[i][j]); + let tmp1 = api.mul(_features_features_0_conv_output_0_floor[i][j], onnx_conv_150_dscale); + let tmp2 = api.sub(_features_features_0_conv_output_0_mul[i][j], _features_features_0_conv_output_0_rem[i][j]); api.assert_is_equal(tmp1, tmp2); } } @@ -528,32 +526,32 @@ fn expander_circuit() -> std::io::Result<()>{ let mut ctx = Context::::default(); let mut assignment = default_variable(); - let onnx__Conv_150_mat = ctx.copy_to_device(&assignment.onnx__Conv_150); // [64, 3, 3, 3] - let onnx__Conv_150_mat = onnx__Conv_150_mat.reshape(&[64, 27]); // [64, 27] - let onnx__Conv_150_mat = onnx__Conv_150_mat.transpose(&[1, 0]); // [27, 64] + let onnx_conv_150_mat = ctx.copy_to_device(&assignment.onnx_conv_150); // [64, 3, 3, 3] + let onnx_conv_150_mat = onnx_conv_150_mat.reshape(&[64, 27]); // [64, 27] + let onnx_conv_150_mat = onnx_conv_150_mat.transpose(&[1, 0]); // [27, 64] - let kernel__features_features_0_Conv_conv_ab_matrix: KernelPrimitive = compile__features_features_0_Conv_conv_ab_matrix_macro().unwrap(); + let kernel__features_features_0_conv_conv_ab_matrix: KernelPrimitive = compile__features_features_0_conv_conv_ab_matrix_macro().unwrap(); let input_mat = ctx.copy_to_device(&vec![vec![BN254Fr::default();1024];27]); let input_mat_ru = ctx.copy_to_device(&assignment.input_mat_ru); - let onnx__Conv_150_mat_rv = ctx.copy_to_device(&assignment.onnx__Conv_150_mat_rv); - let mut _features_features_0_Conv_conv_rx = None; - let mut _features_features_0_Conv_conv_ry = None; + let onnx_conv_150_mat_rv = ctx.copy_to_device(&assignment.onnx_conv_150_mat_rv); + let mut _features_features_0_conv_conv_rx = None; + let mut _features_features_0_conv_conv_ry = None; let mut input_mat_clone = input_mat.clone(); - let mut onnx__Conv_150_mat_clone = onnx__Conv_150_mat.clone(); + let mut onnx_conv_150_mat_clone = onnx_conv_150_mat.clone(); let mut input_mat_ru_clone = input_mat_ru.clone(); - let mut onnx__Conv_150_mat_rv_clone = onnx__Conv_150_mat_rv.clone(); - call_kernel!(ctx, kernel__features_features_0_Conv_conv_ab_matrix, 27, input_mat_clone, onnx__Conv_150_mat_clone, input_mat_ru_clone, onnx__Conv_150_mat_rv_clone, mut _features_features_0_Conv_conv_rx, mut _features_features_0_Conv_conv_ry).unwrap(); + let mut onnx_conv_150_mat_rv_clone = onnx_conv_150_mat_rv.clone(); + call_kernel!(ctx, kernel__features_features_0_conv_conv_ab_matrix, 27, input_mat_clone, onnx_conv_150_mat_clone, input_mat_ru_clone, onnx_conv_150_mat_rv_clone, mut _features_features_0_conv_conv_rx, mut _features_features_0_conv_conv_ry).unwrap(); - let _features_features_0_Conv_output_0_conv = ctx.copy_to_device(&assignment._features_features_0_Conv_output_0_conv); // [1, 64, 32, 32] - let _features_features_0_Conv_output_0_conv_mat = _features_features_0_Conv_output_0_conv.transpose(&[1, 0, 2, 3]); // [64, 1, 32, 32] - let _features_features_0_Conv_output_0_conv_mat = _features_features_0_Conv_output_0_conv_mat.reshape(&[64, 1024]); // [64, 1024] + let _features_features_0_conv_output_0_conv = ctx.copy_to_device(&assignment._features_features_0_conv_output_0_conv); // [1, 64, 32, 32] + let _features_features_0_conv_output_0_conv_mat = _features_features_0_conv_output_0_conv.transpose(&[1, 0, 2, 3]); // [64, 1, 32, 32] + let _features_features_0_conv_output_0_conv_mat = _features_features_0_conv_output_0_conv_mat.reshape(&[64, 1024]); // [64, 1024] - let kernel__features_features_0_Conv_conv_c_matrix: KernelPrimitive = compile__features_features_0_Conv_conv_c_matrix_macro().unwrap(); - // let _features_features_0_Conv_output_0_conv_mat = ctx.copy_to_device(&vec![vec![BN254Fr::default();1024];64]); - let mut _features_features_0_Conv_conv_rz = None; - let _features_features_0_Conv_output_0_conv_mat_clone = _features_features_0_Conv_output_0_conv_mat.clone(); + let kernel__features_features_0_conv_conv_c_matrix: KernelPrimitive = compile__features_features_0_conv_conv_c_matrix_macro().unwrap(); + // let _features_features_0_conv_output_0_conv_mat = ctx.copy_to_device(&vec![vec![BN254Fr::default();1024];64]); + let mut _features_features_0_conv_conv_rz = None; + let _features_features_0_conv_output_0_conv_mat_clone = _features_features_0_conv_output_0_conv_mat.clone(); let input_mat_ru_clone = input_mat_ru.clone(); - call_kernel!(ctx, kernel__features_features_0_Conv_conv_c_matrix, 64, _features_features_0_Conv_output_0_conv_mat_clone, input_mat_ru_clone, mut _features_features_0_Conv_conv_rz).unwrap(); + call_kernel!(ctx, kernel__features_features_0_conv_conv_c_matrix, 64, _features_features_0_conv_output_0_conv_mat_clone, input_mat_ru_clone, mut _features_features_0_conv_conv_rz).unwrap(); let computation_graph = ctx.compile_computation_graph().unwrap(); let file = std::fs::File::create("graph.txt").unwrap(); From aa80f74cce313d3e0e893b149c8b172e18f83591 Mon Sep 17 00:00:00 2001 From: chonps Date: Mon, 28 Jul 2025 23:26:07 -0700 Subject: [PATCH 60/60] clear --- expander_compiler/src/zkcuda/context.rs | 1 - expander_compiler/src/zkcuda/shape.rs | 27 ------------------------- 2 files changed, 28 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 65cd5231..aaeda969 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -486,7 +486,6 @@ impl>> Context { loop { let get_pad_shape = |x: &DeviceMemoryHandle| { x.as_ref().map(|handle| { -println!("get handle {:?}", handle.id); handle .shape_history .get_transposed_shape_and_bit_order(&dm_shapes[handle.id]) diff --git a/expander_compiler/src/zkcuda/shape.rs b/expander_compiler/src/zkcuda/shape.rs index 2318b7eb..651a962b 100644 --- a/expander_compiler/src/zkcuda/shape.rs +++ b/expander_compiler/src/zkcuda/shape.rs @@ -83,29 +83,9 @@ impl Entry { if self.axes.is_none() { return shape.to_vec(); } -println!("self axes {:?}", self.axes); -println!("self shape {:?}", self.shape); let mut segments = vec![]; - // let mut cur_prod = 1; - // let mut target = 1; - // let mut self_shape_iter = self.shape.iter(); - // for &x in shape.iter() { - // if cur_prod == target { - // cur_prod = x.0; - // target = *self_shape_iter.next().unwrap(); - // segments.push(vec![x]); - // } else { - // cur_prod *= x.0; - // segments.last_mut().unwrap().push(x); - // } - // } - // assert_eq!(cur_prod, target); - // assert_eq!(self_shape_iter.next(), None); let mut shape_iter = shape.iter(); for &x in self.shape.iter() { - // if x == 1 { - // continue; - // } let mut cur_prod = 1; segments.push(vec![]); while let Some(y) = shape_iter.next() { @@ -117,7 +97,6 @@ println!("self shape {:?}", self.shape); } assert_eq!(cur_prod, x); } -println!("segments {:?}", segments); let mut res = Vec::with_capacity(shape.len()); for i in self.axes.as_ref().unwrap() { res.extend(segments[*i].iter()); @@ -129,9 +108,6 @@ println!("segments {:?}", segments); return products.to_vec(); } let ts = self.transposed_shape(); -println!("undo transpose shape products {:?}", products); -println!("self axes {:?}", self.axes); -println!("transposed shape {:?}", ts); let mut segments_in_ts = vec![Vec::new(); ts.len()]; let mut cur_ts_prod = 1; let mut cur_ts_idx = 0; @@ -300,11 +276,8 @@ impl ShapeHistory { cur = if e.axes.as_ref().is_none() { cur } else if cur.is_none() { -println!("intiial shape {:?}", initial_shape()); - // Some(e.transpose_shape(&initial_shape())) Some(e.minimize(false).transpose_shape(&initial_shape())) } else { - // Some(e.transpose_shape(&cur.unwrap())) Some(e.minimize(false).transpose_shape(&cur.unwrap())) }; }