diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 4dbb0ca36fa9..acc4a89824ff 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -38,6 +38,7 @@ import operator import re import warnings +import functools from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as _np @@ -1660,6 +1661,22 @@ def _impl_v1(cls, bb, inputs, attr, params): return relax.op.sqrt(inputs[0]) +def compute_broadcast_shape(shape_a, shape_b): + """Compute target shape for Multidirectional Broadcasting""" + rank = max(len(shape_a), len(shape_b)) + + a = (1,) * (rank - len(shape_a)) + tuple(shape_a) + b = (1,) * (rank - len(shape_b)) + tuple(shape_b) + + target = [] + for ai, bi in zip(a, b): + if ai == bi or ai == 1 or bi == 1: + target.append(max(ai, bi)) + else: + raise ValueError(f"Cannot broadcast {ai} and {bi}") + return tuple(target) + + class MultiInputBase(OnnxOpConverter): """Converts an onnx MultiInputBase node into an equivalent Relax expression.""" @@ -1675,9 +1692,12 @@ def _impl_v1(cls, bb, inputs, attr, params): output = cls.numpy_op(*np_inputs) # pylint: disable=not-callable return relax.const(output, output.dtype) - # Expand inputs, stack them, then perform minimum over the new axis. - inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in inputs] - stacked_tensor = relax.op.concat(inputs, axis=0) + input_shapes = [inp.struct_info.shape for inp in inputs] + target_shape = functools.reduce(compute_broadcast_shape, input_shapes) + + # broadcast_to, stack them, then perform minimum over the new axis. + inputs = [bb.normalize(relax.op.broadcast_to(i, target_shape)) for i in inputs] + stacked_tensor = bb.normalize(relax.op.stack(inputs, axis=0)) return cls.relax_op(stacked_tensor, axis=0) # pylint: disable=not-callable diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 6f5c7da5ef7e..037bb53eb39c 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -393,22 +393,53 @@ def test_mod(int_mode: bool): verify_binary_scalar("Mod", attrs={"fmod": fmod}, dtype=dtype) -@pytest.mark.parametrize("num_inputs", [1, 2, 4]) +SHAPE_PARAMS = [ + ([[32, 32], [32, 32]], [32, 32]), + ([[32, 1], [1, 2]], [32, 2]), + ( + [ + [ + 32, + ], + [ + 1, + ], + ], + [ + 32, + ], + ), + ([[32, 32, 1, 1], [1, 32, 32]], [32, 32, 32, 32]), + ( + [ + [32, 32, 1, 1], + [1, 32, 1], + [ + 32, + ], + ], + [32, 32, 32, 32], + ), +] + + +@pytest.mark.parametrize("input_shapes, expected_output_shape", SHAPE_PARAMS) @pytest.mark.parametrize("op_name", ["Min", "Max", "Sum", "Mean"]) -def test_multi_input(op_name: str, num_inputs: int): - input_shape = [32, 32] - input_var = ["i" + str(i) for i in range(num_inputs)] - input_values = [ - helper.make_tensor_value_info(var, TensorProto.FLOAT, input_shape) for var in input_var - ] - test_node = helper.make_node(op_name, input_var, ["c"]) +def test_multi_input_broadcasting(op_name, input_shapes, expected_output_shape): + num_inputs = len(input_shapes) + input_names = [f"i{i}" for i in range(num_inputs)] + + input_values_info = [] + for name, shape in zip(input_names, input_shapes): + input_values_info.append(helper.make_tensor_value_info(name, TensorProto.FLOAT, shape)) + test_node = helper.make_node(op_name, input_names, ["output"]) + output_info = helper.make_tensor_value_info("output", TensorProto.FLOAT, expected_output_shape) graph = helper.make_graph( [test_node], - "multi_input_test", - inputs=input_values, - outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, input_shape)], + f"multi_input_{op_name}_test", + inputs=input_values_info, + outputs=[output_info], ) - model = helper.make_model(graph, producer_name="multi_input_test") check_correctness(model)