From 29cecd400b61d0a7c1854092d965b72500d0c845 Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 9 Jan 2026 14:39:31 -0800 Subject: [PATCH 1/2] support dynamic shapes in mlir_op --- src/module.cpp | 5 +-- src/targets/gpu/fuse_mlir.cpp | 12 ++++--- test/gpu/fuse_mlir.cpp | 60 +++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 6 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 4838d241904..92f09db3eea 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -717,7 +717,8 @@ std::vector module::compute_shapes(const std::vector& inputs, ins->get_shape().type_string() + " but passed " + ins_shapes[ins].type_string()); } - if(options.strict_lens and ins->get_shape().lens() != ins_shapes[ins].lens()) + if(not ins->get_shape().dynamic() and options.strict_lens and + ins->get_shape().lens() != ins_shapes[ins].lens()) { MIGRAPHX_THROW(options.name + ": Mismatched lens: expected {" + to_string_range(ins->get_shape().lens()) + "} but passed {" + diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index c8f0b0f154f..f24951f5477 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -177,7 +177,7 @@ struct mlir_op // Check if the shape can be created from a transpose/broadcast/slice static bool is_mlir_compatible(const shape& s) { - if(s.standard() or s.packed() or s.scalar() or s.ndim() == 1) + if(s.standard() or s.packed() or s.scalar() or s.ndim() == 1 or s.dynamic()) return true; auto ns = reorder_shape(s, find_permutation(s)); std::vector stride_ratios; @@ -202,7 +202,7 @@ struct mlir_op shape compute_shape(const std::vector& inputs, const std::vector& mods) const { module_ref mod = mods[0]; - check_shapes{inputs, *this}.has_at_least(1); + check_shapes{inputs, *this, true}.has_at_least(1); if(mods.size() != 1) MIGRAPHX_THROW("should have one submodule."); @@ -319,6 +319,8 @@ auto is_mlir_dot(mlir_mode mode) return false; if(ins->name() != "dot" and ins->name() != "quant_dot") return false; + if(ins->get_shape().dynamic()) + return true; // dot operation where (FP8 * FP8 = FP8) is not available in MLIR. rocBLAS/hipBLASLt should // have the support for it. if(contains(fp8_types{}.get(), ins->get_shape().type())) @@ -355,6 +357,8 @@ auto is_mlir_conv(mlir_mode mode) return false; if(ins->name() != "convolution" and ins->name() != "quant_convolution") return false; + if(ins->get_shape().dynamic()) + return true; auto input = ins->inputs().front()->get_shape(); value v = ins->get_operator().to_value(); auto group = v.at("group").to(); @@ -698,7 +702,7 @@ struct find_mlir_split_reduce * Fuses rocMLIR compatible dot or conv op -> reshapes -> pointwise * into a mlir_op with submodule. */ -struct find_mlir_fused_ops +struct find_mlir_fused_ops : match::supports_dynamic_shapes { mlir_mode conv_mode = mlir_mode::none; mlir_mode dot_mode = mlir_mode::none; @@ -974,7 +978,7 @@ struct find_mlir_fused_geg_ops }; template -struct find_mlir_standalone_op +struct find_mlir_standalone_op : match::supports_dynamic_shapes { mlir_mode mode = mlir_mode::none; std::size_t* counter = nullptr; diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 5323e5c7341..bc3cd5a25c3 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -2727,6 +2727,66 @@ TEST_CASE_SKIP(dot_add_dot_both_multi_user, "Not supported in rocMLIR") EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(dyn_dot) +{ + migraphx::shape s1{migraphx::shape::float_type, {1, 6}, {4, 6}, {}}; + migraphx::shape s2{migraphx::shape::float_type, {6, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); + mm->add_return({dot}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("a", s1); + auto b = mm->add_parameter("b", s2); + auto a_cont = mm->add_instruction(migraphx::make_op("contiguous"), a); + + auto fused = + add_mlir(p2, "mlir_dot0", {a_cont, b}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) { + auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + return std::make_tuple(dot->get_operator(), dot); + }); + mm->add_return({fused}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(dyn_conv) +{ + migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {56, 56}, {8, 64}, {8, 64}}}; + migraphx::shape s2{migraphx::shape::float_type, {14, 56, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto w = mm->add_parameter("w", s2); + auto conv = mm->add_instruction(migraphx::make_op("convolution"), x, w); + mm->add_return({conv}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto w = mm->add_parameter("w", s2); + auto x_cont = mm->add_instruction(migraphx::make_op("contiguous"), x); + auto conv = add_mlir( + p2, "mlir_convolution0", {x_cont, w}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) { + auto c = + pm->add_instruction(migraphx::make_op("convolution"), inputs[0], inputs[1]); + return std::make_tuple(c->get_operator(), c); + }); + mm->add_return({conv}); + } + EXPECT(p1.sort() == p2.sort()); +} + int main(int argc, const char* argv[]) { if(migraphx::gpu::mlir_enabled()) From a52b8069426242a9e454c35152581c2516caaef1 Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 9 Jan 2026 14:44:02 -0800 Subject: [PATCH 2/2] fix test case readability --- test/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index bc3cd5a25c3..f66f61fc2f2 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -2729,7 +2729,7 @@ TEST_CASE_SKIP(dot_add_dot_both_multi_user, "Not supported in rocMLIR") TEST_CASE(dyn_dot) { - migraphx::shape s1{migraphx::shape::float_type, {1, 6}, {4, 6}, {}}; + migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {6, 6}}}; migraphx::shape s2{migraphx::shape::float_type, {6, 3}}; migraphx::program p1; {