diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 9968eb5ed8f8..cc93d812c588 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3347,10 +3347,20 @@ def _impl_v1(cls, bb, inputs, attr, params): x = inputs[0] dtype = x.struct_info.dtype alpha = float(attr.get("alpha", 0.2)) - alpha = relax.const(alpha, dtype=dtype) + alpha_const = relax.const(alpha, dtype=dtype) beta = float(attr.get("beta", 0.5)) - beta = relax.const(beta, dtype=dtype) - return relax.op.clip(relax.op.add(relax.op.multiply(alpha, x), beta), 0, 1) + beta_const = relax.const(beta, dtype=dtype) + + is_nan = relax.op.not_equal(x, x) + + zero_const = relax.const(0.0, dtype=dtype) + x_safe = relax.op.where(is_nan, zero_const, x) + + transformed = relax.op.add(relax.op.multiply(alpha_const, x_safe), beta_const) + clipped = relax.op.clip(transformed, 0, 1) + + nan_const = relax.const(float("nan"), dtype=dtype) + return relax.op.where(is_nan, nan_const, clipped) class HardSwish(OnnxOpConverter): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index df94c13478cb..246e14d60904 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1089,6 +1089,31 @@ def test_hardsigmoid(): verify_unary("HardSigmoid", [1, 3, 20, 20], attrs={"alpha": 0.5, "beta": 0.6}) +def test_hardsigmoid_nan(): + """Test that HardSigmoid preserves NaN values in output.""" + test_node = helper.make_node("HardSigmoid", ["x"], ["y"]) + graph = helper.make_graph( + [test_node], + "hardsigmoid_nan_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 4])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 4])], + ) + + model = helper.make_model(graph, producer_name="hardsigmoid_nan_test") + + # Create input with NaN values + input_data = np.array( + [ + [np.nan, 0.5, -0.5, 1.0], + [0.0, np.nan, 2.0, -2.0], + [0.3, 0.7, np.nan, np.nan], + ], + dtype=np.float32, + ) + + check_correctness(model, inputs={"x": input_data}) + + def test_shrink(): verify_unary("Shrink", [32, 32]) verify_unary("Shrink", [32, 32], attrs={"lambd": 0.2, "bias": 0.1})