Skip to content

Conversation

@justinrosner
Copy link
Contributor

@justinrosner justinrosner commented Aug 12, 2025

Motivation

This pass identifies and removes redundant cast chains in the IR, specifically patterns where a value is converted from f32 to a smaller type (e.g., f16) and then immediately extended back to f32. By eliminating these unnecessary conversions, the pass simplifies the IR and can improve performance by reducing superfluous operations and memory traffic.

Technical Details

This pass can currently handle the following situations:

  1. MFMA ops do accumulation in higher precision, so a GEMM returning a f16 type with MFMA enabled will under the hood to a TruncF back to the lower precision type. If this op is then used directly by a convert (extf), then we want to remove this chain of casts.
function {
  migraphx.dot A, B : f16 -> C : f16
  migraphx.convert (C) : f16 -> f32
}
  1. Same case as above, but there are some additional uses of the initial GEMM that means the initial truncf needs to stick around.
function {
  migraphx.dot A, B : f16 -> C : f16
  ...
  use of C
  ... 
  migraphx.convert (C) : f16 -> f32
}
  1. We don't explicitly have to look for the MFMA case, we could also have other cast -> cast chains that are redundant and can be removed.
function{
  opA -> f16
  convertA = migraphx.convert(opA) : f16 -> f32
  ...
  migraphx.convert(convertA) : f32 -> f16
}

Test Plan

Test Result

I've manually examined the final generated assembly in the attached design and confirmed that we no longer have the redundant cast chains.

Submission Checklist

@justinrosner justinrosner marked this pull request as ready for review August 12, 2025 19:39
@justinrosner justinrosner requested a review from causten as a code owner August 12, 2025 19:39
@justinrosner justinrosner requested a review from Copilot August 14, 2025 14:26
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a new optimization pass that removes redundant cast chains in the IR. The pass identifies patterns where values are converted from f32 to smaller precision types (e.g., f16) and then immediately extended back to f32, eliminating these unnecessary conversions to simplify the IR and improve performance.

Key changes include:

  • Implementation of a new RockRemoveRedundantCastsPass that handles MFMA operations and generic cast chains
  • Integration of the pass into the compilation pipeline at appropriate stages
  • Comprehensive test coverage for different cast chain scenarios

Reviewed Changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp Core implementation of the redundant cast removal pass
mlir/test/Dialect/Rock/remove_redundant_casts*.mlir Test files covering various cast chain removal scenarios
mlir/test/rocmlir-driver/pipelines.mlir Pipeline test update to include the new pass
mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp Integration of pass into compilation pipelines
mlir/include/mlir/Dialect/Rock/Passes.* Pass declaration and registration
mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt Build system integration
mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp Debug output addition

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.


def RockRemoveRedundantCastsPass : Pass<"rock-remove-redundant-casts", "::mlir::func::FuncOp"> {
let summary = "Remove redundant casts between ops";
let dependentDialects = ["rock::RockDialect", "linalg::LinalgDialect"];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q : it uses arith Dialect as well. It is not explictly listed here but it seems to be working fine. Do you know why is that ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that this is because dependentDialects just serves as a hint to the compiler about what dialects it needs to load for a given pass. Regardless, I've added in Arith here as it makes sense to do so.

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you require GPU dialect ?

// with either a single trunc or ext op. e.g.,:
// %1 = migraphx.convert %0 : <1x5x3xf16, 15x3x1> to <1x5x3xf32, 15x3x1>
template <typename OpType>
bool isGenericWithSingleOp(linalg::GenericOp generic) const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are running linalg-elementwise-op-fusion in highlevel pipeline which would fuse multiple linalg generic ops into a single one.

Therefore, I don't' think it is always guranteed that linalg will only contain single operation.

I think "extf/truncf' are a special case because they require tensors/memrefs of different types on inputs and outputs. therefore it is probably true that linalg generic will only contains extf/truncf and it won't be fused with other linalg ops when we run linalg-elementwise-fusion Do you think that's the case ?

You can probably write a small test and run it through linalg-elementwise-fusion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree, I think it would be better to just check its uses and go back until you find the pattern dtype->dtype2, dtype2->dtype. Instead of assuming linalg will have one op.

// We don't need to investigate BlockArguments any further
Value input = generic.getInputs()[0];
if (isa<BlockArgument>(input))
return nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be returning input ?


// If this op uses mfma, it will accumulate in higher precision (F32 or I32)
auto features = rock::getFeatures(rockOp);
bool isMfma = bitEnumContainsAll(features, GemmFeatures::mfma);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also work on wmma, it also accumulates in higher precision.

Also check the non-accel path if it is doing accumulation in higher precision or not.

Comment on lines +289 to +290
auto inputType = cast<RankedTensorType>(input->getResult(0).getType());
auto outputType = cast<RankedTensorType>(output->getResult(0).getType());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use shapedtype

Value outputArg = args[1];
Type oType = outputArg.getType();
Value truncResult =
builder.create<arith::TruncFOp>(loc, oType, blockArg);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arith::truncf
This wouldn't work on Integer type

%5 = rock.transform %4 by #transform_map2 : tensor<1x5x3xf32> to tensor<5x3xf32>

%temp_alloc = bufferization.alloc_tensor() : tensor<5x3xf16>
// CHECK-NOT: %downcast
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use variable names inside checks. variable names can change.

e.g.

if i comment out logic for remove-redundant-cast and just run empty pass,
variable name changes for %downcast to %6.

Comment on lines +206 to +212
} else if (isa<linalg::GenericOp>(user) &&
(isGenericWithSingleOp<arith::ExtFOp>(
cast<linalg::GenericOp>(user)) ||
isGenericWithSingleOp<arith::ExtSIOp>(
cast<linalg::GenericOp>(user)) ||
isGenericWithSingleOp<arith::ExtUIOp>(
cast<linalg::GenericOp>(user)))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it look like formatting is incorrect here ?

@dhernandez0
Copy link
Contributor

dhernandez0 commented Aug 18, 2025

More tests to add:

  • a test that checks the conversion instructions in assembly are not found anymore
  • input fusion tests
  • more than two linalgs

%1 = migraphx.convert %0 : <1x5x3xf16, 15x3x1> to <1x5x3xf32, 15x3x1>
return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1>
}
} No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new line

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test for rock.attention? especially the one in the ticket of this PR.

