-
Notifications
You must be signed in to change notification settings - Fork 52
Use tosa.matmul_t_block_scaled for the scaled GEMMs
#2164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR replaces the workaround for scaled GEMM operations by using the new tosa.matmul_t_block_scaled operator that directly supports scale arguments, eliminating the need to decompose scaled GEMMs into separate convert, multiply, and matmul operations.
Key changes:
- Added support for
tosa.matmul_t_block_scaledin MIGraphXToTosa and TosaToRock conversions - Modified
migraphx.quant_dotdecomposition to only apply for non-kernel (host) functions - Updated test pipelines to apply different transformations for host vs kernel functions
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/test/fusion/pr-e2e/mixr-gemm-fp4/mixr-dot-fp4.mlir | Updated RUN command to use separate host/kernel pipelines |
| mlir/test/fusion/pr-e2e/mixr-gemm-fp4/migraphx-quant-dot-fp4.mlir | Updated RUN command to use separate host/kernel pipelines |
| mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_trp_rsp_rsp_quant_dot_unsqueeze_broadcast_add_add.mlir | Updated RUN command to use separate host/kernel pipelines |
| mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_trp_rsp_rsp_quant_dot_rsp_broadcast_add_relu.mlir | Updated RUN command to use separate host/kernel pipelines |
| mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_rsp_trp_squeeze_rsp_trp_squeeze_rsp_trp_rsp_rsp_trp_rsp_quant_dot_rsp.mlir | Updated RUN command to use separate host/kernel pipelines |
| mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_rsp_trp_squeeze_rsp_rsp_trp_rsp_quant_dot.mlir | Updated RUN command to use separate host/kernel pipelines |
| mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_rsp_rsp_quant_dot_rsp_broadcast_mul_add.mlir | Updated RUN command to use separate host/kernel pipelines |
| mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_rsp_rsp_quant_dot_rsp_broadcast_add.mlir | Updated RUN command to use separate host/kernel pipelines |
| mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_rsp_rsp_quant_dot_add_add.mlir | Updated RUN command to use separate host/kernel pipelines |
| mlir/test/fusion/e2e/mixr-gemm-fp4/mixr_rsp_rsp_quant_dot.mlir | Updated RUN command to use separate host/kernel pipelines |
| mlir/test/Dialect/MIGraphX/quant-dot-decompose.mlir | Added test verifying kernel functions don't get decomposed |
| mlir/test/Conversion/TosaToRock/tosa-to-rock.mlir | Removed old pattern-matching tests for scaled GEMM |
| mlir/test/Conversion/TosaToRock/tosa-to-rock-matmul-t-block-scaled.mlir | Added comprehensive tests for matmul_t_block_scaled lowering |
| mlir/test/Conversion/MIGraphXToTosa/quant-dot-scaled-to-matmul-t-block-scaled.mlir | Added tests for quant_dot to matmul_t_block_scaled conversion |
| mlir/lib/Dialect/MIGraphX/Transforms/MIGraphXTransform.cpp | Restricted decomposition to non-kernel functions |
| mlir/lib/Conversion/TosaToRock/TosaToRockPass.cpp | Added matmul_t_block_scaled as illegal op |
| mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | Removed scale extraction logic from MatMulConverter; added MatmulTBlockScaledConverter with scale broadcasting |
| mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp | Added support for converting quant_dot with scales to matmul_t_block_scaled |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Compute 3D output shape | ||
| SmallVector<int64_t> newDimsOut = { | ||
| batchInfo.newBatchOut, | ||
| (batchInfo.batchSizeB == 1 && batchInfo.batchSizeA != 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use batchInfo.newM instead of this?
| Value inBReshaped = inB; | ||
| if (batchInfo.needsReshape) { | ||
| auto [reshapedA, reshapedB] = | ||
| reshapeTo3DForMatmul(rewriter, loc, inA, inB, batchInfo, elementTy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we handle A and B together instead of independently? do they always require a reshape together?
| if (hasScales) { | ||
| // TODO: only blockSize of 32 is supported for matmul_t_block_scaled for | ||
| // now | ||
| int64_t blockSize = 32; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't blockSize a param of the kernel somehow? can we assert something == 32?
| SmallVector<int64_t> physScaleBShape = {batchInfo.newBatchB, scaleKDim, | ||
| batchInfo.nDim}; | ||
|
|
||
| Value scaleBPhysical = unbroadcastScale( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
physical implies it's the actual memory layout? they could have done more broadcast or transposes, couldn't they?
|
|
||
| // Convert optional attributes | ||
| if (auto attr = (*op).template getAttrOfType<StringAttr>("perf_config")) | ||
| matmulOp->setAttr("perf_config", attr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please, let's a add a test where we check perf_config is copied correctly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we need to set "acc_type" here?
| auto mergeAttr = mergeB.get(); | ||
| return rock::TransformOp::create(b, loc, scale, mergeAttr); | ||
| } | ||
| // else if not transposed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we unify if/else by having variables for indices etc?
| // Output is [batch, M, N] | ||
| auto bDataType = cast<RankedTensorType>(bData.getType()); | ||
| ArrayRef<int64_t> bShape = bDataType.getShape(); | ||
| // B physical shape depends on transpose: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as my question in another comment, do we know for sure this is the "physical" (I guess underlying layout)? couldn't they broadcast it before? or reshape it etc? if that's the case I wouldn't use "physical"
| transposeB, transposeC, nullptr, nullptr, | ||
| rw, loc, outputType, aData, bData, output, brAScale, brBScale, | ||
| transposeA, transposeB, | ||
| /*cTransposed=*/nullptr, aScaleTransposed, bScaleTransposed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why can't C be transposed? we have transpose_c in tosa::matmulop lowering
| return matMulOp.emitWarning( | ||
| "transpose found leading to a matmul input other than A or B"); | ||
| } | ||
| } else if (auto matMulTBlockScaledOp = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all of this can be removed since we have the SortDimensions pass. So, basically, whatever we do here (regarding tranposes etc) will get overwritten by SortDimensions, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean we could remove TransposeRewritePattern completely I think.
| // side non-kernel functions, we need to run the conversion manually. for | ||
| // the kernel side, tosa.matmul_t_block_scaled is converted to rock.gemm | ||
| // with scales in TosaToRock.cpp | ||
| if (!func->hasAttr("kernel")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could we use options.disableRock instead of this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's too difficult, this is fine anyway, we use it everywhere anyway...
Motivation
When scaled GEMM was implemented in rocMLIR, TOSA did not have matmul operation which would take scale arguments. Therefore as a workaround rocMLIR implemented "decompose" transform for the
migraphx.quant_dotoperation which would decompose scaled gemm intoAnd then during TosaToRock it would try to patten match "convert" + "mul" with data type checking to see if it was doing scaled GEMM or not.
TOSA now has
tosa.matmul_t_block_scaledwhich can take two additional arguments forscale_aandscale_b. Therefore this PR makes changes to make use of this new TOSA operator and adds lowering for the same in MIGraphXToTosa and then TosaToRock.Fixes https://github.com/ROCm/rocMLIR-internal/issues/2139
Technical Details
Host side
TOSA has not implemented TosaToLinalg conversion for the
tosa.matmul_t_block_scaledand therefore we still require decomposition for the host functions.As a result when writing E2E test it should first do
--clone-harnessand then run host pipeline and kernel pipeline to apply different migraphx pipelines for both host and kernel functions. This PR has changes for the same.Broadcasting of scales
migraphx.quant_dotoperator expects the scales and A/B arguments to be of same shape. It achieves this by broadcasting scales. Buttosa.matmul_t_block_scaleddoestn't have this restriction. Therefore when doing MIGraphXToTosa, it would need to do unbroadcast scales and later during TosaToRock it adds back broadcasts as rock.gemm also requires scales and A/B to be of same shape.Test Plan
CI passes