From 1c0b146c488f675da618d519d061e202e8337aba Mon Sep 17 00:00:00 2001 From: Pierre Tholoniat Date: Fri, 23 Jan 2026 09:26:31 -0800 Subject: [PATCH] Use interior mutability for Decryptor PRNG too; add constructor; clean up imports. PiperOrigin-RevId: 860138561 --- willow/benches/shell_benchmarks.rs | 6 +---- willow/src/api/client.rs | 2 -- .../testing_utils/shell_testing_decryptor.rs | 20 ++++++++------ willow/src/traits/decryptor.rs | 4 +-- willow/src/willow_v1/BUILD | 9 +------ willow/src/willow_v1/client.rs | 2 -- willow/src/willow_v1/decryptor.rs | 26 ++++++++++++------- willow/src/willow_v1/server.rs | 6 +---- willow/src/willow_v1/verifier.rs | 6 +---- willow/tests/willow_v1_shell.rs | 26 +++++-------------- 10 files changed, 40 insertions(+), 67 deletions(-) diff --git a/willow/benches/shell_benchmarks.rs b/willow/benches/shell_benchmarks.rs index bc5fe59..d9fc07d 100644 --- a/willow/benches/shell_benchmarks.rs +++ b/willow/benches/shell_benchmarks.rs @@ -28,9 +28,7 @@ use messages::{ PartialDecryptionRequest, }; use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config}; -use prng_traits::SecurePrng; use server_traits::SecureAggregationServer; -use single_thread_hkdf::SingleThreadHkdfPrng; use testing_utils::{generate_random_nonce, generate_random_unsigned_vector}; use vahe_shell::ShellVahe; use verifier_traits::SecureAggregationVerifier; @@ -135,10 +133,8 @@ fn setup_base(args: &Args) -> BaseInputs { // Create decryptor. let vahe = ShellVahe::new(ahe_config.clone(), CONTEXT_STRING).unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); // Create server. let kahe = ShellKahe::new(kahe_config.clone(), CONTEXT_STRING).unwrap(); diff --git a/willow/src/api/client.rs b/willow/src/api/client.rs index 558b3cc..c6f505b 100644 --- a/willow/src/api/client.rs +++ b/willow/src/api/client.rs @@ -22,11 +22,9 @@ use client_traits::SecureAggregationClient; use kahe_shell::ShellKahe; use kahe_traits::KaheBase; use parameters_shell::create_shell_configs; -use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; use protobuf::prelude::*; use shell_ciphertexts_rust_proto::ShellAhePublicKey; -use single_thread_hkdf::SingleThreadHkdfPrng; use status::ffi::FfiStatus; use status::StatusError; use std::collections::HashMap; diff --git a/willow/src/testing_utils/shell_testing_decryptor.rs b/willow/src/testing_utils/shell_testing_decryptor.rs index 5db29f7..ae74447 100644 --- a/willow/src/testing_utils/shell_testing_decryptor.rs +++ b/willow/src/testing_utils/shell_testing_decryptor.rs @@ -31,6 +31,7 @@ use protobuf::prelude::*; use single_thread_hkdf::SingleThreadHkdfPrng; use status::ffi::FfiStatus; use status::{StatusError, StatusErrorCode}; +use std::cell::RefCell; use vahe_shell::ShellVahe; use vahe_traits::Recover; use vahe_traits::{HasVahe, VaheBase}; @@ -41,7 +42,7 @@ use vahe_traits::{HasVahe, VaheBase}; pub struct ShellTestingDecryptor { kahe: ShellKahe, vahe: ShellVahe, - prng: SingleThreadHkdfPrng, + prng: RefCell, secret_key: Option<::SecretKeyShare>, } @@ -64,14 +65,14 @@ impl ShellTestingDecryptor { let vahe = ShellVahe::new(ahe_config, context_bytes)?; let seed = SingleThreadHkdfPrng::generate_seed()?; let prng = SingleThreadHkdfPrng::create(&seed)?; - Ok(ShellTestingDecryptor { kahe, vahe, prng, secret_key: None }) + Ok(ShellTestingDecryptor { kahe, vahe, prng: RefCell::new(prng), secret_key: None }) } /// Generates a new AHE public key, and stores the corresponding secret key. pub fn generate_public_key( &mut self, ) -> Result<::PublicKey, StatusError> { - let (sk_share, pk_share, _) = self.vahe.key_gen(&mut self.prng)?; + let (sk_share, pk_share, _) = self.vahe.key_gen(&mut self.prng.borrow_mut())?; self.secret_key = Some(sk_share); let public_key = self.vahe.aggregate_public_key_shares(&[pk_share])?; Ok(public_key) @@ -81,7 +82,7 @@ impl ShellTestingDecryptor { /// the AHE ciphertext and then decrypting the KAHE ciphertext. Does not verify the client proof /// contained in the message. pub fn decrypt( - &mut self, + &self, client_message: &ClientMessage, ) -> Result<::Plaintext, StatusError> { let partial_dec_ciphertext = @@ -94,8 +95,11 @@ impl ShellTestingDecryptor { "No secret key available", )), Some(sk_share) => { - let partial_decryption = - self.vahe.partial_decrypt(&partial_dec_ciphertext, sk_share, &mut self.prng)?; + let partial_decryption = self.vahe.partial_decrypt( + &partial_dec_ciphertext, + sk_share, + &mut self.prng.borrow_mut(), + )?; let decrypted_kahe_key = self.vahe.recover(&partial_decryption, &rest_of_ciphertext, None)?; let decrypted_kahe_key = self.kahe.try_secret_key_from(decrypted_kahe_key)?; @@ -134,7 +138,7 @@ impl ShellTestingDecryptor { } fn decrypt_serialized( - &mut self, + &self, contribution: &[u8], ) -> Result, StatusError> { let client_message_proto = ClientMessageProto::parse(contribution) @@ -192,7 +196,7 @@ impl ShellTestingDecryptor { let partial_decryption = self.vahe.partial_decrypt( &request.partial_dec_ciphertext, sk_share, - &mut self.prng, + &mut self.prng.borrow_mut(), )?; Ok(PartialDecryptionResponse { partial_decryption }) } diff --git a/willow/src/traits/decryptor.rs b/willow/src/traits/decryptor.rs index 31f8934..b0d7dae 100644 --- a/willow/src/traits/decryptor.rs +++ b/willow/src/traits/decryptor.rs @@ -24,14 +24,14 @@ pub trait SecureAggregationDecryptor: HasVahe { /// Creates a public key share to be sent to the Server, updating the /// decryptor state. fn create_public_key_share( - &mut self, + &self, decryptor_state: &mut Self::DecryptorState, ) -> Result::Vahe>, StatusError>; /// Handles a partial decryption request received from the Server. Returns a /// partial decryption to the Server. fn handle_partial_decryption_request( - &mut self, + &self, partial_decryption_request: PartialDecryptionRequest<::Vahe>, decryptor_state: &Self::DecryptorState, ) -> Result::Vahe>, StatusError>; diff --git a/willow/src/willow_v1/BUILD b/willow/src/willow_v1/BUILD index edd7f8a..be432b5 100644 --- a/willow/src/willow_v1/BUILD +++ b/willow/src/willow_v1/BUILD @@ -44,12 +44,10 @@ rust_test( "//willow/src/api:aggregation_config", "//willow/src/shell:kahe_shell", "//willow/src/shell:parameters_shell", - "//willow/src/shell:single_thread_hkdf", "//willow/src/shell:vahe_shell", "//willow/src/testing_utils", "//willow/src/testing_utils:shell_testing_decryptor", "//willow/src/testing_utils:shell_testing_parameters", - "//willow/src/traits:prng_traits", ], ) @@ -59,11 +57,9 @@ rust_test( deps = [ "@crate_index//:googletest", "//willow/src/shell:parameters_shell", - "//willow/src/shell:single_thread_hkdf", "//willow/src/shell:vahe_shell", "//willow/src/traits:ahe_traits", "//willow/src/traits:decryptor_traits", - "//willow/src/traits:prng_traits", "//willow/src/traits:proto_serialization_traits", ], ) @@ -81,6 +77,7 @@ rust_library( "//willow/src/traits:ahe_traits", "//willow/src/traits:decryptor_traits", "//willow/src/traits:messages", + "//willow/src/traits:prng_traits", "//willow/src/traits:proto_serialization_traits", "//willow/src/traits:vahe_traits", ], @@ -96,13 +93,11 @@ rust_test( "@crate_index//:googletest", "//willow/src/shell:kahe_shell", "//willow/src/shell:parameters_shell", - "//willow/src/shell:single_thread_hkdf", "//willow/src/shell:vahe_shell", "//willow/src/testing_utils", "//willow/src/traits:ahe_traits", "//willow/src/traits:client_traits", "//willow/src/traits:decryptor_traits", - "//willow/src/traits:prng_traits", "//willow/src/traits:proto_serialization_traits", "//willow/src/traits:server_traits", "//willow/src/traits:verifier_traits", @@ -158,7 +153,6 @@ rust_test( "//shell_wrapper:status_matchers_rs", "//willow/src/shell:kahe_shell", "//willow/src/shell:parameters_shell", - "//willow/src/shell:single_thread_hkdf", "//willow/src/shell:vahe_shell", "//willow/src/testing_utils", "//willow/src/testing_utils:shell_testing_parameters", @@ -166,7 +160,6 @@ rust_test( "//willow/src/traits:client_traits", "//willow/src/traits:decryptor_traits", "//willow/src/traits:kahe_traits", - "//willow/src/traits:prng_traits", "//willow/src/traits:proto_serialization_traits", "//willow/src/traits:server_traits", "//willow/src/traits:vahe_traits", diff --git a/willow/src/willow_v1/client.rs b/willow/src/willow_v1/client.rs index d84f379..4716c0c 100644 --- a/willow/src/willow_v1/client.rs +++ b/willow/src/willow_v1/client.rs @@ -102,9 +102,7 @@ mod test { use googletest::{gtest, verify_eq, verify_that}; use kahe_shell::ShellKahe; use parameters_shell::create_shell_configs; - use prng_traits::SecurePrng; use shell_testing_decryptor::ShellTestingDecryptor; - use single_thread_hkdf::SingleThreadHkdfPrng; use std::collections::HashMap; use testing_utils::generate_random_nonce; use vahe_shell::ShellVahe; diff --git a/willow/src/willow_v1/decryptor.rs b/willow/src/willow_v1/decryptor.rs index 91803ba..81b023e 100644 --- a/willow/src/willow_v1/decryptor.rs +++ b/willow/src/willow_v1/decryptor.rs @@ -16,17 +16,19 @@ use ahe_traits::{AheKeygen, PartialDec}; use decryptor_traits::SecureAggregationDecryptor; use messages::{DecryptorPublicKeyShare, PartialDecryptionRequest, PartialDecryptionResponse}; use messages_rust_proto::DecryptorStateProto; +use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; use protobuf::AsView; use shell_ciphertexts_rust_proto::ShellAheSecretKeyShare; use status::StatusError; +use std::cell::RefCell; use vahe_traits::{EncryptVerify, HasVahe, VaheBase}; /// Lightweight decryptor directly exposing KAHE/VAHE types. It verifies only the client proofs, /// does not provide verifiable partial decryptions. pub struct WillowV1Decryptor { pub vahe: Vahe, - pub prng: Vahe::Rng, + pub prng: RefCell, } impl HasVahe for WillowV1Decryptor { @@ -36,6 +38,14 @@ impl HasVahe for WillowV1Decryptor { } } +impl WillowV1Decryptor { + pub fn new_with_randomly_generated_seed(vahe: Vahe) -> Result { + let seed = Vahe::Rng::generate_seed()?; + let prng = RefCell::new(Vahe::Rng::create(&seed)?); + Ok(Self { vahe, prng }) + } +} + pub struct DecryptorState { sk_share: Option, } @@ -97,10 +107,10 @@ where /// Creates a public key share to be sent to the Server, updating the /// decryptor state. fn create_public_key_share( - &mut self, + &self, decryptor_state: &mut Self::DecryptorState, ) -> Result, status::StatusError> { - let (sk_share, pk_share, _) = self.vahe.key_gen(&mut self.prng)?; + let (sk_share, pk_share, _) = self.vahe.key_gen(&mut self.prng.borrow_mut())?; decryptor_state.sk_share = Some(sk_share); Ok(pk_share) } @@ -108,7 +118,7 @@ where /// Handles a partial decryption request received from the Server. Returns a /// partial decryption to the Server. fn handle_partial_decryption_request( - &mut self, + &self, partial_decryption_request: PartialDecryptionRequest, decryptor_state: &Self::DecryptorState, ) -> Result, status::StatusError> { @@ -121,7 +131,7 @@ where let pd = self.vahe.partial_decrypt( &partial_decryption_request.partial_dec_ciphertext, sk_share, - &mut self.prng, + &mut self.prng.borrow_mut(), )?; Ok(PartialDecryptionResponse { partial_decryption: pd }) } @@ -134,9 +144,7 @@ mod tests { use decryptor_traits::SecureAggregationDecryptor; use googletest::{gtest, verify_true}; use parameters_shell::create_shell_ahe_config; - use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; - use single_thread_hkdf::SingleThreadHkdfPrng; use vahe_shell::ShellVahe; const CONTEXT_STRING: &[u8] = b"testing_context_string"; @@ -144,9 +152,7 @@ mod tests { #[gtest] fn decryptor_state_serialization_roundtrip() -> googletest::Result<()> { let vahe = ShellVahe::new(create_shell_ahe_config(1).unwrap(), CONTEXT_STRING).unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&seed)?; - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe)?; let mut decryptor_state = DecryptorState::default(); // Check empty state serialization. diff --git a/willow/src/willow_v1/server.rs b/willow/src/willow_v1/server.rs index a9d07b4..d2110aa 100644 --- a/willow/src/willow_v1/server.rs +++ b/willow/src/willow_v1/server.rs @@ -362,10 +362,8 @@ mod tests { use googletest::{gtest, verify_true}; use kahe_shell::ShellKahe; use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config}; - use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; use server_traits::SecureAggregationServer; - use single_thread_hkdf::SingleThreadHkdfPrng; use std::collections::HashMap; use testing_utils::{generate_aggregation_config, generate_random_nonce}; use vahe_shell::ShellVahe; @@ -400,10 +398,8 @@ mod tests { CONTEXT_STRING, ) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&seed)?; let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe)?; // Create server. let kahe = diff --git a/willow/src/willow_v1/verifier.rs b/willow/src/willow_v1/verifier.rs index aaecf34..c887bc0 100644 --- a/willow/src/willow_v1/verifier.rs +++ b/willow/src/willow_v1/verifier.rs @@ -271,10 +271,8 @@ mod tests { use kahe_shell::ShellKahe; use kahe_traits::KaheBase; use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config}; - use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; use server_traits::SecureAggregationServer; - use single_thread_hkdf::SingleThreadHkdfPrng; use status_matchers_rs::status_is; use std::collections::HashMap; use testing_utils::{generate_aggregation_config, generate_random_nonce}; @@ -314,10 +312,8 @@ mod tests { CONTEXT_STRING, ) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed()?; - let prng = SingleThreadHkdfPrng::create(&seed)?; let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe)?; // Create server. let kahe = diff --git a/willow/tests/willow_v1_shell.rs b/willow/tests/willow_v1_shell.rs index c0e2d6d..56ff8c9 100644 --- a/willow/tests/willow_v1_shell.rs +++ b/willow/tests/willow_v1_shell.rs @@ -24,10 +24,8 @@ use messages::{ PartialDecryptionRequest, PartialDecryptionResponse, }; use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config}; -use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; use server_traits::SecureAggregationServer; -use single_thread_hkdf::SingleThreadHkdfPrng; use status::StatusErrorCode; use status_matchers_rs::status_is; use std::collections::HashMap; @@ -64,10 +62,8 @@ fn encrypt_decrypt_one() -> googletest::Result<()> { let vahe = ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); // Create server. let vahe = @@ -152,10 +148,8 @@ fn encrypt_decrypt_one_serialized() -> googletest::Result<()> { let vahe = ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); // Create server. let kahe = @@ -292,10 +286,8 @@ fn encrypt_decrypt_multiple_clients() -> googletest::Result<()> { let vahe = ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); // Create server. let vahe = @@ -438,10 +430,8 @@ fn encrypt_decrypt_multiple_clients_including_invalid_proofs() -> googletest::Re let vahe = ShellVahe::new(create_shell_ahe_config(max_number_of_decryptors).unwrap(), CONTEXT_STRING) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); // Create server. let vahe = @@ -606,10 +596,8 @@ fn encrypt_decrypt_many_clients_decryptors() -> googletest::Result<()> { CONTEXT_STRING, ) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::default(); - let mut decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); // Decryptor generates public key share. let public_key_share = decryptor.create_public_key_share(&mut decryptor_state).unwrap(); @@ -732,10 +720,8 @@ fn encrypt_decrypt_no_dropout() -> googletest::Result<()> { CONTEXT_STRING, ) .unwrap(); - let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); - let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let decryptor_state = DecryptorState::default(); - let decryptor = WillowV1Decryptor { vahe, prng }; + let decryptor = WillowV1Decryptor::new_with_randomly_generated_seed(vahe).unwrap(); decryptor_states.push(decryptor_state); decryptors.push(decryptor); }