diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..861cc80 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[target.wasm32-unknown-unknown] +rustflags = ["--cfg=getrandom_backend=\"wasm_js\""] diff --git a/Cargo.lock b/Cargo.lock index 9fdc800..3a8d405 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -75,7 +75,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" dependencies = [ "num-traits", - "rand", + "rand 0.8.5", "rayon", ] @@ -91,6 +91,12 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "bitflags" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" + [[package]] name = "bumpalo" version = "3.16.0" @@ -155,13 +161,12 @@ dependencies = [ [[package]] name = "custom-constraints" -version = "0.1.0" +version = "0.2.0" dependencies = [ "ark-ff", "ark-std", "getrandom", - "rand", - "rayon", + "rand 0.9.0", "rstest", "wasm-bindgen", "wasm-bindgen-test", @@ -226,15 +231,16 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" dependencies = [ "cfg-if", "js-sys", "libc", "wasi", "wasm-bindgen", + "windows-targets", ] [[package]] @@ -336,7 +342,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -363,9 +369,19 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ - "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.0", + "zerocopy 0.8.18", ] [[package]] @@ -375,7 +391,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.0", ] [[package]] @@ -383,8 +409,15 @@ name = "rand_core" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" + +[[package]] +name = "rand_core" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b08f3c9802962f7e1b25113931d94f43ed9725bebc59db9d0c3e9a23b67e15ff" dependencies = [ "getrandom", + "zerocopy 0.8.18", ] [[package]] @@ -546,9 +579,12 @@ dependencies = [ [[package]] name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +version = "0.13.3+wasi-0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] [[package]] name = "wasm-bindgen" @@ -737,6 +773,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags", +] + [[package]] name = "zerocopy" version = "0.7.35" @@ -744,7 +789,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79386d31a42a4996e3336b0919ddb90f81112af416270cff95b5f5af22b839c2" +dependencies = [ + "zerocopy-derive 0.8.18", ] [[package]] @@ -758,6 +812,17 @@ dependencies = [ "syn", ] +[[package]] +name = "zerocopy-derive" +version = "0.8.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76331675d372f91bf8d17e13afbd5fe639200b73d01f0fc748bb059f9cca2db7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zeroize" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index 4adc85e..c431786 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,22 +7,27 @@ license = "MIT" name = "custom-constraints" readme = "README.md" repository = "https://github.com/autoparallel/custom-constraints" -version = "0.1.0" +version = "0.2.0" + +# Needed to deal with ark-ff's BS +[features] +asm = [] +default = [] [dependencies] -ark-ff = { version = "0.5", default-features = false, features = [ - "parallel", - "asm", -] } -rayon = { version = "1.10" } +[target.'cfg(target_arch = "x86_64")'.dependencies] +ark-ff = { version = "0.5", features = ["parallel", "asm"] } + +[target.'cfg(not(target_arch = "x86_64"))'.dependencies] +ark-ff = { version = "0.5", default-features = false, features = ["parallel"] } [dev-dependencies] ark-std = { version = "0.5", default-features = false, features = ["std"] } -rand = "0.8" +rand = "0.9" rstest = { version = "0.24", default-features = false } [target.'cfg(target_arch = "wasm32")'.dev-dependencies] -getrandom = { version = "0.2", features = ["js"] } +getrandom = { version = "0.3", features = ["wasm_js"] } wasm-bindgen = { version = "0.2" } wasm-bindgen-test = { version = "0.3" } diff --git a/README.md b/README.md index 7ceab5c..71b6ac7 100644 --- a/README.md +++ b/README.md @@ -38,9 +38,11 @@ See the [Customizable constraint systems for succinct arguments](https://eprint. ## Roadmap - [x] CSR Sparse matrices -- [ ] CCS structure -- [ ] CCS checking -- [ ] CCS builder/allocator (i.e., from constraints) +- [x] CCS structure +- [x] CCS checking +- [x] CCS builder/allocator (i.e., from constraints) +- [x] Plonkish CCS +- [ ] Noir to Plonkish CCS ## Contributing diff --git a/justfile b/justfile index b573579..349362a 100644 --- a/justfile +++ b/justfile @@ -124,6 +124,7 @@ build-wasm: test: @just header "Running native architecture tests" cargo test --workspace --tests --all-features + cargo test --workspace --doc --all-features @just header "Running wasm tests" wasm-pack test --node @@ -159,7 +160,7 @@ semver: # Run format for the workspace fmt: @just header "Formatting code" - cargo fmt --all + cargo +nightly fmt --all taplo fmt # Check for unused dependencies diff --git a/src/ccs.rs b/src/ccs/generic.rs similarity index 62% rename from src/ccs.rs rename to src/ccs/generic.rs index efcdae5..5047815 100644 --- a/src/ccs.rs +++ b/src/ccs/generic.rs @@ -1,34 +1,88 @@ -//! Implements the Customizable Constraint System (CCS) format. +//! Implementation of the standard/generic Customizable Constraint Systems (CCS). //! -//! A CCS represents arithmetic constraints through a combination of matrices -//! and multisets, allowing efficient verification of arithmetic computations. +//! This module provides a standard implementation of CCS where each selector is a single +//! field element. The constraint system has the form: //! -//! The system consists of: -//! - A set of sparse matrices representing linear combinations -//! - Multisets defining which matrices participate in each constraint -//! - Constants applied to each constraint term - -use matrix::SparseMatrix; +//! ```text +//! sum_i q_i * (prod_{j in S_i} M_j z) = 0 +//! ``` +//! +//! where: +//! - `q_i` are field element selectors +//! - `S_i` are multisets of matrix indices +//! - `M_j` are the selector matrices +//! - `z` is the combined input vector `z = (w, 1, x)` +//! - `prod` denotes the Hadamard (element-wise) product +//! +//! # Example Usage +//! +//! Creating a constraint system for `x * y = z`: +//! ``` +//! use custom_constraints::{ +//! ccs::{generic::Generic, CCS}, +//! matrix::SparseMatrix, +//! }; +//! # use ark_ff::{Field, Fp, MontBackend, MontConfig}; +//! # #[derive(MontConfig)] +//! # #[modulus = "17"] +//! # #[generator = "3"] +//! # struct FConfig; +//! # type F = Fp, 1>; +//! +//! // Create matrices to select variables +//! let mut m1 = SparseMatrix::new_rows_cols(1, 4); +//! m1.write(0, 3, F::ONE); // Select x +//! +//! let mut m2 = SparseMatrix::new_rows_cols(1, 4); +//! m2.write(0, 0, F::ONE); // Select y +//! +//! let mut m3 = SparseMatrix::new_rows_cols(1, 4); +//! m3.write(0, 1, F::ONE); // Select z +//! +//! // Create CCS and set matrices +//! let mut ccs = CCS::, F>::new(); +//! ccs.matrices = vec![m1, m2, m3]; +//! +//! // Encode x * y - z = 0 +//! ccs.multisets = vec![vec![0, 1], vec![2]]; // Terms: (M1·z ∘ M2·z), (M3·z) +//! ccs.selectors = vec![F::ONE, F::from(-1)]; // Coefficients: 1, -1 +//! ``` +//! +//! # Features +//! +//! - Support for constraints up to arbitrary degree +//! - Efficient sparse matrix operations +//! - Verification of constraint satisfaction +//! - Pretty-printing of constraint systems +//! +//! # Implementation Details +//! +//! The generic CCS uses: +//! - Single field elements for selectors +//! - Multisets to specify which matrices participate in each term +//! - Sparse matrices for efficient variable selection +//! - Combined input vector z = (w, 1, x) where: +//! - w is the witness vector +//! - 1 is a constant term +//! - x is the public input vector +//! +//! The system can represent arbitrary degree constraints through +//! the `new_degree` constructor, which sets up the appropriate +//! number of matrices and terms for constraints up to the specified +//! degree. use super::*; -/// A Customizable Constraint System over a field F. -#[derive(Debug, Default)] -pub struct CCS { - /// Constants for each constraint term - pub constants: Vec, - /// Sets of matrix indices for Hadamard products - pub multisets: Vec>, - /// Constraint matrices - pub matrices: Vec>, -} +/// A type marker for the standard/generic CCS format with scalar constants as "selectors". +#[derive(Clone, Debug, Default)] +pub struct Generic(PhantomData); -impl CCS { - /// Creates a new empty CCS. - pub fn new() -> Self { - Self::default() - } +impl CCSType for Generic { + /// For Generic CCS, selectors are just single field elements + type Selectors = F; +} +impl CCS, F> { /// Checks if a witness and public input satisfy the constraint system. /// /// Forms vector z = (w, 1, x) and verifies that all constraints are satisfied. @@ -77,7 +131,7 @@ impl CCS { term *= products[idx][row]; } - let contribution = self.constants[i] * term; + let contribution = self.selectors[i] * term; sum += contribution; } @@ -99,7 +153,7 @@ impl CCS { pub fn new_degree(d: usize) -> Self { assert!(d >= 2, "Degree must be positive"); - let mut ccs = Self { constants: Vec::new(), multisets: Vec::new(), matrices: Vec::new() }; + let mut ccs = Self { selectors: Vec::new(), multisets: Vec::new(), matrices: Vec::new() }; // We'll create terms starting from highest degree down to degree 1 // For a degree d CCS, we need terms of all degrees from d down to 1 @@ -112,7 +166,7 @@ impl CCS { // Add this term's multiset and its coefficient ccs.multisets.push(matrix_indices); - ccs.constants.push(F::ONE); + ccs.selectors.push(F::ONE); // Update our tracking of matrix indices next_matrix_index += degree; @@ -132,7 +186,7 @@ impl CCS { } } -impl Display for CCS { +impl Display for CCS, F> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { writeln!(f, "Customizable Constraint System:")?; @@ -149,7 +203,7 @@ impl Display for CCS { // We expect multisets to come in pairs, each pair forming one constraint for i in 0..self.multisets.len() { // Write the constant for the first multiset - write!(f, "{}·(", self.constants[i])?; + write!(f, "{}·(", self.selectors[i])?; // Write the Hadamard product for the first multiset if let Some(first_idx) = self.multisets[i].first() { @@ -193,11 +247,11 @@ mod tests { println!("M2 (selects y): {m2:?}"); println!("M3 (selects z): {m3:?}"); - let mut ccs = CCS::new(); + let mut ccs = CCS::, _>::new(); ccs.matrices = vec![m1, m2, m3]; // Encode x * y - z = 0 ccs.multisets = vec![vec![0, 1], vec![2]]; - ccs.constants = vec![F17::ONE, F17::from(-1)]; + ccs.selectors = vec![F17::ONE, F17::from(-1)]; println!("\nTesting valid case: x=2, y=3, z=6"); let x = vec![F17::from(2)]; // public input x = 2 diff --git a/src/ccs/mod.rs b/src/ccs/mod.rs new file mode 100644 index 0000000..d945031 --- /dev/null +++ b/src/ccs/mod.rs @@ -0,0 +1,49 @@ +//! Implements the Customizable Constraint System (CCS) format. +//! +//! A CCS represents arithmetic constraints through a combination of matrices +//! and multisets, allowing efficient verification of arithmetic computations. +//! +//! The system consists of: +//! - A set of sparse matrices representing linear combinations +//! - Multisets defining which matrices participate in each constraint +//! - Constants applied to each constraint term + +use std::marker::PhantomData; + +use matrix::SparseMatrix; + +use super::*; + +pub mod generic; +pub mod plonkish; + +/// A trait for configuring different types of Customizable Constraint Systems (CCS). +/// +/// This trait allows different CCS variants to specify their selector types. +/// Different CCS designs can use different types of selectors: +/// - Generic CCS uses single field elements as selectors +/// - Plonkish CCS uses vectors of field elements for multi-constraint support +/// - Other variants might use matrices or more complex structures +/// +/// The selector type must implement Default to provide a zero/empty value +/// when initializing a new CCS. +pub trait CCSType { + /// The type of selectors used in this CCS variant. + type Selectors: Default; +} + +/// A Customizable Constraint System over a field F. +#[derive(Debug, Default)] +pub struct CCS, F: Field> { + /// Constants for each constraint term + pub selectors: Vec, + /// Sets of matrix indices for Hadamard products + pub multisets: Vec>, + /// Constraint matrices + pub matrices: Vec>, +} + +impl + Default, F: Field> CCS { + /// Creates a new empty CCS. + pub fn new() -> Self { Self::default() } +} diff --git a/src/ccs/plonkish.rs b/src/ccs/plonkish.rs new file mode 100644 index 0000000..efcfb69 --- /dev/null +++ b/src/ccs/plonkish.rs @@ -0,0 +1,675 @@ +//! PLONK-style Customizable Constraint Systems (CCS). +//! +//! This module implements a variant of CCS that follows the PLONK (Permutations over +//! Lagrange-bases for Oecumenical Noninteractive arguments of Knowledge) design pattern. +//! The constraint system has the form: +//! +//! ```text +//! sum_{i, 1>; +//! +//! // Create a system for the constraint x * y + z = 0 +//! let mut ccs = CCS::, F>::new_width(3); +//! let c = ccs.add_constraint(); +//! +//! // Set up matrices to select variables +//! let mut a1 = SparseMatrix::new_rows_cols(1, 3); +//! a1.write(0, 0, F::ONE); // Select x +//! ccs.matrices[0] = a1; +//! +//! let mut a2 = SparseMatrix::new_rows_cols(1, 3); +//! a2.write(0, 1, F::ONE); // Select y +//! ccs.matrices[1] = a2; +//! +//! let mut a3 = SparseMatrix::new_rows_cols(1, 3); +//! a3.write(0, 2, F::ONE); // Select z +//! ccs.matrices[2] = a3; +//! +//! // Set coefficients +//! ccs.set_cross_term(0, 1, c, F::ONE); // x * y +//! ccs.set_linear(2, c, F::ONE); // + z +//! ``` + +use super::*; + +/// A type marker for PLONK-style constraint systems. +/// +/// This type configures a CCS to use vector-valued selectors suitable for +/// PLONK-style constraints where each selector holds coefficients for multiple +/// constraints. +#[derive(Clone, Debug, Default)] +pub struct Plonkish(PhantomData); +impl CCSType for Plonkish { + type Selectors = Vec; +} + +impl CCS, F> { + /// Creates a new Plonkish CCS with the specified width. + /// + /// The width determines the number of selector matrices A_i in the system. + /// Each matrix can select different variables from the input vector z. + /// Cross terms (multiplications) are only allowed between different matrices. + /// + /// # Arguments + /// * `width` - Number of matrices A_i (must be >= 2) + /// + /// # Panics + /// If width < 2 + /// + /// # Example + /// ``` + /// use custom_constraints::ccs::{plonkish::Plonkish, CCS}; + /// # use ark_ff::{Field, Fp, MontBackend, MontConfig}; + /// # #[derive(MontConfig)] + /// # #[modulus = "17"] + /// # #[generator = "3"] + /// # struct FConfig; + /// # type F = Fp, 1>; + /// let ccs = CCS::, F>::new_width(3); + /// ``` + pub fn new_width(width: usize) -> Self { + assert!(width >= 2, "Width must be at least 2"); + + let mut ccs = Self::default(); + + // Initialize matrices with no rows + for _ in 0..width { + ccs.matrices.push(SparseMatrix::new_rows_cols(0, 0)); + } + + // Set up multisets + for i in 0..width { + for j in (i + 1)..width { + ccs.multisets.push(vec![i, j]); + } + } + for i in 0..width { + ccs.multisets.push(vec![i]); + } + ccs.multisets.push(vec![]); + + // Initialize selectors with empty vectors + let num_cross_terms = (width * (width - 1)) / 2; + let num_terms = num_cross_terms + width + 1; + ccs.selectors = vec![vec![]; num_terms]; + + ccs + } + + /// Adds a new constraint to the system. + /// + /// This extends all matrices with a new row and all selectors with a new + /// coefficient initialized to zero. The new constraint can then be configured + /// using set_cross_term, set_linear, and set_constant. + /// + /// # Returns + /// The index of the new constraint (0-based) + /// + /// # Example + /// ``` + /// use custom_constraints::ccs::{plonkish::Plonkish, CCS}; + /// # use ark_ff::{Field, Fp, MontBackend, MontConfig}; + /// # #[derive(MontConfig)] + /// # #[modulus = "17"] + /// # #[generator = "3"] + /// # struct FConfig; + /// # type F = Fp, 1>; + /// let mut ccs = CCS::, F>::new_width(2); + /// let c1 = ccs.add_constraint(); // First constraint + /// let c2 = ccs.add_constraint(); // Second constraint + /// ``` + pub fn add_constraint(&mut self) -> usize { + // Get current number of constraints + let constraint_idx = self.matrices.first().map_or(0, |first| first.dimensions().0); + + // Add a new row to each matrix + for matrix in &mut self.matrices { + matrix.add_row(); + } + + // Add a zero coefficient for each selector + for selector in &mut self.selectors { + selector.push(F::ZERO); + } + + constraint_idx + } + + /// Sets a cross-term coefficient q_{i,j} for a specific constraint. + /// + /// This sets the coefficient for the term A_i·z ∘ A_j·z in the specified constraint. + /// + /// # Arguments + /// * `i` - First matrix index + /// * `j` - Second matrix index (must be different from i) + /// * `constraint_idx` - Index of the constraint to modify + /// * `value` - Coefficient value to set + /// + /// # Panics + /// - If i == j (cross terms must be between different matrices) + /// - If i or j are out of bounds + /// - If constraint_idx is out of bounds + pub fn set_cross_term(&mut self, i: usize, j: usize, constraint_idx: usize, value: F) { + assert!(i != j, "Cross terms must be between different matrices"); + let width = self.matrices.len(); + assert!(i < width && j < width, "Matrix index out of bounds"); + + // Ensure i < j for consistent indexing + let (i, j) = if i < j { (i, j) } else { (j, i) }; + + // Calculate index for the cross term + let idx = (i * (2 * width - i - 1)) / 2 + (j - i - 1); + + if let Some(selector) = self.selectors.get_mut(idx) { + if let Some(coeff) = selector.get_mut(constraint_idx) { + *coeff = value; + } + } + } + + /// Sets a linear term coefficient q_i for a specific constraint. + /// + /// This sets the coefficient for the term A_i·z in the specified constraint. + /// + /// # Arguments + /// * `i` - Matrix index + /// * `constraint_idx` - Index of the constraint to modify + /// * `value` - Coefficient value to set + /// + /// # Panics + /// - If i is out of bounds + /// - If constraint_idx is out of bounds + pub fn set_linear(&mut self, i: usize, constraint_idx: usize, value: F) { + let width = self.matrices.len(); + assert!(i < width, "Matrix index out of bounds"); + + let num_cross_terms = (width * (width - 1)) / 2; + if let Some(selector) = self.selectors.get_mut(num_cross_terms + i) { + if let Some(coeff) = selector.get_mut(constraint_idx) { + *coeff = value; + } + } + } + + /// Sets the constant term q_c for a specific constraint. + /// + /// # Arguments + /// * `constraint_idx` - Index of the constraint to modify + /// * `value` - Constant value to set + /// + /// # Panics + /// If constraint_idx is out of bounds + pub fn set_constant(&mut self, constraint_idx: usize, value: F) { + if let Some(selector) = self.selectors.last_mut() { + if let Some(coeff) = selector.get_mut(constraint_idx) { + *coeff = value; + } + } + } + + /// Helper to calculate number of cross terms + fn num_cross_terms(&self) -> usize { + let width = self.matrices.len(); + (width * (width - 1)) / 2 + } + + /// Checks if a witness and public input satisfy the Plonkish constraint system. + /// The constraint has the form: + /// sum_{i bool { + let mut z = Vec::with_capacity(x.len() + w.len()); + z.extend(x.iter().copied()); + z.extend(w.iter().copied()); + + let products: Vec> = self + .matrices + .iter() + .enumerate() + .map(|(i, matrix)| { + let result = matrix * &z; + println!("A_{i}·z = {result:?}"); + result + }) + .collect(); + + let m = if let Some(first) = products.first() { + first.len() + } else { + return true; + }; + + for row in 0..m { + let mut sum = F::ZERO; + let width = self.matrices.len(); + let mut term_idx = 0; + + // Process quadratic terms (i < j) + for i in 0..width { + for j in (i + 1)..width { + if let Some(selector) = self.selectors.get(term_idx) { + let term = products[i][row] * products[j][row]; + for &coeff in selector { + sum += coeff * term; + } + } + term_idx += 1; + } + } + + // Process linear terms + products.iter().take(width).zip(self.selectors.iter().skip(self.num_cross_terms())).for_each( + |(product, selector)| { + let term = product[row]; + for &coeff in selector { + sum += coeff * term; + } + }, + ); + + // Add constant term + if let Some(selector) = self.selectors.last() { + for &coeff in selector { + sum += coeff; + } + } + + println!("Row {row}: sum = {sum:?}"); + if sum != F::ZERO { + return false; + } + } + + true + } +} + +impl Display for CCS, F> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let width = self.matrices.len(); + + writeln!(f, "Plonkish Constraint System (width = {width}):")?; + + // Display matrices + writeln!(f, "\nMatrices:")?; + for (i, matrix) in self.matrices.iter().enumerate() { + writeln!(f, "A_{i} =")?; + writeln!(f, "{matrix}")?; + } + + // Display selectors + writeln!(f, "\nSelectors:")?; + let mut term_idx = 0; + + // Display cross term selectors + for i in 0..width { + for j in (i + 1)..width { + if let Some(selector) = self.selectors.get(term_idx) { + write!(f, "q_{i},{j} = [")?; + for (idx, &coeff) in selector.iter().enumerate() { + if idx > 0 { + write!(f, ", ")?; + } + write!(f, "{coeff}")?; + } + writeln!(f, "]")?; + } + term_idx += 1; + } + } + + // Display linear term selectors + for i in 0..width { + if let Some(selector) = self.selectors.get(term_idx) { + write!(f, "q_{i} = [")?; + for (idx, &coeff) in selector.iter().enumerate() { + if idx > 0 { + write!(f, ", ")?; + } + write!(f, "{coeff}")?; + } + writeln!(f, "]")?; + } + term_idx += 1; + } + + if let Some(selector) = self.selectors.last() { + write!(f, "q_c = [")?; + for (idx, &coeff) in selector.iter().enumerate() { + if idx > 0 { + write!(f, ", ")?; + } + write!(f, "{coeff}")?; + } + writeln!(f, "]")?; + } + + // Display constraint equation + writeln!(f, "\nConstraint equation:")?; + + let mut first_term = true; + term_idx = 0; + + // Display cross terms (i != j) + for i in 0..width { + for j in (i + 1)..width { + if let Some(selector) = self.selectors.get(term_idx) { + if !selector.iter().all(|&x| x == F::ZERO) { + if !first_term { + write!(f, " + ")?; + } + write!(f, "q_{i},{j}·(A_{i}·z ∘ A_{j}·z)")?; + first_term = false; + } + } + term_idx += 1; + } + } + + // Display linear terms + for i in 0..width { + if let Some(selector) = self.selectors.get(term_idx) { + if !selector.iter().all(|&x| x == F::ZERO) { + if !first_term { + write!(f, " + ")?; + } + write!(f, "q_{i}·(A_{i}·z)")?; + first_term = false; + } + } + term_idx += 1; + } + + if self.selectors.last().is_some() { + write!(f, " + q_c")?; + } + + writeln!(f, " = 0")?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::mock::F17; + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_plonkish_structure() { + let ccs = CCS::, F17>::new_width(3); + + // For width 3, we should have: + // - 3 cross terms (1,2), (1,3), (2,3) + // - 3 linear terms + // - 1 constant term + assert_eq!(ccs.multisets.len(), 7, "Should have 6 terms total"); + + // Check cross term multisets + assert_eq!(ccs.multisets[0], vec![0, 1], "First cross term incorrect"); + assert_eq!(ccs.multisets[1], vec![0, 2], "Second cross term incorrect"); + assert_eq!(ccs.multisets[2], vec![1, 2], "Third cross term incorrect"); + + // Check linear term multisets + assert_eq!(ccs.multisets[3], vec![0], "First linear term incorrect"); + assert_eq!(ccs.multisets[4], vec![1], "Second linear term incorrect"); + assert_eq!(ccs.multisets[5], vec![2], "Third linear term incorrect"); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_plonkish_display() { + let mut ccs = CCS::, F17>::new_width(2); + + // Set up display for one constraint + ccs.add_constraint(); + + // Set up test matrices + let mut a1 = SparseMatrix::new_rows_cols(1, 4); + a1.write(0, 0, F17::ONE); + ccs.matrices[0] = a1; + + let mut a2 = SparseMatrix::new_rows_cols(1, 4); + a2.write(0, 1, F17::ONE); + ccs.matrices[1] = a2; + + // Set some coefficients + ccs.set_cross_term(0, 1, 0, F17::from(3)); // 3(A_1·z)(A_2·z) + ccs.set_linear(0, 0, F17::from(4)); // 4(A_1·z) + ccs.set_linear(1, 0, F17::from(5)); // 5(A_2·z) + + println!("{ccs}"); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_plonkish_satisfaction() { + let mut ccs = CCS::, F17>::new_width(2); + + // Test one constraint + ccs.add_constraint(); + + // Set up matrices for x * y + 2x + 3y + 4 = 0 + let mut a1 = SparseMatrix::new_rows_cols(1, 2); + a1.write(0, 0, F17::ONE); // Select x + ccs.matrices[0] = a1; + + let mut a2 = SparseMatrix::new_rows_cols(1, 2); + a2.write(0, 1, F17::ONE); // Select y + ccs.matrices[1] = a2; + + // Set coefficients + ccs.set_cross_term(0, 1, 0, F17::ONE); // 1 * (x * y) + ccs.set_linear(0, 0, F17::from(2)); // + 2x + ccs.set_linear(1, 0, F17::from(3)); // + 3y + ccs.set_constant(0, F17::from(8)); // + 4 + + println!("ccs: {ccs}"); + + // With: + // x = 4, y = 5 + // 4 * 5 + 2*4 + 3*5 + 8 = 51 ≡ 0 (mod 17) + let x = vec![]; + let w = vec![F17::from(4), F17::from(5)]; + + // Let's print the computation + println!("\nVerifying computation:"); + let prod = F17::from(4) * F17::from(5); // x * y + let lin1 = F17::from(2) * F17::from(4); // 2x + let lin2 = F17::from(3) * F17::from(5); // 3y + let constant = F17::from(8); // 4 + println!("x * y = {prod}"); + println!("2x = {lin1}"); + println!("3y = {lin2}"); + println!("constant = {constant}"); + println!("sum = {}", prod + lin1 + lin2 + constant); + + assert!(ccs.is_satisfied(&x, &w)); + + // Test with invalid assignment + let w = vec![F17::from(2), F17::from(3)]; + assert!(!ccs.is_satisfied(&x, &w)); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_plonkish_simple() { + let mut ccs = CCS::, F17>::new_width(2); + + // Test one simple constraint + ccs.add_constraint(); + + // Set up matrices for x * y + 1 = 0 + let mut a1 = SparseMatrix::new_rows_cols(1, 2); + a1.write(0, 0, F17::ONE); // Select x + ccs.matrices[0] = a1; + + let mut a2 = SparseMatrix::new_rows_cols(1, 2); + a2.write(0, 1, F17::ONE); // Select y + ccs.matrices[1] = a2; + + // Set coefficients + ccs.set_cross_term(0, 1, 0, F17::ONE); // x * y + ccs.set_constant(0, F17::ONE); // + 1 + + println!("ccs: {ccs}"); + + // 16 * 16 + 1 = 257 ≡ 0 (mod 17) + let x = vec![]; + let w = vec![-F17::from(1), F17::from(1)]; + assert!(ccs.is_satisfied(&x, &w)); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_plonkish_width3() { + let mut ccs = CCS::, F17>::new_width(3); + + // Let's create a constraint: + // (x * y) + (y * z) + (x * z) + 2x + 3y + 4z + 5 = 0 + ccs.add_constraint(); + + // Set up matrices + let mut a0 = SparseMatrix::new_rows_cols(1, 3); + a0.write(0, 0, F17::ONE); // Select x + ccs.matrices[0] = a0; + + let mut a1 = SparseMatrix::new_rows_cols(1, 3); + a1.write(0, 1, F17::ONE); // Select y + ccs.matrices[1] = a1; + + let mut a2 = SparseMatrix::new_rows_cols(1, 3); + a2.write(0, 2, F17::ONE); // Select z + ccs.matrices[2] = a2; + + // Set cross terms + ccs.set_cross_term(0, 1, 0, F17::ONE); // x * y + ccs.set_cross_term(1, 2, 0, F17::ONE); // y * z + ccs.set_cross_term(0, 2, 0, F17::ONE); // x * z + + // Set linear terms + ccs.set_linear(0, 0, F17::from(2)); // 2x + ccs.set_linear(1, 0, F17::from(3)); // 3y + ccs.set_linear(2, 0, F17::from(4)); // 4z + + // Set constant term + ccs.set_constant(0, -F17::from(4)); // - 4 + + println!("ccs: {ccs}"); + + // Let's print the computation + println!("\nVerifying computation:"); + let xy = F17::from(2) * F17::from(3); + let yz = F17::from(3) * F17::from(4); + let xz = F17::from(2) * F17::from(4); + let x_term = F17::from(2) * F17::from(2); + let y_term = F17::from(3) * F17::from(3); + let z_term = F17::from(4) * F17::from(4); + let constant = -F17::from(4); + + println!("x * y = {xy}"); + println!("y * z = {yz}"); + println!("x * z = {xz}"); + println!("2x = {x_term}"); + println!("3y = {y_term}"); + println!("4z = {z_term}"); + println!("constant = {constant}"); + println!("sum = {}", xy + yz + xz + x_term + y_term + z_term + constant); + + let x = vec![]; + + // Find solution where this equals 0 (mod 17) + // Solution: x = 2, y = 3, z = 1 + let w = vec![F17::from(2), F17::from(3), F17::from(4)]; + assert!(ccs.is_satisfied(&x, &w)); + + // Invalid assignment should fail + let w = vec![F17::from(1), F17::from(1), F17::from(1)]; + assert!(!ccs.is_satisfied(&x, &w)); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_multiple_constraints() { + let mut ccs = CCS::, F17>::new_width(3); + + // First constraint: x * y + z = 0 + let c1 = ccs.add_constraint(); + + // Second constraint: y * z + x = 0 + let c2 = ccs.add_constraint(); + + // Set up matrices + let mut a1 = SparseMatrix::new_rows_cols(2, 3); + a1.write(0, 0, F17::ONE); // x in first constraint + a1.write(1, 0, F17::ONE); // x in second constraint + ccs.matrices[0] = a1; + + let mut a2 = SparseMatrix::new_rows_cols(2, 3); + a2.write(0, 1, F17::ONE); // y in first constraint + a2.write(1, 1, F17::ONE); // y in second constraint + ccs.matrices[1] = a2; + + let mut a3 = SparseMatrix::new_rows_cols(2, 3); + a3.write(0, 2, F17::ONE); // z in first constraint + a3.write(1, 2, F17::ONE); // z in second constraint + ccs.matrices[2] = a3; + + // Set coefficients for first constraint: x * y + z + 12 = 0 + ccs.set_cross_term(0, 1, c1, F17::ONE); // x * y + ccs.set_linear(2, c1, F17::ONE); // + z + ccs.set_constant(c1, F17::from(12)); // + 12 + + // Set coefficients for second constraint: y * z + x + 10 = 0 + ccs.set_cross_term(1, 2, c2, F17::ONE); // y * z + ccs.set_linear(0, c2, F17::ONE); // + x + ccs.set_constant(c2, F17::from(10)); // + 10 + + println!("ccs: {ccs}"); + + // Test with satisfying assignment + // For first constraint: 1 * 2 + 3 + 12 ≡ 0 (mod 17) + // For second constraint: 2 * 3 + 1 + 10 ≡ 0 (mod 17) + let x = vec![]; + let w = vec![F17::from(1), F17::from(2), F17::from(3)]; + assert!(ccs.is_satisfied(&x, &w)); + + // Test with invalid assignment + let w = vec![F17::from(1), F17::from(1), F17::from(1)]; + assert!(!ccs.is_satisfied(&x, &w)); + } +} diff --git a/src/circuit/mod.rs b/src/circuit/mod.rs index 595df10..a0bb04d 100644 --- a/src/circuit/mod.rs +++ b/src/circuit/mod.rs @@ -5,15 +5,15 @@ //! 2. DegreeConstrained: Circuit with enforced degree bounds //! 3. Optimized: Circuit after optimization passes -use super::*; - use std::{collections::HashMap, marker::PhantomData}; +use ccs::generic::Generic; + +use super::*; use crate::{ccs::CCS, matrix::SparseMatrix}; pub mod expression; -#[cfg(test)] -mod tests; +#[cfg(test)] mod tests; use self::expression::*; @@ -40,32 +40,32 @@ impl CircuitState for Optimized {} #[derive(Debug, Clone, Default)] pub struct Circuit { /// Number of public inputs - pub pub_inputs: usize, + pub pub_inputs: usize, /// Number of witness inputs - pub wit_inputs: usize, + pub wit_inputs: usize, /// Number of auxiliary variables - pub aux_count: usize, + pub aux_count: usize, /// Number of output variables pub output_count: usize, /// Circuit expressions and their assigned variables - expressions: Vec<(Expression, Variable)>, + expressions: Vec<(Expression, Variable)>, /// Memoization cache for expressions - memo: HashMap, + memo: HashMap, /// State type marker - _marker: PhantomData, + _marker: PhantomData, } impl Circuit { /// Creates a new empty circuit. pub fn new() -> Self { Self { - pub_inputs: 0, - wit_inputs: 0, - aux_count: 0, + pub_inputs: 0, + wit_inputs: 0, + aux_count: 0, output_count: 0, - expressions: Vec::new(), - memo: HashMap::new(), - _marker: PhantomData, + expressions: Vec::new(), + memo: HashMap::new(), + _marker: PhantomData, } } @@ -84,9 +84,7 @@ impl Circuit { } /// Creates a constant expression. - pub const fn constant(c: F) -> Expression { - Expression::Constant(c) - } + pub const fn constant(c: F) -> Expression { Expression::Constant(c) } /// Adds an internal auxiliary variable. pub fn add_internal(&mut self, expr: Expression) -> Expression { @@ -144,13 +142,13 @@ impl Circuit { } Circuit { - pub_inputs: self.pub_inputs, - wit_inputs: self.wit_inputs, - aux_count: self.aux_count, + pub_inputs: self.pub_inputs, + wit_inputs: self.wit_inputs, + aux_count: self.aux_count, output_count: self.output_count, - expressions: self.expressions, - memo: self.memo, - _marker: PhantomData, + expressions: self.expressions, + memo: self.memo, + _marker: PhantomData, } } @@ -244,12 +242,13 @@ impl Circuit { impl Circuit, F> { /// Converts circuit to CCS format. - pub fn into_ccs(self) -> CCS { + pub fn into_ccs(self) -> CCS, F> { let mut ccs = CCS::new_degree(D); // Calculate dimensions let num_cols = 1 + self.pub_inputs + self.wit_inputs + self.aux_count + self.output_count; + // TODO: Num rows does not need to equal the num cols // Initialize matrices for matrix in &mut ccs.matrices { *matrix = SparseMatrix::new_rows_cols(num_cols, num_cols); @@ -274,7 +273,7 @@ impl Circuit, F> { /// * `output` - Output variable fn create_constraint( &self, - ccs: &mut CCS, + ccs: &mut CCS, F>, d: usize, row: usize, expr: &Expression, @@ -285,17 +284,16 @@ impl Circuit, F> { ccs.matrices.last_mut().unwrap().write(row, output_pos, -F::ONE); match expr { - Expression::Add(terms) => { + Expression::Add(terms) => for term in terms { self.process_term(ccs, d, row, term); - } - }, + }, _ => self.process_term(ccs, d, row, expr), } } /// Processes term in constraint creation. - fn process_term(&self, ccs: &mut CCS, d: usize, row: usize, term: &Expression) { + fn process_term(&self, ccs: &mut CCS, F>, d: usize, row: usize, term: &Expression) { // First, fully expand the expression let expanded = expand_expression(term); @@ -311,7 +309,13 @@ impl Circuit, F> { } /// Processes a simple (non-compound) term. - fn process_simple_term(&self, ccs: &mut CCS, d: usize, row: usize, term: &Expression) { + fn process_simple_term( + &self, + ccs: &mut CCS, F>, + d: usize, + row: usize, + term: &Expression, + ) { match term { Expression::Mul(factors) => { // Collect constants and variables @@ -488,13 +492,13 @@ impl Circuit, F> { // Convert to optimized circuit Circuit { - pub_inputs: new_circuit.pub_inputs, - wit_inputs: new_circuit.wit_inputs, - aux_count: new_circuit.aux_count, + pub_inputs: new_circuit.pub_inputs, + wit_inputs: new_circuit.wit_inputs, + aux_count: new_circuit.aux_count, output_count: new_circuit.output_count, - expressions: new_circuit.expressions, - memo: new_circuit.memo, - _marker: PhantomData, + expressions: new_circuit.expressions, + memo: new_circuit.memo, + _marker: PhantomData, } } } @@ -554,7 +558,7 @@ fn multiply_expressions(a: &Expression, b: &Expression) -> Expre impl Circuit, F> { /// Converts and `Optimized` circuit into CCS. - pub fn into_ccs(self) -> CCS { + pub fn into_ccs(self) -> CCS, F> { let mut ccs = CCS::new_degree(D); // Calculate dimensions @@ -577,7 +581,7 @@ impl Circuit, F> { /// Creates a constraint in the constraint system fn create_constraint( &self, - ccs: &mut CCS, + ccs: &mut CCS, F>, d: usize, row: usize, expr: &Expression, @@ -588,17 +592,16 @@ impl Circuit, F> { ccs.matrices.last_mut().unwrap().write(row, output_pos, -F::ONE); match expr { - Expression::Add(terms) => { + Expression::Add(terms) => for term in terms { self.process_term(ccs, d, row, term); - } - }, + }, _ => self.process_term(ccs, d, row, expr), } } /// Processes term in constraint creation. - fn process_term(&self, ccs: &mut CCS, d: usize, row: usize, term: &Expression) { + fn process_term(&self, ccs: &mut CCS, F>, d: usize, row: usize, term: &Expression) { match term { Expression::Mul(factors) => { let degree = factors.len(); @@ -664,9 +667,7 @@ impl Circuit, F> { impl Circuit { /// Returns circuit expressions. - pub fn expressions(&self) -> &[(Expression, Variable)] { - &self.expressions - } + pub fn expressions(&self) -> &[(Expression, Variable)] { &self.expressions } // TODO: Should this really only be some kind of `#[cfg(test)]` fn? /// Expands an expression by substituting definitions. @@ -677,25 +678,21 @@ impl Circuit { | Expression::Variable(Variable::Public(_) | Variable::Witness(_)) => expr.clone(), // For auxiliary and output variables, look up their definition - Expression::Variable(var @ (Variable::Aux(_) | Variable::Output(_))) => { - self.get_definition(var).map_or_else(|| expr.clone(), |definition| self.expand(definition)) - }, + Expression::Variable(var @ (Variable::Aux(_) | Variable::Output(_))) => + self.get_definition(var).map_or_else(|| expr.clone(), |definition| self.expand(definition)), - Expression::Add(terms) => { - Expression::Add(terms.iter().map(|term| self.expand(term)).collect()) - }, - Expression::Mul(factors) => { - Expression::Mul(factors.iter().map(|factor| self.expand(factor)).collect()) - }, + Expression::Add(terms) => + Expression::Add(terms.iter().map(|term| self.expand(term)).collect()), + Expression::Mul(factors) => + Expression::Mul(factors.iter().map(|factor| self.expand(factor)).collect()), } } /// Gets definition for a variable if it exists. fn get_definition(&self, var: &Variable) -> Option<&Expression> { match var { - Variable::Aux(idx) | Variable::Output(idx) => { - self.expressions.get(*idx).map(|(expr, _)| expr) - }, + Variable::Aux(idx) | Variable::Output(idx) => + self.expressions.get(*idx).map(|(expr, _)| expr), _ => None, } } @@ -734,7 +731,7 @@ impl Circuit { } /// Computes the degree of an expression. -fn compute_degree(expr: &Expression) -> usize { +pub fn compute_degree(expr: &Expression) -> usize { match expr { // Constants are degree 0 Expression::Constant(_) => 0, diff --git a/src/lib.rs b/src/lib.rs index d0afb6f..bd44e5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,10 +15,10 @@ //! - [`CCS`](ccs::CCS): The customizable constraint system representation //! - [`SparseMatrix`](matrix::SparseMatrix): Efficient sparse matrix operations -use ark_ff::Field; -#[cfg(test)] -use mock::F17; use std::fmt::{self, Display, Formatter}; + +use ark_ff::Field; +#[cfg(test)] use mock::F17; #[cfg(all(target_arch = "wasm32", test))] use wasm_bindgen_test::wasm_bindgen_test; diff --git a/src/matrix.rs b/src/matrix.rs index 16333b8..e369f15 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -17,9 +17,9 @@ pub struct SparseMatrix { /// Column indices of non-zero elements col_indices: Vec, /// Values of non-zero elements - values: Vec, + values: Vec, /// Number of columns in the matrix - num_cols: usize, + num_cols: usize, } impl SparseMatrix { @@ -73,6 +73,44 @@ impl SparseMatrix { } } + /// Writes a value to the matrix, expanding its dimensions if necessary. + /// + /// If the specified position is outside the current matrix dimensions, + /// the matrix will be expanded to accommodate the new position. + /// + /// # Arguments + /// * `row` - Row index for the value + /// * `col` - Column index for the value + /// * `val` - Value to write + /// + /// # Panics + /// - If attempting to write a zero value + pub fn write_expand(&mut self, row: usize, col: usize, val: F) { + assert_ne!(val, F::ZERO, "Trying to add a zero element into the `SparseMatrix`!"); + + // Expand the matrix if necessary + if row >= self.row_offsets.len() - 1 { + // Add new row offsets, copying the last offset + let last_offset = *self.row_offsets.last().unwrap(); + self.row_offsets.resize(row + 2, last_offset); + } + if col >= self.num_cols { + self.num_cols = col + 1; + } + + // Now we can use the existing write logic + self.write(row, col, val); + } + + /// Returns the current dimensions of the matrix. + /// + /// # Returns + /// A tuple (rows, cols) representing the matrix dimensions + pub fn dimensions(&self) -> (usize, usize) { (self.row_offsets.len() - 1, self.num_cols) } + + /// Adds a new empty row to the matrix + pub fn add_row(&mut self) { self.row_offsets.push(*self.row_offsets.last().unwrap_or(&0)); } + #[allow(unused)] /// Removes an entry from the [`SparseMatrix`] fn remove(&mut self, row: usize, col: usize) { @@ -316,15 +354,89 @@ mod tests { // [6 0 0] // [0 6 0] // [0 0 10] - assert_eq!( - result.values, - [ - F17::from(6), // 2*3 at (0,0) - F17::from(6), // 3*2 at (1,1) - F17::from(10), // 5*2 at (2,2) - ] - ); + assert_eq!(result.values, [ + F17::from(6), // 2*3 at (0,0) + F17::from(6), // 3*2 at (1,1) + F17::from(10), // 5*2 at (2,2) + ]); assert_eq!(result.col_indices, [0, 1, 2]); assert_eq!(result.row_offsets, [0, 1, 2, 3]); } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_write_expand() { + // Create a 2x2 matrix + let mut matrix = SparseMatrix::new_rows_cols(2, 2); + + // Write within bounds + matrix.write_expand(0, 0, F17::from(1)); + matrix.write_expand(1, 1, F17::from(2)); + + // Write beyond current dimensions + matrix.write_expand(3, 4, F17::from(3)); + + // Check dimensions + let (rows, cols) = matrix.dimensions(); + assert_eq!(rows, 4); + assert_eq!(cols, 5); + + // Verify values + let expected_values = vec![(0, 0, F17::from(1)), (1, 1, F17::from(2)), (3, 4, F17::from(3))]; + + for (row, col, expected_val) in expected_values { + // Find the value in the sparse representation + let row_start = matrix.row_offsets[row]; + let row_end = matrix.row_offsets[row + 1]; + let pos = matrix.col_indices[row_start..row_end].iter().position(|&c| c == col); + + match pos { + Some(idx) => { + assert_eq!( + matrix.values[row_start + idx], + expected_val, + "Value mismatch at position ({}, {})", + row, + col + ); + }, + None => panic!("Expected value not found at position ({}, {})", row, col), + } + } + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_write_expand_multiple() { + let mut matrix = SparseMatrix::new_rows_cols(2, 2); + + // Write values in various orders to test expansion + matrix.write_expand(5, 3, F17::from(1)); + matrix.write_expand(2, 6, F17::from(2)); + matrix.write_expand(4, 1, F17::from(3)); + + let (rows, cols) = matrix.dimensions(); + assert_eq!(rows, 6); + assert_eq!(cols, 7); + + // Test that row offsets are properly maintained + assert_eq!(matrix.row_offsets.len(), rows + 1); + + // Verify that all rows between have valid offsets + for i in 0..rows { + assert!( + matrix.row_offsets[i] <= matrix.row_offsets[i + 1], + "Row offset invariant violated at row {}", + i + ); + } + } + + #[test] + #[should_panic(expected = "Trying to add a zero element")] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_write_expand_zero() { + let mut matrix = SparseMatrix::new_rows_cols(2, 2); + matrix.write_expand(3, 3, F17::from(0)); + } }