From 2a9121b75167f8fcf2d44a73d1d72e36bf182d08 Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 20 Jan 2026 08:57:04 +0530 Subject: [PATCH 1/3] [RELAX][LAYOUT] Support for dynamic layout specification This allows user defined callback to specify layouts dynamically based on call description Helpful to alter layouts based on the operator shapes or attributes. --- include/tvm/relax/transform.h | 5 +- python/tvm/relax/transform/transform.py | 10 +- src/relax/transform/convert_layout.cc | 28 ++++-- .../relax/test_transform_convert_layout.py | 95 ++++++++++++++++++- 4 files changed, 124 insertions(+), 14 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 786dfdcdf98c..0e660292c46c 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -41,6 +41,7 @@ using PassContext = tvm::transform::PassContext; using Function = tvm::relax::Function; using DataflowBlock = tvm::relax::DataflowBlock; using tvm::transform::CreateModulePass; +using LayoutCb = ffi::TypedFunction>(Call)>; /*! * \brief Create a function pass. @@ -606,10 +607,12 @@ TVM_DLL Pass AlterOpImpl( /*! * \brief Layout conversion pass. * \param desired_layouts The desired layouts for some operators. + * \param layout_cb custom call back to define layouts dynamically. * \return The Pass. * \note Operates only on dataflow blocks. ConvertToDataflow may need to be called first. */ -TVM_DLL Pass ConvertLayout(ffi::Map> desired_layouts); +TVM_DLL Pass ConvertLayout(ffi::Map> desired_layouts, + LayoutCb layout_cb); /*! * \brief A pass that converts consecutive dataflow operations diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 46efc17e3d4f..b2210e242810 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1367,7 +1367,10 @@ def AlterOpImpl( ) # type: ignore -def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pass: +def ConvertLayout( + desired_layouts: Dict[str, List[str]], + layout_cb: Callable = None, +) -> tvm.ir.transform.Pass: """Automatic layout conversion pass. Parameters @@ -1377,13 +1380,16 @@ def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pas of the desired feature map, weight and output. For example, if we want to convert the layout of conv2d from NCHW to NHWC, we can set the desired layout of conv2d to be ``{"relax.nn.conv2d": ["NHWC", "OHWI"]}``. + layout_cb : Callable + A user defined call back function that can dynamically handle operator layouts + based on Call description. desigred_layouts will be ignored if layout_cb is defined. Returns ------- ret : tvm.transform.Pass The registered pass for layout conversion. """ - return _ffi_api.ConvertLayout(desired_layouts) # type: ignore + return _ffi_api.ConvertLayout(desired_layouts, layout_cb) # type: ignore def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> tvm.ir.transform.Pass: diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index c543799e3b0d..0f1c74cf16f0 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -38,6 +38,7 @@ namespace relax { using tir::IndexMap; using tir::Layout; +using LayoutCb = tvm::relax::transform::LayoutCb; /*! * \brief Main logic to convert the layout of conv2d. Other ops @@ -79,8 +80,8 @@ using tir::Layout; class LayoutConvertMutator : public ExprMutator { public: explicit LayoutConvertMutator( - const ffi::Map>& desired_layouts) - : desired_layouts_(desired_layouts) {} + const ffi::Map>& desired_layouts, LayoutCb layout_cb) + : desired_layouts_(desired_layouts), layout_cb_(layout_cb) {} private: ffi::Array LayoutToIntegers(const Layout& layout) { @@ -201,7 +202,7 @@ class LayoutConvertMutator : public ExprMutator { ffi::Optional GetInferLayoutInfo( const CallNode* call_node, const ffi::Map>& desired_layouts, - const VarLayoutMap& var_layout_map) { + const LayoutCb& layout_cb, const VarLayoutMap& var_layout_map) { const OpNode* op_node = call_node->op.as(); if (op_node == nullptr) return std::nullopt; Op op = Downcast(ffi::GetRef(op_node)); @@ -209,7 +210,13 @@ class LayoutConvertMutator : public ExprMutator { if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) { // If the op has FRelaxInferLayout, and all the input tensors have known ndim FRelaxInferLayout f = attr_map[op]; - return f(ffi::GetRef(call_node), desired_layouts, var_layout_map); + if (layout_cb != nullptr) { + ffi::Map> custom_layouts; + custom_layouts = layout_cb(ffi::GetRef(call_node)); + return f(ffi::GetRef(call_node), custom_layouts, var_layout_map); + } else { + return f(ffi::GetRef(call_node), desired_layouts, var_layout_map); + } } else { // Otherwise, we use the default policy. return std::nullopt; @@ -218,7 +225,7 @@ class LayoutConvertMutator : public ExprMutator { void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { ffi::Optional res = - GetInferLayoutInfo(call_node, desired_layouts_, var_layout_map_); + GetInferLayoutInfo(call_node, desired_layouts_, layout_cb_, var_layout_map_); ObjectPtr new_call = ffi::make_object(*call_node); new_call->struct_info_ = std::nullopt; if (!res.defined() || @@ -335,20 +342,23 @@ class LayoutConvertMutator : public ExprMutator { std::unordered_map var_layout_map_; ffi::Map> desired_layouts_; + LayoutCb layout_cb_; }; // namespace relax DataflowBlock ConvertLayoutPass(const DataflowBlock& df_block, - ffi::Map> desired_layouts) { - LayoutConvertMutator mutator(desired_layouts); + ffi::Map> desired_layouts, + LayoutCb layout_cb) { + LayoutConvertMutator mutator(desired_layouts, layout_cb); return Downcast(mutator.VisitBindingBlock(df_block)); } namespace transform { -Pass ConvertLayout(ffi::Map> desired_layouts) { +Pass ConvertLayout(ffi::Map> desired_layouts, + LayoutCb layout_cb) { ffi::TypedFunction pass_func = [=](DataflowBlock df_block, IRModule m, PassContext pc) { - return Downcast(ConvertLayoutPass(df_block, desired_layouts)); + return Downcast(ConvertLayoutPass(df_block, desired_layouts, layout_cb)); }; return CreateDataflowBlockPass(pass_func, 0, "ConvertLayout", {}); } diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 84fa9e70c7d7..fe412fd93b18 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -21,10 +21,10 @@ from tvm.script.parser import ir as I, relax as R, tir as T -def verify(input, expected, extra_ops={}): +def verify(input, expected, extra_ops={}, cb=None): desired_layouts = {"relax.nn.conv2d": ["NHWC", "OHWI"]} desired_layouts.update(extra_ops) - mod = ConvertLayout(desired_layouts)(input) + mod = ConvertLayout(desired_layouts, cb)(input) mod = Normalize()(mod) tvm.ir.assert_structural_equal(mod, expected) @@ -5487,5 +5487,96 @@ def main( verify(Input, Expected) +def test_layout_cb(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), dtype="float32"), + w: R.Tensor((4, 4, 3, 3), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), dtype="float32"), + ) -> R.Tensor((2, 4, 24, 24), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 1, 28, 28, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.layout_transform( + bias, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.add(gv, lv2) + gv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv2) + lv3: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform( + w, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + lv4: R.Tensor((2, 1, 24, 24, 4), dtype="float32") = R.nn.conv2d( + gv3, + lv3, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.layout_transform( + lv4, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv4) + return gv4 + + def layout_cb(call: tvm.relax.Call): + return {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]} + + verify(Input, Expected, cb=layout_cb) + + if __name__ == "__main__": tvm.testing.main() From 58fed45c82e82e03b48992774dde7dfd5318062e Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 20 Jan 2026 09:20:22 +0530 Subject: [PATCH 2/3] Update python/tvm/relax/transform/transform.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/tvm/relax/transform/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index b2210e242810..bfd7dbf87d70 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1382,7 +1382,7 @@ def ConvertLayout( ``{"relax.nn.conv2d": ["NHWC", "OHWI"]}``. layout_cb : Callable A user defined call back function that can dynamically handle operator layouts - based on Call description. desigred_layouts will be ignored if layout_cb is defined. + based on Call description. desired_layouts will be ignored if layout_cb is defined. Returns ------- From eb1e1864ce57e95cfeb9de4dc83f20239cc9b5d7 Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 20 Jan 2026 09:25:05 +0530 Subject: [PATCH 3/3] review --- src/relax/transform/convert_layout.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 0f1c74cf16f0..27684313de02 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -210,12 +210,12 @@ class LayoutConvertMutator : public ExprMutator { if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) { // If the op has FRelaxInferLayout, and all the input tensors have known ndim FRelaxInferLayout f = attr_map[op]; + auto call = ffi::GetRef(call_node); if (layout_cb != nullptr) { - ffi::Map> custom_layouts; - custom_layouts = layout_cb(ffi::GetRef(call_node)); - return f(ffi::GetRef(call_node), custom_layouts, var_layout_map); + auto custom_layouts = layout_cb(call); + return f(call, custom_layouts, var_layout_map); } else { - return f(ffi::GetRef(call_node), desired_layouts, var_layout_map); + return f(call, desired_layouts, var_layout_map); } } else { // Otherwise, we use the default policy.