diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index 5524ee45cf5..11cf4d40f5d 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -432,6 +432,14 @@ Debug settings for passes. | Default: Multi-output pointwise fusion is enabled. + * - | ``MIGRAPHX_DISABLE_ELIMINATE_INT64`` + | When set, int64 types are preserved instead of being converted to int32. + + - | ``1``: int64 types are preserved natively. + | ``0``: Returns to default behavior. + + | Default: int64 is converted to int32. + * - | ``MIGRAPHX_TRACE_PASSES`` | Turns on printing of the compile passes and the program after the passes. diff --git a/src/eliminate_data_type.cpp b/src/eliminate_data_type.cpp index 067f730db99..72b404626eb 100644 --- a/src/eliminate_data_type.cpp +++ b/src/eliminate_data_type.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -112,6 +112,16 @@ void eliminate_data_type::apply(module& m) const if(unsupported_types.empty()) return; + // Warn when converting int64 to int32 as this may cause overflow for large indices + if(contains(unsupported_types, shape::type_t::int64_type) and + target_type == shape::type_t::int32_type) + { + std::cerr << "Warning: Converting int64 to int32. Values exceeding int32 range " + "(>2147483647) may overflow and cause incorrect indexing. " + "Set MIGRAPHX_DISABLE_ELIMINATE_INT64=1 to preserve int64 precision." + << std::endl; + } + for(auto ins : iterator_for(m)) { if(ins->name()[0] == '@') diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 891d80631bc..ab77521e946 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -90,6 +90,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_SET_GEMM_PROVIDER) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_FULL_DYNAMIC) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_ELIMINATE_INT64) std::vector target::get_passes(migraphx::context& gctx, const compile_options& options) const { @@ -107,6 +108,13 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_types.erase(shape::type_t::int8_type); unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::int32_type); + // int64_type handling: If MIGRAPHX_DISABLE_ELIMINATE_INT64=1, support int64 natively. + // Otherwise, it will be converted to int32_type via a separate pass below + // to preserve integer semantics for operations like mod, gather, slice. + if(enabled(MIGRAPHX_DISABLE_ELIMINATE_INT64{})) + { + unsupported_types.erase(shape::type_t::int64_type); + } unsupported_types.erase(shape::type_t::tuple_type); // No BF-16 Support on Navi21 @@ -197,6 +205,10 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, // workaround for rocBLAS unsupported error when using uint8 in quant_dot, quant_convolution & pooling eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_convolution", "quant_dot", "pooling"}}, + // Convert int64 to int32 to preserve integer semantics (mod, gather, slice need exact integers) + // This must run before the general float elimination to avoid int64 -> float conversion + // Set MIGRAPHX_DISABLE_ELIMINATE_INT64=1 to keep native int64 support + enable_pass(disabled(MIGRAPHX_DISABLE_ELIMINATE_INT64{}), eliminate_data_type{{shape::type_t::int64_type}, shape::type_t::int32_type}), eliminate_data_type{unsupported_types, shape::type_t::float_type}, simplify_reshapes{}, eliminate_identity{}, diff --git a/test/eliminate_data_type_test.cpp b/test/eliminate_data_type_test.cpp index d93b37ebd7b..1a15a1af30f 100644 --- a/test/eliminate_data_type_test.cpp +++ b/test/eliminate_data_type_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -39,6 +39,15 @@ static void run_pass(migraphx::module& m, std::set type migraphx::dead_code_elimination{}}); } +static void run_pass_to_int32(migraphx::module& m, std::set types) +{ + migraphx::run_passes( + m, + {migraphx::eliminate_data_type{std::move(types), migraphx::shape::int32_type}, + migraphx::eliminate_identity{}, + migraphx::dead_code_elimination{}}); +} + TEST_CASE(simple) { migraphx::shape s{migraphx::shape::int8_type, {2, 2}}; @@ -91,4 +100,71 @@ TEST_CASE(quant) EXPECT(mm1 == mm2); } +// Test int64 to int32 conversion for mod operation +// This ensures integer semantics are preserved (not converted to float) +TEST_CASE(int64_to_int32_mod) +{ + migraphx::shape s{migraphx::shape::int64_type, {1, 96}}; + migraphx::module mm1; + { + auto x = mm1.add_parameter("x", s); + auto y = mm1.add_parameter("y", s); + mm1.add_instruction(migraphx::make_op("mod"), x, y); + } + run_pass_to_int32(mm1, {migraphx::shape::int64_type}); + + migraphx::module mm2; + { + auto x = mm2.add_parameter("x", s); + auto y = mm2.add_parameter("y", s); + auto int32x = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), x); + auto int32y = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), y); + auto mod_result = mm2.add_instruction(migraphx::make_op("mod"), int32x, int32y); + mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), + mod_result); + } + EXPECT(mm1 == mm2); +} + +// Test int64 to int32 conversion for mul then mod (common pattern in gather index computation) +TEST_CASE(int64_to_int32_mul_mod) +{ + migraphx::shape s{migraphx::shape::int64_type, {1, 96}}; + migraphx::module mm1; + { + auto x = mm1.add_parameter("x", s); + auto y = mm1.add_parameter("y", s); + auto z = mm1.add_parameter("z", s); + auto mul = mm1.add_instruction(migraphx::make_op("mul"), x, y); + mm1.add_instruction(migraphx::make_op("mod"), mul, z); + } + run_pass_to_int32(mm1, {migraphx::shape::int64_type}); + + migraphx::module mm2; + { + auto x = mm2.add_parameter("x", s); + auto y = mm2.add_parameter("y", s); + auto z = mm2.add_parameter("z", s); + auto int32x = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), x); + auto int32y = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), y); + auto mul = mm2.add_instruction(migraphx::make_op("mul"), int32x, int32y); + auto mul_back = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), mul); + auto mul_to_int32 = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), mul_back); + auto int32z = mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int32_type}}), z); + auto mod_result = mm2.add_instruction(migraphx::make_op("mod"), mul_to_int32, int32z); + mm2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int64_type}}), + mod_result); + } + EXPECT(mm1 == mm2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); }