Skip to content

Conversation

@umangyadav
Copy link
Member

@umangyadav umangyadav commented Dec 9, 2025

Motivation

When scaled GEMM was implemented in rocMLIR, TOSA did not have matmul operation which would take scale arguments. Therefore as a workaround rocMLIR implemented "decompose" transform for the migraphx.quant_dot operation which would decompose scaled gemm into

migraphx.convert %scale : f8E8M0FNU to f32
migraphx.convert %a : f4E2M1FN to f32
migraphx.mul %a, %scale : f32
migraphx.dot %aScaled, %bScaled

And then during TosaToRock it would try to patten match "convert" + "mul" with data type checking to see if it was doing scaled GEMM or not.

TOSA now has tosa.matmul_t_block_scaled which can take two additional arguments for scale_a and scale_b. Therefore this PR makes changes to make use of this new TOSA operator and adds lowering for the same in MIGraphXToTosa and then TosaToRock.

Fixes https://github.com/ROCm/rocMLIR-internal/issues/2139

Technical Details

Host side

TOSA has not implemented TosaToLinalg conversion for the tosa.matmul_t_block_scaled and therefore we still require decomposition for the host functions.

As a result when writing E2E test it should first do --clone-harness and then run host pipeline and kernel pipeline to apply different migraphx pipelines for both host and kernel functions. This PR has changes for the same.

Broadcasting of scales

migraphx.quant_dot operator expects the scales and A/B arguments to be of same shape. It achieves this by broadcasting scales. But tosa.matmul_t_block_scaled doestn't have this restriction. Therefore when doing MIGraphXToTosa, it would need to do unbroadcast scales and later during TosaToRock it adds back broadcasts as rock.gemm also requires scales and A/B to be of same shape.

Test Plan

CI passes

@umangyadav umangyadav requested a review from causten as a code owner December 9, 2025 21:20
@umangyadav umangyadav requested review from Copilot, dhernandez0, justinrosner and pabloantoniom and removed request for causten December 9, 2025 21:20
@umangyadav umangyadav self-assigned this Dec 9, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR replaces the workaround for scaled GEMM operations by using the new tosa.matmul_t_block_scaled operator that directly supports scale arguments, eliminating the need to decompose scaled GEMMs into separate convert, multiply, and matmul operations.

Key changes:

  • Added support for tosa.matmul_t_block_scaled in MIGraphXToTosa and TosaToRock conversions
  • Modified migraphx.quant_dot decomposition to only apply for non-kernel (host) functions
  • Updated test pipelines to apply different transformations for host vs kernel functions

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
mlir/test/fusion/pr-e2e/mixr-gemm-fp4/mixr-dot-fp4.mlir Updated RUN command to use separate host/kernel pipelines
mlir/test/fusion/pr-e2e/mixr-gemm-fp4/migraphx-quant-dot-fp4.mlir Updated RUN command to use separate host/kernel pipelines
mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_trp_rsp_rsp_quant_dot_unsqueeze_broadcast_add_add.mlir Updated RUN command to use separate host/kernel pipelines
mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_trp_rsp_rsp_quant_dot_rsp_broadcast_add_relu.mlir Updated RUN command to use separate host/kernel pipelines
mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_rsp_trp_squeeze_rsp_trp_squeeze_rsp_trp_rsp_rsp_trp_rsp_quant_dot_rsp.mlir Updated RUN command to use separate host/kernel pipelines
mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_rsp_trp_squeeze_rsp_rsp_trp_rsp_quant_dot.mlir Updated RUN command to use separate host/kernel pipelines
mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_rsp_rsp_quant_dot_rsp_broadcast_mul_add.mlir Updated RUN command to use separate host/kernel pipelines
mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_rsp_rsp_quant_dot_rsp_broadcast_add.mlir Updated RUN command to use separate host/kernel pipelines
mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_rsp_rsp_quant_dot_add_add.mlir Updated RUN command to use separate host/kernel pipelines
mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_rsp_rsp_quant_dot.mlir Updated RUN command to use separate host/kernel pipelines
mlir/test/Dialect/MIGraphX/quant-dot-decompose.mlir Added test verifying kernel functions don't get decomposed
mlir/test/Conversion/TosaToRock/tosa-to-rock.mlir Removed old pattern-matching tests for scaled GEMM
mlir/test/Conversion/TosaToRock/tosa-to-rock-matmul-t-block-scaled.mlir Added comprehensive tests for matmul_t_block_scaled lowering
mlir/test/Conversion/MIGraphXToTosa/quant-dot-scaled-to-matmul-t-block-scaled.mlir Added tests for quant_dot to matmul_t_block_scaled conversion
mlir/lib/Dialect/MIGraphX/Transforms/MIGraphXTransform.cpp Restricted decomposition to non-kernel functions
mlir/lib/Conversion/TosaToRock/TosaToRockPass.cpp Added matmul_t_block_scaled as illegal op
mlir/lib/Conversion/TosaToRock/TosaToRock.cpp Removed scale extraction logic from MatMulConverter; added MatmulTBlockScaledConverter with scale broadcasting
mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp Added support for converting quant_dot with scales to matmul_t_block_scaled

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

