diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 42c95514555..563c420de82 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -46,7 +46,7 @@ static literal get_scalar(instruction_ref ins) if(contains({"contiguous", "broadcast", "multibroadcast"}, ins->name())) return get_scalar(ins->inputs().front()); const auto& s = ins->get_shape(); - if(s.elements() != 1 and not(s.scalar())) + if(s.dynamic() or (s.elements() != 1 and not(s.scalar()))) return {}; if(not ins->can_eval()) return {}; @@ -340,16 +340,20 @@ struct pointwise_reshape : rewrite_reshapes_base static std::string name() { return "pointwise"; } }; -struct pointwise_broadcast_pointwise +struct pointwise_broadcast_pointwise : match::supports_dynamic_shapes { auto matcher() const { + auto pointwise = match::name("pointwise")(match::used_once()).bind("x"); auto broadcast_pointwise = - match::name("multibroadcast")( - match::used_once(), - match::args(match::name("pointwise")(match::used_once()).bind("x"))) + match::name("multibroadcast")(match::used_once(), match::args(pointwise)) .bind("broadcast"); - return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise)); + auto dyn_broadcast_pointwise = match::name("multibroadcast")(match::used_once(), + match::nargs(2), + match::arg(1)(pointwise)) + .bind("broadcast"); + return match::name("pointwise")(match::any_of[match::inputs()]( + match::any_of(broadcast_pointwise, dyn_broadcast_pointwise))); } void apply(module& m, const match::matcher_result& r) const diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 8d1a7ff39d6..f15eeabeea8 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -63,15 +63,19 @@ struct fused_reduce if(not sm->bypass()) MIGRAPHX_THROW("fused_reduce: bypass flag is not set"); auto names = sm->get_parameter_names(); - check_shapes{inputs, *this}.has(names.size()).same_ndims(); + check_shapes{inputs, *this, true}.has(names.size()).same_ndims(); std::sort(names.begin(), names.end()); auto shapes = sm->get_parameter_shapes(); // Check dimension matches for each input if(not equal(names, inputs, [&](const auto& name, const auto& input) { - return shapes.at(name).lens() == input.lens(); + auto s = shapes.at(name); + return shape::same_lens(input, s); })) MIGRAPHX_THROW("Input dimension does not match the submodule."); + if(sm->get_output_shapes().front().dynamic()) + return sm->get_output_shapes().front(); + return shape::from_permutation(sm->get_output_shapes().front().type(), sm->get_output_shapes().front().lens(), find_permutation(inputs)); diff --git a/src/include/migraphx/op/pointwise.hpp b/src/include/migraphx/op/pointwise.hpp index 51ca75ee92b..9d16879cf35 100644 --- a/src/include/migraphx/op/pointwise.hpp +++ b/src/include/migraphx/op/pointwise.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -61,7 +61,10 @@ struct pointwise MIGRAPHX_THROW("pointwise should have at least one input"); auto* pm = mods.front(); auto pnames = pm->get_parameter_names(); - check_shapes{inputs, *this}.has(pnames.size()).same_dims(); + check_shapes{inputs, *this, true}.has(pnames.size()).same_dims(); + + std::vector scalar_const_out_lens = + inputs.front().dynamic() ? std::vector{} : inputs.front().lens(); const auto rank = inputs.front().ndim(); const bool has_broadcasts = @@ -69,7 +72,7 @@ struct pointwise auto result = pm->compute_shapes( (rank > 1 and has_broadcasts) ? remove_broadcasts(inputs) : inputs, - {.name = name(), .strict_type = true, .scalar_const_out_lens = inputs.front().lens()}); + {.name = name(), .strict_type = true, .scalar_const_out_lens = scalar_const_out_lens}); if(result.size() == 1) return result.front(); return shape{result}; diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 22f6d1663d6..64b510176c3 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -359,6 +359,8 @@ struct MIGRAPHX_EXPORT shape MIGRAPHX_EXPORT friend bool operator!=(const shape& x, const shape& y); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const shape& x); + static bool same_lens(const shape& x, const shape& y); + template struct as { diff --git a/src/shape.cpp b/src/shape.cpp index 678bf5b53f0..6698ff126fc 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -828,6 +828,19 @@ std::ostream& operator<<(std::ostream& os, const shape& x) return os; } +bool shape::same_lens(const shape& x, const shape& y) +{ + if(x.dynamic() and y.dynamic()) + { + return x.dyn_dims() == y.dyn_dims(); + } + else if(x.dynamic() or y.dynamic()) + { + MIGRAPHX_THROW("SHAPE: same_lens() called on mixed dynamic and static shapes"); + } + return x.lens() == y.lens(); +} + shape::type_t shape::parse_type(const std::string& s) { static const std::unordered_map m = { diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 44e08e63c65..08a1950fc63 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -67,6 +67,35 @@ TEST_CASE(single) EXPECT(p1 == p2); } +TEST_CASE(single_dyn) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto pass = mm->add_instruction(pass_op{}, add1); + auto add2 = mm->add_instruction(migraphx::make_op("add"), pass, z); + mm->add_return({add2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add")); + auto pass = mm->add_instruction(pass_op{}, add1); + auto add2 = add_pointwise(p2, "main:pointwise1", {pass, z}, single_pointwise("add")); + mm->add_return({add2}); + } + EXPECT(p1 == p2); +} + TEST_CASE(double_add) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 8bf01bc7789..589c3c07cb1 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -63,6 +63,31 @@ TEST_CASE(single) EXPECT(p1 == p2); } +TEST_CASE(single_dyn) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x); + auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), y); + mm->add_return({rsum1, rsum2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto rsum1 = add_reduce(p2, "main:reduce_sum0", {x}, {1}, single_reduce("reduce_sum")); + auto rsum2 = add_reduce(p2, "main:reduce_sum1", {y}, {1}, single_reduce("reduce_sum")); + mm->add_return({rsum1, rsum2}); + } + EXPECT(p1 == p2); +} + TEST_CASE(pointwise_reduce) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; diff --git a/test/include/reduce.hpp b/test/include/reduce.hpp index 81d87566893..6583bb900eb 100644 --- a/test/include/reduce.hpp +++ b/test/include/reduce.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -61,9 +61,16 @@ migraphx::module_ref add_reduce_module(migraphx::program& p, rm->set_bypass(); std::vector params; std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { - return rm->add_parameter( - "x" + std::to_string(params.size()), - migraphx::shape{input->get_shape().type(), input->get_shape().lens()}); + migraphx::shape s; + if(input->get_shape().dynamic()) + { + s = input->get_shape(); + } + else + { + s = migraphx::shape{input->get_shape().type(), input->get_shape().lens()}; + } + return rm->add_parameter("x" + std::to_string(params.size()), s); }); auto r = f(rm, params, axes); auto_add_return(rm, r);