diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 786dfdcdf98c..0e660292c46c 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -41,6 +41,7 @@ using PassContext = tvm::transform::PassContext; using Function = tvm::relax::Function; using DataflowBlock = tvm::relax::DataflowBlock; using tvm::transform::CreateModulePass; +using LayoutCb = ffi::TypedFunction>(Call)>; /*! * \brief Create a function pass. @@ -606,10 +607,12 @@ TVM_DLL Pass AlterOpImpl( /*! * \brief Layout conversion pass. * \param desired_layouts The desired layouts for some operators. + * \param layout_cb custom call back to define layouts dynamically. * \return The Pass. * \note Operates only on dataflow blocks. ConvertToDataflow may need to be called first. */ -TVM_DLL Pass ConvertLayout(ffi::Map> desired_layouts); +TVM_DLL Pass ConvertLayout(ffi::Map> desired_layouts, + LayoutCb layout_cb); /*! * \brief A pass that converts consecutive dataflow operations diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 46efc17e3d4f..bfd7dbf87d70 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1367,7 +1367,10 @@ def AlterOpImpl( ) # type: ignore -def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pass: +def ConvertLayout( + desired_layouts: Dict[str, List[str]], + layout_cb: Callable = None, +) -> tvm.ir.transform.Pass: """Automatic layout conversion pass. Parameters @@ -1377,13 +1380,16 @@ def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pas of the desired feature map, weight and output. For example, if we want to convert the layout of conv2d from NCHW to NHWC, we can set the desired layout of conv2d to be ``{"relax.nn.conv2d": ["NHWC", "OHWI"]}``. + layout_cb : Callable + A user defined call back function that can dynamically handle operator layouts + based on Call description. desired_layouts will be ignored if layout_cb is defined. Returns ------- ret : tvm.transform.Pass The registered pass for layout conversion. """ - return _ffi_api.ConvertLayout(desired_layouts) # type: ignore + return _ffi_api.ConvertLayout(desired_layouts, layout_cb) # type: ignore def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> tvm.ir.transform.Pass: diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index c543799e3b0d..27684313de02 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -38,6 +38,7 @@ namespace relax { using tir::IndexMap; using tir::Layout; +using LayoutCb = tvm::relax::transform::LayoutCb; /*! * \brief Main logic to convert the layout of conv2d. Other ops @@ -79,8 +80,8 @@ using tir::Layout; class LayoutConvertMutator : public ExprMutator { public: explicit LayoutConvertMutator( - const ffi::Map>& desired_layouts) - : desired_layouts_(desired_layouts) {} + const ffi::Map>& desired_layouts, LayoutCb layout_cb) + : desired_layouts_(desired_layouts), layout_cb_(layout_cb) {} private: ffi::Array LayoutToIntegers(const Layout& layout) { @@ -201,7 +202,7 @@ class LayoutConvertMutator : public ExprMutator { ffi::Optional GetInferLayoutInfo( const CallNode* call_node, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { + const LayoutCb& layout_cb, const VarLayoutMap& var_layout_map) { const OpNode* op_node = call_node->op.as(); if (op_node == nullptr) return std::nullopt; Op op = Downcast(ffi::GetRef(op_node)); @@ -209,7 +210,13 @@ class LayoutConvertMutator : public ExprMutator { if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) { // If the op has FRelaxInferLayout, and all the input tensors have known ndim FRelaxInferLayout f = attr_map[op]; - return f(ffi::GetRef(call_node), desired_layouts, var_layout_map); + auto call = ffi::GetRef(call_node); + if (layout_cb != nullptr) { + auto custom_layouts = layout_cb(call); + return f(call, custom_layouts, var_layout_map); + } else { + return f(call, desired_layouts, var_layout_map); + } } else { // Otherwise, we use the default policy. return std::nullopt; @@ -218,7 +225,7 @@ class LayoutConvertMutator : public ExprMutator { void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { ffi::Optional res = - GetInferLayoutInfo(call_node, desired_layouts_, var_layout_map_); + GetInferLayoutInfo(call_node, desired_layouts_, layout_cb_, var_layout_map_); ObjectPtr new_call = ffi::make_object(*call_node); new_call->struct_info_ = std::nullopt; if (!res.defined() || @@ -335,20 +342,23 @@ class LayoutConvertMutator : public ExprMutator { std::unordered_map var_layout_map_; ffi::Map> desired_layouts_; + LayoutCb layout_cb_; }; // namespace relax DataflowBlock ConvertLayoutPass(const DataflowBlock& df_block, - ffi::Map> desired_layouts) { - LayoutConvertMutator mutator(desired_layouts); + ffi::Map> desired_layouts, + LayoutCb layout_cb) { + LayoutConvertMutator mutator(desired_layouts, layout_cb); return Downcast(mutator.VisitBindingBlock(df_block)); } namespace transform { -Pass ConvertLayout(ffi::Map> desired_layouts) { +Pass ConvertLayout(ffi::Map> desired_layouts, + LayoutCb layout_cb) { ffi::TypedFunction pass_func = [=](DataflowBlock df_block, IRModule m, PassContext pc) { - return Downcast(ConvertLayoutPass(df_block, desired_layouts)); + return Downcast(ConvertLayoutPass(df_block, desired_layouts, layout_cb)); }; return CreateDataflowBlockPass(pass_func, 0, "ConvertLayout", {}); } diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 84fa9e70c7d7..fe412fd93b18 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -21,10 +21,10 @@ from tvm.script.parser import ir as I, relax as R, tir as T -def verify(input, expected, extra_ops={}): +def verify(input, expected, extra_ops={}, cb=None): desired_layouts = {"relax.nn.conv2d": ["NHWC", "OHWI"]} desired_layouts.update(extra_ops) - mod = ConvertLayout(desired_layouts)(input) + mod = ConvertLayout(desired_layouts, cb)(input) mod = Normalize()(mod) tvm.ir.assert_structural_equal(mod, expected) @@ -5487,5 +5487,96 @@ def main( verify(Input, Expected) +def test_layout_cb(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 1, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.layout_transform( + bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.add(gv, lv2) + gv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv2) + lv3: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + lv4: R.Tensor((2, 1, 24, 24, 4), dtype="float32") = R.nn.conv2d( + gv3, + lv3, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.layout_transform( + lv4, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv4) + return gv4 + + def layout_cb(call: tvm.relax.Call): + return {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]} + + verify(Input, Expected, cb=layout_cb) + + if __name__ == "__main__": tvm.testing.main()