Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
138 commits
Select commit Hold shift + click to select a range
0c853ef
Rewrite gather
pfultz2 Oct 2, 2025
4abca29
Format
pfultz2 Oct 2, 2025
f0048fc
Add unit tests
pfultz2 Oct 2, 2025
852e0b5
Support nearest resize
pfultz2 Oct 3, 2025
ac28112
Format
pfultz2 Oct 3, 2025
cd3aeb5
Support linear indexing
pfultz2 Oct 3, 2025
3913fc1
Format
pfultz2 Oct 3, 2025
24ab281
Handle over channels
pfultz2 Oct 3, 2025
b87e980
Format
pfultz2 Oct 3, 2025
291add2
Handle 4 factor
pfultz2 Oct 4, 2025
2d9601b
Format
pfultz2 Oct 4, 2025
8a96c59
Wider coverage
pfultz2 Oct 4, 2025
04607d9
Format
pfultz2 Oct 4, 2025
aafd973
Support a split concat
pfultz2 Oct 4, 2025
cd08cbc
Format
pfultz2 Oct 4, 2025
2b9b845
Add gather verify tests
pfultz2 Oct 5, 2025
d40e31e
Format
pfultz2 Oct 5, 2025
2ec87b5
Handle nearest downsample
pfultz2 Oct 5, 2025
4921233
Format
pfultz2 Oct 5, 2025
5b9fb63
Handle by 4 downsample
pfultz2 Oct 5, 2025
fc9c9ca
Format
pfultz2 Oct 5, 2025
1b45f3b
Handle multiple axes and larger sizes
pfultz2 Oct 5, 2025
8083ca0
Format
pfultz2 Oct 5, 2025
4b6f626
Add more tests
pfultz2 Oct 5, 2025
3a87046
Format
pfultz2 Oct 5, 2025
48650ab
Refactor into strategy classes
pfultz2 Oct 7, 2025
43bdce9
Format
pfultz2 Oct 7, 2025
5b065ba
Refactor
pfultz2 Oct 8, 2025
f3a4dda
Format
pfultz2 Oct 8, 2025
4adff3f
Add rectangular pattern
pfultz2 Oct 8, 2025
e060da6
Format
pfultz2 Oct 8, 2025
097db77
Some more cleanup
pfultz2 Oct 8, 2025
ba45916
Format
pfultz2 Oct 8, 2025
6d35650
Reduce searching
pfultz2 Oct 8, 2025
ff59c19
Format
pfultz2 Oct 8, 2025
93682e0
Refactor
pfultz2 Oct 9, 2025
1910cc6
Format
pfultz2 Oct 9, 2025
893a310
Remove segment type enum
pfultz2 Oct 9, 2025
171a15d
Format
pfultz2 Oct 9, 2025
968b794
Format
pfultz2 Oct 9, 2025
c6ca6ec
Format
pfultz2 Oct 9, 2025
18321a3
Remove rectangular pattern
pfultz2 Oct 10, 2025
0a02c3e
Format
pfultz2 Oct 10, 2025
79e9817
Handle grid tiles
pfultz2 Oct 10, 2025
da348a5
Format
pfultz2 Oct 10, 2025
8ac727f
Add multi-axis
pfultz2 Oct 10, 2025
c9340ca
Format
pfultz2 Oct 10, 2025
547cf62
Refactor target shape
pfultz2 Oct 10, 2025
4096a52
Format
pfultz2 Oct 10, 2025
f1e880a
Refactor gather context
pfultz2 Oct 13, 2025
8de9e1d
Format
pfultz2 Oct 13, 2025
1d8d4fc
Merge branch 'develop' into codex-gather-rewrite
pfultz2 Oct 13, 2025
bb74fd2
Add verify tests
pfultz2 Oct 14, 2025
158204e
Format
pfultz2 Oct 14, 2025
db3701b
Re-eanble the tests
pfultz2 Oct 14, 2025
4adc54b
Format
pfultz2 Oct 14, 2025
b71c512
Update flatten tests
pfultz2 Oct 14, 2025
ced0567
Format
pfultz2 Oct 14, 2025
f6f1bc9
Fix param shape
pfultz2 Oct 14, 2025
e335dfc
Insert squeeze/unsqueeze
pfultz2 Oct 15, 2025
229bbf3
Format
pfultz2 Oct 15, 2025
cb3554b
Normalize gather
pfultz2 Oct 15, 2025
7b8fa2d
Format
pfultz2 Oct 15, 2025
3da78a6
Add more unit tests
pfultz2 Oct 15, 2025
d61fb22
Format
pfultz2 Oct 15, 2025
dd6167c
Try to use slice for step
pfultz2 Oct 15, 2025
800c055
Format
pfultz2 Oct 15, 2025
d921d51
Fix slice with step
pfultz2 Oct 16, 2025
2eb0c96
Remove the repetive meta class
pfultz2 Oct 16, 2025
482a633
Format
pfultz2 Oct 16, 2025
310ca2a
1d tests pass
pfultz2 Oct 20, 2025
bfe6d6b
Format
pfultz2 Oct 20, 2025
7a8b7ea
Fix rtr_window 1d
pfultz2 Oct 21, 2025
c443573
Format
pfultz2 Oct 21, 2025
981208e
Add comment
pfultz2 Oct 22, 2025
e922cdb
Use tensor_view instead
pfultz2 Oct 24, 2025
a38e01f
FOrmat
pfultz2 Oct 24, 2025
31e3fbf
Update segmenting of indices
pfultz2 Oct 25, 2025
9f6fc2f
Format
pfultz2 Oct 25, 2025
7bb12be
Use stride-based view to create ops
pfultz2 Oct 27, 2025
4c2d481
Format
pfultz2 Oct 27, 2025
28043a1
Skip overlapping
pfultz2 Oct 30, 2025
f212f31
Move shape generation to shape transform desc
pfultz2 Oct 31, 2025
584a424
Format
pfultz2 Oct 31, 2025
1289f27
Move to shape_transform_descriptor
pfultz2 Oct 31, 2025
973975c
Format
pfultz2 Oct 31, 2025
04c0f84
Add offset test
pfultz2 Oct 31, 2025
f4261ca
Format
pfultz2 Oct 31, 2025
64ee0a4
Improve offset handling
pfultz2 Nov 1, 2025
fe5c15c
Format
pfultz2 Nov 1, 2025
bc83f5d
Add gather verify tests
pfultz2 Nov 1, 2025
0a24050
Format
pfultz2 Nov 1, 2025
87825c2
Remove debug prints
pfultz2 Nov 1, 2025
2c74c7f
Format
pfultz2 Nov 1, 2025
bc18bfd
Cleanup
pfultz2 Nov 1, 2025
bcf56b9
Format
pfultz2 Nov 1, 2025
b99aebe
Add unit test for common dims
pfultz2 Nov 2, 2025
df30b2d
Fix common dims
pfultz2 Nov 2, 2025
17a78a9
Format
pfultz2 Nov 2, 2025
5b0bbc4
Some more cleanup
pfultz2 Nov 2, 2025
fb4fe9e
Format
pfultz2 Nov 2, 2025
82288e2
Only find uniform segments
pfultz2 Nov 3, 2025
72286ab
Make algo parallel
pfultz2 Nov 3, 2025
545a69c
Format
pfultz2 Nov 3, 2025
9c15621
Rewrite gather once
pfultz2 Nov 3, 2025
47996d1
Format
pfultz2 Nov 3, 2025
b4d1100
Always use standard shape
pfultz2 Nov 3, 2025
dfcdfaf
Add flag for uniform
pfultz2 Nov 4, 2025
82939f3
Merge
pfultz2 Jan 14, 2026
519aa73
Format
pfultz2 Jan 14, 2026
b4409fc
Fix tests
pfultz2 Jan 14, 2026
25c0fe8
Format
pfultz2 Jan 14, 2026
3acf7ac
Remove unused functions
pfultz2 Jan 14, 2026
87a7636
Disable matcher
pfultz2 Jan 14, 2026
3c50695
Fix hang
pfultz2 Jan 14, 2026
a11c704
Format
pfultz2 Jan 14, 2026
78565eb
Simplify more
pfultz2 Jan 14, 2026
393c90b
Renable test case
pfultz2 Jan 14, 2026
98f6c6f
Format
pfultz2 Jan 14, 2026
49e7771
Remove commented tests
pfultz2 Jan 14, 2026
d542141
Disable find_where
pfultz2 Jan 14, 2026
71d2fba
Remove mutable var
pfultz2 Jan 14, 2026
cc110e2
Refactor
pfultz2 Jan 14, 2026
aa8851d
Format
pfultz2 Jan 14, 2026
e341fd6
Still simplify shape transforms
pfultz2 Jan 15, 2026
b2ebdb6
Format
pfultz2 Jan 15, 2026
9f12141
Format
pfultz2 Jan 15, 2026
c63a3f6
Update tests
pfultz2 Jan 15, 2026
47129c1
Format
pfultz2 Jan 15, 2026
1062c6f
Merge branch 'codex-gather-rewrite' into codex-gather-rewrite2
pfultz2 Jan 15, 2026
86d434e
Update tests
pfultz2 Jan 16, 2026
386cc41
Move into find_gather class
pfultz2 Jan 16, 2026
19cdfed
Format
pfultz2 Jan 16, 2026
1eb0fca
Some more cleanup
pfultz2 Jan 16, 2026
70ac01c
Improve comment
pfultz2 Jan 16, 2026
a1d5e6f
Fix tidy
pfultz2 Jan 16, 2026
2c20298
Use insert_ops
pfultz2 Jan 16, 2026
ce3bae5
Fix compilation error
pfultz2 Jan 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 65 additions & 8 deletions src/common_dims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,25 +97,66 @@ static bool compute_common_dim(std::vector<std::size_t>& cd_dims,
assert(state1.get() < state2.get());
auto d2 = state2.get();
auto dims = state1.dims_for(d2);
auto n = elements(dims);
auto naxes = distance(dims);

if(naxes == 0)
return false;

// Check if state1 has a remainder from previous split
bool has_remainder = (state1.rem != 1);

// Compute the product of dimensions, adjusting for remainder if needed
auto n = elements(dims);
if(has_remainder and naxes > 0)
{
n = n / *dims.begin() * (*dims.begin() / state1.rem);
}

// If not divisible then we can't compute a common dim
if((d2 % n) != 0)
return false;

auto rem = d2 / n;
state1.add_multi_axes(naxes, cd_dims.size());
state2.add_axes(rem == 1 ? naxes : naxes + 1, cd_dims.size());
auto start_pos = cd_dims.size();

// Add axes mappings
if(has_remainder)
{
// state1: dimension was split, keep axes together
state1.add_axes(naxes, start_pos);
// state2: axes should include the previous remainder dimension
state2.add_axes(rem == 1 ? naxes : naxes + 1, start_pos - 1);
}
else
{
// state1: separate axes for each dimension
state1.add_multi_axes(naxes, start_pos);
// state2: normal axes mapping
state2.add_axes(rem == 1 ? naxes : naxes + 1, start_pos);
}

// Add dimensions to cd_dims
if(has_remainder and naxes > 0)
{
// Adjust the first dimension by dividing by the remainder
cd_dims.push_back(*dims.begin() / state1.rem);
cd_dims.insert(cd_dims.end(), std::next(dims.begin()), dims.end());
}
else
{
cd_dims.insert(cd_dims.end(), dims.begin(), dims.end());
}

// Add remainder dimension if needed
if(rem != 1)
cd_dims.push_back(rem);

// Update states
state1.rem = rem;
state2.rem = 1;

cd_dims.insert(cd_dims.end(), dims.begin(), dims.end());
if(state1.rem != 1)
cd_dims.push_back(state1.rem);
state1.next(distance(dims));
state1.next(naxes);
state2.next();

return true;
}

Expand Down Expand Up @@ -152,6 +193,22 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
return {};
}
}

// Handle case where one state has a remainder that equals the next dimension
// In this case, the dimension was already added as a remainder, we just need the axes mapping
auto handle_remaining_dimension = [&cd](common_dim_state& state) {
if(not state.is_end() and state.rem != 1 and state.get() == 1)
{
// The remainder already added to cd_dims matches this dimension
// Add a single axes mapping
state.axes_map->push_back({cd.dims.size() - 1});
state.next();
}
};

handle_remaining_dimension(state1);
handle_remaining_dimension(state2);

assert(elements(dims1) == elements(cd.dims));
return cd;
}
Expand Down
2 changes: 1 addition & 1 deletion src/include/migraphx/fuse_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct module_pass_manager;

struct MIGRAPHX_EXPORT fuse_attention
{
bool attn_enabled = false;
bool attn_enabled = false;
bool flash_decoding_enabled = false;
std::size_t flash_decoding_num_splits = 0;
std::size_t flash_decoding_threshold = 32;
Expand Down
7 changes: 7 additions & 0 deletions src/include/migraphx/shape_transform_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <migraphx/config.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/shape.hpp>
#include <cstdint>
#include <iosfwd>
#include <set>
Expand Down Expand Up @@ -106,6 +107,8 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor
std::vector<std::vector<std::size_t>> common_axes_map_from_src() const;
std::vector<std::vector<std::size_t>> common_axes_map_from_dst() const;

std::vector<std::size_t> get_dst_axes_from_src(std::size_t axis) const;

bool empty() const;
std::vector<std::size_t> lens() const;

Expand Down Expand Up @@ -158,6 +161,10 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor
MIGRAPHX_EXPORT std::vector<operation>
optimize_shape_transforms(const std::vector<std::size_t>& dims, const std::vector<operation>& ops);

// Generate the shape transforms for strided view
MIGRAPHX_EXPORT optional<std::vector<operation>>
generate_shape_transforms_for(shape s, const std::vector<std::size_t>& idims, std::int64_t offset);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SHAPE_TRANSFORM_DESCRIPTOR_HPP
1 change: 1 addition & 0 deletions src/include/migraphx/simplify_reshapes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct MIGRAPHX_EXPORT simplify_reshapes
{
size_t depth = 4;
bool enable_op_shape_transform_op = false;
bool enable_gather_rewrite = false;
std::string name() const { return "simplify_reshapes"; }
void apply(module& m) const;
};
Expand Down
2 changes: 1 addition & 1 deletion src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1320,7 +1320,7 @@ module::print_py(std::ostream& os,
if(ins->name() == "@literal")
{
os << mname << ".add_literal(";
if(ins->get_shape().elements() < 10)
if(ins->get_shape().elements() < 1024)
{
os << "migraphx.create_argument(";
print_py_shape(os, ins->get_shape());
Expand Down
2 changes: 1 addition & 1 deletion src/pass_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ struct module_pm : module_pass_manager
catch(const std::exception& e)
{
std::cerr << "Error " << p.name() << ": " << e.what() << std::endl;
auto clk = std::chrono::steady_clock::now().time_since_epoch().count();
auto clk = std::chrono::steady_clock::now().time_since_epoch().count();
fs::path dirname = fs::temp_directory_path() / "migraphx";
fs::create_directories(dirname);
std::string base = p.name() + std::to_string(clk) + ".mxr";
Expand Down
226 changes: 226 additions & 0 deletions src/shape_transform_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1870,6 +1870,28 @@
return result;
}

std::vector<std::size_t> shape_transform_descriptor::get_dst_axes_from_src(std::size_t axis) const
{
std::vector<std::size_t> result;
for(auto i : range(dimensions.size()))
{
const auto& d = dimensions[i];
auto it = std::find_if(d.subdimensions.begin(), d.subdimensions.end(), [&](auto& s) {
if(s.axis.empty())
return false;
return s.axis.front() == axis;
});
if(it == d.subdimensions.end())
continue;
// If it maps to a subdimesion then exit as there isnt a clear mapping
if(d.len() != it->len)
return {};
result.push_back(i);
}
// TODO: Put it in the correct order if there is multiple axes
return result;
}

bool shape_transform_descriptor::empty() const { return dimensions.empty(); }

std::vector<std::size_t> shape_transform_descriptor::lens() const
Expand Down Expand Up @@ -2011,5 +2033,209 @@
return sd.generate();
}

// Replace broadcasted dimensions with size 1, and set the stride to the previous stride
static shape unbroadcast(const shape& s)
{
std::vector<std::size_t> lens = s.lens();
std::vector<std::size_t> strides = s.strides();
auto stride_it = std::find_if(
s.strides().begin(), s.strides().end(), [](auto stride) { return stride != 0; });
std::size_t prev_stride = stride_it == s.strides().end() ? 1 : *stride_it;
for(std::size_t i = 0; i < lens.size(); ++i)
{
if(strides[i] == 0)
{
lens[i] = 1;
strides[i] = prev_stride;
}
else
{
prev_stride = strides[i];
}
}
return {s.type(), lens, strides};
}

static std::size_t adjust_strided_shape(shape& s, std::size_t n)
{
auto lens = s.lens();
auto strides = s.strides();

// Insert a dim of 1 so it can be used to handle steps
if(std::none_of(strides.begin(), strides.end(), [](auto stride) { return stride == 1; }) and
std::any_of(strides.begin(), strides.end(), [](auto stride) { return stride != 0; }))
{
lens.push_back(1);
strides.push_back(1);
}

auto last_axis = std::max_element(strides.begin(), strides.end()) - strides.begin();
auto total_elements = std::max<std::size_t>(1, strides[last_axis] * lens[last_axis]);
// Add a dim of 1 to the front so it can handle extra elements
auto extra = n / total_elements;
if(extra > 1)
{
strides.insert(strides.begin(), total_elements);
lens.insert(lens.begin(), 1);
}
s = shape(s.type(), lens, strides);
return std::max<std::size_t>(1, extra);
}

// Generate the shape transforms for strided view
optional<std::vector<operation>>
generate_shape_transforms_for(shape s, const std::vector<std::size_t>& idims, std::int64_t offset)

Check warning on line 2087 in src/shape_transform_descriptor.cpp

View workflow job for this annotation

GitHub Actions / tidy

function 'generate_shape_transforms_for' exceeds recommended size/complexity thresholds [readability-function-size,-warnings-as-errors]
{
std::vector<operation> result;
if(s.lens().empty())
return std::nullopt;
std::size_t ielements =
std::accumulate(idims.begin(), idims.end(), std::size_t(1), std::multiplies<>());
auto extra = adjust_strided_shape(s, ielements);
// TODO: Improve handling of multiple dimensions, for now just reshape to 1 dimension
if(idims.size() != 1)
{
result.push_back(make_op("reshape", {{"dims", {ielements}}}));
auto ops = generate_shape_transforms_for(s, {ielements}, offset);
if(not ops)
return std::nullopt;
result.insert(result.end(), ops->begin(), ops->end());
return result;
}
auto pre_broadcast = unbroadcast(s);
auto perm = find_permutation(pre_broadcast);
auto iperm = invert_permutation(perm);
auto pre_transpose = reorder_shape(pre_broadcast, perm);

std::vector<std::size_t> start_lens;
std::adjacent_difference(pre_transpose.strides().begin(),
pre_transpose.strides().end(),
std::back_inserter(start_lens),
[](auto y, auto x) -> std::size_t {
assert(x >= y);
assert(y != 0);
if((x % y) != 0)
return 0;
return x / y;
});
if(std::any_of(start_lens.begin(), start_lens.end(), [](auto len) { return len == 0; }))
return std::nullopt;
start_lens.front() = extra > 1 ? extra : pre_transpose.lens().front();

std::size_t nelements =
std::accumulate(start_lens.begin(), start_lens.end(), std::size_t(1), std::multiplies<>());

if(nelements < pre_transpose.elements() * extra)
return std::nullopt;

std::vector<std::size_t> start_mask(start_lens.size(), 0);
if(offset != 0)
{
shape start_shape{shape::float_type, start_lens};
auto idx = start_shape.multi(offset);

std::vector<std::size_t> overhead;
std::transform(start_lens.begin(),
start_lens.end(),
pre_transpose.lens().begin(),
std::back_inserter(overhead),
[](auto start_len, auto len) { return start_len - len; });
if(std::equal(
idx.begin(), idx.end(), overhead.begin(), overhead.end(), [](auto i, auto over) {
return i <= over;
}))
{
start_mask = reorder_dims(idx, iperm);
offset = 0;
}
}

std::vector<std::size_t> pre_slice_mask;
std::transform(start_lens.begin(),
start_lens.end(),
pre_transpose.lens().begin(),
std::back_inserter(pre_slice_mask),
[](auto start_len, auto len) -> std::size_t {
if(start_len == len)
return 0;
return len;
});
auto slice_mask = reorder_dims(pre_slice_mask, iperm);

std::vector<std::size_t> blens = reorder_dims(start_lens, iperm);
std::transform(s.lens().begin(),
s.lens().end(),
blens.begin(),
blens.begin(),
[](auto len, auto blen) -> std::size_t {
if(blen == 1)
return len;
return blen;
});

std::vector<operation> ops;
ops.push_back(make_op("multibroadcast", {{"out_lens", blens}}));
ops.push_back(make_op("transpose", {{"permutation", invert_permutation(perm)}}));
ops.push_back(make_op("reshape", {{"dims", start_lens}}));
std::reverse(ops.begin(), ops.end());

auto desc = shape_transform_descriptor::create({nelements}, ops);

auto end = offset + nelements;
if(offset != 0 or nelements != ielements)
{

// If the end is out of bounds broadcast it to pad it
if(end > ielements)
{
result.push_back(make_op("broadcast", {{"axis", 1}, {"out_lens", {2, ielements}}}));
result.push_back(make_op("reshape", {{"dims", {2 * ielements}}}));
}
result.push_back(make_op("slice", {{"axes", {0}}, {"starts", {offset}}, {"ends", {end}}}));
}

auto opt_ops = desc.generate();
result.insert(result.end(), opt_ops.begin(), opt_ops.end());

std::vector<std::size_t> axes;
std::transform(slice_mask.begin(),
slice_mask.end(),
range(slice_mask.size()).begin(),
join_back_inserter(axes),
[](std::size_t mask, std::size_t idx) -> std::vector<std::size_t> {
if(mask > 0)
return {idx};
return {};
});

if(not axes.empty())
{
std::vector<std::size_t> starts;
std::transform(slice_mask.begin(),
slice_mask.end(),
start_mask.begin(),
join_back_inserter(starts),
[](std::size_t mask, std::size_t start) -> std::vector<std::size_t> {
if(mask == 0)
return {};
return {start};
});
std::vector<std::size_t> ends;
std::transform(slice_mask.begin(),
slice_mask.end(),
s.lens().begin(),
join_back_inserter(ends),
[](std::size_t mask, std::size_t len) -> std::vector<std::size_t> {
if(mask == 0)
return {};
return {len};
});
std::transform(ends.begin(), ends.end(), starts.begin(), ends.begin(), std::plus<>{});

result.push_back(make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}));
}
return result;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Loading
Loading