Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/tvm/relax/backend/metal/coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,10 @@ def _conv2d_pattern(pattern_name):
*default_unary_patterns(op_name="nn.relu"),
*default_unary_patterns(op_name="expand_dims"),
*default_unary_patterns(op_name="nn.avg_pool2d"),
*default_unary_patterns(op_name="nn.batch_flatten"),
*conv2d_patterns(),
*clip_patterns(),
*matmul_patterns(),
# TODO(@tvm-team): enable when relax op is implemented
# ("coreml.nn.batch_flatten", is_op("relax.nn.batch_flatten")(wildcard())),
]
)

Expand Down Expand Up @@ -271,7 +270,7 @@ def _convert_avg_pool2d(builder, name, inputs, outputs, args, attrs):
"clip": _convert_clip,
"expand_dims": _convert_expand_dims,
"nn.relu": _convert_relu,
# "nn.batch_flatten": _convert_batch_flatten,
"nn.batch_flatten": _convert_batch_flatten,
"nn.softmax": _convert_softmax,
"nn.conv2d": _convert_conv2d,
"nn.avg_pool2d": _convert_avg_pool2d,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
avg_pool1d,
avg_pool2d,
avg_pool3d,
batch_flatten,
batch_norm,
conv1d,
conv1d_transpose,
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2249,3 +2249,22 @@ def attention_var_len(
causal_mask,
window_size,
) # type: ignore


def batch_flatten(data: Expr) -> Expr:
"""Flatten all dimensions except the first (batch) dimension.

This operation flattens a tensor of shape `(N, C, H, W, ...)` into
a 2D tensor of shape `(N, C*H*W*...)`.

Parameters
----------
data : relax.Expr
The input data to the operator.

Returns
-------
result : relax.Expr
The flattened result with shape `(batch_size, flattened_features)`.
"""
return _ffi_api.batch_flatten(data) # type: ignore
7 changes: 7 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,3 +775,10 @@ def nll_loss_without_weight(predictions, targets, reduction, ignore_index):
reduction=call.attrs.reduction,
ignore_index=call.attrs.ignore_index,
)


@register_legalize("relax.nn.batch_flatten")
def _nn_batch_flatten(bb: BlockBuilder, call: Call) -> Expr:
if call.struct_info.shape is None:
return call
return bb.call_te(topi.reshape, call.args[0], call.struct_info.shape.values)
50 changes: 50 additions & 0 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1192,5 +1192,55 @@ TVM_REGISTER_OP("relax.nn.nll_loss")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoNLLLoss)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nn.batch_flatten */

Expr batch_flatten(Expr data) {
static const Op& op = Op::Get("relax.nn.batch_flatten");
return Call(op, {std::move(data)}, {}, {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.op.nn.batch_flatten", batch_flatten);
}

StructInfo InferStructInfoBatchFlatten(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);

if (data_sinfo->IsUnknownNdim()) {
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/2, data_sinfo->vdevice);
}

if (data_sinfo->ndim < 2) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "batch_flatten expects input tensor to have at least 2 dimensions, "
<< "but got " << data_sinfo->ndim);
}

if (data_sinfo->ndim == 2) {
return data_sinfo;
}

const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape == nullptr) {
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/2, data_sinfo->vdevice);
}

PrimExpr batch_dim = data_shape->values[0];
PrimExpr flat_dim = IntImm(DataType::Int(64), 1);
for (size_t i = 1; i < data_shape->values.size(); ++i) {
flat_dim = flat_dim * data_shape->values[i];
}

return TensorStructInfo(ShapeExpr({batch_dim, flat_dim}), data_sinfo->dtype, data_sinfo->vdevice);
}

TVM_REGISTER_OP("relax.nn.batch_flatten")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoBatchFlatten)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
} // namespace tvm
3 changes: 3 additions & 0 deletions src/relax/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ Expr cross_entropy_with_logits(Expr predictions, Expr labels);
Expr nll_loss(Expr predictions, Expr targets, ffi::Optional<Expr> weights, ffi::String reduction,
int ignore_index);

/*! \brief Batch flatten: flatten all dimensions except the first (batch) dimension. */
Expr batch_flatten(Expr data);

} // namespace relax
} // namespace tvm

Expand Down
1 change: 0 additions & 1 deletion tests/python/relax/test_codegen_coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def test_relu():
verify(mod, [x_data])


