Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/fuse_pointwise.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -46,7 +46,7 @@ static literal get_scalar(instruction_ref ins)
if(contains({"contiguous", "broadcast", "multibroadcast"}, ins->name()))
return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape();
if(s.elements() != 1 and not(s.scalar()))
if(s.dynamic() or (s.elements() != 1 and not(s.scalar())))
return {};
if(not ins->can_eval())
return {};
Expand Down Expand Up @@ -340,16 +340,20 @@ struct pointwise_reshape : rewrite_reshapes_base
static std::string name() { return "pointwise"; }
};

struct pointwise_broadcast_pointwise
struct pointwise_broadcast_pointwise : match::supports_dynamic_shapes
{
auto matcher() const
{
auto pointwise = match::name("pointwise")(match::used_once()).bind("x");
auto broadcast_pointwise =
match::name("multibroadcast")(
match::used_once(),
match::args(match::name("pointwise")(match::used_once()).bind("x")))
match::name("multibroadcast")(match::used_once(), match::args(pointwise))
.bind("broadcast");
return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise));
auto dyn_broadcast_pointwise = match::name("multibroadcast")(match::used_once(),
match::nargs(2),
match::arg(1)(pointwise))
Copy link

Copilot AI Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The matcher for dynamic broadcast pointwise appears to be checking the wrong argument index. For the 2-input dynamic multibroadcast operation, arg(0) is the input to broadcast and arg(1) is a reference tensor. The matcher should use match::arg(0)(pointwise) instead of match::arg(1)(pointwise) to correctly identify when the broadcasted input is a pointwise operation.

Suggested change
match::arg(1)(pointwise))
match::arg(0)(pointwise))

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment by copilot is right. Why is this looking at arg(1) of the 2 input multibroadcast?

.bind("broadcast");
return match::name("pointwise")(match::any_of[match::inputs()](
match::any_of(broadcast_pointwise, dyn_broadcast_pointwise)));
}

void apply(module& m, const match::matcher_result& r) const
Expand Down
10 changes: 7 additions & 3 deletions src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -63,15 +63,19 @@ struct fused_reduce
if(not sm->bypass())
MIGRAPHX_THROW("fused_reduce: bypass flag is not set");
auto names = sm->get_parameter_names();
check_shapes{inputs, *this}.has(names.size()).same_ndims();
check_shapes{inputs, *this, true}.has(names.size()).same_ndims();
std::sort(names.begin(), names.end());
auto shapes = sm->get_parameter_shapes();
// Check dimension matches for each input
if(not equal(names, inputs, [&](const auto& name, const auto& input) {
return shapes.at(name).lens() == input.lens();
auto s = shapes.at(name);
return shape::same_lens(input, s);
}))
MIGRAPHX_THROW("Input dimension does not match the submodule.");

if(sm->get_output_shapes().front().dynamic())
return sm->get_output_shapes().front();

return shape::from_permutation(sm->get_output_shapes().front().type(),
sm->get_output_shapes().front().lens(),
find_permutation(inputs));
Expand Down
9 changes: 6 additions & 3 deletions src/include/migraphx/op/pointwise.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -61,15 +61,18 @@ struct pointwise
MIGRAPHX_THROW("pointwise should have at least one input");
auto* pm = mods.front();
auto pnames = pm->get_parameter_names();
check_shapes{inputs, *this}.has(pnames.size()).same_dims();
check_shapes{inputs, *this, true}.has(pnames.size()).same_dims();

std::vector<std::size_t> scalar_const_out_lens =
inputs.front().dynamic() ? std::vector<std::size_t>{} : inputs.front().lens();

const auto rank = inputs.front().ndim();
const bool has_broadcasts =
std::any_of(inputs.begin(), inputs.end(), [](auto s) { return s.broadcasted(); });

auto result = pm->compute_shapes(
(rank > 1 and has_broadcasts) ? remove_broadcasts(inputs) : inputs,
{.name = name(), .strict_type = true, .scalar_const_out_lens = inputs.front().lens()});
{.name = name(), .strict_type = true, .scalar_const_out_lens = scalar_const_out_lens});
if(result.size() == 1)
return result.front();
return shape{result};
Expand Down
4 changes: 3 additions & 1 deletion src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -359,6 +359,8 @@ struct MIGRAPHX_EXPORT shape
MIGRAPHX_EXPORT friend bool operator!=(const shape& x, const shape& y);
MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const shape& x);