auto &funcPm3 = pm.nest<func::FuncOp>();
funcPm3.addPass(bufferization::createEmptyTensorToAllocTensorPass());
funcPm3.addPass(createLinalgFoldUnitExtentDimsPass());
funcPm3.addPass(rock::createRockRemoveRedundantCastsPass());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the attention pass will create its own conversion linalgs way after this. So, I think for this to work on attention we need it to happen at least after ToBlockwise. Note that the case highlighted in the ticket is attention, so it's the main goal of this PR I think: https://github.com/ROCm/rocMLIR-internal/issues/1932

Copy link
Contributor

@dhernandez0 dhernandez0 Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, this is going to be a major change to the code in this PR, because all of this happens after bufferization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spoke with @umangyadav about this last week as I was trying to avoid this case. Tracing uses becomes tricky because we need to come up with some additional memory analysis passes that can trace uses of a memref to find out if any reads/writes happen between two ops. I had an initial version of this in my first commits, but then changed it to this approach. Maybe we can discuss further after standup tomorrow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it only makes sense to do it after ToBlockwise, otherwise we won't be solving the issue described in the ticket.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think doing it after ToBlockwise is simpler, you don't need to have a special case for rock.fusionop (gemm/conv etc). It's only linalg.generics. You only need to get linalg.generic and keep tracing back (though Allocs with BufferDependencyAnalysis) to find the pattern you are looking for.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, let's discuss in the meeting.

}

def RockRemoveRedundantCastsPass : Pass<"rock-remove-redundant-casts", "::mlir::func::FuncOp"> {
let summary = "Remove redundant casts between ops";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: be more explicit about what we consider redundant.

// IR, specifically patterns where a value is converted from f32 to a smaller
// type (e.g., f16) and then immediately extended back to f32. By eliminating
// these unnecessary conversions, the pass simplifies the IR and can improve
// performance by reducing superfluous operations and memory traffic.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't think memory traffic is relevant here, I think everything is in registers in the assembly in most cases, right?

// the return type isn't F32 or I32 (highest level of precision), then it
// means that there will be a trunc op inserted in RockGemmToGridwise
// that will potentially be redundant.
changed |= handleRockGemmWrapper(input, generic, rewriter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed earlier, we need this pass to happen later on to work properly for attention. So, there's no need to handle this special case there will be a linalg.generic for this after ToBlockwise (I think).

// with either a single trunc or ext op. e.g.,:
// %1 = migraphx.convert %0 : <1x5x3xf16, 15x3x1> to <1x5x3xf32, 15x3x1>
template <typename OpType>
bool isGenericWithSingleOp(linalg::GenericOp generic) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree, I think it would be better to just check its uses and go back until you find the pattern dtype->dtype2, dtype2->dtype. Instead of assuming linalg will have one op.


Value getExtInput(linalg::GenericOp generic) const {
// Check that there is only one input to the generic operation
if (generic.getInputs().size() != 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a requirement

// Helper function to create a new output value for a
// RockGemmWrapperInterface/RockGemmGemmwrapperInterface Op or LinalgGeneric
// TruncOp, and any corresponding rock.transformOps
Value createNewOutput(Value prevValue,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use rock::transform


// If there is only a single ExtOp use, then we can go ahead and remove
// all of the TransformOps and the original truncf
auto singleExtOp = std::get<1>(tup);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if some ops are not used, they will get removed by the canonalization stage, I think there's no need to do it explicitly.

@justinrosner
Copy link
Contributor Author

Went with a different approach (at the LLVMIR dialect level) to avoid a lot of the issues in dealing with linalg.generics. That new approach is in draft here: #2202

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants