diff --git a/python/tvm/relax/op/vision/nms.py b/python/tvm/relax/op/vision/nms.py index 3714b00b01e2..4c50748bdbf7 100644 --- a/python/tvm/relax/op/vision/nms.py +++ b/python/tvm/relax/op/vision/nms.py @@ -54,12 +54,10 @@ def all_class_non_max_suppression( `num_total_detection` of shape `(1,)` representing the total number of selected boxes. The three values in `indices` encode batch, class, and box indices. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come - first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of - `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` - rows are valid. + first, in descending of scores, followed by boxes from batch 0, class 1 etc. + The output uses dynamic_strided_slice to trim to only valid detections, + so the first tensor has shape (num_total_detection, 3) containing only valid rows. - TODO: Implement true dynamic output shapes to match ONNX Runtime behavior exactly. - This would eliminate the need for manual trimming and improve memory efficiency. If `output_format` is "tensorflow", the output is three tensors, the first is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py index f910f62cec64..9511c130183a 100644 --- a/python/tvm/relax/transform/legalize_ops/vision.py +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -15,64 +15,27 @@ # specific language governing permissions and limitations # under the License. """Default legalization function for vision network related operators.""" -from tvm import topi, te -from tvm import relax +from tvm import relax, te, tir, topi + from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Call, Expr, TupleGetItem from .common import register_legalize -def _create_onnx_nms_te(boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold): - """Create a proper NMS implementation that follows the correct algorithm""" - scores_shape = list(scores.shape) - if len(scores_shape) == 3: - batch, num_classes, _ = scores_shape - elif len(scores_shape) == 2: - num_classes, _ = scores_shape - batch = 1 - else: - raise ValueError(f"Unexpected scores shape: {scores_shape}") - - if hasattr(max_output_boxes_per_class, "data"): - max_boxes = int(max_output_boxes_per_class.data.numpy()) - else: - max_boxes = 3 # Default value - - expected_detections = batch * num_classes * max_boxes - - selected_indices_full, _ = topi.vision.all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" - ) - - def slice_to_onnx_shape(data, expected_size): - def compute_element(i, j): - return tvm.tir.if_then_else(i < expected_size, data[i, j], tvm.tir.Cast("int64", 0)) - - return te.compute((expected_size, 3), compute_element, name="sliced_indices") - - sliced_indices = slice_to_onnx_shape(selected_indices_full, expected_detections) - - actual_detections = te.compute( - (1,), lambda i: tvm.tir.Cast("int64", expected_detections), name="actual_detections" - ) - - return [sliced_indices, actual_detections] - - @register_legalize("relax.vision.all_class_non_max_suppression") def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr: - """Legalize all_class_non_max_suppression with fixed shape output. - - Note: This implementation outputs fixed-size tensors with trailing garbage data. - Only the first `num_total_detection` rows contain valid data. Users should use - the `valid_count` tensor to determine how many rows are actually valid. - - For complete ONNX compatibility, users can post-process the output: - ```python - selected_indices, valid_count = nms_output - actual_count = int(valid_count.numpy()[0]) - valid_indices = selected_indices.numpy()[:actual_count, :] - ``` + """Legalize all_class_non_max_suppression with dynamic output trimming. + + This implementation uses dynamic_strided_slice to trim the NMS output to only + contain valid detections, improving memory efficiency and ONNX compatibility. + + Returns + ------- + result : Tuple[Tensor, Tensor] + A tuple of (trimmed_indices, num_total_detections) where: + - trimmed_indices: Tensor of shape (num_total_detections, 3) containing only + valid detection indices (batch_id, class_id, box_id) + - num_total_detections: Tensor of shape (1,) with the count of valid detections """ boxes = call.args[0] scores = call.args[1] @@ -105,16 +68,37 @@ def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> E output_format, ) - # TODO: Implement dynamic output trimming for better memory efficiency - # Current approach returns fixed-size output with trailing garbage data - # Future improvements could include: - # 1. Dynamic strided_slice based on num_total_detections - # 2. Custom Relax operator with true dynamic shapes - # 3. VM builtin functions for runtime shape adjustment - # 4. Symbolic shape inference in Relax IR - # - # For now, users should trim manually: - # actual_count = int(num_total_detections.numpy()[0]) - # valid_indices = selected_indices.numpy()[:actual_count, :] - - return nms_result + # Dynamic output trimming using dynamic_strided_slice + # Extract selected_indices and num_total_detections from the NMS result + selected_indices = block_builder.emit(TupleGetItem(nms_result, 0)) + num_total_detections = block_builder.emit(TupleGetItem(nms_result, 1)) + + # Build slicing parameters using TE to avoid high-level Relax ops during legalization + def build_begin(): + return te.compute((2,), lambda i: tir.const(0, "int64"), name="begin") + + def build_strides(): + return te.compute((2,), lambda i: tir.const(1, "int64"), name="strides") + + def build_end(count_tensor): + # end = [count_tensor[0], 3] + def compute_end(i): + return tir.if_then_else( + i == 0, + tir.Cast("int64", count_tensor[0]), + tir.const(3, "int64"), + ) + + return te.compute((2,), compute_end, name="end") + + begin = block_builder.call_te(build_begin) + strides = block_builder.call_te(build_strides) + end = block_builder.call_te(build_end, num_total_detections) + + # Apply dynamic strided slice to trim to valid detections only + trimmed_indices = block_builder.emit( + relax.op.dynamic_strided_slice(selected_indices, begin, end, strides) + ) + + # Return trimmed indices along with num_total_detections for compatibility + return relax.Tuple([trimmed_indices, num_total_detections]) diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index 97145a53ff3b..660b5d27720b 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. +import numpy as np import pytest + import tvm import tvm.testing -from tvm import relax, tir -from tvm import TVMError -from tvm.ir import Op, VDevice +from tvm import TVMError, relax, tir +from tvm.relax.transform import LegalizeOps from tvm.script import relax as R @@ -53,7 +54,6 @@ def test_all_class_non_max_suppression_infer_struct_info(): def test_all_class_non_max_suppression_wrong_input_number(): - bb = relax.BlockBuilder() boxes = relax.Var("boxes", R.Tensor((1, 5, 4), "float32")) scores = relax.Var("scores", R.Tensor((1, 3, 5), "float32")) @@ -86,5 +86,95 @@ def test_all_class_non_max_suppression_infer_struct_info_shape_var(): ) +def test_all_class_non_max_suppression_legalize_dynamic_trim(): + @tvm.script.ir_module + class NMSModule: + @R.function + def main( + boxes: R.Tensor((1, 5, 4), "float32"), + scores: R.Tensor((1, 2, 5), "float32"), + ) -> R.Tuple(R.Tensor(dtype="int64", ndim=2), R.Tensor((1,), "int64")): + max_output_boxes_per_class = R.const(3, "int64") + iou_threshold = R.const(0.5, "float32") + score_threshold = R.const(0.1, "float32") + return R.vision.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" + ) + + mod = LegalizeOps()(NMSModule) + + # Check legalized function has dynamic output (uses dynamic_strided_slice) + assert "dynamic_strided_slice" in str(mod) + + ret_sinfo = mod["main"].ret_struct_info + tvm.ir.assert_structural_equal( + ret_sinfo, + relax.TupleStructInfo( + [ + relax.TensorStructInfo(ndim=2, dtype="int64"), + relax.TensorStructInfo((1,), "int64"), + ] + ), + ) + + +def test_all_class_non_max_suppression_legalize_e2e(): + @tvm.script.ir_module + class NMSModule: + @R.function + def main( + boxes: R.Tensor((1, 5, 4), "float32"), + scores: R.Tensor((1, 2, 5), "float32"), + ) -> R.Tuple(R.Tensor(dtype="int64", ndim=2), R.Tensor((1,), "int64")): + max_output_boxes_per_class = R.const(3, "int64") + iou_threshold = R.const(0.5, "float32") + score_threshold = R.const(0.1, "float32") + return R.vision.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, "onnx" + ) + + boxes_data = np.array( + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.1, 0.1, 1.1, 1.1], + [2.0, 2.0, 3.0, 3.0], + [4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0], + ] + ], + dtype=np.float32, + ) + scores_data = np.array( + [[[0.9, 0.8, 0.7, 0.6, 0.5], [0.85, 0.75, 0.65, 0.55, 0.45]]], + dtype=np.float32, + ) + + mod = LegalizeOps()(NMSModule) + + # Check struct info + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TupleStructInfo( + [ + relax.TensorStructInfo(ndim=2, dtype="int64"), + relax.TensorStructInfo((1,), "int64"), + ] + ), + ) + + # Check runtime execution + exe = tvm.compile(mod, target="llvm") + vm = relax.VirtualMachine(exe, tvm.cpu()) + result = vm["main"]( + tvm.runtime.tensor(boxes_data, tvm.cpu()), + tvm.runtime.tensor(scores_data, tvm.cpu()), + ) + + selected_indices = result[0].numpy() + num_total_detections = int(result[1].numpy()[0]) + tvm.testing.assert_allclose(selected_indices.shape, (num_total_detections, 3)) + + if __name__ == "__main__": tvm.testing.main()