diff --git a/python/tvm/relax/backend/metal/coreml.py b/python/tvm/relax/backend/metal/coreml.py index dfc891dc1f31..c7d922b59986 100644 --- a/python/tvm/relax/backend/metal/coreml.py +++ b/python/tvm/relax/backend/metal/coreml.py @@ -142,11 +142,10 @@ def _conv2d_pattern(pattern_name): *default_unary_patterns(op_name="nn.relu"), *default_unary_patterns(op_name="expand_dims"), *default_unary_patterns(op_name="nn.avg_pool2d"), + *default_unary_patterns(op_name="nn.batch_flatten"), *conv2d_patterns(), *clip_patterns(), *matmul_patterns(), - # TODO(@tvm-team): enable when relax op is implemented - # ("coreml.nn.batch_flatten", is_op("relax.nn.batch_flatten")(wildcard())), ] ) @@ -271,7 +270,7 @@ def _convert_avg_pool2d(builder, name, inputs, outputs, args, attrs): "clip": _convert_clip, "expand_dims": _convert_expand_dims, "nn.relu": _convert_relu, - # "nn.batch_flatten": _convert_batch_flatten, + "nn.batch_flatten": _convert_batch_flatten, "nn.softmax": _convert_softmax, "nn.conv2d": _convert_conv2d, "nn.avg_pool2d": _convert_avg_pool2d, diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index ec1135ef2cd0..cf804440f1e4 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -25,6 +25,7 @@ avg_pool1d, avg_pool2d, avg_pool3d, + batch_flatten, batch_norm, conv1d, conv1d_transpose, diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index c7be2a7ba6f6..e9710deca9bf 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -2249,3 +2249,22 @@ def attention_var_len( causal_mask, window_size, ) # type: ignore + + +def batch_flatten(data: Expr) -> Expr: + """Flatten all dimensions except the first (batch) dimension. + + This operation flattens a tensor of shape `(N, C, H, W, ...)` into + a 2D tensor of shape `(N, C*H*W*...)`. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + Returns + ------- + result : relax.Expr + The flattened result with shape `(batch_size, flattened_features)`. + """ + return _ffi_api.batch_flatten(data) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index ed9802fc9e63..1a0477af20e2 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -775,3 +775,10 @@ def nll_loss_without_weight(predictions, targets, reduction, ignore_index): reduction=call.attrs.reduction, ignore_index=call.attrs.ignore_index, ) + + +@register_legalize("relax.nn.batch_flatten") +def _nn_batch_flatten(bb: BlockBuilder, call: Call) -> Expr: + if call.struct_info.shape is None: + return call + return bb.call_te(topi.reshape, call.args[0], call.struct_info.shape.values) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index f4b9fe400bee..0a2335834399 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -1192,5 +1192,55 @@ TVM_REGISTER_OP("relax.nn.nll_loss") .set_attr("FInferStructInfo", InferStructInfoNLLLoss) .set_attr("FPurity", Bool(true)); +/* relax.nn.batch_flatten */ + +Expr batch_flatten(Expr data) { + static const Op& op = Op::Get("relax.nn.batch_flatten"); + return Call(op, {std::move(data)}, {}, {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.nn.batch_flatten", batch_flatten); +} + +StructInfo InferStructInfoBatchFlatten(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/2, data_sinfo->vdevice); + } + + if (data_sinfo->ndim < 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "batch_flatten expects input tensor to have at least 2 dimensions, " + << "but got " << data_sinfo->ndim); + } + + if (data_sinfo->ndim == 2) { + return data_sinfo; + } + + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/2, data_sinfo->vdevice); + } + + PrimExpr batch_dim = data_shape->values[0]; + PrimExpr flat_dim = IntImm(DataType::Int(64), 1); + for (size_t i = 1; i < data_shape->values.size(); ++i) { + flat_dim = flat_dim * data_shape->values[i]; + } + + return TensorStructInfo(ShapeExpr({batch_dim, flat_dim}), data_sinfo->dtype, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.nn.batch_flatten") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoBatchFlatten) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 989dfbb3f613..b6f749854f36 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -114,6 +114,9 @@ Expr cross_entropy_with_logits(Expr predictions, Expr labels); Expr nll_loss(Expr predictions, Expr targets, ffi::Optional weights, ffi::String reduction, int ignore_index); +/*! \brief Batch flatten: flatten all dimensions except the first (batch) dimension. */ +Expr batch_flatten(Expr data); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_codegen_coreml.py b/tests/python/relax/test_codegen_coreml.py index b07271e8949a..de3a6d0789f8 100644 --- a/tests/python/relax/test_codegen_coreml.py +++ b/tests/python/relax/test_codegen_coreml.py @@ -198,7 +198,6 @@ def test_relu(): verify(mod, [x_data]) -@pytest.mark.skip("`batch_flatten` is not implemented yet.") def test_batch_flatten(): x = relax.Var("x", relax.TensorStructInfo([10, 10, 10], "float32")) bb = relax.BlockBuilder() diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index b076827dc4a0..4c419ed0e1ce 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -1845,5 +1845,59 @@ def test_pixel_shuffle_infer_struct_info(): ) +def test_batch_flatten_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + assert relax.op.nn.batch_flatten(x).op == Op.get("relax.nn.batch_flatten") + + +def test_batch_flatten_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + x4 = relax.Var("x", R.Tensor((10, 20), "float32")) + x5 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0)) + + _check_inference(bb, relax.op.nn.batch_flatten(x0), relax.TensorStructInfo((2, 60), "float32")) + _check_inference( + bb, relax.op.nn.batch_flatten(x5), relax.TensorStructInfo((2, 60), "float32", vdev0) + ) + _check_inference( + bb, relax.op.nn.batch_flatten(x1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.nn.batch_flatten(x2), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.nn.batch_flatten(x3), relax.TensorStructInfo((2, 60), dtype="")) + _check_inference(bb, relax.op.nn.batch_flatten(x4), relax.TensorStructInfo((10, 20), "float32")) + + +def test_batch_flatten_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + h = tir.Var("h", "int64") + w = tir.Var("w", "int64") + x0 = relax.Var("x", R.Tensor((m, n, h, w), "float32")) + x1 = relax.Var("x", R.Tensor((4, n, 8, 8), "float32")) + + _check_inference( + bb, relax.op.nn.batch_flatten(x0), relax.TensorStructInfo((m, n * h * w), "float32") + ) + _check_inference( + bb, relax.op.nn.batch_flatten(x1), relax.TensorStructInfo((4, n * 8 * 8), "float32") + ) + + +def test_batch_flatten_infer_struct_info_wrong_ndim(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.batch_flatten(x0)) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index e81e1bab2af4..f39dc238efe1 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -3898,5 +3898,48 @@ def pad( tvm.ir.assert_structural_equal(mod, Expected) +def test_batch_flatten(): + # fmt: off + @tvm.script.ir_module + class BatchFlatten: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 60), "float32"): + gv: R.Tensor((2, 60), "float32") = R.nn.batch_flatten(x) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((2, 60), dtype="float32"): + gv = R.call_tir(Expected.reshape, (x,), out_sinfo=R.Tensor((2, 60), dtype="float32")) + return gv + + @T.prim_func(private=True) + def reshape(x: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(60)), "float32")): + T.func_attr({"tir.noalias": True}) + for ax0, ax1 in T.grid(T.int64(2), T.int64(60)): + with T.block("T_reshape"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x[(v_ax1 // T.int64(60) + v_ax0) % T.int64(2), v_ax1 % T.int64(60) // T.int64(20), v_ax1 % T.int64(20) // T.int64(5), v_ax1 % T.int64(5)]) + T.writes(T_reshape[v_ax0, v_ax1]) + T_reshape[v_ax0, v_ax1] = x[(v_ax1 // T.int64(60) + v_ax0) % T.int64(2), v_ax1 % T.int64(60) // T.int64(20), v_ax1 % T.int64(20) // T.int64(5), v_ax1 % T.int64(5)] + # fmt: on + + mod = LegalizeOps()(BatchFlatten) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_batch_flatten_undefined_shape(): + @tvm.script.ir_module + class BatchFlattenUndefinedShape: + @R.function + def main(x: R.Tensor(ndim=4, dtype="float32")) -> R.Tensor(ndim=2, dtype="float32"): + gv: R.Tensor(ndim=2, dtype="float32") = R.nn.batch_flatten(x) + return gv + + mod = LegalizeOps()(BatchFlattenUndefinedShape) + tvm.ir.assert_structural_equal(mod, BatchFlattenUndefinedShape) + + if __name__ == "__main__": tvm.testing.main()