-
Notifications
You must be signed in to change notification settings - Fork 113
[AIMIGRAPHX-231] Enable dynamic shapes for pointwise and reduce fusions #4534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,7 +1,7 @@ | ||||||
| /* | ||||||
| * The MIT License (MIT) | ||||||
| * | ||||||
| * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. | ||||||
| * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. | ||||||
| * | ||||||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
| * of this software and associated documentation files (the "Software"), to deal | ||||||
|
|
@@ -67,6 +67,35 @@ TEST_CASE(single) | |||||
| EXPECT(p1 == p2); | ||||||
| } | ||||||
|
|
||||||
| TEST_CASE(single_dyn) | ||||||
| { | ||||||
| migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}}; | ||||||
|
||||||
| migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}}; | |
| migraphx::shape s{migraphx::shape::float_type, {1, 3}, {3, 4}, {}}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot is right, this is using shape(type_t t, std::vector<std::size_t> mins, std::vector<std::size_t> maxes, std::vector<std::set<std::size_t>> optimals_list);. I would use the simpler to understand constructor for a dynamic shape of shape(float_type, {{1, 3}, {3, 4}})
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,7 +1,7 @@ | ||||||
| /* | ||||||
| * The MIT License (MIT) | ||||||
| * | ||||||
| * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. | ||||||
| * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. | ||||||
| * | ||||||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
| * of this software and associated documentation files (the "Software"), to deal | ||||||
|
|
@@ -63,6 +63,31 @@ TEST_CASE(single) | |||||
| EXPECT(p1 == p2); | ||||||
| } | ||||||
|
|
||||||
| TEST_CASE(single_dyn) | ||||||
| { | ||||||
| migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}}; | ||||||
|
||||||
| migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}}; | |
| migraphx::shape s{migraphx::shape::float_type, {1, 3}, {4, 3}, {}}; |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,7 +1,7 @@ | ||||||||||||||||||||||
| /* | ||||||||||||||||||||||
| * The MIT License (MIT) | ||||||||||||||||||||||
| * | ||||||||||||||||||||||
| * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. | ||||||||||||||||||||||
| * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. | ||||||||||||||||||||||
| * | ||||||||||||||||||||||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||||||||||||||||||
| * of this software and associated documentation files (the "Software"), to deal | ||||||||||||||||||||||
|
|
@@ -61,9 +61,16 @@ migraphx::module_ref add_reduce_module(migraphx::program& p, | |||||||||||||||||||||
| rm->set_bypass(); | ||||||||||||||||||||||
| std::vector<migraphx::instruction_ref> params; | ||||||||||||||||||||||
| std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { | ||||||||||||||||||||||
| return rm->add_parameter( | ||||||||||||||||||||||
| "x" + std::to_string(params.size()), | ||||||||||||||||||||||
| migraphx::shape{input->get_shape().type(), input->get_shape().lens()}); | ||||||||||||||||||||||
| migraphx::shape s; | ||||||||||||||||||||||
| if(input->get_shape().dynamic()) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| s = input->get_shape(); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| else | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| s = migraphx::shape{input->get_shape().type(), input->get_shape().lens()}; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
Comment on lines
+65
to
+72
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
| return rm->add_parameter("x" + std::to_string(params.size()), s); | ||||||||||||||||||||||
| }); | ||||||||||||||||||||||
| auto r = f(rm, params, axes); | ||||||||||||||||||||||
| auto_add_return(rm, r); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The matcher for dynamic broadcast pointwise appears to be checking the wrong argument index. For the 2-input dynamic multibroadcast operation, arg(0) is the input to broadcast and arg(1) is a reference tensor. The matcher should use match::arg(0)(pointwise) instead of match::arg(1)(pointwise) to correctly identify when the broadcasted input is a pointwise operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment by copilot is right. Why is this looking at arg(1) of the 2 input multibroadcast?