Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ using PassContext = tvm::transform::PassContext;
using Function = tvm::relax::Function;
using DataflowBlock = tvm::relax::DataflowBlock;
using tvm::transform::CreateModulePass;
using LayoutCb = ffi::TypedFunction<ffi::Map<ffi::String, ffi::Array<ffi::String>>(Call)>;

/*!
* \brief Create a function pass.
Expand Down Expand Up @@ -606,10 +607,12 @@ TVM_DLL Pass AlterOpImpl(
/*!
* \brief Layout conversion pass.
* \param desired_layouts The desired layouts for some operators.
* \param layout_cb custom call back to define layouts dynamically.
* \return The Pass.
* \note Operates only on dataflow blocks. ConvertToDataflow may need to be called first.
*/
TVM_DLL Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts);
TVM_DLL Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts,
LayoutCb layout_cb);

/*!
* \brief A pass that converts consecutive dataflow operations
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,10 @@ def AlterOpImpl(
) # type: ignore


def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pass:
def ConvertLayout(
desired_layouts: Dict[str, List[str]],
layout_cb: Callable = None,
) -> tvm.ir.transform.Pass:
"""Automatic layout conversion pass.

Parameters
Expand All @@ -1377,13 +1380,16 @@ def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pas
of the desired feature map, weight and output. For example, if we want to convert the
layout of conv2d from NCHW to NHWC, we can set the desired layout of conv2d to be
``{"relax.nn.conv2d": ["NHWC", "OHWI"]}``.
layout_cb : Callable
A user defined call back function that can dynamically handle operator layouts
based on Call description. desired_layouts will be ignored if layout_cb is defined.

Returns
-------
ret : tvm.transform.Pass
The registered pass for layout conversion.
"""
return _ffi_api.ConvertLayout(desired_layouts) # type: ignore
return _ffi_api.ConvertLayout(desired_layouts, layout_cb) # type: ignore


def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> tvm.ir.transform.Pass:
Expand Down
28 changes: 19 additions & 9 deletions src/relax/transform/convert_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace relax {

using tir::IndexMap;
using tir::Layout;
using LayoutCb = tvm::relax::transform::LayoutCb;

/*!
* \brief Main logic to convert the layout of conv2d. Other ops
Expand Down Expand Up @@ -79,8 +80,8 @@ using tir::Layout;
class LayoutConvertMutator : public ExprMutator {
public:
explicit LayoutConvertMutator(
const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts)
: desired_layouts_(desired_layouts) {}
const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts, LayoutCb layout_cb)
: desired_layouts_(desired_layouts), layout_cb_(layout_cb) {}

private:
ffi::Array<Integer> LayoutToIntegers(const Layout& layout) {
Expand Down Expand Up @@ -201,15 +202,21 @@ class LayoutConvertMutator : public ExprMutator {
ffi::Optional<InferLayoutOutput> GetInferLayoutInfo(
const CallNode* call_node,
const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
const VarLayoutMap& var_layout_map) {
const LayoutCb& layout_cb, const VarLayoutMap& var_layout_map) {
const OpNode* op_node = call_node->op.as<OpNode>();
if (op_node == nullptr) return std::nullopt;
Op op = Downcast<Op>(ffi::GetRef<Op>(op_node));
const auto attr_map = Op::GetAttrMap<FRelaxInferLayout>("FRelaxInferLayout");
if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) {
// If the op has FRelaxInferLayout, and all the input tensors have known ndim
FRelaxInferLayout f = attr_map[op];
return f(ffi::GetRef<Call>(call_node), desired_layouts, var_layout_map);
auto call = ffi::GetRef<Call>(call_node);
if (layout_cb != nullptr) {
auto custom_layouts = layout_cb(call);
return f(call, custom_layouts, var_layout_map);
} else {
return f(call, desired_layouts, var_layout_map);
}
} else {
// Otherwise, we use the default policy.
return std::nullopt;
Expand All @@ -218,7 +225,7 @@ class LayoutConvertMutator : public ExprMutator {

void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final {
ffi::Optional<InferLayoutOutput> res =
GetInferLayoutInfo(call_node, desired_layouts_, var_layout_map_);
GetInferLayoutInfo(call_node, desired_layouts_, layout_cb_, var_layout_map_);
ObjectPtr<CallNode> new_call = ffi::make_object<CallNode>(*call_node);
new_call->struct_info_ = std::nullopt;
if (!res.defined() ||
Expand Down Expand Up @@ -335,20 +342,23 @@ class LayoutConvertMutator : public ExprMutator {

std::unordered_map<Var, NLayout> var_layout_map_;
ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts_;
LayoutCb layout_cb_;
}; // namespace relax

DataflowBlock ConvertLayoutPass(const DataflowBlock& df_block,
ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts) {
LayoutConvertMutator mutator(desired_layouts);
ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts,
LayoutCb layout_cb) {
LayoutConvertMutator mutator(desired_layouts, layout_cb);
return Downcast<DataflowBlock>(mutator.VisitBindingBlock(df_block));
}

namespace transform {

Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts) {
Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts,
LayoutCb layout_cb) {
ffi::TypedFunction<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func =
[=](DataflowBlock df_block, IRModule m, PassContext pc) {
return Downcast<DataflowBlock>(ConvertLayoutPass(df_block, desired_layouts));
return Downcast<DataflowBlock>(ConvertLayoutPass(df_block, desired_layouts, layout_cb));
};
return CreateDataflowBlockPass(pass_func, 0, "ConvertLayout", {});
}
Expand Down
95 changes: 93 additions & 2 deletions tests/python/relax/test_transform_convert_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from tvm.script.parser import ir as I, relax as R, tir as T


def verify(input, expected, extra_ops={}):
def verify(input, expected, extra_ops={}, cb=None):
desired_layouts = {"relax.nn.conv2d": ["NHWC", "OHWI"]}
desired_layouts.update(extra_ops)
mod = ConvertLayout(desired_layouts)(input)
mod = ConvertLayout(desired_layouts, cb)(input)
mod = Normalize()(mod)
tvm.ir.assert_structural_equal(mod, expected)

Expand Down Expand Up @@ -5487,5 +5487,96 @@ def main(
verify(Input, Expected)


def test_layout_cb():
@I.ir_module
class Input:
@R.function
def main(
x: R.Tensor((2, 4, 28, 28), "float32"),
w: R.Tensor((4, 4, 3, 3), "float32"),
bias: R.Tensor((2, 4, 26, 26), "float32"),
) -> R.Tensor(None, "float32", ndim=4):
with R.dataflow():
gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32")
gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias)
gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2)
gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32")
R.output(gv4)
return gv4

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((2, 4, 28, 28), dtype="float32"),
w: R.Tensor((4, 4, 3, 3), dtype="float32"),
bias: R.Tensor((2, 4, 26, 26), dtype="float32"),
) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
with R.dataflow():
lv: R.Tensor((2, 1, 28, 28, 4), dtype="float32") = R.layout_transform(
x,
index_map=T.index_map(
lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32"
),
)
lv1: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform(
w,
index_map=T.index_map(
lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32"
),
)
gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d(
lv,
lv1,
strides=[1, 1],
padding=[0, 0, 0, 0],
dilation=[1, 1],
groups=1,
data_layout="NCHW4c",
kernel_layout="OIHW4o",
out_layout="NCHW4c",
out_dtype="float32",
)
lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.layout_transform(
bias,
index_map=T.index_map(
lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32"
),
)
gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.add(gv, lv2)
gv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv2)
lv3: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform(
w,
index_map=T.index_map(
lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32"
),
)
lv4: R.Tensor((2, 1, 24, 24, 4), dtype="float32") = R.nn.conv2d(
gv3,
lv3,
strides=[1, 1],
padding=[0, 0, 0, 0],
dilation=[1, 1],
groups=1,
data_layout="NCHW4c",
kernel_layout="OIHW4o",
out_layout="NCHW4c",
out_dtype="float32",
)
gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.layout_transform(
lv4,
index_map=T.index_map(
lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32"
),
)
R.output(gv4)
return gv4

def layout_cb(call: tvm.relax.Call):
return {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}

verify(Input, Expected, cb=layout_cb)


if __name__ == "__main__":
tvm.testing.main()
Loading