Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/onnx/broadcast_qdq.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -61,6 +61,42 @@ instruction_ref bcast_qdq_instr(const std::string& op_name,
return info.add_instruction(migraphx::make_op(op_name), x_in, bcast_scale, bcast_zero_pt);
}

instruction_ref bcast_qdq_instr_matmul(const std::string& op_name,
instruction_ref x_in,
instruction_ref arg_fscale,
instruction_ref arg_z_pt,
const onnx_parser::node_info& info)
{
auto in_lens = x_in->get_shape().lens();

// prep 1: broadcast scale. it can come as a scalar or a 1-D tensor.
instruction_ref bcast_scale;
if(arg_fscale->get_shape().elements() > 1)
{
auto axis = x_in->get_shape().lens().size() - arg_fscale->get_shape().lens().size();
bcast_scale = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", in_lens}}), arg_fscale);
}
else
bcast_scale = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", in_lens}}), arg_fscale);

// prep 2: broadcast zero point. it can come as a scalar or a 1-D tensor.
instruction_ref bcast_zero_pt;
if(arg_z_pt->get_shape().elements() > 1)
{
auto axis = x_in->get_shape().lens().size() - arg_z_pt->get_shape().lens().size();
bcast_zero_pt = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", in_lens}}), arg_z_pt);
}
else
bcast_zero_pt = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", in_lens}}), arg_z_pt);

// op_name is either quantizelinear or dequantizelinear:
return info.add_instruction(migraphx::make_op(op_name), x_in, bcast_scale, bcast_zero_pt);
}

// Multibroadcast a scaler..
instruction_ref bcast_scalar_instr(const migraphx::shape& shape_out,
instruction_ref arg_in,
Expand Down
8 changes: 7 additions & 1 deletion src/onnx/include/migraphx/onnx/broadcast_qdq.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -44,6 +44,12 @@ instruction_ref bcast_qdq_instr(const std::string& op_name,
instruction_ref arg_z_pt,
const onnx_parser::node_info& info);

instruction_ref bcast_qdq_instr_matmul(const std::string& op_name,
instruction_ref x_in,
instruction_ref arg_fscale,
instruction_ref arg_z_pt,
const onnx_parser::node_info& info);

// Multibroadcast a scaler..
instruction_ref bcast_scalar_instr(const migraphx::shape& shape_out,
instruction_ref arg_in,
Expand Down
22 changes: 20 additions & 2 deletions src/onnx/parse_qlinearconv.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -229,7 +229,25 @@ struct parse_qlinearconv : op_parser<parse_qlinearconv>

// Biases, if any.. : is an optional argument.
if(args.size() > 8)
conv_x_w = add_bias_to_conv(args[8], conv_x_w, info);
{
const auto& in_b = args[8];
auto b_sh = in_b->get_shape();

auto bcast_scale_x = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", in_scale_w->get_shape().lens()}}),
in_scale_x);

auto bias_scale =
info.add_instruction(migraphx::make_op("mul"), bcast_scale_x, in_scale_w);
auto zero_lit = info.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0}});
auto bias_zp = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", b_sh.lens()}}), zero_lit);
auto dquant_bias = info.add_instruction(
migraphx::make_op("dequantizelinear"), args[8], bias_scale, bias_zp);

conv_x_w = add_bias_to_conv(dquant_bias, conv_x_w, info);
}

return bcast_qdq_instr("quantizelinear", conv_x_w, in_scale_y, in_zero_pt_y, info);
}
Expand Down
18 changes: 14 additions & 4 deletions src/onnx/parse_qlinearmatmul.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -138,9 +138,17 @@ struct parse_qlinearmatmul : op_parser<parse_qlinearmatmul>
not std::equal(lens_a.rbegin() + 2, lens_a.rend(), lens_b.rbegin() + 2, lens_b.rend()))
MIGRAPHX_THROW("QLINEARMATMUL: mismatched input dimensions");

if(migraphx::any_of({args[1], args[2], args[4], args[5]},
if(migraphx::any_of({args[1], args[2]},
[](auto arg) { return not arg->get_shape().scalar(); }))
MIGRAPHX_THROW("QLINEARMATMUL: unsupported row/column quantization");

const auto& in_scale_b = args[4];
const auto& in_zero_pt_b = args[5];
size_t dim_scale_b = in_scale_b->get_shape().lens().size();
size_t dim_zero_pt_b = in_zero_pt_b->get_shape().lens().size();

if((dim_scale_b > 1) or (dim_zero_pt_b > 1))
MIGRAPHX_THROW("QLINEARMATMUL: unsupported row/column quantization");
}

instruction_ref parse(const op_desc& /* opd */,
Expand All @@ -154,13 +162,15 @@ struct parse_qlinearmatmul : op_parser<parse_qlinearmatmul>
const auto& in_a = args[0];
const auto& in_scale_a = args[1];
const auto& in_zero_pt_a = args[2];
auto dquant_a = bcast_qdq_instr("dequantizelinear", in_a, in_scale_a, in_zero_pt_a, info);
auto dquant_a =
bcast_qdq_instr_matmul("dequantizelinear", in_a, in_scale_a, in_zero_pt_a, info);

// B
const auto& in_b = args[3];
const auto& in_scale_b = args[4];
const auto& in_zero_pt_b = args[5];
auto dquant_b = bcast_qdq_instr("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info);
auto dquant_b =
bcast_qdq_instr_matmul("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info);

bool is_a_prepended = false;
bool is_b_appended = false;
Expand Down
Loading