@pytest.mark.skip("`batch_flatten` is not implemented yet.")
def test_batch_flatten():
x = relax.Var("x", relax.TensorStructInfo([10, 10, 10], "float32"))
bb = relax.BlockBuilder()
Expand Down
54 changes: 54 additions & 0 deletions tests/python/relax/test_op_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,5 +1845,59 @@ def test_pixel_shuffle_infer_struct_info():
)


def test_batch_flatten_op_correctness():
x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
assert relax.op.nn.batch_flatten(x).op == Op.get("relax.nn.batch_flatten")


def test_batch_flatten_infer_struct_info():
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
x1 = relax.Var("x", R.Tensor("float32", ndim=4))
x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
x4 = relax.Var("x", R.Tensor((10, 20), "float32"))
x5 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))

_check_inference(bb, relax.op.nn.batch_flatten(x0), relax.TensorStructInfo((2, 60), "float32"))
_check_inference(
bb, relax.op.nn.batch_flatten(x5), relax.TensorStructInfo((2, 60), "float32", vdev0)
)
_check_inference(
bb, relax.op.nn.batch_flatten(x1), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(
bb, relax.op.nn.batch_flatten(x2), relax.TensorStructInfo(dtype="float32", ndim=2)
)
_check_inference(bb, relax.op.nn.batch_flatten(x3), relax.TensorStructInfo((2, 60), dtype=""))
_check_inference(bb, relax.op.nn.batch_flatten(x4), relax.TensorStructInfo((10, 20), "float32"))


def test_batch_flatten_infer_struct_info_shape_symbolic():
bb = relax.BlockBuilder()
m = tir.Var("m", "int64")
n = tir.Var("n", "int64")
h = tir.Var("h", "int64")
w = tir.Var("w", "int64")
x0 = relax.Var("x", R.Tensor((m, n, h, w), "float32"))
x1 = relax.Var("x", R.Tensor((4, n, 8, 8), "float32"))

_check_inference(
bb, relax.op.nn.batch_flatten(x0), relax.TensorStructInfo((m, n * h * w), "float32")
)
_check_inference(
bb, relax.op.nn.batch_flatten(x1), relax.TensorStructInfo((4, n * 8 * 8), "float32")
)


def test_batch_flatten_infer_struct_info_wrong_ndim():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((3,), "float32"))

with pytest.raises(TVMError):
bb.normalize(relax.op.nn.batch_flatten(x0))


if __name__ == "__main__":
tvm.testing.main()
43 changes: 43 additions & 0 deletions tests/python/relax/test_transform_legalize_ops_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3898,5 +3898,48 @@ def pad(
tvm.ir.assert_structural_equal(mod, Expected)


def test_batch_flatten():
# fmt: off
@tvm.script.ir_module
class BatchFlatten:
@R.function
def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 60), "float32"):
gv: R.Tensor((2, 60), "float32") = R.nn.batch_flatten(x)
return gv

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((2, 60), dtype="float32"):
gv = R.call_tir(Expected.reshape, (x,), out_sinfo=R.Tensor((2, 60), dtype="float32"))
return gv

@T.prim_func(private=True)
def reshape(x: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(60)), "float32")):
T.func_attr({"tir.noalias": True})
for ax0, ax1 in T.grid(T.int64(2), T.int64(60)):
with T.block("T_reshape"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(x[(v_ax1 // T.int64(60) + v_ax0) % T.int64(2), v_ax1 % T.int64(60) // T.int64(20), v_ax1 % T.int64(20) // T.int64(5), v_ax1 % T.int64(5)])
T.writes(T_reshape[v_ax0, v_ax1])
T_reshape[v_ax0, v_ax1] = x[(v_ax1 // T.int64(60) + v_ax0) % T.int64(2), v_ax1 % T.int64(60) // T.int64(20), v_ax1 % T.int64(20) // T.int64(5), v_ax1 % T.int64(5)]
# fmt: on

mod = LegalizeOps()(BatchFlatten)
tvm.ir.assert_structural_equal(mod, Expected)


def test_batch_flatten_undefined_shape():
@tvm.script.ir_module
class BatchFlattenUndefinedShape:
@R.function
def main(x: R.Tensor(ndim=4, dtype="float32")) -> R.Tensor(ndim=2, dtype="float32"):
gv: R.Tensor(ndim=2, dtype="float32") = R.nn.batch_flatten(x)
return gv

mod = LegalizeOps()(BatchFlattenUndefinedShape)
tvm.ir.assert_structural_equal(mod, BatchFlattenUndefinedShape)


if __name__ == "__main__":
tvm.testing.main()
Loading