From 8deac8d9ebb0da28c3b9851a3ad2269c7e771e8e Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 8 Jan 2026 15:52:20 -0800 Subject: [PATCH 1/3] enable dyn shapes for pointwise and reduce fusions --- src/fuse_pointwise.cpp | 16 +++++++++------ src/fuse_reduce.cpp | 8 ++++++-- src/include/migraphx/op/pointwise.hpp | 7 +++++-- src/include/migraphx/shape.hpp | 2 ++ src/shape.cpp | 13 ++++++++++++ test/fuse_pointwise.cpp | 29 +++++++++++++++++++++++++++ test/fuse_reduce.cpp | 25 +++++++++++++++++++++++ test/include/reduce.hpp | 13 +++++++++--- 8 files changed, 100 insertions(+), 13 deletions(-) diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 42c95514555..722597a63b2 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -46,7 +46,7 @@ static literal get_scalar(instruction_ref ins) if(contains({"contiguous", "broadcast", "multibroadcast"}, ins->name())) return get_scalar(ins->inputs().front()); const auto& s = ins->get_shape(); - if(s.elements() != 1 and not(s.scalar())) + if(s.dynamic() or (s.elements() != 1 and not(s.scalar()))) return {}; if(not ins->can_eval()) return {}; @@ -340,16 +340,20 @@ struct pointwise_reshape : rewrite_reshapes_base static std::string name() { return "pointwise"; } }; -struct pointwise_broadcast_pointwise +struct pointwise_broadcast_pointwise : match::supports_dynamic_shapes { auto matcher() const { + auto pointwise = match::name("pointwise")(match::used_once()).bind("x"); auto broadcast_pointwise = - match::name("multibroadcast")( - match::used_once(), - match::args(match::name("pointwise")(match::used_once()).bind("x"))) + match::name("multibroadcast")(match::used_once(), match::args(pointwise)) .bind("broadcast"); - return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise)); + auto dyn_broadcast_pointwise = match::name("multibroadcast")(match::used_once(), + match::nargs(2), + match::arg(1)(pointwise)) + .bind("broadcast"); + return match::name("pointwise")(match::any_of[match::inputs()]( + match::any_of(broadcast_pointwise, dyn_broadcast_pointwise))); } void apply(module& m, const match::matcher_result& r) const diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 8d1a7ff39d6..f1c7329e402 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -63,15 +63,19 @@ struct fused_reduce if(not sm->bypass()) MIGRAPHX_THROW("fused_reduce: bypass flag is not set"); auto names = sm->get_parameter_names(); - check_shapes{inputs, *this}.has(names.size()).same_ndims(); + check_shapes{inputs, *this, true}.has(names.size()).same_ndims(); std::sort(names.begin(), names.end()); auto shapes = sm->get_parameter_shapes(); // Check dimension matches for each input if(not equal(names, inputs, [&](const auto& name, const auto& input) { - return shapes.at(name).lens() == input.lens(); + auto s = shapes.at(name); + return shape::same_lens(input, s); })) MIGRAPHX_THROW("Input dimension does not match the submodule."); + if(sm->get_output_shapes().front().dynamic()) + return sm->get_output_shapes().front(); + return shape::from_permutation(sm->get_output_shapes().front().type(), sm->get_output_shapes().front().lens(), find_permutation(inputs)); diff --git a/src/include/migraphx/op/pointwise.hpp b/src/include/migraphx/op/pointwise.hpp index 51ca75ee92b..03f16ca7d01 100644 --- a/src/include/migraphx/op/pointwise.hpp +++ b/src/include/migraphx/op/pointwise.hpp @@ -61,7 +61,10 @@ struct pointwise MIGRAPHX_THROW("pointwise should have at least one input"); auto* pm = mods.front(); auto pnames = pm->get_parameter_names(); - check_shapes{inputs, *this}.has(pnames.size()).same_dims(); + check_shapes{inputs, *this, true}.has(pnames.size()).same_dims(); + + std::vector scalar_const_out_lens = + inputs.front().dynamic() ? std::vector{} : inputs.front().lens(); const auto rank = inputs.front().ndim(); const bool has_broadcasts = @@ -69,7 +72,7 @@ struct pointwise auto result = pm->compute_shapes( (rank > 1 and has_broadcasts) ? remove_broadcasts(inputs) : inputs, - {.name = name(), .strict_type = true, .scalar_const_out_lens = inputs.front().lens()}); + {.name = name(), .strict_type = true, .scalar_const_out_lens = scalar_const_out_lens}); if(result.size() == 1) return result.front(); return shape{result}; diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 22f6d1663d6..d8c9c5ffed2 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -359,6 +359,8 @@ struct MIGRAPHX_EXPORT shape MIGRAPHX_EXPORT friend bool operator!=(const shape& x, const shape& y); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const shape& x); + static bool same_lens(const shape& x, const shape& y); + template struct as { diff --git a/src/shape.cpp b/src/shape.cpp index 678bf5b53f0..c7a96ed9efd 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -828,6 +828,19 @@ std::ostream& operator<<(std::ostream& os, const shape& x) return os; } +bool shape::same_lens(const shape& x, const shape& y) +{ + if(x.dynamic() and y.dynamic()) + { + return x.dyn_dims() == y.dyn_dims(); + } + else if(x.dynamic() or y.dynamic()) + { + MIGRAPHX_THROW("SHAPE: same_lens() called on mixed dynamic and static shapes"); + } + return x.lens() == y.lens(); +} + shape::type_t shape::parse_type(const std::string& s) { static const std::unordered_map m = { diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 44e08e63c65..62751b1d831 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -67,6 +67,35 @@ TEST_CASE(single) EXPECT(p1 == p2); } +TEST_CASE(single_dyn) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto pass = mm->add_instruction(pass_op{}, add1); + auto add2 = mm->add_instruction(migraphx::make_op("add"), pass, z); + mm->add_return({add2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add")); + auto pass = mm->add_instruction(pass_op{}, add1); + auto add2 = add_pointwise(p2, "main:pointwise1", {pass, z}, single_pointwise("add")); + mm->add_return({add2}); + } + EXPECT(p1 == p2); +} + TEST_CASE(double_add) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 8bf01bc7789..4f07a8be63a 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -63,6 +63,31 @@ TEST_CASE(single) EXPECT(p1 == p2); } +TEST_CASE(single_dyn) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x); + auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), y); + mm->add_return({rsum1, rsum2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto rsum1 = add_reduce(p2, "main:reduce_sum0", {x}, {1}, single_reduce("reduce_sum")); + auto rsum2 = add_reduce(p2, "main:reduce_sum1", {y}, {1}, single_reduce("reduce_sum")); + mm->add_return({rsum1, rsum2}); + } + EXPECT(p1 == p2); +} + TEST_CASE(pointwise_reduce) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; diff --git a/test/include/reduce.hpp b/test/include/reduce.hpp index 81d87566893..59d62b289fe 100644 --- a/test/include/reduce.hpp +++ b/test/include/reduce.hpp @@ -61,9 +61,16 @@ migraphx::module_ref add_reduce_module(migraphx::program& p, rm->set_bypass(); std::vector params; std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { - return rm->add_parameter( - "x" + std::to_string(params.size()), - migraphx::shape{input->get_shape().type(), input->get_shape().lens()}); + migraphx::shape s; + if(input->get_shape().dynamic()) + { + s = input->get_shape(); + } + else + { + s = migraphx::shape{input->get_shape().type(), input->get_shape().lens()}; + } + return rm->add_parameter("x" + std::to_string(params.size()), s); }); auto r = f(rm, params, axes); auto_add_return(rm, r); From d83f2a028fe26840897e32a00c9d9c26e92e7f22 Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 9 Jan 2026 09:10:20 -0800 Subject: [PATCH 2/3] format --- src/fuse_reduce.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index f1c7329e402..b8d8dc9833f 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -75,7 +75,7 @@ struct fused_reduce if(sm->get_output_shapes().front().dynamic()) return sm->get_output_shapes().front(); - + return shape::from_permutation(sm->get_output_shapes().front().type(), sm->get_output_shapes().front().lens(), find_permutation(inputs)); From 370e386c5b265f0a5f31b1723db796dd9137528e Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 9 Jan 2026 09:14:34 -0800 Subject: [PATCH 3/3] License update --- src/fuse_pointwise.cpp | 2 +- src/fuse_reduce.cpp | 2 +- src/include/migraphx/op/pointwise.hpp | 2 +- src/include/migraphx/shape.hpp | 2 +- src/shape.cpp | 2 +- test/fuse_pointwise.cpp | 2 +- test/fuse_reduce.cpp | 2 +- test/include/reduce.hpp | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 722597a63b2..563c420de82 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index b8d8dc9833f..f15eeabeea8 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/migraphx/op/pointwise.hpp b/src/include/migraphx/op/pointwise.hpp index 03f16ca7d01..9d16879cf35 100644 --- a/src/include/migraphx/op/pointwise.hpp +++ b/src/include/migraphx/op/pointwise.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index d8c9c5ffed2..64b510176c3 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/shape.cpp b/src/shape.cpp index c7a96ed9efd..6698ff126fc 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 62751b1d831..08a1950fc63 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 4f07a8be63a..589c3c07cb1 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/include/reduce.hpp b/test/include/reduce.hpp index 59d62b289fe..6583bb900eb 100644 --- a/test/include/reduce.hpp +++ b/test/include/reduce.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal