Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import operator
import re
import warnings
import functools
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as _np
Expand Down Expand Up @@ -1660,6 +1661,22 @@ def _impl_v1(cls, bb, inputs, attr, params):
return relax.op.sqrt(inputs[0])


def compute_broadcast_shape(shape_a, shape_b):
"""Compute target shape for Multidirectional Broadcasting"""
rank = max(len(shape_a), len(shape_b))

a = (1,) * (rank - len(shape_a)) + tuple(shape_a)
b = (1,) * (rank - len(shape_b)) + tuple(shape_b)

target = []
for ai, bi in zip(a, b):
if ai == bi or ai == 1 or bi == 1:
target.append(max(ai, bi))
else:
raise ValueError(f"Cannot broadcast {ai} and {bi}")
return tuple(target)


class MultiInputBase(OnnxOpConverter):
"""Converts an onnx MultiInputBase node into an equivalent Relax expression."""

Expand All @@ -1675,9 +1692,12 @@ def _impl_v1(cls, bb, inputs, attr, params):
output = cls.numpy_op(*np_inputs) # pylint: disable=not-callable
return relax.const(output, output.dtype)

# Expand inputs, stack them, then perform minimum over the new axis.
inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in inputs]
stacked_tensor = relax.op.concat(inputs, axis=0)
input_shapes = [inp.struct_info.shape for inp in inputs]
target_shape = functools.reduce(compute_broadcast_shape, input_shapes)

# broadcast_to, stack them, then perform minimum over the new axis.
inputs = [bb.normalize(relax.op.broadcast_to(i, target_shape)) for i in inputs]
stacked_tensor = bb.normalize(relax.op.stack(inputs, axis=0))
return cls.relax_op(stacked_tensor, axis=0) # pylint: disable=not-callable


Expand Down
55 changes: 43 additions & 12 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,22 +393,53 @@ def test_mod(int_mode: bool):
verify_binary_scalar("Mod", attrs={"fmod": fmod}, dtype=dtype)


@pytest.mark.parametrize("num_inputs", [1, 2, 4])
SHAPE_PARAMS = [
([[32, 32], [32, 32]], [32, 32]),
([[32, 1], [1, 2]], [32, 2]),
(
[
[
32,
],
[
1,
],
],
[
32,
],
),
([[32, 32, 1, 1], [1, 32, 32]], [32, 32, 32, 32]),
(
[
[32, 32, 1, 1],
[1, 32, 1],
[
32,
],
],
[32, 32, 32, 32],
),
]


@pytest.mark.parametrize("input_shapes, expected_output_shape", SHAPE_PARAMS)
@pytest.mark.parametrize("op_name", ["Min", "Max", "Sum", "Mean"])
def test_multi_input(op_name: str, num_inputs: int):
input_shape = [32, 32]
input_var = ["i" + str(i) for i in range(num_inputs)]
input_values = [
helper.make_tensor_value_info(var, TensorProto.FLOAT, input_shape) for var in input_var
]
test_node = helper.make_node(op_name, input_var, ["c"])
def test_multi_input_broadcasting(op_name, input_shapes, expected_output_shape):
num_inputs = len(input_shapes)
input_names = [f"i{i}" for i in range(num_inputs)]

input_values_info = []
for name, shape in zip(input_names, input_shapes):
input_values_info.append(helper.make_tensor_value_info(name, TensorProto.FLOAT, shape))
test_node = helper.make_node(op_name, input_names, ["output"])
output_info = helper.make_tensor_value_info("output", TensorProto.FLOAT, expected_output_shape)
graph = helper.make_graph(
[test_node],
"multi_input_test",
inputs=input_values,
outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, input_shape)],
f"multi_input_{op_name}_test",
inputs=input_values_info,
outputs=[output_info],
)

model = helper.make_model(graph, producer_name="multi_input_test")
check_correctness(model)

Expand Down