diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4efdf4de..3a5d27ab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: with: workspaces: "expander_compiler -> expander_compiler/target" # The prefix cache key, this can be changed to start a new cache manually. - prefix-key: "mpi-v5.0.5" # update me if brew formula changes to a new version + prefix-key: "mpi-v5.0.8" # update me if brew formula changes to a new version - if: matrix.os == 'macos-latest' run: brew install openmpi - if: matrix.os == 'ubuntu-latest' @@ -71,12 +71,12 @@ jobs: with: workspaces: "expander_compiler -> expander_compiler/target" # The prefix cache key, this can be changed to start a new cache manually. - prefix-key: "mpi-v5.0.5" # update me if brew formula changes to a new version + prefix-key: "mpi-v5.0.8" # update me if brew formula changes to a new version - if: matrix.os == 'macos-latest' run: brew install openmpi - if: matrix.os == 'ubuntu-latest' run: sudo apt-get update && sudo apt-get install libopenmpi-dev -y - - run: cargo build --release --bin expander_server --bin expander_server_pcs_defered + - run: cargo build --release --bin expander_server --bin expander_server_pcs_defered --bin expander_server_no_oversubscribe - run: cargo test test-rust-avx512: @@ -90,7 +90,7 @@ jobs: with: workspaces: "expander_compiler -> expander_compiler/target" # The prefix cache key, this can be changed to start a new cache manually. - prefix-key: "mpi-v5.0.5" # update me if brew formula changes to a new version + prefix-key: "mpi-v5.0.8" # update me if brew formula changes to a new version - run: RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo build --release --bin expander_commit --bin expander_prove - run: RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo test diff --git a/Cargo.lock b/Cargo.lock index 90f583eb..3385ba66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,7 +112,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "ark-std", "criterion", @@ -330,7 +330,7 @@ dependencies = [ [[package]] name = "babybear" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -383,7 +383,7 @@ dependencies = [ [[package]] name = "bin" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "babybear", @@ -523,9 +523,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.27" +version = "1.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d487aa071b5f64da6f19a3e848e3578944b726ee5a4854b82172f02aa876bfdc" +checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7" dependencies = [ "shlex", ] @@ -589,7 +589,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -645,9 +645,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" +checksum = "be92d32e80243a54711e5d7ce823c35c41c9d929dc4ab58e1276f625841aadf9" dependencies = [ "clap_builder", "clap_derive", @@ -655,9 +655,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" +checksum = "707eab41e9622f9139419d573eca0900137718000c517d47da73045f54331c3d" dependencies = [ "anstream", "anstyle", @@ -667,9 +667,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce" +checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491" dependencies = [ "heck", "proc-macro2", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "config_macros" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "gkr_engine", "gkr_hashers", @@ -817,7 +817,7 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crosslayer_prototype" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "env_logger", @@ -1005,6 +1005,7 @@ dependencies = [ "serdes", "sha2", "shared_memory", + "stacker", "sumcheck", "tiny-keccak", "tokio", @@ -1143,7 +1144,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -1160,7 +1161,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -1179,7 +1180,7 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -1212,7 +1213,7 @@ dependencies = [ [[package]] name = "gkr_engine" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "babybear", @@ -1231,7 +1232,7 @@ dependencies = [ [[package]] name = "gkr_hashers" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "halo2curves", @@ -1249,7 +1250,7 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "goldilocks" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -1271,9 +1272,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" +checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d" dependencies = [ "bytes", "fnv", @@ -1508,9 +1509,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.14" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc2fdfdbff08affe55bb779f33b053aa1fe5dd5b54c257343c17edfa55711bdb" +checksum = "8d9b05277c7e8da2c93a568989bb6207bef0112e8d17df7a6eda4a3cf143bc5e" dependencies = [ "bytes", "futures-core", @@ -1665,9 +1666,9 @@ dependencies = [ [[package]] name = "io-uring" -version = "0.7.8" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b86e202f00093dcba4275d4636b93ef9dd75d025ae560d2521b45ea28ab49013" +checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" dependencies = [ "bitflags 2.9.1", "cfg-if", @@ -1881,7 +1882,7 @@ dependencies = [ [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -2259,7 +2260,7 @@ dependencies = [ [[package]] name = "poly_commit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -2284,7 +2285,7 @@ dependencies = [ [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -2330,9 +2331,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.35" +version = "0.2.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "061c1221631e079b26479d25bbf2275bfe5917ae8419cd7e34f13bfc2aa7539a" +checksum = "ff24dfcda44452b9816fff4cd4227e1bb73ff5a2f1bc1105aa92fb8565ce44d2" dependencies = [ "proc-macro2", "syn 2.0.104", @@ -2347,6 +2348,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "psm" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" +dependencies = [ + "cc", +] + [[package]] name = "quote" version = "1.0.40" @@ -2429,9 +2439,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.13" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d04b7d0ee6b4a0207a0a7adb104d23ecb0b47d6beae7152d0fa34b692b29fd6" +checksum = "7e8af0dde094006011e6a740d4879319439489813bd0bcdc7d821beaeeff48ec" dependencies = [ "bitflags 2.9.1", ] @@ -2557,15 +2567,15 @@ dependencies = [ [[package]] name = "rustix" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8" dependencies = [ "bitflags 2.9.1", "errno", "libc", "linux-raw-sys 0.9.4", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -2670,9 +2680,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3" dependencies = [ "itoa", "memchr", @@ -2705,7 +2715,7 @@ dependencies = [ [[package]] name = "serdes" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "ethnum", "halo2curves", @@ -2716,7 +2726,7 @@ dependencies = [ [[package]] name = "serdes_derive" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "proc-macro2", "quote", @@ -2813,6 +2823,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stacker" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys 0.59.0", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -2840,7 +2863,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "circuit", @@ -2936,7 +2959,7 @@ dependencies = [ "fastrand", "getrandom 0.3.3", "once_cell", - "rustix 1.0.7", + "rustix 1.0.8", "windows-sys 0.59.0", ] @@ -2991,9 +3014,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.46.0" +version = "1.46.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1140bb80481756a8cbe10541f37433b459c5aa1e727b4c020fbfebdc25bf3ec4" +checksum = "0cc3a2344dafbe23a245241fe8b09735b521110d30fcefbbd5feb1797ca35d17" dependencies = [ "backtrace", "bytes", @@ -3106,7 +3129,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "gkr_engine", @@ -3129,7 +3152,7 @@ dependencies = [ [[package]] name = "tree" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "arith", "ark-std", @@ -3231,7 +3254,7 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "utils" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=main#18dd57f6cc5f9bb8531c73f1061c481a0e951d87" +source = "git+https://github.com/PolyhedraZK/Expander?branch=zf%2Foptimize_pcs_claim_merging#300aa3d2e58ca9f5814ab39a502b265b9d7c7d92" dependencies = [ "colored", ] diff --git a/Cargo.toml b/Cargo.toml index 98820fbf..bdf159a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,22 +47,22 @@ stacker = "0.1.17" tiny-keccak = { version = "2.0", features = ["keccak"] } tokio = { version = "1", features = ["full"] } -arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -mpi_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -gkr_field_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -babybear = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -crosslayer_prototype = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "circuit" } -expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "transcript" } -expander_binary = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "bin" } -gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -goldilocks = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -poly_commit = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "poly_commit" } -polynomials = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -sumcheck = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -serdes = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "serdes" } -gkr_engine = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -gkr_hashers = { git = "https://github.com/PolyhedraZK/Expander", branch = "main" } -expander_utils = { git = "https://github.com/PolyhedraZK/Expander", branch = "main", package = "utils" } +arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +mpi_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +gkr_field_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +babybear = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +crosslayer_prototype = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging", package = "circuit" } +expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging", package = "transcript" } +expander_binary = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging", package = "bin" } +gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +goldilocks = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +poly_commit = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging", package = "poly_commit" } +polynomials = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +sumcheck = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +serdes = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging", package = "serdes" } +gkr_engine = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +gkr_hashers = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging" } +expander_utils = { git = "https://github.com/PolyhedraZK/Expander", branch = "zf/optimize_pcs_claim_merging", package = "utils" } diff --git a/circuit-std-rs/tests/logup.rs b/circuit-std-rs/tests/logup.rs index 3181d3d3..d22a7177 100644 --- a/circuit-std-rs/tests/logup.rs +++ b/circuit-std-rs/tests/logup.rs @@ -165,7 +165,7 @@ fn rangeproof_zkcuda_test() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } @@ -192,7 +192,7 @@ fn rangeproof_zkcuda_test_fail() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 36906b39..98cfb413 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -44,10 +44,12 @@ once_cell = "1.21.3" [dev-dependencies] rayon = "1.9" sha2 = "0.10.8" +stacker = "0.1.15" [features] default = [] profile = ["expander_utils/profile"] +zkcuda_profile = [] [[bin]] name = "trivial_circuit" @@ -61,10 +63,34 @@ path = "src/zkcuda/proving_system/expander_parallelized/server_bin.rs" name = "expander_server_pcs_defered" path = "src/zkcuda/proving_system/expander_pcs_defered/server_bin.rs" +[[bin]] +name = "expander_server_no_oversubscribe" +path = "src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs" + [[bin]] name = "zkcuda_matmul" -path = "bin/zkcuda_matmul.rs" +path = "bin/zkcuda_bench/zkcuda_matmul.rs" [[bin]] name = "zkcuda_matmul_pcs_defered" -path = "bin/zkcuda_matmul_pcs_defered.rs" +path = "bin/zkcuda_bench/zkcuda_matmul_pcs_defered.rs" + +[[bin]] +name = "zkcuda_matmul_no_oversubscribe" +path = "bin/zkcuda_bench/zkcuda_matmul_no_oversubscribe.rs" + +[[bin]] +name = "zkcuda_setup" +path = "bin/zkcuda_integration/setup.rs" + +[[bin]] +name = "zkcuda_prove" +path = "bin/zkcuda_integration/prove.rs" + +[[bin]] +name = "zkcuda_verify" +path = "bin/zkcuda_integration/verify.rs" + +[[bin]] +name = "zkcuda_cleanup" +path = "bin/zkcuda_integration/cleanup.rs" diff --git a/expander_compiler/bin/zkcuda_matmul.rs b/expander_compiler/bin/zkcuda_bench/zkcuda_matmul.rs similarity index 82% rename from expander_compiler/bin/zkcuda_matmul.rs rename to expander_compiler/bin/zkcuda_bench/zkcuda_matmul.rs index 75bd7b4a..17425403 100644 --- a/expander_compiler/bin/zkcuda_matmul.rs +++ b/expander_compiler/bin/zkcuda_bench/zkcuda_matmul.rs @@ -4,13 +4,14 @@ use expander_compiler::frontend::{ BN254Config, BasicAPI, CircuitField, Config, Error, FieldArith, Variable, API, }; -use expander_compiler::zkcuda::proving_system::{Expander, ParallelizedExpander, ProvingSystem}; +use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; +// use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; +use expander_compiler::zkcuda::proving_system::{ParallelizedExpander, ProvingSystem}; use expander_compiler::zkcuda::shape::Reshape; use expander_compiler::zkcuda::{ context::{call_kernel, Context}, kernel::{compile_with_spec_and_shapes, kernel, IOVecSpec, KernelPrimitive}, }; -use gkr::BN254ConfigSha2Hyrax; const M: usize = 512; const K: usize = 512; @@ -80,7 +81,7 @@ pub fn zkcuda_matmul, const N: usize>() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); let elapsed = timer.elapsed(); println!("Parallel Count {N}, Proving time: {elapsed:?}"); @@ -93,10 +94,10 @@ pub fn zkcuda_matmul, const N: usize>() { } fn main() { - zkcuda_matmul::, 4>(); - zkcuda_matmul::, 8>(); - zkcuda_matmul::, 16>(); - zkcuda_matmul::, 4>(); - zkcuda_matmul::, 8>(); - zkcuda_matmul::, 16>(); + // zkcuda_matmul::, 4>(); + // zkcuda_matmul::, 8>(); + // zkcuda_matmul::, 16>(); + zkcuda_matmul::, 4>(); + zkcuda_matmul::, 8>(); + zkcuda_matmul::, 16>(); } diff --git a/expander_compiler/bin/zkcuda_bench/zkcuda_matmul_no_oversubscribe.rs b/expander_compiler/bin/zkcuda_bench/zkcuda_matmul_no_oversubscribe.rs new file mode 100644 index 00000000..df3d19aa --- /dev/null +++ b/expander_compiler/bin/zkcuda_bench/zkcuda_matmul_no_oversubscribe.rs @@ -0,0 +1,18 @@ +#![allow(unused)] +mod zkcuda_matmul; +use expander_compiler::{ + frontend::BN254Config, + zkcuda::proving_system::{ + expander::config::ZKCudaBN254Hyrax, expander_pcs_defered::BN254ConfigSha2UniKZG, + ExpanderNoOverSubscribe, + }, +}; +use zkcuda_matmul::zkcuda_matmul; + +fn main() { + zkcuda_matmul::<_, ExpanderNoOverSubscribe, 4>(); + zkcuda_matmul::<_, ExpanderNoOverSubscribe, 8>(); + zkcuda_matmul::<_, ExpanderNoOverSubscribe, 16>(); + + zkcuda_matmul::<_, ExpanderNoOverSubscribe, 1024>(); +} diff --git a/expander_compiler/bin/zkcuda_bench/zkcuda_matmul_pcs_defered.rs b/expander_compiler/bin/zkcuda_bench/zkcuda_matmul_pcs_defered.rs new file mode 100644 index 00000000..0970798a --- /dev/null +++ b/expander_compiler/bin/zkcuda_bench/zkcuda_matmul_pcs_defered.rs @@ -0,0 +1,18 @@ +#![allow(unused)] +mod zkcuda_matmul; +use expander_compiler::{ + frontend::BN254Config, + zkcuda::proving_system::{expander_pcs_defered::BN254ConfigSha2UniKZG, ExpanderPCSDefered}, +}; +use gkr::BN254ConfigSha2Hyrax; +use zkcuda_matmul::zkcuda_matmul; + +fn main() { + // zkcuda_matmul::, 4>(); + // zkcuda_matmul::, 8>(); + // zkcuda_matmul::, 16>(); + + zkcuda_matmul::, 4>(); + zkcuda_matmul::, 8>(); + zkcuda_matmul::, 16>(); +} diff --git a/expander_compiler/bin/zkcuda_integration/circuit_def.rs b/expander_compiler/bin/zkcuda_integration/circuit_def.rs new file mode 100644 index 00000000..6c54e382 --- /dev/null +++ b/expander_compiler/bin/zkcuda_integration/circuit_def.rs @@ -0,0 +1,71 @@ +#![allow(clippy::ptr_arg)] +#![allow(clippy::needless_range_loop)] + +use expander_compiler::frontend::{ + BasicAPI, CircuitField, Config, Error, SIMDField, Variable, API, +}; +use expander_compiler::zkcuda::shape::Reshape; +use expander_compiler::zkcuda::{ + context::{call_kernel, ComputationGraph, Context, DeviceMemoryHandle}, + kernel::{compile_with_spec_and_shapes, kernel, IOVecSpec, KernelPrimitive}, +}; + +#[kernel] +fn add_2_macro(api: &mut API, a: &[InputVariable; 2], b: &mut OutputVariable) { + *b = api.add(a[0], a[1]); +} + +#[kernel] +fn add_16_macro(api: &mut API, a: &[InputVariable; 16], b: &mut OutputVariable) { + let mut sum = api.constant(0); + for i in 0..16 { + sum = api.add(sum, a[i]); + } + *b = sum; +} + +#[allow(clippy::type_complexity)] +pub fn gen_computation_graph_and_witness( + input: Option>>>, +) -> (ComputationGraph, Option>>>) { + let kernel_add_2: KernelPrimitive = compile_add_2_macro().unwrap(); + let kernel_add_16: KernelPrimitive = compile_add_16_macro().unwrap(); + + let mut ctx: Context = Context::default(); + let a = if let Some(input) = input.as_ref() { + assert_eq!(input.len(), 16); + assert!(input.iter().all(|v| v.len() == 2)); + input.clone() + } else { + let mut tmp = vec![vec![]; 16]; + for i in 0..16 { + for j in 0..2 { + tmp[i].push(CircuitField::::from((i * 2 + j + 1) as u32)); + } + } + tmp + }; + + let expected_result = a.iter().flatten().sum::>(); + + let a = ctx.copy_to_device(&a); + let mut b: DeviceMemoryHandle = None; + call_kernel!(ctx, kernel_add_2, 16, a, mut b).unwrap(); + let b = b.reshape(&[1, 16]); + let mut c: DeviceMemoryHandle = None; + call_kernel!(ctx, kernel_add_16, 1, b, mut c).unwrap(); + let c = c.reshape(&[]); + let result: CircuitField = ctx.copy_to_host(c); + assert_eq!(result, expected_result); + + let computation_graph = ctx.compile_computation_graph().unwrap(); + + let extended_witness = if input.is_some() { + ctx.solve_witness().unwrap(); + Some(ctx.export_device_memories()) + } else { + None + }; + + (computation_graph, extended_witness) +} diff --git a/expander_compiler/bin/zkcuda_integration/cleanup.rs b/expander_compiler/bin/zkcuda_integration/cleanup.rs new file mode 100644 index 00000000..c989c36c --- /dev/null +++ b/expander_compiler/bin/zkcuda_integration/cleanup.rs @@ -0,0 +1,10 @@ +use expander_compiler::{ + frontend::BN254Config, + zkcuda::proving_system::{ + expander::config::ZKCudaBN254Hyrax, ExpanderNoOverSubscribe, ProvingSystem, + }, +}; + +fn main() { + as ProvingSystem>::post_process(); +} diff --git a/expander_compiler/bin/zkcuda_integration/prove.rs b/expander_compiler/bin/zkcuda_integration/prove.rs new file mode 100644 index 00000000..0e77c7a2 --- /dev/null +++ b/expander_compiler/bin/zkcuda_integration/prove.rs @@ -0,0 +1,41 @@ +mod circuit_def; +use circuit_def::gen_computation_graph_and_witness; +use expander_compiler::{ + frontend::{BN254Config, CircuitField}, + zkcuda::{ + context::ComputationGraph, + proving_system::{ + expander::config::ZKCudaBN254Hyrax, ExpanderNoOverSubscribe, ProvingSystem, + }, + }, +}; +use serdes::ExpSerde; + +#[allow(clippy::needless_range_loop)] +fn main() { + // Replace this with your actual input data. + let mut input = vec![vec![]; 16]; + for i in 0..16 { + for j in 0..2 { + input[i].push(CircuitField::::from((i * 2 + j + 1) as u32)); + } + } + + let (_, extended_witness) = gen_computation_graph_and_witness::(Some(input)); + + // Note: we've saved the computation graph and setup in the server. In order to generate a proof, we only need to submit the witness. + let dummy_prover_setup = as ProvingSystem< + BN254Config, + >>::ProverSetup::default(); + let dummy_computation_graph = ComputationGraph::::default(); + + let proof = ExpanderNoOverSubscribe::::prove( + &dummy_prover_setup, + &dummy_computation_graph, + extended_witness.unwrap(), + ); + + let mut bytes = vec![]; + proof.serialize_into(&mut bytes).unwrap(); + std::fs::write("/tmp/proof.bin", &bytes).unwrap(); +} diff --git a/expander_compiler/bin/zkcuda_integration/run.sh b/expander_compiler/bin/zkcuda_integration/run.sh new file mode 100755 index 00000000..206119cd --- /dev/null +++ b/expander_compiler/bin/zkcuda_integration/run.sh @@ -0,0 +1,17 @@ + +#!/bin/bash +cargo build --release --bin zkcuda_setup --bin zkcuda_prove --bin zkcuda_verify --bin zkcuda_cleanup + +# setup the server +cargo run --release --bin zkcuda_setup + +# prove a first instance +cargo run --release --bin zkcuda_prove +cargo run --release --bin zkcuda_verify + +# prove a second instance +cargo run --release --bin zkcuda_prove +cargo run --release --bin zkcuda_verify + +# shutdown the server +cargo run --release --bin zkcuda_cleanup diff --git a/expander_compiler/bin/zkcuda_integration/setup.rs b/expander_compiler/bin/zkcuda_integration/setup.rs new file mode 100644 index 00000000..bc94515b --- /dev/null +++ b/expander_compiler/bin/zkcuda_integration/setup.rs @@ -0,0 +1,23 @@ +mod circuit_def; +use circuit_def::gen_computation_graph_and_witness; +use expander_compiler::{ + frontend::BN254Config, + zkcuda::proving_system::{ + expander::config::ZKCudaBN254Hyrax, ExpanderNoOverSubscribe, ProvingSystem, + }, +}; +use serdes::ExpSerde; + +fn main() { + let (computation_graph, _) = gen_computation_graph_and_witness::(None); + let (prover_setup, verifier_setup) = + ExpanderNoOverSubscribe::::setup(&computation_graph); + + let mut bytes = vec![]; + prover_setup.serialize_into(&mut bytes).unwrap(); + std::fs::write("/tmp/prover_setup.bin", &bytes).unwrap(); + + bytes.clear(); + verifier_setup.serialize_into(&mut bytes).unwrap(); + std::fs::write("/tmp/verifier_setup.bin", &bytes).unwrap(); +} diff --git a/expander_compiler/bin/zkcuda_integration/verify.rs b/expander_compiler/bin/zkcuda_integration/verify.rs new file mode 100644 index 00000000..73c93656 --- /dev/null +++ b/expander_compiler/bin/zkcuda_integration/verify.rs @@ -0,0 +1,32 @@ +mod circuit_def; +use std::io::Cursor; + +use circuit_def::gen_computation_graph_and_witness; +use expander_compiler::{ + frontend::BN254Config, + zkcuda::proving_system::{ + expander::config::ZKCudaBN254Hyrax, ExpanderNoOverSubscribe, ProvingSystem, + }, +}; +use serdes::ExpSerde; + +fn main() { + let (computation_graph, _) = gen_computation_graph_and_witness::(None); + + let verifier_setup_bytes = std::fs::read("/tmp/verifier_setup.bin").unwrap(); + let verifier_setup = as ProvingSystem< + BN254Config, + >>::VerifierSetup::deserialize_from(Cursor::new(verifier_setup_bytes)) + .unwrap(); + + let proof_bytes = std::fs::read("/tmp/proof.bin").unwrap(); + let proof = as ProvingSystem>::Proof::deserialize_from(Cursor::new(proof_bytes)).unwrap(); + + let verified = + as ProvingSystem>::verify( + &verifier_setup, + &computation_graph, + &proof, + ); + assert!(verified, "Proof verification failed"); +} diff --git a/expander_compiler/bin/zkcuda_matmul_pcs_defered.rs b/expander_compiler/bin/zkcuda_matmul_pcs_defered.rs deleted file mode 100644 index fa8e17be..00000000 --- a/expander_compiler/bin/zkcuda_matmul_pcs_defered.rs +++ /dev/null @@ -1,11 +0,0 @@ -#![allow(unused)] -mod zkcuda_matmul; -use expander_compiler::{frontend::BN254Config, zkcuda::proving_system::ExpanderPCSDefered}; -use gkr::BN254ConfigSha2Hyrax; -use zkcuda_matmul::zkcuda_matmul; - -fn main() { - zkcuda_matmul::, 4>(); - zkcuda_matmul::, 8>(); - zkcuda_matmul::, 16>(); -} diff --git a/expander_compiler/ec_go_lib/src/proving.rs b/expander_compiler/ec_go_lib/src/proving.rs index a0a0b459..809528c3 100644 --- a/expander_compiler/ec_go_lib/src/proving.rs +++ b/expander_compiler/ec_go_lib/src/proving.rs @@ -1,12 +1,10 @@ use std::ptr; use std::slice; -use arith::SimdField; use expander_binary::executor; use expander_compiler::frontend::ChallengeField; use expander_compiler::frontend::SIMDField; -use gkr_engine::FieldEngine; use libc::{c_uchar, c_ulong, malloc}; use expander_compiler::circuit::config; @@ -20,11 +18,7 @@ use super::{match_config_id, ByteArray, Config}; fn prove_circuit_file_inner( circuit_filename: &str, witness: &[u8], -) -> Result, String> -where - C::FieldConfig: FieldEngine, - C::PCSField: SimdField::CircuitField>, -{ +) -> Result, String> { // (None, None) means single core execution let mpi_config = MPIConfig::prover_new(None, None); diff --git a/expander_compiler/src/circuit/layered/export.rs b/expander_compiler/src/circuit/layered/export.rs index f59b7311..558d5ab8 100644 --- a/expander_compiler/src/circuit/layered/export.rs +++ b/expander_compiler/src/circuit/layered/export.rs @@ -72,8 +72,8 @@ impl Circuit { pub fn export_to_expander_flatten(&self) -> expander_circuit::Circuit { let circuit = self.export_to_expander::(); - let mut flattened = circuit.flatten::(); - flattened.pre_process_gkr::(); + let mut flattened = circuit.flatten(); + flattened.pre_process_gkr(); flattened } } diff --git a/expander_compiler/src/utils/misc.rs b/expander_compiler/src/utils/misc.rs index af06365c..58cb1031 100644 --- a/expander_compiler/src/utils/misc.rs +++ b/expander_compiler/src/utils/misc.rs @@ -1,5 +1,16 @@ use std::collections::{HashMap, HashSet}; +pub fn prev_power_of_two(x: usize) -> usize { + if x == 0 { + return 0; + } + let mut padk: usize = 0; + while (1 << padk) <= x { + padk += 1; + } + 1 << (padk - 1) +} + pub fn next_power_of_two(x: usize) -> usize { let mut padk: usize = 0; while (1 << padk) < x { diff --git a/expander_compiler/src/zkcuda/proving_system.rs b/expander_compiler/src/zkcuda/proving_system.rs index a5f3fce5..20e627f0 100644 --- a/expander_compiler/src/zkcuda/proving_system.rs +++ b/expander_compiler/src/zkcuda/proving_system.rs @@ -18,3 +18,6 @@ pub use expander_parallelized::api_parallel::*; pub mod expander_pcs_defered; pub use expander_pcs_defered::api_pcs_defered::*; + +pub mod expander_no_oversubscribe; +pub use expander_no_oversubscribe::api_no_oversubscribe::*; diff --git a/expander_compiler/src/zkcuda/proving_system/dummy.rs b/expander_compiler/src/zkcuda/proving_system/dummy.rs index 56f2717a..b18beb6b 100644 --- a/expander_compiler/src/zkcuda/proving_system/dummy.rs +++ b/expander_compiler/src/zkcuda/proving_system/dummy.rs @@ -146,7 +146,7 @@ impl ProvingSystem for DummyProvingSystem { fn prove( prover_setup: &Self::ProverSetup, computation_graph: &ComputationGraph, - device_memories: &[Vec>], + device_memories: Vec>>, ) -> Self::Proof { let (commitments, states) = device_memories .iter() diff --git a/expander_compiler/src/zkcuda/proving_system/expander.rs b/expander_compiler/src/zkcuda/proving_system/expander.rs index d75cdb6f..ec444713 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander.rs @@ -1,6 +1,7 @@ pub mod api_single_thread; pub mod commit_impl; +pub mod config; pub mod prove_impl; pub mod setup_impl; pub mod structs; diff --git a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs index 0bb500de..836201ac 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/api_single_thread.rs @@ -1,4 +1,3 @@ -use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; use std::io::Cursor; use crate::circuit::config::Config; @@ -35,13 +34,12 @@ impl KernelWiseProvingSystem for Expander where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { - type ProverSetup = ExpanderProverSetup; - type VerifierSetup = ExpanderVerifierSetup; + type ProverSetup = ExpanderProverSetup; + type VerifierSetup = ExpanderVerifierSetup; type Proof = ExpanderProof; - type Commitment = ExpanderCommitment; - type CommitmentState = ExpanderCommitmentState; + type Commitment = ExpanderCommitment; + type CommitmentState = ExpanderCommitmentState; fn setup( computation_graph: &crate::zkcuda::context::ComputationGraph, @@ -53,7 +51,7 @@ where prover_setup: &Self::ProverSetup, vals: &[SIMDField], ) -> (Self::Commitment, Self::CommitmentState) { - local_commit_impl::(prover_setup, vals) + local_commit_impl::(prover_setup.p_keys.get(&vals.len()).unwrap(), vals) } fn prove_kernel( @@ -69,7 +67,7 @@ where check_inputs(kernel, commitments_values, parallel_count, is_broadcast); let (mut expander_circuit, mut prover_scratch) = - prepare_expander_circuit::(kernel, 1); + prepare_expander_circuit::(kernel, 1); let mut proof = ExpanderProof { data: vec![] }; @@ -83,7 +81,7 @@ where parallel_index, parallel_count, ); - let challenge = prove_gkr_with_local_vals::( + let challenge = prove_gkr_with_local_vals::( &mut expander_circuit, &mut prover_scratch, &local_vals, @@ -118,8 +116,7 @@ where is_broadcast: &[bool], ) -> bool { let timer = Timer::new("verify", true); - let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten::(); - expander_circuit.pre_process_gkr::(); + let mut expander_circuit = kernel.layered_circuit().export_to_expander_flatten(); for i in 0..parallel_count { let mut transcript = C::TranscriptConfig::new(); @@ -173,8 +170,6 @@ where // In this case, generate the implementation with a procedural macro seems to be the best solution. impl> ProvingSystem for Expander -where - C::FieldConfig: FieldEngine, { type ProverSetup = >::ProverSetup; type VerifierSetup = >::VerifierSetup; @@ -189,7 +184,7 @@ where fn prove( prover_setup: &Self::ProverSetup, computation_graph: &ComputationGraph, - device_memories: &[Vec>], + device_memories: Vec>>, ) -> Self::Proof { let (commitments, states) = device_memories .iter() @@ -238,8 +233,8 @@ where ) -> bool { let verified = proof .proofs - .par_iter() - .zip(computation_graph.proof_templates().par_iter()) + .iter() + .zip(computation_graph.proof_templates().iter()) .map(|(local_proof, template)| { let local_commitments = template .commitment_indices() diff --git a/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs index 2b827083..400296ba 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/commit_impl.rs @@ -1,37 +1,34 @@ use expander_utils::timer::Timer; -use gkr_engine::{ExpanderPCS, FieldEngine, GKREngine, MPIConfig}; +use gkr_engine::{ExpanderPCS, GKREngine, MPIConfig, StructuredReferenceString}; use polynomials::RefMultiLinearPoly; -use super::structs::ExpanderProverSetup; use crate::{ frontend::{Config, SIMDField}, zkcuda::proving_system::expander::structs::{ExpanderCommitment, ExpanderCommitmentState}, }; pub fn local_commit_impl( - prover_setup: &ExpanderProverSetup, + p_key: &<>::SRS as StructuredReferenceString>::PKey, vals: &[SIMDField], ) -> ( - ExpanderCommitment, - ExpanderCommitmentState, + ExpanderCommitment, + ExpanderCommitmentState, ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let timer = Timer::new("commit", true); let n_vars = vals.len().ilog2() as usize; - let params = >::gen_params(n_vars, 1); - let p_key = prover_setup.p_keys.get(&vals.len()).unwrap(); + let params = >::gen_params(n_vars, 1); - let mut scratch = >::init_scratch_pad( + let mut scratch = >::init_scratch_pad( ¶ms, &MPIConfig::prover_new(None, None), ); - let commitment = >::commit( + let commitment = >::commit( ¶ms, &MPIConfig::prover_new(None, None), p_key, diff --git a/expander_compiler/src/zkcuda/proving_system/expander/config.rs b/expander_compiler/src/zkcuda/proving_system/expander/config.rs new file mode 100644 index 00000000..75c78557 --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander/config.rs @@ -0,0 +1,59 @@ +use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2Raw, M31x16ConfigSha2RawVanilla}; +use gkr_engine::GKREngine; + +use crate::{ + frontend::{BN254Config, Config, M31Config}, + zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG, +}; + +pub trait ZKCudaConfig { + type ECCConfig: Config; + type GKRConfig: GKREngine::FieldConfig>; + + const BATCH_PCS: bool = false; +} + +pub type GetPCS = <::GKRConfig as GKREngine>::PCSConfig; +pub type GetTranscript = + <::GKRConfig as GKREngine>::TranscriptConfig; +pub type GetFieldConfig = + <::GKRConfig as GKREngine>::FieldConfig; + +pub struct ZKCudaConfigImpl +where + ECC: Config, + GKR: GKREngine::FieldConfig>, +{ + _phantom: std::marker::PhantomData<(ECC, GKR, bool)>, +} + +impl ZKCudaConfig for ZKCudaConfigImpl +where + ECC: Config, + GKR: GKREngine::FieldConfig>, +{ + type ECCConfig = ECC; + type GKRConfig = GKR; + + const BATCH_PCS: bool = BATCH_PCS; +} + +// Concrete ZKCudaConfig types for various configurations +pub type ZKCudaBN254Hyrax<'a> = ZKCudaConfigImpl, false>; +pub type ZKCudaBN254KZG<'a> = ZKCudaConfigImpl, false>; + +pub type ZKCudaM31<'a> = ZKCudaConfigImpl, false>; +pub type ZKCudaGF2<'a> = ZKCudaConfigImpl, false>; +pub type ZKCudaGoldilocks<'a> = ZKCudaConfigImpl, false>; +pub type ZKCudaBabyBear<'a> = ZKCudaConfigImpl, false>; + +// Batch PCS types +pub type ZKCudaBN254HyraxBatchPCS<'a> = + ZKCudaConfigImpl, true>; +pub type ZKCudaBN254KZGBatchPCS<'a> = + ZKCudaConfigImpl, true>; + +pub type ZKCudaM31BatchPCS<'a> = ZKCudaConfigImpl, true>; +pub type ZKCudaGF2BatchPCS<'a> = ZKCudaConfigImpl, true>; +pub type ZKCudaGoldilocksBatchPCS<'a> = ZKCudaConfigImpl, true>; +pub type ZKCudaBabyBearBatchPCS<'a> = ZKCudaConfigImpl, true>; diff --git a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs index 3530abae..ad7b3121 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/prove_impl.rs @@ -19,23 +19,21 @@ use crate::{ /// ECCCircuit -> ExpanderCircuit /// Returns an additional prover scratch pad for later use in GKR. -pub fn prepare_expander_circuit( +pub fn prepare_expander_circuit( kernel: &Kernel, mpi_world_size: usize, -) -> (Circuit, ProverScratchPad) +) -> (Circuit, ProverScratchPad) where - C: GKREngine, - ECCConfig: Config, - C::FieldConfig: FieldEngine, + F: FieldEngine, + ECCConfig: Config, + ECCConfig::FieldConfig: FieldEngine, { - let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten::(); - expander_circuit.pre_process_gkr::(); + let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten(); + expander_circuit.pre_process_gkr(); + let (max_num_input_var, max_num_output_var) = super::utils::max_n_vars(&expander_circuit); - let prover_scratch = ProverScratchPad::::new( - max_num_input_var, - max_num_output_var, - mpi_world_size, - ); + let prover_scratch = + ProverScratchPad::::new(max_num_input_var, max_num_output_var, mpi_world_size); (expander_circuit, prover_scratch) } @@ -84,17 +82,14 @@ pub fn prepare_inputs_with_local_vals( input_vals } -pub fn prove_gkr_with_local_vals( - expander_circuit: &mut Circuit, - prover_scratch: &mut ProverScratchPad, - local_commitment_values: &[impl AsRef<[::SimdCircuitField]>], +pub fn prove_gkr_with_local_vals( + expander_circuit: &mut Circuit, + prover_scratch: &mut ProverScratchPad, + local_commitment_values: &[impl AsRef<[F::SimdCircuitField]>], partition_info: &[LayeredCircuitInputVec], - transcript: &mut C::TranscriptConfig, + transcript: &mut T, mpi_config: &MPIConfig, -) -> ExpanderDualVarChallenge -where - C::FieldConfig: FieldEngine, -{ +) -> ExpanderDualVarChallenge { expander_circuit.layers[0].input_vals = prepare_inputs_with_local_vals( 1 << expander_circuit.log_input_size(), partition_info, @@ -104,10 +99,7 @@ where expander_circuit.evaluate(); let (claimed_v, challenge) = gkr_prove(expander_circuit, prover_scratch, transcript, mpi_config); - assert_eq!( - claimed_v, - ::ChallengeField::from(0) - ); + assert_eq!(claimed_v, F::ChallengeField::from(0)); challenge } @@ -169,18 +161,14 @@ pub fn partition_challenge_and_location_for_pcs_no_mpi( pub fn pcs_local_open_impl( vals: &[::SimdCircuitField], challenge: &ExpanderSingleVarChallenge, - p_keys: &ExpanderProverSetup, + p_keys: &ExpanderProverSetup, transcript: &mut C::TranscriptConfig, -) where - C::FieldConfig: FieldEngine, -{ +) { assert_eq!(challenge.r_mpi.len(), 0); let val_len = vals.len(); - let params = >::gen_params( - val_len.ilog2() as usize, - 1, - ); + let params = + >::gen_params(val_len.ilog2() as usize, 1); let p_key = p_keys.p_keys.get(&val_len).unwrap(); let poly = RefMultiLinearPoly::from_ref(vals); @@ -191,14 +179,14 @@ pub fn pcs_local_open_impl( transcript.append_field_element(&v); transcript.lock_proof(); - let opening = >::open( + let opening = >::open( ¶ms, &MPIConfig::prover_new(None, None), p_key, &poly, challenge, transcript, - &>::init_scratch_pad( + &>::init_scratch_pad( ¶ms, &MPIConfig::prover_new(None, None), ), @@ -217,14 +205,12 @@ pub fn pcs_local_open_impl( pub fn partition_gkr_claims_and_open_pcs_no_mpi_impl( gkr_claim: &ExpanderSingleVarChallenge, global_vals: &[impl AsRef<[::SimdCircuitField]>], - p_keys: &ExpanderProverSetup, + p_keys: &ExpanderProverSetup, is_broadcast: &[bool], parallel_index: usize, parallel_num: usize, transcript: &mut C::TranscriptConfig, -) where - C::FieldConfig: FieldEngine, -{ +) { for (commitment_val, ib) in global_vals.iter().zip(is_broadcast) { let val_len = commitment_val.as_ref().len(); let (challenge_for_pcs, _) = partition_challenge_and_location_for_pcs_no_mpi::< @@ -248,14 +234,12 @@ pub fn partition_gkr_claims_and_open_pcs_no_mpi_impl( pub fn partition_gkr_claims_and_open_pcs_no_mpi( gkr_claim: &ExpanderDualVarChallenge, global_vals: &[impl AsRef<[::SimdCircuitField]>], - p_keys: &ExpanderProverSetup, + p_keys: &ExpanderProverSetup, is_broadcast: &[bool], parallel_index: usize, parallel_num: usize, transcript: &mut C::TranscriptConfig, -) where - C::FieldConfig: FieldEngine, -{ +) { let challenges = if let Some(challenge_y) = gkr_claim.challenge_y() { vec![gkr_claim.challenge_x(), challenge_y] } else { diff --git a/expander_compiler/src/zkcuda/proving_system/expander/setup_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/setup_impl.rs index 7b087004..e761b6cb 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/setup_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/setup_impl.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use gkr_engine::{FieldEngine, GKREngine, MPIConfig}; +use gkr_engine::{GKREngine, MPIConfig}; use crate::{ frontend::Config, @@ -16,13 +16,12 @@ use crate::{ pub fn local_setup_impl( computation_graph: &ComputationGraph, ) -> ( - ExpanderProverSetup, - ExpanderVerifierSetup, + ExpanderProverSetup, + ExpanderVerifierSetup, ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let mut p_keys = HashMap::new(); let mut v_keys = HashMap::new(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander/structs.rs b/expander_compiler/src/zkcuda/proving_system/expander/structs.rs index 07c47480..1a81a5e1 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/structs.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/structs.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; -use arith::Field; use gkr_engine::{ExpanderPCS, FieldEngine, Proof as BytesProof, StructuredReferenceString}; use serdes::ExpSerde; @@ -9,14 +8,12 @@ use crate::{frontend::Config, zkcuda::proving_system::Commitment}; /// A wrapper for the PCS Commitment that includes the length of the values committed to. #[allow(clippy::type_complexity)] #[derive(ExpSerde)] -pub struct ExpanderCommitment> { +pub struct ExpanderCommitment> { pub vals_len: usize, pub commitment: PCS::Commitment, } -impl> Clone - for ExpanderCommitment -{ +impl> Clone for ExpanderCommitment { fn clone(&self) -> Self { Self { vals_len: self.vals_len, @@ -25,11 +22,8 @@ impl> Clone } } -impl< - F: FieldEngine, - PCS: ExpanderPCS, - ECCConfig: Config, - > Commitment for ExpanderCommitment +impl, ECCConfig: Config> Commitment + for ExpanderCommitment { fn vals_len(&self) -> usize { self.vals_len @@ -40,13 +34,11 @@ impl< /// For Raw, KZG, and Hyrax, this is not needed, so the scratchpad can be empty. #[allow(clippy::type_complexity)] #[derive(ExpSerde)] -pub struct ExpanderCommitmentState> { +pub struct ExpanderCommitmentState> { pub scratch: PCS::ScratchPad, } -impl> Clone - for ExpanderCommitmentState -{ +impl> Clone for ExpanderCommitmentState { fn clone(&self) -> Self { Self { scratch: self.scratch.clone(), @@ -58,13 +50,11 @@ impl> Clone /// The keys are indexed by the length of values committed to, allowing for different setups based on the length of the values. #[allow(clippy::type_complexity)] #[derive(ExpSerde)] -pub struct ExpanderProverSetup> { +pub struct ExpanderProverSetup> { pub p_keys: HashMap::PKey>, } -impl> Default - for ExpanderProverSetup -{ +impl> Default for ExpanderProverSetup { fn default() -> Self { Self { p_keys: HashMap::new(), @@ -72,9 +62,7 @@ impl> Default } } -impl> Clone - for ExpanderProverSetup -{ +impl> Clone for ExpanderProverSetup { fn clone(&self) -> Self { Self { p_keys: self.p_keys.clone(), @@ -86,14 +74,12 @@ impl> Clone /// The keys are indexed by the length of values committed to, allowing for different setups based on the length of the values. #[allow(clippy::type_complexity)] #[derive(ExpSerde)] -pub struct ExpanderVerifierSetup> { +pub struct ExpanderVerifierSetup> { pub v_keys: HashMap::VKey>, } // implement default -impl> Default - for ExpanderVerifierSetup -{ +impl> Default for ExpanderVerifierSetup { fn default() -> Self { Self { v_keys: HashMap::new(), @@ -101,9 +87,7 @@ impl> Default } } -impl> Clone - for ExpanderVerifierSetup -{ +impl> Clone for ExpanderVerifierSetup { fn clone(&self) -> Self { Self { v_keys: self.v_keys.clone(), diff --git a/expander_compiler/src/zkcuda/proving_system/expander/utils.rs b/expander_compiler/src/zkcuda/proving_system/expander/utils.rs index a525eab8..0273a55d 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/utils.rs @@ -3,12 +3,7 @@ use gkr_engine::{ExpanderPCS, FieldEngine, MPIConfig, StructuredReferenceString, use poly_commit::expander_pcs_init_testing_only; #[allow(clippy::type_complexity)] -pub fn pcs_testing_setup_fixed_seed< - 'a, - F: FieldEngine, - T: Transcript, - PCS: ExpanderPCS, ->( +pub fn pcs_testing_setup_fixed_seed<'a, F: FieldEngine, T: Transcript, PCS: ExpanderPCS>( vals_len: usize, mpi_config: &MPIConfig<'a>, ) -> ( @@ -17,10 +12,7 @@ pub fn pcs_testing_setup_fixed_seed< ::VKey, PCS::ScratchPad, ) { - expander_pcs_init_testing_only::( - vals_len.ilog2() as usize, - mpi_config, - ) + expander_pcs_init_testing_only::(vals_len.ilog2() as usize, mpi_config) } pub fn max_n_vars(circuit: &ExpanderCircuit) -> (usize, usize) { diff --git a/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs index 8aed5cbf..d5c40f2a 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander/verify_impl.rs @@ -24,35 +24,31 @@ use crate::{ pub fn verify_pcs( mut proof_reader: impl Read, - commitment: &ExpanderCommitment, + commitment: &ExpanderCommitment, challenge: &ExpanderSingleVarChallenge, claim: &::ChallengeField, - v_keys: &ExpanderVerifierSetup, + v_keys: &ExpanderVerifierSetup, transcript: &mut C::TranscriptConfig, ) -> bool where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { - let val_len = as Commitment< - ECCConfig, - >>::vals_len(commitment); + let val_len = + as Commitment>::vals_len( + commitment, + ); - let params = >::gen_params( - val_len.ilog2() as usize, - 1, - ); + let params = + >::gen_params(val_len.ilog2() as usize, 1); let v_key = v_keys.v_keys.get(&val_len).unwrap(); let opening = - >::Opening::deserialize_from( - &mut proof_reader, - ) - .unwrap(); + >::Opening::deserialize_from(&mut proof_reader) + .unwrap(); transcript.lock_proof(); - let verified = >::verify( + let verified = >::verify( ¶ms, v_key, &commitment.commitment, @@ -76,10 +72,10 @@ where pub fn verify_pcs_opening_and_aggregation_no_mpi_impl( mut proof_reader: impl Read, kernel: &Kernel, - v_keys: &ExpanderVerifierSetup, + v_keys: &ExpanderVerifierSetup, challenge: &ExpanderSingleVarChallenge, y: &::ChallengeField, - commitments: &[&ExpanderCommitment], + commitments: &[&ExpanderCommitment], is_broadcast: &[bool], parallel_index: usize, parallel_count: usize, @@ -88,7 +84,6 @@ pub fn verify_pcs_opening_and_aggregation_no_mpi_impl( where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let mut target_y = ::ChallengeField::ZERO; for ((input, commitment), ib) in kernel @@ -98,9 +93,9 @@ where .zip(is_broadcast) { let val_len = - as Commitment< - ECCConfig, - >>::vals_len(commitment); + as Commitment>::vals_len( + commitment, + ); let (challenge_for_pcs, component_idx_vars) = partition_challenge_and_location_for_pcs_no_mpi( challenge, @@ -144,11 +139,11 @@ where pub fn verify_pcs_opening_and_aggregation_no_mpi( mut proof_reader: impl Read, kernel: &Kernel, - v_keys: &ExpanderVerifierSetup, + v_keys: &ExpanderVerifierSetup, challenge: &ExpanderDualVarChallenge, claim_v0: ::ChallengeField, claim_v1: Option<::ChallengeField>, - commitments: &[&ExpanderCommitment], + commitments: &[&ExpanderCommitment], is_broadcast: &[bool], parallel_index: usize, parallel_count: usize, @@ -157,7 +152,6 @@ pub fn verify_pcs_opening_and_aggregation_no_mpi( where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let challenges = if let Some(challenge_y) = challenge.challenge_y() { vec![challenge.challenge_x(), challenge_y] diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs new file mode 100644 index 00000000..65954def --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe.rs @@ -0,0 +1,4 @@ +pub mod api_no_oversubscribe; +pub mod profiler; +pub mod prove_impl; +pub mod server_fn; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs new file mode 100644 index 00000000..7d7fed98 --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/api_no_oversubscribe.rs @@ -0,0 +1,75 @@ +use crate::frontend::SIMDField; +use crate::zkcuda::context::ComputationGraph; +use crate::zkcuda::proving_system::expander::config::{GetFieldConfig, GetPCS, ZKCudaConfig}; +use crate::zkcuda::proving_system::expander::structs::{ + ExpanderProverSetup, ExpanderVerifierSetup, +}; +use crate::zkcuda::proving_system::expander_parallelized::client_utils::{ + client_launch_server_and_setup, client_parse_args, client_send_witness_and_prove, wait_async, + ClientHttpHelper, +}; +use crate::zkcuda::proving_system::{ + CombinedProof, ExpanderPCSDefered, ParallelizedExpander, ProvingSystem, +}; + +use super::super::Expander; + +use gkr_engine::ExpanderPCS; + +pub struct ExpanderNoOverSubscribe { + _config: std::marker::PhantomData, +} + +impl ProvingSystem for ExpanderNoOverSubscribe +where + as ExpanderPCS>>::Commitment: + AsRef< as ExpanderPCS>>::Commitment>, +{ + type ProverSetup = ExpanderProverSetup, GetPCS>; + type VerifierSetup = ExpanderVerifierSetup, GetPCS>; + type Proof = CombinedProof>; + + fn setup( + computation_graph: &ComputationGraph, + ) -> (Self::ProverSetup, Self::VerifierSetup) { + let server_binary = client_parse_args() + .unwrap_or("../target/release/expander_server_no_oversubscribe".to_owned()); + client_launch_server_and_setup::( + &server_binary, + computation_graph, + false, + ZC::BATCH_PCS, + ) + } + + fn prove( + _prover_setup: &Self::ProverSetup, + _computation_graph: &ComputationGraph, + device_memories: Vec>>, + ) -> Self::Proof { + client_send_witness_and_prove(device_memories) + } + + fn verify( + verifier_setup: &Self::VerifierSetup, + computation_graph: &ComputationGraph, + proof: &Self::Proof, + ) -> bool { + match ZC::BATCH_PCS { + true => ExpanderPCSDefered::::verify( + verifier_setup, + computation_graph, + proof, + ), + false => ParallelizedExpander::::verify( + verifier_setup, + computation_graph, + proof, + ), + } + } + + fn post_process() { + wait_async(ClientHttpHelper::request_exit()) + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs new file mode 100644 index 00000000..ed9421cc --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/profiler.rs @@ -0,0 +1,78 @@ +#[cfg(feature = "zkcuda_profile")] +mod profiler_enabled { + use std::collections::HashMap; + + use arith::Fr; + use halo2curves::ff::PrimeField; + + #[derive(Clone, Debug, Default)] + pub struct NBytesProfiler { + pub bytes_stats: HashMap, + } + + impl NBytesProfiler { + pub fn new() -> Self { + NBytesProfiler { + bytes_stats: HashMap::new(), + } + } + + pub fn add_bytes(&mut self, n_bytes: usize) { + *self.bytes_stats.entry(n_bytes).or_insert(0) += 1; + } + + pub fn add_fr(&mut self, fr: Fr) { + let le_bytes = fr.to_repr(); + let be_leading_zeros_bytes = le_bytes.into_iter().rev().take_while(|&b| b == 0).count(); + let n_bytes = le_bytes.len() - be_leading_zeros_bytes; + self.add_bytes(n_bytes); + } + + pub fn print_stats(&self) { + for (bytes, count) in &self.bytes_stats { + println!("{bytes} bytes: {count}"); + } + } + } +} + +#[cfg(not(feature = "zkcuda_profile"))] +mod profiler_disabled { + use arith::Fr; + + #[derive(Clone, Debug, Default)] + pub struct NBytesProfiler; + + impl NBytesProfiler { + pub fn new() -> Self { + NBytesProfiler + } + + pub fn add_bytes(&mut self, _n_bytes: usize) {} + + pub fn add_fr(&mut self, _fr: Fr) {} + + pub fn print_stats(&self) {} + } +} + +#[cfg(not(feature = "zkcuda_profile"))] +pub use profiler_disabled::NBytesProfiler; +#[cfg(feature = "zkcuda_profile")] +pub use profiler_enabled::NBytesProfiler; + +#[cfg(feature = "zkcuda_profile")] +mod test { + #![allow(unused_imports)] + use super::NBytesProfiler; + use arith::Fr; + + #[test] + fn test_n_bytes_profiler() { + let mut profiler = NBytesProfiler::new(); + profiler.add_bytes(32); + profiler.add_bytes(64); + profiler.add_fr(Fr::from(256u64)); + profiler.print_stats(); + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs new file mode 100644 index 00000000..bc980372 --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/prove_impl.rs @@ -0,0 +1,453 @@ +use arith::{Field, Fr, SimdField}; +use expander_utils::timer::Timer; +use gkr_engine::{ + BN254ConfigXN, ExpanderDualVarChallenge, FieldEngine, GKREngine, MPIConfig, MPIEngine, + Transcript, +}; + +use crate::{ + frontend::{Config, SIMDField}, + utils::misc::next_power_of_two, + zkcuda::{ + context::ComputationGraph, + kernel::{Kernel, LayeredCircuitInputVec}, + proving_system::{ + expander::{ + commit_impl::local_commit_impl, + config::{GetFieldConfig, GetPCS, GetTranscript, ZKCudaConfig}, + prove_impl::{ + get_local_vals, prepare_expander_circuit, prepare_inputs_with_local_vals, + }, + structs::{ExpanderProof, ExpanderProverSetup}, + }, + expander_no_oversubscribe::profiler::NBytesProfiler, + expander_parallelized::{ + prove_impl::partition_single_gkr_claim_and_open_pcs_mpi, + server_ctrl::generate_local_mpi_config, + }, + expander_pcs_defered::prove_impl::{ + extract_pcs_claims, max_len_setup_commit_impl, open_defered_pcs, + }, + CombinedProof, Expander, + }, + }, +}; + +pub fn mpi_prove_no_oversubscribe_impl( + global_mpi_config: &MPIConfig<'static>, + prover_setup: &ExpanderProverSetup, GetPCS>, + computation_graph: &ComputationGraph, + values: &[impl AsRef<[SIMDField]>], + n_bytes_profiler: &mut NBytesProfiler, +) -> Option>> +where + ::FieldConfig: FieldEngine, +{ + let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); + let (commitments, states) = if global_mpi_config.is_root() { + let (commitments, states) = values + .iter() + .map(|value| match ZC::BATCH_PCS { + true => max_len_setup_commit_impl::( + prover_setup, + value.as_ref(), + ), + false => local_commit_impl::( + prover_setup.p_keys.get(&value.as_ref().len()).unwrap(), + value.as_ref(), + ), + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + (Some(commitments), Some(states)) + } else { + (None, None) + }; + commit_timer.stop(); + + let mut vals_ref = vec![]; + let mut challenges = vec![]; + + let prove_timer = Timer::new("Prove all kernels", global_mpi_config.is_root()); + let proofs = + computation_graph + .proof_templates() + .iter() + .map(|template| { + let commitment_values = template + .commitment_indices() + .iter() + .map(|&idx| values[idx].as_ref()) + .collect::>(); + + let single_kernel_gkr_timer = + Timer::new("small gkr kernel", global_mpi_config.is_root()); + let gkr_end_state = prove_kernel_gkr_no_oversubscribe::< + GetFieldConfig, + GetTranscript, + ZC::ECCConfig, + >( + global_mpi_config, + &computation_graph.kernels()[template.kernel_id()], + &commitment_values, + next_power_of_two(template.parallel_count()), + template.is_broadcast(), + n_bytes_profiler, + ); + single_kernel_gkr_timer.stop(); + + match ZC::BATCH_PCS { + true => { + if global_mpi_config.is_root() { + let (mut transcript, challenge) = gkr_end_state.unwrap(); + assert!(challenge.challenge_y().is_none()); + let challenge = challenge.challenge_x(); + + let (local_vals_ref, local_challenges) = + extract_pcs_claims::( + &commitment_values, + &challenge, + template.is_broadcast(), + next_power_of_two(template.parallel_count()), + ); + + vals_ref.extend(local_vals_ref); + challenges.extend(local_challenges); + + Some(ExpanderProof { + data: vec![transcript.finalize_and_get_proof()], + }) + } else { + None + } + } + false => { + if global_mpi_config.is_root() { + let pcs_open_timer = Timer::new("pcs open", true); + let (mut transcript, challenge) = gkr_end_state.unwrap(); + let challenges = if let Some(challenge_y) = challenge.challenge_y() { + vec![challenge.challenge_x(), challenge_y] + } else { + vec![challenge.challenge_x()] + }; + + challenges.iter().for_each(|c| { + partition_single_gkr_claim_and_open_pcs_mpi::( + prover_setup, + &commitment_values, + &template + .commitment_indices() + .iter() + .map(|&idx| &states.as_ref().unwrap()[idx]) + .collect::>(), + c, + template.is_broadcast(), + &mut transcript, + ); + }); + + pcs_open_timer.stop(); + Some(ExpanderProof { + data: vec![transcript.finalize_and_get_proof()], + }) + } else { + None + } + } + } + }) + .collect::>(); + prove_timer.stop(); + + match ZC::BATCH_PCS { + true => { + if global_mpi_config.is_root() { + let mut proofs = proofs.into_iter().map(|p| p.unwrap()).collect::>(); + + let pcs_opening_timer = Timer::new("Batch PCS Opening for all kernels", true); + let pcs_batch_opening = open_defered_pcs::( + prover_setup, + &vals_ref, + &challenges, + ); + pcs_opening_timer.stop(); + + proofs.push(pcs_batch_opening); + Some(CombinedProof { + commitments: commitments.unwrap(), + proofs, + }) + } else { + None + } + } + false => { + if global_mpi_config.is_root() { + let proofs = proofs.into_iter().map(|p| p.unwrap()).collect::>(); + Some(CombinedProof { + commitments: commitments.unwrap(), + proofs, + }) + } else { + None + } + } + } +} + +#[allow(clippy::too_many_arguments)] +pub fn prove_kernel_gkr_no_oversubscribe( + mpi_config: &MPIConfig<'static>, + kernel: &Kernel, + commitments_values: &[&[F::SimdCircuitField]], + parallel_count: usize, + is_broadcast: &[bool], + n_bytes_profiler: &mut NBytesProfiler, +) -> Option<(T, ExpanderDualVarChallenge)> +where + F: FieldEngine, + T: Transcript, + ECCConfig: Config, +{ + let local_mpi_config = generate_local_mpi_config(mpi_config, parallel_count); + + local_mpi_config.as_ref()?; + + let local_mpi_config = local_mpi_config.unwrap(); + let local_world_size = local_mpi_config.world_size(); + + let n_local_copies = parallel_count / local_world_size; + match n_local_copies { + 1 => prove_kernel_gkr_internal::( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 2 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 4 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 8 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 16 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 32 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 64 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 128 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 256 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 512 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 1024 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + 2048 => prove_kernel_gkr_internal::, T, ECCConfig>( + &local_mpi_config, + kernel, + commitments_values, + parallel_count, + is_broadcast, + n_bytes_profiler, + ), + _ => { + panic!("Unsupported parallel count: {parallel_count}"); + } + } +} + +pub fn prove_kernel_gkr_internal( + mpi_config: &MPIConfig<'static>, + kernel: &Kernel, + commitments_values: &[&[FBasic::SimdCircuitField]], + parallel_count: usize, + is_broadcast: &[bool], + n_bytes_profiler: &mut NBytesProfiler, +) -> Option<(T, ExpanderDualVarChallenge)> +where + FBasic: FieldEngine, + FMulti: + FieldEngine, + T: Transcript, + ECCConfig: Config, +{ + let world_rank = mpi_config.world_rank(); + let world_size = mpi_config.world_size(); + let n_copies = parallel_count / world_size; + + let local_commitment_values = get_local_vals_multi_copies( + commitments_values, + is_broadcast, + world_rank, + n_copies, + parallel_count, + ); + + let (mut expander_circuit, mut prover_scratch) = + prepare_expander_circuit::(kernel, world_size); + + let mut transcript = T::new(); + let challenge = prove_gkr_with_local_vals_multi_copies::( + &mut expander_circuit, + &mut prover_scratch, + &local_commitment_values, + kernel.layered_circuit_input(), + &mut transcript, + mpi_config, + n_bytes_profiler, + ); + + Some((transcript, challenge)) +} + +pub fn get_local_vals_multi_copies<'vals_life, F: Field>( + global_vals: &'vals_life [impl AsRef<[F]>], + is_broadcast: &[bool], + local_world_rank: usize, + n_copies: usize, + parallel_count: usize, +) -> Vec> { + let parallel_indices = (0..n_copies) + .map(|i| local_world_rank * n_copies + i) + .collect::>(); + + parallel_indices + .iter() + .map(|¶llel_index| { + get_local_vals(global_vals, is_broadcast, parallel_index, parallel_count) + }) + .collect::>() +} + +pub fn prove_gkr_with_local_vals_multi_copies( + expander_circuit: &mut expander_circuit::Circuit, + prover_scratch: &mut sumcheck::ProverScratchPad, + local_commitment_values_multi_copies: &[Vec>], + partition_info: &[LayeredCircuitInputVec], + transcript: &mut T, + mpi_config: &MPIConfig, + _n_bytes_profiler: &mut NBytesProfiler, +) -> ExpanderDualVarChallenge +where + FBasic: FieldEngine, + FMulti: + FieldEngine, + T: Transcript, +{ + let input_vals_multi_copies = local_commitment_values_multi_copies + .iter() + .map(|local_commitment_values| { + prepare_inputs_with_local_vals( + 1 << expander_circuit.log_input_size(), + partition_info, + local_commitment_values, + ) + }) + .collect::>(); + + let mut input_vals = + vec![FMulti::SimdCircuitField::ZERO; 1 << expander_circuit.log_input_size()]; + + for (i, vals) in input_vals.iter_mut().enumerate() { + let vals_unpacked = input_vals_multi_copies + .iter() + .flat_map(|v| v[i].unpack()) + .collect::>(); + *vals = FMulti::SimdCircuitField::pack(&vals_unpacked); + } + expander_circuit.layers[0].input_vals = input_vals; + + expander_circuit.fill_rnd_coefs(transcript); + expander_circuit.evaluate(); + + #[cfg(feature = "zkcuda_profile")] + { + expander_circuit.layers.iter().for_each(|layer| { + layer.input_vals.iter().for_each(|val| { + val.unpack().iter().for_each(|fr| { + _n_bytes_profiler.add_fr(*fr); + }) + }); + }); + } + + let (claimed_v, challenge) = + gkr::gkr_prove(expander_circuit, prover_scratch, transcript, mpi_config); + assert_eq!(claimed_v, FBasic::ChallengeField::from(0u32)); + + let n_simd_vars_basic = FBasic::SimdCircuitField::PACK_SIZE.ilog2() as usize; + + ExpanderDualVarChallenge { + rz_0: challenge.rz_0, + rz_1: challenge.rz_1, + r_simd: challenge.r_simd[..n_simd_vars_basic].to_vec(), + r_mpi: { + let mut v = challenge.r_simd[n_simd_vars_basic..].to_vec(); + v.extend(&challenge.r_mpi); + v + }, + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs new file mode 100644 index 00000000..5c402ac7 --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_bin.rs @@ -0,0 +1,54 @@ +use std::str::FromStr; + +use clap::Parser; +use expander_compiler::zkcuda::proving_system::{ + expander::config::{ + ZKCudaBN254Hyrax, ZKCudaBN254HyraxBatchPCS, ZKCudaBN254KZG, ZKCudaBN254KZGBatchPCS, + }, + expander_parallelized::server_ctrl::{serve, ExpanderExecArgs}, + ExpanderNoOverSubscribe, +}; +use gkr_engine::PolynomialCommitmentType; + +#[tokio::main] +pub async fn main() { + let expander_exec_args = ExpanderExecArgs::parse(); + assert_eq!( + expander_exec_args.fiat_shamir_hash, "SHA256", + "Only SHA256 is supported for now" + ); + + let pcs_type = PolynomialCommitmentType::from_str(&expander_exec_args.poly_commit).unwrap(); + + match (expander_exec_args.field_type.as_str(), pcs_type) { + ("BN254", PolynomialCommitmentType::Hyrax) => { + if expander_exec_args.batch_pcs { + serve::<_, _, ExpanderNoOverSubscribe>( + expander_exec_args.port_number, + ) + .await; + } else { + serve::<_, _, ExpanderNoOverSubscribe>( + expander_exec_args.port_number, + ) + .await; + } + } + ("BN254", PolynomialCommitmentType::KZG) => { + if expander_exec_args.batch_pcs { + serve::<_, _, ExpanderNoOverSubscribe>( + expander_exec_args.port_number, + ) + .await; + } else { + serve::<_, _, ExpanderNoOverSubscribe>( + expander_exec_args.port_number, + ) + .await; + } + } + (field_type, pcs_type) => { + panic!("Combination of {field_type:?} and {pcs_type:?} not supported for no oversubscribe expander proving system."); + } + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs new file mode 100644 index 00000000..85c5955f --- /dev/null +++ b/expander_compiler/src/zkcuda/proving_system/expander_no_oversubscribe/server_fn.rs @@ -0,0 +1,101 @@ +use arith::Fr; +use gkr_engine::{FieldEngine, GKREngine, MPIConfig}; + +use crate::{ + frontend::SIMDField, + zkcuda::{ + context::ComputationGraph, + proving_system::{ + expander::{ + config::{GetFieldConfig, GetPCS, ZKCudaConfig}, + structs::{ExpanderProverSetup, ExpanderVerifierSetup}, + }, + expander_no_oversubscribe::{ + profiler::NBytesProfiler, prove_impl::mpi_prove_no_oversubscribe_impl, + }, + expander_parallelized::{server_ctrl::SharedMemoryWINWrapper, server_fns::ServerFns}, + CombinedProof, Expander, ExpanderNoOverSubscribe, ExpanderPCSDefered, + ParallelizedExpander, + }, + }, +}; + +impl ServerFns for ExpanderNoOverSubscribe +where + ::FieldConfig: FieldEngine, +{ + fn setup_request_handler( + global_mpi_config: &MPIConfig<'static>, + setup_file: Option, + computation_graph: &mut ComputationGraph, + prover_setup: &mut ExpanderProverSetup, GetPCS>, + verifier_setup: &mut ExpanderVerifierSetup, GetPCS>, + mpi_win: &mut Option, + ) { + match ZC::BATCH_PCS { + true => ExpanderPCSDefered::::setup_request_handler( + global_mpi_config, + setup_file, + computation_graph, + prover_setup, + verifier_setup, + mpi_win, + ), + false => ParallelizedExpander::::setup_request_handler( + global_mpi_config, + setup_file, + computation_graph, + prover_setup, + verifier_setup, + mpi_win, + ), + } + } + + fn prove_request_handler( + global_mpi_config: &MPIConfig<'static>, + prover_setup: &ExpanderProverSetup, GetPCS>, + computation_graph: &ComputationGraph, + values: &[impl AsRef<[SIMDField]>], + ) -> Option>> { + let mut n_bytes_profiler = NBytesProfiler::new(); + + #[cfg(feature = "zkcuda_profile")] + { + use arith::SimdField; + use gkr_engine::MPIEngine; + + values.iter().for_each(|vals| { + vals.as_ref().iter().for_each(|fr| { + let fr_unpacked = fr.unpack(); + assert!(fr_unpacked.len() == 1); + n_bytes_profiler.add_fr(fr_unpacked[0]); + }); + }); + if global_mpi_config.is_root() { + println!("NBytesProfiler stats before proving:"); + n_bytes_profiler.print_stats(); + } + } + + let proof = mpi_prove_no_oversubscribe_impl::( + global_mpi_config, + prover_setup, + computation_graph, + values, + &mut n_bytes_profiler, + ); + + #[cfg(feature = "zkcuda_profile")] + { + use gkr_engine::MPIEngine; + + if global_mpi_config.is_root() { + println!("NBytesProfiler stats after proving:"); + n_bytes_profiler.print_stats(); + } + } + + proof + } +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs index cfd77e38..81005317 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/api_parallel.rs @@ -14,8 +14,8 @@ use crate::zkcuda::proving_system::{CombinedProof, ProvingSystem}; use super::super::Expander; -use gkr_engine::{FieldEngine, GKREngine}; -use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use expander_utils::timer::Timer; +use gkr_engine::GKREngine; pub struct ParallelizedExpander { _config: std::marker::PhantomData, @@ -23,11 +23,9 @@ pub struct ParallelizedExpander { impl> ProvingSystem for ParallelizedExpander -where - C::FieldConfig: FieldEngine, { - type ProverSetup = ExpanderProverSetup; - type VerifierSetup = ExpanderVerifierSetup; + type ProverSetup = ExpanderProverSetup; + type VerifierSetup = ExpanderVerifierSetup; type Proof = CombinedProof>; fn setup( @@ -35,13 +33,18 @@ where ) -> (Self::ProverSetup, Self::VerifierSetup) { let server_binary = client_parse_args().unwrap_or("../target/release/expander_server".to_owned()); - client_launch_server_and_setup::(&server_binary, computation_graph) + client_launch_server_and_setup::( + &server_binary, + computation_graph, + true, + false, + ) } fn prove( _prover_setup: &Self::ProverSetup, _computation_graph: &crate::zkcuda::context::ComputationGraph, - device_memories: &[Vec>], + device_memories: Vec>>, ) -> Self::Proof { client_send_witness_and_prove(device_memories) } @@ -51,10 +54,11 @@ where computation_graph: &ComputationGraph, proof: &Self::Proof, ) -> bool { + let verification_timer = Timer::new("Verify all kernels", true); let verified = proof .proofs - .par_iter() - .zip(computation_graph.proof_templates().par_iter()) + .iter() + .zip(computation_graph.proof_templates().iter()) .map(|(local_proof, template)| { let local_commitments = template .commitment_indices() @@ -72,6 +76,7 @@ where ) }) .collect::>(); + verification_timer.stop(); verified.iter().all(|x| *x) } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs index e14aba1c..42315b39 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/client_utils.rs @@ -2,7 +2,7 @@ use std::fs; use crate::{ frontend::{Config, SIMDField}, - utils::misc::next_power_of_two, + utils::misc::{next_power_of_two, prev_power_of_two}, zkcuda::{ context::ComputationGraph, proving_system::{ @@ -19,7 +19,7 @@ use crate::{ use super::server_ctrl::{RequestType, SERVER_IP, SERVER_PORT}; use expander_utils::timer::Timer; -use gkr_engine::{FieldEngine, GKREngine}; +use gkr_engine::GKREngine; use reqwest::Client; use serdes::ExpSerde; @@ -77,20 +77,23 @@ pub fn client_parse_args() -> Option { pub fn client_launch_server_and_setup( server_binary: &str, computation_graph: &ComputationGraph, + allow_oversubscribe: bool, + batch_pcs: bool, ) -> ( - ExpanderProverSetup, - ExpanderVerifierSetup, + ExpanderProverSetup, + ExpanderVerifierSetup, ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let setup_timer = Timer::new("setup", true); println!("Starting server with binary: {server_binary}"); let mut bytes = vec![]; computation_graph.serialize_into(&mut bytes).unwrap(); + println!("Serialized computation graph, size: {}", bytes.len()); + // append current timestamp to the file name to avoid conflicts let setup_filename = format!( "/tmp/computation_graph_{}.bin", @@ -104,10 +107,22 @@ where .map(|t| t.parallel_count()) .max() .unwrap_or(1); + let max_parallel_count = next_power_of_two(max_parallel_count); + + let mpi_size = if allow_oversubscribe { + max_parallel_count + } else { + let num_cpus = prev_power_of_two(num_cpus::get_physical()); + if max_parallel_count > num_cpus { + num_cpus + } else { + max_parallel_count + } + }; let port = parse_port_number(); let server_url = format!("{SERVER_IP}:{port}"); - start_server::(server_binary, next_power_of_two(max_parallel_count), port); + start_server::(server_binary, mpi_size, port, batch_pcs); // Keep trying until the server is ready loop { @@ -125,18 +140,15 @@ where } pub fn client_send_witness_and_prove( - device_memories: &[Vec>], + device_memories: Vec>>, ) -> CombinedProof> where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let timer = Timer::new("prove", true); - SharedMemoryEngine::write_witness_to_shared_memory::( - &device_memories.iter().map(|m| &m[..]).collect::>(), - ); + SharedMemoryEngine::write_witness_to_shared_memory::(device_memories); wait_async(ClientHttpHelper::request_prove()); let proof = SharedMemoryEngine::read_proof_from_shared_memory(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs index 48dd79de..48b895a7 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/cmd_utils.rs @@ -2,21 +2,23 @@ use gkr_engine::{ExpanderPCS, FieldEngine, FieldType, GKREngine, PolynomialCommi use std::process::Command; #[allow(clippy::zombie_processes)] -pub fn start_server(binary: &str, max_parallel_count: usize, port_number: u16) -where - C::FieldConfig: FieldEngine, -{ +pub fn start_server( + binary: &str, + max_parallel_count: usize, + port_number: u16, + batch_pcs: bool, +) { let (overscribe, field_name, pcs_name) = parse_config::(max_parallel_count); + let batch_pcs_option = if batch_pcs { "--batch-pcs" } else { "" }; let cmd_str = format!( - "mpiexec -n {max_parallel_count} {overscribe} {binary} --field-type {field_name} --poly-commit {pcs_name} --port-number {port_number}" + "mpiexec -n {max_parallel_count} {overscribe} {binary} --field-type {field_name} --poly-commit {pcs_name} --port-number {port_number} {batch_pcs_option}" ); exec_command(&cmd_str, false); } fn parse_config(mpi_size: usize) -> (String, String, String) where - C::FieldConfig: FieldEngine, { let oversubscription = if mpi_size > num_cpus::get_physical() { println!("Warning: Not enough cores available for the requested number of processes. Using oversubscription."); @@ -34,7 +36,7 @@ where _ => panic!("Unsupported field type"), }; - let pcs_name = match >::PCS_TYPE { + let pcs_name = match >::PCS_TYPE { PolynomialCommitmentType::Raw => "Raw", PolynomialCommitmentType::Hyrax => "Hyrax", PolynomialCommitmentType::KZG => "KZG", @@ -50,6 +52,7 @@ where #[allow(clippy::zombie_processes)] fn exec_command(cmd: &str, wait_for_completion: bool) { + println!("Executing command: {cmd}"); let mut parts = cmd.split_whitespace(); let command = parts.next().unwrap(); let args: Vec<&str> = parts.collect(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs index a7d02721..5605daf0 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/prove_impl.rs @@ -28,20 +28,24 @@ use crate::{ pub fn mpi_prove_impl( global_mpi_config: &MPIConfig<'static>, - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, computation_graph: &ComputationGraph, values: &[impl AsRef<[SIMDField]>], ) -> Option>> where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); let (commitments, states) = if global_mpi_config.is_root() { let (commitments, states) = values .iter() - .map(|value| local_commit_impl::(prover_setup, value.as_ref())) + .map(|value| { + local_commit_impl::( + prover_setup.p_keys.get(&value.as_ref().len()).unwrap(), + value.as_ref(), + ) + }) .unzip::<_, _, Vec<_>, Vec<_>>(); (Some(commitments), Some(states)) } else { @@ -62,7 +66,7 @@ where let single_kernel_gkr_timer = Timer::new("small gkr kernel", global_mpi_config.is_root()); - let gkr_end_state = prove_kernel_gkr::( + let gkr_end_state = prove_kernel_gkr::( global_mpi_config, &computation_graph.kernels()[template.kernel_id()], &commitment_values, @@ -118,20 +122,17 @@ where } #[allow(clippy::too_many_arguments)] -pub fn prove_kernel_gkr( +pub fn prove_kernel_gkr( mpi_config: &MPIConfig<'static>, kernel: &Kernel, - commitments_values: &[&[SIMDField]], + commitments_values: &[&[F::SimdCircuitField]], parallel_count: usize, is_broadcast: &[bool], -) -> Option<( - C::TranscriptConfig, - ExpanderDualVarChallenge, -)> +) -> Option<(T, ExpanderDualVarChallenge)> where - C: GKREngine, - ECCConfig: Config, - C::FieldConfig: FieldEngine, + F: FieldEngine, + T: Transcript, + ECCConfig: Config, { let local_mpi_config = generate_local_mpi_config(mpi_config, parallel_count); @@ -149,10 +150,10 @@ where ); let (mut expander_circuit, mut prover_scratch) = - prepare_expander_circuit::(kernel, local_world_size); + prepare_expander_circuit::(kernel, local_world_size); - let mut transcript = C::TranscriptConfig::new(); - let challenge = prove_gkr_with_local_vals::( + let mut transcript = T::new(); + let challenge = prove_gkr_with_local_vals::( &mut expander_circuit, &mut prover_scratch, &local_commitment_values, @@ -190,16 +191,14 @@ pub fn partition_challenge_and_location_for_pcs_mpi( } #[allow(clippy::too_many_arguments)] -fn partition_single_gkr_claim_and_open_pcs_mpi( - p_keys: &ExpanderProverSetup, +pub fn partition_single_gkr_claim_and_open_pcs_mpi( + p_keys: &ExpanderProverSetup, commitments_values: &[impl AsRef<[SIMDField]>], - commitments_state: &[&ExpanderCommitmentState], + commitments_state: &[&ExpanderCommitmentState], gkr_challenge: &ExpanderSingleVarChallenge, is_broadcast: &[bool], transcript: &mut C::TranscriptConfig, -) where - C::FieldConfig: FieldEngine, -{ +) { let parallel_count = 1 << gkr_challenge.r_mpi.len(); for ((commitment_val, _state), ib) in commitments_values .iter() diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs index 53e76bbe..6250428d 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_bin.rs @@ -3,12 +3,15 @@ use std::str::FromStr; use clap::Parser; use expander_compiler::{ frontend::{BN254Config, BabyBearConfig, GF2Config, GoldilocksConfig, M31Config}, - zkcuda::proving_system::expander_parallelized::{ - server_ctrl::{serve, ExpanderExecArgs}, - ParallelizedExpander, + zkcuda::proving_system::{ + expander_parallelized::{ + server_ctrl::{serve, ExpanderExecArgs}, + ParallelizedExpander, + }, + expander_pcs_defered::BN254ConfigSha2UniKZG, }, }; -use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; +use gkr::BN254ConfigSha2Hyrax; use gkr_engine::PolynomialCommitmentType; #[tokio::main] @@ -55,7 +58,7 @@ pub async fn main() { .await; } ("BN254", PolynomialCommitmentType::KZG) => { - serve::>( + serve::>( expander_exec_args.port_number, ) .await; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs index 8ac1aaf9..a95202ec 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_ctrl.rs @@ -16,10 +16,10 @@ use mpi::ffi::MPI_Win; use mpi::topology::SimpleCommunicator; use mpi::traits::Communicator; -use crate::frontend::Config; +use crate::frontend::{Config, SIMDField}; use axum::{extract::State, Json}; -use gkr_engine::{FieldEngine, GKREngine, MPIConfig, MPIEngine}; +use gkr_engine::{GKREngine, MPIConfig, MPIEngine}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use std::net::{IpAddr, SocketAddr}; @@ -56,20 +56,16 @@ pub struct SharedMemoryWINWrapper { unsafe impl Send for SharedMemoryWINWrapper {} unsafe impl Sync for SharedMemoryWINWrapper {} -pub struct ServerState> -where - C::FieldConfig: FieldEngine, -{ +pub struct ServerState> { pub lock: Arc>, // For now we want to ensure that only one request is processed at a time pub global_mpi_config: MPIConfig<'static>, pub local_mpi_config: Option>, - pub prover_setup: Arc>>, - pub verifier_setup: - Arc>>, + pub prover_setup: Arc>>, + pub verifier_setup: Arc>>, pub computation_graph: Arc>>, - pub witness: Arc>>>, + pub witness: Arc>>>>, pub cg_shared_memory_win: Arc>>, // Shared memory for computation graph pub wt_shared_memory_win: Arc>>, // Shared memory for witness @@ -79,22 +75,16 @@ where unsafe impl> Send for ServerState -where - C::FieldConfig: FieldEngine, { } unsafe impl> Sync for ServerState -where - C::FieldConfig: FieldEngine, { } impl> Clone for ServerState -where - C::FieldConfig: FieldEngine, { fn clone(&self) -> Self { ServerState { @@ -119,7 +109,7 @@ pub async fn root_main( where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, + S: ServerFns, { let _lock = state.lock.lock().await; // Ensure only one request is processed at a time @@ -194,7 +184,7 @@ pub async fn worker_main( ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, + S: ServerFns, { loop { @@ -283,7 +273,7 @@ pub async fn serve(port_number: String) where C: GKREngine + 'static, ECCConfig: Config + 'static, - C::FieldConfig: FieldEngine, + S: ServerFns + 'static, { let global_mpi_config = unsafe { @@ -389,6 +379,10 @@ pub struct ExpanderExecArgs { pub poly_commit: String, /// The port number for the server to listen on. - #[arg(short, long, default_value = "Port")] + #[arg(short, long, default_value = "3000")] pub port_number: String, + + /// Whether to batch PCS opening in proving. + #[arg(short, long, default_value_t = false)] + pub batch_pcs: bool, } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs index 371f7150..37d5f8a3 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/server_fns.rs @@ -1,4 +1,4 @@ -use gkr_engine::{FieldEngine, GKREngine, MPIConfig, MPIEngine, MPISharedMemory}; +use gkr_engine::{GKREngine, MPIConfig, MPIEngine, MPISharedMemory}; use serdes::ExpSerde; use crate::{ @@ -23,27 +23,26 @@ pub trait ServerFns where C: gkr_engine::GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { fn setup_request_handler( global_mpi_config: &MPIConfig<'static>, setup_file: Option, computation_graph: &mut ComputationGraph, - prover_setup: &mut ExpanderProverSetup, - verifier_setup: &mut ExpanderVerifierSetup, + prover_setup: &mut ExpanderProverSetup, + verifier_setup: &mut ExpanderVerifierSetup, mpi_win: &mut Option, ); fn prove_request_handler( global_mpi_config: &MPIConfig<'static>, - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, computation_graph: &ComputationGraph, values: &[impl AsRef<[SIMDField]>], ) -> Option>>; fn setup_shared_witness( global_mpi_config: &MPIConfig<'static>, - witness_target: &mut Vec>, + witness_target: &mut Vec>>, mpi_shared_memory_win: &mut Option, ) { // dispose of the previous shared memory if it exists @@ -68,7 +67,7 @@ where fn shared_memory_clean_up( global_mpi_config: &MPIConfig<'static>, computation_graph: ComputationGraph, - witness: Vec>, + witness: Vec>>, cg_mpi_win: &mut Option, wt_mpi_win: &mut Option, ) { @@ -91,18 +90,15 @@ impl ServerFns for ParallelizedExpander where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { fn setup_request_handler( global_mpi_config: &MPIConfig<'static>, setup_file: Option, computation_graph: &mut ComputationGraph, - prover_setup: &mut ExpanderProverSetup, - verifier_setup: &mut ExpanderVerifierSetup, + prover_setup: &mut ExpanderProverSetup, + verifier_setup: &mut ExpanderVerifierSetup, mpi_win: &mut Option, - ) where - C::FieldConfig: FieldEngine, - { + ) { let setup_file = if global_mpi_config.is_root() { let setup_file = setup_file.expect("Setup file path must be provided"); broadcast_string(global_mpi_config, Some(setup_file)) @@ -119,14 +115,13 @@ where fn prove_request_handler( global_mpi_config: &MPIConfig<'static>, - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, computation_graph: &ComputationGraph, values: &[impl AsRef<[SIMDField]>], ) -> Option>> where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { mpi_prove_impl(global_mpi_config, prover_setup, computation_graph, values) } @@ -152,7 +147,6 @@ pub fn read_circuit( ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let computation_graph_bytes = std::fs::read(setup_file).expect("Failed to read computation graph from file"); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs index 53a4078f..648f33a8 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/shared_memory_utils.rs @@ -62,6 +62,8 @@ impl SharedMemoryEngine { .serialize_into(&mut buffer) .expect("Failed to serialize object"); + println!("Object size: {}", buffer.len()); + unsafe { Self::allocate_shared_memory_if_necessary(shared_memory_ref, name, buffer.len()); let object_ptr = shared_memory_ref.as_mut().unwrap().as_ptr(); @@ -88,16 +90,10 @@ impl SharedMemoryEngine { /// This impl block contains functions for reading/writing specific objects to shared memory. impl SharedMemoryEngine { - pub fn write_pcs_setup_to_shared_memory< - PCSField: Field, - F: FieldEngine, - PCS: ExpanderPCS, - >( - pcs_setup: &( - ExpanderProverSetup, - ExpanderVerifierSetup, - ), + pub fn write_pcs_setup_to_shared_memory>( + pcs_setup: &(ExpanderProverSetup, ExpanderVerifierSetup), ) { + println!("Writing PCS setup to shared memory..."); Self::write_object_to_shared_memory( pcs_setup, unsafe { &mut SHARED_MEMORY.pcs_setup }, @@ -105,26 +101,19 @@ impl SharedMemoryEngine { ); } - pub fn read_pcs_setup_from_shared_memory< - PCSField: Field, - F: FieldEngine, - PCS: ExpanderPCS, - >() -> ( - ExpanderProverSetup, - ExpanderVerifierSetup, - ) { + pub fn read_pcs_setup_from_shared_memory>( + ) -> (ExpanderProverSetup, ExpanderVerifierSetup) { Self::read_object_from_shared_memory("pcs_setup", 0) } - pub fn write_witness_to_shared_memory( - values: &[impl AsRef<[F::SimdCircuitField]>], - ) { + pub fn write_witness_to_shared_memory(values: Vec>) { let total_size = std::mem::size_of::() + values .iter() - .map(|v| std::mem::size_of::() + std::mem::size_of_val(v.as_ref())) + .map(|v| std::mem::size_of::() + std::mem::size_of_val(v.as_slice())) .sum::(); + println!("Writing witness to shared memory, total size: {total_size}"); unsafe { Self::allocate_shared_memory_if_necessary( &mut SHARED_MEMORY.witness, @@ -141,13 +130,13 @@ impl SharedMemoryEngine { ptr = ptr.add(std::mem::size_of::()); for vals in values { - let vals_len = vals.as_ref().len(); + let vals_len = vals.len(); let len_ptr = &vals_len as *const usize as *const u8; std::ptr::copy_nonoverlapping(len_ptr, ptr, std::mem::size_of::()); ptr = ptr.add(std::mem::size_of::()); - let vals_size = std::mem::size_of_val(vals.as_ref()); - std::ptr::copy_nonoverlapping(vals.as_ref().as_ptr() as *const u8, ptr, vals_size); + let vals_size = std::mem::size_of_val(vals.as_slice()); + std::ptr::copy_nonoverlapping(vals.as_ptr() as *const u8, ptr, vals_size); ptr = ptr.add(vals_size); } } @@ -220,9 +209,8 @@ impl SharedMemoryEngine { ECCConfig: Config, >( proof: &CombinedProof>, - ) where - C::FieldConfig: FieldEngine, - { + ) { + println!("Writing proof to shared memory..."); Self::write_object_to_shared_memory(proof, unsafe { &mut SHARED_MEMORY.proof }, "proof"); } @@ -230,9 +218,7 @@ impl SharedMemoryEngine { C: GKREngine, ECCConfig: Config, >() -> CombinedProof> - where - C::FieldConfig: FieldEngine, - { +where { Self::read_object_from_shared_memory("proof", 0) } } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs index f05f9898..5b1d28a6 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_parallelized/verify_impl.rs @@ -25,21 +25,19 @@ use crate::{ }; pub fn verify_kernel( - verifier_setup: &ExpanderVerifierSetup, + verifier_setup: &ExpanderVerifierSetup, kernel: &Kernel, proof: &ExpanderProof, - commitments: &[&ExpanderCommitment], + commitments: &[&ExpanderCommitment], parallel_count: usize, is_broadcast: &[bool], ) -> bool where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let timer = Timer::new("verify", true); - let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten::(); - expander_circuit.pre_process_gkr::(); + let mut expander_circuit = kernel.layered_circuit().export_to_expander_flatten(); let mut transcript = C::TranscriptConfig::new(); expander_circuit.fill_rnd_coefs(&mut transcript); @@ -84,10 +82,10 @@ where pub fn verify_pcs_opening_and_aggregation_mpi_impl( mut proof_reader: impl Read, kernel: &Kernel, - v_keys: &ExpanderVerifierSetup, + v_keys: &ExpanderVerifierSetup, challenge: &ExpanderSingleVarChallenge, y: &::ChallengeField, - commitments: &[&ExpanderCommitment], + commitments: &[&ExpanderCommitment], is_broadcast: &[bool], parallel_count: usize, transcript: &mut C::TranscriptConfig, @@ -95,7 +93,6 @@ pub fn verify_pcs_opening_and_aggregation_mpi_impl( where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let mut target_y = ::ChallengeField::ZERO; for ((input, commitment), ib) in kernel @@ -105,9 +102,9 @@ where .zip(is_broadcast) { let val_len = - as Commitment< - ECCConfig, - >>::vals_len(commitment); + as Commitment>::vals_len( + commitment, + ); let (challenge_for_pcs, component_idx_vars) = partition_challenge_and_location_for_pcs_mpi(challenge, val_len, parallel_count, *ib); @@ -143,11 +140,11 @@ where pub fn verify_pcs_opening_and_aggregation_mpi( mut proof_reader: impl Read, kernel: &Kernel, - v_keys: &ExpanderVerifierSetup, + v_keys: &ExpanderVerifierSetup, challenge: &ExpanderDualVarChallenge, claim_v0: ::ChallengeField, claim_v1: Option<::ChallengeField>, - commitments: &[&ExpanderCommitment], + commitments: &[&ExpanderCommitment], is_broadcast: &[bool], parallel_count: usize, transcript: &mut C::TranscriptConfig, @@ -155,7 +152,6 @@ pub fn verify_pcs_opening_and_aggregation_mpi( where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let challenges = if let Some(challenge_y) = challenge.challenge_y() { vec![challenge.challenge_x(), challenge_y] diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered.rs index 6ecc6058..9be9c805 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered.rs @@ -13,11 +13,13 @@ use gkr_hashers::SHA256hasher; use halo2curves::bn256::Bn256; use poly_commit::HyperUniKZGPCS; -pub struct BN254ConfigSha2UniKZG; +pub struct BN254ConfigSha2UniKZG<'a> { + _phantom: std::marker::PhantomData<&'a ()>, +} -impl GKREngine for BN254ConfigSha2UniKZG { +impl<'a> GKREngine for BN254ConfigSha2UniKZG<'a> { type FieldConfig = ::FieldConfig; - type MPIConfig = MPIConfig<'static>; + type MPIConfig = MPIConfig<'a>; type TranscriptConfig = BytesHashTranscript; type PCSConfig = HyperUniKZGPCS; const SCHEME: GKRScheme = GKRScheme::Vanilla; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs index 5a8fa237..0d4d9d39 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/api_pcs_defered.rs @@ -1,4 +1,4 @@ -use gkr_engine::{ExpanderPCS, FieldEngine, GKREngine}; +use gkr_engine::{ExpanderPCS, GKREngine}; use crate::{ frontend::{Config, SIMDField}, @@ -20,13 +20,13 @@ impl ProvingSystem for ExpanderPCSDefered where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, - >::Commitment: - AsRef<>::Commitment>, + + >::Commitment: + AsRef<>::Commitment>, { - type ProverSetup = ExpanderProverSetup; + type ProverSetup = ExpanderProverSetup; - type VerifierSetup = ExpanderVerifierSetup; + type VerifierSetup = ExpanderVerifierSetup; type Proof = CombinedProof>; @@ -35,13 +35,18 @@ where ) -> (Self::ProverSetup, Self::VerifierSetup) { let server_binary = client_parse_args() .unwrap_or("../target/release/expander_server_pcs_defered".to_owned()); - client_launch_server_and_setup::(&server_binary, computation_graph) + client_launch_server_and_setup::( + &server_binary, + computation_graph, + true, + true, + ) } fn prove( _prover_setup: &Self::ProverSetup, _computation_graph: &crate::zkcuda::context::ComputationGraph, - device_memories: &[Vec>], + device_memories: Vec>>, ) -> Self::Proof { client_send_witness_and_prove(device_memories) } diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs index 03195d4c..72545956 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/prove_impl.rs @@ -1,7 +1,7 @@ -use arith::Field; +use expander_utils::timer::Timer; use gkr_engine::{ - ExpanderPCS, ExpanderSingleVarChallenge, FieldEngine, GKREngine, MPIConfig, MPIEngine, - Proof as BytesProof, Transcript, + ExpanderPCS, ExpanderSingleVarChallenge, GKREngine, MPIConfig, MPIEngine, Proof as BytesProof, + Transcript, }; use polynomials::RefMultiLinearPoly; use serdes::ExpSerde; @@ -26,17 +26,16 @@ use crate::{ }, }; -pub fn pad_vals_and_commit( - prover_setup: &ExpanderProverSetup, +pub fn max_len_setup_commit_impl( + prover_setup: &ExpanderProverSetup, vals: &[SIMDField], ) -> ( - ExpanderCommitment, - ExpanderCommitmentState, + ExpanderCommitment, + ExpanderCommitmentState, ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { assert_eq!(prover_setup.p_keys.len(), 1); let len_to_commit = prover_setup.p_keys.keys().next().cloned().unwrap(); @@ -44,25 +43,21 @@ where let actual_len = vals.len(); assert!(len_to_commit >= actual_len); - // padding to max length and commit, this may be very inefficient - // TODO: optimize this - let mut vals = vals.to_vec(); - vals.resize(len_to_commit, SIMDField::::ZERO); - let (mut commitment, state) = local_commit_impl::(prover_setup, &vals); + let (mut commitment, state) = + local_commit_impl::(prover_setup.p_keys.get(&len_to_commit).unwrap(), vals); commitment.vals_len = actual_len; // Store the actual length in the commitment (commitment, state) } pub fn open_defered_pcs( - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, vals: &[&[SIMDField]], challenges: &[ExpanderSingleVarChallenge], ) -> ExpanderProof where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { // TODO: Efficiency let polys: Vec<_> = vals @@ -73,26 +68,23 @@ where // TODO: Soundness let mut transcript = C::TranscriptConfig::new(); let max_length = prover_setup.p_keys.keys().max().cloned().unwrap_or(0); - let params = >::gen_params( - max_length.ilog2() as usize, - 1, - ); - let scratch_pad = >::init_scratch_pad( + let params = + >::gen_params(max_length.ilog2() as usize, 1); + let scratch_pad = >::init_scratch_pad( ¶ms, &MPIConfig::prover_new(None, None), ); transcript.lock_proof(); - let (vals, opening) = - >::multi_points_batch_open( - ¶ms, - &MPIConfig::prover_new(None, None), - prover_setup.p_keys.get(&max_length).unwrap(), - &polys, - challenges, - &scratch_pad, - &mut transcript, - ); + let (vals, opening) = >::multi_points_batch_open( + ¶ms, + &MPIConfig::prover_new(None, None), + prover_setup.p_keys.get(&max_length).unwrap(), + &polys, + challenges, + &scratch_pad, + &mut transcript, + ); transcript.unlock_proof(); let mut bytes = vec![]; @@ -106,28 +98,33 @@ where pub fn mpi_prove_with_pcs_defered( global_mpi_config: &MPIConfig<'static>, - prover_setup: &ExpanderProverSetup, + prover_setup: &ExpanderProverSetup, computation_graph: &ComputationGraph, values: &[impl AsRef<[SIMDField]>], ) -> Option>> where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { + let commit_timer = Timer::new("Commit to all input", global_mpi_config.is_root()); let (commitments, _states) = if global_mpi_config.is_root() { let (commitments, states) = values .iter() - .map(|value| pad_vals_and_commit::(prover_setup, value.as_ref())) + .map(|value| max_len_setup_commit_impl::(prover_setup, value.as_ref())) .unzip::<_, _, Vec<_>, Vec<_>>(); (Some(commitments), Some(states)) } else { (None, None) }; + commit_timer.stop(); let mut vals_ref = vec![]; let mut challenges = vec![]; + let prove_timer = Timer::new( + "Prove all kernels (NO PCS Opening)", + global_mpi_config.is_root(), + ); let proofs = computation_graph .proof_templates() .iter() @@ -138,7 +135,7 @@ where .map(|&idx| values[idx].as_ref()) .collect::>(); - let gkr_end_state = prove_kernel_gkr::( + let gkr_end_state = prove_kernel_gkr::( global_mpi_config, &computation_graph.kernels()[template.kernel_id()], &commitment_values, @@ -169,12 +166,16 @@ where } }) .collect::>(); + prove_timer.stop(); if global_mpi_config.is_root() { let mut proofs = proofs.into_iter().map(|p| p.unwrap()).collect::>(); + let pcs_opening_timer = Timer::new("Batch PCS Opening for all kernels", true); let pcs_batch_opening = open_defered_pcs::(prover_setup, &vals_ref, &challenges); + pcs_opening_timer.stop(); + proofs.push(pcs_batch_opening); Some(CombinedProof { commitments: commitments.unwrap(), @@ -195,7 +196,6 @@ pub fn extract_pcs_claims<'a, C: GKREngine>( Vec>, ) where - C::FieldConfig: FieldEngine, { let mut commitment_values_rt = vec![]; let mut challenges = vec![]; diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs index edc311e3..a34894e0 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/server_fns.rs @@ -1,4 +1,4 @@ -use gkr_engine::{FieldEngine, GKREngine, MPIEngine}; +use gkr_engine::{GKREngine, MPIEngine}; use crate::{ frontend::Config, @@ -22,14 +22,13 @@ impl ServerFns for ExpanderPCSDefered where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { fn setup_request_handler( global_mpi_config: &gkr_engine::MPIConfig<'static>, setup_file: Option, computation_graph: &mut ComputationGraph, - prover_setup: &mut ExpanderProverSetup, - verifier_setup: &mut ExpanderVerifierSetup, + prover_setup: &mut ExpanderProverSetup, + verifier_setup: &mut ExpanderVerifierSetup, mpi_win: &mut Option, ) { let setup_file = if global_mpi_config.is_root() { @@ -50,7 +49,6 @@ where fn prove_request_handler( global_mpi_config: &gkr_engine::MPIConfig<'static>, prover_setup: &ExpanderProverSetup< - ::PCSField, ::FieldConfig, ::PCSConfig, >, diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/setup_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/setup_impl.rs index 6aaf73fd..dba77cf4 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/setup_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/setup_impl.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use gkr_engine::{FieldEngine, GKREngine, MPIConfig}; +use gkr_engine::{GKREngine, MPIConfig}; use crate::{ frontend::Config, @@ -16,13 +16,12 @@ use crate::{ pub fn pcs_setup_max_length_only( computation_graph: &ComputationGraph, ) -> ( - ExpanderProverSetup, - ExpanderVerifierSetup, + ExpanderProverSetup, + ExpanderVerifierSetup, ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let mut p_keys = HashMap::new(); let mut v_keys = HashMap::new(); diff --git a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs index 86deb072..b06be3df 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_pcs_defered/verify_impl.rs @@ -1,12 +1,12 @@ use std::io::Cursor; use arith::Field; +use expander_utils::timer::Timer; use gkr::gkr_verify; use gkr_engine::{ ExpanderDualVarChallenge, ExpanderPCS, ExpanderSingleVarChallenge, FieldEngine, GKREngine, Proof as BytesProof, Transcript, }; -use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; use serdes::ExpSerde; use crate::{ @@ -24,27 +24,26 @@ use crate::{ }; fn verifier_extract_pcs_claims<'a, C, ECCConfig>( - commitments: &[&'a ExpanderCommitment], + commitments: &[&'a ExpanderCommitment], gkr_challenge: &ExpanderSingleVarChallenge, is_broadcast: &[bool], parallel_count: usize, ) -> ( - Vec<&'a ExpanderCommitment>, + Vec<&'a ExpanderCommitment>, Vec>, ) where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { let mut commitments_rt = vec![]; let mut challenges = vec![]; for (&commitment, ib) in commitments.iter().zip(is_broadcast) { let val_len = - as Commitment< - ECCConfig, - >>::vals_len(commitment); + as Commitment>::vals_len( + commitment, + ); let (challenge_for_pcs, _) = partition_challenge_and_location_for_pcs_mpi( gkr_challenge, val_len, @@ -67,10 +66,8 @@ pub fn verify_gkr( where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, { - let mut expander_circuit = kernel.layered_circuit().export_to_expander().flatten::(); - expander_circuit.pre_process_gkr::(); + let mut expander_circuit = kernel.layered_circuit().export_to_expander_flatten(); let mut transcript = C::TranscriptConfig::new(); expander_circuit.fill_rnd_coefs(&mut transcript); @@ -95,21 +92,20 @@ where pub fn verify_defered_pcs_opening( proof: &BytesProof, - verifier_setup: &ExpanderVerifierSetup, - commitments: &[&ExpanderCommitment], + verifier_setup: &ExpanderVerifierSetup, + commitments: &[&ExpanderCommitment], challenges: &[ExpanderSingleVarChallenge], ) -> bool where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, - >::Commitment: - AsRef<>::Commitment>, + + >::Commitment: + AsRef<>::Commitment>, { let mut transcript = C::TranscriptConfig::new(); let max_num_vars = verifier_setup.v_keys.keys().max().cloned().unwrap_or(0); - let params = - >::gen_params(max_num_vars, 1); + let params = >::gen_params(max_num_vars, 1); let mut defered_proof_bytes = proof.bytes.clone(); let mut cursor = Cursor::new(&mut defered_proof_bytes); @@ -122,45 +118,44 @@ where Vec::<::ChallengeField>::deserialize_from(&mut cursor) .unwrap(); let opening = - >::BatchOpening::deserialize_from( - &mut cursor, - ) - .unwrap(); + >::BatchOpening::deserialize_from(&mut cursor) + .unwrap(); transcript.lock_proof(); - let pcs_verified = - >::multi_points_batch_verify( - ¶ms, - verifier_setup.v_keys.get(&max_num_vars).unwrap(), - &commitments, - challenges, - &vals, - &opening, - &mut transcript, - ); + let pcs_verified = >::multi_points_batch_verify( + ¶ms, + verifier_setup.v_keys.get(&max_num_vars).unwrap(), + &commitments, + challenges, + &vals, + &opening, + &mut transcript, + ); transcript.unlock_proof(); pcs_verified } pub fn verify( - verifier_setup: &ExpanderVerifierSetup, + verifier_setup: &ExpanderVerifierSetup, computation_graph: &ComputationGraph, mut proof: CombinedProof>, ) -> bool where C: GKREngine, ECCConfig: Config, - C::FieldConfig: FieldEngine, - >::Commitment: - AsRef<>::Commitment>, + + >::Commitment: + AsRef<>::Commitment>, { + let verification_timer = Timer::new("Total Verification", true); let pcs_batch_opening = proof.proofs.pop().unwrap(); + let gkr_verification_timer = Timer::new("GKR Verification", true); let verified_with_pcs_claims = proof .proofs - .par_iter() - .zip(computation_graph.proof_templates().par_iter()) + .iter() + .zip(computation_graph.proof_templates().iter()) .map(|(local_proof, template)| { let local_commitments = template .commitment_indices() @@ -193,7 +188,9 @@ where println!("Failed to verify GKR proofs"); return false; } + gkr_verification_timer.stop(); + let pcs_verification_timer = Timer::new("PCS Verification", true); let commitments_ref = verified_with_pcs_claims .iter() .flat_map(|(_, c, _)| c) @@ -211,6 +208,8 @@ where &commitments_ref, &challenges, ); + pcs_verification_timer.stop(); + verification_timer.stop(); gkr_verified && pcs_verified } diff --git a/expander_compiler/src/zkcuda/proving_system/traits.rs b/expander_compiler/src/zkcuda/proving_system/traits.rs index cc791c50..539a0f43 100644 --- a/expander_compiler/src/zkcuda/proving_system/traits.rs +++ b/expander_compiler/src/zkcuda/proving_system/traits.rs @@ -70,7 +70,7 @@ pub trait ProvingSystem { fn prove( prover_setup: &Self::ProverSetup, computation_graph: &ComputationGraph, - device_memories: &[Vec>], + device_memories: Vec>>, ) -> Self::Proof; fn verify( diff --git a/expander_compiler/src/zkcuda/shape.rs b/expander_compiler/src/zkcuda/shape.rs index 443651d2..651a962b 100644 --- a/expander_compiler/src/zkcuda/shape.rs +++ b/expander_compiler/src/zkcuda/shape.rs @@ -84,21 +84,19 @@ impl Entry { return shape.to_vec(); } let mut segments = vec![]; - let mut cur_prod = 1; - let mut target = 1; - let mut self_shape_iter = self.shape.iter(); - for &x in shape.iter() { - if cur_prod == target { - cur_prod = x.0; - target = *self_shape_iter.next().unwrap(); - segments.push(vec![x]); - } else { - cur_prod *= x.0; - segments.last_mut().unwrap().push(x); + let mut shape_iter = shape.iter(); + for &x in self.shape.iter() { + let mut cur_prod = 1; + segments.push(vec![]); + while let Some(y) = shape_iter.next() { + cur_prod *= y.0; + segments.last_mut().unwrap().push(*y); + if cur_prod == x { + break; + } } + assert_eq!(cur_prod, x); } - assert_eq!(cur_prod, target); - assert_eq!(self_shape_iter.next(), None); let mut res = Vec::with_capacity(shape.len()); for i in self.axes.as_ref().unwrap() { res.extend(segments[*i].iter()); @@ -114,6 +112,12 @@ impl Entry { let mut cur_ts_prod = 1; let mut cur_ts_idx = 0; for &x in products.iter().skip(1) { + while ts[cur_ts_idx] == 1 && cur_ts_idx < ts.len() { + cur_ts_idx += 1; + } + if cur_ts_idx >= ts.len() { + break; + } segments_in_ts[self.axes.as_ref().unwrap()[cur_ts_idx]].push(x / cur_ts_prod); if x == cur_ts_prod * ts[cur_ts_idx] { cur_ts_prod = x; @@ -161,9 +165,13 @@ pub fn prefix_products(shape: &[usize]) -> Vec { } pub fn prefix_products_to_shape(products: &[usize]) -> Vec { - let mut shape = Vec::with_capacity(products.len() - 1); + // let mut shape = Vec::with_capacity(products.len() - 1); + // for i in 1..products.len() { + // shape.push(products[i] / products[i - 1]); + // } + let mut shape = products.to_vec(); for i in 1..products.len() { - shape.push(products[i] / products[i - 1]); + shape[i] /= products[i - 1]; } shape } @@ -268,9 +276,9 @@ impl ShapeHistory { cur = if e.axes.as_ref().is_none() { cur } else if cur.is_none() { - Some(e.transpose_shape(&initial_shape())) + Some(e.minimize(false).transpose_shape(&initial_shape())) } else { - Some(e.transpose_shape(&cur.unwrap())) + Some(e.minimize(false).transpose_shape(&cur.unwrap())) }; } let new_shape_and_id = match cur { diff --git a/expander_compiler/src/zkcuda/tests.rs b/expander_compiler/src/zkcuda/tests.rs index ad7c609a..715dd77d 100644 --- a/expander_compiler/src/zkcuda/tests.rs +++ b/expander_compiler/src/zkcuda/tests.rs @@ -119,7 +119,7 @@ fn context_shape_test_1_impl>() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); P::post_process(); diff --git a/expander_compiler/tests/example.rs b/expander_compiler/tests/circuit/example.rs similarity index 100% rename from expander_compiler/tests/example.rs rename to expander_compiler/tests/circuit/example.rs diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/circuit/example_call_expander.rs similarity index 91% rename from expander_compiler/tests/example_call_expander.rs rename to expander_compiler/tests/circuit/example_call_expander.rs index e38f88e0..1eaf8e05 100644 --- a/expander_compiler/tests/example_call_expander.rs +++ b/expander_compiler/tests/circuit/example_call_expander.rs @@ -2,7 +2,6 @@ use arith::Field; use arith::SimdField; use expander_binary::executor; use expander_compiler::frontend::*; -use gkr_engine::FieldEngine; use gkr_engine::MPIConfig; use rand::SeedableRng; @@ -21,11 +20,7 @@ impl Define for Circuit { } } -fn example() -where - C::PCSField: SimdField::CircuitField>, - C::FieldConfig: FieldEngine, -{ +fn example() { let n_witnesses = SIMDField::::PACK_SIZE; println!("n_witnesses: {}", n_witnesses); let compile_result: CompileResult = diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/circuit/keccak_gf2.rs similarity index 100% rename from expander_compiler/tests/keccak_gf2.rs rename to expander_compiler/tests/circuit/keccak_gf2.rs diff --git a/expander_compiler/tests/keccak_gf2_full.rs b/expander_compiler/tests/circuit/keccak_gf2_full.rs similarity index 100% rename from expander_compiler/tests/keccak_gf2_full.rs rename to expander_compiler/tests/circuit/keccak_gf2_full.rs diff --git a/expander_compiler/tests/keccak_gf2_full_crosslayer.rs b/expander_compiler/tests/circuit/keccak_gf2_full_crosslayer.rs similarity index 100% rename from expander_compiler/tests/keccak_gf2_full_crosslayer.rs rename to expander_compiler/tests/circuit/keccak_gf2_full_crosslayer.rs diff --git a/expander_compiler/tests/keccak_gf2_vec.rs b/expander_compiler/tests/circuit/keccak_gf2_vec.rs similarity index 100% rename from expander_compiler/tests/keccak_gf2_vec.rs rename to expander_compiler/tests/circuit/keccak_gf2_vec.rs diff --git a/expander_compiler/tests/keccak_non_gf2.rs b/expander_compiler/tests/circuit/keccak_non_gf2.rs similarity index 100% rename from expander_compiler/tests/keccak_non_gf2.rs rename to expander_compiler/tests/circuit/keccak_non_gf2.rs diff --git a/expander_compiler/tests/circuit/mod.rs b/expander_compiler/tests/circuit/mod.rs new file mode 100644 index 00000000..c4216d8a --- /dev/null +++ b/expander_compiler/tests/circuit/mod.rs @@ -0,0 +1,16 @@ +mod example; +mod example_call_expander; +mod keccak_gf2; +mod keccak_gf2_full; +mod keccak_gf2_full_crosslayer; +mod keccak_gf2_vec; +mod keccak_non_gf2; + +mod mul_fanout_limit; +mod multithreading_witness; + +mod simple_add_m31; +mod sub_circuit_macro; +mod to_binary_builtin; +mod to_binary_hint; +mod to_binary_unconstrained_api; diff --git a/expander_compiler/tests/mul_fanout_limit.rs b/expander_compiler/tests/circuit/mul_fanout_limit.rs similarity index 100% rename from expander_compiler/tests/mul_fanout_limit.rs rename to expander_compiler/tests/circuit/mul_fanout_limit.rs diff --git a/expander_compiler/tests/multithreading_witness.rs b/expander_compiler/tests/circuit/multithreading_witness.rs similarity index 100% rename from expander_compiler/tests/multithreading_witness.rs rename to expander_compiler/tests/circuit/multithreading_witness.rs diff --git a/expander_compiler/tests/rsa_mul.py b/expander_compiler/tests/circuit/rsa_mul.py similarity index 100% rename from expander_compiler/tests/rsa_mul.py rename to expander_compiler/tests/circuit/rsa_mul.py diff --git a/expander_compiler/tests/simple_add_m31.rs b/expander_compiler/tests/circuit/simple_add_m31.rs similarity index 100% rename from expander_compiler/tests/simple_add_m31.rs rename to expander_compiler/tests/circuit/simple_add_m31.rs diff --git a/expander_compiler/tests/sub_circuit_macro.rs b/expander_compiler/tests/circuit/sub_circuit_macro.rs similarity index 100% rename from expander_compiler/tests/sub_circuit_macro.rs rename to expander_compiler/tests/circuit/sub_circuit_macro.rs diff --git a/expander_compiler/tests/to_binary_builtin.rs b/expander_compiler/tests/circuit/to_binary_builtin.rs similarity index 100% rename from expander_compiler/tests/to_binary_builtin.rs rename to expander_compiler/tests/circuit/to_binary_builtin.rs diff --git a/expander_compiler/tests/to_binary_hint.rs b/expander_compiler/tests/circuit/to_binary_hint.rs similarity index 100% rename from expander_compiler/tests/to_binary_hint.rs rename to expander_compiler/tests/circuit/to_binary_hint.rs diff --git a/expander_compiler/tests/to_binary_unconstrained_api.rs b/expander_compiler/tests/circuit/to_binary_unconstrained_api.rs similarity index 100% rename from expander_compiler/tests/to_binary_unconstrained_api.rs rename to expander_compiler/tests/circuit/to_binary_unconstrained_api.rs diff --git a/expander_compiler/tests/circuit1.rs b/expander_compiler/tests/circuit1.rs new file mode 100755 index 00000000..77ff29ab --- /dev/null +++ b/expander_compiler/tests/circuit1.rs @@ -0,0 +1,563 @@ +use expander_compiler::frontend::*; +use expander_compiler::zkcuda::{context::*, kernel::*}; +use expander_compiler::zkcuda::shape::{Reshape, Transpose}; +use serdes::ExpSerde; + +#[allow(dead_code)] +struct Circuit { + output: Vec>, + input: Vec>>>, + _features_features_0_conv_output_0_conv: Vec>>>, + _features_features_0_conv_output_0_div: Vec>>>, + _features_features_0_conv_output_0_rem: Vec>>>, + _features_features_0_conv_output_0_floor: Vec>>>, + _features_features_2_relu_output_0: Vec>>>, + _features_features_3_conv_output_0_conv: Vec>>>, + _features_features_3_conv_output_0_div: Vec>>>, + _features_features_3_conv_output_0_rem: Vec>>>, + _features_features_3_conv_output_0_floor: Vec>>>, + _features_features_5_relu_output_0: Vec>>>, + _features_features_6_maxpool_output_0: Vec>>>, + _features_features_7_conv_output_0_conv: Vec>>>, + _features_features_7_conv_output_0_div: Vec>>>, + _features_features_7_conv_output_0_rem: Vec>>>, + _features_features_7_conv_output_0_floor: Vec>>>, + _features_features_9_relu_output_0: Vec>>>, + _features_features_10_conv_output_0_conv: Vec>>>, + _features_features_10_conv_output_0_div: Vec>>>, + _features_features_10_conv_output_0_rem: Vec>>>, + _features_features_10_conv_output_0_floor: Vec>>>, + _features_features_12_relu_output_0: Vec>>>, + _features_features_13_maxpool_output_0: Vec>>>, + _features_features_14_conv_output_0_conv: Vec>>>, + _features_features_14_conv_output_0_div: Vec>>>, + _features_features_14_conv_output_0_rem: Vec>>>, + _features_features_14_conv_output_0_floor: Vec>>>, + _features_features_16_relu_output_0: Vec>>>, + _features_features_17_conv_output_0_conv: Vec>>>, + _features_features_17_conv_output_0_div: Vec>>>, + _features_features_17_conv_output_0_rem: Vec>>>, + _features_features_17_conv_output_0_floor: Vec>>>, + _features_features_19_relu_output_0: Vec>>>, + _features_features_20_conv_output_0_conv: Vec>>>, + _features_features_20_conv_output_0_div: Vec>>>, + _features_features_20_conv_output_0_rem: Vec>>>, + _features_features_20_conv_output_0_floor: Vec>>>, + _features_features_22_relu_output_0: Vec>>>, + _features_features_23_maxpool_output_0: Vec>>>, + _features_features_24_conv_output_0_conv: Vec>>>, + _features_features_24_conv_output_0_div: Vec>>>, + _features_features_24_conv_output_0_rem: Vec>>>, + _features_features_24_conv_output_0_floor: Vec>>>, + _features_features_26_relu_output_0: Vec>>>, + _features_features_27_conv_output_0_conv: Vec>>>, + _features_features_27_conv_output_0_div: Vec>>>, + _features_features_27_conv_output_0_rem: Vec>>>, + _features_features_27_conv_output_0_floor: Vec>>>, + _features_features_29_relu_output_0: Vec>>>, + _features_features_30_conv_output_0_conv: Vec>>>, + _features_features_30_conv_output_0_div: Vec>>>, + _features_features_30_conv_output_0_rem: Vec>>>, + _features_features_30_conv_output_0_floor: Vec>>>, + _features_features_32_relu_output_0: Vec>>>, + _features_features_33_maxpool_output_0: Vec>>>, + _features_features_34_conv_output_0_conv: Vec>>>, + _features_features_34_conv_output_0_div: Vec>>>, + _features_features_34_conv_output_0_rem: Vec>>>, + _features_features_34_conv_output_0_floor: Vec>>>, + _features_features_36_relu_output_0: Vec>>>, + _features_features_37_conv_output_0_conv: Vec>>>, + _features_features_37_conv_output_0_div: Vec>>>, + _features_features_37_conv_output_0_rem: Vec>>>, + _features_features_37_conv_output_0_floor: Vec>>>, + _features_features_39_relu_output_0: Vec>>>, + _features_features_40_conv_output_0_conv: Vec>>>, + _features_features_40_conv_output_0_div: Vec>>>, + _features_features_40_conv_output_0_rem: Vec>>>, + _features_features_40_conv_output_0_floor: Vec>>>, + _features_features_42_relu_output_0: Vec>>>, + _features_features_43_maxpool_output_0: Vec>>>, + _avgpool_GlobalAveragePool_output_0: Vec>>>, + _classifier_classifier_0_gemm_output_0_matmul: Vec>, + _classifier_classifier_0_gemm_output_0_div: Vec>, + _classifier_classifier_0_gemm_output_0_rem: Vec>, + _classifier_classifier_0_gemm_output_0_floor: Vec>, + _classifier_classifier_1_relu_output_0: Vec>, + _classifier_classifier_3_gemm_output_0_matmul: Vec>, + _classifier_classifier_3_gemm_output_0_div: Vec>, + _classifier_classifier_3_gemm_output_0_rem: Vec>, + _classifier_classifier_3_gemm_output_0_floor: Vec>, + _classifier_classifier_4_relu_output_0: Vec>, + output_matmul: Vec>, + output_div: Vec>, + output_rem: Vec>, + output_floor: Vec>, + onnx_conv_150: Vec>>>, + onnx_conv_151: Vec, + onnx_conv_151_q: Vec>>, + onnx_conv_150_nscale: BN254Fr, + onnx_conv_150_dscale: BN254Fr, + onnx_conv_153: Vec>>>, + onnx_conv_154: Vec, + onnx_conv_154_q: Vec>>, + onnx_conv_153_nscale: BN254Fr, + onnx_conv_153_dscale: BN254Fr, + onnx_conv_156: Vec>>>, + onnx_conv_157: Vec, + onnx_conv_157_q: Vec>>, + onnx_conv_156_nscale: BN254Fr, + onnx_conv_156_dscale: BN254Fr, + onnx_conv_159: Vec>>>, + onnx_conv_160: Vec, + onnx_conv_160_q: Vec>>, + onnx_conv_159_nscale: BN254Fr, + onnx_conv_159_dscale: BN254Fr, + onnx_conv_162: Vec>>>, + onnx_conv_163: Vec, + onnx_conv_163_q: Vec>>, + onnx_conv_162_nscale: BN254Fr, + onnx_conv_162_dscale: BN254Fr, + onnx_conv_165: Vec>>>, + onnx_conv_166: Vec, + onnx_conv_166_q: Vec>>, + onnx_conv_165_nscale: BN254Fr, + onnx_conv_165_dscale: BN254Fr, + onnx_conv_168: Vec>>>, + onnx_conv_169: Vec, + onnx_conv_169_q: Vec>>, + onnx_conv_168_nscale: BN254Fr, + onnx_conv_168_dscale: BN254Fr, + onnx_conv_171: Vec>>>, + onnx_conv_172: Vec, + onnx_conv_172_q: Vec>>, + onnx_conv_171_nscale: BN254Fr, + onnx_conv_171_dscale: BN254Fr, + onnx_conv_174: Vec>>>, + onnx_conv_175: Vec, + onnx_conv_175_q: Vec>>, + onnx_conv_174_nscale: BN254Fr, + onnx_conv_174_dscale: BN254Fr, + onnx_conv_177: Vec>>>, + onnx_conv_178: Vec, + onnx_conv_178_q: Vec>>, + onnx_conv_177_nscale: BN254Fr, + onnx_conv_177_dscale: BN254Fr, + onnx_conv_180: Vec>>>, + onnx_conv_181: Vec, + onnx_conv_181_q: Vec>>, + onnx_conv_180_nscale: BN254Fr, + onnx_conv_180_dscale: BN254Fr, + onnx_conv_183: Vec>>>, + onnx_conv_184: Vec, + onnx_conv_184_q: Vec>>, + onnx_conv_183_nscale: BN254Fr, + onnx_conv_183_dscale: BN254Fr, + onnx_conv_186: Vec>>>, + onnx_conv_187: Vec, + onnx_conv_187_q: Vec>>, + onnx_conv_186_nscale: BN254Fr, + onnx_conv_186_dscale: BN254Fr, + classifier_0_weight: Vec>, + classifier_0_bias_q: Vec, + classifier_0_weight_nscale: BN254Fr, + classifier_0_weight_dscale: BN254Fr, + classifier_3_weight: Vec>, + classifier_3_bias_q: Vec, + classifier_3_weight_nscale: BN254Fr, + classifier_3_weight_dscale: BN254Fr, + classifier_6_weight: Vec>, + classifier_6_bias_q: Vec, + classifier_6_weight_nscale: BN254Fr, + classifier_6_weight_dscale: BN254Fr, + input_mat_ru: Vec, + onnx_conv_150_mat_rv: Vec, + _features_features_2_relu_output_0_mat_ru: Vec, + onnx_conv_153_mat_rv: Vec, + _features_features_6_maxpool_output_0_mat_ru: Vec, + onnx_conv_156_mat_rv: Vec, + _features_features_9_relu_output_0_mat_ru: Vec, + onnx_conv_159_mat_rv: Vec, + _features_features_13_maxpool_output_0_mat_ru: Vec, + onnx_conv_162_mat_rv: Vec, + _features_features_16_relu_output_0_mat_ru: Vec, + onnx_conv_165_mat_rv: Vec, + _features_features_19_relu_output_0_mat_ru: Vec, + onnx_conv_168_mat_rv: Vec, + _features_features_23_maxpool_output_0_mat_ru: Vec, + onnx_conv_171_mat_rv: Vec, + _features_features_26_relu_output_0_mat_ru: Vec, + onnx_conv_174_mat_rv: Vec, + _features_features_29_relu_output_0_mat_ru: Vec, + onnx_conv_177_mat_rv: Vec, + _features_features_33_maxpool_output_0_mat_ru: Vec, + onnx_conv_180_mat_rv: Vec, + _features_features_36_relu_output_0_mat_ru: Vec, + onnx_conv_183_mat_rv: Vec, + _features_features_39_relu_output_0_mat_ru: Vec, + onnx_conv_186_mat_rv: Vec, + _Flatten_output_0_mat_ru: Vec, + classifier_0_weight_mat_rv: Vec, + _classifier_classifier_1_relu_output_0_mat_ru: Vec, + classifier_3_weight_mat_rv: Vec, + _classifier_classifier_4_relu_output_0_mat_ru: Vec, + classifier_6_weight_mat_rv: Vec, +} + +fn default_variable() -> Circuit{ + let output = vec![vec![BN254Fr::default();10];1]; + let input = vec![vec![vec![vec![BN254Fr::default();32];32];3];1]; + let _features_features_0_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_0_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_0_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_0_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_2_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_3_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_3_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_3_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_3_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_5_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();32];32];64];1]; + let _features_features_6_maxpool_output_0 = vec![vec![vec![vec![BN254Fr::default();16];16];64];1]; + let _features_features_7_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_7_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_7_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_7_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_9_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_10_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_10_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_10_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_10_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_12_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();16];16];128];1]; + let _features_features_13_maxpool_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];128];1]; + let _features_features_14_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_14_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_14_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_14_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_16_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_17_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_17_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_17_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_17_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_19_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_20_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_20_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_20_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_20_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_22_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();8];8];256];1]; + let _features_features_23_maxpool_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];256];1]; + let _features_features_24_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_24_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_24_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_24_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_26_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_27_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_27_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_27_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_27_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_29_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_30_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_30_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_30_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_30_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_32_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();4];4];512];1]; + let _features_features_33_maxpool_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_34_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_34_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_34_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_34_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_36_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_37_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_37_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_37_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_37_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_39_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_40_conv_output_0_conv = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_40_conv_output_0_div = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_40_conv_output_0_rem = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_40_conv_output_0_floor = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_42_relu_output_0 = vec![vec![vec![vec![BN254Fr::default();2];2];512];1]; + let _features_features_43_maxpool_output_0 = vec![vec![vec![vec![BN254Fr::default();1];1];512];1]; + let _avgpool_GlobalAveragePool_output_0 = vec![vec![vec![vec![BN254Fr::default();1];1];512];1]; + let _classifier_classifier_0_gemm_output_0_matmul = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_0_gemm_output_0_div = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_0_gemm_output_0_rem = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_0_gemm_output_0_floor = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_1_relu_output_0 = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_3_gemm_output_0_matmul = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_3_gemm_output_0_div = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_3_gemm_output_0_rem = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_3_gemm_output_0_floor = vec![vec![BN254Fr::default();512];1]; + let _classifier_classifier_4_relu_output_0 = vec![vec![BN254Fr::default();512];1]; + let output_matmul = vec![vec![BN254Fr::default();10];1]; + let output_div = vec![vec![BN254Fr::default();10];1]; + let output_rem = vec![vec![BN254Fr::default();10];1]; + let output_floor = vec![vec![BN254Fr::default();10];1]; + let onnx_conv_150 = vec![vec![vec![vec![BN254Fr::default();3];3];3];64]; + let onnx_conv_151 = vec![BN254Fr::default();64]; + let onnx_conv_151_q = vec![vec![vec![BN254Fr::default();1];1];64]; + let onnx_conv_150_nscale = BN254Fr::default(); + let onnx_conv_150_dscale = BN254Fr::default(); + let onnx_conv_153 = vec![vec![vec![vec![BN254Fr::default();3];3];64];64]; + let onnx_conv_154 = vec![BN254Fr::default();64]; + let onnx_conv_154_q = vec![vec![vec![BN254Fr::default();1];1];64]; + let onnx_conv_153_nscale = BN254Fr::default(); + let onnx_conv_153_dscale = BN254Fr::default(); + let onnx_conv_156 = vec![vec![vec![vec![BN254Fr::default();3];3];64];128]; + let onnx_conv_157 = vec![BN254Fr::default();128]; + let onnx_conv_157_q = vec![vec![vec![BN254Fr::default();1];1];128]; + let onnx_conv_156_nscale = BN254Fr::default(); + let onnx_conv_156_dscale = BN254Fr::default(); + let onnx_conv_159 = vec![vec![vec![vec![BN254Fr::default();3];3];128];128]; + let onnx_conv_160 = vec![BN254Fr::default();128]; + let onnx_conv_160_q = vec![vec![vec![BN254Fr::default();1];1];128]; + let onnx_conv_159_nscale = BN254Fr::default(); + let onnx_conv_159_dscale = BN254Fr::default(); + let onnx_conv_162 = vec![vec![vec![vec![BN254Fr::default();3];3];128];256]; + let onnx_conv_163 = vec![BN254Fr::default();256]; + let onnx_conv_163_q = vec![vec![vec![BN254Fr::default();1];1];256]; + let onnx_conv_162_nscale = BN254Fr::default(); + let onnx_conv_162_dscale = BN254Fr::default(); + let onnx_conv_165 = vec![vec![vec![vec![BN254Fr::default();3];3];256];256]; + let onnx_conv_166 = vec![BN254Fr::default();256]; + let onnx_conv_166_q = vec![vec![vec![BN254Fr::default();1];1];256]; + let onnx_conv_165_nscale = BN254Fr::default(); + let onnx_conv_165_dscale = BN254Fr::default(); + let onnx_conv_168 = vec![vec![vec![vec![BN254Fr::default();3];3];256];256]; + let onnx_conv_169 = vec![BN254Fr::default();256]; + let onnx_conv_169_q = vec![vec![vec![BN254Fr::default();1];1];256]; + let onnx_conv_168_nscale = BN254Fr::default(); + let onnx_conv_168_dscale = BN254Fr::default(); + let onnx_conv_171 = vec![vec![vec![vec![BN254Fr::default();3];3];256];512]; + let onnx_conv_172 = vec![BN254Fr::default();512]; + let onnx_conv_172_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx_conv_171_nscale = BN254Fr::default(); + let onnx_conv_171_dscale = BN254Fr::default(); + let onnx_conv_174 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx_conv_175 = vec![BN254Fr::default();512]; + let onnx_conv_175_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx_conv_174_nscale = BN254Fr::default(); + let onnx_conv_174_dscale = BN254Fr::default(); + let onnx_conv_177 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx_conv_178 = vec![BN254Fr::default();512]; + let onnx_conv_178_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx_conv_177_nscale = BN254Fr::default(); + let onnx_conv_177_dscale = BN254Fr::default(); + let onnx_conv_180 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx_conv_181 = vec![BN254Fr::default();512]; + let onnx_conv_181_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx_conv_180_nscale = BN254Fr::default(); + let onnx_conv_180_dscale = BN254Fr::default(); + let onnx_conv_183 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx_conv_184 = vec![BN254Fr::default();512]; + let onnx_conv_184_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx_conv_183_nscale = BN254Fr::default(); + let onnx_conv_183_dscale = BN254Fr::default(); + let onnx_conv_186 = vec![vec![vec![vec![BN254Fr::default();3];3];512];512]; + let onnx_conv_187 = vec![BN254Fr::default();512]; + let onnx_conv_187_q = vec![vec![vec![BN254Fr::default();1];1];512]; + let onnx_conv_186_nscale = BN254Fr::default(); + let onnx_conv_186_dscale = BN254Fr::default(); + let classifier_0_weight = vec![vec![BN254Fr::default();512];512]; + let classifier_0_bias_q = vec![BN254Fr::default();512]; + let classifier_0_weight_nscale = BN254Fr::default(); + let classifier_0_weight_dscale = BN254Fr::default(); + let classifier_3_weight = vec![vec![BN254Fr::default();512];512]; + let classifier_3_bias_q = vec![BN254Fr::default();512]; + let classifier_3_weight_nscale = BN254Fr::default(); + let classifier_3_weight_dscale = BN254Fr::default(); + let classifier_6_weight = vec![vec![BN254Fr::default();10];512]; + let classifier_6_bias_q = vec![BN254Fr::default();10]; + let classifier_6_weight_nscale = BN254Fr::default(); + let classifier_6_weight_dscale = BN254Fr::default(); + let input_mat_ru = vec![BN254Fr::default();1024]; + let onnx_conv_150_mat_rv = vec![BN254Fr::default();64]; + let _features_features_2_relu_output_0_mat_ru = vec![BN254Fr::default();1024]; + let onnx_conv_153_mat_rv = vec![BN254Fr::default();64]; + let _features_features_6_maxpool_output_0_mat_ru = vec![BN254Fr::default();256]; + let onnx_conv_156_mat_rv = vec![BN254Fr::default();128]; + let _features_features_9_relu_output_0_mat_ru = vec![BN254Fr::default();256]; + let onnx_conv_159_mat_rv = vec![BN254Fr::default();128]; + let _features_features_13_maxpool_output_0_mat_ru = vec![BN254Fr::default();64]; + let onnx_conv_162_mat_rv = vec![BN254Fr::default();256]; + let _features_features_16_relu_output_0_mat_ru = vec![BN254Fr::default();64]; + let onnx_conv_165_mat_rv = vec![BN254Fr::default();256]; + let _features_features_19_relu_output_0_mat_ru = vec![BN254Fr::default();64]; + let onnx_conv_168_mat_rv = vec![BN254Fr::default();256]; + let _features_features_23_maxpool_output_0_mat_ru = vec![BN254Fr::default();16]; + let onnx_conv_171_mat_rv = vec![BN254Fr::default();512]; + let _features_features_26_relu_output_0_mat_ru = vec![BN254Fr::default();16]; + let onnx_conv_174_mat_rv = vec![BN254Fr::default();512]; + let _features_features_29_relu_output_0_mat_ru = vec![BN254Fr::default();16]; + let onnx_conv_177_mat_rv = vec![BN254Fr::default();512]; + let _features_features_33_maxpool_output_0_mat_ru = vec![BN254Fr::default();4]; + let onnx_conv_180_mat_rv = vec![BN254Fr::default();512]; + let _features_features_36_relu_output_0_mat_ru = vec![BN254Fr::default();4]; + let onnx_conv_183_mat_rv = vec![BN254Fr::default();512]; + let _features_features_39_relu_output_0_mat_ru = vec![BN254Fr::default();4]; + let onnx_conv_186_mat_rv = vec![BN254Fr::default();512]; + let _Flatten_output_0_mat_ru = vec![BN254Fr::default();1]; + let classifier_0_weight_mat_rv = vec![BN254Fr::default();512]; + let _classifier_classifier_1_relu_output_0_mat_ru = vec![BN254Fr::default();1]; + let classifier_3_weight_mat_rv = vec![BN254Fr::default();512]; + let _classifier_classifier_4_relu_output_0_mat_ru = vec![BN254Fr::default();1]; + let classifier_6_weight_mat_rv = vec![BN254Fr::default();10]; + let ass = Circuit{output,input,_features_features_0_conv_output_0_conv,_features_features_0_conv_output_0_div,_features_features_0_conv_output_0_rem,_features_features_0_conv_output_0_floor,_features_features_2_relu_output_0,_features_features_3_conv_output_0_conv,_features_features_3_conv_output_0_div,_features_features_3_conv_output_0_rem,_features_features_3_conv_output_0_floor,_features_features_5_relu_output_0,_features_features_6_maxpool_output_0,_features_features_7_conv_output_0_conv,_features_features_7_conv_output_0_div,_features_features_7_conv_output_0_rem,_features_features_7_conv_output_0_floor,_features_features_9_relu_output_0,_features_features_10_conv_output_0_conv,_features_features_10_conv_output_0_div,_features_features_10_conv_output_0_rem,_features_features_10_conv_output_0_floor,_features_features_12_relu_output_0,_features_features_13_maxpool_output_0,_features_features_14_conv_output_0_conv,_features_features_14_conv_output_0_div,_features_features_14_conv_output_0_rem,_features_features_14_conv_output_0_floor,_features_features_16_relu_output_0,_features_features_17_conv_output_0_conv,_features_features_17_conv_output_0_div,_features_features_17_conv_output_0_rem,_features_features_17_conv_output_0_floor,_features_features_19_relu_output_0,_features_features_20_conv_output_0_conv,_features_features_20_conv_output_0_div,_features_features_20_conv_output_0_rem,_features_features_20_conv_output_0_floor,_features_features_22_relu_output_0,_features_features_23_maxpool_output_0,_features_features_24_conv_output_0_conv,_features_features_24_conv_output_0_div,_features_features_24_conv_output_0_rem,_features_features_24_conv_output_0_floor,_features_features_26_relu_output_0,_features_features_27_conv_output_0_conv,_features_features_27_conv_output_0_div,_features_features_27_conv_output_0_rem,_features_features_27_conv_output_0_floor,_features_features_29_relu_output_0,_features_features_30_conv_output_0_conv,_features_features_30_conv_output_0_div,_features_features_30_conv_output_0_rem,_features_features_30_conv_output_0_floor,_features_features_32_relu_output_0,_features_features_33_maxpool_output_0,_features_features_34_conv_output_0_conv,_features_features_34_conv_output_0_div,_features_features_34_conv_output_0_rem,_features_features_34_conv_output_0_floor,_features_features_36_relu_output_0,_features_features_37_conv_output_0_conv,_features_features_37_conv_output_0_div,_features_features_37_conv_output_0_rem,_features_features_37_conv_output_0_floor,_features_features_39_relu_output_0,_features_features_40_conv_output_0_conv,_features_features_40_conv_output_0_div,_features_features_40_conv_output_0_rem,_features_features_40_conv_output_0_floor,_features_features_42_relu_output_0,_features_features_43_maxpool_output_0,_avgpool_GlobalAveragePool_output_0,_classifier_classifier_0_gemm_output_0_matmul,_classifier_classifier_0_gemm_output_0_div,_classifier_classifier_0_gemm_output_0_rem,_classifier_classifier_0_gemm_output_0_floor,_classifier_classifier_1_relu_output_0,_classifier_classifier_3_gemm_output_0_matmul,_classifier_classifier_3_gemm_output_0_div,_classifier_classifier_3_gemm_output_0_rem,_classifier_classifier_3_gemm_output_0_floor,_classifier_classifier_4_relu_output_0,output_matmul,output_div,output_rem,output_floor,onnx_conv_150,onnx_conv_151,onnx_conv_151_q,onnx_conv_150_nscale,onnx_conv_150_dscale,onnx_conv_153,onnx_conv_154,onnx_conv_154_q,onnx_conv_153_nscale,onnx_conv_153_dscale,onnx_conv_156,onnx_conv_157,onnx_conv_157_q,onnx_conv_156_nscale,onnx_conv_156_dscale,onnx_conv_159,onnx_conv_160,onnx_conv_160_q,onnx_conv_159_nscale,onnx_conv_159_dscale,onnx_conv_162,onnx_conv_163,onnx_conv_163_q,onnx_conv_162_nscale,onnx_conv_162_dscale,onnx_conv_165,onnx_conv_166,onnx_conv_166_q,onnx_conv_165_nscale,onnx_conv_165_dscale,onnx_conv_168,onnx_conv_169,onnx_conv_169_q,onnx_conv_168_nscale,onnx_conv_168_dscale,onnx_conv_171,onnx_conv_172,onnx_conv_172_q,onnx_conv_171_nscale,onnx_conv_171_dscale,onnx_conv_174,onnx_conv_175,onnx_conv_175_q,onnx_conv_174_nscale,onnx_conv_174_dscale,onnx_conv_177,onnx_conv_178,onnx_conv_178_q,onnx_conv_177_nscale,onnx_conv_177_dscale,onnx_conv_180,onnx_conv_181,onnx_conv_181_q,onnx_conv_180_nscale,onnx_conv_180_dscale,onnx_conv_183,onnx_conv_184,onnx_conv_184_q,onnx_conv_183_nscale,onnx_conv_183_dscale,onnx_conv_186,onnx_conv_187,onnx_conv_187_q,onnx_conv_186_nscale,onnx_conv_186_dscale,classifier_0_weight,classifier_0_bias_q,classifier_0_weight_nscale,classifier_0_weight_dscale,classifier_3_weight,classifier_3_bias_q,classifier_3_weight_nscale,classifier_3_weight_dscale,classifier_6_weight,classifier_6_bias_q,classifier_6_weight_nscale,classifier_6_weight_dscale,input_mat_ru,onnx_conv_150_mat_rv,_features_features_2_relu_output_0_mat_ru,onnx_conv_153_mat_rv,_features_features_6_maxpool_output_0_mat_ru,onnx_conv_156_mat_rv,_features_features_9_relu_output_0_mat_ru,onnx_conv_159_mat_rv,_features_features_13_maxpool_output_0_mat_ru,onnx_conv_162_mat_rv,_features_features_16_relu_output_0_mat_ru,onnx_conv_165_mat_rv,_features_features_19_relu_output_0_mat_ru,onnx_conv_168_mat_rv,_features_features_23_maxpool_output_0_mat_ru,onnx_conv_171_mat_rv,_features_features_26_relu_output_0_mat_ru,onnx_conv_174_mat_rv,_features_features_29_relu_output_0_mat_ru,onnx_conv_177_mat_rv,_features_features_33_maxpool_output_0_mat_ru,onnx_conv_180_mat_rv,_features_features_36_relu_output_0_mat_ru,onnx_conv_183_mat_rv,_features_features_39_relu_output_0_mat_ru,onnx_conv_186_mat_rv,_Flatten_output_0_mat_ru,classifier_0_weight_mat_rv,_classifier_classifier_1_relu_output_0_mat_ru,classifier_3_weight_mat_rv,_classifier_classifier_4_relu_output_0_mat_ru,classifier_6_weight_mat_rv}; + ass +} + +#[kernel] +fn _features_features_0_conv_conv_copy_macro( + api: &mut API, + onnx_conv_150: &[[[[InputVariable;3];3];3];64], + _features_features_0_conv_output_0_conv: &[[[[InputVariable;32];32];64];1], + input: &[[[[InputVariable;32];32];3];1], + + onnx_conv_150_mat: &mut [[OutputVariable;64];27], + _features_features_0_conv_output_0_conv_mat: &mut [[OutputVariable;1024];64], + input_mat: &mut [[OutputVariable;1024];27], +) { + // for i in 0..64 { + // for j in 0..3 { + // for k in 0..3 { + // for l in 0..3 { + // onnx_conv_150_mat[((j)*3 + k)*3 + l][i] = onnx_conv_150[i][j][k][l]; + // } + // } + // } + // } + // for i in 0..1 { + // for j in 0..64 { + // for k in 0..32 { + // for l in 0..32 { + // _features_features_0_conv_output_0_conv_mat[j][((i)*32 + k)*32 + l] = _features_features_0_conv_output_0_conv[i][j][k][l]; + // } + // } + // } + // } + for i in (0..(1 + 0 + 0 - 1 + 1)).step_by(1) { + for j in (0..(3 + 0 + 0 - 3 + 1)).step_by(3) { + for k in (0..(32 + 1 + 1 - 3 + 1)).step_by(1) { + for l in (0..(32 + 1 + 1 - 3 + 1)).step_by(1) { + for m in 0..1 { + for n in 0..3 { + for o in 0..3 { + for p in 0..3 { + if true && (i+m-0) >= 0 && (i+m-0) < 1 && (j+n-0) >= 0 && (j+n-0) < 3 && (k+o-1) >= 0 && (k+o-1) < 32 && (l+p-1) >= 0 && (l+p-1) < 32 { input_mat[((n)*3 + o)*3 + p][((i)*32 + k)*32 + l] = input[i+m-0][j+n-0][k+o-1][l+p-1]} + else { input_mat[((n)*3 + o)*3 + p][((i)*32 + k)*32 + l] = api.constant(0)}; + } + } + } + } + } + } + } + } +} + +#[kernel] +fn _features_features_0_conv_conv_ab_matrix_macro( + api: &mut API, + input_mat: & [InputVariable;1024], + onnx_conv_150_mat: & [InputVariable;64], + input_mat_ru: & [InputVariable;1024], + onnx_conv_150_mat_rv: & [InputVariable;64], + _features_features_0_conv_conv_ab_matrix_rx: &mut OutputVariable, + _features_features_0_conv_conv_ab_matrix_ry: &mut OutputVariable, +) { + *_features_features_0_conv_conv_ab_matrix_rx = api.constant(0); + for i in 0..1024 { + let tmp = api.mul(input_mat_ru[i], input_mat[i]); + *_features_features_0_conv_conv_ab_matrix_rx = api.add(tmp, *_features_features_0_conv_conv_ab_matrix_rx); + } + *_features_features_0_conv_conv_ab_matrix_ry = api.constant(0); + for i in 0..64 { + let tmp = api.mul(onnx_conv_150_mat_rv[i], onnx_conv_150_mat[i]); + *_features_features_0_conv_conv_ab_matrix_ry = api.add(tmp, *_features_features_0_conv_conv_ab_matrix_ry); + } +} +#[kernel] +fn _features_features_0_conv_conv_c_matrix_macro( + api: &mut API, + _features_features_0_conv_output_0_conv_mat: & [InputVariable;1024], + input_mat_ru: & [InputVariable;1024], + _features_features_0_conv_conv_c_matrix_rz: &mut OutputVariable, +) { + *_features_features_0_conv_conv_c_matrix_rz = api.constant(0); + for i in 0..1024 { + let tmp = api.mul(input_mat_ru[i], _features_features_0_conv_output_0_conv_mat[i]); + *_features_features_0_conv_conv_c_matrix_rz = api.add(tmp, *_features_features_0_conv_conv_c_matrix_rz); + } +} + +#[kernel] // multiply operation +fn _features_features_0_conv_mul_macro( + api: &mut API, + _features_features_0_conv_output_0_conv: &[[InputVariable;32];32], + onnx_conv_150_nscale: &InputVariable, + _features_features_0_conv_output_0_mul: &mut [[OutputVariable;32];32], +) { + for i in 0..32 { + for j in 0..32 { + _features_features_0_conv_output_0_mul[i][j] = api.mul(_features_features_0_conv_output_0_conv[i][j], onnx_conv_150_nscale); + } + } +} + +#[kernel] // divide operation +fn _features_features_0_conv_div_macro( + api: &mut API, + _features_features_0_conv_output_0_mul: &[[InputVariable;32];32], + onnx_conv_150_dscale: &InputVariable, + _features_features_0_conv_output_0_floor: &[[InputVariable;32];32], + _features_features_0_conv_output_0_rem: &[[InputVariable;32];32], +) { + for i in 0..32 { + for j in 0..32 { + let tmp1 = api.mul(_features_features_0_conv_output_0_floor[i][j], onnx_conv_150_dscale); + let tmp2 = api.sub(_features_features_0_conv_output_0_mul[i][j], _features_features_0_conv_output_0_rem[i][j]); + api.assert_is_equal(tmp1, tmp2); + } + } +} + +#[test] +fn expander_circuit() -> std::io::Result<()>{ + let compile_result = stacker::grow(32 * 1024 * 1024 * 1024, || + { + let mut ctx = Context::::default(); + let mut assignment = default_variable(); + + let onnx_conv_150_mat = ctx.copy_to_device(&assignment.onnx_conv_150); // [64, 3, 3, 3] + let onnx_conv_150_mat = onnx_conv_150_mat.reshape(&[64, 27]); // [64, 27] + let onnx_conv_150_mat = onnx_conv_150_mat.transpose(&[1, 0]); // [27, 64] + + let kernel__features_features_0_conv_conv_ab_matrix: KernelPrimitive = compile__features_features_0_conv_conv_ab_matrix_macro().unwrap(); + let input_mat = ctx.copy_to_device(&vec![vec![BN254Fr::default();1024];27]); + let input_mat_ru = ctx.copy_to_device(&assignment.input_mat_ru); + let onnx_conv_150_mat_rv = ctx.copy_to_device(&assignment.onnx_conv_150_mat_rv); + let mut _features_features_0_conv_conv_rx = None; + let mut _features_features_0_conv_conv_ry = None; + let mut input_mat_clone = input_mat.clone(); + let mut onnx_conv_150_mat_clone = onnx_conv_150_mat.clone(); + let mut input_mat_ru_clone = input_mat_ru.clone(); + let mut onnx_conv_150_mat_rv_clone = onnx_conv_150_mat_rv.clone(); + call_kernel!(ctx, kernel__features_features_0_conv_conv_ab_matrix, 27, input_mat_clone, onnx_conv_150_mat_clone, input_mat_ru_clone, onnx_conv_150_mat_rv_clone, mut _features_features_0_conv_conv_rx, mut _features_features_0_conv_conv_ry).unwrap(); + + let _features_features_0_conv_output_0_conv = ctx.copy_to_device(&assignment._features_features_0_conv_output_0_conv); // [1, 64, 32, 32] + let _features_features_0_conv_output_0_conv_mat = _features_features_0_conv_output_0_conv.transpose(&[1, 0, 2, 3]); // [64, 1, 32, 32] + let _features_features_0_conv_output_0_conv_mat = _features_features_0_conv_output_0_conv_mat.reshape(&[64, 1024]); // [64, 1024] + + let kernel__features_features_0_conv_conv_c_matrix: KernelPrimitive = compile__features_features_0_conv_conv_c_matrix_macro().unwrap(); + // let _features_features_0_conv_output_0_conv_mat = ctx.copy_to_device(&vec![vec![BN254Fr::default();1024];64]); + let mut _features_features_0_conv_conv_rz = None; + let _features_features_0_conv_output_0_conv_mat_clone = _features_features_0_conv_output_0_conv_mat.clone(); + let input_mat_ru_clone = input_mat_ru.clone(); + call_kernel!(ctx, kernel__features_features_0_conv_conv_c_matrix, 64, _features_features_0_conv_output_0_conv_mat_clone, input_mat_ru_clone, mut _features_features_0_conv_conv_rz).unwrap(); + + let computation_graph = ctx.compile_computation_graph().unwrap(); + let file = std::fs::File::create("graph.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + computation_graph.serialize_into(writer); + } + ); + Ok(()) +} diff --git a/expander_compiler/tests/mod.rs b/expander_compiler/tests/mod.rs new file mode 100644 index 00000000..e9f75e05 --- /dev/null +++ b/expander_compiler/tests/mod.rs @@ -0,0 +1,2 @@ +mod circuit; +mod zkcuda; diff --git a/expander_compiler/tests/cg_mpi_share.rs b/expander_compiler/tests/zkcuda/cg_mpi_share.rs similarity index 100% rename from expander_compiler/tests/cg_mpi_share.rs rename to expander_compiler/tests/zkcuda/cg_mpi_share.rs diff --git a/expander_compiler/tests/zkcuda/mod.rs b/expander_compiler/tests/zkcuda/mod.rs new file mode 100644 index 00000000..07bcfda6 --- /dev/null +++ b/expander_compiler/tests/zkcuda/mod.rs @@ -0,0 +1,3 @@ +mod zkcuda_examples; +mod zkcuda_keccak; +mod zkcuda_matmul; diff --git a/expander_compiler/tests/zkcuda_examples.rs b/expander_compiler/tests/zkcuda/zkcuda_examples.rs similarity index 91% rename from expander_compiler/tests/zkcuda_examples.rs rename to expander_compiler/tests/zkcuda/zkcuda_examples.rs index c70ef94e..a9bb04fe 100644 --- a/expander_compiler/tests/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda/zkcuda_examples.rs @@ -1,9 +1,15 @@ use expander_compiler::frontend::*; -use expander_compiler::zkcuda::proving_system::{Expander, ParallelizedExpander, ProvingSystem}; +use expander_compiler::zkcuda::proving_system::expander::config::{ + ZKCudaBN254Hyrax, ZKCudaBN254HyraxBatchPCS, ZKCudaBN254KZG, ZKCudaBN254KZGBatchPCS, +}; +use expander_compiler::zkcuda::proving_system::expander_pcs_defered::BN254ConfigSha2UniKZG; +use expander_compiler::zkcuda::proving_system::{ + Expander, ExpanderNoOverSubscribe, ParallelizedExpander, ProvingSystem, +}; use expander_compiler::zkcuda::shape::Reshape; use expander_compiler::zkcuda::{context::*, kernel::*}; -use gkr::{BN254ConfigSha2Hyrax, BN254ConfigSha2KZG}; +use gkr::BN254ConfigSha2Hyrax; use serdes::ExpSerde; #[kernel] @@ -50,7 +56,7 @@ fn zkcuda_test>() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); P::post_process(); @@ -74,7 +80,7 @@ fn zkcuda_test_single_core() { zkcuda_test::>(); zkcuda_test::>(); zkcuda_test::>(); - zkcuda_test::>(); + zkcuda_test::>(); } #[test] @@ -85,7 +91,12 @@ fn zkcuda_test_multi_core() { zkcuda_test::>(); zkcuda_test::>(); zkcuda_test::>(); - zkcuda_test::>(); + zkcuda_test::>(); + + // zkcuda_test::<_, ExpanderNoOverSubscribe>(); + // zkcuda_test::<_, ExpanderNoOverSubscribe>(); + zkcuda_test::<_, ExpanderNoOverSubscribe>(); + zkcuda_test::<_, ExpanderNoOverSubscribe>(); } fn zkcuda_test_simd_prepare_ctx() -> Context { @@ -138,7 +149,7 @@ fn zkcuda_test_simd() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); @@ -166,7 +177,7 @@ fn zkcuda_test_simd() { let proof3 = P::prove( &prover_setup3, &computation_graph, - &ctx3.export_device_memories(), + ctx3.export_device_memories(), ); assert!(P::verify(&verifier_setup2, &computation_graph, &proof3)); } @@ -211,7 +222,7 @@ fn zkcuda_test_simd_autopack() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } @@ -274,7 +285,7 @@ fn zkcuda_to_binary() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } @@ -300,7 +311,7 @@ fn zkcuda_assertion() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } @@ -322,7 +333,7 @@ fn zkcuda_assertion_fail() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } diff --git a/expander_compiler/tests/zkcuda_keccak.rs b/expander_compiler/tests/zkcuda/zkcuda_keccak.rs similarity index 99% rename from expander_compiler/tests/zkcuda_keccak.rs rename to expander_compiler/tests/zkcuda/zkcuda_keccak.rs index f8bfeb7d..995637a0 100644 --- a/expander_compiler/tests/zkcuda_keccak.rs +++ b/expander_compiler/tests/zkcuda/zkcuda_keccak.rs @@ -353,7 +353,7 @@ fn zkcuda_keccak_1_helper>() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); println!("proof generation ok"); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); @@ -416,7 +416,7 @@ fn zkcuda_keccak_2_helper>() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); println!("proof generation ok"); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); diff --git a/expander_compiler/tests/zkcuda_matmul.rs b/expander_compiler/tests/zkcuda/zkcuda_matmul.rs similarity index 98% rename from expander_compiler/tests/zkcuda_matmul.rs rename to expander_compiler/tests/zkcuda/zkcuda_matmul.rs index 605d449b..7b20e63d 100644 --- a/expander_compiler/tests/zkcuda_matmul.rs +++ b/expander_compiler/tests/zkcuda/zkcuda_matmul.rs @@ -93,7 +93,7 @@ fn zkcuda_matmul_sum() { let proof = P::prove( &prover_setup, &computation_graph, - &ctx.export_device_memories(), + ctx.export_device_memories(), ); assert!(P::verify(&verifier_setup, &computation_graph, &proof)); }