Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/reference/MIGraphX-dev-env-vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,14 @@ Debug settings for passes.

| Default: Multi-output pointwise fusion is enabled.

* - | ``MIGRAPHX_DISABLE_ELIMINATE_INT64``
| When set, int64 types are preserved instead of being converted to int32.

- | ``1``: int64 types are preserved natively.
| ``0``: Returns to default behavior.

| Default: int64 is converted to int32.

* - | ``MIGRAPHX_TRACE_PASSES``
| Turns on printing of the compile passes and the program after the passes.

Expand Down
12 changes: 11 additions & 1 deletion src/eliminate_data_type.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -112,6 +112,16 @@ void eliminate_data_type::apply(module& m) const
if(unsupported_types.empty())
return;

// Warn when converting int64 to int32 as this may cause overflow for large indices
if(contains(unsupported_types, shape::type_t::int64_type) and
target_type == shape::type_t::int32_type)
{
std::cerr << "Warning: Converting int64 to int32. Values exceeding int32 range "
"(>2147483647) may overflow and cause incorrect indexing. "
"Set MIGRAPHX_DISABLE_ELIMINATE_INT64=1 to preserve int64 precision."
<< std::endl;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we should have a warning for this, especially here in the pass as the MIGRAPHX_DISABLE_ELIMINATE_INT64 variable doesnt change the behavior of the class. It would be better to have a more general env variable to skip type elimination in the pass with MIGRAPHX_DISABLE_ELIMINATE_TYPES=int64_t,uint64_t.


for(auto ins : iterator_for(m))
{
if(ins->name()[0] == '@')
Expand Down
12 changes: 12 additions & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
#endif
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_SET_GEMM_PROVIDER)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_FULL_DYNAMIC)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_ELIMINATE_INT64)

std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const
{
Expand All @@ -107,6 +108,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::int8_type);
unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type);
// int64_type handling: If MIGRAPHX_DISABLE_ELIMINATE_INT64=1, support int64 natively.
// Otherwise, it will be converted to int32_type via a separate pass below
// to preserve integer semantics for operations like mod, gather, slice.
if(enabled(MIGRAPHX_DISABLE_ELIMINATE_INT64{}))
{
unsupported_types.erase(shape::type_t::int64_type);
}
unsupported_types.erase(shape::type_t::tuple_type);

// No BF-16 Support on Navi21
Expand Down Expand Up @@ -197,6 +205,10 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
// workaround for rocBLAS unsupported error when using uint8 in quant_dot, quant_convolution & pooling
eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_convolution", "quant_dot", "pooling"}},
// Convert int64 to int32 to preserve integer semantics (mod, gather, slice need exact integers)
// This must run before the general float elimination to avoid int64 -> float conversion
// Set MIGRAPHX_DISABLE_ELIMINATE_INT64=1 to keep native int64 support
enable_pass(disabled(MIGRAPHX_DISABLE_ELIMINATE_INT64{}), eliminate_data_type{{shape::type_t::int64_type}, shape::type_t::int32_type}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to have a one pass to convert float types to fp32 and another to convert integer types to int32.

Also think all these eliminate_data_type pass should be moved to a meta pass that calls them all instead of putting al this logic into get_passes.

eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{},
eliminate_identity{},
Expand Down
78 changes: 77 additions & 1 deletion test/eliminate_data_type_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -39,6 +39,15 @@ static void run_pass(migraphx::module& m, std::set<migraphx::shape::type_t> type
migraphx::dead_code_elimination{}});
}

static void run_pass_to_int32(migraphx::module& m, std::set<migraphx::shape::type_t> types)
{
migraphx::run_passes(
m,
{migraphx::eliminate_data_type{std::move(types), migraphx::shape::int32_type},
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{}});
}

TEST_CASE(simple)
{
migraphx::shape s{migraphx::shape::int8_type, {2, 2}};
Expand Down Expand Up @@ -91,4 +100,71 @@ TEST_CASE(quant)
EXPECT(mm1 == mm2);
}

// Test int64 to int32 conversion for mod operation
// This ensures integer semantics are preserved (not converted to float)
TEST_CASE(int64_to_int32_mod)
{
migraphx::shape s{migraphx::shape::int64_type, {1, 96}};
migraphx::module mm1;
{
auto x = mm1.add_parameter("x", s);
auto y = mm1.add_parameter("y", s);
mm1.add_instruction(migraphx::make_op("mod"), x, y);
}
run_pass_to_int32(mm1, {migraphx::shape::int64_type});

migraphx::module mm2;
{
auto x = mm2.add_parameter("x", s);
auto y = mm2.add_parameter("y", s);
auto int32x = mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), x);
auto int32y = mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), y);
auto mod_result = mm2.add_instruction(migraphx::make_op("mod"), int32x, int32y);
mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}),
mod_result);
}
EXPECT(mm1 == mm2);
}

// Test int64 to int32 conversion for mul then mod (common pattern in gather index computation)
TEST_CASE(int64_to_int32_mul_mod)
{
migraphx::shape s{migraphx::shape::int64_type, {1, 96}};
migraphx::module mm1;
{
auto x = mm1.add_parameter("x", s);
auto y = mm1.add_parameter("y", s);
auto z = mm1.add_parameter("z", s);
auto mul = mm1.add_instruction(migraphx::make_op("mul"), x, y);
mm1.add_instruction(migraphx::make_op("mod"), mul, z);
}
run_pass_to_int32(mm1, {migraphx::shape::int64_type});

migraphx::module mm2;
{
auto x = mm2.add_parameter("x", s);
auto y = mm2.add_parameter("y", s);
auto z = mm2.add_parameter("z", s);
auto int32x = mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), x);
auto int32y = mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), y);
auto mul = mm2.add_instruction(migraphx::make_op("mul"), int32x, int32y);
auto mul_back = mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), mul);
auto mul_to_int32 = mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), mul_back);
auto int32z = mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), z);
auto mod_result = mm2.add_instruction(migraphx::make_op("mod"), mul_to_int32, int32z);
mm2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}),
mod_result);
}
EXPECT(mm1 == mm2);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
Loading