From 9c7cff31e95dd385ce0943716d785f07bad853c5 Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Mon, 19 Jan 2026 16:32:33 +0700 Subject: [PATCH 1/4] [Relax][Onnx] Support Multi Input Ops with Multidirectional Broadcasting - Compute target shape for Multidirectional Broadcasting - Workflow: broadcast_to -> stack -> reduce ops --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 28 +++++++++++++-- tests/python/relax/test_frontend_onnx.py | 34 ++++++++++++------- 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 4dbb0ca36fa9..90ee6ac038a2 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1660,6 +1660,22 @@ def _impl_v1(cls, bb, inputs, attr, params): return relax.op.sqrt(inputs[0]) +def compute_broadcast_shape(shape_a, shape_b): + """Compute target shape for Multidirectional Broadcasting""" + rank = max(len(shape_a), len(shape_b)) + + a = (1,) * (rank - len(shape_a)) + tuple(shape_a) + b = (1,) * (rank - len(shape_b)) + tuple(shape_b) + + target = [] + for ai, bi in zip(a, b): + if ai == bi or ai == 1 or bi == 1: + target.append(max(ai, bi)) + else: + raise ValueError(f"Cannot broadcast {ai} and {bi}") + return tuple(target) + + class MultiInputBase(OnnxOpConverter): """Converts an onnx MultiInputBase node into an equivalent Relax expression.""" @@ -1675,9 +1691,15 @@ def _impl_v1(cls, bb, inputs, attr, params): output = cls.numpy_op(*np_inputs) # pylint: disable=not-callable return relax.const(output, output.dtype) - # Expand inputs, stack them, then perform minimum over the new axis. - inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in inputs] - stacked_tensor = relax.op.concat(inputs, axis=0) + input_shapes = [inp.struct_info.shape for inp in inputs] + current_target_shape = input_shapes[0] + for next_shape in input_shapes[1:]: + current_target_shape = compute_broadcast_shape(current_target_shape, next_shape) + print("target shape", current_target_shape) + + # broadcast_to, stack them, then perform minimum over the new axis. + inputs = [bb.normalize(relax.op.broadcast_to(i, current_target_shape)) for i in inputs] + stacked_tensor = bb.normalize(relax.op.stack(inputs, axis=0)) return cls.relax_op(stacked_tensor, axis=0) # pylint: disable=not-callable diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 6f5c7da5ef7e..708fcd380442 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -393,22 +393,32 @@ def test_mod(int_mode: bool): verify_binary_scalar("Mod", attrs={"fmod": fmod}, dtype=dtype) -@pytest.mark.parametrize("num_inputs", [1, 2, 4]) +SHAPE_PARAMS = [ + ([[32, 32], [32, 32]], [32, 32]), + ([[32, 1], [1, 2]], [32, 2]), + ([[32,], [1,]], [32,]), + ([[32, 32, 1, 1], [1, 32, 32]], [32, 32, 32, 32]), + ([[32, 32, 1, 1], [1, 32, 1], [32,]], [32, 32, 32, 32]) +] + + +@pytest.mark.parametrize("input_shapes, expected_output_shape", SHAPE_PARAMS) @pytest.mark.parametrize("op_name", ["Min", "Max", "Sum", "Mean"]) -def test_multi_input(op_name: str, num_inputs: int): - input_shape = [32, 32] - input_var = ["i" + str(i) for i in range(num_inputs)] - input_values = [ - helper.make_tensor_value_info(var, TensorProto.FLOAT, input_shape) for var in input_var - ] - test_node = helper.make_node(op_name, input_var, ["c"]) +def test_multi_input_broadcasting(op_name, input_shapes, expected_output_shape): + num_inputs = len(input_shapes) + input_names = [f"i{i}" for i in range(num_inputs)] + + input_values_info = [] + for name, shape in zip(input_names, input_shapes): + input_values_info.append(helper.make_tensor_value_info(name, TensorProto.FLOAT, shape)) + test_node = helper.make_node(op_name, input_names, ["output"]) + output_info = helper.make_tensor_value_info("output", TensorProto.FLOAT, expected_output_shape) graph = helper.make_graph( [test_node], - "multi_input_test", - inputs=input_values, - outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, input_shape)], + f"multi_input_{op_name}_test", + inputs=input_values_info, + outputs=[output_info], ) - model = helper.make_model(graph, producer_name="multi_input_test") check_correctness(model) From 1a4f984e38187234221aa7d4d11f37469d62c26e Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Mon, 19 Jan 2026 16:48:32 +0700 Subject: [PATCH 2/4] remove debug --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 90ee6ac038a2..9283b3a6366f 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1695,7 +1695,6 @@ def _impl_v1(cls, bb, inputs, attr, params): current_target_shape = input_shapes[0] for next_shape in input_shapes[1:]: current_target_shape = compute_broadcast_shape(current_target_shape, next_shape) - print("target shape", current_target_shape) # broadcast_to, stack them, then perform minimum over the new axis. inputs = [bb.normalize(relax.op.broadcast_to(i, current_target_shape)) for i in inputs] From 216036106e4777ccb80183e73bbdf93e2f898b9d Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Mon, 19 Jan 2026 17:46:34 +0700 Subject: [PATCH 3/4] update based on review --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 8 +++--- tests/python/relax/test_frontend_onnx.py | 25 +++++++++++++++++-- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 9283b3a6366f..4b5d9d519f4a 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1691,13 +1691,13 @@ def _impl_v1(cls, bb, inputs, attr, params): output = cls.numpy_op(*np_inputs) # pylint: disable=not-callable return relax.const(output, output.dtype) + import functools + input_shapes = [inp.struct_info.shape for inp in inputs] - current_target_shape = input_shapes[0] - for next_shape in input_shapes[1:]: - current_target_shape = compute_broadcast_shape(current_target_shape, next_shape) + target_shape = functools.reduce(compute_broadcast_shape, input_shapes) # broadcast_to, stack them, then perform minimum over the new axis. - inputs = [bb.normalize(relax.op.broadcast_to(i, current_target_shape)) for i in inputs] + inputs = [bb.normalize(relax.op.broadcast_to(i, target_shape)) for i in inputs] stacked_tensor = bb.normalize(relax.op.stack(inputs, axis=0)) return cls.relax_op(stacked_tensor, axis=0) # pylint: disable=not-callable diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 708fcd380442..037bb53eb39c 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -396,9 +396,30 @@ def test_mod(int_mode: bool): SHAPE_PARAMS = [ ([[32, 32], [32, 32]], [32, 32]), ([[32, 1], [1, 2]], [32, 2]), - ([[32,], [1,]], [32,]), + ( + [ + [ + 32, + ], + [ + 1, + ], + ], + [ + 32, + ], + ), ([[32, 32, 1, 1], [1, 32, 32]], [32, 32, 32, 32]), - ([[32, 32, 1, 1], [1, 32, 1], [32,]], [32, 32, 32, 32]) + ( + [ + [32, 32, 1, 1], + [1, 32, 1], + [ + 32, + ], + ], + [32, 32, 32, 32], + ), ] From 96735b20e407bd9aa51b0097c5e88a93b2508ca3 Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Mon, 19 Jan 2026 18:11:22 +0700 Subject: [PATCH 4/4] Import outside toplevel (functools) --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 4b5d9d519f4a..acc4a89824ff 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -38,6 +38,7 @@ import operator import re import warnings +import functools from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as _np @@ -1691,8 +1692,6 @@ def _impl_v1(cls, bb, inputs, attr, params): output = cls.numpy_op(*np_inputs) # pylint: disable=not-callable return relax.const(output, output.dtype) - import functools - input_shapes = [inp.struct_info.shape for inp in inputs] target_shape = functools.reduce(compute_broadcast_shape, input_shapes)