-
Notifications
You must be signed in to change notification settings - Fork 52
Remove redundant cast chains #1944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new optimization pass that removes redundant cast chains in the IR. The pass identifies patterns where values are converted from f32 to smaller precision types (e.g., f16) and then immediately extended back to f32, eliminating these unnecessary conversions to simplify the IR and improve performance.
Key changes include:
- Implementation of a new
RockRemoveRedundantCastsPassthat handles MFMA operations and generic cast chains - Integration of the pass into the compilation pipeline at appropriate stages
- Comprehensive test coverage for different cast chain scenarios
Reviewed Changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp |
Core implementation of the redundant cast removal pass |
mlir/test/Dialect/Rock/remove_redundant_casts*.mlir |
Test files covering various cast chain removal scenarios |
mlir/test/rocmlir-driver/pipelines.mlir |
Pipeline test update to include the new pass |
mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp |
Integration of pass into compilation pipelines |
mlir/include/mlir/Dialect/Rock/Passes.* |
Pass declaration and registration |
mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt |
Build system integration |
mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp |
Debug output addition |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
|
|
||
| def RockRemoveRedundantCastsPass : Pass<"rock-remove-redundant-casts", "::mlir::func::FuncOp"> { | ||
| let summary = "Remove redundant casts between ops"; | ||
| let dependentDialects = ["rock::RockDialect", "linalg::LinalgDialect"]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q : it uses arith Dialect as well. It is not explictly listed here but it seems to be working fine. Do you know why is that ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that this is because dependentDialects just serves as a hint to the compiler about what dialects it needs to load for a given pass. Regardless, I've added in Arith here as it makes sense to do so.
| #include "mlir/Dialect/Arith/IR/Arith.h" | ||
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" | ||
| #include "mlir/Dialect/Func/IR/FuncOps.h" | ||
| #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you require GPU dialect ?
| // with either a single trunc or ext op. e.g.,: | ||
| // %1 = migraphx.convert %0 : <1x5x3xf16, 15x3x1> to <1x5x3xf32, 15x3x1> | ||
| template <typename OpType> | ||
| bool isGenericWithSingleOp(linalg::GenericOp generic) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are running linalg-elementwise-op-fusion in highlevel pipeline which would fuse multiple linalg generic ops into a single one.
Therefore, I don't' think it is always guranteed that linalg will only contain single operation.
I think "extf/truncf' are a special case because they require tensors/memrefs of different types on inputs and outputs. therefore it is probably true that linalg generic will only contains extf/truncf and it won't be fused with other linalg ops when we run linalg-elementwise-fusion Do you think that's the case ?
You can probably write a small test and run it through linalg-elementwise-fusion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree, I think it would be better to just check its uses and go back until you find the pattern dtype->dtype2, dtype2->dtype. Instead of assuming linalg will have one op.
| // We don't need to investigate BlockArguments any further | ||
| Value input = generic.getInputs()[0]; | ||
| if (isa<BlockArgument>(input)) | ||
| return nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should it be returning input ?
|
|
||
| // If this op uses mfma, it will accumulate in higher precision (F32 or I32) | ||
| auto features = rock::getFeatures(rockOp); | ||
| bool isMfma = bitEnumContainsAll(features, GemmFeatures::mfma); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also work on wmma, it also accumulates in higher precision.
Also check the non-accel path if it is doing accumulation in higher precision or not.
| auto inputType = cast<RankedTensorType>(input->getResult(0).getType()); | ||
| auto outputType = cast<RankedTensorType>(output->getResult(0).getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use shapedtype
| Value outputArg = args[1]; | ||
| Type oType = outputArg.getType(); | ||
| Value truncResult = | ||
| builder.create<arith::TruncFOp>(loc, oType, blockArg); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arith::truncf
This wouldn't work on Integer type
| %5 = rock.transform %4 by #transform_map2 : tensor<1x5x3xf32> to tensor<5x3xf32> | ||
|
|
||
| %temp_alloc = bufferization.alloc_tensor() : tensor<5x3xf16> | ||
| // CHECK-NOT: %downcast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not use variable names inside checks. variable names can change.
e.g.
if i comment out logic for remove-redundant-cast and just run empty pass,
variable name changes for %downcast to %6.
| } else if (isa<linalg::GenericOp>(user) && | ||
| (isGenericWithSingleOp<arith::ExtFOp>( | ||
| cast<linalg::GenericOp>(user)) || | ||
| isGenericWithSingleOp<arith::ExtSIOp>( | ||
| cast<linalg::GenericOp>(user)) || | ||
| isGenericWithSingleOp<arith::ExtUIOp>( | ||
| cast<linalg::GenericOp>(user)))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does it look like formatting is incorrect here ?
|
More tests to add:
|
| %1 = migraphx.convert %0 : <1x5x3xf16, 15x3x1> to <1x5x3xf32, 15x3x1> | ||
| return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1> | ||
| } | ||
| } No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a test for rock.attention? especially the one in the ticket of this PR.
| auto &funcPm3 = pm.nest<func::FuncOp>(); | ||
| funcPm3.addPass(bufferization::createEmptyTensorToAllocTensorPass()); | ||
| funcPm3.addPass(createLinalgFoldUnitExtentDimsPass()); | ||
| funcPm3.addPass(rock::createRockRemoveRedundantCastsPass()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the attention pass will create its own conversion linalgs way after this. So, I think for this to work on attention we need it to happen at least after ToBlockwise. Note that the case highlighted in the ticket is attention, so it's the main goal of this PR I think: https://github.com/ROCm/rocMLIR-internal/issues/1932
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, this is going to be a major change to the code in this PR, because all of this happens after bufferization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I spoke with @umangyadav about this last week as I was trying to avoid this case. Tracing uses becomes tricky because we need to come up with some additional memory analysis passes that can trace uses of a memref to find out if any reads/writes happen between two ops. I had an initial version of this in my first commits, but then changed it to this approach. Maybe we can discuss further after standup tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it only makes sense to do it after ToBlockwise, otherwise we won't be solving the issue described in the ticket.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think doing it after ToBlockwise is simpler, you don't need to have a special case for rock.fusionop (gemm/conv etc). It's only linalg.generics. You only need to get linalg.generic and keep tracing back (though Allocs with BufferDependencyAnalysis) to find the pattern you are looking for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, let's discuss in the meeting.
| } | ||
|
|
||
| def RockRemoveRedundantCastsPass : Pass<"rock-remove-redundant-casts", "::mlir::func::FuncOp"> { | ||
| let summary = "Remove redundant casts between ops"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: be more explicit about what we consider redundant.
| // IR, specifically patterns where a value is converted from f32 to a smaller | ||
| // type (e.g., f16) and then immediately extended back to f32. By eliminating | ||
| // these unnecessary conversions, the pass simplifies the IR and can improve | ||
| // performance by reducing superfluous operations and memory traffic. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I don't think memory traffic is relevant here, I think everything is in registers in the assembly in most cases, right?
| // the return type isn't F32 or I32 (highest level of precision), then it | ||
| // means that there will be a trunc op inserted in RockGemmToGridwise | ||
| // that will potentially be redundant. | ||
| changed |= handleRockGemmWrapper(input, generic, rewriter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as discussed earlier, we need this pass to happen later on to work properly for attention. So, there's no need to handle this special case there will be a linalg.generic for this after ToBlockwise (I think).
| // with either a single trunc or ext op. e.g.,: | ||
| // %1 = migraphx.convert %0 : <1x5x3xf16, 15x3x1> to <1x5x3xf32, 15x3x1> | ||
| template <typename OpType> | ||
| bool isGenericWithSingleOp(linalg::GenericOp generic) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree, I think it would be better to just check its uses and go back until you find the pattern dtype->dtype2, dtype2->dtype. Instead of assuming linalg will have one op.
|
|
||
| Value getExtInput(linalg::GenericOp generic) const { | ||
| // Check that there is only one input to the generic operation | ||
| if (generic.getInputs().size() != 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is a requirement
| // Helper function to create a new output value for a | ||
| // RockGemmWrapperInterface/RockGemmGemmwrapperInterface Op or LinalgGeneric | ||
| // TruncOp, and any corresponding rock.transformOps | ||
| Value createNewOutput(Value prevValue, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use rock::transform
|
|
||
| // If there is only a single ExtOp use, then we can go ahead and remove | ||
| // all of the TransformOps and the original truncf | ||
| auto singleExtOp = std::get<1>(tup); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if some ops are not used, they will get removed by the canonalization stage, I think there's no need to do it explicitly.
|
Went with a different approach (at the LLVMIR dialect level) to avoid a lot of the issues in dealing with linalg.generics. That new approach is in draft here: #2202 |
Motivation
This pass identifies and removes redundant cast chains in the IR, specifically patterns where a value is converted from f32 to a smaller type (e.g., f16) and then immediately extended back to f32. By eliminating these unnecessary conversions, the pass simplifies the IR and can improve performance by reducing superfluous operations and memory traffic.
Technical Details
This pass can currently handle the following situations:
Test Plan
Test Result
I've manually examined the final generated assembly in the attached design and confirmed that we no longer have the redundant cast chains.
Submission Checklist