Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/module.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -717,7 +717,8 @@ std::vector<shape> module::compute_shapes(const std::vector<shape>& inputs,
ins->get_shape().type_string() + " but passed " +
ins_shapes[ins].type_string());
}
if(options.strict_lens and ins->get_shape().lens() != ins_shapes[ins].lens())
if(not ins->get_shape().dynamic() and options.strict_lens and
ins->get_shape().lens() != ins_shapes[ins].lens())
{
MIGRAPHX_THROW(options.name + ": Mismatched lens: expected {" +
to_string_range(ins->get_shape().lens()) + "} but passed {" +
Expand Down
12 changes: 8 additions & 4 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ struct mlir_op
// Check if the shape can be created from a transpose/broadcast/slice
static bool is_mlir_compatible(const shape& s)
{
if(s.standard() or s.packed() or s.scalar() or s.ndim() == 1)
if(s.standard() or s.packed() or s.scalar() or s.ndim() == 1 or s.dynamic())
return true;
auto ns = reorder_shape(s, find_permutation(s));
std::vector<std::size_t> stride_ratios;
Expand All @@ -202,7 +202,7 @@ struct mlir_op
shape compute_shape(const std::vector<shape>& inputs, const std::vector<module_ref>& mods) const
{
module_ref mod = mods[0];
check_shapes{inputs, *this}.has_at_least(1);
check_shapes{inputs, *this, true}.has_at_least(1);
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");

Expand Down Expand Up @@ -319,6 +319,8 @@ auto is_mlir_dot(mlir_mode mode)
return false;
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(ins->get_shape().dynamic())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we're accepting any dynamic shape for the mlir_op and then going to resolve what actually can be done in mlir when the compute object is created?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my understanding regarding this function was that it decides whether the gemm should use MLIR (as opposed to hipblas) based on the dims. So accepting any dynamic shape is essentially saying always use mlir for the dynamic case

return true;
// dot operation where (FP8 * FP8 = FP8) is not available in MLIR. rocBLAS/hipBLASLt should
// have the support for it.
if(contains(fp8_types{}.get(), ins->get_shape().type()))
Expand Down Expand Up @@ -355,6 +357,8 @@ auto is_mlir_conv(mlir_mode mode)
return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
if(ins->get_shape().dynamic())
return true;
auto input = ins->inputs().front()->get_shape();
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
Expand Down Expand Up @@ -698,7 +702,7 @@ struct find_mlir_split_reduce
* Fuses rocMLIR compatible dot or conv op -> reshapes -> pointwise
* into a mlir_op with submodule.
*/
struct find_mlir_fused_ops
struct find_mlir_fused_ops : match::supports_dynamic_shapes
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
Expand Down Expand Up @@ -974,7 +978,7 @@ struct find_mlir_fused_geg_ops
};

template <auto Matcher>
struct find_mlir_standalone_op
struct find_mlir_standalone_op : match::supports_dynamic_shapes
{
mlir_mode mode = mlir_mode::none;
std::size_t* counter = nullptr;
Expand Down
60 changes: 60 additions & 0 deletions test/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2727,6 +2727,66 @@ TEST_CASE_SKIP(dot_add_dot_both_multi_user, "Not supported in rocMLIR")
EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(dyn_dot)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {6, 6}}};
migraphx::shape s2{migraphx::shape::float_type, {6, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s1);
auto b = mm->add_parameter("b", s2);
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
mm->add_return({dot});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("a", s1);
auto b = mm->add_parameter("b", s2);
auto a_cont = mm->add_instruction(migraphx::make_op("contiguous"), a);

auto fused =
add_mlir(p2, "mlir_dot0", {a_cont, b}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) {
auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]);
return std::make_tuple(dot->get_operator(), dot);
});
mm->add_return({fused});
}
EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(dyn_conv)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {56, 56}, {8, 64}, {8, 64}}};
migraphx::shape s2{migraphx::shape::float_type, {14, 56, 3, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s1);
auto w = mm->add_parameter("w", s2);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), x, w);
mm->add_return({conv});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s1);
auto w = mm->add_parameter("w", s2);
auto x_cont = mm->add_instruction(migraphx::make_op("contiguous"), x);
auto conv = add_mlir(
p2, "mlir_convolution0", {x_cont, w}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) {
auto c =
pm->add_instruction(migraphx::make_op("convolution"), inputs[0], inputs[1]);
return std::make_tuple(c->get_operator(), c);
});
mm->add_return({conv});
}
EXPECT(p1.sort() == p2.sort());
}

int main(int argc, const char* argv[])
{
if(migraphx::gpu::mlir_enabled())
Expand Down
Loading