// Compute 3D output shape
SmallVector<int64_t> newDimsOut = {
batchInfo.newBatchOut,
(batchInfo.batchSizeB == 1 && batchInfo.batchSizeA != 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use batchInfo.newM instead of this?

Value inBReshaped = inB;
if (batchInfo.needsReshape) {
auto [reshapedA, reshapedB] =
reshapeTo3DForMatmul(rewriter, loc, inA, inB, batchInfo, elementTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we handle A and B together instead of independently? do they always require a reshape together?

if (hasScales) {
// TODO: only blockSize of 32 is supported for matmul_t_block_scaled for
// now
int64_t blockSize = 32;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't blockSize a param of the kernel somehow? can we assert something == 32?

SmallVector<int64_t> physScaleBShape = {batchInfo.newBatchB, scaleKDim,
batchInfo.nDim};

Value scaleBPhysical = unbroadcastScale(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

physical implies it's the actual memory layout? they could have done more broadcast or transposes, couldn't they?


// Convert optional attributes
if (auto attr = (*op).template getAttrOfType<StringAttr>("perf_config"))
matmulOp->setAttr("perf_config", attr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please, let's a add a test where we check perf_config is copied correctly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need to set "acc_type" here?

auto mergeAttr = mergeB.get();
return rock::TransformOp::create(b, loc, scale, mergeAttr);
}
// else if not transposed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we unify if/else by having variables for indices etc?

// Output is [batch, M, N]
auto bDataType = cast<RankedTensorType>(bData.getType());
ArrayRef<int64_t> bShape = bDataType.getShape();
// B physical shape depends on transpose:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as my question in another comment, do we know for sure this is the "physical" (I guess underlying layout)? couldn't they broadcast it before? or reshape it etc? if that's the case I wouldn't use "physical"

transposeB, transposeC, nullptr, nullptr,
rw, loc, outputType, aData, bData, output, brAScale, brBScale,
transposeA, transposeB,
/*cTransposed=*/nullptr, aScaleTransposed, bScaleTransposed,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't C be transposed? we have transpose_c in tosa::matmulop lowering

return matMulOp.emitWarning(
"transpose found leading to a matmul input other than A or B");
}
} else if (auto matMulTBlockScaledOp =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of this can be removed since we have the SortDimensions pass. So, basically, whatever we do here (regarding tranposes etc) will get overwritten by SortDimensions, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we could remove TransposeRewritePattern completely I think.

// side non-kernel functions, we need to run the conversion manually. for
// the kernel side, tosa.matmul_t_block_scaled is converted to rock.gemm
// with scales in TosaToRock.cpp
if (!func->hasAttr("kernel")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we use options.disableRock instead of this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's too difficult, this is fine anyway, we use it everywhere anyway...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants