diff --git a/src/common_dims.cpp b/src/common_dims.cpp index 1afe92087fd..6dbe5f95167 100644 --- a/src/common_dims.cpp +++ b/src/common_dims.cpp @@ -97,25 +97,66 @@ static bool compute_common_dim(std::vector& cd_dims, assert(state1.get() < state2.get()); auto d2 = state2.get(); auto dims = state1.dims_for(d2); - auto n = elements(dims); auto naxes = distance(dims); + if(naxes == 0) return false; + + // Check if state1 has a remainder from previous split + bool has_remainder = (state1.rem != 1); + + // Compute the product of dimensions, adjusting for remainder if needed + auto n = elements(dims); + if(has_remainder and naxes > 0) + { + n = n / *dims.begin() * (*dims.begin() / state1.rem); + } + // If not divisible then we can't compute a common dim if((d2 % n) != 0) return false; + auto rem = d2 / n; - state1.add_multi_axes(naxes, cd_dims.size()); - state2.add_axes(rem == 1 ? naxes : naxes + 1, cd_dims.size()); + auto start_pos = cd_dims.size(); + // Add axes mappings + if(has_remainder) + { + // state1: dimension was split, keep axes together + state1.add_axes(naxes, start_pos); + // state2: axes should include the previous remainder dimension + state2.add_axes(rem == 1 ? naxes : naxes + 1, start_pos - 1); + } + else + { + // state1: separate axes for each dimension + state1.add_multi_axes(naxes, start_pos); + // state2: normal axes mapping + state2.add_axes(rem == 1 ? naxes : naxes + 1, start_pos); + } + + // Add dimensions to cd_dims + if(has_remainder and naxes > 0) + { + // Adjust the first dimension by dividing by the remainder + cd_dims.push_back(*dims.begin() / state1.rem); + cd_dims.insert(cd_dims.end(), std::next(dims.begin()), dims.end()); + } + else + { + cd_dims.insert(cd_dims.end(), dims.begin(), dims.end()); + } + + // Add remainder dimension if needed + if(rem != 1) + cd_dims.push_back(rem); + + // Update states state1.rem = rem; state2.rem = 1; - - cd_dims.insert(cd_dims.end(), dims.begin(), dims.end()); - if(state1.rem != 1) - cd_dims.push_back(state1.rem); - state1.next(distance(dims)); + state1.next(naxes); state2.next(); + return true; } @@ -152,6 +193,22 @@ common_dims common_dims::compute(const std::vector& dims1, return {}; } } + + // Handle case where one state has a remainder that equals the next dimension + // In this case, the dimension was already added as a remainder, we just need the axes mapping + auto handle_remaining_dimension = [&cd](common_dim_state& state) { + if(not state.is_end() and state.rem != 1 and state.get() == 1) + { + // The remainder already added to cd_dims matches this dimension + // Add a single axes mapping + state.axes_map->push_back({cd.dims.size() - 1}); + state.next(); + } + }; + + handle_remaining_dimension(state1); + handle_remaining_dimension(state2); + assert(elements(dims1) == elements(cd.dims)); return cd; } diff --git a/src/include/migraphx/fuse_attention.hpp b/src/include/migraphx/fuse_attention.hpp index 1b9329b809c..abdba8c711d 100644 --- a/src/include/migraphx/fuse_attention.hpp +++ b/src/include/migraphx/fuse_attention.hpp @@ -37,7 +37,7 @@ struct module_pass_manager; struct MIGRAPHX_EXPORT fuse_attention { - bool attn_enabled = false; + bool attn_enabled = false; bool flash_decoding_enabled = false; std::size_t flash_decoding_num_splits = 0; std::size_t flash_decoding_threshold = 32; diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index c8d42119b98..458501c6328 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -27,6 +27,7 @@ #include #include +#include #include #include #include @@ -106,6 +107,8 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor std::vector> common_axes_map_from_src() const; std::vector> common_axes_map_from_dst() const; + std::vector get_dst_axes_from_src(std::size_t axis) const; + bool empty() const; std::vector lens() const; @@ -158,6 +161,10 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor MIGRAPHX_EXPORT std::vector optimize_shape_transforms(const std::vector& dims, const std::vector& ops); +// Generate the shape transforms for strided view +MIGRAPHX_EXPORT optional> +generate_shape_transforms_for(shape s, const std::vector& idims, std::int64_t offset); + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif // MIGRAPHX_GUARD_MIGRAPHX_SHAPE_TRANSFORM_DESCRIPTOR_HPP diff --git a/src/include/migraphx/simplify_reshapes.hpp b/src/include/migraphx/simplify_reshapes.hpp index 9c02dc9c00d..269aa6e4b30 100644 --- a/src/include/migraphx/simplify_reshapes.hpp +++ b/src/include/migraphx/simplify_reshapes.hpp @@ -40,6 +40,7 @@ struct MIGRAPHX_EXPORT simplify_reshapes { size_t depth = 4; bool enable_op_shape_transform_op = false; + bool enable_gather_rewrite = false; std::string name() const { return "simplify_reshapes"; } void apply(module& m) const; }; diff --git a/src/module.cpp b/src/module.cpp index 4838d241904..b529f2b3c08 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1320,7 +1320,7 @@ module::print_py(std::ostream& os, if(ins->name() == "@literal") { os << mname << ".add_literal("; - if(ins->get_shape().elements() < 10) + if(ins->get_shape().elements() < 1024) { os << "migraphx.create_argument("; print_py_shape(os, ins->get_shape()); diff --git a/src/pass_manager.cpp b/src/pass_manager.cpp index 6bbeb5e73db..ba81218cb6a 100644 --- a/src/pass_manager.cpp +++ b/src/pass_manager.cpp @@ -182,7 +182,7 @@ struct module_pm : module_pass_manager catch(const std::exception& e) { std::cerr << "Error " << p.name() << ": " << e.what() << std::endl; - auto clk = std::chrono::steady_clock::now().time_since_epoch().count(); + auto clk = std::chrono::steady_clock::now().time_since_epoch().count(); fs::path dirname = fs::temp_directory_path() / "migraphx"; fs::create_directories(dirname); std::string base = p.name() + std::to_string(clk) + ".mxr"; diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index f544318938b..654c75940f3 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -1870,6 +1870,28 @@ std::vector> shape_transform_descriptor::common_axes_ma return result; } +std::vector shape_transform_descriptor::get_dst_axes_from_src(std::size_t axis) const +{ + std::vector result; + for(auto i : range(dimensions.size())) + { + const auto& d = dimensions[i]; + auto it = std::find_if(d.subdimensions.begin(), d.subdimensions.end(), [&](auto& s) { + if(s.axis.empty()) + return false; + return s.axis.front() == axis; + }); + if(it == d.subdimensions.end()) + continue; + // If it maps to a subdimesion then exit as there isnt a clear mapping + if(d.len() != it->len) + return {}; + result.push_back(i); + } + // TODO: Put it in the correct order if there is multiple axes + return result; +} + bool shape_transform_descriptor::empty() const { return dimensions.empty(); } std::vector shape_transform_descriptor::lens() const @@ -2011,5 +2033,209 @@ std::vector optimize_shape_transforms(const std::vector& return sd.generate(); } +// Replace broadcasted dimensions with size 1, and set the stride to the previous stride +static shape unbroadcast(const shape& s) +{ + std::vector lens = s.lens(); + std::vector strides = s.strides(); + auto stride_it = std::find_if( + s.strides().begin(), s.strides().end(), [](auto stride) { return stride != 0; }); + std::size_t prev_stride = stride_it == s.strides().end() ? 1 : *stride_it; + for(std::size_t i = 0; i < lens.size(); ++i) + { + if(strides[i] == 0) + { + lens[i] = 1; + strides[i] = prev_stride; + } + else + { + prev_stride = strides[i]; + } + } + return {s.type(), lens, strides}; +} + +static std::size_t adjust_strided_shape(shape& s, std::size_t n) +{ + auto lens = s.lens(); + auto strides = s.strides(); + + // Insert a dim of 1 so it can be used to handle steps + if(std::none_of(strides.begin(), strides.end(), [](auto stride) { return stride == 1; }) and + std::any_of(strides.begin(), strides.end(), [](auto stride) { return stride != 0; })) + { + lens.push_back(1); + strides.push_back(1); + } + + auto last_axis = std::max_element(strides.begin(), strides.end()) - strides.begin(); + auto total_elements = std::max(1, strides[last_axis] * lens[last_axis]); + // Add a dim of 1 to the front so it can handle extra elements + auto extra = n / total_elements; + if(extra > 1) + { + strides.insert(strides.begin(), total_elements); + lens.insert(lens.begin(), 1); + } + s = shape(s.type(), lens, strides); + return std::max(1, extra); +} + +// Generate the shape transforms for strided view +optional> +generate_shape_transforms_for(shape s, const std::vector& idims, std::int64_t offset) +{ + std::vector result; + if(s.lens().empty()) + return std::nullopt; + std::size_t ielements = + std::accumulate(idims.begin(), idims.end(), std::size_t(1), std::multiplies<>()); + auto extra = adjust_strided_shape(s, ielements); + // TODO: Improve handling of multiple dimensions, for now just reshape to 1 dimension + if(idims.size() != 1) + { + result.push_back(make_op("reshape", {{"dims", {ielements}}})); + auto ops = generate_shape_transforms_for(s, {ielements}, offset); + if(not ops) + return std::nullopt; + result.insert(result.end(), ops->begin(), ops->end()); + return result; + } + auto pre_broadcast = unbroadcast(s); + auto perm = find_permutation(pre_broadcast); + auto iperm = invert_permutation(perm); + auto pre_transpose = reorder_shape(pre_broadcast, perm); + + std::vector start_lens; + std::adjacent_difference(pre_transpose.strides().begin(), + pre_transpose.strides().end(), + std::back_inserter(start_lens), + [](auto y, auto x) -> std::size_t { + assert(x >= y); + assert(y != 0); + if((x % y) != 0) + return 0; + return x / y; + }); + if(std::any_of(start_lens.begin(), start_lens.end(), [](auto len) { return len == 0; })) + return std::nullopt; + start_lens.front() = extra > 1 ? extra : pre_transpose.lens().front(); + + std::size_t nelements = + std::accumulate(start_lens.begin(), start_lens.end(), std::size_t(1), std::multiplies<>()); + + if(nelements < pre_transpose.elements() * extra) + return std::nullopt; + + std::vector start_mask(start_lens.size(), 0); + if(offset != 0) + { + shape start_shape{shape::float_type, start_lens}; + auto idx = start_shape.multi(offset); + + std::vector overhead; + std::transform(start_lens.begin(), + start_lens.end(), + pre_transpose.lens().begin(), + std::back_inserter(overhead), + [](auto start_len, auto len) { return start_len - len; }); + if(std::equal( + idx.begin(), idx.end(), overhead.begin(), overhead.end(), [](auto i, auto over) { + return i <= over; + })) + { + start_mask = reorder_dims(idx, iperm); + offset = 0; + } + } + + std::vector pre_slice_mask; + std::transform(start_lens.begin(), + start_lens.end(), + pre_transpose.lens().begin(), + std::back_inserter(pre_slice_mask), + [](auto start_len, auto len) -> std::size_t { + if(start_len == len) + return 0; + return len; + }); + auto slice_mask = reorder_dims(pre_slice_mask, iperm); + + std::vector blens = reorder_dims(start_lens, iperm); + std::transform(s.lens().begin(), + s.lens().end(), + blens.begin(), + blens.begin(), + [](auto len, auto blen) -> std::size_t { + if(blen == 1) + return len; + return blen; + }); + + std::vector ops; + ops.push_back(make_op("multibroadcast", {{"out_lens", blens}})); + ops.push_back(make_op("transpose", {{"permutation", invert_permutation(perm)}})); + ops.push_back(make_op("reshape", {{"dims", start_lens}})); + std::reverse(ops.begin(), ops.end()); + + auto desc = shape_transform_descriptor::create({nelements}, ops); + + auto end = offset + nelements; + if(offset != 0 or nelements != ielements) + { + + // If the end is out of bounds broadcast it to pad it + if(end > ielements) + { + result.push_back(make_op("broadcast", {{"axis", 1}, {"out_lens", {2, ielements}}})); + result.push_back(make_op("reshape", {{"dims", {2 * ielements}}})); + } + result.push_back(make_op("slice", {{"axes", {0}}, {"starts", {offset}}, {"ends", {end}}})); + } + + auto opt_ops = desc.generate(); + result.insert(result.end(), opt_ops.begin(), opt_ops.end()); + + std::vector axes; + std::transform(slice_mask.begin(), + slice_mask.end(), + range(slice_mask.size()).begin(), + join_back_inserter(axes), + [](std::size_t mask, std::size_t idx) -> std::vector { + if(mask > 0) + return {idx}; + return {}; + }); + + if(not axes.empty()) + { + std::vector starts; + std::transform(slice_mask.begin(), + slice_mask.end(), + start_mask.begin(), + join_back_inserter(starts), + [](std::size_t mask, std::size_t start) -> std::vector { + if(mask == 0) + return {}; + return {start}; + }); + std::vector ends; + std::transform(slice_mask.begin(), + slice_mask.end(), + s.lens().begin(), + join_back_inserter(ends), + [](std::size_t mask, std::size_t len) -> std::vector { + if(mask == 0) + return {}; + return {len}; + }); + std::transform(ends.begin(), ends.end(), starts.begin(), ends.begin(), std::plus<>{}); + + result.push_back(make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}})); + } + return result; +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 51b5e199db5..5e6748880fc 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -21,15 +21,20 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include +#include #include #include #include #include #include +#include +#include #include #include #include #include +#include #include #include #include @@ -40,13 +45,105 @@ #include #include #include +#include +#include +#include #include +#include +#include +#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace { + +template +instruction_ref +insert_auto_reshape(module& m, instruction_ref ins, const Dims& dims, instruction_ref input) +{ + assert(std::all_of(dims.begin(), dims.end(), [](auto i) { return i > 0; })); + if(std::equal(dims.begin(), + dims.end(), + input->get_shape().lens().begin(), + input->get_shape().lens().end())) + { + return input; + } + + auto curr_lens = input->get_shape().lens(); + // Check if we can use squeeze (removing dimensions of size 1) + if(curr_lens.size() > dims.size()) + { + // Potential squeeze - check if we're just removing 1s + std::vector axes_to_squeeze; + std::size_t target_idx = 0; + for(std::size_t curr_idx = 0; curr_idx < curr_lens.size(); ++curr_idx) + { + if(curr_lens[curr_idx] == 1) + { + axes_to_squeeze.push_back(curr_idx); + } + else + { + if(target_idx >= dims.size() or curr_lens[curr_idx] != dims[target_idx]) + { + axes_to_squeeze.clear(); + break; + } + ++target_idx; + } + } + if(not axes_to_squeeze.empty() and target_idx == dims.size()) + { + return m.insert_instruction( + ins, make_op("squeeze", {{"axes", axes_to_squeeze}}), input); + } + } + // Check if we can use unsqueeze (adding dimensions of size 1) + else if(curr_lens.size() < dims.size()) + { + // Potential unsqueeze - check if we're just adding 1s + std::vector axes_to_unsqueeze; + std::size_t curr_idx = 0; + for(std::size_t target_idx = 0; target_idx < dims.size(); ++target_idx) + { + if(dims[target_idx] == 1) + { + axes_to_unsqueeze.push_back(target_idx); + } + else + { + if(curr_idx >= curr_lens.size() or dims[target_idx] != curr_lens[curr_idx]) + { + axes_to_unsqueeze.clear(); + break; + } + ++curr_idx; + } + } + if(not axes_to_unsqueeze.empty() and curr_idx == curr_lens.size()) + { + return m.insert_instruction( + ins, make_op("unsqueeze", {{"axes", axes_to_unsqueeze}}), input); + } + } + + return m.insert_instruction(ins, make_op("reshape", {{"dims", dims}}), input); +} + +template +instruction_ref insert_auto_reshape(module& m, + instruction_ref ins, + const std::initializer_list& dims, + instruction_ref input) +{ + return insert_auto_reshape(m, ins, std::vector(dims), input); +} + const auto& reshaper_names() { // clang-format off @@ -61,6 +158,16 @@ const auto& reshaper_names() return names; } +instruction_ref +insert_ops(module& m, instruction_ref ins, std::vector& ops, instruction_ref input) +{ + for(const auto& op : ops) + { + input = m.insert_instruction(ins, op, input); + } + return input; +} + struct find_nested_shape_transforms { static const auto& shape_transform_ops() @@ -114,9 +221,7 @@ struct find_nested_shape_transforms auto opt_ops = optimize_shape_transforms(x->get_shape().lens(), ops); if(ops == opt_ops) return; - auto y = x; - for(const auto& op : opt_ops) - y = m.insert_instruction(ins, op, y); + auto y = insert_ops(m, ins, opt_ops, x); m.replace_instruction(ins, y); } } @@ -379,6 +484,118 @@ struct find_op_shape_transform_op } }; +struct find_slice_shape_transforms +{ + static const auto& shape_transform_ops() + { + static const std::unordered_set names = { + "reshape", + "squeeze", + "unsqueeze", + "flatten", + "transpose", + "contiguous", + "multibroadcast", + "broadcast", + }; + return names; + } + + // auto matcher() const + // { + // auto reshapes = match::name(shape_transform_ops()); + // auto match_op = match::any_of(match::reduce(), match::pointwise()); + // auto x_op = + // match_op(match::none_of(fusable_split())); + // auto reshapes_x_op = reshapes(match::arg(0)(match::skip(reshapes())(x_op.bind("x")))); + // return match_op(match::any_of[match::inputs()](reshapes_x_op.bind("input"))); + // } + + auto matcher() const + { + auto reshapes = match::name(shape_transform_ops()); + auto slice_op = match::name("slice")(match::arg(0)(match::used_once())); + return reshapes(reshapes(match::none_of[match::outputs()](reshapes())), + match::arg(0)(match::skip(reshapes())(slice_op.bind("slice")))); + } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + auto slice = mr.instructions["slice"]; + auto slice_op = slice->get_operator().to_value(); + auto axes = slice_op.at("axes").to_vector(); + + std::vector ops; + auto x = ins; + while(contains(shape_transform_ops(), x->get_operator().name())) + { + ops.push_back(x->get_operator()); + x = x->inputs().front(); + } + if(x != slice) + return; + x = x->inputs().front(); + std::reverse(ops.begin(), ops.end()); + auto desc = shape_transform_descriptor::create(slice->get_shape().lens(), ops); + + // std::cout << "desc: " << desc << std::endl; + + std::vector new_axes; + std::transform(axes.begin(), + axes.end(), + join_back_inserter(new_axes), + [&](auto axis) -> std::vector { + auto result = desc.get_dst_axes_from_src(axis); + if(result.size() != 1) + return {}; + return result; + }); + + // Optimizes shape transforms if the slice cant be optimized + if(axes.size() != new_axes.size()) + { + auto opt_ops = desc.generate(); + auto y = insert_ops(m, ins, opt_ops, slice); + m.replace_instruction(ins, y); + return; + } + slice_op["axes"] = new_axes; + + auto new_desc = desc.rebase(slice->inputs().front()->get_shape().lens()); + if(new_desc.empty()) + return; + new_desc.simplify(); + + auto opt_ops = new_desc.generate(); + auto y = insert_ops(m, ins, opt_ops, x); + y = m.insert_instruction(ins, make_op("slice", slice_op), y); + m.replace_instruction(ins, y); + + // auto opt_ops = optimize_shape_transforms(x->get_shape().lens(), ops); + // if(ops == opt_ops) + // return; + // auto y = x; + // for(const auto& op : opt_ops) + // y = m.insert_instruction(ins, op, y); + // m.replace_instruction(ins, y); + // if(x->get_shape().scalar()) + // { + // m.replace_instruction( + // ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), x); + // } + // else if(x->get_shape().elements() == 1 and ins->get_shape().elements() == 1) + // { + // // TODO: Use squeeze or unsqueeze + // m.replace_instruction(ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), + // x); + // } + // else + // { + // } + } +}; + struct find_nop_reshapes { auto matcher() const @@ -790,165 +1007,313 @@ struct find_nested_concat } }; -struct find_resize +struct find_gather { - auto matcher() const + struct arithmetic_segment { - return match::name("gather")( - match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind"))); - } + int64_t base = 0; + int64_t stride = 0; + std::size_t count = 0; - void apply(module& m, const match::matcher_result& r) const - { - auto ins = r.result; - auto ins_rsp = r.instructions["data"]; - auto ins_ind = r.instructions["ind"]; - - // resize input shape - if(ins_rsp->get_shape().lens().size() != 1) + template + static std::vector from_ints(Iterator begin, Iterator end) { - return; + std::vector result(std::distance(begin, end)); + par_transform( + begin, end, result.begin(), [](auto x) { return arithmetic_segment{x, 1, 1}; }); + return result; } - // resize output shape - const auto& in_shape = ins_rsp->inputs().front()->get_shape(); - const auto& out_shape = ins->get_shape(); - // check if output shape is multiple of input shape - const auto& in_lens = in_shape.lens(); - const auto& out_lens = out_shape.lens(); - if(in_lens.size() != out_lens.size()) + template + static Iterator find_largest(Iterator start, Iterator last, OutputIterator out) { - return; + for(auto it = start; it != last;) + { + auto [seg, next_it] = find(it, last); + it = next_it; + *out = seg; + out++; + } + return last; } - // output shape must be multiple of input shape - std::vector is_multi(in_lens.size()); - std::transform( - in_lens.begin(), in_lens.end(), out_lens.begin(), is_multi.begin(), [](auto x, auto y) { - return (y % x == 0); - }); - if(not std::all_of(is_multi.begin(), is_multi.end(), [](auto b) { return b; })) + template + static Iterator find_n(Iterator start, Iterator last, std::size_t n, OutputIterator out) { - return; + for(auto it = start; it != last;) + { + auto [seg, next_it] = find(it, it + n); + if(next_it != it + n) + return next_it; + it = next_it; + *out = seg; + out++; + } + return last; } - // output must be multiple of inputs - std::vector scales(in_lens.size()); - std::transform( - in_lens.begin(), in_lens.end(), out_lens.begin(), scales.begin(), [](auto x, auto y) { - return y / x; - }); + static std::vector + make_segments(const std::vector& segments, bool uniform = true) + { + std::vector result; + auto [first_seg, first_it] = find(segments.begin(), segments.end()); + result.push_back(first_seg); + // Try to find segments that are the same size + auto it = find_n(first_it, segments.end(), first_seg.count, std::back_inserter(result)); + if(it != segments.end()) + { + if(uniform) + return {}; + result.resize(1); + find_largest(first_it, segments.end(), std::back_inserter(result)); + } + return result; + } - // if ind is not constant, cannot optimize - std::vector vec_ind; - auto arg_ind = ins_ind->eval(); - if(arg_ind.empty()) + static std::vector shift(std::vector segments, + std::int64_t shift) { - return; + par_transform( + segments.begin(), segments.end(), segments.begin(), [&](arithmetic_segment x) { + x.base += shift; + return x; + }); + return segments; } - arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); }); - if(not all_of(range(out_shape.elements()), [&](auto i) { - auto out_idx = out_shape.multi(i); - auto in_idx = out_idx; - std::transform(out_idx.begin(), - out_idx.end(), - scales.begin(), - in_idx.begin(), - [&](auto io, auto scale) { return io - (io % scale); }); - return vec_ind[i] == vec_ind[out_shape.index(in_idx)]; - })) + + /// Detect arithmetic segment pattern + template + static std::pair find(Iterator begin, Iterator end) { - return; + std::size_t length = std::distance(begin, end); + if(length == 0) + return std::make_pair(arithmetic_segment{}, begin); + if(length == 1) + return std::make_pair(*begin, std::next(begin)); + auto start = *begin; + // auto base = *begin; + auto stride = std::next(begin)->base - start.base; + if(stride < 0) + return std::make_pair(*begin, std::next(begin)); + auto diff = + std::adjacent_find(begin, end, [&](arithmetic_segment x, arithmetic_segment y) { + return y.base - x.base != stride; + }); + if(diff != end) + diff++; + return std::make_pair( + arithmetic_segment{start.base, stride, std::size_t(std::distance(begin, diff))}, + diff); } - // wrap up shapes for multibroadcast - std::vector> dim_scales; - std::transform(in_lens.begin(), - in_lens.end(), - out_lens.begin(), - std::back_inserter(dim_scales), - [](auto x, auto y) { return std::make_pair(x, y / x); }); - - std::vector in_dims; - std::vector out_dims; - for(auto& isp : dim_scales) + static shape make_strided_view(std::vector segments) { - in_dims.push_back(isp.first); - out_dims.push_back(isp.first * isp.second); - if(isp.first == 1 or isp.second == 1) + std::vector lens; + std::vector strides; + + do { - continue; + segments = make_segments(segments); + // std::cout << "nsegments: " << segments.size() << std::endl; + // for(auto segment : segments) + // std::cout << " {" << segment.base << ", " << segment.stride << ", " + // << segment.count << "}\n"; + if(segments.empty()) + return {}; + auto seg = segments.front(); + if(seg.stride < 0) + return {}; + if(std::any_of(segments.begin(), segments.end(), [](const arithmetic_segment& seg) { + return seg.base < 0; + })) + return {}; + if(not std::all_of( + segments.begin(), segments.end(), [&](const arithmetic_segment& seg) { + return seg.stride == segments.front().stride and + seg.count == segments.front().count; + })) + return {}; + lens.push_back(seg.count); + strides.push_back(seg.stride); + } while(segments.size() > 1); + + std::reverse(lens.begin(), lens.end()); + std::reverse(strides.begin(), strides.end()); + + if(std::none_of( + strides.begin(), strides.end(), [](auto stride) { return stride == 1; })) + { + lens.push_back(1); + strides.push_back(1); } - out_dims.back() = isp.first; - in_dims.push_back(1); - out_dims.push_back(isp.second); + return {shape::float_type, lens, strides}; } - auto in_rsp = ins_rsp->inputs().front(); - auto rsp_data = m.insert_instruction( - ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp); - auto mb_rsp = m.insert_instruction( - ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data); - std::vector rsp_dims(out_lens.begin(), out_lens.end()); - m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), mb_rsp); - } -}; + template + static std::optional + transform_indices(const Indices& indices, module& m, instruction_ref start) + { + auto isegments = from_ints(indices.begin(), indices.end()); + std::int64_t offset = isegments.front().base; + auto s = make_strided_view(shift(std::move(isegments), -offset)); + auto ops = generate_shape_transforms_for(s, {start->get_shape().elements()}, offset); + if(not ops.has_value()) + return std::nullopt; + return insert_ops(m, std::next(start), *ops, start); + } + }; -struct find_where_op -{ + static std::vector build_flat_gather_indices(instruction_ref gather_ins, + const argument& indices_arg, + std::size_t axis_index) + { + auto data_ins = gather_ins->inputs()[0]; + auto output_dims = gather_ins->get_shape().lens(); + const auto r_in = data_ins->get_shape().lens().size(); + const auto r_idx = indices_arg.get_shape().lens().size(); + assert(axis_index < r_in); + + shape output_s{shape::float_type, output_dims}; // element type doesn't matter here + const auto out_n = output_s.elements(); + std::vector flat(out_n); + std::iota(flat.begin(), flat.end(), 0); + + auto indices = indices_arg.to_vector(); + + transform(flat, flat.begin(), [&](std::size_t out_lin) -> std::int64_t { + // 1) output linear -> output multi-index + auto out_multi = output_s.multi(out_lin); + + // 2) isolate the "indices" coordinates from the output coords (inserted at `axis`) + std::vector idx_multi(r_idx); + std::copy(out_multi.begin() + axis_index, + out_multi.begin() + axis_index + r_idx, + idx_multi.begin()); + + // 3) look up the actual index value (may be negative) + const std::int64_t idx_lin = indices_arg.get_shape().index(idx_multi); + const std::int64_t axis_len = data_ins->get_shape().lens().at(axis_index); + auto idx_val = indices.at(idx_lin); + + // Normalize negative indices into [0, axis_len) + if(idx_val < 0) + idx_val += axis_len; + + assert(idx_val >= 0 and idx_val < axis_len); + + // 4) construct corresponding INPUT multi-index + std::vector in_multi(r_in); + + // copy dims before axis + std::copy(out_multi.begin(), out_multi.begin() + axis_index, in_multi.begin()); + + // axis coordinate from indices + in_multi.at(axis_index) = idx_val; + + // copy dims after axis; they are shifted by r_idx in output + std::copy(out_multi.begin() + axis_index + r_idx, + out_multi.end(), + in_multi.begin() + axis_index + 1); + + // 5) map input multi-index -> flat offset in contiguous buffer + const auto in_lin = data_ins->get_shape().index(in_multi); + return in_lin; + }); + + return flat; + } auto matcher() const { return match::name("gather")( - match::args(match::name("reshape")(match::arg(0)(match::name("concat").bind("data"))), - match::is_constant().bind("ind"))); + match::args(match::any().bind("data"), match::is_constant().bind("indices"))); } void apply(module& m, const match::matcher_result& r) const { - auto ins = r.result; - auto concat = r.instructions["data"]; - auto ins_ind = r.instructions["ind"]; - std::vector vec_ind; - auto arg_ind = ins_ind->eval(); - arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); }); - // ind has to be the same value - auto val = vec_ind.front(); - if(not std::all_of(vec_ind.begin(), vec_ind.end(), [&](auto v) { return (v == val); })) - { + auto ins = r.result; + auto indices_ins = r.instructions["indices"]; + auto data_ins = r.instructions["data"]; + auto gather_op = any_cast(ins->get_operator()); + const auto& dlens = data_ins->get_shape().lens(); + if(dlens.empty()) return; - } - // concat axis must be 0 - auto op = any_cast(concat->get_operator()); - if(op.axis != 0) - { + const auto axis_index = static_cast( + tune_axis(static_cast(dlens.size()), gather_op.axis, gather_op.name())); + const auto axis_len = dlens.at(axis_index); + if(axis_len == 0) return; - } - // check concat inputs, it has to be 2 and have the same shape - const auto& inputs = concat->inputs(); - if(inputs.size() != 2) - { + auto arg_ind = indices_ins->eval(); + if(arg_ind.empty()) return; - } - if(inputs.at(0)->get_shape() != inputs.at(1)->get_shape()) - { + + std::vector indices_values; + arg_ind.visit([&](auto v) { + indices_values.resize(v.size()); + std::transform(v.begin(), v.end(), indices_values.begin(), [](auto x) { + return static_cast(x); + }); + }); + if(indices_values.empty()) return; - } - if(inputs.at(0)->get_shape().lens() != ins_ind->get_shape().lens()) - { + + const auto& indices_shape = indices_ins->get_shape(); + if(indices_shape.elements() != indices_values.size()) + return; + + // Skip if indices have broadcast strides (e.g., scalar broadcast) + if(indices_shape.broadcasted()) return; - } - if(val) + // Normalize negative indices using transform + std::transform(indices_values.begin(), + indices_values.end(), + indices_values.begin(), + [axis_len](auto idx) { + if(idx < 0) + idx += static_cast(axis_len); + return idx; + }); + + // Validate all indices are in bounds + bool all_valid = + std::all_of(indices_values.begin(), indices_values.end(), [axis_len](auto idx) { + return idx >= 0 and idx < static_cast(axis_len); + }); + if(not all_valid) + return; + + // Create indices argument with normalized values + shape normalized_indices_shape{shape::int64_type, indices_shape.lens()}; + literal indices_lit(normalized_indices_shape, indices_values.begin(), indices_values.end()); + argument indices_arg = indices_lit.get_argument(); + + // Sanity check: ensure the argument shape matches + assert(indices_arg.get_shape().lens() == indices_shape.lens()); + assert(indices_arg.get_shape().elements() == indices_values.size()); + + std::optional new_ins = std::nullopt; + + if(data_ins->get_shape().ndim() == 1 and indices_ins->get_shape().ndim() == 1) { - m.replace_instruction(ins, inputs.at(0)); + new_ins = arithmetic_segment::transform_indices(indices_values, m, data_ins); } else { - m.replace_instruction(ins, inputs.at(1)); + auto data_1d = + insert_auto_reshape(m, ins, {data_ins->get_shape().elements()}, data_ins); + auto new_indices = build_flat_gather_indices(ins, indices_arg, axis_index); + new_ins = arithmetic_segment::transform_indices(new_indices, m, data_1d); } + + if(not new_ins.has_value()) + return; + + auto reshaped = insert_auto_reshape(m, ins, ins->get_shape().lens(), *new_ins); + + m.replace_instruction(ins, reshaped); } }; @@ -1420,13 +1785,14 @@ struct find_flatten void simplify_reshapes::apply(module& m) const { + if(enable_gather_rewrite) + match::find_matches(m, find_gather{}); m.repeat_while_changes(depth, [&] { match::find_matches(m, - find_where_op{}, - find_resize{}, find_nop_reshapes{}, find_flatten{}, find_reshape_cont{}, + find_slice_shape_transforms{}, find_nested_shape_transforms{}, find_concat_slice{}, find_concat_transpose{}, diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 891d80631bc..6d68f8dd808 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -198,7 +198,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti // workaround for rocBLAS unsupported error when using uint8 in quant_dot, quant_convolution & pooling eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_convolution", "quant_dot", "pooling"}}, eliminate_data_type{unsupported_types, shape::type_t::float_type}, - simplify_reshapes{}, + simplify_reshapes{.enable_gather_rewrite = true}, eliminate_identity{}, eliminate_pad{}, dead_code_elimination{}, diff --git a/test/algorithm.cpp b/test/algorithm.cpp index 358ae0063e5..47d671b4763 100644 --- a/test/algorithm.cpp +++ b/test/algorithm.cpp @@ -85,34 +85,34 @@ MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(adjacent_remove_if_non_equivalence, int) MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_basic, int) { - Container v = {5, 3, 7, 1, 9, 2}; - auto is_even = [](int x) { return x % 2 == 0; }; - auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); + Container v = {5, 3, 7, 1, 9, 2}; + auto is_even = [](int x) { return x % 2 == 0; }; + auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); EXPECT(it != v.end()); EXPECT(*it == 2); } MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_no_valid, int) { - Container v = {5, 3, 7, 1, 9}; - auto is_even = [](int x) { return x % 2 == 0; }; - auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); + Container v = {5, 3, 7, 1, 9}; + auto is_even = [](int x) { return x % 2 == 0; }; + auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); EXPECT(it == v.end()); } MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_all_valid, int) { - Container v = {6, 2, 8, 4, 10}; - auto is_even = [](int x) { return x % 2 == 0; }; - auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); + Container v = {6, 2, 8, 4, 10}; + auto is_even = [](int x) { return x % 2 == 0; }; + auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); EXPECT(it != v.end()); EXPECT(*it == 2); } MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_custom_compare, int) { - Container v = {5, 3, 7, 1, 9, 2, 8}; - auto is_even = [](int x) { return x % 2 == 0; }; + Container v = {5, 3, 7, 1, 9, 2, 8}; + auto is_even = [](int x) { return x % 2 == 0; }; // Find the largest even number auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::greater<>{}); EXPECT(it != v.end()); @@ -129,9 +129,9 @@ MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_empty, int) MIGRAPHX_FORWARD_CONTAINER_TEST_CASE(min_element_if_first_element, int) { - Container v = {2, 5, 3, 7, 1, 9}; - auto is_even = [](int x) { return x % 2 == 0; }; - auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); + Container v = {2, 5, 3, 7, 1, 9}; + auto is_even = [](int x) { return x % 2 == 0; }; + auto it = migraphx::min_element_if(v.begin(), v.end(), is_even, std::less<>{}); EXPECT(it != v.end()); EXPECT(*it == 2); EXPECT(it == v.begin()); diff --git a/test/common_dims.cpp b/test/common_dims.cpp index 416e7495255..04e0581dbca 100644 --- a/test/common_dims.cpp +++ b/test/common_dims.cpp @@ -83,6 +83,22 @@ TEST_CASE(common4) EXPECT(cd.axes_map2 == axes_map{{0}, {1, 2}, {3, 4}}); } +TEST_CASE(common5) +{ + auto cd = migraphx::common_dims::compute({3, 8, 5}, {12, 10}); + EXPECT(cd.dims == std::vector{3, 4, 2, 5}); + EXPECT(cd.axes_map1 == axes_map{{0}, {1, 2}, {3}}); + EXPECT(cd.axes_map2 == axes_map{{0, 1}, {1, 2}}); +} + +TEST_CASE(common6) +{ + auto cd = migraphx::common_dims::compute({12, 10}, {3, 8, 5}); + EXPECT(cd.dims == std::vector{3, 4, 2, 5}); + EXPECT(cd.axes_map1 == axes_map{{0, 1}, {1, 2}}); + EXPECT(cd.axes_map2 == axes_map{{0}, {1, 2}, {3}}); +} + TEST_CASE(common_same_dims) { auto cd = migraphx::common_dims::compute({{2, 32, 4}}, {64, 2, 2}); diff --git a/test/include/test.hpp b/test/include/test.hpp index 259905381b9..7ae503e2db4 100644 --- a/test/include/test.hpp +++ b/test/include/test.hpp @@ -190,10 +190,24 @@ Stream& print_stream_impl(rank<4>, Stream& s, std::nullptr_t) return s; } +template +auto print_stream_impl(rank<5>, Stream& s, const Optional& x) + -> decltype(bool(Optional{*x}), x.has_value(), x.value(), void()) +{ + if(x.has_value()) + { + print_stream(s, x.value()); + } + else + { + s << "nullopt"; + } +} + template void print_stream(Stream& s, const T& x) { - print_stream_impl(rank<5>{}, s, x); + print_stream_impl(rank<6>{}, s, x); } template diff --git a/test/shape_transform_descriptor.cpp b/test/shape_transform_descriptor.cpp index ccf887c5e52..789001a54d2 100644 --- a/test/shape_transform_descriptor.cpp +++ b/test/shape_transform_descriptor.cpp @@ -97,6 +97,15 @@ static std::vector run_shape_transforms(const std::vector& return result.to_vector(); } +static std::vector run_strided_view(const migraphx::shape& s, std::int64_t offset) +{ + auto n = s.element_space(); + std::vector data(n); + std::iota(data.begin(), data.end(), offset); + migraphx::literal l(migraphx::shape{migraphx::shape::int64_type, {n}}, data); + return l.get_argument().reshape(s).to_vector(); +} + static std::vector check_optimize_shape_transforms(const std::vector& dims, const std::vector& ops) @@ -125,6 +134,21 @@ static shape_transform_descriptor make_simple_descriptor(const std::vector> +generate_for(const std::vector& dims, + const std::vector& strides, + const std::vector& idims, + std::int64_t offset = 0) +{ + migraphx::shape s{migraphx::shape::int64_type, dims, strides}; + auto result = migraphx::generate_shape_transforms_for(s, idims, offset); + if(result) + { + CHECK(run_strided_view(s, offset) == run_shape_transforms(idims, result.value())); + } + return result; +} + TEST_CASE(dimension_len) { dimension dim; @@ -1241,4 +1265,121 @@ TEST_CASE(rebase_adjust_squeeze_unsqueeze_broadcast) } } +TEST_CASE(generate_shape_transforms_for) +{ + EXPECT(generate_for({3}, {1}, {3}) == ops{}); + EXPECT(generate_for({3}, {0}, {1}) == ops{make_op("multibroadcast", {{"out_lens", {3}}})}); + EXPECT(generate_for({3}, {3}, {9}) == + ops{ + make_op("reshape", {{"dims", {3, 3}}}), + make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), + }); + + EXPECT(generate_for({3, 4, 5, 2}, {2, 0, 0, 1}, {6}) == + ops{ + make_op("reshape", {{"dims", {3, 1, 1, 2}}}), + make_op("multibroadcast", {{"out_lens", {3, 4, 5, 2}}}), + }); + EXPECT(generate_for({3, 2}, {3, 0}, {9}) == + ops{ + make_op("reshape", {{"dims", {3, 1, 3}}}), + make_op("multibroadcast", {{"out_lens", {3, 2, 3}}}), + make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {1}}}), + }); + + EXPECT(generate_for({3, 2}, {2, 1}, {6}) == ops{ + make_op("reshape", {{"dims", {3, 2}}}), + }); + + EXPECT(generate_for({3, 2}, {1, 3}, {6}) == ops{ + make_op("reshape", {{"dims", {2, 3}}}), + make_op("transpose", {{"permutation", {1, 0}}}), + }); + + EXPECT(generate_for({2, 2, 2, 2, 3}, {0, 2, 0, 1, 0}, {4}) == + ops{ + make_op("reshape", {{"dims", {1, 2, 1, 2, 1}}}), + make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2, 3}}}), + }); + + EXPECT(generate_for({2, 2, 3}, {4, 1, 0}, {8}) == + ops{ + make_op("reshape", {{"dims", {2, 4}}}), + make_op("broadcast", {{"axis", 0}, {"out_lens", {2, 4, 3}}}), + make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {2}}}), + }); + + EXPECT(generate_for({2, 3, 4, 1}, {4, 16, 1, 1}, {48}) == + ops{ + make_op("reshape", {{"dims", {3, 4, 4, 1}}}), + make_op("transpose", {{"permutation", {1, 0, 2, 3}}}), + make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), + }); +} + +TEST_CASE(generate_shape_transforms_for_overlap) +{ + // TODO: Overlaping strides not supported yet, need to support something like torch.unfold. + + // Case 1: {2, 3} with strides {1, 1} - overlapping rows + // Row 0 accesses [0, 1, 2], Row 1 accesses [1, 2, 3] + // Total elements needed: 4 (exactly matches input size) + EXPECT(generate_for({2, 3}, {1, 1}, {4}) == std::nullopt); + // EXPECT(generate_for({2, 3}, {1, 1}, {4}) == + // ops{ + // make_op("broadcast", {{"axis", 0}, {"out_lens", {2, 4}}}), + // make_op("reshape", {{"dims", {8}}}), + // make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), + // make_op("reshape", {{"dims", {4}}}), + // make_op("reshape", {{"dims", {2, 2}}}), + // make_op("multibroadcast", {{"out_lens", {2, 3}}}), + // make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {3}}}), + // }); + + // Case 2: {3, 2, 1} with strides {3, 2, 1} + // Element at (i,j,k) is at index i*3 + j*2 + k*1 + // Max index is (2,1,0) = 2*3 + 1*2 + 0*1 = 8 + // So we need 9 elements total (indices 0-8) + EXPECT(generate_for({3, 2, 1}, {3, 2, 1}, {9}) == std::nullopt); + // EXPECT(generate_for({3, 2, 1}, {3, 2, 1}, {9}) == + // ops{ + // make_op("reshape", {{"dims", {9}}}), + // // Extract the specific pattern of elements based on strides + // make_op("reshape", {{"dims", {3, 3}}}), + // make_op("transpose", {{"permutation", {1, 0}}}), + // make_op("reshape", {{"dims", {3, 3}}}), + // make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {2}}}), + // make_op("reshape", {{"dims", {3, 2}}}), + // make_op("multibroadcast", {{"out_lens", {3, 2, 1}}}), + // }); +} + +TEST_CASE(generate_shape_transforms_for_offset) +{ + EXPECT(generate_for({3, 1}, {4, 1}, {30}, 1) == + ops{ + make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {24}}}), + make_op("reshape", {{"dims", {2, 3, 4}}}), + make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 1}}, {"ends", {1, 2}}}), + }); + + EXPECT(generate_for({3, 1}, {5, 1}, {30}, 1) == + ops{ + make_op("reshape", {{"dims", {2, 3, 5}}}), + make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 1}}, {"ends", {1, 2}}}), + }); + + EXPECT(generate_for({3, 2}, {10, 1}, {60}, 1) == + ops{ + make_op("reshape", {{"dims", {2, 3, 10}}}), + make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 1}}, {"ends", {1, 3}}}), + }); + + EXPECT(generate_for({4, 3, 2}, {24, 4, 1}, {96}, 5) == + ops{ + make_op("reshape", {{"dims", {4, 6, 4}}}), + make_op("slice", {{"axes", {1, 2}}, {"starts", {1, 1}}, {"ends", {4, 3}}}), + }); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index ce8f1f50eec..a7bedcba91e 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -38,7 +38,8 @@ static void run_pass(migraphx::module& m) { migraphx::run_passes(m, { - migraphx::simplify_reshapes{.enable_op_shape_transform_op = true}, + migraphx::simplify_reshapes{.enable_op_shape_transform_op = true, + .enable_gather_rewrite = true}, migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}, }); @@ -1516,8 +1517,7 @@ TEST_CASE(optimize_resize) auto create_optimized_module = [&] { migraphx::module m; - auto inx = m.add_parameter("X", sx); - std::vector dims = {1, 1, 2, 1, 2, 1}; + auto inx = m.add_parameter("X", sx); auto rspx = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3, 5}}}), inx); auto mbx = m.add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 2, 2, 2, 3}}}), rspx); @@ -1534,6 +1534,47 @@ TEST_CASE(optimize_resize) EXPECT(m1 == create_optimized_module()); } +TEST_CASE(optimize_resize_flatten) +{ + migraphx::shape sx{migraphx::shape::float_type, {4}}; + auto create_resize_module = [&] { + migraphx::module m; + auto inx = m.add_parameter("X", sx); + + migraphx::shape si{migraphx::shape::int32_type, {48}}; + std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, + 3, 3, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 0, 0, + 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; + auto li = m.add_literal(migraphx::literal(si, ind)); + + auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), inx, li); + auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 0}}), gr); + m.add_return({r}); + + return m; + }; + + auto m1 = create_resize_module(); + run_pass(m1); + + auto create_optimized_module = [&] { + migraphx::module m; + auto inx = m.add_parameter("X", sx); + auto rspx = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 1, 2, 1}}}), inx); + auto mbx = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 2, 3}}}), rspx); + std::vector orig_dims = {1, 2, 4, 6}; + auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), mbx); + auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 0}}), rmb); + m.add_return({r}); + + return m; + }; + + EXPECT(m1 == create_optimized_module()); +} + TEST_CASE(optimize_resize_ind_not_apply) { migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; @@ -1588,30 +1629,89 @@ TEST_CASE(optimize_resize_ndims_unequal) { migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 3, 2}}; - auto create_resize_module = [&] { - migraphx::module m; - auto inx = m.add_parameter("X", sx); - auto iny = m.add_parameter("Y", sy); + + migraphx::module m1; + { + auto inx = m1.add_parameter("X", sx); + auto iny = m1.add_parameter("Y", sy); migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 3, 2}}; std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; - auto li = m.add_literal(migraphx::literal(si, ind)); + auto li = m1.add_literal(migraphx::literal(si, ind)); - auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx); - auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); - auto r = m.add_instruction(migraphx::make_op("sub"), iny, gr); - m.add_return({r}); + auto lrsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), inx); + auto gr = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + auto r = m1.add_instruction(migraphx::make_op("sub"), iny, gr); + m1.add_return({r}); + } + run_pass(m1); - return m; - }; + migraphx::module m2; + { + auto inx = m2.add_parameter("X", sx); + auto iny = m2.add_parameter("Y", sy); + + auto rsp_y = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 3}}}), iny); + auto trans_x = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1}}}), inx); + auto mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 2, 3}}}), trans_x); + auto sub = m2.add_instruction(migraphx::make_op("sub"), rsp_y, mb); + auto rsp_out = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 4, 3, 2}}}), sub); + m2.add_return({rsp_out}); + } - auto m = create_resize_module(); - run_pass(m); - EXPECT(m == create_resize_module()); + EXPECT(m1.sort() == m2.sort()); } TEST_CASE(optimize_resize_ind_non_brcst) +{ + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; + migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 6}}; + + migraphx::module m1; + { + auto inx = m1.add_parameter("X", sx); + auto iny = m1.add_parameter("Y", sy); + + migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}}; + std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, + 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; + auto li = m1.add_literal(migraphx::literal(si, ind)); + + auto lrsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {6}}}), inx); + auto gr = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); + auto r = m1.add_instruction(migraphx::make_op("sub"), iny, gr); + m1.add_return({r}); + } + run_pass(m1); + + migraphx::module m2; + { + auto inx = m2.add_parameter("X", sx); + auto iny = m2.add_parameter("Y", sy); + + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {6}}}), inx); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), rsp1); + auto rsp2 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 2, 1, 2, 1}}}), slc); + auto mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 2, 2, 2, 3}}}), rsp2); + auto rsp_y = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 2, 2, 2, 3}}}), iny); + auto sub = m2.add_instruction(migraphx::make_op("sub"), rsp_y, mb); + auto rsp3 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 1, 4, 6}}}), sub); + m2.add_return({rsp3}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(optimize_resize_ind_non_const) { migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 6}}; @@ -1621,197 +1721,1089 @@ TEST_CASE(optimize_resize_ind_non_brcst) auto iny = m.add_parameter("Y", sy); migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}}; - std::vector ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, - 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3}; - auto li = m.add_literal(migraphx::literal(si, ind)); - + auto li = m.add_parameter("ind", si); auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {6}}}), inx); auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); auto r = m.add_instruction(migraphx::make_op("sub"), iny, gr); m.add_return({r}); - return m; - }; + return m; + }; + + auto m = create_resize_module(); + run_pass(m); + EXPECT(m == create_resize_module()); +} + +TEST_CASE(optimize_where_true) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; + auto create_where_module = [&](bool cond) { + migraphx::module m; + auto inx = m.add_parameter("X", s); + auto iny = m.add_parameter("Y", s); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; + std::vector idata(si.elements(), static_cast(cond)); + auto li = m.add_literal(migraphx::literal(si, idata)); + auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m.add_return({r}); + return m; + }; + + auto create_expected = [&](bool cond) { + migraphx::module m; + auto inx = m.add_parameter("X", s); + auto iny = m.add_parameter("Y", s); + + auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto bc = m.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {12, 1, 3, 2}}}), rsp); + int64_t start = cond ? 1 : 0; + int64_t end = cond ? 2 : 1; + auto slc = m.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {start}}, {"ends", {end}}}), bc); + m.add_return({slc}); + return m; + }; + + auto m = create_where_module(true); + run_pass(m); + auto expected = create_expected(true); + EXPECT(m.sort() == expected.sort()); + + auto m1 = create_where_module(false); + run_pass(m1); + auto expected1 = create_expected(false); + EXPECT(m1.sort() == expected1.sort()); +} + +TEST_CASE(where_different_cond_values) +{ + auto create_where_module = [] { + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; + auto inx = m.add_parameter("X", s); + auto iny = m.add_parameter("Y", s); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; + std::vector idata = {1, 1, 0, 1, 0, 1}; + auto li = m.add_literal(migraphx::literal(si, idata)); + auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m.add_return({r}); + return m; + }; + + auto m = create_where_module(); + run_pass(m); + EXPECT(m == create_where_module()); +} + +TEST_CASE(where_axis_nonzero) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; + + migraphx::module m1; + { + auto inx = m1.add_parameter("X", s); + auto iny = m1.add_parameter("Y", s); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; + std::vector idata(6, 1); + auto li = m1.add_literal(migraphx::literal(si, idata)); + auto data = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), inx, iny); + auto data_1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto r = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m1.add_return({r}); + } + run_pass(m1); + + migraphx::module m2; + { + auto inx = m2.add_parameter("X", s); + auto iny = m2.add_parameter("Y", s); + + auto data = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), inx, iny); + auto tr = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {1, 2, 3, 0}}}), data); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12, 1}}}), tr); + auto bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {12, 1, 3, 2}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), bc); + m2.add_return({slc}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(where_three_concat_inputs) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; + + migraphx::module m1; + { + auto inx = m1.add_parameter("X", s); + auto iny = m1.add_parameter("Y", s); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; + std::vector idata(6, 1); + auto li = m1.add_literal(migraphx::literal(si, idata)); + auto data = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny, inx); + auto data_1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); + auto r = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m1.add_return({r}); + } + run_pass(m1); + + migraphx::module m2; + { + auto inx = m2.add_parameter("X", s); + auto iny = m2.add_parameter("Y", s); + + auto data = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny, inx); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); + auto bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {18, 1, 3, 2}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), bc); + m2.add_return({slc}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(where_three_inputs_diff_shapes) +{ + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; + migraphx::shape sy{migraphx::shape::float_type, {2, 1, 3, 2}}; + + migraphx::module m1; + { + auto inx = m1.add_parameter("X", sx); + auto iny = m1.add_parameter("Y", sy); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; + std::vector idata(6, 1); + auto li = m1.add_literal(migraphx::literal(si, idata)); + auto data = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto data_1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); + auto r = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m1.add_return({r}); + } + run_pass(m1); + + migraphx::module m2; + { + auto inx = m2.add_parameter("X", sx); + auto iny = m2.add_parameter("Y", sy); + + auto data = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); + auto bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {18, 1, 3, 2}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), bc); + m2.add_return({slc}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(where_three_lens_diff) +{ + migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; + migraphx::shape sy{migraphx::shape::float_type, {1, 1, 3, 2}}; + + migraphx::module m1; + { + auto inx = m1.add_parameter("X", sx); + auto iny = m1.add_parameter("Y", sy); + + migraphx::shape si{migraphx::shape::bool_type, {1, 1, 6}}; + std::vector idata(6, 1); + auto li = m1.add_literal(migraphx::literal(si, idata)); + auto data = m1.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto data_1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto r = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); + m1.add_return({r}); + } + run_pass(m1); + + migraphx::module m2; + { + auto inx = m2.add_parameter("X", sx); + auto iny = m2.add_parameter("Y", sy); + + auto data = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + auto bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {12, 1, 6}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), bc); + m2.add_return({slc}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_1d_nd_indices) +{ + migraphx::module m; + auto x = m.add_parameter("x", {migraphx::shape::float_type, {6}}); + migraphx::shape si{migraphx::shape::int32_type, {2, 3}}; + std::vector indices = {0, 1, 2, 3, 4, 5}; + auto li = m.add_literal(migraphx::literal(si, indices)); + auto g = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), x, li); + m.add_return({g}); + + run_pass(m); + + migraphx::module expected; + auto xe = expected.add_parameter("x", {migraphx::shape::float_type, {6}}); + auto reshaped = expected.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3}}}), xe); + expected.add_return({reshaped}); + + EXPECT(m == expected); +} + +TEST_CASE(gather_axis_slice_broadcast) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 4}}); + migraphx::shape si{migraphx::shape::int32_type, {2, 3}}; + std::vector indices = {1, 1, 1, 2, 2, 2}; + auto li = m1.add_literal(migraphx::literal(si, indices)); + auto g = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), x, li); + m1.add_return({g}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 4}}); + auto br = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {2, 4, 3}}}), x); + auto sliced = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {3}}}), br); + m2.add_return({sliced}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_single_index) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4, 5}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {1}}; + auto indices = m1.add_literal(migraphx::literal{si, {2}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + // Verify gather was optimized away + EXPECT( + std::none_of(m1.begin(), m1.end(), [](const auto& ins) { return ins.name() == "gather"; })); + + // Verify output shape is correct: {3, 1, 5} + auto result = + std::find_if(m1.begin(), m1.end(), [](const auto& ins) { return ins.name() == "@return"; }); + EXPECT(result != m1.end()); + EXPECT(result->inputs().front()->get_shape().lens() == std::vector{3, 1, 5}); + + // Verify only view operations are used (transpose, slice, reshape, squeeze, unsqueeze, + // broadcast) + EXPECT(std::all_of(m1.begin(), m1.end(), [](const auto& ins) { + return ins.name() == "@param" or ins.name() == "@literal" or ins.name() == "@return" or + ins.name() == "transpose" or ins.name() == "slice" or ins.name() == "reshape" or + ins.name() == "squeeze" or ins.name() == "unsqueeze" or + ins.name() == "multibroadcast" or ins.name() == "broadcast"; + })); +} + +TEST_CASE(gather_multi_axis_stride) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {1, 3, 4, 4}}); + auto flatten = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 3, 1, 4}}; + std::vector indices = {0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, + 4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39}; + auto li = m1.add_literal(migraphx::literal{indices_shape, indices}); + auto gather = m1.add_instruction(migraphx::make_op("gather"), flatten, li); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {1, 3, 4, 4}}); + auto unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {3}}}), x); + auto tr = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {2, 0, 1, 3, 4}}}), unsq); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), tr); + auto sliced = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), sq); + m2.add_return({sliced}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_flatten_multi_axis_stride) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {48}}); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {24}}; + std::vector indices = {0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, + 4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39}; + auto li = m1.add_literal(migraphx::literal{indices_shape, indices}); + auto gather = m1.add_instruction(migraphx::make_op("gather"), x, li); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {48}}); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 4}}}), x); + auto tr = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), tr); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {24}}}), slc); + m2.add_return({rsp2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_same_indices) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4, 5}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {1, 1, 1}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4, 5}}; + auto data = m2.add_parameter("data", s); + auto unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), data); + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 1, 20}}}), unsq); + auto mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 3, 20}}}), rsp1); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), mb); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 3, 4, 5}}}), slc); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), rsp2); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_same_indices_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {12}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {1, 1, 1}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + auto unsqueeze = + m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), gather); + m1.add_return({unsqueeze}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {12}}; + auto data = m2.add_parameter("data", s); + auto bc = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 0}, {"out_lens", {12, 3}}}), data); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), bc); + m2.add_return({slc}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_sequential_indices) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {5, 6}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {1, 2, 3}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {5, 6}}; + auto data = m2.add_parameter("data", s); + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {30}}}), data); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {6}}, {"ends", {24}}}), rsp1); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 6}}}), slc); + m2.add_return({rsp2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_sequential_indices_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {1, 2, 3}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m2.add_parameter("data", s); + auto unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), data); + auto mb = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 30}}}), unsq); + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {60}}}), mb); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {31}}}), rsp1); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {10, 3}}}), slc1); + auto slc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), rsp2); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slc2); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_stride_indices_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {1, 5, 9}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m2.add_parameter("data", s); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {24}}}), data); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 4}}}), slc1); + auto slc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 1}}, {"ends", {1, 2}}}), + rsp); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 2}}}), slc2); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_stride_divisible_indices_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {0, 5, 10}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m2.add_parameter("data", s); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 5}}}), data); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 0}}, {"ends", {1, 1}}}), + rsp); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 2}}}), slc); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_stride_divisible_indices_window_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {5, 10, 15}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {30}}; + auto data = m2.add_parameter("data", s); + auto unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), data); + auto mb = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 30}}}), unsq); + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {60}}}), mb); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {5}}, {"ends", {35}}}), rsp1); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 3, 5}}}), slc1); + auto slc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0, 2}}, {"starts", {0, 0}}, {"ends", {1, 1}}}), + rsp2); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 2}}}), slc2); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_constant_stride_divisible_both_indices_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {15}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {3}}; + auto indices = m1.add_literal(migraphx::literal{si, {0, 5, 10}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {15}}; + auto data = m2.add_parameter("data", s); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 5}}}), data); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), rsp); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), slc); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_sequential_stride_rtr_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {8}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {8}}; + auto indices = m1.add_literal(migraphx::literal{si, {0, 4, 1, 5, 2, 6, 3, 7}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {8}}; + auto data = m2.add_parameter("data", s); + auto reshape1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4}}}), data); + auto transpose = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), reshape1); + auto reshape2 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), transpose); + m2.add_return({reshape2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_sequential_stride_rtr_window_1d) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {12}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {8}}; + auto indices = m1.add_literal(migraphx::literal{si, {1, 4, 7, 10, 2, 5, 8, 11}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {12}}; + auto data = m2.add_parameter("data", s); + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 3}}}), data); + auto tr = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), rsp1); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {3}}}), tr); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), slc); + m2.add_return({rsp2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_axis0_half_split_concat) +{ + // This pattern is not optimized - gather remains + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {4, 3}}); + migraphx::shape si{migraphx::shape::int32_type, {4}}; + std::vector indices = {2, 3, 0, 1}; + auto li = m1.add_literal(migraphx::literal(si, indices)); + auto g = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), x, li); + m1.add_return({g}); + } + auto m2 = m1; + run_pass(m1); + + // Verify output shape is correct: {4, 3} + auto result = + std::find_if(m1.begin(), m1.end(), [](const auto& ins) { return ins.name() == "@return"; }); + EXPECT(result != m1.end()); + EXPECT(result->inputs().front()->get_shape().lens() == std::vector{4, 3}); + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_axis1_same_stride_diff_base) +{ + // This pattern is not optimized - gather remains + migraphx::module m1; + { + migraphx::shape si{migraphx::shape::int32_type, {2, 2}}; + std::vector indices = {1, 1, 0, 2}; + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {3, 3}}); + auto tx = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x); + auto ind = m1.add_literal(migraphx::literal{si, indices}); + auto tind = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), ind); + auto g = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), tx, tind); + m1.add_return({g}); + } + auto m2 = m1; + // Verify there is no hang + run_pass(m1); + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_flatten_stride_slice) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {8}}); + migraphx::shape si{migraphx::shape::int32_type, {4}}; + std::vector indices = {1, 5, 2, 6}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather"), x, li); + m1.add_return({g}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {8}}); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4}}}), x); + auto tr = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {3}}}), tr); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), slc); + m2.add_return({rsp2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_flatten_stride_first) +{ + migraphx::module m; + auto x = m.add_parameter("X", {migraphx::shape::float_type, {8}}); + migraphx::shape si{migraphx::shape::int32_type, {4}}; + std::vector indices = {0, 2, 4, 6}; + auto li = m.add_literal(migraphx::literal{si, indices}); + auto g = m.add_instruction(migraphx::make_op("gather"), x, li); + m.add_return({g}); + + run_pass(m); + + migraphx::module expected; + auto xe = expected.add_parameter("X", {migraphx::shape::float_type, {8}}); + auto reshape_block = + expected.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 2}}}), xe); + auto slice = expected.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), reshape_block); + auto result = expected.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), slice); + expected.add_return({result}); + + EXPECT(m == expected); +} + +TEST_CASE(gather_flatten_stride_offset) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {16}}); + migraphx::shape si{migraphx::shape::int32_type, {1, 4}}; + std::vector indices = {1, 5, 9, 13}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather"), x, li); + m1.add_return({g}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {16}}); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 4}}}), x); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), rsp); + auto unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), slc); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), unsq); + m2.add_return({sq}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_flatten_stride_grid) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {768}}); + migraphx::shape si{migraphx::shape::int32_type, {48}}; + std::vector indices = {17, 21, 25, 29, 81, 85, 89, 93, 145, 149, 153, 157, + 209, 213, 217, 221, 273, 277, 281, 285, 337, 341, 345, 349, + 401, 405, 409, 413, 465, 469, 473, 477, 529, 533, 537, 541, + 593, 597, 601, 605, 657, 661, 665, 669, 721, 725, 729, 733}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather"), x, li); + m1.add_return({g}); + } + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {768}}); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12, 16, 4}}}), x); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1, 2}}, {"starts", {4, 1}}, {"ends", {8, 2}}}), + rsp); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), slc); + m2.add_return({rsp2}); + } - auto m = create_resize_module(); - run_pass(m); - EXPECT(m == create_resize_module()); + EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(optimize_resize_ind_non_const) +TEST_CASE(gather_flatten_permutation) { - migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; - migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 6}}; - auto create_resize_module = [&] { - migraphx::module m; - auto inx = m.add_parameter("X", sx); - auto iny = m.add_parameter("Y", sy); + migraphx::module m; + auto x = m.add_parameter("X", {migraphx::shape::float_type, {16}}); + migraphx::shape si{migraphx::shape::int32_type, {16}}; + std::vector indices = {0, 2, 8, 10, 4, 6, 12, 14, 1, 3, 9, 11, 5, 7, 13, 15}; + auto li = m.add_literal(migraphx::literal{si, indices}); + auto g = m.add_instruction(migraphx::make_op("gather"), x, li); + m.add_return({g}); - migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}}; - auto li = m.add_parameter("ind", si); - auto lrsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", {6}}}), inx); - auto gr = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); - auto r = m.add_instruction(migraphx::make_op("sub"), iny, gr); - m.add_return({r}); + run_pass(m); - return m; - }; + migraphx::module expected; + auto xe = expected.add_parameter("X", {migraphx::shape::float_type, {16}}); + auto reshape_perm = + expected.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2}}}), xe); + auto transpose = expected.add_instruction( + migraphx::make_op("transpose", {{"permutation", {3, 1, 0, 2}}}), reshape_perm); + auto reshape_out = + expected.add_instruction(migraphx::make_op("reshape", {{"dims", {16}}}), transpose); + expected.add_return({reshape_out}); - auto m = create_resize_module(); - run_pass(m); - EXPECT(m == create_resize_module()); + expected.debug_print(); + + EXPECT(m == expected); } -TEST_CASE(optimize_where_true) +TEST_CASE(gather_flatten_channel_patch) { - migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; - auto create_where_module = [&](bool cond) { - migraphx::module m; - auto inx = m.add_parameter("X", s); - auto iny = m.add_parameter("Y", s); - - migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; - std::vector idata(si.elements(), static_cast(cond)); - auto li = m.add_literal(migraphx::literal(si, idata)); - auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); - auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); - auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); - m.add_return({r}); - return m; - }; + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {48}}); + migraphx::shape si{migraphx::shape::int32_type, {12}}; + std::vector indices = {5, 21, 37, 9, 25, 41, 6, 22, 38, 10, 26, 42}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather"), x, li); + m1.add_return({g}); + } + run_pass(m1); - auto return_xy = [&](bool cond) { - migraphx::module m; - auto x = m.add_parameter("X", s); - auto y = m.add_parameter("Y", s); - cond ? m.add_return({x}) : m.add_return({y}); - return m; - }; + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {48}}); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 4}}}), x); + auto tr = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), rsp); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 1}}, {"ends", {3, 3}}}), + tr); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), slc); + m2.add_return({rsp2}); + } - auto m = create_where_module(true); - run_pass(m); - EXPECT(m == return_xy(true)); + EXPECT(m1.sort() == m2.sort()); +} - auto m1 = create_where_module(false); +TEST_CASE(gather_flatten_channel_parity_permutation) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("X", {migraphx::shape::float_type, {48}}); + migraphx::shape si{migraphx::shape::int32_type, {48}}; + std::vector indices = {0, 2, 8, 10, 16, 18, 24, 26, 32, 34, 40, 42, + 4, 6, 12, 14, 20, 22, 28, 30, 36, 38, 44, 46, + 1, 3, 9, 11, 17, 19, 25, 27, 33, 35, 41, 43, + 5, 7, 13, 15, 21, 23, 29, 31, 37, 39, 45, 47}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather"), x, li); + m1.add_return({g}); + } run_pass(m1); - EXPECT(m1 == return_xy(false)); + + migraphx::module m2; + { + auto x = m2.add_parameter("X", {migraphx::shape::float_type, {48}}); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {6, 2, 2, 2}}}), x); + auto tr = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {3, 1, 0, 2}}}), rsp); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), tr); + m2.add_return({rsp2}); + } + + EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(where_different_cond_values) +TEST_CASE(gather_axis1_factorized_grid_const) { - auto create_where_module = [] { - migraphx::module m; - migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; - auto inx = m.add_parameter("X", s); - auto iny = m.add_parameter("Y", s); + migraphx::module m1; + { + auto data = m1.add_parameter("data", {migraphx::shape::float_type, {3, 8, 5}}); + migraphx::shape si{migraphx::shape::int32_type, {2, 2, 1}}; + std::vector indices = {1, 3, 5, 7}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, li); + m1.add_return({g}); + } + run_pass(m1); - migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; - std::vector idata = {1, 1, 0, 1, 0, 1}; - auto li = m.add_literal(migraphx::literal(si, idata)); - auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); - auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); - auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); - m.add_return({r}); - return m; - }; + migraphx::module m2; + { + auto data = m2.add_parameter("data", {migraphx::shape::float_type, {3, 8, 5}}); + auto rsp1 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", std::vector{3, 4, 2, 5}}}), data); + auto rsp2 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", std::vector{12, 10}}}), rsp1); + auto slc = m2.add_instruction(migraphx::make_op("slice", + {{"axes", std::vector{1}}, + {"starts", std::vector{5}}, + {"ends", std::vector{10}}}), + rsp2); + auto rsp3 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", std::vector{3, 2, 2, 1, 5}}}), slc); + m2.add_return({rsp3}); + } - auto m = create_where_module(); - run_pass(m); - EXPECT(m == create_where_module()); + EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(where_axis_nonzero) +TEST_CASE(gather_axis1_factorized_grid_multi_const) { - auto create_where_module = [] { - migraphx::module m; - migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; - auto inx = m.add_parameter("X", s); - auto iny = m.add_parameter("Y", s); + migraphx::module m1; + { + auto data = m1.add_parameter("data", {migraphx::shape::float_type, {2, 27, 4}}); + migraphx::shape si{migraphx::shape::int32_type, {3, 1}}; + std::vector indices = {5, 14, 23}; + auto li = m1.add_literal(migraphx::literal{si, indices}); + auto g = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, li); + m1.add_return({g}); + } + run_pass(m1); - migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; - std::vector idata(6, 1); - auto li = m.add_literal(migraphx::literal(si, idata)); - auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), inx, iny); - auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); - auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); - m.add_return({r}); - return m; - }; + migraphx::module m2; + { + auto data = m2.add_parameter("data", {migraphx::shape::float_type, {2, 27, 4}}); + auto rsp1 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", std::vector{2, 3, 9, 4}}}), data); + auto rsp2 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", std::vector{6, 36}}}), rsp1); + auto slc = m2.add_instruction(migraphx::make_op("slice", + {{"axes", std::vector{1}}, + {"starts", std::vector{20}}, + {"ends", std::vector{24}}}), + rsp2); + auto rsp3 = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", std::vector{2, 3, 1, 4}}}), slc); + m2.add_return({rsp3}); + } - auto m = create_where_module(); - run_pass(m); - EXPECT(m == create_where_module()); + EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(where_three_concat_inputs) +TEST_CASE_SKIP(gather_constant_scalar_index, "Scalar indices are not supported yet") { - auto create_where_module = [] { - migraphx::module m; - migraphx::shape s{migraphx::shape::float_type, {1, 1, 3, 2}}; - auto inx = m.add_parameter("X", s); - auto iny = m.add_parameter("Y", s); + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type}; + auto indices = m1.add_literal(migraphx::literal{si, {2}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); - migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; - std::vector idata(6, 1); - auto li = m.add_literal(migraphx::literal(si, idata)); - auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny, inx); - auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); - auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); - m.add_return({r}); - return m; - }; + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4}}; + auto data = m2.add_parameter("data", s); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), data); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice); + m2.add_return({squeeze}); + } - auto m = create_where_module(); - run_pass(m); - EXPECT(m == create_where_module()); + EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(where_three_inputs_diff_shapes) +TEST_CASE(gather_constant_negative_index) { - auto create_where_module = [] { - migraphx::module m; - migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; - migraphx::shape sy{migraphx::shape::float_type, {2, 1, 3, 2}}; - auto inx = m.add_parameter("X", sx); - auto iny = m.add_parameter("Y", sy); + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {1}}; + auto indices = m1.add_literal(migraphx::literal{si, {-1}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); - migraphx::shape si{migraphx::shape::bool_type, {1, 1, 3, 2}}; - std::vector idata(6, 1); - auto li = m.add_literal(migraphx::literal(si, idata)); - auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); - auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {18}}}), data); - auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); - m.add_return({r}); - return m; - }; + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4}}; + auto data = m2.add_parameter("data", s); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), data); + m2.add_return({slice}); + } - auto m = create_where_module(); - run_pass(m); - EXPECT(m == create_where_module()); + EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(where_three_lens_diff) +TEST_CASE(gather_non_constant_indices) { - auto create_where_module = [] { - migraphx::module m; - migraphx::shape sx{migraphx::shape::float_type, {1, 1, 3, 2}}; - migraphx::shape sy{migraphx::shape::float_type, {1, 1, 3, 2}}; - auto inx = m.add_parameter("X", sx); - auto iny = m.add_parameter("Y", sy); + // Should not be transformed + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 4}}; + auto si = migraphx::shape{migraphx::shape::int32_type, {2}}; + auto data = m1.add_parameter("data", s); + auto indices = m1.add_parameter("indices", si); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data, indices); + m1.add_return({gather}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} - migraphx::shape si{migraphx::shape::bool_type, {1, 1, 6}}; - std::vector idata(6, 1); - auto li = m.add_literal(migraphx::literal(si, idata)); - auto data = m.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), inx, iny); - auto data_1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); - auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), data_1, li); - m.add_return({r}); - return m; - }; +TEST_CASE(gather_axis_1) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {2, 5, 3}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {2}}; + auto indices = m1.add_literal(migraphx::literal{si, {0, 1}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, indices); + m1.add_return({gather}); + } + run_pass(m1); - auto m = create_where_module(); - run_pass(m); - EXPECT(m == create_where_module()); + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {2, 5, 3}}; + auto data = m2.add_parameter("data", s); + auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 15}}}), data); + auto slc = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), rsp1); + auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 3}}}), slc); + m2.add_return({rsp2}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(gather_onnx_axis_one_ex) +{ + migraphx::module m1; + { + auto s = migraphx::shape{migraphx::shape::float_type, {3, 3}}; + auto data = m1.add_parameter("data", s); + migraphx::shape si{migraphx::shape::int32_type, {2, 1}}; + auto indices = m1.add_literal(migraphx::literal{si, {0, 2}}); + auto gather = m1.add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, indices); + m1.add_return({gather}); + } + migraphx::module m2 = m1; + run_pass(m1); + + EXPECT(m1.sort() == m2.sort()); } TEST_CASE(reshape_cont) @@ -3474,6 +4466,63 @@ TEST_CASE(add_transpose) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(slice_squeeze_unsqueeze) +{ + migraphx::shape s{migraphx::shape::float_type, {12, 6, 1}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto slice = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), x); + auto squeeze = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 2}}}), slice); + auto unsqueeze = + m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), squeeze); + m1.add_return({unsqueeze}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), x); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {3}}}), unsqueeze); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), squeeze); + m2.add_return({slice}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(slice_reshape_reshape) +{ + migraphx::shape s{migraphx::shape::float_type, {3, 3, 20}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto slice = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), x); + auto reshape1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {60}}}), slice); + auto reshape2 = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 5}}}), reshape1); + m1.add_return({reshape2}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto slice = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), x); + auto reshape = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 3, 4, 5}}}), slice); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), reshape); + m2.add_return({squeeze}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + TEST_CASE(flatten) { migraphx::shape s{migraphx::shape::float_type, {4608, 8, 2}}; diff --git a/test/verify/test_gather_axis0_half_split_concat.cpp b/test/verify/test_gather_axis0_half_split_concat.cpp new file mode 100644 index 00000000000..351241f034a --- /dev/null +++ b/test/verify/test_gather_axis0_half_split_concat.cpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_axis0_half_split_concat : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("X", {migraphx::shape::float_type, {4, 3}}); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {4}}; + std::vector indices = {2, 3, 0, 1}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction( + migraphx::make_op("gather", {{"axis", int64_t{0}}}), data, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_axis0_slice_broadcast.cpp b/test/verify/test_gather_axis0_slice_broadcast.cpp new file mode 100644 index 00000000000..16706609346 --- /dev/null +++ b/test/verify/test_gather_axis0_slice_broadcast.cpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_axis0_slice_broadcast : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 4}}); + auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 8}}; + std::vector indices = {0, 0, 0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3, 3, 3}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction(migraphx::make_op("gather"), reshape, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_axis1_factorized_grid_const.cpp b/test/verify/test_gather_axis1_factorized_grid_const.cpp new file mode 100644 index 00000000000..de41e37caaa --- /dev/null +++ b/test/verify/test_gather_axis1_factorized_grid_const.cpp @@ -0,0 +1,47 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_axis1_factorized_grid_const + : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("data", {migraphx::shape::float_type, {3, 8, 5}}); + migraphx::shape si{migraphx::shape::int32_type, {2, 2, 1}}; + std::vector indices = {1, 3, 5, 7}; + auto li = mm->add_literal(migraphx::literal{si, indices}); + auto g = mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, li); + mm->add_return({g}); + + return p; + } +}; diff --git a/test/verify/test_gather_axis1_factorized_grid_multi_const.cpp b/test/verify/test_gather_axis1_factorized_grid_multi_const.cpp new file mode 100644 index 00000000000..1e9a967b623 --- /dev/null +++ b/test/verify/test_gather_axis1_factorized_grid_multi_const.cpp @@ -0,0 +1,47 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_axis1_factorized_grid_multi_const + : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("data", {migraphx::shape::float_type, {2, 27, 4}}); + migraphx::shape si{migraphx::shape::int32_type, {3, 1}}; + std::vector indices = {5, 14, 23}; + auto li = mm->add_literal(migraphx::literal{si, indices}); + auto g = mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, li); + mm->add_return({g}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_channel_parity.cpp b/test/verify/test_gather_flatten_channel_parity.cpp new file mode 100644 index 00000000000..44fbf442084 --- /dev/null +++ b/test/verify/test_gather_flatten_channel_parity.cpp @@ -0,0 +1,53 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_channel_parity : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 3, 4, 4}}); + auto reshape_flat = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {4, 3, 2, 2}}; + std::vector indices = {0, 2, 8, 10, 16, 18, 24, 26, 32, 34, 40, 42, + 4, 6, 12, 14, 20, 22, 28, 30, 36, 38, 44, 46, + 1, 3, 9, 11, 17, 19, 25, 27, 33, 35, 41, 43, + 5, 7, 13, 15, 21, 23, 29, 31, 37, 39, 45, 47}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = + mm->add_instruction(migraphx::make_op("gather"), reshape_flat, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_channel_patch.cpp b/test/verify/test_gather_flatten_channel_patch.cpp new file mode 100644 index 00000000000..bc4dd1b7600 --- /dev/null +++ b/test/verify/test_gather_flatten_channel_patch.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_channel_patch : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 3, 4, 4}}); + auto reshape_flat = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {4, 3, 1, 1}}; + std::vector indices = {5, 21, 37, 9, 25, 41, 6, 22, 38, 10, 26, 42}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = + mm->add_instruction(migraphx::make_op("gather"), reshape_flat, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_multi_axis_stride.cpp b/test/verify/test_gather_flatten_multi_axis_stride.cpp new file mode 100644 index 00000000000..3669121828e --- /dev/null +++ b/test/verify/test_gather_flatten_multi_axis_stride.cpp @@ -0,0 +1,51 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_multi_axis_stride : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 3, 4, 4}}); + auto reshape_flat = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {48}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 3, 1, 4}}; + std::vector indices = {0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, + 4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = + mm->add_instruction(migraphx::make_op("gather"), reshape_flat, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_permutation.cpp b/test/verify/test_gather_flatten_permutation.cpp new file mode 100644 index 00000000000..15862adb759 --- /dev/null +++ b/test/verify/test_gather_flatten_permutation.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_permutation : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 1, 4, 4}}); + auto reshape_flat = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {16}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {4, 1, 2, 2}}; + std::vector indices = {0, 2, 8, 10, 4, 6, 12, 14, 1, 3, 9, 11, 5, 7, 13, 15}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = + mm->add_instruction(migraphx::make_op("gather"), reshape_flat, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_rectangular_three_axes.cpp b/test/verify/test_gather_flatten_rectangular_three_axes.cpp new file mode 100644 index 00000000000..3cf567420c4 --- /dev/null +++ b/test/verify/test_gather_flatten_rectangular_three_axes.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_rectangular_three_axes + : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("X", {migraphx::shape::float_type, {2, 24, 5}}); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 2, 3}}; + std::vector indices = {4, 5, 6, 8, 9, 10, 16, 17, 18, 20, 21, 22}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction( + migraphx::make_op("gather", {{"axis", int64_t{1}}}), data, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_rectangular_two_axes.cpp b/test/verify/test_gather_flatten_rectangular_two_axes.cpp new file mode 100644 index 00000000000..b4a14f53e8a --- /dev/null +++ b/test/verify/test_gather_flatten_rectangular_two_axes.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_rectangular_two_axes + : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("X", {migraphx::shape::float_type, {1, 12}}); + auto flatten = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), data); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 3}}; + std::vector indices = {4, 5, 6, 8, 9, 10}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction(migraphx::make_op("gather"), flatten, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_stride_first.cpp b/test/verify/test_gather_flatten_stride_first.cpp new file mode 100644 index 00000000000..07058073374 --- /dev/null +++ b/test/verify/test_gather_flatten_stride_first.cpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_stride_first : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 8}}); + auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {1, 4}}; + std::vector indices = {0, 2, 4, 6}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction(migraphx::make_op("gather"), reshape, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_stride_grid.cpp b/test/verify/test_gather_flatten_stride_grid.cpp new file mode 100644 index 00000000000..d6a61137597 --- /dev/null +++ b/test/verify/test_gather_flatten_stride_grid.cpp @@ -0,0 +1,53 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights + * reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to do so, subject to the + * following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_stride_grid : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("X", {migraphx::shape::float_type, {1, 3, 16, 16}}); + auto flatten = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {768}}}), data); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {1, 3, 4, 4}}; + std::vector indices = {17, 21, 25, 29, 81, 85, 89, 93, 145, 149, 153, 157, + 209, 213, 217, 221, 273, 277, 281, 285, 337, 341, 345, 349, + 401, 405, 409, 413, 465, 469, 473, 477, 529, 533, 537, 541, + 593, 597, 601, 605, 657, 661, 665, 669, 721, 725, 729, 733}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction(migraphx::make_op("gather"), flatten, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_stride_offset.cpp b/test/verify/test_gather_flatten_stride_offset.cpp new file mode 100644 index 00000000000..05f87955ea9 --- /dev/null +++ b/test/verify/test_gather_flatten_stride_offset.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights + * reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_stride_offset : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto data = mm->add_parameter("X", {migraphx::shape::float_type, {1, 16}}); + auto flatten = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {16}}}), data); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {1, 4}}; + std::vector indices = {1, 5, 9, 13}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = mm->add_instruction(migraphx::make_op("gather"), flatten, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_flatten_stride_slice.cpp b/test/verify/test_gather_flatten_stride_slice.cpp new file mode 100644 index 00000000000..62148bb163f --- /dev/null +++ b/test/verify/test_gather_flatten_stride_slice.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_flatten_stride_slice : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("X", {migraphx::shape::float_type, {1, 8}}); + auto reshape_flat = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), x); + + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 2}}; + std::vector indices = {1, 5, 2, 6}; + auto indices_literal = mm->add_literal(migraphx::literal{indices_shape, indices}); + + auto gather = + mm->add_instruction(migraphx::make_op("gather"), reshape_flat, indices_literal); + mm->add_return({gather}); + + return p; + } +}; diff --git a/test/verify/test_gather_simplify.cpp b/test/verify/test_gather_simplify.cpp new file mode 100644 index 00000000000..cf5a3d62dcd --- /dev/null +++ b/test/verify/test_gather_simplify.cpp @@ -0,0 +1,46 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_gather_simplify : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape data_shape{migraphx::shape::float_type, {2, 4}}; + migraphx::shape indices_shape{migraphx::shape::int32_type, {2, 3}}; + std::vector indices = {1, 1, 1, 2, 2, 2}; + auto data = mm->add_parameter("data", data_shape); + auto idx_lit = mm->add_literal(migraphx::literal{indices_shape, indices}); + auto gather = + mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), data, idx_lit); + mm->add_return({gather}); + return p; + } +};