diff --git a/src/onnx/broadcast_qdq.cpp b/src/onnx/broadcast_qdq.cpp index dbe6878a4cd..9469894d200 100644 --- a/src/onnx/broadcast_qdq.cpp +++ b/src/onnx/broadcast_qdq.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -61,6 +61,42 @@ instruction_ref bcast_qdq_instr(const std::string& op_name, return info.add_instruction(migraphx::make_op(op_name), x_in, bcast_scale, bcast_zero_pt); } +instruction_ref bcast_qdq_instr_matmul(const std::string& op_name, + instruction_ref x_in, + instruction_ref arg_fscale, + instruction_ref arg_z_pt, + const onnx_parser::node_info& info) +{ + auto in_lens = x_in->get_shape().lens(); + + // prep 1: broadcast scale. it can come as a scalar or a 1-D tensor. + instruction_ref bcast_scale; + if(arg_fscale->get_shape().elements() > 1) + { + auto axis = x_in->get_shape().lens().size() - arg_fscale->get_shape().lens().size(); + bcast_scale = info.add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", in_lens}}), arg_fscale); + } + else + bcast_scale = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", in_lens}}), arg_fscale); + + // prep 2: broadcast zero point. it can come as a scalar or a 1-D tensor. + instruction_ref bcast_zero_pt; + if(arg_z_pt->get_shape().elements() > 1) + { + auto axis = x_in->get_shape().lens().size() - arg_z_pt->get_shape().lens().size(); + bcast_zero_pt = info.add_instruction( + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", in_lens}}), arg_z_pt); + } + else + bcast_zero_pt = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", in_lens}}), arg_z_pt); + + // op_name is either quantizelinear or dequantizelinear: + return info.add_instruction(migraphx::make_op(op_name), x_in, bcast_scale, bcast_zero_pt); +} + // Multibroadcast a scaler.. instruction_ref bcast_scalar_instr(const migraphx::shape& shape_out, instruction_ref arg_in, diff --git a/src/onnx/include/migraphx/onnx/broadcast_qdq.hpp b/src/onnx/include/migraphx/onnx/broadcast_qdq.hpp index 04432b01d86..b1362b6ad4b 100644 --- a/src/onnx/include/migraphx/onnx/broadcast_qdq.hpp +++ b/src/onnx/include/migraphx/onnx/broadcast_qdq.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -44,6 +44,12 @@ instruction_ref bcast_qdq_instr(const std::string& op_name, instruction_ref arg_z_pt, const onnx_parser::node_info& info); +instruction_ref bcast_qdq_instr_matmul(const std::string& op_name, + instruction_ref x_in, + instruction_ref arg_fscale, + instruction_ref arg_z_pt, + const onnx_parser::node_info& info); + // Multibroadcast a scaler.. instruction_ref bcast_scalar_instr(const migraphx::shape& shape_out, instruction_ref arg_in, diff --git a/src/onnx/parse_qlinearconv.cpp b/src/onnx/parse_qlinearconv.cpp index 26f2f7b9125..50a06a55d2c 100644 --- a/src/onnx/parse_qlinearconv.cpp +++ b/src/onnx/parse_qlinearconv.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -229,7 +229,25 @@ struct parse_qlinearconv : op_parser // Biases, if any.. : is an optional argument. if(args.size() > 8) - conv_x_w = add_bias_to_conv(args[8], conv_x_w, info); + { + const auto& in_b = args[8]; + auto b_sh = in_b->get_shape(); + + auto bcast_scale_x = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", in_scale_w->get_shape().lens()}}), + in_scale_x); + + auto bias_scale = + info.add_instruction(migraphx::make_op("mul"), bcast_scale_x, in_scale_w); + auto zero_lit = info.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0}}); + auto bias_zp = info.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", b_sh.lens()}}), zero_lit); + auto dquant_bias = info.add_instruction( + migraphx::make_op("dequantizelinear"), args[8], bias_scale, bias_zp); + + conv_x_w = add_bias_to_conv(dquant_bias, conv_x_w, info); + } return bcast_qdq_instr("quantizelinear", conv_x_w, in_scale_y, in_zero_pt_y, info); } diff --git a/src/onnx/parse_qlinearmatmul.cpp b/src/onnx/parse_qlinearmatmul.cpp index 1b430ab6fbd..4d3a1670749 100644 --- a/src/onnx/parse_qlinearmatmul.cpp +++ b/src/onnx/parse_qlinearmatmul.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -138,9 +138,17 @@ struct parse_qlinearmatmul : op_parser not std::equal(lens_a.rbegin() + 2, lens_a.rend(), lens_b.rbegin() + 2, lens_b.rend())) MIGRAPHX_THROW("QLINEARMATMUL: mismatched input dimensions"); - if(migraphx::any_of({args[1], args[2], args[4], args[5]}, + if(migraphx::any_of({args[1], args[2]}, [](auto arg) { return not arg->get_shape().scalar(); })) MIGRAPHX_THROW("QLINEARMATMUL: unsupported row/column quantization"); + + const auto& in_scale_b = args[4]; + const auto& in_zero_pt_b = args[5]; + size_t dim_scale_b = in_scale_b->get_shape().lens().size(); + size_t dim_zero_pt_b = in_zero_pt_b->get_shape().lens().size(); + + if((dim_scale_b > 1) or (dim_zero_pt_b > 1)) + MIGRAPHX_THROW("QLINEARMATMUL: unsupported row/column quantization"); } instruction_ref parse(const op_desc& /* opd */, @@ -154,13 +162,15 @@ struct parse_qlinearmatmul : op_parser const auto& in_a = args[0]; const auto& in_scale_a = args[1]; const auto& in_zero_pt_a = args[2]; - auto dquant_a = bcast_qdq_instr("dequantizelinear", in_a, in_scale_a, in_zero_pt_a, info); + auto dquant_a = + bcast_qdq_instr_matmul("dequantizelinear", in_a, in_scale_a, in_zero_pt_a, info); // B const auto& in_b = args[3]; const auto& in_scale_b = args[4]; const auto& in_zero_pt_b = args[5]; - auto dquant_b = bcast_qdq_instr("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info); + auto dquant_b = + bcast_qdq_instr_matmul("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info); bool is_a_prepended = false; bool is_b_appended = false;