static bool same_lens(const shape& x, const shape& y);

template <class T>
struct as
{
Expand Down
15 changes: 14 additions & 1 deletion src/shape.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -828,6 +828,19 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
return os;
}

bool shape::same_lens(const shape& x, const shape& y)
{
if(x.dynamic() and y.dynamic())
{
return x.dyn_dims() == y.dyn_dims();
}
else if(x.dynamic() or y.dynamic())
{
MIGRAPHX_THROW("SHAPE: same_lens() called on mixed dynamic and static shapes");
}
return x.lens() == y.lens();
}

shape::type_t shape::parse_type(const std::string& s)
{
static const std::unordered_map<std::string, shape::type_t> m = {
Expand Down
31 changes: 30 additions & 1 deletion test/fuse_pointwise.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -67,6 +67,35 @@ TEST_CASE(single)
EXPECT(p1 == p2);
}

TEST_CASE(single_dyn)
{
migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}};
Copy link

Copilot AI Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dynamic shape definition appears to have invalid min/max values. For dimension 1, min=4 but max=3, which is impossible since the minimum value cannot exceed the maximum. The parameters should likely be swapped, e.g., mins={1, 3} and maxes={4, 3}, or corrected to valid ranges like mins={1, 3} and maxes={3, 4}.

Suggested change
migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}};
migraphx::shape s{migraphx::shape::float_type, {1, 3}, {3, 4}, {}};

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot is right, this is using shape(type_t t, std::vector<std::size_t> mins, std::vector<std::size_t> maxes, std::vector<std::set<std::size_t>> optimals_list);. I would use the simpler to understand constructor for a dynamic shape of shape(float_type, {{1, 3}, {3, 4}})

migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), pass, z);
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add"));
auto pass = mm->add_instruction(pass_op{}, add1);
auto add2 = add_pointwise(p2, "main:pointwise1", {pass, z}, single_pointwise("add"));
mm->add_return({add2});
}
EXPECT(p1 == p2);
}

TEST_CASE(double_add)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
Expand Down
27 changes: 26 additions & 1 deletion test/fuse_reduce.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -63,6 +63,31 @@ TEST_CASE(single)
EXPECT(p1 == p2);
}

TEST_CASE(single_dyn)
{
migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}};
Copy link

Copilot AI Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dynamic shape definition appears to have invalid min/max values. For dimension 1, min=4 but max=3, which is impossible since the minimum value cannot exceed the maximum. The parameters should likely be swapped, e.g., mins={1, 3} and maxes={4, 3}, or corrected to valid ranges like mins={1, 3} and maxes={3, 4}.

Suggested change
migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}};
migraphx::shape s{migraphx::shape::float_type, {1, 3}, {4, 3}, {}};

Copilot uses AI. Check for mistakes.
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), y);
mm->add_return({rsum1, rsum2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum1 = add_reduce(p2, "main:reduce_sum0", {x}, {1}, single_reduce("reduce_sum"));
auto rsum2 = add_reduce(p2, "main:reduce_sum1", {y}, {1}, single_reduce("reduce_sum"));
mm->add_return({rsum1, rsum2});
}
EXPECT(p1 == p2);
}

TEST_CASE(pointwise_reduce)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
Expand Down
15 changes: 11 additions & 4 deletions test/include/reduce.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -61,9 +61,16 @@ migraphx::module_ref add_reduce_module(migraphx::program& p,
rm->set_bypass();
std::vector<migraphx::instruction_ref> params;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) {
return rm->add_parameter(
"x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type(), input->get_shape().lens()});
migraphx::shape s;
if(input->get_shape().dynamic())
{
s = input->get_shape();
}
else
{
s = migraphx::shape{input->get_shape().type(), input->get_shape().lens()};
}
Comment on lines +65 to +72
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if(input->get_shape().dynamic())
{
s = input->get_shape();
}
else
{
s = migraphx::shape{input->get_shape().type(), input->get_shape().lens()};
}
s = input.as_standard();

return rm->add_parameter("x" + std::to_string(params.size()), s);
});
auto r = f(rm, params, axes);
auto_add_return(rm, r);
Expand Down
Loading