From 78aea3fcf02bc639bb15c16660008768223d7dfd Mon Sep 17 00:00:00 2001 From: SashaMalysehko Date: Mon, 19 Jan 2026 22:49:32 +0200 Subject: [PATCH 1/3] Optimize root public values hashing buffer --- crates/prover/src/lib.rs | 1706 ++++---------------------------------- 1 file changed, 163 insertions(+), 1543 deletions(-) diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index d20e5b53..d4a3f7dd 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -1,1590 +1,210 @@ -//! An end-to-end-prover implementation for the Ziren zkVM. -//! -//! Separates the proof generation process into multiple stages: -//! -//! 1. Generate shard proofs which split up and prove the valid execution of a MIPS program. -//! 2. Compress shard proofs into a single shard proof. -//! 3. Wrap the shard proof into a SNARK-friendly field. -//! 4. Wrap the last shard proof, proven over the SNARK-friendly field, into a PLONK proof. - -#![allow(clippy::too_many_arguments)] -#![allow(clippy::new_without_default)] -#![allow(clippy::collapsible_else_if)] - -pub mod build; -pub mod components; -pub mod shapes; -pub mod types; -pub mod utils; -pub mod verify; - use std::{ borrow::Borrow, - collections::BTreeMap, - env, - num::NonZeroUsize, - path::Path, - sync::{ - atomic::{AtomicUsize, Ordering}, - mpsc::sync_channel, - Arc, Mutex, OnceLock, - }, - thread, + fs::{self, File}, + io::Read, + iter::{Skip, Take}, }; -use lru::LruCache; -use p3_field::{FieldAlgebra, PrimeField, PrimeField32}; +use itertools::Itertools; +use p3_bn254_fr::Bn254Fr; +use p3_field::{FieldAlgebra, PrimeField32}; use p3_koala_bear::KoalaBear; -use p3_matrix::dense::RowMajorMatrix; -use shapes::ZKMProofShape; -use tracing::instrument; -use zkm_core_executor::{ExecutionError, ExecutionReport, Executor, Program, ZKMContext}; -use zkm_core_machine::{ - io::ZKMStdin, - mips::MipsAir, - reduce::ZKMReduceProof, - shape::CoreShapeConfig, - utils::{concurrency::TurnBasedSync, ZKMCoreProverError}, -}; -use zkm_primitives::{hash_deferred_proof, io::ZKMPublicValues}; -use zkm_recursion_circuit::{ - hash::FieldHasher, - machine::{ - PublicValuesOutputDigest, ZKMCompressRootVerifierWithVKey, ZKMCompressShape, - ZKMCompressWithVKeyVerifier, ZKMCompressWithVKeyWitnessValues, ZKMCompressWithVkeyShape, - ZKMCompressWitnessValues, ZKMDeferredVerifier, ZKMDeferredWitnessValues, - ZKMMerkleProofWitnessValues, ZKMRecursionShape, ZKMRecursionWitnessValues, - ZKMRecursiveVerifier, - }, - merkle_tree::MerkleTree, - witness::Witnessable, - WrapConfig, -}; -use zkm_recursion_compiler::{ - circuit::AsmCompiler, - config::InnerConfig, - ir::{Builder, Witness}, -}; +use p3_symmetric::CryptographicHasher; +use zkm_core_executor::{Executor, Program}; +use zkm_core_machine::{io::ZKMStdin, reduce::ZKMReduceProof}; +use zkm_recursion_circuit::machine::RootPublicValues; use zkm_recursion_core::{ - air::RecursionPublicValues, - machine::RecursionAir, - runtime::ExecutionRecord, - shape::{RecursionShape, RecursionShapeConfig}, + air::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH}, stark::KoalaBearPoseidon2Outer, - RecursionProgram, Runtime as RecursionRuntime, }; -pub use zkm_recursion_gnark_ffi::proof::{Groth16Bn254Proof, PlonkBn254Proof}; -use zkm_recursion_gnark_ffi::{groth16_bn254::Groth16Bn254Prover, plonk_bn254::PlonkBn254Prover}; -use zkm_stark::{ - air::PublicValues, koala_bear_poseidon2::KoalaBearPoseidon2, Challenge, MachineProver, - ShardProof, StarkGenericConfig, StarkVerifyingKey, Val, Word, ZKMCoreOpts, ZKMProverOpts, - DIGEST_SIZE, -}; -use zkm_stark::{shape::OrderedShape, MachineProvingKey}; - -pub use types::*; -use utils::{words_to_bytes, zkm_committed_values_digest_bn254, zkm_vkey_digest_bn254}; - -use components::{DefaultProverComponents, ZKMProverComponents}; - -pub use zkm_core_machine::ZKM_CIRCUIT_VERSION; - -/// The configuration for the core prover. -pub type CoreSC = KoalaBearPoseidon2; - -/// The configuration for the inner prover. -pub type InnerSC = KoalaBearPoseidon2; - -/// The configuration for the outer prover. -pub type OuterSC = KoalaBearPoseidon2Outer; - -pub type DeviceProvingKey = <::CoreProver as MachineProver< - KoalaBearPoseidon2, - MipsAir, ->>::DeviceProvingKey; - -const COMPRESS_DEGREE: usize = 3; -const SHRINK_DEGREE: usize = 3; -const WRAP_DEGREE: usize = 9; - -const CORE_CACHE_SIZE: usize = 5; -pub const REDUCE_BATCH_SIZE: usize = 2; - -// TODO: FIX -// -// const SHAPES_URL_PREFIX: &str = "https://zkm-circuits.s3.us-east-2.amazonaws.com/shapes"; -// const SHAPES_VERSION: &str = "146079e0e"; -// lazy_static! { -// static ref SHAPES_INIT: Once = Once::new(); -// } - -pub type CompressAir = RecursionAir; -pub type ShrinkAir = RecursionAir; -pub type WrapAir = RecursionAir; - -/// An end-to-end prover implementation for the Ziren zkVM. -pub struct ZKMProver { - /// The machine used for proving the core step. - pub core_prover: C::CoreProver, - - /// The machine used for proving the recursive and reduction steps. - pub compress_prover: C::CompressProver, - - /// The machine used for proving the shrink step. - pub shrink_prover: C::ShrinkProver, - - /// The machine used for proving the wrapping step. - pub wrap_prover: C::WrapProver, - - /// The cache of compiled recursion programs. - pub lift_programs_lru: Mutex>>>, - - /// The number of cache misses for recursion programs. - pub lift_cache_misses: AtomicUsize, - - /// The cache of compiled compression programs. - pub join_programs_map: BTreeMap>>, - - /// The number of cache misses for compression programs. - pub join_cache_misses: AtomicUsize, - - /// The root of the allowed recursion verification keys. - pub recursion_vk_root: >::Digest, - - /// The allowed VKs and their corresponding indices. - pub recursion_vk_map: BTreeMap<>::Digest, usize>, - - /// The Merkle tree for the allowed VKs. - pub recursion_vk_tree: MerkleTree, - - /// The core shape configuration. - pub core_shape_config: Option>, - - /// The recursion shape configuration. - pub compress_shape_config: Option>>, - - /// The program for wrapping. - pub wrap_program: OnceLock>>, +use zkm_stark::{koala_bear_poseidon2::MyHash as InnerHash, Word, ZKMCoreOpts}; - /// The verifying key for wrapping. - pub wrap_vk: OnceLock>, +use crate::{InnerSC, ZKMCoreProofData}; - /// Whether to verify verification keys. - pub vk_verification: bool, +/// Get the Ziren vkey KoalaBear Poseidon2 digest this reduce proof is representing. +pub fn zkm_vkey_digest_koalabear( + proof: &ZKMReduceProof, +) -> [KoalaBear; 8] { + let proof = &proof.proof; + let pv: &RecursionPublicValues = proof.public_values.as_slice().borrow(); + pv.zkm_vk_digest } -impl ZKMProver { - /// Initializes a new [ZKMProver]. - #[instrument(name = "initialize prover", level = "debug", skip_all)] - pub fn new() -> Self { - Self::uninitialized() - } - - /// Creates a new [ZKMProver] with lazily initialized components. - pub fn uninitialized() -> Self { - // Initialize the provers. - let core_machine = MipsAir::machine(CoreSC::default()); - let core_prover = C::CoreProver::new(core_machine); - - let compress_machine = CompressAir::compress_machine(InnerSC::default()); - let compress_prover = C::CompressProver::new(compress_machine); - - // TODO: Put the correct shrink and wrap machines here. - let shrink_machine = ShrinkAir::shrink_machine(InnerSC::compressed()); - let shrink_prover = C::ShrinkProver::new(shrink_machine); - - let wrap_machine = WrapAir::wrap_machine(OuterSC::default()); - let wrap_prover = C::WrapProver::new(wrap_machine); - - let core_cache_size = NonZeroUsize::new( - env::var("PROVER_CORE_CACHE_SIZE") - .unwrap_or_else(|_| CORE_CACHE_SIZE.to_string()) - .parse() - .unwrap_or(CORE_CACHE_SIZE), - ) - .expect("PROVER_CORE_CACHE_SIZE must be a non-zero usize"); - - let core_shape_config = env::var("FIX_CORE_SHAPES") - .map(|v| v.eq_ignore_ascii_case("true")) - .unwrap_or(true) - .then_some(CoreShapeConfig::default()); - - let recursion_shape_config = env::var("FIX_RECURSION_SHAPES") - .map(|v| v.eq_ignore_ascii_case("true")) - .unwrap_or(true) - .then_some(RecursionShapeConfig::default()); - - let vk_verification = - env::var("VERIFY_VK").map(|v| v.eq_ignore_ascii_case("true")).unwrap_or(true); - - tracing::debug!("vk verification: {}", vk_verification); - - // Read the shapes from the shapes directory and deserialize them into memory. - let allowed_vk_map: BTreeMap<[KoalaBear; DIGEST_SIZE], usize> = if vk_verification { - // Regenerate the vk_map.bin when the Ziren circuit is updated. - // ``` - // cd Ziren - // cargo run -r --bin build_compress_vks -- --num-compiler-workers 32 --count-setup-workers 32 --build-dir crates/prover - // ``` - // It takes several days. - bincode::deserialize(include_bytes!("../vk_map.bin")).unwrap() - } else { - bincode::deserialize(include_bytes!("../dummy_vk_map.bin")).unwrap() - }; - - let (root, merkle_tree) = MerkleTree::commit(allowed_vk_map.keys().copied().collect()); - - let mut compress_programs = BTreeMap::new(); - if let Some(config) = &recursion_shape_config { - ZKMProofShape::generate_compress_shapes(config, REDUCE_BATCH_SIZE).for_each(|shape| { - let compress_shape = ZKMCompressWithVkeyShape { - compress_shape: shape.into(), - merkle_tree_height: merkle_tree.height, - }; - let input = ZKMCompressWithVKeyWitnessValues::dummy( - compress_prover.machine(), - &compress_shape, - ); - let program = compress_program_from_input::( - recursion_shape_config.as_ref(), - &compress_prover, - vk_verification, - &input, - ); - let program = Arc::new(program); - compress_programs.insert(compress_shape, program); - }); - } - - Self { - core_prover, - compress_prover, - shrink_prover, - wrap_prover, - lift_programs_lru: Mutex::new(LruCache::new(core_cache_size)), - lift_cache_misses: AtomicUsize::new(0), - join_programs_map: compress_programs, - join_cache_misses: AtomicUsize::new(0), - recursion_vk_root: root, - recursion_vk_tree: merkle_tree, - recursion_vk_map: allowed_vk_map, - core_shape_config, - compress_shape_config: recursion_shape_config, - vk_verification, - wrap_program: OnceLock::new(), - wrap_vk: OnceLock::new(), - } - } - - /// Fully initializes the programs, proving keys, and verifying keys that are normally - /// lazily initialized. TODO: remove this. - pub fn initialize(&mut self) {} - - /// Creates a proving key and a verifying key for a given MIPS ELF. - #[instrument(name = "setup", level = "debug", skip_all)] - pub fn setup( - &self, - elf: &[u8], - ) -> (ZKMProvingKey, DeviceProvingKey, Program, ZKMVerifyingKey) { - let program = self.get_program(elf).unwrap(); - let (pk, vk) = self.core_prover.setup(&program); - let vk = ZKMVerifyingKey { vk }; - let pk = ZKMProvingKey { - pk: self.core_prover.pk_to_host(&pk), - elf: elf.to_vec(), - vk: vk.clone(), - }; - let pk_d = self.core_prover.pk_to_device(&pk.pk); - (pk, pk_d, program, vk) - } - - /// Get a program with an allowed preprocessed shape. - pub fn get_program(&self, elf: &[u8]) -> eyre::Result { - let mut program = Program::from(elf).unwrap(); - if let Some(core_shape_config) = &self.core_shape_config { - core_shape_config.fix_preprocessed_shape(&mut program)?; - } - Ok(program) - } - - /// Generate a proof of a Ziren program with the specified inputs. - #[instrument(name = "execute", level = "info", skip_all)] - pub fn execute<'a>( - &'a self, - elf: &[u8], - stdin: &ZKMStdin, - mut context: ZKMContext<'a>, - ) -> Result<(ZKMPublicValues, ExecutionReport), ExecutionError> { - context.subproof_verifier = Some(self); - let program = self.get_program(elf).unwrap(); - let opts = ZKMCoreOpts::default(); - let mut runtime = Executor::with_context(program, opts, context); - runtime.write_vecs(&stdin.buffer); - for (proof, vkey) in stdin.proofs.iter() { - runtime.write_proof(proof.clone(), vkey.clone()); - } - runtime.run_fast()?; - Ok((ZKMPublicValues::from(&runtime.state.public_values_stream), runtime.report)) - } - - /// Generate shard proofs which split up and prove the valid execution of a MIPS program with - /// the core prover. Uses the provided context. - #[instrument(name = "prove_core", level = "info", skip_all)] - pub fn prove_core<'a>( - &'a self, - pk_d: &<::CoreProver as MachineProver< - KoalaBearPoseidon2, - MipsAir, - >>::DeviceProvingKey, - program: Program, - stdin: &ZKMStdin, - opts: ZKMProverOpts, - mut context: ZKMContext<'a>, - ) -> Result { - context.subproof_verifier = Some(self); - let pk = pk_d; - let (proof, public_values_stream, cycles) = - zkm_core_machine::utils::prove_with_context::<_, C::CoreProver>( - &self.core_prover, - pk, - program, - stdin, - opts.core_opts, - context, - self.core_shape_config.as_ref(), - )?; - Self::check_for_high_cycles(cycles); - let public_values = ZKMPublicValues::from(&public_values_stream); - Ok(ZKMCoreProof { - proof: ZKMCoreProofData(proof.shard_proofs), - stdin: stdin.clone(), - public_values, - cycles, - }) - } - - pub fn recursion_program( - &self, - input: &ZKMRecursionWitnessValues, - ) -> Arc> { - let mut cache = self.lift_programs_lru.lock().unwrap_or_else(|e| e.into_inner()); - cache - .get_or_insert(input.shape(), || { - let misses = self.lift_cache_misses.fetch_add(1, Ordering::Relaxed); - tracing::debug!("core cache miss, misses: {}", misses); - // Get the operations. - let builder_span = tracing::debug_span!("build recursion program").entered(); - let mut builder = Builder::::default(); - - let input = input.read(&mut builder); - ZKMRecursiveVerifier::verify(&mut builder, self.core_prover.machine(), input); - let operations = builder.into_operations(); - builder_span.exit(); - - // Compile the program. - let compiler_span = tracing::debug_span!("compile recursion program").entered(); - let mut compiler = AsmCompiler::::default(); - let mut program = compiler.compile(operations); - if let Some(recursion_shape_config) = &self.compress_shape_config { - recursion_shape_config.fix_shape(&mut program); - } - let program = Arc::new(program); - compiler_span.exit(); - program - }) - .clone() - } - - pub fn compress_program( - &self, - input: &ZKMCompressWithVKeyWitnessValues, - ) -> Arc> { - self.join_programs_map.get(&input.shape()).cloned().unwrap_or_else(|| { - tracing::warn!("compress program not found in map, recomputing join program."); - // Get the operations. - Arc::new(compress_program_from_input::( - self.compress_shape_config.as_ref(), - &self.compress_prover, - self.vk_verification, - input, - )) - }) - } - - pub fn shrink_program( - &self, - shrink_shape: RecursionShape, - input: &ZKMCompressWithVKeyWitnessValues, - ) -> Arc> { - // Get the operations. - let builder_span = tracing::debug_span!("build shrink program").entered(); - let mut builder = Builder::::default(); - let input = input.read(&mut builder); - // Verify the proof. - ZKMCompressRootVerifierWithVKey::verify( - &mut builder, - self.compress_prover.machine(), - input, - self.vk_verification, - PublicValuesOutputDigest::Reduce, - ); - let operations = builder.into_operations(); - builder_span.exit(); - - // Compile the program. - let compiler_span = tracing::debug_span!("compile shrink program").entered(); - let mut compiler = AsmCompiler::::default(); - let mut program = compiler.compile(operations); - *program.shape_mut() = Some(shrink_shape); - let program = Arc::new(program); - compiler_span.exit(); - program - } - - pub fn wrap_program(&self) -> Arc> { - self.wrap_program - .get_or_init(|| { - // Get the operations. - let builder_span = tracing::debug_span!("build compress program").entered(); - let mut builder = Builder::::default(); - - let shrink_shape: OrderedShape = ShrinkAir::::shrink_shape().into(); - let input_shape = ZKMCompressShape::from(vec![shrink_shape]); - let shape = ZKMCompressWithVkeyShape { - compress_shape: input_shape, - merkle_tree_height: self.recursion_vk_tree.height, - }; - let dummy_input = - ZKMCompressWithVKeyWitnessValues::dummy(self.shrink_prover.machine(), &shape); - - let input = dummy_input.read(&mut builder); - - // Attest that the merkle tree root is correct. - let root = input.merkle_var.root; - for (val, expected) in root.iter().zip(self.recursion_vk_root.iter()) { - builder.assert_felt_eq(*val, *expected); - } - // Verify the proof. - ZKMCompressRootVerifierWithVKey::verify( - &mut builder, - self.shrink_prover.machine(), - input, - self.vk_verification, - PublicValuesOutputDigest::Root, - ); - - let operations = builder.into_operations(); - builder_span.exit(); - - // Compile the program. - let compiler_span = tracing::debug_span!("compile compress program").entered(); - let mut compiler = AsmCompiler::::default(); - let program = Arc::new(compiler.compile(operations)); - compiler_span.exit(); - program - }) - .clone() - } - - pub fn deferred_program( - &self, - input: &ZKMDeferredWitnessValues, - ) -> Arc> { - // Compile the program. - - // Get the operations. - let operations_span = - tracing::debug_span!("get operations for the deferred program").entered(); - let mut builder = Builder::::default(); - let input_read_span = tracing::debug_span!("Read input values").entered(); - let input = input.read(&mut builder); - input_read_span.exit(); - let verify_span = tracing::debug_span!("Verify deferred program").entered(); - - // Verify the proof. - ZKMDeferredVerifier::verify( - &mut builder, - self.compress_prover.machine(), - input, - self.vk_verification, - ); - verify_span.exit(); - let operations = builder.into_operations(); - operations_span.exit(); - - let compiler_span = tracing::debug_span!("compile deferred program").entered(); - let mut compiler = AsmCompiler::::default(); - let mut program = compiler.compile(operations); - if let Some(recursion_shape_config) = &self.compress_shape_config { - recursion_shape_config.fix_shape(&mut program); - } - let program = Arc::new(program); - compiler_span.exit(); - program - } - - pub fn get_recursion_core_inputs( - &self, - vk: &StarkVerifyingKey, - shard_proofs: &[ShardProof], - batch_size: usize, - is_complete: bool, - ) -> Vec> { - let mut core_inputs = Vec::new(); - - // Prepare the inputs for the recursion programs. - for (batch_idx, batch) in shard_proofs.chunks(batch_size).enumerate() { - let proofs = batch.to_vec(); - - core_inputs.push(ZKMRecursionWitnessValues { - vk: vk.clone(), - shard_proofs: proofs.clone(), - is_complete, - is_first_shard: batch_idx == 0, - vk_root: self.recursion_vk_root, - }); - } - - core_inputs - } - - pub fn get_recursion_deferred_inputs<'a>( - &'a self, - vk: &'a StarkVerifyingKey, - last_proof_pv: &PublicValues, KoalaBear>, - deferred_proofs: &[ZKMReduceProof], - batch_size: usize, - ) -> Vec> { - // Prepare the inputs for the deferred proofs recursive verification. - let mut deferred_digest = [Val::::ZERO; DIGEST_SIZE]; - let mut deferred_inputs = Vec::new(); - - for batch in deferred_proofs.chunks(batch_size) { - let vks_and_proofs = - batch.iter().cloned().map(|proof| (proof.vk, proof.proof)).collect::>(); +/// Get the Ziren vkey Bn Poseidon2 digest this reduce proof is representing. +pub fn zkm_vkey_digest_bn254(proof: &ZKMReduceProof) -> Bn254Fr { + koalabears_to_bn254(&zkm_vkey_digest_koalabear(proof)) +} - let input = ZKMCompressWitnessValues { vks_and_proofs, is_complete: true }; - let input = self.make_merkle_proofs(input); - let ZKMCompressWithVKeyWitnessValues { compress_val, merkle_val } = input; +/// Compute the digest of the public values. +pub fn recursion_public_values_digest( + config: &InnerSC, + public_values: &RecursionPublicValues, +) -> [KoalaBear; 8] { + let hash = InnerHash::new(config.perm.clone()); + let pv_array = public_values.as_array(); + hash.hash_slice(&pv_array[0..NUM_PV_ELMS_TO_HASH]) +} - deferred_inputs.push(ZKMDeferredWitnessValues { - vks_and_proofs: compress_val.vks_and_proofs, - vk_merkle_data: merkle_val, - start_reconstruct_deferred_digest: deferred_digest, - is_complete: false, - zkm_vk_digest: vk.hash_koalabear(), - end_pc: Val::::ZERO, - end_shard: last_proof_pv.shard + KoalaBear::ONE, - end_execution_shard: last_proof_pv.execution_shard, - init_addr_bits: last_proof_pv.last_init_addr_bits, - finalize_addr_bits: last_proof_pv.last_finalize_addr_bits, - committed_value_digest: last_proof_pv.committed_value_digest, - deferred_proofs_digest: last_proof_pv.deferred_proofs_digest, - }); +pub fn root_public_values_digest( + config: &InnerSC, + public_values: &RootPublicValues, +) -> [KoalaBear; 8] { + let hash = InnerHash::new(config.perm.clone()); + let mut input = [KoalaBear::ZERO; 40]; + input[..8].copy_from_slice(public_values.zkm_vk_digest()); + for (i, word) in public_values.committed_value_digest().iter().enumerate() { + let start = 8 + i * 4; + input[start..start + 4].copy_from_slice(&word.0); + } + hash.hash_slice(&input) +} - deferred_digest = Self::hash_deferred_proofs(deferred_digest, batch); +pub fn is_root_public_values_valid( + config: &InnerSC, + public_values: &RootPublicValues, +) -> bool { + let expected_digest = root_public_values_digest(config, public_values); + for (value, expected) in public_values.digest().iter().copied().zip_eq(expected_digest) { + if value != expected { + return false; } - deferred_inputs } + true +} - /// Generate the inputs for the first layer of recursive proofs. - #[allow(clippy::type_complexity)] - pub fn get_first_layer_inputs<'a>( - &'a self, - vk: &'a ZKMVerifyingKey, - shard_proofs: &[ShardProof], - deferred_proofs: &[ZKMReduceProof], - batch_size: usize, - ) -> Vec { - let is_complete = shard_proofs.len() == 1 && deferred_proofs.is_empty(); - let core_inputs = - self.get_recursion_core_inputs(&vk.vk, shard_proofs, batch_size, is_complete); - let last_proof_pv = shard_proofs.last().unwrap().public_values.as_slice().borrow(); - let deferred_inputs = - self.get_recursion_deferred_inputs(&vk.vk, last_proof_pv, deferred_proofs, batch_size); - - let mut inputs = Vec::new(); - inputs.extend(core_inputs.into_iter().map(ZKMCircuitWitness::Core)); - inputs.extend(deferred_inputs.into_iter().map(ZKMCircuitWitness::Deferred)); - inputs - } - - /// Reduce shard proofs to a single shard proof using the recursion prover. - #[instrument(name = "compress", level = "info", skip_all)] - pub fn compress( - &self, - vk: &ZKMVerifyingKey, - proof: ZKMCoreProof, - deferred_proofs: Vec>, - opts: ZKMProverOpts, - ) -> Result, ZKMRecursionProverError> { - // The batch size for reducing two layers of recursion. - let batch_size = REDUCE_BATCH_SIZE; - // The batch size for reducing the first layer of recursion. - let first_layer_batch_size = 1; - - let shard_proofs = &proof.proof.0; - - let first_layer_inputs = - self.get_first_layer_inputs(vk, shard_proofs, &deferred_proofs, first_layer_batch_size); - - // Calculate the expected height of the tree. - let mut expected_height = if first_layer_inputs.len() == 1 { 0 } else { 1 }; - let num_first_layer_inputs = first_layer_inputs.len(); - let mut num_layer_inputs = num_first_layer_inputs; - while num_layer_inputs > batch_size { - num_layer_inputs = num_layer_inputs.div_ceil(2); - expected_height += 1; +/// Check if the digest of the public values is correct. +pub fn is_recursion_public_values_valid( + config: &InnerSC, + public_values: &RecursionPublicValues, +) -> bool { + let expected_digest = recursion_public_values_digest(config, public_values); + for (value, expected) in public_values.digest.iter().copied().zip_eq(expected_digest) { + if value != expected { + return false; } - - // Generate the proofs. - let span = tracing::Span::current().clone(); - let (vk, proof) = thread::scope(|s| { - let _span = span.enter(); - - // Spawn a worker that sends the first layer inputs to a bounded channel. - let input_sync = Arc::new(TurnBasedSync::new()); - let (input_tx, input_rx) = sync_channel::<(usize, usize, ZKMCircuitWitness)>( - opts.recursion_opts.checkpoints_channel_capacity, - ); - let input_tx = Arc::new(Mutex::new(input_tx)); - { - let input_tx = Arc::clone(&input_tx); - let input_sync = Arc::clone(&input_sync); - s.spawn(move || { - for (index, input) in first_layer_inputs.into_iter().enumerate() { - input_sync.wait_for_turn(index); - input_tx.lock().unwrap().send((index, 0, input)).unwrap(); - input_sync.advance_turn(); - } - }); - } - - // Spawn workers who generate the records and traces. - let record_and_trace_sync = Arc::new(TurnBasedSync::new()); - let (record_and_trace_tx, record_and_trace_rx) = - sync_channel::<( - usize, - usize, - Arc>, - ExecutionRecord, - Vec<(String, RowMajorMatrix)>, - )>(opts.recursion_opts.records_and_traces_channel_capacity); - let record_and_trace_tx = Arc::new(Mutex::new(record_and_trace_tx)); - let record_and_trace_rx = Arc::new(Mutex::new(record_and_trace_rx)); - let input_rx = Arc::new(Mutex::new(input_rx)); - for _ in 0..opts.recursion_opts.trace_gen_workers { - let record_and_trace_sync = Arc::clone(&record_and_trace_sync); - let record_and_trace_tx = Arc::clone(&record_and_trace_tx); - let input_rx = Arc::clone(&input_rx); - let span = tracing::debug_span!("generate records and traces"); - s.spawn(move || { - let _span = span.enter(); - loop { - let received = { input_rx.lock().unwrap().recv() }; - if let Ok((index, height, input)) = received { - // Get the program and witness stream. - let (program, witness_stream) = tracing::debug_span!( - "get program and witness stream" - ) - .in_scope(|| match input { - ZKMCircuitWitness::Core(input) => { - let mut witness_stream = Vec::new(); - Witnessable::::write(&input, &mut witness_stream); - (self.recursion_program(&input), witness_stream) - } - ZKMCircuitWitness::Deferred(input) => { - let mut witness_stream = Vec::new(); - Witnessable::::write(&input, &mut witness_stream); - (self.deferred_program(&input), witness_stream) - } - ZKMCircuitWitness::Compress(input) => { - let mut witness_stream = Vec::new(); - - let input_with_merkle = self.make_merkle_proofs(input); - - Witnessable::::write( - &input_with_merkle, - &mut witness_stream, - ); - - (self.compress_program(&input_with_merkle), witness_stream) - } - }); - - // Execute the runtime. - let record = tracing::debug_span!("execute runtime").in_scope(|| { - let mut runtime = - RecursionRuntime::, Challenge, _>::new( - program.clone(), - self.compress_prover.config().perm.clone(), - ); - runtime.witness_stream = witness_stream.into(); - runtime - .run() - .map_err(|e| { - ZKMRecursionProverError::RuntimeError(e.to_string()) - }) - .unwrap(); - runtime.record - }); - - // Generate the dependencies. - let mut records = vec![record]; - tracing::debug_span!("generate dependencies").in_scope(|| -> Result<(), ZKMRecursionProverError> { - match self.compress_prover.machine().generate_dependencies( - &mut records, - &opts.recursion_opts, - None, - ) { - Ok(_) => Ok(()), - Err(e) => { - tracing::error!( - "Failed to generate dependencies for recursion proof: {}", - e - ); - Err(ZKMRecursionProverError::DependenciesGenerationError) - } - } - })?; - - // Generate the traces. - let record = records.into_iter().next().unwrap(); - let traces = tracing::debug_span!("generate traces") - .in_scope(|| self.compress_prover.generate_traces(&record)); - let traces = match traces { - Ok(traces) => traces, - Err(e) => { - tracing::error!( - "Failed to generate traces for recursion proof: {}", - e - ); - return Err(ZKMRecursionProverError::TracesGenerationError); - } - }; - - // Wait for our turn to update the state. - record_and_trace_sync.wait_for_turn(index); - - // Send the record and traces to the worker. - record_and_trace_tx - .lock() - .unwrap() - .send((index, height, program, record, traces)) - .unwrap(); - - // Advance the turn. - record_and_trace_sync.advance_turn(); - } else { - break Ok(()); - } - } - }); - } - - // Spawn workers who generate the compress proofs. - let proofs_sync = Arc::new(TurnBasedSync::new()); - let (proofs_tx, proofs_rx) = - sync_channel::<(usize, usize, StarkVerifyingKey, ShardProof)>( - num_first_layer_inputs * 2, - ); - let proofs_tx = Arc::new(Mutex::new(proofs_tx)); - let proofs_rx = Arc::new(Mutex::new(proofs_rx)); - let mut prover_handles = Vec::new(); - for _ in 0..opts.recursion_opts.shard_batch_size { - let prover_sync = Arc::clone(&proofs_sync); - let record_and_trace_rx = Arc::clone(&record_and_trace_rx); - let proofs_tx = Arc::clone(&proofs_tx); - let span = tracing::debug_span!("prove"); - let handle = s.spawn(move || { - let _span = span.enter(); - loop { - let received = { record_and_trace_rx.lock().unwrap().recv() }; - if let Ok((index, height, program, record, traces)) = received { - tracing::debug_span!("batch").in_scope(|| { - // Get the keys. - let (pk, vk) = tracing::debug_span!("Setup compress program") - .in_scope(|| self.compress_prover.setup(&program)); - - // Observe the proving key. - let mut challenger = self.compress_prover.config().challenger(); - tracing::debug_span!("observe proving key").in_scope(|| { - pk.observe_into(&mut challenger); - }); - - #[cfg(feature = "debug")] - self.compress_prover.debug_constraints( - &self.compress_prover.pk_to_host(&pk), - vec![record.clone()], - &mut challenger.clone(), - ); - - // Commit to the record and traces. - let data = tracing::debug_span!("commit") - .in_scope(|| self.compress_prover.commit(&record, traces)); - - // Generate the proof. - let proof = tracing::debug_span!("open").in_scope(|| { - self.compress_prover.open(&pk, data, &mut challenger).unwrap() - }); - - // Verify the proof. - #[cfg(feature = "debug")] - self.compress_prover - .machine() - .verify( - &vk, - &zkm_stark::MachineProof { - shard_proofs: vec![proof.clone()], - }, - &mut self.compress_prover.config().challenger(), - ) - .unwrap(); - - // Wait for our turn to update the state. - prover_sync.wait_for_turn(index); - - // Send the proof. - proofs_tx.lock().unwrap().send((index, height, vk, proof)).unwrap(); - - // Advance the turn. - prover_sync.advance_turn(); - }); - } else { - break; - } - } - }); - prover_handles.push(handle); - } - - // Spawn a worker that generates inputs for the next layer. - let handle = { - let input_tx = Arc::clone(&input_tx); - let proofs_rx = Arc::clone(&proofs_rx); - let span = tracing::debug_span!("generate next layer inputs"); - s.spawn(move || { - let _span = span.enter(); - let mut count = num_first_layer_inputs; - let mut batch: Vec<( - usize, - usize, - StarkVerifyingKey, - ShardProof, - )> = Vec::new(); - loop { - if expected_height == 0 { - break; - } - let received = { proofs_rx.lock().unwrap().recv() }; - if let Ok((index, height, vk, proof)) = received { - batch.push((index, height, vk, proof)); - - // If we haven't reached the batch size, continue. - if batch.len() < batch_size { - continue; - } - - // Compute whether we're at the last input of a layer. - let mut is_last = false; - if let Some(first) = batch.first() { - is_last = first.1 != height; - } - - // If we're at the last input of a layer, we need to only include the - // first input, otherwise we include all inputs. - let inputs = - if is_last { vec![batch[0].clone()] } else { batch.clone() }; - - let next_input_height = inputs[0].1 + 1; - - let is_complete = next_input_height == expected_height; - - let vks_and_proofs = inputs - .into_iter() - .map(|(_, _, vk, proof)| (vk, proof)) - .collect::>(); - let input = ZKMCircuitWitness::Compress(ZKMCompressWitnessValues { - vks_and_proofs, - is_complete, - }); - - input_sync.wait_for_turn(count); - input_tx - .lock() - .unwrap() - .send((count, next_input_height, input)) - .unwrap(); - input_sync.advance_turn(); - count += 1; - - // If we're at the root of the tree, stop generating inputs. - if is_complete { - break; - } - - // If we were at the last input of a layer, we keep everything but the - // first input. Otherwise, we empty the batch. - if is_last { - batch = vec![batch[1].clone()]; - } else { - batch = Vec::new(); - } - } else { - break; - } - } - }) - }; - - // Wait for all the provers to finish. - drop(input_tx); - drop(record_and_trace_tx); - drop(proofs_tx); - for handle in prover_handles { - handle.join().unwrap(); - } - handle.join().unwrap(); - - let (_, _, vk, proof) = proofs_rx.lock().unwrap().recv().unwrap(); - (vk, proof) - }); - - Ok(ZKMReduceProof { vk, proof }) } + true +} - /// Wrap a reduce proof into a STARK proven over a SNARK-friendly field. - #[instrument(name = "shrink", level = "info", skip_all)] - pub fn shrink( - &self, - reduced_proof: ZKMReduceProof, - opts: ZKMProverOpts, - ) -> Result, ZKMRecursionProverError> { - // Make the compress proof. - let ZKMReduceProof { vk: compressed_vk, proof: compressed_proof } = reduced_proof; - let input = ZKMCompressWitnessValues { - vks_and_proofs: vec![(compressed_vk, compressed_proof)], - is_complete: true, - }; - - let input_with_merkle = self.make_merkle_proofs(input); - - let program = - self.shrink_program(ShrinkAir::::shrink_shape(), &input_with_merkle); - - // Run the compress program. - let mut runtime = RecursionRuntime::, Challenge, _>::new( - program.clone(), - self.shrink_prover.config().perm.clone(), - ); - - let mut witness_stream = Vec::new(); - Witnessable::::write(&input_with_merkle, &mut witness_stream); - - runtime.witness_stream = witness_stream.into(); - - runtime.run().map_err(|e| ZKMRecursionProverError::RuntimeError(e.to_string()))?; - - runtime.print_stats(); - tracing::debug!("Shrink program executed successfully"); - - let (shrink_pk, shrink_vk) = - tracing::debug_span!("setup shrink").in_scope(|| self.shrink_prover.setup(&program)); - - // Prove the compress program. - let mut compress_challenger = self.shrink_prover.config().challenger(); - let mut compress_proof = self - .shrink_prover - .prove(&shrink_pk, vec![runtime.record], &mut compress_challenger, opts.recursion_opts) - .unwrap(); +/// Get the committed values Bn Poseidon2 digest this reduce proof is representing. +pub fn zkm_committed_values_digest_bn254( + proof: &ZKMReduceProof, +) -> Bn254Fr { + let proof = &proof.proof; + let pv: &RecursionPublicValues = proof.public_values.as_slice().borrow(); + let committed_values_digest_bytes: [KoalaBear; 32] = + words_to_bytes(&pv.committed_value_digest).try_into().unwrap(); + koalabear_bytes_to_bn254(&committed_values_digest_bytes) +} - Ok(ZKMReduceProof { vk: shrink_vk, proof: compress_proof.shard_proofs.pop().unwrap() }) +impl ZKMCoreProofData { + pub fn save(&self, path: &str) -> Result<(), std::io::Error> { + let data = serde_json::to_string(self).unwrap(); + fs::write(path, data).unwrap(); + Ok(()) } +} - /// Wrap a reduce proof into a STARK proven over a SNARK-friendly field. - #[instrument(name = "wrap_bn254", level = "info", skip_all)] - pub fn wrap_bn254( - &self, - compressed_proof: ZKMReduceProof, - opts: ZKMProverOpts, - ) -> Result, ZKMRecursionProverError> { - let ZKMReduceProof { vk: compressed_vk, proof: compressed_proof } = compressed_proof; - let input = ZKMCompressWitnessValues { - vks_and_proofs: vec![(compressed_vk, compressed_proof)], - is_complete: true, - }; - let input_with_vk = self.make_merkle_proofs(input); - - let program = self.wrap_program(); - - // Run the compress program. - let mut runtime = RecursionRuntime::, Challenge, _>::new( - program.clone(), - self.shrink_prover.config().perm.clone(), - ); - - let mut witness_stream = Vec::new(); - Witnessable::::write(&input_with_vk, &mut witness_stream); - - runtime.witness_stream = witness_stream.into(); +/// Get the number of cycles for a given program. +pub fn get_cycles(elf: &[u8], stdin: &ZKMStdin) -> u64 { + let program = Program::from(elf).unwrap(); + let mut runtime = Executor::new(program, ZKMCoreOpts::default()); + runtime.write_vecs(&stdin.buffer); + runtime.run_fast().unwrap(); + runtime.state.global_clk +} - runtime.run().map_err(|e| ZKMRecursionProverError::RuntimeError(e.to_string()))?; +/// Load an ELF file from a given path. +pub fn load_elf(path: &str) -> Result, std::io::Error> { + let mut elf_code = Vec::new(); + File::open(path)?.read_to_end(&mut elf_code)?; + Ok(elf_code) +} - runtime.print_stats(); - tracing::debug!("wrap program executed successfully"); +pub fn words_to_bytes(words: &[Word]) -> Vec { + words.iter().flat_map(|word| word.0).collect() +} - // Setup the wrap program. - let (wrap_pk, wrap_vk) = - tracing::debug_span!("setup wrap").in_scope(|| self.wrap_prover.setup(&program)); +/// Convert 8 KoalaBear words into a Bn254Fr field element by shifting by 31 bits each time. The last +/// word becomes the least significant bits. +pub fn koalabears_to_bn254(digest: &[KoalaBear; 8]) -> Bn254Fr { + let mut result = Bn254Fr::ZERO; + for word in digest.iter() { + // Since KoalaBear prime is less than 2^31, we can shift by 31 bits each time and still be + // within the Bn254Fr field, so we don't have to truncate the top 3 bits. + result *= Bn254Fr::from_canonical_u64(1 << 31); + result += Bn254Fr::from_canonical_u32(word.as_canonical_u32()); + } + result +} - if self.wrap_vk.set(wrap_vk.clone()).is_ok() { - tracing::debug!("wrap verifier key set"); +/// Convert 32 KoalaBear bytes into a Bn254Fr field element. The first byte's most significant 3 bits +/// (which would become the 3 most significant bits) are truncated. +pub fn koalabear_bytes_to_bn254(bytes: &[KoalaBear; 32]) -> Bn254Fr { + let mut result = Bn254Fr::ZERO; + for (i, byte) in bytes.iter().enumerate() { + debug_assert!(byte < &KoalaBear::from_canonical_u32(256)); + if i == 0 { + // 32 bytes is more than Bn254 prime, so we need to truncate the top 3 bits. + result = Bn254Fr::from_canonical_u32(byte.as_canonical_u32() & 0x1f); + } else { + result *= Bn254Fr::from_canonical_u32(256); + result += Bn254Fr::from_canonical_u32(byte.as_canonical_u32()); } - - // Prove the wrap program. - let mut wrap_challenger = self.wrap_prover.config().challenger(); - let time = std::time::Instant::now(); - let mut wrap_proof = self - .wrap_prover - .prove(&wrap_pk, vec![runtime.record], &mut wrap_challenger, opts.recursion_opts) - .unwrap(); - let elapsed = time.elapsed(); - tracing::debug!("wrap proving time: {:?}", elapsed); - let mut wrap_challenger = self.wrap_prover.config().challenger(); - self.wrap_prover.machine().verify(&wrap_vk, &wrap_proof, &mut wrap_challenger).unwrap(); - tracing::info!("wrapping successful"); - - Ok(ZKMReduceProof { vk: wrap_vk, proof: wrap_proof.shard_proofs.pop().unwrap() }) } + result +} - /// Wrap the STARK proven over a SNARK-friendly field into a PLONK proof. - #[instrument(name = "wrap_plonk_bn254", level = "info", skip_all)] - pub fn wrap_plonk_bn254( - &self, - proof: ZKMReduceProof, - build_dir: &Path, - ) -> PlonkBn254Proof { - let input = ZKMCompressWitnessValues { - vks_and_proofs: vec![(proof.vk.clone(), proof.proof.clone())], - is_complete: true, - }; - let vkey_hash = zkm_vkey_digest_bn254(&proof); - let committed_values_digest = zkm_committed_values_digest_bn254(&proof); - - let mut witness = Witness::default(); - input.write(&mut witness); - witness.write_committed_values_digest(committed_values_digest); - witness.write_vkey_hash(vkey_hash); - - let prover = PlonkBn254Prover::new(); - let proof = prover.prove(witness, build_dir.to_path_buf()); - - // Verify the proof. - prover - .verify( - &proof, - &vkey_hash.as_canonical_biguint(), - &committed_values_digest.as_canonical_biguint(), - build_dir, - ) - .unwrap(); - - proof - } - - /// Wrap the STARK proven over a SNARK-friendly field into a Groth16 proof. - #[instrument(name = "wrap_groth16_bn254", level = "info", skip_all)] - pub fn wrap_groth16_bn254( - &self, - proof: ZKMReduceProof, - build_dir: &Path, - ) -> Groth16Bn254Proof { - let input = ZKMCompressWitnessValues { - vks_and_proofs: vec![(proof.vk.clone(), proof.proof.clone())], - is_complete: true, - }; - let vkey_hash = zkm_vkey_digest_bn254(&proof); - let committed_values_digest = zkm_committed_values_digest_bn254(&proof); - - let mut witness = Witness::default(); - input.write(&mut witness); - witness.write_committed_values_digest(committed_values_digest); - witness.write_vkey_hash(vkey_hash); - - let prover = Groth16Bn254Prover::new(); - let proof = prover.prove(witness, build_dir.to_path_buf()); - - // Verify the proof. - prover - .verify( - &proof, - &vkey_hash.as_canonical_biguint(), - &committed_values_digest.as_canonical_biguint(), - build_dir, - ) - .unwrap(); - - proof +/// Utility method for converting u32 words to bytes in big endian. +pub fn words_to_bytes_be(words: &[u32; 8]) -> [u8; 32] { + let mut bytes = [0u8; 32]; + for i in 0..8 { + let word_bytes = words[i].to_be_bytes(); + bytes[i * 4..(i + 1) * 4].copy_from_slice(&word_bytes); } + bytes +} - /// Accumulate deferred proofs into a single digest. - pub fn hash_deferred_proofs( - prev_digest: [Val; DIGEST_SIZE], - deferred_proofs: &[ZKMReduceProof], - ) -> [Val; 8] { - let mut digest = prev_digest; - for proof in deferred_proofs.iter() { - let pv: &RecursionPublicValues> = - proof.proof.public_values.as_slice().borrow(); - let committed_values_digest = words_to_bytes(&pv.committed_value_digest); - digest = hash_deferred_proof( - &digest, - &pv.zkm_vk_digest, - &committed_values_digest.try_into().unwrap(), - ); +pub trait MaybeTakeIterator: Iterator { + fn maybe_skip(self, bound: Option) -> RangedIterator + where + Self: Sized, + { + match bound { + Some(bound) => RangedIterator::Skip(self.skip(bound)), + None => RangedIterator::Unbounded(self), } - digest - } - - pub fn make_merkle_proofs( - &self, - input: ZKMCompressWitnessValues, - ) -> ZKMCompressWithVKeyWitnessValues { - let num_vks = self.recursion_vk_map.len(); - let (vk_indices, vk_digest_values): (Vec<_>, Vec<_>) = if self.vk_verification { - input - .vks_and_proofs - .iter() - .map(|(vk, _)| { - let vk_digest = vk.hash_koalabear(); - let index = self.recursion_vk_map.get(&vk_digest).expect("vk not allowed"); - (index, vk_digest) - }) - .unzip() - } else { - input - .vks_and_proofs - .iter() - .map(|(vk, _)| { - let vk_digest = vk.hash_koalabear(); - let index = (vk_digest[0].as_canonical_u32() as usize) % num_vks; - (index, [KoalaBear::from_canonical_usize(index); 8]) - }) - .unzip() - }; - - let proofs = vk_indices - .iter() - .map(|index| { - let (_, proof) = MerkleTree::open(&self.recursion_vk_tree, *index); - proof - }) - .collect(); - - let merkle_val = ZKMMerkleProofWitnessValues { - root: self.recursion_vk_root, - values: vk_digest_values, - vk_merkle_proofs: proofs, - }; - - ZKMCompressWithVKeyWitnessValues { compress_val: input, merkle_val } } - fn check_for_high_cycles(cycles: u64) { - if cycles > 100_000_000 { - tracing::warn!( - "high cycle count, consider using the prover network for proof generation" - ); + fn maybe_take(self, bound: Option) -> RangedIterator + where + Self: Sized, + { + match bound { + Some(bound) => RangedIterator::Take(self.take(bound)), + None => RangedIterator::Unbounded(self), } } } -pub fn compress_program_from_input( - config: Option<&RecursionShapeConfig>>, - compress_prover: &C::CompressProver, - vk_verification: bool, - input: &ZKMCompressWithVKeyWitnessValues, -) -> RecursionProgram { - let builder_span = tracing::debug_span!("build compress program").entered(); - let mut builder = Builder::::default(); - // read the input. - let input = input.read(&mut builder); - // Verify the proof. - ZKMCompressWithVKeyVerifier::verify( - &mut builder, - compress_prover.machine(), - input, - vk_verification, - PublicValuesOutputDigest::Reduce, - ); - let operations = builder.into_operations(); - builder_span.exit(); - - // Compile the program. - let compiler_span = tracing::debug_span!("compile compress program").entered(); - let mut compiler = AsmCompiler::::default(); - let mut program = compiler.compile(operations); - if let Some(config) = config { - config.fix_shape(&mut program); - } - compiler_span.exit(); +impl MaybeTakeIterator for I {} - program +pub enum RangedIterator { + Unbounded(I), + Skip(Skip), + Take(Take), + Range(Take>), } -#[cfg(test)] -pub mod tests { - use std::{ - collections::BTreeSet, - fs::File, - io::{Read, Write}, - }; - - use super::*; - - use crate::build::try_build_plonk_bn254_artifacts_dev; - use anyhow::Result; - use build::{build_constraints_and_witness, try_build_groth16_bn254_artifacts_dev}; - use p3_field::PrimeField32; - - use shapes::ZKMProofShape; - use zkm_recursion_core::air::RecursionPublicValues; - - #[cfg(test)] - use serial_test::serial; - use utils::zkm_vkey_digest_koalabear; - #[cfg(test)] - use zkm_core_machine::utils::setup_logger; - - #[derive(Debug, Clone, Copy, PartialEq, Eq)] - pub enum Test { - Core, - Compress, - Shrink, - Wrap, - CircuitTest, - All, - } - - pub fn test_e2e_prover( - prover: &ZKMProver, - elf: &[u8], - stdin: ZKMStdin, - opts: ZKMProverOpts, - test_kind: Test, - ) -> Result<()> { - run_e2e_prover_with_options(prover, elf, stdin, opts, test_kind, true) - } - - pub fn bench_e2e_prover( - prover: &ZKMProver, - elf: &[u8], - stdin: ZKMStdin, - opts: ZKMProverOpts, - test_kind: Test, - ) -> Result<()> { - run_e2e_prover_with_options(prover, elf, stdin, opts, test_kind, false) - } - - pub fn run_e2e_prover_with_options( - prover: &ZKMProver, - elf: &[u8], - stdin: ZKMStdin, - opts: ZKMProverOpts, - test_kind: Test, - verify: bool, - ) -> Result<()> { - tracing::info!("initializing prover"); - let context = ZKMContext::default(); - - tracing::info!("setup elf"); - let (_, pk_d, program, vk) = prover.setup(elf); +impl Iterator for RangedIterator { + type Item = I::Item; - tracing::info!("prove core"); - let core_proof = prover.prove_core(&pk_d, program, &stdin, opts, context)?; - let public_values = core_proof.public_values.clone(); - - if env::var("COLLECT_SHAPES").is_ok() { - let mut shapes = BTreeSet::new(); - for proof in core_proof.proof.0.iter() { - let shape = ZKMProofShape::Recursion(proof.shape()); - tracing::info!("shape: {:?}", shape); - shapes.insert(shape); - } - - let mut file = File::create("../shapes.bin").unwrap(); - bincode::serialize_into(&mut file, &shapes).unwrap(); - } - - if verify { - tracing::info!("verify core"); - prover.verify(&core_proof.proof, &vk)?; - } - - if test_kind == Test::Core { - return Ok(()); - } - - tracing::info!("compress"); - let compress_span = tracing::debug_span!("compress").entered(); - let compressed_proof = prover.compress(&vk, core_proof, vec![], opts)?; - compress_span.exit(); - - if verify { - tracing::info!("verify compressed"); - prover.verify_compressed(&compressed_proof, &vk)?; - } - - if test_kind == Test::Compress { - return Ok(()); - } - - tracing::info!("shrink"); - let shrink_proof = prover.shrink(compressed_proof, opts)?; - - if verify { - tracing::info!("verify shrink"); - prover.verify_shrink(&shrink_proof, &vk)?; - } - - if test_kind == Test::Shrink { - return Ok(()); + fn next(&mut self) -> Option { + match self { + RangedIterator::Unbounded(unbounded) => unbounded.next(), + RangedIterator::Skip(skip) => skip.next(), + RangedIterator::Take(take) => take.next(), + RangedIterator::Range(range) => range.next(), } - - tracing::info!("wrap bn254"); - let wrapped_bn254_proof = prover.wrap_bn254(shrink_proof, opts)?; - let bytes = bincode::serialize(&wrapped_bn254_proof).unwrap(); - - // Save the proof. - let mut file = File::create("proof-with-pis.bin").unwrap(); - file.write_all(bytes.as_slice()).unwrap(); - - // Load the proof. - let mut file = File::open("proof-with-pis.bin").unwrap(); - let mut bytes = Vec::new(); - file.read_to_end(&mut bytes).unwrap(); - - let wrapped_bn254_proof = bincode::deserialize(&bytes).unwrap(); - - if verify { - tracing::info!("verify wrap bn254"); - prover.verify_wrap_bn254(&wrapped_bn254_proof, &vk).unwrap(); - } - - if test_kind == Test::Wrap { - return Ok(()); - } - - tracing::info!("checking vkey hash koalabear"); - let vk_digest_koalabear = zkm_vkey_digest_koalabear(&wrapped_bn254_proof); - assert_eq!(vk_digest_koalabear, vk.hash_koalabear()); - - tracing::info!("checking vkey hash bn254"); - let vk_digest_bn254 = zkm_vkey_digest_bn254(&wrapped_bn254_proof); - assert_eq!(vk_digest_bn254, vk.hash_bn254()); - - tracing::info!("Test the outer Plonk circuit"); - let (constraints, witness) = - build_constraints_and_witness(&wrapped_bn254_proof.vk, &wrapped_bn254_proof.proof); - PlonkBn254Prover::test(constraints, witness); - tracing::info!("Circuit test succeeded"); - - if test_kind == Test::CircuitTest { - return Ok(()); - } - - tracing::info!("generate plonk bn254 proof"); - let artifacts_dir = try_build_plonk_bn254_artifacts_dev( - &wrapped_bn254_proof.vk, - &wrapped_bn254_proof.proof, - ); - let plonk_bn254_proof = - prover.wrap_plonk_bn254(wrapped_bn254_proof.clone(), &artifacts_dir); - println!("{plonk_bn254_proof:?}"); - - prover.verify_plonk_bn254(&plonk_bn254_proof, &vk, &public_values, &artifacts_dir)?; - - tracing::info!("generate groth16 bn254 proof"); - let artifacts_dir = try_build_groth16_bn254_artifacts_dev( - &wrapped_bn254_proof.vk, - &wrapped_bn254_proof.proof, - ); - let groth16_bn254_proof = prover.wrap_groth16_bn254(wrapped_bn254_proof, &artifacts_dir); - println!("{groth16_bn254_proof:?}"); - - if verify { - prover.verify_groth16_bn254( - &groth16_bn254_proof, - &vk, - &public_values, - &artifacts_dir, - )?; - } - - Ok(()) - } - - pub fn test_e2e_with_deferred_proofs_prover( - opts: ZKMProverOpts, - ) -> Result<()> { - // Test program which proves the Keccak-256 hash of various inputs. - let keccak_elf = test_artifacts::KECCAK_SPONGE_ELF; - - // Test program which verifies proofs of a vkey and a list of committed inputs. - let verify_elf = test_artifacts::VERIFY_PROOF_ELF; - - tracing::info!("initializing prover"); - let prover = ZKMProver::::new(); - - tracing::info!("setup keccak elf"); - let (_, keccak_pk_d, keccak_program, keccak_vk) = prover.setup(keccak_elf); - - tracing::info!("setup verify elf"); - let (_, verify_pk_d, verify_program, verify_vk) = prover.setup(verify_elf); - - tracing::info!("prove subproof 1"); - let mut stdin = ZKMStdin::new(); - stdin.write(&1usize); - stdin.write(&vec![0u8, 0, 0]); - let deferred_proof_1 = prover.prove_core( - &keccak_pk_d, - keccak_program.clone(), - &stdin, - opts, - Default::default(), - )?; - let pv_1 = deferred_proof_1.public_values.as_slice().to_vec().clone(); - - // Generate a second proof of keccak of various inputs. - tracing::info!("prove subproof 2"); - let mut stdin = ZKMStdin::new(); - stdin.write(&3usize); - stdin.write(&vec![0u8, 1, 2]); - stdin.write(&vec![2, 3, 4]); - stdin.write(&vec![5, 6, 7]); - let deferred_proof_2 = - prover.prove_core(&keccak_pk_d, keccak_program, &stdin, opts, Default::default())?; - let pv_2 = deferred_proof_2.public_values.as_slice().to_vec().clone(); - - // Generate recursive proof of first subproof. - tracing::info!("compress subproof 1"); - let deferred_reduce_1 = prover.compress(&keccak_vk, deferred_proof_1, vec![], opts)?; - - // Generate recursive proof of second subproof. - tracing::info!("compress subproof 2"); - let deferred_reduce_2 = prover.compress(&keccak_vk, deferred_proof_2, vec![], opts)?; - - // Run verify program with keccak vkey, subproofs, and their committed values. - let mut stdin = ZKMStdin::new(); - let vkey_digest = keccak_vk.hash_koalabear(); - let vkey_digest: [u32; 8] = vkey_digest - .iter() - .map(|n| n.as_canonical_u32()) - .collect::>() - .try_into() - .unwrap(); - stdin.write(&vkey_digest); - stdin.write(&vec![pv_1.clone(), pv_2.clone(), pv_2.clone()]); - stdin.write_proof(deferred_reduce_1.clone(), keccak_vk.vk.clone()); - stdin.write_proof(deferred_reduce_2.clone(), keccak_vk.vk.clone()); - stdin.write_proof(deferred_reduce_2.clone(), keccak_vk.vk.clone()); - - tracing::info!("proving verify program (core)"); - let verify_proof = - prover.prove_core(&verify_pk_d, verify_program, &stdin, opts, Default::default())?; - // let public_values = verify_proof.public_values.clone(); - - // Generate recursive proof of verify program - tracing::info!("compress verify program"); - let verify_reduce = prover.compress( - &verify_vk, - verify_proof, - vec![deferred_reduce_1, deferred_reduce_2.clone(), deferred_reduce_2], - opts, - )?; - let reduce_pv: &RecursionPublicValues<_> = - verify_reduce.proof.public_values.as_slice().borrow(); - println!("deferred_hash: {:?}", reduce_pv.deferred_proofs_digest); - println!("complete: {:?}", reduce_pv.is_complete); - - tracing::info!("verify verify program"); - prover.verify_compressed(&verify_reduce, &verify_vk)?; - - let shrink_proof = prover.shrink(verify_reduce, opts)?; - - tracing::info!("verify shrink"); - prover.verify_shrink(&shrink_proof, &verify_vk)?; - - tracing::info!("wrap bn254"); - let wrapped_bn254_proof = prover.wrap_bn254(shrink_proof, opts)?; - - tracing::info!("verify wrap bn254"); - println!("verify wrap bn254 {:#?}", wrapped_bn254_proof.vk.commit); - prover.verify_wrap_bn254(&wrapped_bn254_proof, &verify_vk).unwrap(); - - Ok(()) - } - - /// Tests an end-to-end workflow of proving a program across the entire proof generation - /// pipeline. - /// - /// Add `FRI_QUERIES`=1 to your environment for faster execution. Should only take a few minutes - /// on a Mac M2. Note: This test always re-builds the plonk bn254 artifacts, so setting ZKM_DEV - /// is not needed. - #[test] - #[serial] - #[ignore] - fn test_e2e() -> Result<()> { - let elf = test_artifacts::FIBONACCI_ELF; - setup_logger(); - let opts = ZKMProverOpts::default(); - // TODO(mattstam): We should Test::Plonk here, but this uses the existing - // docker image which has a different API than the current. So we need to wait until the - // next release (v1.2.0+), and then switch it back. - let prover = ZKMProver::::new(); - test_e2e_prover::( - &prover, - elf, - ZKMStdin::default(), - opts, - Test::All, - ) - } - - /// Tests an end-to-end workflow of proving a program across the entire proof generation - /// pipeline. - /// - /// Add `FRI_QUERIES`=1 to your environment for faster execution. Should only take a few minutes - /// on a Mac M2. Note: This test always re-builds the plonk bn254 artifacts, so setting ZKM_DEV - /// is not needed. - #[test] - #[serial] - #[ignore] - fn test_e2e_hello_world() -> Result<()> { - let elf = test_artifacts::HELLO_WORLD_ELF; - - setup_logger(); - let opts = ZKMProverOpts::default(); - // TODO(mattstam): We should Test::Plonk here, but this uses the existing - // docker image which has a different API than the current. So we need to wait until the - // next release (v1.2.0+), and then switch it back. - let prover = ZKMProver::::new(); - test_e2e_prover::( - &prover, - elf, - ZKMStdin::default(), - opts, - Test::All, - ) - } - - /// Tests an end-to-end workflow of proving a program across the entire proof generation - /// pipeline in addition to verifying deferred proofs. - #[test] - #[serial] - #[ignore] - fn test_e2e_with_deferred_proofs() -> Result<()> { - setup_logger(); - test_e2e_with_deferred_proofs_prover::(ZKMProverOpts::default()) } } From 9e5f057f49b8c10ce533a07304e103d2bdc85836 Mon Sep 17 00:00:00 2001 From: SashaMalysehko Date: Mon, 19 Jan 2026 22:50:30 +0200 Subject: [PATCH 2/3] Update utils.rs --- crates/prover/src/utils.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/crates/prover/src/utils.rs b/crates/prover/src/utils.rs index 19793279..d4a3f7dd 100644 --- a/crates/prover/src/utils.rs +++ b/crates/prover/src/utils.rs @@ -50,14 +50,12 @@ pub fn root_public_values_digest( public_values: &RootPublicValues, ) -> [KoalaBear; 8] { let hash = InnerHash::new(config.perm.clone()); - let input = (*public_values.zkm_vk_digest()) - .into_iter() - .chain( - (*public_values.committed_value_digest()) - .into_iter() - .flat_map(|word| word.0.into_iter()), - ) - .collect::>(); + let mut input = [KoalaBear::ZERO; 40]; + input[..8].copy_from_slice(public_values.zkm_vk_digest()); + for (i, word) in public_values.committed_value_digest().iter().enumerate() { + let start = 8 + i * 4; + input[start..start + 4].copy_from_slice(&word.0); + } hash.hash_slice(&input) } From cdd040a9e52f2f59dcd35a8adeb25f0788d96d2f Mon Sep 17 00:00:00 2001 From: SashaMalysehko Date: Mon, 19 Jan 2026 22:51:12 +0200 Subject: [PATCH 3/3] fix copypaste mistake --- crates/prover/src/lib.rs | 1706 ++++++++++++++++++++++++++++++++++---- 1 file changed, 1543 insertions(+), 163 deletions(-) diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index d4a3f7dd..d20e5b53 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -1,210 +1,1590 @@ +//! An end-to-end-prover implementation for the Ziren zkVM. +//! +//! Separates the proof generation process into multiple stages: +//! +//! 1. Generate shard proofs which split up and prove the valid execution of a MIPS program. +//! 2. Compress shard proofs into a single shard proof. +//! 3. Wrap the shard proof into a SNARK-friendly field. +//! 4. Wrap the last shard proof, proven over the SNARK-friendly field, into a PLONK proof. + +#![allow(clippy::too_many_arguments)] +#![allow(clippy::new_without_default)] +#![allow(clippy::collapsible_else_if)] + +pub mod build; +pub mod components; +pub mod shapes; +pub mod types; +pub mod utils; +pub mod verify; + use std::{ borrow::Borrow, - fs::{self, File}, - io::Read, - iter::{Skip, Take}, + collections::BTreeMap, + env, + num::NonZeroUsize, + path::Path, + sync::{ + atomic::{AtomicUsize, Ordering}, + mpsc::sync_channel, + Arc, Mutex, OnceLock, + }, + thread, }; -use itertools::Itertools; -use p3_bn254_fr::Bn254Fr; -use p3_field::{FieldAlgebra, PrimeField32}; +use lru::LruCache; +use p3_field::{FieldAlgebra, PrimeField, PrimeField32}; use p3_koala_bear::KoalaBear; -use p3_symmetric::CryptographicHasher; -use zkm_core_executor::{Executor, Program}; -use zkm_core_machine::{io::ZKMStdin, reduce::ZKMReduceProof}; -use zkm_recursion_circuit::machine::RootPublicValues; +use p3_matrix::dense::RowMajorMatrix; +use shapes::ZKMProofShape; +use tracing::instrument; +use zkm_core_executor::{ExecutionError, ExecutionReport, Executor, Program, ZKMContext}; +use zkm_core_machine::{ + io::ZKMStdin, + mips::MipsAir, + reduce::ZKMReduceProof, + shape::CoreShapeConfig, + utils::{concurrency::TurnBasedSync, ZKMCoreProverError}, +}; +use zkm_primitives::{hash_deferred_proof, io::ZKMPublicValues}; +use zkm_recursion_circuit::{ + hash::FieldHasher, + machine::{ + PublicValuesOutputDigest, ZKMCompressRootVerifierWithVKey, ZKMCompressShape, + ZKMCompressWithVKeyVerifier, ZKMCompressWithVKeyWitnessValues, ZKMCompressWithVkeyShape, + ZKMCompressWitnessValues, ZKMDeferredVerifier, ZKMDeferredWitnessValues, + ZKMMerkleProofWitnessValues, ZKMRecursionShape, ZKMRecursionWitnessValues, + ZKMRecursiveVerifier, + }, + merkle_tree::MerkleTree, + witness::Witnessable, + WrapConfig, +}; +use zkm_recursion_compiler::{ + circuit::AsmCompiler, + config::InnerConfig, + ir::{Builder, Witness}, +}; use zkm_recursion_core::{ - air::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH}, + air::RecursionPublicValues, + machine::RecursionAir, + runtime::ExecutionRecord, + shape::{RecursionShape, RecursionShapeConfig}, stark::KoalaBearPoseidon2Outer, + RecursionProgram, Runtime as RecursionRuntime, }; -use zkm_stark::{koala_bear_poseidon2::MyHash as InnerHash, Word, ZKMCoreOpts}; +pub use zkm_recursion_gnark_ffi::proof::{Groth16Bn254Proof, PlonkBn254Proof}; +use zkm_recursion_gnark_ffi::{groth16_bn254::Groth16Bn254Prover, plonk_bn254::PlonkBn254Prover}; +use zkm_stark::{ + air::PublicValues, koala_bear_poseidon2::KoalaBearPoseidon2, Challenge, MachineProver, + ShardProof, StarkGenericConfig, StarkVerifyingKey, Val, Word, ZKMCoreOpts, ZKMProverOpts, + DIGEST_SIZE, +}; +use zkm_stark::{shape::OrderedShape, MachineProvingKey}; -use crate::{InnerSC, ZKMCoreProofData}; +pub use types::*; +use utils::{words_to_bytes, zkm_committed_values_digest_bn254, zkm_vkey_digest_bn254}; -/// Get the Ziren vkey KoalaBear Poseidon2 digest this reduce proof is representing. -pub fn zkm_vkey_digest_koalabear( - proof: &ZKMReduceProof, -) -> [KoalaBear; 8] { - let proof = &proof.proof; - let pv: &RecursionPublicValues = proof.public_values.as_slice().borrow(); - pv.zkm_vk_digest -} +use components::{DefaultProverComponents, ZKMProverComponents}; -/// Get the Ziren vkey Bn Poseidon2 digest this reduce proof is representing. -pub fn zkm_vkey_digest_bn254(proof: &ZKMReduceProof) -> Bn254Fr { - koalabears_to_bn254(&zkm_vkey_digest_koalabear(proof)) -} +pub use zkm_core_machine::ZKM_CIRCUIT_VERSION; -/// Compute the digest of the public values. -pub fn recursion_public_values_digest( - config: &InnerSC, - public_values: &RecursionPublicValues, -) -> [KoalaBear; 8] { - let hash = InnerHash::new(config.perm.clone()); - let pv_array = public_values.as_array(); - hash.hash_slice(&pv_array[0..NUM_PV_ELMS_TO_HASH]) -} +/// The configuration for the core prover. +pub type CoreSC = KoalaBearPoseidon2; + +/// The configuration for the inner prover. +pub type InnerSC = KoalaBearPoseidon2; + +/// The configuration for the outer prover. +pub type OuterSC = KoalaBearPoseidon2Outer; + +pub type DeviceProvingKey = <::CoreProver as MachineProver< + KoalaBearPoseidon2, + MipsAir, +>>::DeviceProvingKey; + +const COMPRESS_DEGREE: usize = 3; +const SHRINK_DEGREE: usize = 3; +const WRAP_DEGREE: usize = 9; + +const CORE_CACHE_SIZE: usize = 5; +pub const REDUCE_BATCH_SIZE: usize = 2; + +// TODO: FIX +// +// const SHAPES_URL_PREFIX: &str = "https://zkm-circuits.s3.us-east-2.amazonaws.com/shapes"; +// const SHAPES_VERSION: &str = "146079e0e"; +// lazy_static! { +// static ref SHAPES_INIT: Once = Once::new(); +// } + +pub type CompressAir = RecursionAir; +pub type ShrinkAir = RecursionAir; +pub type WrapAir = RecursionAir; + +/// An end-to-end prover implementation for the Ziren zkVM. +pub struct ZKMProver { + /// The machine used for proving the core step. + pub core_prover: C::CoreProver, + + /// The machine used for proving the recursive and reduction steps. + pub compress_prover: C::CompressProver, + + /// The machine used for proving the shrink step. + pub shrink_prover: C::ShrinkProver, + + /// The machine used for proving the wrapping step. + pub wrap_prover: C::WrapProver, + + /// The cache of compiled recursion programs. + pub lift_programs_lru: Mutex>>>, + + /// The number of cache misses for recursion programs. + pub lift_cache_misses: AtomicUsize, + + /// The cache of compiled compression programs. + pub join_programs_map: BTreeMap>>, + + /// The number of cache misses for compression programs. + pub join_cache_misses: AtomicUsize, + + /// The root of the allowed recursion verification keys. + pub recursion_vk_root: >::Digest, + + /// The allowed VKs and their corresponding indices. + pub recursion_vk_map: BTreeMap<>::Digest, usize>, + + /// The Merkle tree for the allowed VKs. + pub recursion_vk_tree: MerkleTree, + + /// The core shape configuration. + pub core_shape_config: Option>, + + /// The recursion shape configuration. + pub compress_shape_config: Option>>, + + /// The program for wrapping. + pub wrap_program: OnceLock>>, + + /// The verifying key for wrapping. + pub wrap_vk: OnceLock>, -pub fn root_public_values_digest( - config: &InnerSC, - public_values: &RootPublicValues, -) -> [KoalaBear; 8] { - let hash = InnerHash::new(config.perm.clone()); - let mut input = [KoalaBear::ZERO; 40]; - input[..8].copy_from_slice(public_values.zkm_vk_digest()); - for (i, word) in public_values.committed_value_digest().iter().enumerate() { - let start = 8 + i * 4; - input[start..start + 4].copy_from_slice(&word.0); - } - hash.hash_slice(&input) + /// Whether to verify verification keys. + pub vk_verification: bool, } -pub fn is_root_public_values_valid( - config: &InnerSC, - public_values: &RootPublicValues, -) -> bool { - let expected_digest = root_public_values_digest(config, public_values); - for (value, expected) in public_values.digest().iter().copied().zip_eq(expected_digest) { - if value != expected { - return false; +impl ZKMProver { + /// Initializes a new [ZKMProver]. + #[instrument(name = "initialize prover", level = "debug", skip_all)] + pub fn new() -> Self { + Self::uninitialized() + } + + /// Creates a new [ZKMProver] with lazily initialized components. + pub fn uninitialized() -> Self { + // Initialize the provers. + let core_machine = MipsAir::machine(CoreSC::default()); + let core_prover = C::CoreProver::new(core_machine); + + let compress_machine = CompressAir::compress_machine(InnerSC::default()); + let compress_prover = C::CompressProver::new(compress_machine); + + // TODO: Put the correct shrink and wrap machines here. + let shrink_machine = ShrinkAir::shrink_machine(InnerSC::compressed()); + let shrink_prover = C::ShrinkProver::new(shrink_machine); + + let wrap_machine = WrapAir::wrap_machine(OuterSC::default()); + let wrap_prover = C::WrapProver::new(wrap_machine); + + let core_cache_size = NonZeroUsize::new( + env::var("PROVER_CORE_CACHE_SIZE") + .unwrap_or_else(|_| CORE_CACHE_SIZE.to_string()) + .parse() + .unwrap_or(CORE_CACHE_SIZE), + ) + .expect("PROVER_CORE_CACHE_SIZE must be a non-zero usize"); + + let core_shape_config = env::var("FIX_CORE_SHAPES") + .map(|v| v.eq_ignore_ascii_case("true")) + .unwrap_or(true) + .then_some(CoreShapeConfig::default()); + + let recursion_shape_config = env::var("FIX_RECURSION_SHAPES") + .map(|v| v.eq_ignore_ascii_case("true")) + .unwrap_or(true) + .then_some(RecursionShapeConfig::default()); + + let vk_verification = + env::var("VERIFY_VK").map(|v| v.eq_ignore_ascii_case("true")).unwrap_or(true); + + tracing::debug!("vk verification: {}", vk_verification); + + // Read the shapes from the shapes directory and deserialize them into memory. + let allowed_vk_map: BTreeMap<[KoalaBear; DIGEST_SIZE], usize> = if vk_verification { + // Regenerate the vk_map.bin when the Ziren circuit is updated. + // ``` + // cd Ziren + // cargo run -r --bin build_compress_vks -- --num-compiler-workers 32 --count-setup-workers 32 --build-dir crates/prover + // ``` + // It takes several days. + bincode::deserialize(include_bytes!("../vk_map.bin")).unwrap() + } else { + bincode::deserialize(include_bytes!("../dummy_vk_map.bin")).unwrap() + }; + + let (root, merkle_tree) = MerkleTree::commit(allowed_vk_map.keys().copied().collect()); + + let mut compress_programs = BTreeMap::new(); + if let Some(config) = &recursion_shape_config { + ZKMProofShape::generate_compress_shapes(config, REDUCE_BATCH_SIZE).for_each(|shape| { + let compress_shape = ZKMCompressWithVkeyShape { + compress_shape: shape.into(), + merkle_tree_height: merkle_tree.height, + }; + let input = ZKMCompressWithVKeyWitnessValues::dummy( + compress_prover.machine(), + &compress_shape, + ); + let program = compress_program_from_input::( + recursion_shape_config.as_ref(), + &compress_prover, + vk_verification, + &input, + ); + let program = Arc::new(program); + compress_programs.insert(compress_shape, program); + }); + } + + Self { + core_prover, + compress_prover, + shrink_prover, + wrap_prover, + lift_programs_lru: Mutex::new(LruCache::new(core_cache_size)), + lift_cache_misses: AtomicUsize::new(0), + join_programs_map: compress_programs, + join_cache_misses: AtomicUsize::new(0), + recursion_vk_root: root, + recursion_vk_tree: merkle_tree, + recursion_vk_map: allowed_vk_map, + core_shape_config, + compress_shape_config: recursion_shape_config, + vk_verification, + wrap_program: OnceLock::new(), + wrap_vk: OnceLock::new(), } } - true -} -/// Check if the digest of the public values is correct. -pub fn is_recursion_public_values_valid( - config: &InnerSC, - public_values: &RecursionPublicValues, -) -> bool { - let expected_digest = recursion_public_values_digest(config, public_values); - for (value, expected) in public_values.digest.iter().copied().zip_eq(expected_digest) { - if value != expected { - return false; + /// Fully initializes the programs, proving keys, and verifying keys that are normally + /// lazily initialized. TODO: remove this. + pub fn initialize(&mut self) {} + + /// Creates a proving key and a verifying key for a given MIPS ELF. + #[instrument(name = "setup", level = "debug", skip_all)] + pub fn setup( + &self, + elf: &[u8], + ) -> (ZKMProvingKey, DeviceProvingKey, Program, ZKMVerifyingKey) { + let program = self.get_program(elf).unwrap(); + let (pk, vk) = self.core_prover.setup(&program); + let vk = ZKMVerifyingKey { vk }; + let pk = ZKMProvingKey { + pk: self.core_prover.pk_to_host(&pk), + elf: elf.to_vec(), + vk: vk.clone(), + }; + let pk_d = self.core_prover.pk_to_device(&pk.pk); + (pk, pk_d, program, vk) + } + + /// Get a program with an allowed preprocessed shape. + pub fn get_program(&self, elf: &[u8]) -> eyre::Result { + let mut program = Program::from(elf).unwrap(); + if let Some(core_shape_config) = &self.core_shape_config { + core_shape_config.fix_preprocessed_shape(&mut program)?; } + Ok(program) } - true -} -/// Get the committed values Bn Poseidon2 digest this reduce proof is representing. -pub fn zkm_committed_values_digest_bn254( - proof: &ZKMReduceProof, -) -> Bn254Fr { - let proof = &proof.proof; - let pv: &RecursionPublicValues = proof.public_values.as_slice().borrow(); - let committed_values_digest_bytes: [KoalaBear; 32] = - words_to_bytes(&pv.committed_value_digest).try_into().unwrap(); - koalabear_bytes_to_bn254(&committed_values_digest_bytes) -} + /// Generate a proof of a Ziren program with the specified inputs. + #[instrument(name = "execute", level = "info", skip_all)] + pub fn execute<'a>( + &'a self, + elf: &[u8], + stdin: &ZKMStdin, + mut context: ZKMContext<'a>, + ) -> Result<(ZKMPublicValues, ExecutionReport), ExecutionError> { + context.subproof_verifier = Some(self); + let program = self.get_program(elf).unwrap(); + let opts = ZKMCoreOpts::default(); + let mut runtime = Executor::with_context(program, opts, context); + runtime.write_vecs(&stdin.buffer); + for (proof, vkey) in stdin.proofs.iter() { + runtime.write_proof(proof.clone(), vkey.clone()); + } + runtime.run_fast()?; + Ok((ZKMPublicValues::from(&runtime.state.public_values_stream), runtime.report)) + } -impl ZKMCoreProofData { - pub fn save(&self, path: &str) -> Result<(), std::io::Error> { - let data = serde_json::to_string(self).unwrap(); - fs::write(path, data).unwrap(); - Ok(()) + /// Generate shard proofs which split up and prove the valid execution of a MIPS program with + /// the core prover. Uses the provided context. + #[instrument(name = "prove_core", level = "info", skip_all)] + pub fn prove_core<'a>( + &'a self, + pk_d: &<::CoreProver as MachineProver< + KoalaBearPoseidon2, + MipsAir, + >>::DeviceProvingKey, + program: Program, + stdin: &ZKMStdin, + opts: ZKMProverOpts, + mut context: ZKMContext<'a>, + ) -> Result { + context.subproof_verifier = Some(self); + let pk = pk_d; + let (proof, public_values_stream, cycles) = + zkm_core_machine::utils::prove_with_context::<_, C::CoreProver>( + &self.core_prover, + pk, + program, + stdin, + opts.core_opts, + context, + self.core_shape_config.as_ref(), + )?; + Self::check_for_high_cycles(cycles); + let public_values = ZKMPublicValues::from(&public_values_stream); + Ok(ZKMCoreProof { + proof: ZKMCoreProofData(proof.shard_proofs), + stdin: stdin.clone(), + public_values, + cycles, + }) } -} -/// Get the number of cycles for a given program. -pub fn get_cycles(elf: &[u8], stdin: &ZKMStdin) -> u64 { - let program = Program::from(elf).unwrap(); - let mut runtime = Executor::new(program, ZKMCoreOpts::default()); - runtime.write_vecs(&stdin.buffer); - runtime.run_fast().unwrap(); - runtime.state.global_clk -} + pub fn recursion_program( + &self, + input: &ZKMRecursionWitnessValues, + ) -> Arc> { + let mut cache = self.lift_programs_lru.lock().unwrap_or_else(|e| e.into_inner()); + cache + .get_or_insert(input.shape(), || { + let misses = self.lift_cache_misses.fetch_add(1, Ordering::Relaxed); + tracing::debug!("core cache miss, misses: {}", misses); + // Get the operations. + let builder_span = tracing::debug_span!("build recursion program").entered(); + let mut builder = Builder::::default(); -/// Load an ELF file from a given path. -pub fn load_elf(path: &str) -> Result, std::io::Error> { - let mut elf_code = Vec::new(); - File::open(path)?.read_to_end(&mut elf_code)?; - Ok(elf_code) -} + let input = input.read(&mut builder); + ZKMRecursiveVerifier::verify(&mut builder, self.core_prover.machine(), input); + let operations = builder.into_operations(); + builder_span.exit(); -pub fn words_to_bytes(words: &[Word]) -> Vec { - words.iter().flat_map(|word| word.0).collect() -} + // Compile the program. + let compiler_span = tracing::debug_span!("compile recursion program").entered(); + let mut compiler = AsmCompiler::::default(); + let mut program = compiler.compile(operations); + if let Some(recursion_shape_config) = &self.compress_shape_config { + recursion_shape_config.fix_shape(&mut program); + } + let program = Arc::new(program); + compiler_span.exit(); + program + }) + .clone() + } -/// Convert 8 KoalaBear words into a Bn254Fr field element by shifting by 31 bits each time. The last -/// word becomes the least significant bits. -pub fn koalabears_to_bn254(digest: &[KoalaBear; 8]) -> Bn254Fr { - let mut result = Bn254Fr::ZERO; - for word in digest.iter() { - // Since KoalaBear prime is less than 2^31, we can shift by 31 bits each time and still be - // within the Bn254Fr field, so we don't have to truncate the top 3 bits. - result *= Bn254Fr::from_canonical_u64(1 << 31); - result += Bn254Fr::from_canonical_u32(word.as_canonical_u32()); - } - result -} + pub fn compress_program( + &self, + input: &ZKMCompressWithVKeyWitnessValues, + ) -> Arc> { + self.join_programs_map.get(&input.shape()).cloned().unwrap_or_else(|| { + tracing::warn!("compress program not found in map, recomputing join program."); + // Get the operations. + Arc::new(compress_program_from_input::( + self.compress_shape_config.as_ref(), + &self.compress_prover, + self.vk_verification, + input, + )) + }) + } -/// Convert 32 KoalaBear bytes into a Bn254Fr field element. The first byte's most significant 3 bits -/// (which would become the 3 most significant bits) are truncated. -pub fn koalabear_bytes_to_bn254(bytes: &[KoalaBear; 32]) -> Bn254Fr { - let mut result = Bn254Fr::ZERO; - for (i, byte) in bytes.iter().enumerate() { - debug_assert!(byte < &KoalaBear::from_canonical_u32(256)); - if i == 0 { - // 32 bytes is more than Bn254 prime, so we need to truncate the top 3 bits. - result = Bn254Fr::from_canonical_u32(byte.as_canonical_u32() & 0x1f); - } else { - result *= Bn254Fr::from_canonical_u32(256); - result += Bn254Fr::from_canonical_u32(byte.as_canonical_u32()); + pub fn shrink_program( + &self, + shrink_shape: RecursionShape, + input: &ZKMCompressWithVKeyWitnessValues, + ) -> Arc> { + // Get the operations. + let builder_span = tracing::debug_span!("build shrink program").entered(); + let mut builder = Builder::::default(); + let input = input.read(&mut builder); + // Verify the proof. + ZKMCompressRootVerifierWithVKey::verify( + &mut builder, + self.compress_prover.machine(), + input, + self.vk_verification, + PublicValuesOutputDigest::Reduce, + ); + let operations = builder.into_operations(); + builder_span.exit(); + + // Compile the program. + let compiler_span = tracing::debug_span!("compile shrink program").entered(); + let mut compiler = AsmCompiler::::default(); + let mut program = compiler.compile(operations); + *program.shape_mut() = Some(shrink_shape); + let program = Arc::new(program); + compiler_span.exit(); + program + } + + pub fn wrap_program(&self) -> Arc> { + self.wrap_program + .get_or_init(|| { + // Get the operations. + let builder_span = tracing::debug_span!("build compress program").entered(); + let mut builder = Builder::::default(); + + let shrink_shape: OrderedShape = ShrinkAir::::shrink_shape().into(); + let input_shape = ZKMCompressShape::from(vec![shrink_shape]); + let shape = ZKMCompressWithVkeyShape { + compress_shape: input_shape, + merkle_tree_height: self.recursion_vk_tree.height, + }; + let dummy_input = + ZKMCompressWithVKeyWitnessValues::dummy(self.shrink_prover.machine(), &shape); + + let input = dummy_input.read(&mut builder); + + // Attest that the merkle tree root is correct. + let root = input.merkle_var.root; + for (val, expected) in root.iter().zip(self.recursion_vk_root.iter()) { + builder.assert_felt_eq(*val, *expected); + } + // Verify the proof. + ZKMCompressRootVerifierWithVKey::verify( + &mut builder, + self.shrink_prover.machine(), + input, + self.vk_verification, + PublicValuesOutputDigest::Root, + ); + + let operations = builder.into_operations(); + builder_span.exit(); + + // Compile the program. + let compiler_span = tracing::debug_span!("compile compress program").entered(); + let mut compiler = AsmCompiler::::default(); + let program = Arc::new(compiler.compile(operations)); + compiler_span.exit(); + program + }) + .clone() + } + + pub fn deferred_program( + &self, + input: &ZKMDeferredWitnessValues, + ) -> Arc> { + // Compile the program. + + // Get the operations. + let operations_span = + tracing::debug_span!("get operations for the deferred program").entered(); + let mut builder = Builder::::default(); + let input_read_span = tracing::debug_span!("Read input values").entered(); + let input = input.read(&mut builder); + input_read_span.exit(); + let verify_span = tracing::debug_span!("Verify deferred program").entered(); + + // Verify the proof. + ZKMDeferredVerifier::verify( + &mut builder, + self.compress_prover.machine(), + input, + self.vk_verification, + ); + verify_span.exit(); + let operations = builder.into_operations(); + operations_span.exit(); + + let compiler_span = tracing::debug_span!("compile deferred program").entered(); + let mut compiler = AsmCompiler::::default(); + let mut program = compiler.compile(operations); + if let Some(recursion_shape_config) = &self.compress_shape_config { + recursion_shape_config.fix_shape(&mut program); } + let program = Arc::new(program); + compiler_span.exit(); + program } - result -} -/// Utility method for converting u32 words to bytes in big endian. -pub fn words_to_bytes_be(words: &[u32; 8]) -> [u8; 32] { - let mut bytes = [0u8; 32]; - for i in 0..8 { - let word_bytes = words[i].to_be_bytes(); - bytes[i * 4..(i + 1) * 4].copy_from_slice(&word_bytes); + pub fn get_recursion_core_inputs( + &self, + vk: &StarkVerifyingKey, + shard_proofs: &[ShardProof], + batch_size: usize, + is_complete: bool, + ) -> Vec> { + let mut core_inputs = Vec::new(); + + // Prepare the inputs for the recursion programs. + for (batch_idx, batch) in shard_proofs.chunks(batch_size).enumerate() { + let proofs = batch.to_vec(); + + core_inputs.push(ZKMRecursionWitnessValues { + vk: vk.clone(), + shard_proofs: proofs.clone(), + is_complete, + is_first_shard: batch_idx == 0, + vk_root: self.recursion_vk_root, + }); + } + + core_inputs } - bytes -} -pub trait MaybeTakeIterator: Iterator { - fn maybe_skip(self, bound: Option) -> RangedIterator - where - Self: Sized, - { - match bound { - Some(bound) => RangedIterator::Skip(self.skip(bound)), - None => RangedIterator::Unbounded(self), + pub fn get_recursion_deferred_inputs<'a>( + &'a self, + vk: &'a StarkVerifyingKey, + last_proof_pv: &PublicValues, KoalaBear>, + deferred_proofs: &[ZKMReduceProof], + batch_size: usize, + ) -> Vec> { + // Prepare the inputs for the deferred proofs recursive verification. + let mut deferred_digest = [Val::::ZERO; DIGEST_SIZE]; + let mut deferred_inputs = Vec::new(); + + for batch in deferred_proofs.chunks(batch_size) { + let vks_and_proofs = + batch.iter().cloned().map(|proof| (proof.vk, proof.proof)).collect::>(); + + let input = ZKMCompressWitnessValues { vks_and_proofs, is_complete: true }; + let input = self.make_merkle_proofs(input); + let ZKMCompressWithVKeyWitnessValues { compress_val, merkle_val } = input; + + deferred_inputs.push(ZKMDeferredWitnessValues { + vks_and_proofs: compress_val.vks_and_proofs, + vk_merkle_data: merkle_val, + start_reconstruct_deferred_digest: deferred_digest, + is_complete: false, + zkm_vk_digest: vk.hash_koalabear(), + end_pc: Val::::ZERO, + end_shard: last_proof_pv.shard + KoalaBear::ONE, + end_execution_shard: last_proof_pv.execution_shard, + init_addr_bits: last_proof_pv.last_init_addr_bits, + finalize_addr_bits: last_proof_pv.last_finalize_addr_bits, + committed_value_digest: last_proof_pv.committed_value_digest, + deferred_proofs_digest: last_proof_pv.deferred_proofs_digest, + }); + + deferred_digest = Self::hash_deferred_proofs(deferred_digest, batch); + } + deferred_inputs + } + + /// Generate the inputs for the first layer of recursive proofs. + #[allow(clippy::type_complexity)] + pub fn get_first_layer_inputs<'a>( + &'a self, + vk: &'a ZKMVerifyingKey, + shard_proofs: &[ShardProof], + deferred_proofs: &[ZKMReduceProof], + batch_size: usize, + ) -> Vec { + let is_complete = shard_proofs.len() == 1 && deferred_proofs.is_empty(); + let core_inputs = + self.get_recursion_core_inputs(&vk.vk, shard_proofs, batch_size, is_complete); + let last_proof_pv = shard_proofs.last().unwrap().public_values.as_slice().borrow(); + let deferred_inputs = + self.get_recursion_deferred_inputs(&vk.vk, last_proof_pv, deferred_proofs, batch_size); + + let mut inputs = Vec::new(); + inputs.extend(core_inputs.into_iter().map(ZKMCircuitWitness::Core)); + inputs.extend(deferred_inputs.into_iter().map(ZKMCircuitWitness::Deferred)); + inputs + } + + /// Reduce shard proofs to a single shard proof using the recursion prover. + #[instrument(name = "compress", level = "info", skip_all)] + pub fn compress( + &self, + vk: &ZKMVerifyingKey, + proof: ZKMCoreProof, + deferred_proofs: Vec>, + opts: ZKMProverOpts, + ) -> Result, ZKMRecursionProverError> { + // The batch size for reducing two layers of recursion. + let batch_size = REDUCE_BATCH_SIZE; + // The batch size for reducing the first layer of recursion. + let first_layer_batch_size = 1; + + let shard_proofs = &proof.proof.0; + + let first_layer_inputs = + self.get_first_layer_inputs(vk, shard_proofs, &deferred_proofs, first_layer_batch_size); + + // Calculate the expected height of the tree. + let mut expected_height = if first_layer_inputs.len() == 1 { 0 } else { 1 }; + let num_first_layer_inputs = first_layer_inputs.len(); + let mut num_layer_inputs = num_first_layer_inputs; + while num_layer_inputs > batch_size { + num_layer_inputs = num_layer_inputs.div_ceil(2); + expected_height += 1; + } + + // Generate the proofs. + let span = tracing::Span::current().clone(); + let (vk, proof) = thread::scope(|s| { + let _span = span.enter(); + + // Spawn a worker that sends the first layer inputs to a bounded channel. + let input_sync = Arc::new(TurnBasedSync::new()); + let (input_tx, input_rx) = sync_channel::<(usize, usize, ZKMCircuitWitness)>( + opts.recursion_opts.checkpoints_channel_capacity, + ); + let input_tx = Arc::new(Mutex::new(input_tx)); + { + let input_tx = Arc::clone(&input_tx); + let input_sync = Arc::clone(&input_sync); + s.spawn(move || { + for (index, input) in first_layer_inputs.into_iter().enumerate() { + input_sync.wait_for_turn(index); + input_tx.lock().unwrap().send((index, 0, input)).unwrap(); + input_sync.advance_turn(); + } + }); + } + + // Spawn workers who generate the records and traces. + let record_and_trace_sync = Arc::new(TurnBasedSync::new()); + let (record_and_trace_tx, record_and_trace_rx) = + sync_channel::<( + usize, + usize, + Arc>, + ExecutionRecord, + Vec<(String, RowMajorMatrix)>, + )>(opts.recursion_opts.records_and_traces_channel_capacity); + let record_and_trace_tx = Arc::new(Mutex::new(record_and_trace_tx)); + let record_and_trace_rx = Arc::new(Mutex::new(record_and_trace_rx)); + let input_rx = Arc::new(Mutex::new(input_rx)); + for _ in 0..opts.recursion_opts.trace_gen_workers { + let record_and_trace_sync = Arc::clone(&record_and_trace_sync); + let record_and_trace_tx = Arc::clone(&record_and_trace_tx); + let input_rx = Arc::clone(&input_rx); + let span = tracing::debug_span!("generate records and traces"); + s.spawn(move || { + let _span = span.enter(); + loop { + let received = { input_rx.lock().unwrap().recv() }; + if let Ok((index, height, input)) = received { + // Get the program and witness stream. + let (program, witness_stream) = tracing::debug_span!( + "get program and witness stream" + ) + .in_scope(|| match input { + ZKMCircuitWitness::Core(input) => { + let mut witness_stream = Vec::new(); + Witnessable::::write(&input, &mut witness_stream); + (self.recursion_program(&input), witness_stream) + } + ZKMCircuitWitness::Deferred(input) => { + let mut witness_stream = Vec::new(); + Witnessable::::write(&input, &mut witness_stream); + (self.deferred_program(&input), witness_stream) + } + ZKMCircuitWitness::Compress(input) => { + let mut witness_stream = Vec::new(); + + let input_with_merkle = self.make_merkle_proofs(input); + + Witnessable::::write( + &input_with_merkle, + &mut witness_stream, + ); + + (self.compress_program(&input_with_merkle), witness_stream) + } + }); + + // Execute the runtime. + let record = tracing::debug_span!("execute runtime").in_scope(|| { + let mut runtime = + RecursionRuntime::, Challenge, _>::new( + program.clone(), + self.compress_prover.config().perm.clone(), + ); + runtime.witness_stream = witness_stream.into(); + runtime + .run() + .map_err(|e| { + ZKMRecursionProverError::RuntimeError(e.to_string()) + }) + .unwrap(); + runtime.record + }); + + // Generate the dependencies. + let mut records = vec![record]; + tracing::debug_span!("generate dependencies").in_scope(|| -> Result<(), ZKMRecursionProverError> { + match self.compress_prover.machine().generate_dependencies( + &mut records, + &opts.recursion_opts, + None, + ) { + Ok(_) => Ok(()), + Err(e) => { + tracing::error!( + "Failed to generate dependencies for recursion proof: {}", + e + ); + Err(ZKMRecursionProverError::DependenciesGenerationError) + } + } + })?; + + // Generate the traces. + let record = records.into_iter().next().unwrap(); + let traces = tracing::debug_span!("generate traces") + .in_scope(|| self.compress_prover.generate_traces(&record)); + let traces = match traces { + Ok(traces) => traces, + Err(e) => { + tracing::error!( + "Failed to generate traces for recursion proof: {}", + e + ); + return Err(ZKMRecursionProverError::TracesGenerationError); + } + }; + + // Wait for our turn to update the state. + record_and_trace_sync.wait_for_turn(index); + + // Send the record and traces to the worker. + record_and_trace_tx + .lock() + .unwrap() + .send((index, height, program, record, traces)) + .unwrap(); + + // Advance the turn. + record_and_trace_sync.advance_turn(); + } else { + break Ok(()); + } + } + }); + } + + // Spawn workers who generate the compress proofs. + let proofs_sync = Arc::new(TurnBasedSync::new()); + let (proofs_tx, proofs_rx) = + sync_channel::<(usize, usize, StarkVerifyingKey, ShardProof)>( + num_first_layer_inputs * 2, + ); + let proofs_tx = Arc::new(Mutex::new(proofs_tx)); + let proofs_rx = Arc::new(Mutex::new(proofs_rx)); + let mut prover_handles = Vec::new(); + for _ in 0..opts.recursion_opts.shard_batch_size { + let prover_sync = Arc::clone(&proofs_sync); + let record_and_trace_rx = Arc::clone(&record_and_trace_rx); + let proofs_tx = Arc::clone(&proofs_tx); + let span = tracing::debug_span!("prove"); + let handle = s.spawn(move || { + let _span = span.enter(); + loop { + let received = { record_and_trace_rx.lock().unwrap().recv() }; + if let Ok((index, height, program, record, traces)) = received { + tracing::debug_span!("batch").in_scope(|| { + // Get the keys. + let (pk, vk) = tracing::debug_span!("Setup compress program") + .in_scope(|| self.compress_prover.setup(&program)); + + // Observe the proving key. + let mut challenger = self.compress_prover.config().challenger(); + tracing::debug_span!("observe proving key").in_scope(|| { + pk.observe_into(&mut challenger); + }); + + #[cfg(feature = "debug")] + self.compress_prover.debug_constraints( + &self.compress_prover.pk_to_host(&pk), + vec![record.clone()], + &mut challenger.clone(), + ); + + // Commit to the record and traces. + let data = tracing::debug_span!("commit") + .in_scope(|| self.compress_prover.commit(&record, traces)); + + // Generate the proof. + let proof = tracing::debug_span!("open").in_scope(|| { + self.compress_prover.open(&pk, data, &mut challenger).unwrap() + }); + + // Verify the proof. + #[cfg(feature = "debug")] + self.compress_prover + .machine() + .verify( + &vk, + &zkm_stark::MachineProof { + shard_proofs: vec![proof.clone()], + }, + &mut self.compress_prover.config().challenger(), + ) + .unwrap(); + + // Wait for our turn to update the state. + prover_sync.wait_for_turn(index); + + // Send the proof. + proofs_tx.lock().unwrap().send((index, height, vk, proof)).unwrap(); + + // Advance the turn. + prover_sync.advance_turn(); + }); + } else { + break; + } + } + }); + prover_handles.push(handle); + } + + // Spawn a worker that generates inputs for the next layer. + let handle = { + let input_tx = Arc::clone(&input_tx); + let proofs_rx = Arc::clone(&proofs_rx); + let span = tracing::debug_span!("generate next layer inputs"); + s.spawn(move || { + let _span = span.enter(); + let mut count = num_first_layer_inputs; + let mut batch: Vec<( + usize, + usize, + StarkVerifyingKey, + ShardProof, + )> = Vec::new(); + loop { + if expected_height == 0 { + break; + } + let received = { proofs_rx.lock().unwrap().recv() }; + if let Ok((index, height, vk, proof)) = received { + batch.push((index, height, vk, proof)); + + // If we haven't reached the batch size, continue. + if batch.len() < batch_size { + continue; + } + + // Compute whether we're at the last input of a layer. + let mut is_last = false; + if let Some(first) = batch.first() { + is_last = first.1 != height; + } + + // If we're at the last input of a layer, we need to only include the + // first input, otherwise we include all inputs. + let inputs = + if is_last { vec![batch[0].clone()] } else { batch.clone() }; + + let next_input_height = inputs[0].1 + 1; + + let is_complete = next_input_height == expected_height; + + let vks_and_proofs = inputs + .into_iter() + .map(|(_, _, vk, proof)| (vk, proof)) + .collect::>(); + let input = ZKMCircuitWitness::Compress(ZKMCompressWitnessValues { + vks_and_proofs, + is_complete, + }); + + input_sync.wait_for_turn(count); + input_tx + .lock() + .unwrap() + .send((count, next_input_height, input)) + .unwrap(); + input_sync.advance_turn(); + count += 1; + + // If we're at the root of the tree, stop generating inputs. + if is_complete { + break; + } + + // If we were at the last input of a layer, we keep everything but the + // first input. Otherwise, we empty the batch. + if is_last { + batch = vec![batch[1].clone()]; + } else { + batch = Vec::new(); + } + } else { + break; + } + } + }) + }; + + // Wait for all the provers to finish. + drop(input_tx); + drop(record_and_trace_tx); + drop(proofs_tx); + for handle in prover_handles { + handle.join().unwrap(); + } + handle.join().unwrap(); + + let (_, _, vk, proof) = proofs_rx.lock().unwrap().recv().unwrap(); + (vk, proof) + }); + + Ok(ZKMReduceProof { vk, proof }) + } + + /// Wrap a reduce proof into a STARK proven over a SNARK-friendly field. + #[instrument(name = "shrink", level = "info", skip_all)] + pub fn shrink( + &self, + reduced_proof: ZKMReduceProof, + opts: ZKMProverOpts, + ) -> Result, ZKMRecursionProverError> { + // Make the compress proof. + let ZKMReduceProof { vk: compressed_vk, proof: compressed_proof } = reduced_proof; + let input = ZKMCompressWitnessValues { + vks_and_proofs: vec![(compressed_vk, compressed_proof)], + is_complete: true, + }; + + let input_with_merkle = self.make_merkle_proofs(input); + + let program = + self.shrink_program(ShrinkAir::::shrink_shape(), &input_with_merkle); + + // Run the compress program. + let mut runtime = RecursionRuntime::, Challenge, _>::new( + program.clone(), + self.shrink_prover.config().perm.clone(), + ); + + let mut witness_stream = Vec::new(); + Witnessable::::write(&input_with_merkle, &mut witness_stream); + + runtime.witness_stream = witness_stream.into(); + + runtime.run().map_err(|e| ZKMRecursionProverError::RuntimeError(e.to_string()))?; + + runtime.print_stats(); + tracing::debug!("Shrink program executed successfully"); + + let (shrink_pk, shrink_vk) = + tracing::debug_span!("setup shrink").in_scope(|| self.shrink_prover.setup(&program)); + + // Prove the compress program. + let mut compress_challenger = self.shrink_prover.config().challenger(); + let mut compress_proof = self + .shrink_prover + .prove(&shrink_pk, vec![runtime.record], &mut compress_challenger, opts.recursion_opts) + .unwrap(); + + Ok(ZKMReduceProof { vk: shrink_vk, proof: compress_proof.shard_proofs.pop().unwrap() }) + } + + /// Wrap a reduce proof into a STARK proven over a SNARK-friendly field. + #[instrument(name = "wrap_bn254", level = "info", skip_all)] + pub fn wrap_bn254( + &self, + compressed_proof: ZKMReduceProof, + opts: ZKMProverOpts, + ) -> Result, ZKMRecursionProverError> { + let ZKMReduceProof { vk: compressed_vk, proof: compressed_proof } = compressed_proof; + let input = ZKMCompressWitnessValues { + vks_and_proofs: vec![(compressed_vk, compressed_proof)], + is_complete: true, + }; + let input_with_vk = self.make_merkle_proofs(input); + + let program = self.wrap_program(); + + // Run the compress program. + let mut runtime = RecursionRuntime::, Challenge, _>::new( + program.clone(), + self.shrink_prover.config().perm.clone(), + ); + + let mut witness_stream = Vec::new(); + Witnessable::::write(&input_with_vk, &mut witness_stream); + + runtime.witness_stream = witness_stream.into(); + + runtime.run().map_err(|e| ZKMRecursionProverError::RuntimeError(e.to_string()))?; + + runtime.print_stats(); + tracing::debug!("wrap program executed successfully"); + + // Setup the wrap program. + let (wrap_pk, wrap_vk) = + tracing::debug_span!("setup wrap").in_scope(|| self.wrap_prover.setup(&program)); + + if self.wrap_vk.set(wrap_vk.clone()).is_ok() { + tracing::debug!("wrap verifier key set"); + } + + // Prove the wrap program. + let mut wrap_challenger = self.wrap_prover.config().challenger(); + let time = std::time::Instant::now(); + let mut wrap_proof = self + .wrap_prover + .prove(&wrap_pk, vec![runtime.record], &mut wrap_challenger, opts.recursion_opts) + .unwrap(); + let elapsed = time.elapsed(); + tracing::debug!("wrap proving time: {:?}", elapsed); + let mut wrap_challenger = self.wrap_prover.config().challenger(); + self.wrap_prover.machine().verify(&wrap_vk, &wrap_proof, &mut wrap_challenger).unwrap(); + tracing::info!("wrapping successful"); + + Ok(ZKMReduceProof { vk: wrap_vk, proof: wrap_proof.shard_proofs.pop().unwrap() }) + } + + /// Wrap the STARK proven over a SNARK-friendly field into a PLONK proof. + #[instrument(name = "wrap_plonk_bn254", level = "info", skip_all)] + pub fn wrap_plonk_bn254( + &self, + proof: ZKMReduceProof, + build_dir: &Path, + ) -> PlonkBn254Proof { + let input = ZKMCompressWitnessValues { + vks_and_proofs: vec![(proof.vk.clone(), proof.proof.clone())], + is_complete: true, + }; + let vkey_hash = zkm_vkey_digest_bn254(&proof); + let committed_values_digest = zkm_committed_values_digest_bn254(&proof); + + let mut witness = Witness::default(); + input.write(&mut witness); + witness.write_committed_values_digest(committed_values_digest); + witness.write_vkey_hash(vkey_hash); + + let prover = PlonkBn254Prover::new(); + let proof = prover.prove(witness, build_dir.to_path_buf()); + + // Verify the proof. + prover + .verify( + &proof, + &vkey_hash.as_canonical_biguint(), + &committed_values_digest.as_canonical_biguint(), + build_dir, + ) + .unwrap(); + + proof + } + + /// Wrap the STARK proven over a SNARK-friendly field into a Groth16 proof. + #[instrument(name = "wrap_groth16_bn254", level = "info", skip_all)] + pub fn wrap_groth16_bn254( + &self, + proof: ZKMReduceProof, + build_dir: &Path, + ) -> Groth16Bn254Proof { + let input = ZKMCompressWitnessValues { + vks_and_proofs: vec![(proof.vk.clone(), proof.proof.clone())], + is_complete: true, + }; + let vkey_hash = zkm_vkey_digest_bn254(&proof); + let committed_values_digest = zkm_committed_values_digest_bn254(&proof); + + let mut witness = Witness::default(); + input.write(&mut witness); + witness.write_committed_values_digest(committed_values_digest); + witness.write_vkey_hash(vkey_hash); + + let prover = Groth16Bn254Prover::new(); + let proof = prover.prove(witness, build_dir.to_path_buf()); + + // Verify the proof. + prover + .verify( + &proof, + &vkey_hash.as_canonical_biguint(), + &committed_values_digest.as_canonical_biguint(), + build_dir, + ) + .unwrap(); + + proof + } + + /// Accumulate deferred proofs into a single digest. + pub fn hash_deferred_proofs( + prev_digest: [Val; DIGEST_SIZE], + deferred_proofs: &[ZKMReduceProof], + ) -> [Val; 8] { + let mut digest = prev_digest; + for proof in deferred_proofs.iter() { + let pv: &RecursionPublicValues> = + proof.proof.public_values.as_slice().borrow(); + let committed_values_digest = words_to_bytes(&pv.committed_value_digest); + digest = hash_deferred_proof( + &digest, + &pv.zkm_vk_digest, + &committed_values_digest.try_into().unwrap(), + ); } + digest + } + + pub fn make_merkle_proofs( + &self, + input: ZKMCompressWitnessValues, + ) -> ZKMCompressWithVKeyWitnessValues { + let num_vks = self.recursion_vk_map.len(); + let (vk_indices, vk_digest_values): (Vec<_>, Vec<_>) = if self.vk_verification { + input + .vks_and_proofs + .iter() + .map(|(vk, _)| { + let vk_digest = vk.hash_koalabear(); + let index = self.recursion_vk_map.get(&vk_digest).expect("vk not allowed"); + (index, vk_digest) + }) + .unzip() + } else { + input + .vks_and_proofs + .iter() + .map(|(vk, _)| { + let vk_digest = vk.hash_koalabear(); + let index = (vk_digest[0].as_canonical_u32() as usize) % num_vks; + (index, [KoalaBear::from_canonical_usize(index); 8]) + }) + .unzip() + }; + + let proofs = vk_indices + .iter() + .map(|index| { + let (_, proof) = MerkleTree::open(&self.recursion_vk_tree, *index); + proof + }) + .collect(); + + let merkle_val = ZKMMerkleProofWitnessValues { + root: self.recursion_vk_root, + values: vk_digest_values, + vk_merkle_proofs: proofs, + }; + + ZKMCompressWithVKeyWitnessValues { compress_val: input, merkle_val } } - fn maybe_take(self, bound: Option) -> RangedIterator - where - Self: Sized, - { - match bound { - Some(bound) => RangedIterator::Take(self.take(bound)), - None => RangedIterator::Unbounded(self), + fn check_for_high_cycles(cycles: u64) { + if cycles > 100_000_000 { + tracing::warn!( + "high cycle count, consider using the prover network for proof generation" + ); } } } -impl MaybeTakeIterator for I {} +pub fn compress_program_from_input( + config: Option<&RecursionShapeConfig>>, + compress_prover: &C::CompressProver, + vk_verification: bool, + input: &ZKMCompressWithVKeyWitnessValues, +) -> RecursionProgram { + let builder_span = tracing::debug_span!("build compress program").entered(); + let mut builder = Builder::::default(); + // read the input. + let input = input.read(&mut builder); + // Verify the proof. + ZKMCompressWithVKeyVerifier::verify( + &mut builder, + compress_prover.machine(), + input, + vk_verification, + PublicValuesOutputDigest::Reduce, + ); + let operations = builder.into_operations(); + builder_span.exit(); + + // Compile the program. + let compiler_span = tracing::debug_span!("compile compress program").entered(); + let mut compiler = AsmCompiler::::default(); + let mut program = compiler.compile(operations); + if let Some(config) = config { + config.fix_shape(&mut program); + } + compiler_span.exit(); -pub enum RangedIterator { - Unbounded(I), - Skip(Skip), - Take(Take), - Range(Take>), + program } -impl Iterator for RangedIterator { - type Item = I::Item; +#[cfg(test)] +pub mod tests { + use std::{ + collections::BTreeSet, + fs::File, + io::{Read, Write}, + }; + + use super::*; + + use crate::build::try_build_plonk_bn254_artifacts_dev; + use anyhow::Result; + use build::{build_constraints_and_witness, try_build_groth16_bn254_artifacts_dev}; + use p3_field::PrimeField32; + + use shapes::ZKMProofShape; + use zkm_recursion_core::air::RecursionPublicValues; + + #[cfg(test)] + use serial_test::serial; + use utils::zkm_vkey_digest_koalabear; + #[cfg(test)] + use zkm_core_machine::utils::setup_logger; + + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub enum Test { + Core, + Compress, + Shrink, + Wrap, + CircuitTest, + All, + } + + pub fn test_e2e_prover( + prover: &ZKMProver, + elf: &[u8], + stdin: ZKMStdin, + opts: ZKMProverOpts, + test_kind: Test, + ) -> Result<()> { + run_e2e_prover_with_options(prover, elf, stdin, opts, test_kind, true) + } + + pub fn bench_e2e_prover( + prover: &ZKMProver, + elf: &[u8], + stdin: ZKMStdin, + opts: ZKMProverOpts, + test_kind: Test, + ) -> Result<()> { + run_e2e_prover_with_options(prover, elf, stdin, opts, test_kind, false) + } + + pub fn run_e2e_prover_with_options( + prover: &ZKMProver, + elf: &[u8], + stdin: ZKMStdin, + opts: ZKMProverOpts, + test_kind: Test, + verify: bool, + ) -> Result<()> { + tracing::info!("initializing prover"); + let context = ZKMContext::default(); + + tracing::info!("setup elf"); + let (_, pk_d, program, vk) = prover.setup(elf); - fn next(&mut self) -> Option { - match self { - RangedIterator::Unbounded(unbounded) => unbounded.next(), - RangedIterator::Skip(skip) => skip.next(), - RangedIterator::Take(take) => take.next(), - RangedIterator::Range(range) => range.next(), + tracing::info!("prove core"); + let core_proof = prover.prove_core(&pk_d, program, &stdin, opts, context)?; + let public_values = core_proof.public_values.clone(); + + if env::var("COLLECT_SHAPES").is_ok() { + let mut shapes = BTreeSet::new(); + for proof in core_proof.proof.0.iter() { + let shape = ZKMProofShape::Recursion(proof.shape()); + tracing::info!("shape: {:?}", shape); + shapes.insert(shape); + } + + let mut file = File::create("../shapes.bin").unwrap(); + bincode::serialize_into(&mut file, &shapes).unwrap(); + } + + if verify { + tracing::info!("verify core"); + prover.verify(&core_proof.proof, &vk)?; + } + + if test_kind == Test::Core { + return Ok(()); + } + + tracing::info!("compress"); + let compress_span = tracing::debug_span!("compress").entered(); + let compressed_proof = prover.compress(&vk, core_proof, vec![], opts)?; + compress_span.exit(); + + if verify { + tracing::info!("verify compressed"); + prover.verify_compressed(&compressed_proof, &vk)?; + } + + if test_kind == Test::Compress { + return Ok(()); + } + + tracing::info!("shrink"); + let shrink_proof = prover.shrink(compressed_proof, opts)?; + + if verify { + tracing::info!("verify shrink"); + prover.verify_shrink(&shrink_proof, &vk)?; + } + + if test_kind == Test::Shrink { + return Ok(()); } + + tracing::info!("wrap bn254"); + let wrapped_bn254_proof = prover.wrap_bn254(shrink_proof, opts)?; + let bytes = bincode::serialize(&wrapped_bn254_proof).unwrap(); + + // Save the proof. + let mut file = File::create("proof-with-pis.bin").unwrap(); + file.write_all(bytes.as_slice()).unwrap(); + + // Load the proof. + let mut file = File::open("proof-with-pis.bin").unwrap(); + let mut bytes = Vec::new(); + file.read_to_end(&mut bytes).unwrap(); + + let wrapped_bn254_proof = bincode::deserialize(&bytes).unwrap(); + + if verify { + tracing::info!("verify wrap bn254"); + prover.verify_wrap_bn254(&wrapped_bn254_proof, &vk).unwrap(); + } + + if test_kind == Test::Wrap { + return Ok(()); + } + + tracing::info!("checking vkey hash koalabear"); + let vk_digest_koalabear = zkm_vkey_digest_koalabear(&wrapped_bn254_proof); + assert_eq!(vk_digest_koalabear, vk.hash_koalabear()); + + tracing::info!("checking vkey hash bn254"); + let vk_digest_bn254 = zkm_vkey_digest_bn254(&wrapped_bn254_proof); + assert_eq!(vk_digest_bn254, vk.hash_bn254()); + + tracing::info!("Test the outer Plonk circuit"); + let (constraints, witness) = + build_constraints_and_witness(&wrapped_bn254_proof.vk, &wrapped_bn254_proof.proof); + PlonkBn254Prover::test(constraints, witness); + tracing::info!("Circuit test succeeded"); + + if test_kind == Test::CircuitTest { + return Ok(()); + } + + tracing::info!("generate plonk bn254 proof"); + let artifacts_dir = try_build_plonk_bn254_artifacts_dev( + &wrapped_bn254_proof.vk, + &wrapped_bn254_proof.proof, + ); + let plonk_bn254_proof = + prover.wrap_plonk_bn254(wrapped_bn254_proof.clone(), &artifacts_dir); + println!("{plonk_bn254_proof:?}"); + + prover.verify_plonk_bn254(&plonk_bn254_proof, &vk, &public_values, &artifacts_dir)?; + + tracing::info!("generate groth16 bn254 proof"); + let artifacts_dir = try_build_groth16_bn254_artifacts_dev( + &wrapped_bn254_proof.vk, + &wrapped_bn254_proof.proof, + ); + let groth16_bn254_proof = prover.wrap_groth16_bn254(wrapped_bn254_proof, &artifacts_dir); + println!("{groth16_bn254_proof:?}"); + + if verify { + prover.verify_groth16_bn254( + &groth16_bn254_proof, + &vk, + &public_values, + &artifacts_dir, + )?; + } + + Ok(()) + } + + pub fn test_e2e_with_deferred_proofs_prover( + opts: ZKMProverOpts, + ) -> Result<()> { + // Test program which proves the Keccak-256 hash of various inputs. + let keccak_elf = test_artifacts::KECCAK_SPONGE_ELF; + + // Test program which verifies proofs of a vkey and a list of committed inputs. + let verify_elf = test_artifacts::VERIFY_PROOF_ELF; + + tracing::info!("initializing prover"); + let prover = ZKMProver::::new(); + + tracing::info!("setup keccak elf"); + let (_, keccak_pk_d, keccak_program, keccak_vk) = prover.setup(keccak_elf); + + tracing::info!("setup verify elf"); + let (_, verify_pk_d, verify_program, verify_vk) = prover.setup(verify_elf); + + tracing::info!("prove subproof 1"); + let mut stdin = ZKMStdin::new(); + stdin.write(&1usize); + stdin.write(&vec![0u8, 0, 0]); + let deferred_proof_1 = prover.prove_core( + &keccak_pk_d, + keccak_program.clone(), + &stdin, + opts, + Default::default(), + )?; + let pv_1 = deferred_proof_1.public_values.as_slice().to_vec().clone(); + + // Generate a second proof of keccak of various inputs. + tracing::info!("prove subproof 2"); + let mut stdin = ZKMStdin::new(); + stdin.write(&3usize); + stdin.write(&vec![0u8, 1, 2]); + stdin.write(&vec![2, 3, 4]); + stdin.write(&vec![5, 6, 7]); + let deferred_proof_2 = + prover.prove_core(&keccak_pk_d, keccak_program, &stdin, opts, Default::default())?; + let pv_2 = deferred_proof_2.public_values.as_slice().to_vec().clone(); + + // Generate recursive proof of first subproof. + tracing::info!("compress subproof 1"); + let deferred_reduce_1 = prover.compress(&keccak_vk, deferred_proof_1, vec![], opts)?; + + // Generate recursive proof of second subproof. + tracing::info!("compress subproof 2"); + let deferred_reduce_2 = prover.compress(&keccak_vk, deferred_proof_2, vec![], opts)?; + + // Run verify program with keccak vkey, subproofs, and their committed values. + let mut stdin = ZKMStdin::new(); + let vkey_digest = keccak_vk.hash_koalabear(); + let vkey_digest: [u32; 8] = vkey_digest + .iter() + .map(|n| n.as_canonical_u32()) + .collect::>() + .try_into() + .unwrap(); + stdin.write(&vkey_digest); + stdin.write(&vec![pv_1.clone(), pv_2.clone(), pv_2.clone()]); + stdin.write_proof(deferred_reduce_1.clone(), keccak_vk.vk.clone()); + stdin.write_proof(deferred_reduce_2.clone(), keccak_vk.vk.clone()); + stdin.write_proof(deferred_reduce_2.clone(), keccak_vk.vk.clone()); + + tracing::info!("proving verify program (core)"); + let verify_proof = + prover.prove_core(&verify_pk_d, verify_program, &stdin, opts, Default::default())?; + // let public_values = verify_proof.public_values.clone(); + + // Generate recursive proof of verify program + tracing::info!("compress verify program"); + let verify_reduce = prover.compress( + &verify_vk, + verify_proof, + vec![deferred_reduce_1, deferred_reduce_2.clone(), deferred_reduce_2], + opts, + )?; + let reduce_pv: &RecursionPublicValues<_> = + verify_reduce.proof.public_values.as_slice().borrow(); + println!("deferred_hash: {:?}", reduce_pv.deferred_proofs_digest); + println!("complete: {:?}", reduce_pv.is_complete); + + tracing::info!("verify verify program"); + prover.verify_compressed(&verify_reduce, &verify_vk)?; + + let shrink_proof = prover.shrink(verify_reduce, opts)?; + + tracing::info!("verify shrink"); + prover.verify_shrink(&shrink_proof, &verify_vk)?; + + tracing::info!("wrap bn254"); + let wrapped_bn254_proof = prover.wrap_bn254(shrink_proof, opts)?; + + tracing::info!("verify wrap bn254"); + println!("verify wrap bn254 {:#?}", wrapped_bn254_proof.vk.commit); + prover.verify_wrap_bn254(&wrapped_bn254_proof, &verify_vk).unwrap(); + + Ok(()) + } + + /// Tests an end-to-end workflow of proving a program across the entire proof generation + /// pipeline. + /// + /// Add `FRI_QUERIES`=1 to your environment for faster execution. Should only take a few minutes + /// on a Mac M2. Note: This test always re-builds the plonk bn254 artifacts, so setting ZKM_DEV + /// is not needed. + #[test] + #[serial] + #[ignore] + fn test_e2e() -> Result<()> { + let elf = test_artifacts::FIBONACCI_ELF; + setup_logger(); + let opts = ZKMProverOpts::default(); + // TODO(mattstam): We should Test::Plonk here, but this uses the existing + // docker image which has a different API than the current. So we need to wait until the + // next release (v1.2.0+), and then switch it back. + let prover = ZKMProver::::new(); + test_e2e_prover::( + &prover, + elf, + ZKMStdin::default(), + opts, + Test::All, + ) + } + + /// Tests an end-to-end workflow of proving a program across the entire proof generation + /// pipeline. + /// + /// Add `FRI_QUERIES`=1 to your environment for faster execution. Should only take a few minutes + /// on a Mac M2. Note: This test always re-builds the plonk bn254 artifacts, so setting ZKM_DEV + /// is not needed. + #[test] + #[serial] + #[ignore] + fn test_e2e_hello_world() -> Result<()> { + let elf = test_artifacts::HELLO_WORLD_ELF; + + setup_logger(); + let opts = ZKMProverOpts::default(); + // TODO(mattstam): We should Test::Plonk here, but this uses the existing + // docker image which has a different API than the current. So we need to wait until the + // next release (v1.2.0+), and then switch it back. + let prover = ZKMProver::::new(); + test_e2e_prover::( + &prover, + elf, + ZKMStdin::default(), + opts, + Test::All, + ) + } + + /// Tests an end-to-end workflow of proving a program across the entire proof generation + /// pipeline in addition to verifying deferred proofs. + #[test] + #[serial] + #[ignore] + fn test_e2e_with_deferred_proofs() -> Result<()> { + setup_logger(); + test_e2e_with_deferred_proofs_prover::(ZKMProverOpts::default()) } }