From 36705b47c9bdc80d4b04812d6e1cbfee41f995f4 Mon Sep 17 00:00:00 2001 From: eclipse1605 Date: Mon, 15 Dec 2025 13:05:49 +0530 Subject: [PATCH 1/2] add logprob support for leaky ReLU switch transforms --- pymc/logprob/mixture.py | 152 +++++++++++++++++++++++++++++++ tests/logprob/test_transforms.py | 54 +++++++++++ 2 files changed, 206 insertions(+) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index f45f0ccb00..3c47a3cb10 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -47,7 +47,10 @@ from pytensor.ifelse import IfElse, ifelse from pytensor.scalar import Switch from pytensor.scalar import switch as scalar_switch +from pytensor.scalar.basic import GE, GT, LE, LT, Mul from pytensor.tensor.basic import Join, MakeVector, switch +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.random.rewriting import ( local_dimshuffle_rv_lift, local_rv_size_lift, @@ -80,7 +83,9 @@ measurable_ir_rewrites_db, subtensor_ops, ) +from pymc.logprob.transforms import MeasurableTransform from pymc.logprob.utils import ( + CheckParameterValue, check_potential_measurability, filter_measurable_variables, get_related_valued_nodes, @@ -407,6 +412,80 @@ class MeasurableSwitchMixture(MeasurableElemwise): measurable_switch_mixture = MeasurableSwitchMixture(scalar_switch) +class MeasurableLeakyReLUSwitch(MeasurableElemwise): + """A placeholder for leaky-ReLU graphs built via `switch(x > 0, x, a * x)`. + + this is an invertible, piecewise-linear transform of a single continuous measurable variable. + """ + + valid_scalar_types = (Switch,) + + +measurable_leaky_relu_switch = MeasurableLeakyReLUSwitch(scalar_switch) + + +def _is_x_positive_condition(cond: TensorVariable, x: TensorVariable) -> bool: + if cond.owner is None: + return False + if not isinstance(cond.owner.op, Elemwise): + return False + scalar_op = cond.owner.op.scalar_op + if not isinstance(scalar_op, GT | GE | LT | LE): + return False + + left, right = cond.owner.inputs + + def _is_zero(v: TensorVariable) -> bool: + try: + return pt.get_underlying_scalar_constant_value(v) == 0 + except NotScalarConstantError: + return False + + # x > 0 or x >= 0 + if left is x and _is_zero(right) and isinstance(scalar_op, GT | GE): + return True + # 0 < x or 0 <= x + if right is x and _is_zero(left) and isinstance(scalar_op, LT | LE): + return True + return False + + +def _extract_leaky_relu_slope( + neg_branch: TensorVariable, x: TensorVariable +) -> TensorVariable | None: + """Extract slope `a` from `neg_branch` assuming it represents `a * x`. + + supports both plain `Elemwise(Mul)` and `MeasurableTransform` scale rewrites. + """ + if neg_branch is x: + return pt.constant(1.0) + + if neg_branch.owner is None: + return None + + # handle case where `a * x` was already rewritten into a measurable scale transform + if isinstance(neg_branch.owner.op, MeasurableTransform): + op = neg_branch.owner.op + if not isinstance(op.scalar_op, Mul): + return None + # MeasurableTransform takes (measurable_input, scale) + if len(neg_branch.owner.inputs) != 2: + return None + if neg_branch.owner.inputs[op.measurable_input_idx] is not x: + return None + scale = neg_branch.owner.inputs[1 - op.measurable_input_idx] + return cast(TensorVariable, scale) + + # plain multiplication + if isinstance(neg_branch.owner.op, Elemwise) and isinstance(neg_branch.owner.op.scalar_op, Mul): + left, right = neg_branch.owner.inputs + if left is x: + return cast(TensorVariable, right) + if right is x: + return cast(TensorVariable, left) + return None + + @node_rewriter([switch]) def find_measurable_switch_mixture(fgraph, node): if isinstance(node.op, MeasurableOp): @@ -431,6 +510,46 @@ def find_measurable_switch_mixture(fgraph, node): return [measurable_switch_mixture(switch_cond, *components)] +@node_rewriter([switch]) +def find_measurable_leaky_relu_switch(fgraph, node): + """Detect `switch(x > 0, x, a * x)` and replace it by a measurable op. + + This enables a change-of-variables logprob derivation instead of treating it as a mixture. + """ + if isinstance(node.op, MeasurableOp): + return None + + cond, pos_branch, neg_branch = node.inputs + + if not filter_measurable_variables([pos_branch]): + return None + x = cast(TensorVariable, pos_branch) + + if x.type.dtype.startswith("int"): + return None + + if x.type.broadcastable != node.outputs[0].type.broadcastable: + return None + + if not _is_x_positive_condition(cast(TensorVariable, cond), x): + return None + + a = _extract_leaky_relu_slope(cast(TensorVariable, neg_branch), x) + if a is None: + return None + + if check_potential_measurability([a]): + return None + + return [ + measurable_leaky_relu_switch( + cast(TensorVariable, cond), + x, + cast(TensorVariable, neg_branch), + ) + ] + + @_logprob.register(MeasurableSwitchMixture) def logprob_switch_mixture(op, values, switch_cond, component_true, component_false, **kwargs): [value] = values @@ -442,6 +561,32 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa ) +@_logprob.register(MeasurableLeakyReLUSwitch) +def logprob_leaky_relu_switch(op, values, cond, x, neg_branch, **kwargs): + (value,) = values + + a = _extract_leaky_relu_slope(cast(TensorVariable, neg_branch), cast(TensorVariable, x)) + if a is None: + raise NotImplementedError("Could not extract leaky-ReLU slope") + + # enforce a > 0 at runtime to ensure invertibility and valid Jacobian + a = CheckParameterValue("leaky_relu slope > 0")(a, pt.all(pt.gt(a, 0))) + + # inverse mapping x(y) = y if y>0 else y/a + x_inv = pt.switch(pt.gt(value, 0), value, value / a) + + base_logp = _logprob_helper(x, x_inv, **kwargs) + + # jacobian term: 0 for y>0, -log(a) for y<=0 + jacobian = pt.switch(pt.gt(value, 0), pt.zeros_like(value), -pt.log(a)) + + if base_logp.ndim < value.ndim: + ndim_supp = value.ndim - base_logp.ndim + jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0))) + + return base_logp + jacobian + + measurable_ir_rewrites_db.register( "find_measurable_index_mixture", find_measurable_index_mixture, @@ -456,6 +601,13 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa "mixture", ) +measurable_ir_rewrites_db.register( + "find_measurable_leaky_relu_switch", + find_measurable_leaky_relu_switch, + "basic", + "transform", +) + class MeasurableIfElse(MeasurableOp, IfElse): """Measurable subclass of IfElse operator.""" diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index c9aeaa8abf..282b66858d 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -43,6 +43,8 @@ from pytensor.graph.basic import equal_computations +import pymc as pm + from pymc.distributions.continuous import Cauchy, ChiSquared from pymc.distributions.discrete import Bernoulli from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp @@ -219,6 +221,7 @@ def test_exp_transform_rv(): logp_fn(y_val), sp.stats.lognorm(s=1).logpdf(y_val), ) + np.testing.assert_almost_equal( logcdf_fn(y_val), sp.stats.lognorm(s=1).logcdf(y_val), @@ -229,6 +232,57 @@ def test_exp_transform_rv(): ) +def test_leaky_relu_switch_logp_scalar(): + a = 0.5 + x = pm.Normal.dist(mu=0, sigma=1) + y = pm.math.switch(x > 0, x, a * x) + + v_pos = 1.2 + np.testing.assert_allclose( + pm.logp(y, v_pos, warn_rvs=False).eval(), + pm.logp(x, v_pos, warn_rvs=False).eval(), + ) + + v_neg = -2.0 + np.testing.assert_allclose( + pm.logp(y, v_neg, warn_rvs=False).eval(), + pm.logp(x, v_neg / a, warn_rvs=False).eval() - np.log(a), + ) + + # boundary point (measure-zero for continuous RVs): should still produce a finite logp + assert np.isfinite(pm.logp(y, 0.0, warn_rvs=False).eval()) + + +def test_leaky_relu_switch_logp_vectorized(): + a = 0.5 + x = pm.Normal.dist(mu=0, sigma=1, size=(3,)) + y = pm.math.switch(x > 0, x, a * x) + + v = np.array([-2.0, 0.0, 1.5]) + expected = pm.logp(x, np.where(v > 0, v, v / a), warn_rvs=False).eval() + np.where( + v > 0, 0.0, -np.log(a) + ) + np.testing.assert_allclose(pm.logp(y, v, warn_rvs=False).eval(), expected) + + +def test_leaky_relu_switch_logp_symbolic_slope_checks_positive(): + a = pt.scalar("a") + x = pm.Normal.dist(mu=0, sigma=1) + y = pm.math.switch(x > 0, x, a * x) + + # positive slope passes + res = pm.logp(y, -1.0, warn_rvs=False).eval({a: 0.5}) + expected = pm.logp(x, -1.0 / 0.5, warn_rvs=False).eval() - np.log(0.5) + np.testing.assert_allclose(res, expected) + + # non pos slope raises + with pytest.raises(ParameterValueError, match="leaky_relu slope > 0"): + pm.logp(y, -1.0, warn_rvs=False).eval({a: -0.5}) + + with pytest.raises(ParameterValueError, match="leaky_relu slope > 0"): + pm.logp(y, -1.0, warn_rvs=False).eval({a: 0.0}) + + def test_log_transform_rv(): base_rv = pt.random.lognormal(0, 1, size=2, name="base_rv") y_rv = pt.log(base_rv) From 4d620a3e888f533d0fad4009a54f23920809f751 Mon Sep 17 00:00:00 2001 From: eclipse1605 Date: Tue, 16 Dec 2025 13:42:00 +0530 Subject: [PATCH 2/2] implemented gated branch logps for leaky relu switch --- pymc/logprob/mixture.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 3c47a3cb10..4c16ac24d7 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -521,6 +521,11 @@ def find_measurable_leaky_relu_switch(fgraph, node): cond, pos_branch, neg_branch = node.inputs + # we only mark the switch measurable once both branches are already measurable. + # so, the switch logprob can simply gate between branch logps (delegating inversion/Jacobian details to each branch). + if set(filter_measurable_variables([pos_branch, neg_branch])) != {pos_branch, neg_branch}: + return None + if not filter_measurable_variables([pos_branch]): return None x = cast(TensorVariable, pos_branch) @@ -569,22 +574,20 @@ def logprob_leaky_relu_switch(op, values, cond, x, neg_branch, **kwargs): if a is None: raise NotImplementedError("Could not extract leaky-ReLU slope") - # enforce a > 0 at runtime to ensure invertibility and valid Jacobian - a = CheckParameterValue("leaky_relu slope > 0")(a, pt.all(pt.gt(a, 0))) - - # inverse mapping x(y) = y if y>0 else y/a - x_inv = pt.switch(pt.gt(value, 0), value, value / a) + # enforce `a > 0` at runtime to ensure invertibility and to make the branch selection predicate depend only on the observed value. + a_is_positive = pt.all(pt.gt(a, 0)) - base_logp = _logprob_helper(x, x_inv, **kwargs) + # for `a > 0`, `switch(x > 0, x, a * x)` maps to disjoint regions in `value`: true branch -> value > 0, false branch -> value <= 0. + value_implies_true_branch = pt.gt(value, 0) - # jacobian term: 0 for y>0, -log(a) for y<=0 - jacobian = pt.switch(pt.gt(value, 0), pt.zeros_like(value), -pt.log(a)) - - if base_logp.ndim < value.ndim: - ndim_supp = value.ndim - base_logp.ndim - jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0))) + logp_expr = pt.switch( + value_implies_true_branch, + _logprob_helper(x, value, **kwargs), + _logprob_helper(neg_branch, value, **kwargs), + ) - return base_logp + jacobian + # attach the parameter check to the returned expression so it can't be optimized away. + return CheckParameterValue("leaky_relu slope > 0")(logp_expr, a_is_positive) measurable_ir_rewrites_db.register(