From d9fb6c671d191e74591ea18716e2893c71840b58 Mon Sep 17 00:00:00 2001 From: samay2504 Date: Wed, 3 Dec 2025 20:08:22 +0530 Subject: [PATCH] fix: update Module.param() call in interceptors test for Flax API compatibility The test was using the old Flax API signature self.param('name', lambda_init) which has been updated to require an explicit shape parameter. Updated to use self.param('name', initializer, shape) syntax with nn.initializers.zeros. This fixes test failures with current Flax versions where Module.param() requires the shape as a positional argument. Tests affected: - test_module - test_module_non_share_scope Both tests now pass successfully with the updated API. --- gemma/peft/_interceptors_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gemma/peft/_interceptors_test.py b/gemma/peft/_interceptors_test.py index bb7db5be..7dff5fd1 100644 --- a/gemma/peft/_interceptors_test.py +++ b/gemma/peft/_interceptors_test.py @@ -34,7 +34,7 @@ def __post_init__(self): @nn.compact def __call__(self, *args: Any, **kwargs: Any) -> Any: # Create an extra param. - self.param('extra_param', lambda _: jnp.zeros(())) + self.param('extra_param', nn.initializers.zeros, ()) # Wrapped the output, using the features as key to ensure we captured the # correct module.