diff --git a/gemma/peft/_interceptors_test.py b/gemma/peft/_interceptors_test.py index bb7db5be..7dff5fd1 100644 --- a/gemma/peft/_interceptors_test.py +++ b/gemma/peft/_interceptors_test.py @@ -34,7 +34,7 @@ def __post_init__(self): @nn.compact def __call__(self, *args: Any, **kwargs: Any) -> Any: # Create an extra param. - self.param('extra_param', lambda _: jnp.zeros(())) + self.param('extra_param', nn.initializers.zeros, ()) # Wrapped the output, using the features as key to ensure we captured the # correct module.