Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions python/tvm/relax/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,10 @@ def all_class_non_max_suppression(
`num_total_detection` of shape `(1,)` representing the total number of selected
boxes. The three values in `indices` encode batch, class, and box indices.
Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come
first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of
`batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection`
rows are valid.
first, in descending of scores, followed by boxes from batch 0, class 1 etc.
The output uses dynamic_strided_slice to trim to only valid detections,
so the first tensor has shape (num_total_detection, 3) containing only valid rows.

TODO: Implement true dynamic output shapes to match ONNX Runtime behavior exactly.
This would eliminate the need for manual trimming and improve memory efficiency.
If `output_format` is "tensorflow", the output is three tensors, the first
is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of
size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size
Expand Down
114 changes: 49 additions & 65 deletions python/tvm/relax/transform/legalize_ops/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,64 +15,27 @@
# specific language governing permissions and limitations
# under the License.
"""Default legalization function for vision network related operators."""
from tvm import topi, te
from tvm import relax
from tvm import relax, te, tir, topi

from ...block_builder import BlockBuilder
from ...expr import Call, Expr
from ...expr import Call, Expr, TupleGetItem
from .common import register_legalize


def _create_onnx_nms_te(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
"""Create a proper NMS implementation that follows the correct algorithm"""
scores_shape = list(scores.shape)
if len(scores_shape) == 3:
batch, num_classes, _ = scores_shape
elif len(scores_shape) == 2:
num_classes, _ = scores_shape
batch = 1
else:
raise ValueError(f"Unexpected scores shape: {scores_shape}")

if hasattr(max_output_boxes_per_class, "data"):
max_boxes = int(max_output_boxes_per_class.data.numpy())
else:
max_boxes = 3 # Default value

expected_detections = batch * num_classes * max_boxes

selected_indices_full, _ = topi.vision.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx"
)

def slice_to_onnx_shape(data, expected_size):
def compute_element(i, j):
return tvm.tir.if_then_else(i < expected_size, data[i, j], tvm.tir.Cast("int64", 0))

return te.compute((expected_size, 3), compute_element, name="sliced_indices")

sliced_indices = slice_to_onnx_shape(selected_indices_full, expected_detections)

actual_detections = te.compute(
(1,), lambda i: tvm.tir.Cast("int64", expected_detections), name="actual_detections"
)

return [sliced_indices, actual_detections]


@register_legalize("relax.vision.all_class_non_max_suppression")
def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr:
"""Legalize all_class_non_max_suppression with fixed shape output.

Note: This implementation outputs fixed-size tensors with trailing garbage data.
Only the first `num_total_detection` rows contain valid data. Users should use
the `valid_count` tensor to determine how many rows are actually valid.

For complete ONNX compatibility, users can post-process the output:
```python
selected_indices, valid_count = nms_output
actual_count = int(valid_count.numpy()[0])
valid_indices = selected_indices.numpy()[:actual_count, :]
```
"""Legalize all_class_non_max_suppression with dynamic output trimming.

This implementation uses dynamic_strided_slice to trim the NMS output to only
contain valid detections, improving memory efficiency and ONNX compatibility.

Returns
-------
result : Tuple[Tensor, Tensor]
A tuple of (trimmed_indices, num_total_detections) where:
- trimmed_indices: Tensor of shape (num_total_detections, 3) containing only
valid detection indices (batch_id, class_id, box_id)
- num_total_detections: Tensor of shape (1,) with the count of valid detections
"""
boxes = call.args[0]
scores = call.args[1]
Expand Down Expand Up @@ -105,16 +68,37 @@ def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> E
output_format,
)

# TODO: Implement dynamic output trimming for better memory efficiency
# Current approach returns fixed-size output with trailing garbage data
# Future improvements could include:
# 1. Dynamic strided_slice based on num_total_detections
# 2. Custom Relax operator with true dynamic shapes
# 3. VM builtin functions for runtime shape adjustment
# 4. Symbolic shape inference in Relax IR
#
# For now, users should trim manually:
# actual_count = int(num_total_detections.numpy()[0])
# valid_indices = selected_indices.numpy()[:actual_count, :]

return nms_result
# Dynamic output trimming using dynamic_strided_slice
# Extract selected_indices and num_total_detections from the NMS result
selected_indices = block_builder.emit(TupleGetItem(nms_result, 0))
num_total_detections = block_builder.emit(TupleGetItem(nms_result, 1))

# Build slicing parameters using TE to avoid high-level Relax ops during legalization
def build_begin():
return te.compute((2,), lambda i: tir.const(0, "int64"), name="begin")

def build_strides():
return te.compute((2,), lambda i: tir.const(1, "int64"), name="strides")

def build_end(count_tensor):
# end = [count_tensor[0], 3]
def compute_end(i):
return tir.if_then_else(
i == 0,
tir.Cast("int64", count_tensor[0]),
tir.const(3, "int64"),
)

return te.compute((2,), compute_end, name="end")

begin = block_builder.call_te(build_begin)
strides = block_builder.call_te(build_strides)
end = block_builder.call_te(build_end, num_total_detections)

# Apply dynamic strided slice to trim to valid detections only
trimmed_indices = block_builder.emit(
relax.op.dynamic_strided_slice(selected_indices, begin, end, strides)
)

# Return trimmed indices along with num_total_detections for compatibility
return relax.Tuple([trimmed_indices, num_total_detections])
98 changes: 94 additions & 4 deletions tests/python/relax/test_op_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# specific language governing permissions and limitations
# under the License.

import numpy as np
import pytest

import tvm
import tvm.testing
from tvm import relax, tir
from tvm import TVMError
from tvm.ir import Op, VDevice
from tvm import TVMError, relax, tir
from tvm.relax.transform import LegalizeOps
from tvm.script import relax as R


Expand Down Expand Up @@ -53,7 +54,6 @@ def test_all_class_non_max_suppression_infer_struct_info():


def test_all_class_non_max_suppression_wrong_input_number():
bb = relax.BlockBuilder()
boxes = relax.Var("boxes", R.Tensor((1, 5, 4), "float32"))
scores = relax.Var("scores", R.Tensor((1, 3, 5), "float32"))

Expand Down Expand Up @@ -86,5 +86,95 @@ def test_all_class_non_max_suppression_infer_struct_info_shape_var():
)


def test_all_class_non_max_suppression_legalize_dynamic_trim():
@tvm.script.ir_module
class NMSModule:
@R.function
def main(
boxes: R.Tensor((1, 5, 4), "float32"),
scores: R.Tensor((1, 2, 5), "float32"),
) -> R.Tuple(R.Tensor(dtype="int64", ndim=2), R.Tensor((1,), "int64")):
max_output_boxes_per_class = R.const(3, "int64")
iou_threshold = R.const(0.5, "float32")
score_threshold = R.const(0.1, "float32")
return R.vision.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx"
)

mod = LegalizeOps()(NMSModule)

# Check legalized function has dynamic output (uses dynamic_strided_slice)
assert "dynamic_strided_slice" in str(mod)

ret_sinfo = mod["main"].ret_struct_info
tvm.ir.assert_structural_equal(
ret_sinfo,
relax.TupleStructInfo(
[
relax.TensorStructInfo(ndim=2, dtype="int64"),
relax.TensorStructInfo((1,), "int64"),
]
),
)


def test_all_class_non_max_suppression_legalize_e2e():
@tvm.script.ir_module
class NMSModule:
@R.function
def main(
boxes: R.Tensor((1, 5, 4), "float32"),
scores: R.Tensor((1, 2, 5), "float32"),
) -> R.Tuple(R.Tensor(dtype="int64", ndim=2), R.Tensor((1,), "int64")):
max_output_boxes_per_class = R.const(3, "int64")
iou_threshold = R.const(0.5, "float32")
score_threshold = R.const(0.1, "float32")
return R.vision.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx"
)

boxes_data = np.array(
[
[
[0.0, 0.0, 1.0, 1.0],
[0.1, 0.1, 1.1, 1.1],
[2.0, 2.0, 3.0, 3.0],
[4.0, 4.0, 5.0, 5.0],
[6.0, 6.0, 7.0, 7.0],
]
],
dtype=np.float32,
)
scores_data = np.array(
[[[0.9, 0.8, 0.7, 0.6, 0.5], [0.85, 0.75, 0.65, 0.55, 0.45]]],
dtype=np.float32,
)

mod = LegalizeOps()(NMSModule)

# Check struct info
tvm.ir.assert_structural_equal(
mod["main"].ret_struct_info,
relax.TupleStructInfo(
[
relax.TensorStructInfo(ndim=2, dtype="int64"),
relax.TensorStructInfo((1,), "int64"),
]
),
)

# Check runtime execution
exe = tvm.compile(mod, target="llvm")
vm = relax.VirtualMachine(exe, tvm.cpu())
result = vm["main"](
tvm.runtime.tensor(boxes_data, tvm.cpu()),
tvm.runtime.tensor(scores_data, tvm.cpu()),
)

selected_indices = result[0].numpy()
num_total_detections = int(result[1].numpy()[0])
tvm.testing.assert_allclose(selected_indices.shape, (num_total_detections, 3))


if __name__ == "__main__":
tvm.testing.main()
Loading