From b49aaa61cecc416c4eefcbd1c31be1ccfd421b6d Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Fri, 9 Jan 2026 00:00:43 -0600 Subject: [PATCH 1/4] try to convert int64 to int32 or don't eliminate it with flag --- src/eliminate_data_type.cpp | 12 ++++- src/targets/gpu/target.cpp | 12 +++++ test/eliminate_data_type_test.cpp | 76 ++++++++++++++++++++++++++++++- 3 files changed, 98 insertions(+), 2 deletions(-) diff --git a/src/eliminate_data_type.cpp b/src/eliminate_data_type.cpp index 067f730db99..72b404626eb 100644 --- a/src/eliminate_data_type.cpp +++ b/src/eliminate_data_type.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -112,6 +112,16 @@ void eliminate_data_type::apply(module& m) const if(unsupported_types.empty()) return; + // Warn when converting int64 to int32 as this may cause overflow for large indices + if(contains(unsupported_types, shape::type_t::int64_type) and + target_type == shape::type_t::int32_type) + { + std::cerr << "Warning: Converting int64 to int32. Values exceeding int32 range " + "(>2147483647) may overflow and cause incorrect indexing. " + "Set MIGRAPHX_DISABLE_ELIMINATE_INT64=1 to preserve int64 precision." + << std::endl; + } + for(auto ins : iterator_for(m)) { if(ins->name()[0] == '@') diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 891d80631bc..ab77521e946 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -90,6 +90,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_SET_GEMM_PROVIDER) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_FULL_DYNAMIC) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_ELIMINATE_INT64) std::vector target::get_passes(migraphx::context& gctx, const compile_options& options) const { @@ -107,6 +108,13 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_types.erase(shape::type_t::int8_type); unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::int32_type); + // int64_type handling: If MIGRAPHX_DISABLE_ELIMINATE_INT64=1, support int64 natively. + // Otherwise, it will be converted to int32_type via a separate pass below + // to preserve integer semantics for operations like mod, gather, slice. + if(enabled(MIGRAPHX_DISABLE_ELIMINATE_INT64{})) + { + unsupported_types.erase(shape::type_t::int64_type); + } unsupported_types.erase(shape::type_t::tuple_type); // No BF-16 Support on Navi21 @@ -197,6 +205,10 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, // workaround for rocBLAS unsupported error when using uint8 in quant_dot, quant_convolution & pooling eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_convolution", "quant_dot", "pooling"}}, + // Convert int64 to int32 to preserve integer semantics (mod, gather, slice need exact integers) + // This must run before the general float elimination to avoid int64 -> float conversion + // Set MIGRAPHX_DISABLE_ELIMINATE_INT64=1 to keep native int64 support + enable_pass(disabled(MIGRAPHX_DISABLE_ELIMINATE_INT64{}), eliminate_data_type{{shape::type_t::int64_type}, shape::type_t::int32_type}), eliminate_data_type{unsupported_types, shape::type_t::float_type}, simplify_reshapes{}, eliminate_identity{}, diff --git a/test/eliminate_data_type_test.cpp b/test/eliminate_data_type_test.cpp index d93b37ebd7b..b868f1f2472 100644 --- a/test/eliminate_data_type_test.cpp +++ b/test/eliminate_data_type_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -39,6 +39,15 @@ static void run_pass(migraphx::module& m, std::set type migraphx::dead_code_elimination{}}); } +static void run_pass_to_int32(migraphx::module& m, std::set types) +{ + migraphx::run_passes( + m, + {migraphx::eliminate_data_type{std::move(types), migraphx::shape::int32_type}, + migraphx::eliminate_identity{}, + migraphx::dead_code_elimination{}}); +} + TEST_CASE(simple) { migraphx::shape s{migraphx::shape::int8_type, {2, 2}}; @@ -91,4 +100,69 @@ TEST_CASE(quant) EXPECT(mm1 == mm2); } +// Test int64 to int32 conversion for mod operation +// This ensures integer semantics are preserved (not converted to float) +TEST_CASE(int64_to_int32_mod) +{ + migraphx::shape s{migraphx::shape::int64_type, {1, 96}}; + migraphx::module mm1; + { + auto x = mm1.add_parameter("x", s); + auto y = mm1.add_parameter("y", s); + mm1.add_instruction(migraphx::make_op("mod"), x, y); + } + run_pass_to_int32(mm1, {migraphx::shape::int64_type}); + + migraphx::module mm2; + { + auto x = mm2.add_parameter("x", s); + auto y = mm2.add_parameter("y", s); + auto int32x = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), x); + auto int32y = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), y); + auto mod_result = mm2.add_instruction(migraphx::make_op("mod"), int32x, int32y); + mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), mod_result); + } + EXPECT(mm1 == mm2); +} + +// Test int64 to int32 conversion for mul then mod (common pattern in gather index computation) +TEST_CASE(int64_to_int32_mul_mod) +{ + migraphx::shape s{migraphx::shape::int64_type, {1, 96}}; + migraphx::module mm1; + { + auto x = mm1.add_parameter("x", s); + auto y = mm1.add_parameter("y", s); + auto z = mm1.add_parameter("z", s); + auto mul = mm1.add_instruction(migraphx::make_op("mul"), x, y); + mm1.add_instruction(migraphx::make_op("mod"), mul, z); + } + run_pass_to_int32(mm1, {migraphx::shape::int64_type}); + + migraphx::module mm2; + { + auto x = mm2.add_parameter("x", s); + auto y = mm2.add_parameter("y", s); + auto z = mm2.add_parameter("z", s); + auto int32x = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), x); + auto int32y = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), y); + auto int32z = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), z); + auto mul = mm2.add_instruction(migraphx::make_op("mul"), int32x, int32y); + auto mul_back = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), mul); + auto mul_to_int32 = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), mul_back); + auto mod_result = mm2.add_instruction(migraphx::make_op("mod"), mul_to_int32, int32z); + mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), mod_result); + } + EXPECT(mm1 == mm2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From e9f0f0e955d387b7045b0282f68307fb14cb0e74 Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Fri, 9 Jan 2026 00:02:04 -0600 Subject: [PATCH 2/4] formatting --- test/eliminate_data_type_test.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/test/eliminate_data_type_test.cpp b/test/eliminate_data_type_test.cpp index b868f1f2472..e6971ac61f1 100644 --- a/test/eliminate_data_type_test.cpp +++ b/test/eliminate_data_type_test.cpp @@ -115,15 +115,16 @@ TEST_CASE(int64_to_int32_mod) migraphx::module mm2; { - auto x = mm2.add_parameter("x", s); - auto y = mm2.add_parameter("y", s); + auto x = mm2.add_parameter("x", s); + auto y = mm2.add_parameter("y", s); auto int32x = mm2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), x); auto int32y = mm2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), y); auto mod_result = mm2.add_instruction(migraphx::make_op("mod"), int32x, int32y); mm2.add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), mod_result); + migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), + mod_result); } EXPECT(mm1 == mm2); } @@ -134,9 +135,9 @@ TEST_CASE(int64_to_int32_mul_mod) migraphx::shape s{migraphx::shape::int64_type, {1, 96}}; migraphx::module mm1; { - auto x = mm1.add_parameter("x", s); - auto y = mm1.add_parameter("y", s); - auto z = mm1.add_parameter("z", s); + auto x = mm1.add_parameter("x", s); + auto y = mm1.add_parameter("y", s); + auto z = mm1.add_parameter("z", s); auto mul = mm1.add_instruction(migraphx::make_op("mul"), x, y); mm1.add_instruction(migraphx::make_op("mod"), mul, z); } @@ -153,14 +154,15 @@ TEST_CASE(int64_to_int32_mul_mod) migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), y); auto int32z = mm2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), z); - auto mul = mm2.add_instruction(migraphx::make_op("mul"), int32x, int32y); + auto mul = mm2.add_instruction(migraphx::make_op("mul"), int32x, int32y); auto mul_back = mm2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), mul); auto mul_to_int32 = mm2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), mul_back); auto mod_result = mm2.add_instruction(migraphx::make_op("mod"), mul_to_int32, int32z); mm2.add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), mod_result); + migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), + mod_result); } EXPECT(mm1 == mm2); } From cb8ac2e305d662aa42de175e400b31aca23d2b59 Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Fri, 9 Jan 2026 00:05:04 -0600 Subject: [PATCH 3/4] fix test --- test/eliminate_data_type_test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/eliminate_data_type_test.cpp b/test/eliminate_data_type_test.cpp index e6971ac61f1..1a15a1af30f 100644 --- a/test/eliminate_data_type_test.cpp +++ b/test/eliminate_data_type_test.cpp @@ -152,13 +152,13 @@ TEST_CASE(int64_to_int32_mul_mod) migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), x); auto int32y = mm2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), y); - auto int32z = mm2.add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), z); auto mul = mm2.add_instruction(migraphx::make_op("mul"), int32x, int32y); auto mul_back = mm2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), mul); auto mul_to_int32 = mm2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), mul_back); + auto int32z = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), z); auto mod_result = mm2.add_instruction(migraphx::make_op("mod"), mul_to_int32, int32z); mm2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), From cdbcc6825657d7943e15fa3bf673d95243812260 Mon Sep 17 00:00:00 2001 From: kahmed10 <15948690+kahmed10@users.noreply.github.com> Date: Fri, 9 Jan 2026 00:13:22 -0600 Subject: [PATCH 4/4] update doc --- docs/reference/MIGraphX-dev-env-vars.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index 5524ee45cf5..11cf4d40f5d 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -432,6 +432,14 @@ Debug settings for passes. | Default: Multi-output pointwise fusion is enabled. + * - | ``MIGRAPHX_DISABLE_ELIMINATE_INT64`` + | When set, int64 types are preserved instead of being converted to int32. + + - | ``1``: int64 types are preserved natively. + | ``0``: Returns to default behavior. + + | Default: int64 is converted to int32. + * - | ``MIGRAPHX_TRACE_PASSES`` | Turns on printing of the compile passes and the program after the passes.