From 1fafdc063a973e7e3c2bcc213ead5488dcf065d4 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Mon, 5 Jan 2026 21:41:34 +0000 Subject: [PATCH 01/12] Initial truncf finding --- mlir/include/mlir/Dialect/Rock/Passes.h | 1 + mlir/include/mlir/Dialect/Rock/Passes.td | 17 +++ mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp | 1 + .../Dialect/Rock/Transforms/CMakeLists.txt | 1 + .../Rock/Transforms/RemoveRedundantCasts.cpp | 120 ++++++++++++++++++ 5 files changed, 140 insertions(+) create mode 100644 mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp diff --git a/mlir/include/mlir/Dialect/Rock/Passes.h b/mlir/include/mlir/Dialect/Rock/Passes.h index c0771efaae26..33130b81f36d 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.h +++ b/mlir/include/mlir/Dialect/Rock/Passes.h @@ -51,6 +51,7 @@ namespace rock { #define GEN_PASS_DECL_ROCKSORTDIMENSIONSMEMORYLAYOUTPASS #define GEN_PASS_DECL_ROCKFINDFIRSTGEMMINDEXPASS #define GEN_PASS_DECL_ROCKREMOVEOUTPUTALLOCPASS +#define GEN_PASS_DECL_ROCKREMOVEREDUNDANTCASTSPASS #define GEN_PASS_DECL_ROCKBLOCKWISELOADTILETOTHREADWISEPASS #define GEN_PASS_DECL_ROCKANNOTATELIVENESSPASS #define GEN_PASS_DECL_ROCKADDASYNCWAITPASS diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index 6f2965a15373..6607510e6ed1 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -152,6 +152,23 @@ def RockVectorizeFusionsPass : Pass<"rock-vectorize-fusions", "::mlir::func::Fun let summary = "Vectorize affine element-wise loops"; } +def RockRemoveRedundantCastsPass : Pass<"rock-remove-redundant-casts", "::mlir::func::FuncOp"> { + let summary = "Remove redundant truncf/extf pairs through buffers"; + let description = [{ + Detects patterns where wider float values are truncated to a narrower + float type, stored to a buffer, then loaded and extended back to the + original wider input type. Replaces the extf uses with the original wide + values, preserving precision. + }]; + let dependentDialects = [ + "rock::RockDialect", + "linalg::LinalgDialect", + "arith::ArithDialect", + "memref::MemRefDialect", + "vector::VectorDialect" + ]; +} + def RockBufferLoadMergePass : Pass<"rock-buffer-load-merge", "::mlir::func::FuncOp"> { let summary = "Merge identical memory loads to buffers only read. Assumes noalias."; let dependentDialects = ["::mlir::amdgpu::AMDGPUDialect"]; diff --git a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp index a945ed55ca8e..38fb4faa9488 100644 --- a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp +++ b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp @@ -196,6 +196,7 @@ void rock::buildKernelPipeline(OpPassManager &pm, funcPm.addPass(createCanonicalizerPass()); funcPm.addPass(createConvertLinalgToAffineLoopsPass()); funcPm.addPass(rock::createRockVectorizeFusionsPass()); + funcPm.addPass(rock::createRockRemoveRedundantCastsPass()); funcPm.addPass(rock::createRockAddAsyncWaitPass()); // We run reuse LDS before the output swizzle pass because it uses a // heuristic to determine whether to swizzle or not, and that heuristic diff --git a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt index 4147eaf69b75..79e7f63c7cc4 100644 --- a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt @@ -37,6 +37,7 @@ add_rocmlir_dialect_library(MLIRRockTransforms BlockwiseLoadTileToThreadwise.cpp AnnotateLiveness.cpp AddAsyncWait.cpp + RemoveRedundantCasts.cpp LowerRockOpsToROCDLOps.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp new file mode 100644 index 000000000000..f595be2bd6ac --- /dev/null +++ b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp @@ -0,0 +1,120 @@ +//===--------------------- RemoveRedundantCasts.cpp -----------------------===// +// +// Copyright 2026 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// +// +// This pass detects patterns where wider float values are truncated to a +// narrower float type, stored to a buffer, then loaded and extended back to the +// original wider input type. Replaces the extf uses with the original wide +// values, preserving precision. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Rock/IR/Rock.h" +#include "mlir/Dialect/Rock/Passes.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace rock { +#define GEN_PASS_DEF_ROCKREMOVEREDUNDANTCASTSPASS +#include "mlir/Dialect/Rock/Passes.h.inc" +} // namespace rock +} // namespace mlir + +#define DEBUG_TYPE "rock-remove-redundant-casts" + +using namespace mlir; +using namespace mlir::rock; + +namespace { + +struct TruncfStoreInfo { + arith::TruncFOp truncfOp; + Value wideValue; + Operation *storeOp; + Value targetBuffer; + SmallVector storeIndices; +}; + +struct LoadExtfInfo { + Operation *loadOp; + arith::ExtFOp extfOp; + SmallVector loadIndices; +}; + +// Collect all arith.truncf operations in the function that convert from +// a wider float type to a narrower float type. +SmallVector findAllTruncfOps(func::FuncOp funcOp) { + SmallVector truncfOps; + + funcOp.walk([&](arith::TruncFOp truncfOp) -> WalkResult { + Type inputType = getElementTypeOrSelf(truncfOp.getIn().getType()); + Type outputType = getElementTypeOrSelf(truncfOp.getOut().getType()); + + // Check that this is a narrowing conversion (truncf) + if (outputType.getIntOrFloatBitWidth() >= inputType.getIntOrFloatBitWidth()) + return WalkResult::advance(); + + LLVM_DEBUG(llvm::dbgs() << "Found truncf: " << truncfOp << "\n"); + truncfOps.push_back(truncfOp); + return WalkResult::advance(); + }); + + LLVM_DEBUG(llvm::dbgs() << "Total truncf operations found: " + << truncfOps.size() << "\n"); + return truncfOps; +} + +struct RockRemoveRedundantCastsPass + : public rock::impl::RockRemoveRedundantCastsPassBase< + RockRemoveRedundantCastsPass> { + void runOnOperation() override; +}; + +} // end namespace + +void RockRemoveRedundantCastsPass::runOnOperation() { + func::FuncOp funcOp = getOperation(); + + // Step 1: Find all truncf operations (f32 -> narrow float) + SmallVector truncfOps = findAllTruncfOps(funcOp); + + if (truncfOps.empty()) { + LLVM_DEBUG(llvm::dbgs() << "No truncf operations found, nothing to do.\n"); + return; + } + + LLVM_DEBUG(llvm::dbgs() << "Found " << truncfOps.size() + << " truncf operations to analyze.\n"); + + // TODO: Implement remaining steps of the algorithm: + // Step 2: Check for direct stores + // Step 3: Find direct extf readers + // Step 4: Verify safety + // Step 5: Apply the optimization +} From 625ba462f9911961523181832e05cb0a65b725a2 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Mon, 5 Jan 2026 22:10:45 +0000 Subject: [PATCH 02/12] Add in logic so that we are only finding truncfs with direct stores --- .../Rock/Transforms/RemoveRedundantCasts.cpp | 83 +++++++++++++++---- 1 file changed, 67 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp index f595be2bd6ac..eb0b2934081a 100644 --- a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp @@ -67,27 +67,79 @@ struct LoadExtfInfo { SmallVector loadIndices; }; -// Collect all arith.truncf operations in the function that convert from -// a wider float type to a narrower float type. -SmallVector findAllTruncfOps(func::FuncOp funcOp) { - SmallVector truncfOps; +// Helper to check if a store operation directly stores a value. +// Returns the target buffer and indices if it's a supported store type. +static FailureOr>> +getStoreBufferAndIndices(Operation *op, Value storedValue) { + if (auto inBoundsStore = dyn_cast(op)) { + if (inBoundsStore.getData() == storedValue) { + return std::pair>( + inBoundsStore.getDest(), + SmallVector(inBoundsStore.getCoords())); + } + } else if (auto vectorStore = dyn_cast(op)) { + if (vectorStore.getValueToStore() == storedValue) { + return std::pair>( + vectorStore.getBase(), + SmallVector(vectorStore.getIndices())); + } + } else if (auto memrefStore = dyn_cast(op)) { + if (memrefStore.getValue() == storedValue) { + return std::pair>( + memrefStore.getMemRef(), + SmallVector(memrefStore.getIndices())); + } + } + return failure(); +} + +// Find all arith.truncf operations that are directly stored to a buffer. +// A "direct store" means the truncf result is used immediately by a store +// operation with no intermediate operations modifying the value: +// Valid: truncf -> store +// Invalid: truncf -> other_op -> store +SmallVector findTruncfWithDirectStores(func::FuncOp funcOp) { + SmallVector results; funcOp.walk([&](arith::TruncFOp truncfOp) -> WalkResult { Type inputType = getElementTypeOrSelf(truncfOp.getIn().getType()); Type outputType = getElementTypeOrSelf(truncfOp.getOut().getType()); - // Check that this is a narrowing conversion (truncf) + // Step 1: Check that this is a narrowing conversion (truncf) if (outputType.getIntOrFloatBitWidth() >= inputType.getIntOrFloatBitWidth()) return WalkResult::advance(); LLVM_DEBUG(llvm::dbgs() << "Found truncf: " << truncfOp << "\n"); - truncfOps.push_back(truncfOp); + + // Step 2: Check for direct stores of the truncf result + Value truncfResult = truncfOp.getOut(); + Value wideValue = truncfOp.getIn(); + + for (Operation *user : truncfResult.getUsers()) { + FailureOr>> storeInfo = + getStoreBufferAndIndices(user, truncfResult); + if (failed(storeInfo)) + continue; + + LLVM_DEBUG(llvm::dbgs() << " Found direct store: " << *user << "\n"); + LLVM_DEBUG(llvm::dbgs() << " Target buffer: " << storeInfo->first + << "\n"); + + TruncfStoreInfo info; + info.truncfOp = truncfOp; + info.wideValue = wideValue; + info.storeOp = user; + info.targetBuffer = storeInfo->first; + info.storeIndices = std::move(storeInfo->second); + results.push_back(info); + } + return WalkResult::advance(); }); - LLVM_DEBUG(llvm::dbgs() << "Total truncf operations found: " - << truncfOps.size() << "\n"); - return truncfOps; + LLVM_DEBUG(llvm::dbgs() << "Total truncf -> store pairs found: " + << results.size() << "\n"); + return results; } struct RockRemoveRedundantCastsPass @@ -101,19 +153,18 @@ struct RockRemoveRedundantCastsPass void RockRemoveRedundantCastsPass::runOnOperation() { func::FuncOp funcOp = getOperation(); - // Step 1: Find all truncf operations (f32 -> narrow float) - SmallVector truncfOps = findAllTruncfOps(funcOp); + SmallVector truncfStores = findTruncfWithDirectStores(funcOp); - if (truncfOps.empty()) { - LLVM_DEBUG(llvm::dbgs() << "No truncf operations found, nothing to do.\n"); + if (truncfStores.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "No truncf -> store patterns found, nothing to do.\n"); return; } - LLVM_DEBUG(llvm::dbgs() << "Found " << truncfOps.size() - << " truncf operations to analyze.\n"); + LLVM_DEBUG(llvm::dbgs() << "Found " << truncfStores.size() + << " truncf -> store patterns to analyze.\n"); // TODO: Implement remaining steps of the algorithm: - // Step 2: Check for direct stores // Step 3: Find direct extf readers // Step 4: Verify safety // Step 5: Apply the optimization From f48f74a9049d77339451bc2d180d23dba1a3e9c1 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Mon, 5 Jan 2026 22:17:04 +0000 Subject: [PATCH 03/12] Minor comment and debug message fixes --- mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp index eb0b2934081a..fd54488409e8 100644 --- a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp @@ -20,6 +20,10 @@ // original wider input type. Replaces the extf uses with the original wide // values, preserving precision. // +// Note: The simpler truncf -> extf folding with no loads/stores is already +// handled by arith.truncf canonicalization patterns. This pass specifically +// deals with the more complex case where the values are stored to buffers. +// //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" @@ -122,8 +126,6 @@ SmallVector findTruncfWithDirectStores(func::FuncOp funcOp) { continue; LLVM_DEBUG(llvm::dbgs() << " Found direct store: " << *user << "\n"); - LLVM_DEBUG(llvm::dbgs() << " Target buffer: " << storeInfo->first - << "\n"); TruncfStoreInfo info; info.truncfOp = truncfOp; @@ -137,8 +139,6 @@ SmallVector findTruncfWithDirectStores(func::FuncOp funcOp) { return WalkResult::advance(); }); - LLVM_DEBUG(llvm::dbgs() << "Total truncf -> store pairs found: " - << results.size() << "\n"); return results; } From 6f9114a914ea7134ce470cff1a7c1e90628660d3 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 6 Jan 2026 20:29:39 +0000 Subject: [PATCH 04/12] Add detection for extf ops --- .../Rock/Transforms/RemoveRedundantCasts.cpp | 97 ++++++++++++++++--- 1 file changed, 86 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp index fd54488409e8..5af23b88253e 100644 --- a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp @@ -97,11 +97,74 @@ getStoreBufferAndIndices(Operation *op, Value storedValue) { return failure(); } +// Helper to check if a load operation reads from a specific buffer. +// Returns the loaded value and indices if it's a supported load type. +static FailureOr>> +getLoadResultAndIndices(Operation *op, Value expectedBuffer) { + if (auto inBoundsLoad = dyn_cast(op)) { + if (inBoundsLoad.getSource() == expectedBuffer) { + return std::pair>( + inBoundsLoad.getResult(), + SmallVector(inBoundsLoad.getCoords())); + } + } else if (auto transferRead = dyn_cast(op)) { + if (transferRead.getBase() == expectedBuffer) { + return std::pair>( + transferRead.getResult(), + SmallVector(transferRead.getIndices())); + } + } else if (auto memrefLoad = dyn_cast(op)) { + if (memrefLoad.getMemRef() == expectedBuffer) { + return std::pair>( + memrefLoad.getResult(), + SmallVector(memrefLoad.getIndices())); + } + } + return failure(); +} + +// Find all load -> extf patterns from a given buffer. +// A "direct extf" means the load result is used immediately by an extf +// operation with no intermediate operations modifying the value. +SmallVector findDirectExtfReaders(Value narrowBuffer, + Type wideType) { + SmallVector results; + + // Iterate over direct users of the buffer + for (Operation *user : narrowBuffer.getUsers()) { + // Check if this user is a load from our buffer + FailureOr>> loadInfo = + getLoadResultAndIndices(user, narrowBuffer); + if (failed(loadInfo)) + continue; + + Value loadResult = loadInfo->first; + + // Check if the load result is used directly by an arith.extf + for (Operation *loadUser : loadResult.getUsers()) { + auto extfOp = dyn_cast(loadUser); + if (!extfOp) + continue; + + // Verify the extf output type matches the expected wide type + Type extfOutputType = getElementTypeOrSelf(extfOp.getOut().getType()); + if (extfOutputType != wideType) + continue; + + LoadExtfInfo info; + info.loadOp = user; + info.extfOp = extfOp; + info.loadIndices = std::move(loadInfo->second); + results.push_back(info); + } + } + + return results; +} + // Find all arith.truncf operations that are directly stored to a buffer. // A "direct store" means the truncf result is used immediately by a store -// operation with no intermediate operations modifying the value: -// Valid: truncf -> store -// Invalid: truncf -> other_op -> store +// operation with no intermediate operations modifying the value. SmallVector findTruncfWithDirectStores(func::FuncOp funcOp) { SmallVector results; @@ -109,13 +172,11 @@ SmallVector findTruncfWithDirectStores(func::FuncOp funcOp) { Type inputType = getElementTypeOrSelf(truncfOp.getIn().getType()); Type outputType = getElementTypeOrSelf(truncfOp.getOut().getType()); - // Step 1: Check that this is a narrowing conversion (truncf) + // Check that this is a narrowing conversion (truncf) if (outputType.getIntOrFloatBitWidth() >= inputType.getIntOrFloatBitWidth()) return WalkResult::advance(); - LLVM_DEBUG(llvm::dbgs() << "Found truncf: " << truncfOp << "\n"); - - // Step 2: Check for direct stores of the truncf result + // Check for direct stores of the truncf result Value truncfResult = truncfOp.getOut(); Value wideValue = truncfOp.getIn(); @@ -125,8 +186,6 @@ SmallVector findTruncfWithDirectStores(func::FuncOp funcOp) { if (failed(storeInfo)) continue; - LLVM_DEBUG(llvm::dbgs() << " Found direct store: " << *user << "\n"); - TruncfStoreInfo info; info.truncfOp = truncfOp; info.wideValue = wideValue; @@ -164,8 +223,24 @@ void RockRemoveRedundantCastsPass::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "Found " << truncfStores.size() << " truncf -> store patterns to analyze.\n"); + // For each truncf -> store pair, find load -> extf readers + for (const TruncfStoreInfo &truncfStore : truncfStores) { + Type wideType = getElementTypeOrSelf(truncfStore.wideValue.getType()); + LLVM_DEBUG(llvm::dbgs() << "Analyzing buffer: " << truncfStore.targetBuffer + << "\n"); + SmallVector extfReaders = + findDirectExtfReaders(truncfStore.targetBuffer, wideType); + + if (extfReaders.empty()) { + LLVM_DEBUG(llvm::dbgs() << "\tNo load -> extf readers found.\n"); + continue; + } + + LLVM_DEBUG(llvm::dbgs() << "\tFound " << extfReaders.size() + << " load -> extf readers.\n"); + } + // TODO: Implement remaining steps of the algorithm: - // Step 3: Find direct extf readers - // Step 4: Verify safety + // Step 4: Verify safety (dominance, no intervening writes, same indices) // Step 5: Apply the optimization } From a8a7cda469419a3e44b25d78e9931673115e970a Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 6 Jan 2026 21:20:43 +0000 Subject: [PATCH 05/12] Partial verification of store/load chains --- .../Rock/Transforms/RemoveRedundantCasts.cpp | 118 +++++++++++++++++- 1 file changed, 114 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp index 5af23b88253e..89c6c9d856b5 100644 --- a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp @@ -40,6 +40,7 @@ #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" @@ -71,6 +72,94 @@ struct LoadExtfInfo { SmallVector loadIndices; }; +// A verified candidate for optimization - a (truncf->store, load->extf) pair +// that has passed all safety checks. +struct OptimizationCandidate { + TruncfStoreInfo truncfStore; + LoadExtfInfo loadExtf; +}; + +// Check if there are any other stores to the buffer that could interfere. +// Returns true if there are no intervening writes. +static bool hasNoInterveningWrites(Value buffer, Operation *ourStore) { + // Conservative check: ensure our store is the ONLY store to this buffer. + // This handles the common case where a buffer is written once and read + // multiple times. + for (Operation *user : buffer.getUsers()) { + // Skip our own store + if (user == ourStore) + continue; + + // Check if this user is a store operation + if (isa(user)) { + LLVM_DEBUG(llvm::dbgs() + << "\t\tFound another store to buffer: " << *user << "\n"); + return false; + } + } + return true; +} + +// Find the ancestor of 'op' that is a direct child of 'block'. +static Operation *getAncestorInBlock(Operation *op, Block *block) { + while (op && op->getBlock() != block) + op = op->getParentOp(); + return op; +} + +// Check if storeOp's enclosing operation dominates loadOp's enclosing operation. +// For ops in nested regions, finds their ancestors at a common nesting level +// and checks dominance between those ancestors. +static bool storeEnclosingOpDominatesLoad(Operation *storeOp, Operation *loadOp, + DominanceInfo &domInfo) { + // If they're in the same block, use direct dominance + if (storeOp->getBlock() == loadOp->getBlock()) + return domInfo.properlyDominates(storeOp, loadOp); + + // Find a common ancestor block by walking up from the load + for (Operation *loadWalk = loadOp; loadWalk; + loadWalk = loadWalk->getParentOp()) { + Block *block = loadWalk->getBlock(); + if (Operation *storeAncestor = getAncestorInBlock(storeOp, block)) { + // Found common block - check if store's ancestor dominates load's + return domInfo.properlyDominates(storeAncestor, loadWalk); + } + } + + return false; +} + +// Verify that a (truncf->store, load->extf) pair is safe to optimize. +// Returns true if all safety conditions are met. +static bool verifySafety(const TruncfStoreInfo &truncfStore, + const LoadExtfInfo &loadExtf, DominanceInfo &domInfo) { + // 4a. Store's enclosing op must dominate load's enclosing op + if (!storeEnclosingOpDominatesLoad(truncfStore.storeOp, loadExtf.loadOp, + domInfo)) { + LLVM_DEBUG(llvm::dbgs() << "\t\tStore does not dominate load\n"); + return false; + } + + // 4b. No intervening writes to the buffer + // If our truncf store is the ONLY store to the buffer, then any value + // read from the buffer must be a value we wrote - no need to match indices. + if (!hasNoInterveningWrites(truncfStore.targetBuffer, truncfStore.storeOp)) { + LLVM_DEBUG(llvm::dbgs() << "\t\tBuffer has other stores\n"); + return false; + } + // Note: We skip explicit index matching (4c) because hasNoInterveningWrites + // guarantees our store is the only writer. Combined with the dominance check, + // this means any loaded value must have come from our truncf store. + + // Note: We don't check if the wide value dominates the extf (4d) because + // the wide value is typically defined inside a loop body and won't be + // accessible at the extf location. Instead, Step 5 will create a shadow + // buffer to store the wide value and redirect reads there. + + return true; +} + // Helper to check if a store operation directly stores a value. // Returns the target buffer and indices if it's a supported store type. static FailureOr>> @@ -223,7 +312,11 @@ void RockRemoveRedundantCastsPass::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "Found " << truncfStores.size() << " truncf -> store patterns to analyze.\n"); - // For each truncf -> store pair, find load -> extf readers + // Collect verified optimization candidates + SmallVector candidates; + + // For each truncf -> store pair, find load -> extf readers and verify safety + DominanceInfo domInfo(funcOp); for (const TruncfStoreInfo &truncfStore : truncfStores) { Type wideType = getElementTypeOrSelf(truncfStore.wideValue.getType()); LLVM_DEBUG(llvm::dbgs() << "Analyzing buffer: " << truncfStore.targetBuffer @@ -238,9 +331,26 @@ void RockRemoveRedundantCastsPass::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "\tFound " << extfReaders.size() << " load -> extf readers.\n"); + + // Verify safety of each load -> extf reader + for (const LoadExtfInfo &loadExtf : extfReaders) { + LLVM_DEBUG(llvm::dbgs() << "\tVerifying: load=" << *loadExtf.loadOp + << ", extf=" << loadExtf.extfOp << "\n"); + + if (verifySafety(truncfStore, loadExtf, domInfo)) { + LLVM_DEBUG(llvm::dbgs() << "\t\tSafety verified!\n"); + candidates.push_back({truncfStore, loadExtf}); + } + } } - // TODO: Implement remaining steps of the algorithm: - // Step 4: Verify safety (dominance, no intervening writes, same indices) - // Step 5: Apply the optimization + if (candidates.empty()) { + LLVM_DEBUG(llvm::dbgs() << "No safe optimization candidates found.\n"); + return; + } + + LLVM_DEBUG(llvm::dbgs() << "Found " << candidates.size() + << " safe truncation/extension candidates.\n"); + + // TODO: Step 5: Apply the optimization } From 90266321e3010fb52d694af91a689a27e9406139 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 6 Jan 2026 21:52:16 +0000 Subject: [PATCH 06/12] Initial attempt at LLVMIR level transformation --- mlir/include/mlir/Dialect/Rock/Passes.h | 2 +- mlir/include/mlir/Dialect/Rock/Passes.td | 30 +- mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp | 2 +- .../Rock/Transforms/RemoveRedundantCasts.cpp | 866 +++++++++++++----- 4 files changed, 646 insertions(+), 254 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/Passes.h b/mlir/include/mlir/Dialect/Rock/Passes.h index 33130b81f36d..305669ac548a 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.h +++ b/mlir/include/mlir/Dialect/Rock/Passes.h @@ -41,6 +41,7 @@ namespace rock { #define GEN_PASS_DECL_ROCKVIEWTOTRANSFORMPASS #define GEN_PASS_DECL_ROCKDETECTFLASHDECODINGPASS #define GEN_PASS_DECL_ROCKLOWERREDUCEPASS +#define GEN_PASS_DECL_ROCKREMOVEREDUNDANTCASTSPASS #define GEN_PASS_DECL_ROCKPREPARELLVMPASS #define GEN_PASS_DECL_ROCKCHECKRESIDENCYPASS #define GEN_PASS_DECL_ROCKVECTORIZEFUSIONSPASS @@ -51,7 +52,6 @@ namespace rock { #define GEN_PASS_DECL_ROCKSORTDIMENSIONSMEMORYLAYOUTPASS #define GEN_PASS_DECL_ROCKFINDFIRSTGEMMINDEXPASS #define GEN_PASS_DECL_ROCKREMOVEOUTPUTALLOCPASS -#define GEN_PASS_DECL_ROCKREMOVEREDUNDANTCASTSPASS #define GEN_PASS_DECL_ROCKBLOCKWISELOADTILETOTHREADWISEPASS #define GEN_PASS_DECL_ROCKANNOTATELIVENESSPASS #define GEN_PASS_DECL_ROCKADDASYNCWAITPASS diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index 6607510e6ed1..3b911d11e28b 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -152,23 +152,6 @@ def RockVectorizeFusionsPass : Pass<"rock-vectorize-fusions", "::mlir::func::Fun let summary = "Vectorize affine element-wise loops"; } -def RockRemoveRedundantCastsPass : Pass<"rock-remove-redundant-casts", "::mlir::func::FuncOp"> { - let summary = "Remove redundant truncf/extf pairs through buffers"; - let description = [{ - Detects patterns where wider float values are truncated to a narrower - float type, stored to a buffer, then loaded and extended back to the - original wider input type. Replaces the extf uses with the original wide - values, preserving precision. - }]; - let dependentDialects = [ - "rock::RockDialect", - "linalg::LinalgDialect", - "arith::ArithDialect", - "memref::MemRefDialect", - "vector::VectorDialect" - ]; -} - def RockBufferLoadMergePass : Pass<"rock-buffer-load-merge", "::mlir::func::FuncOp"> { let summary = "Merge identical memory loads to buffers only read. Assumes noalias."; let dependentDialects = ["::mlir::amdgpu::AMDGPUDialect"]; @@ -189,6 +172,19 @@ def RockLowerReducePass : Pass<"rock-lower-reduce", "::mlir::func::FuncOp"> { let dependentDialects = ["rock::RockDialect", "func::FuncDialect", "gpu::GPUDialect"]; } +def RockRemoveRedundantCastsPass + : Pass<"rock-remove-redundant-casts", "::mlir::LLVM::LLVMFuncOp"> { + let summary = "Remove redundant fptrunc/fpext pairs through buffers at LLVM dialect level"; + let description = [{ + Detects patterns at the LLVM dialect level where wider float values are + truncated (llvm.fptrunc) to a narrower type, stored to a buffer, then loaded + and extended (llvm.fpext) back to the original wider type. This pass + redirects the loads to read the wide values directly, eliminating the + fpext and preserving precision. + }]; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + def RockPrepareLLVMPass : Pass<"rock-prepare-llvm", "::mlir::LLVM::LLVMFuncOp"> { let summary = "prepare the generated code for llvm"; let dependentDialects = ["ROCDL::ROCDLDialect"]; diff --git a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp index 38fb4faa9488..5984fcb1a6c3 100644 --- a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp +++ b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp @@ -196,7 +196,6 @@ void rock::buildKernelPipeline(OpPassManager &pm, funcPm.addPass(createCanonicalizerPass()); funcPm.addPass(createConvertLinalgToAffineLoopsPass()); funcPm.addPass(rock::createRockVectorizeFusionsPass()); - funcPm.addPass(rock::createRockRemoveRedundantCastsPass()); funcPm.addPass(rock::createRockAddAsyncWaitPass()); // We run reuse LDS before the output swizzle pass because it uses a // heuristic to determine whether to swizzle or not, and that heuristic @@ -306,6 +305,7 @@ void rock::buildBackendPipeline(OpPassManager &pm, // descriptors. (Mainly we want the `extractvalue` fold). llvmFuncPm.addPass(createCanonicalizerPass()); llvmFuncPm.addPass(createCSEPass()); + llvmFuncPm.addPass(rock::createRockRemoveRedundantCastsPass()); llvmFuncPm.addPass(rock::createRockPrepareLLVMPass()); if (options.compile) { GpuROCDLAttachTargetOptions opts; diff --git a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp index 89c6c9d856b5..9e9f481e18cb 100644 --- a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp @@ -1,4 +1,4 @@ -//===--------------------- RemoveRedundantCasts.cpp -----------------------===// +//===-------------------- RemoveRedundantCasts.cpp ------------------------===// // // Copyright 2026 The MLIR Authors. // @@ -15,29 +15,46 @@ // limitations under the License. //===----------------------------------------------------------------------===// // -// This pass detects patterns where wider float values are truncated to a -// narrower float type, stored to a buffer, then loaded and extended back to the -// original wider input type. Replaces the extf uses with the original wide -// values, preserving precision. +// This pass detects patterns at the LLVM dialect level where wider float values +// are truncated (llvm.fptrunc) to a narrower type, stored to a buffer, then +// loaded and extended (llvm.fpext) back to the original wider type. This pass +// creates a parallel wide buffer (if one doesn't exist) and redirects the loads +// to read the wide values directly, eliminating the fpext and preserving +// precision. // -// Note: The simpler truncf -> extf folding with no loads/stores is already -// handled by arith.truncf canonicalization patterns. This pass specifically -// deals with the more complex case where the values are stored to buffers. +// Algorithm: +// 1. Find all fptrunc -> store patterns in the function. For each pattern, +// record whether there's already a parallel store of the wide value to +// a separate buffer. +// 2. Find all load -> fpext patterns where the load is from a buffer that +// has fptrunc stores. +// 3. Verify safety for each load+fpext pattern: +// - All stores to the narrow buffer must be from tracked fptrunc patterns +// (i.e., no untracked stores that could write different values) +// - All tracked stores must dominate the load +// - The narrow buffer must be an alloca +// 4. For safe patterns, create a wide buffer and the corresponding stores if +// they don't exist. If a parallel store already exists, reuse it: +// - Create a wide alloca right after the narrow alloca +// - For each fptrunc store, insert a store of the wide value to the +// wide buffer (right after the narrow store, using the same indices) +// 5. Apply the transformation: +// - Redirect the load to read from the wide buffer instead +// - Replace uses of the fpext result with the wide load result +// - Delete the fpext (and the old load/GEP if unused) +// 6. Clean up unused narrow buffer operations: +// - If the narrow buffer has no remaining uses, erase the fptrunc stores +// - These can only be erased if they are not used by any other +// operations +// - Erase the narrow alloca if it has no remaining uses // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Rock/IR/Rock.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Rock/Passes.h" -#include "mlir/Dialect/UB/IR/UBOps.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dominance.h" -#include "mlir/IR/PatternMatch.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" @@ -54,240 +71,615 @@ namespace rock { #define DEBUG_TYPE "rock-remove-redundant-casts" using namespace mlir; -using namespace mlir::rock; +using namespace mlir::LLVM; namespace { -struct TruncfStoreInfo { - arith::TruncFOp truncfOp; - Value wideValue; - Operation *storeOp; - Value targetBuffer; - SmallVector storeIndices; -}; +// Information about a fptrunc -> store pattern +struct FPTruncStoreInfo { + Value wideValue; // The input to fptrunc (original wide value) + StoreOp narrowStore; // The store of the narrow value + Value narrowBuffer; // Base alloca of narrow store + GEPOp narrowGep; // GEP for narrow store (null if storing directly) + + // These may be null if no parallel store exists yet + Value wideBuffer; + StoreOp wideStore; -struct LoadExtfInfo { - Operation *loadOp; - arith::ExtFOp extfOp; - SmallVector loadIndices; + bool hasParallelStore() const { return wideStore != nullptr; } }; -// A verified candidate for optimization - a (truncf->store, load->extf) pair -// that has passed all safety checks. -struct OptimizationCandidate { - TruncfStoreInfo truncfStore; - LoadExtfInfo loadExtf; +// Information about a load + fpext pattern that can potentially be optimized. +struct LoadFPExtPattern { + LoadOp loadOp; // The load from narrow buffer + FPExtOp fpextOp; // The fpext that extends the loaded value + Value narrowBuffer; // Base pointer of the narrow buffer being loaded + GEPOp gepOp; // The GEP operation (if any) used for indexing + + // All fptrunc stores that contribute to covering the buffer. + SmallVector matchingStores; }; -// Check if there are any other stores to the buffer that could interfere. -// Returns true if there are no intervening writes. -static bool hasNoInterveningWrites(Value buffer, Operation *ourStore) { - // Conservative check: ensure our store is the ONLY store to this buffer. - // This handles the common case where a buffer is written once and read - // multiple times. - for (Operation *user : buffer.getUsers()) { - // Skip our own store - if (user == ourStore) - continue; +// Get the base pointer from a value, tracing through GEP operations. +static Value getBasePointer(Value ptr) { + while (auto gep = ptr.getDefiningOp()) { + ptr = gep.getBase(); + } + return ptr; +} + +// Get the scalar element type, unwrapping vectors if needed. +static Type getScalarType(Type type) { + if (auto vecType = dyn_cast(type)) + return vecType.getElementType(); + return type; +} + +// Find all fptrunc -> store patterns in the function. +static SmallVector +findFPTruncStorePatterns(LLVMFuncOp funcOp) { + SmallVector results; + + funcOp.walk([&](FPTruncOp fptruncOp) -> WalkResult { + LLVM_DEBUG(llvm::dbgs() << "Found fptrunc: " << fptruncOp << "\n"); + Value wideValue = fptruncOp.getArg(); + Value narrowValue = fptruncOp.getRes(); + + // Find direct stores of the narrow value + for (Operation *user : narrowValue.getUsers()) { + auto narrowStore = dyn_cast(user); + if (!narrowStore) + continue; + + Value narrowPtr = narrowStore.getAddr(); + Value narrowBuffer = getBasePointer(narrowPtr); + GEPOp narrowGep = narrowPtr.getDefiningOp(); - // Check if this user is a store operation - if (isa(user)) { LLVM_DEBUG(llvm::dbgs() - << "\t\tFound another store to buffer: " << *user << "\n"); - return false; + << "\tFound narrow store: " << narrowStore << "\n"); + + FPTruncStoreInfo info; + info.wideValue = wideValue; + info.narrowStore = narrowStore; + info.narrowBuffer = narrowBuffer; + info.narrowGep = narrowGep; + info.wideBuffer = nullptr; + info.wideStore = nullptr; + + // Look for existing parallel wide store + for (Operation *wideUser : wideValue.getUsers()) { + auto wideStore = dyn_cast(wideUser); + if (!wideStore) + continue; + + Value widePtr = wideStore.getAddr(); + Value wideBuffer = getBasePointer(widePtr); + + if (wideBuffer == narrowBuffer) + continue; + + LLVM_DEBUG(llvm::dbgs() << "\tFound existing parallel wide store: " + << wideStore << "\n"); + + info.wideBuffer = wideBuffer; + info.wideStore = wideStore; + break; + } + + results.push_back(info); } - } - return true; + + return WalkResult::advance(); + }); + + return results; } -// Find the ancestor of 'op' that is a direct child of 'block'. -static Operation *getAncestorInBlock(Operation *op, Block *block) { - while (op && op->getBlock() != block) - op = op->getParentOp(); - return op; +// Find all load + fpext patterns. +static SmallVector findLoadFPExtPatterns(LLVMFuncOp funcOp) { + SmallVector results; + + funcOp.walk([&](FPExtOp fpextOp) -> WalkResult { + Value input = fpextOp.getArg(); + auto loadOp = input.getDefiningOp(); + if (!loadOp) + return WalkResult::advance(); + + Value loadPtr = loadOp.getAddr(); + Value narrowBuffer = getBasePointer(loadPtr); + GEPOp gepOp = loadPtr.getDefiningOp(); + + LLVM_DEBUG(llvm::dbgs() << "Found load+fpext pattern:\n"); + LLVM_DEBUG(llvm::dbgs() << "\tLoad: " << loadOp << "\n"); + LLVM_DEBUG(llvm::dbgs() << "\tFPExt: " << fpextOp << "\n"); + + LoadFPExtPattern pattern; + pattern.loadOp = loadOp; + pattern.fpextOp = fpextOp; + pattern.narrowBuffer = narrowBuffer; + pattern.gepOp = gepOp; + results.push_back(pattern); + return WalkResult::advance(); + }); + + return results; } -// Check if storeOp's enclosing operation dominates loadOp's enclosing operation. -// For ops in nested regions, finds their ancestors at a common nesting level -// and checks dominance between those ancestors. -static bool storeEnclosingOpDominatesLoad(Operation *storeOp, Operation *loadOp, - DominanceInfo &domInfo) { - // If they're in the same block, use direct dominance - if (storeOp->getBlock() == loadOp->getBlock()) - return domInfo.properlyDominates(storeOp, loadOp); - - // Find a common ancestor block by walking up from the load - for (Operation *loadWalk = loadOp; loadWalk; - loadWalk = loadWalk->getParentOp()) { - Block *block = loadWalk->getBlock(); - if (Operation *storeAncestor = getAncestorInBlock(storeOp, block)) { - // Found common block - check if store's ancestor dominates load's - return domInfo.properlyDominates(storeAncestor, loadWalk); +// Collect all stores that write to a buffer, tracing through GEPs. +static SmallVector collectStoresToBuffer(Value buffer) { + SmallVector stores; + SmallVector worklist; + worklist.push_back(buffer); + + while (!worklist.empty()) { + Value ptr = worklist.pop_back_val(); + for (Operation *user : ptr.getUsers()) { + if (auto store = dyn_cast(user)) { + // Only count stores TO this address, not stores OF this value + if (store.getAddr() == ptr) { + stores.push_back(store); + } + } else if (auto gep = dyn_cast(user)) { + // Trace through GEPs that use this as base + if (gep.getBase() == ptr) { + worklist.push_back(gep.getResult()); + } + } } } + return stores; +} - return false; +// Check if a store is from one of our tracked fptrunc patterns. +static bool +isStoreFromFPTruncPattern(StoreOp store, + const SmallVector &storeInfos) { + return llvm::any_of( + storeInfos, [&](const auto &info) { return info.narrowStore == store; }); } -// Verify that a (truncf->store, load->extf) pair is safe to optimize. -// Returns true if all safety conditions are met. -static bool verifySafety(const TruncfStoreInfo &truncfStore, - const LoadExtfInfo &loadExtf, DominanceInfo &domInfo) { - // 4a. Store's enclosing op must dominate load's enclosing op - if (!storeEnclosingOpDominatesLoad(truncfStore.storeOp, loadExtf.loadOp, - domInfo)) { - LLVM_DEBUG(llvm::dbgs() << "\t\tStore does not dominate load\n"); - return false; +// Represents a range of element indices [start, start + count). +struct IndexRange { + int64_t start; + int64_t count; + + bool isValid() const { return count > 0; } + + bool isSubsetOf(const IndexRange &other) const { + return start >= other.start && + (start + count) <= (other.start + other.count); } - // 4b. No intervening writes to the buffer - // If our truncf store is the ONLY store to the buffer, then any value - // read from the buffer must be a value we wrote - no need to match indices. - if (!hasNoInterveningWrites(truncfStore.targetBuffer, truncfStore.storeOp)) { - LLVM_DEBUG(llvm::dbgs() << "\t\tBuffer has other stores\n"); - return false; + bool overlaps(const IndexRange &other) const { + return start < (other.start + other.count) && other.start < (start + count); } - // Note: We skip explicit index matching (4c) because hasNoInterveningWrites - // guarantees our store is the only writer. Combined with the dominance check, - // this means any loaded value must have come from our truncf store. +}; - // Note: We don't check if the wide value dominates the extf (4d) because - // the wide value is typically defined inside a loop body and won't be - // accessible at the extf location. Instead, Step 5 will create a shadow - // buffer to store the wide value and redirect reads there. +// Get the index range for a memory access. Returns an invalid range if we can't +// determine it (e.g., dynamic indices). +static IndexRange getAccessRange(GEPOp gep, Type accessType) { + int64_t elementCount = 1; + if (auto vecType = dyn_cast(accessType)) + elementCount = vecType.getNumElements(); - return true; + // No GEP means accessing at base (index 0) + if (!gep) + return {0, elementCount}; + + // Check if all indices are constant + auto indices = gep.getIndices(); + if (indices.empty()) + return {0, elementCount}; + + // We only handle single-index GEPs with constant index + if (indices.size() != 1) + return {-1, 0}; // Invalid + + auto constIdx = dyn_cast(indices[0]); + if (!constIdx) + return {-1, 0}; // Invalid + + return {constIdx.getInt(), elementCount}; } -// Helper to check if a store operation directly stores a value. -// Returns the target buffer and indices if it's a supported store type. -static FailureOr>> -getStoreBufferAndIndices(Operation *op, Value storedValue) { - if (auto inBoundsStore = dyn_cast(op)) { - if (inBoundsStore.getData() == storedValue) { - return std::pair>( - inBoundsStore.getDest(), - SmallVector(inBoundsStore.getCoords())); - } - } else if (auto vectorStore = dyn_cast(op)) { - if (vectorStore.getValueToStore() == storedValue) { - return std::pair>( - vectorStore.getBase(), - SmallVector(vectorStore.getIndices())); - } - } else if (auto memrefStore = dyn_cast(op)) { - if (memrefStore.getValue() == storedValue) { - return std::pair>( - memrefStore.getMemRef(), - SmallVector(memrefStore.getIndices())); +// Get the total size (in elements) of a buffer from its alloca. +static int64_t getBufferSize(Value buffer) { + auto alloca = buffer.getDefiningOp(); + if (!alloca) + return -1; + + // Get array size (number of elements allocated) + Value arraySizeVal = alloca.getArraySize(); + if (auto constOp = arraySizeVal.getDefiningOp()) { + if (auto intAttr = dyn_cast(constOp.getValue())) + return intAttr.getInt(); + } + return -1; // Dynamic or unknown size +} + +// Find all fptrunc stores that cover the load's location. Returns all +// dominating stores if they collectively cover the entire buffer, otherwise +// returns an empty list. +static SmallVector +findMatchingFPTruncStores(LoadFPExtPattern &pattern, + SmallVector &storeInfos, + DominanceInfo &domInfo) { + SmallVector dominatingStores; + int64_t bufferSize = getBufferSize(pattern.narrowBuffer); + + if (bufferSize <= 0) + return dominatingStores; + + // Track which elements are covered + std::vector covered(bufferSize, false); + + for (auto &info : storeInfos) { + if (info.narrowBuffer != pattern.narrowBuffer) + continue; + if (!domInfo.dominates(info.narrowStore.getOperation(), + pattern.loadOp.getOperation())) + continue; + + dominatingStores.push_back(&info); + + IndexRange storeRange = + getAccessRange(info.narrowGep, info.narrowStore.getValue().getType()); + if (!storeRange.isValid()) + continue; + + // Mark covered elements + for (int64_t i = storeRange.start; + i < storeRange.start + storeRange.count && i < bufferSize; ++i) { + if (i >= 0) + covered[i] = true; } } - return failure(); + + // Check if all elements are covered + for (bool c : covered) { + if (!c) + return {}; // Not fully covered, return empty + } + return dominatingStores; +} + +// Check that no non-fptrunc stores could intervene between the fptrunc stores +// and the load that would overwrite the fptrunc'd value. +static bool +hasNoInterveningStores(LoadFPExtPattern &pattern, + const SmallVector &storeInfos, + DominanceInfo &domInfo) { + IndexRange loadRange = + getAccessRange(pattern.gepOp, pattern.loadOp.getRes().getType()); + + SmallVector allStores = collectStoresToBuffer(pattern.narrowBuffer); + + for (auto store : allStores) { + // Skip fptrunc stores + if (isStoreFromFPTruncPattern(store, storeInfos)) + continue; + + // If load dominates store, the store happens after the load on all paths + if (domInfo.dominates(pattern.loadOp.getOperation(), store.getOperation())) + continue; + + // This non-fptrunc store could execute before the load on some path. + // Check if it could overwrite what the load is reading. + GEPOp storeGep = store.getAddr().getDefiningOp(); + IndexRange storeRange = + getAccessRange(storeGep, store.getValue().getType()); + + // If we can determine ranges and they don't overlap, it's safe + if (storeRange.isValid() && loadRange.isValid() && + !storeRange.overlaps(loadRange)) + continue; + + LLVM_DEBUG(llvm::dbgs() + << "\tUNSAFE: Non-fptrunc store could overwrite value: " << store + << "\n"); + return false; + } + return true; } -// Helper to check if a load operation reads from a specific buffer. -// Returns the loaded value and indices if it's a supported load type. -static FailureOr>> -getLoadResultAndIndices(Operation *op, Value expectedBuffer) { - if (auto inBoundsLoad = dyn_cast(op)) { - if (inBoundsLoad.getSource() == expectedBuffer) { - return std::pair>( - inBoundsLoad.getResult(), - SmallVector(inBoundsLoad.getCoords())); +// Verify that a load -> fpext pattern is safe to optimize. +static FailureOr> +verifySafety(LoadFPExtPattern &pattern, + SmallVector &storeInfos, + DominanceInfo &domInfo) { + // Check that the narrow buffer is an alloca + if (!pattern.narrowBuffer.getDefiningOp()) { + LLVM_DEBUG(llvm::dbgs() << "\tUNSAFE: Narrow buffer is not an alloca\n"); + return failure(); + } + + // Find all fptrunc stores that cover this load + SmallVector matchingStores = + findMatchingFPTruncStores(pattern, storeInfos, domInfo); + if (matchingStores.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "\tUNSAFE: No matching fptrunc stores found for load\n"); + return failure(); + } + + // Check that all matching stores have compatible element types + Type fpextElemType = getScalarType(pattern.fpextOp.getRes().getType()); + for (auto *store : matchingStores) { + Type originalWideElemType = getScalarType(store->wideValue.getType()); + if (fpextElemType != originalWideElemType) { + LLVM_DEBUG( + llvm::dbgs() + << "\tUNSAFE: Element type mismatch - fpext result element type " + << fpextElemType << " != original wide element type " + << originalWideElemType << "\n"); + return failure(); } - } else if (auto transferRead = dyn_cast(op)) { - if (transferRead.getBase() == expectedBuffer) { - return std::pair>( - transferRead.getResult(), - SmallVector(transferRead.getIndices())); + + // For existing parallel stores, check wide store dominance + if (store->hasParallelStore() && + !domInfo.dominates(store->wideStore.getOperation(), + pattern.loadOp.getOperation())) { + LLVM_DEBUG(llvm::dbgs() + << "\tUNSAFE: Wide store does not dominate load\n"); + return failure(); } - } else if (auto memrefLoad = dyn_cast(op)) { - if (memrefLoad.getMemRef() == expectedBuffer) { - return std::pair>( - memrefLoad.getResult(), - SmallVector(memrefLoad.getIndices())); + } + + // Check that no non-fptrunc stores could intervene + if (!hasNoInterveningStores(pattern, storeInfos, domInfo)) + return failure(); + + LLVM_DEBUG(llvm::dbgs() << "\tSAFE: All checks passed\n"); + return matchingStores; +} + +// Create a wide store for an fptrunc store info, using the given wide buffer. +static void createWideStore(FPTruncStoreInfo *info, Value wideBuffer, + Type wideElemType, OpBuilder &builder) { + builder.setInsertionPointAfter(info->narrowStore); + + Value widePtr; + if (info->narrowGep) { + SmallVector gepArgs; + for (auto idx : info->narrowGep.getIndices()) { + if (auto constIdx = dyn_cast(idx)) + gepArgs.push_back(static_cast(constIdx.getInt())); + else + gepArgs.push_back(cast(idx)); } + auto wideGep = GEPOp::create(builder, info->narrowGep.getLoc(), + info->narrowGep.getType(), wideElemType, + wideBuffer, gepArgs); + wideGep.setNoWrapFlags(info->narrowGep.getNoWrapFlags()); + widePtr = wideGep.getResult(); + } else { + widePtr = wideBuffer; } - return failure(); + + auto wideStore = StoreOp::create(builder, info->narrowStore.getLoc(), + info->wideValue, widePtr); + info->wideBuffer = wideBuffer; + info->wideStore = wideStore; + LLVM_DEBUG(llvm::dbgs() << "Created wide store: " << wideStore << "\n"); } -// Find all load -> extf patterns from a given buffer. -// A "direct extf" means the load result is used immediately by an extf -// operation with no intermediate operations modifying the value. -SmallVector findDirectExtfReaders(Value narrowBuffer, - Type wideType) { - SmallVector results; - - // Iterate over direct users of the buffer - for (Operation *user : narrowBuffer.getUsers()) { - // Check if this user is a load from our buffer - FailureOr>> loadInfo = - getLoadResultAndIndices(user, narrowBuffer); - if (failed(loadInfo)) +// Create wide buffers and stores for safe patterns that don't already have +// them. For patterns with existing parallel wide stores, do nothing. For +// patterns without parallel stores, create a new wide alloca and insert a wide +// store right after each narrow store. +static void +createWideBuffersAndStores(SmallVector &safePatterns, + OpBuilder &builder) { + DenseMap narrowToWideBuffer; + DenseSet processedStores; + + for (auto &pattern : safePatterns) { + if (pattern.matchingStores.empty()) continue; - Value loadResult = loadInfo->first; + Type wideElemType = getScalarType(pattern.fpextOp.getRes().getType()); - // Check if the load result is used directly by an arith.extf - for (Operation *loadUser : loadResult.getUsers()) { - auto extfOp = dyn_cast(loadUser); - if (!extfOp) + for (FPTruncStoreInfo *info : pattern.matchingStores) { + if (processedStores.contains(info)) continue; + processedStores.insert(info); - // Verify the extf output type matches the expected wide type - Type extfOutputType = getElementTypeOrSelf(extfOp.getOut().getType()); - if (extfOutputType != wideType) + if (info->hasParallelStore()) continue; - LoadExtfInfo info; - info.loadOp = user; - info.extfOp = extfOp; - info.loadIndices = std::move(loadInfo->second); - results.push_back(info); + // Check if we already created a wide buffer for this narrow buffer + auto it = narrowToWideBuffer.find(info->narrowBuffer); + if (it != narrowToWideBuffer.end()) { + createWideStore(info, it->second, wideElemType, builder); + continue; + } + + // Need to create new wide buffer + auto narrowAlloca = info->narrowBuffer.getDefiningOp(); + if (!narrowAlloca) { + LLVM_DEBUG( + llvm::dbgs() + << "Cannot create wide buffer: narrow buffer is not alloca\n"); + continue; + } + + builder.setInsertionPointAfter(narrowAlloca); + auto wideAlloca = + AllocaOp::create(builder, narrowAlloca.getLoc(), + LLVM::LLVMPointerType::get(builder.getContext()), + wideElemType, narrowAlloca.getArraySize()); + + LLVM_DEBUG(llvm::dbgs() << "Created wide alloca: " << wideAlloca << "\n"); + + narrowToWideBuffer[info->narrowBuffer] = wideAlloca.getResult(); + createWideStore(info, wideAlloca.getResult(), wideElemType, builder); } } - - return results; } -// Find all arith.truncf operations that are directly stored to a buffer. -// A "direct store" means the truncf result is used immediately by a store -// operation with no intermediate operations modifying the value. -SmallVector findTruncfWithDirectStores(func::FuncOp funcOp) { - SmallVector results; +// Apply the transformation: redirect loads from narrow buffer to wide buffer, +// eliminating the fpext operations. +static void applyTransformation(SmallVector &safePatterns, + OpBuilder &builder) { + for (auto &pattern : safePatterns) { + // Find the first matching store with a wide buffer + Value wideBuffer; + for (auto *store : pattern.matchingStores) { + if (store->wideBuffer) { + wideBuffer = store->wideBuffer; + break; + } + } + if (!wideBuffer) { + LLVM_DEBUG(llvm::dbgs() << "No wide buffer for pattern, skipping\n"); + continue; + } - funcOp.walk([&](arith::TruncFOp truncfOp) -> WalkResult { - Type inputType = getElementTypeOrSelf(truncfOp.getIn().getType()); - Type outputType = getElementTypeOrSelf(truncfOp.getOut().getType()); + LLVM_DEBUG(llvm::dbgs() << "Transforming pattern:\n"); + LLVM_DEBUG(llvm::dbgs() << " Load: " << pattern.loadOp << "\n"); + LLVM_DEBUG(llvm::dbgs() << " FPExt: " << pattern.fpextOp << "\n"); - // Check that this is a narrowing conversion (truncf) - if (outputType.getIntOrFloatBitWidth() >= inputType.getIntOrFloatBitWidth()) - return WalkResult::advance(); + Type wideType = pattern.fpextOp.getRes().getType(); + Type wideElemType = getScalarType(wideType); - // Check for direct stores of the truncf result - Value truncfResult = truncfOp.getOut(); - Value wideValue = truncfOp.getIn(); + Value newPtr; + if (pattern.gepOp) { + builder.setInsertionPoint(pattern.gepOp); - for (Operation *user : truncfResult.getUsers()) { - FailureOr>> storeInfo = - getStoreBufferAndIndices(user, truncfResult); - if (failed(storeInfo)) - continue; + SmallVector gepArgs; + for (auto idx : pattern.gepOp.getIndices()) { + if (auto constIdx = dyn_cast(idx)) { + gepArgs.push_back(static_cast(constIdx.getInt())); + } else { + gepArgs.push_back(cast(idx)); + } + } - TruncfStoreInfo info; - info.truncfOp = truncfOp; - info.wideValue = wideValue; - info.storeOp = user; - info.targetBuffer = storeInfo->first; - info.storeIndices = std::move(storeInfo->second); - results.push_back(info); + auto newGep = GEPOp::create(builder, pattern.gepOp.getLoc(), + pattern.gepOp.getType(), wideElemType, + wideBuffer, gepArgs); + newGep.setNoWrapFlags(pattern.gepOp.getNoWrapFlags()); + newPtr = newGep.getResult(); + } else { + newPtr = wideBuffer; } - return WalkResult::advance(); - }); + builder.setInsertionPoint(pattern.loadOp); - return results; + unsigned wideAlignment = 4; + if (auto vecType = dyn_cast(wideType)) { + unsigned elemBits = vecType.getElementType().getIntOrFloatBitWidth(); + wideAlignment = (elemBits / 8) * vecType.getNumElements(); + wideAlignment = std::min(wideAlignment, 16u); + } else { + wideAlignment = wideType.getIntOrFloatBitWidth() / 8; + } + + auto newLoad = LoadOp::create(builder, pattern.loadOp.getLoc(), wideType, + newPtr, wideAlignment); + + // Clean up the load -> fpext + pattern.fpextOp.getRes().replaceAllUsesWith(newLoad.getRes()); + pattern.fpextOp.erase(); + + if (pattern.loadOp.getRes().use_empty()) { + pattern.loadOp.erase(); + } + + if (pattern.gepOp && pattern.gepOp.getRes().use_empty()) { + pattern.gepOp.erase(); + } + + LLVM_DEBUG(llvm::dbgs() << " Transformation complete.\n"); + } +} + +// Clean up unused narrow buffer operations after transformation. +// If the narrow buffer has no remaining uses, we can remove the fptrunc stores, +// the fptrunc ops (if only used by the store), and the narrow alloca. +static void +cleanupUnusedNarrowBufferOps(SmallVector &safePatterns) { + // Collect all narrow buffers and their associated fptrunc stores + DenseMap> bufferToStores; + for (auto &pattern : safePatterns) { + for (auto *info : pattern.matchingStores) { + bufferToStores[info->narrowBuffer].push_back(info); + } + } + + // Track what we've already erased to avoid double-erase + DenseSet erased; + + for (auto &[narrowBuffer, stores] : bufferToStores) { + // Check if the narrow buffer still has uses + // (other than the stores we're about to erase) + bool hasOtherUses = false; + for (Operation *user : narrowBuffer.getUsers()) { + // Check if this user is one of the stores we might erase + bool isTrackedStore = false; + for (auto *info : stores) { + if (info->narrowStore.getOperation() == user) { + isTrackedStore = true; + break; + } + if (info->narrowGep && info->narrowGep.getOperation() == user) { + isTrackedStore = true; + break; + } + } + if (!isTrackedStore) { + hasOtherUses = true; + break; + } + } + + if (hasOtherUses) { + LLVM_DEBUG(llvm::dbgs() << "Narrow buffer still has other uses, " + << "keeping fptrunc stores\n"); + continue; + } + + // No other uses, so we can clean up the fptrunc stores and related ops + LLVM_DEBUG(llvm::dbgs() << "Cleaning up unused narrow buffer ops\n"); + + for (auto *info : stores) { + // Capture the fptrunc op before erasing the store + auto fptruncOp = info->narrowStore.getValue().getDefiningOp(); + + // Erase the narrow store + if (!erased.contains(info->narrowStore.getOperation())) { + erased.insert(info->narrowStore.getOperation()); + info->narrowStore.erase(); + LLVM_DEBUG(llvm::dbgs() << "\tErased narrow store\n"); + } + + // Erase the GEP if it has no uses + if (info->narrowGep && info->narrowGep.getRes().use_empty() && + !erased.contains(info->narrowGep.getOperation())) { + erased.insert(info->narrowGep.getOperation()); + info->narrowGep.erase(); + LLVM_DEBUG(llvm::dbgs() << "\tErased narrow GEP\n"); + } + + // Erase the fptrunc if it has no uses + if (fptruncOp && fptruncOp.getRes().use_empty() && + !erased.contains(fptruncOp.getOperation())) { + erased.insert(fptruncOp.getOperation()); + fptruncOp.erase(); + LLVM_DEBUG(llvm::dbgs() << "\tErased fptrunc\n"); + } + } + + // Erase the narrow alloca if it has no uses + if (auto narrowAlloca = narrowBuffer.getDefiningOp()) { + if (narrowAlloca.getResult().use_empty() && + !erased.contains(narrowAlloca.getOperation())) { + erased.insert(narrowAlloca.getOperation()); + narrowAlloca.erase(); + LLVM_DEBUG(llvm::dbgs() << "\tErased narrow alloca\n"); + } + } + } } struct RockRemoveRedundantCastsPass @@ -299,58 +691,62 @@ struct RockRemoveRedundantCastsPass } // end namespace void RockRemoveRedundantCastsPass::runOnOperation() { - func::FuncOp funcOp = getOperation(); + LLVMFuncOp funcOp = getOperation(); + OpBuilder builder(funcOp.getContext()); - SmallVector truncfStores = findTruncfWithDirectStores(funcOp); + LLVM_DEBUG(llvm::dbgs() << "Running RockRemoveRedundantCastsPass on " + << funcOp.getName() << "\n"); - if (truncfStores.empty()) { - LLVM_DEBUG(llvm::dbgs() - << "No truncf -> store patterns found, nothing to do.\n"); + // Step 1: Find all fptrunc -> store patterns + SmallVector storeInfo = findFPTruncStorePatterns(funcOp); + if (storeInfo.empty()) { + LLVM_DEBUG(llvm::dbgs() << "No fptrunc -> store patterns found.\n"); return; } + LLVM_DEBUG(llvm::dbgs() << "Found " << storeInfo.size() + << " fptrunc -> store patterns.\n"); + + // Step 2: Find all load -> fpext patterns + SmallVector loadFPExtPatterns = + findLoadFPExtPatterns(funcOp); + if (loadFPExtPatterns.empty()) { + LLVM_DEBUG(llvm::dbgs() << "No load+fpext patterns found.\n"); + return; + } + LLVM_DEBUG(llvm::dbgs() << "Found " << loadFPExtPatterns.size() + << " load+fpext patterns.\n"); - LLVM_DEBUG(llvm::dbgs() << "Found " << truncfStores.size() - << " truncf -> store patterns to analyze.\n"); - - // Collect verified optimization candidates - SmallVector candidates; - - // For each truncf -> store pair, find load -> extf readers and verify safety + // Step 3: Verify safety (applicability) for each pattern DominanceInfo domInfo(funcOp); - for (const TruncfStoreInfo &truncfStore : truncfStores) { - Type wideType = getElementTypeOrSelf(truncfStore.wideValue.getType()); - LLVM_DEBUG(llvm::dbgs() << "Analyzing buffer: " << truncfStore.targetBuffer - << "\n"); - SmallVector extfReaders = - findDirectExtfReaders(truncfStore.targetBuffer, wideType); - - if (extfReaders.empty()) { - LLVM_DEBUG(llvm::dbgs() << "\tNo load -> extf readers found.\n"); - continue; - } - - LLVM_DEBUG(llvm::dbgs() << "\tFound " << extfReaders.size() - << " load -> extf readers.\n"); - - // Verify safety of each load -> extf reader - for (const LoadExtfInfo &loadExtf : extfReaders) { - LLVM_DEBUG(llvm::dbgs() << "\tVerifying: load=" << *loadExtf.loadOp - << ", extf=" << loadExtf.extfOp << "\n"); - - if (verifySafety(truncfStore, loadExtf, domInfo)) { - LLVM_DEBUG(llvm::dbgs() << "\t\tSafety verified!\n"); - candidates.push_back({truncfStore, loadExtf}); - } + SmallVector safePatterns; + for (auto &pattern : loadFPExtPatterns) { + LLVM_DEBUG(llvm::dbgs() + << "Verifying pattern: load=" << pattern.loadOp << "\n"); + FailureOr> result = + verifySafety(pattern, storeInfo, domInfo); + if (succeeded(result)) { + pattern.matchingStores = *result; + safePatterns.push_back(pattern); } } - if (candidates.empty()) { - LLVM_DEBUG(llvm::dbgs() << "No safe optimization candidates found.\n"); + if (safePatterns.empty()) { + LLVM_DEBUG(llvm::dbgs() << "No safe patterns to optimize.\n"); return; } - LLVM_DEBUG(llvm::dbgs() << "Found " << candidates.size() - << " safe truncation/extension candidates.\n"); + LLVM_DEBUG(llvm::dbgs() << "Found " << safePatterns.size() + << " safe pattern combination(s) to optimize.\n"); + + // Step 4: Create wide buffers and stores for patterns that need them + createWideBuffersAndStores(safePatterns, builder); + + // Step 5: Apply transformation (redirect loads to wide buffer) + applyTransformation(safePatterns, builder); + + // Step 6: Clean up unused narrow buffer operations + cleanupUnusedNarrowBufferOps(safePatterns); - // TODO: Step 5: Apply the optimization + LLVM_DEBUG(llvm::dbgs() << "Optimized " << safePatterns.size() + << " patterns.\n"); } From 9a0cbb2757f5523135d57af96fffa838c33eb71b Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Thu, 8 Jan 2026 14:35:45 +0000 Subject: [PATCH 07/12] Add E2E test --- .../pr-e2e/mixr-remove-redundant-casts.mlir | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir diff --git a/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir b/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir new file mode 100644 index 000000000000..3e99eaf73994 --- /dev/null +++ b/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir @@ -0,0 +1,30 @@ +// RUN: rocmlir-gen -fut mlir_remove_casts --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_remove_casts_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s + +// CHECK: [1 1 1] +module { +func.func @mlir_remove_casts(%arg0: !migraphx.shaped<1x75352x5x128xf16, 48225280x640x128x1>, %arg1: !migraphx.shaped<1x75352x5x128xf16, 48225280x640x128x1>, %arg2: !migraphx.shaped<1x75352x5x128xf16, 48225280x640x128x1>) -> !migraphx.shaped<1x5x75352x128xf16, 48225280x9645056x128x1> attributes {kernel = "mixr"} { + %0 = migraphx.literal(dense<8.837890e-02> : tensor<1xf16>) : <1xf16, 0> + %1 = migraphx.transpose %arg0 {permutation = [0, 2, 1, 3]} : <1x75352x5x128xf16, 48225280x640x128x1> -> <1x5x75352x128xf16, 48225280x128x640x1> + %2 = migraphx.transpose %arg1 {permutation = [0, 2, 3, 1]} : <1x75352x5x128xf16, 48225280x640x128x1> -> <1x5x128x75352xf16, 48225280x128x1x640> + %3 = migraphx.transpose %arg2 {permutation = [0, 2, 1, 3]} : <1x75352x5x128xf16, 48225280x640x128x1> -> <1x5x75352x128xf16, 48225280x128x640x1> + %4 = migraphx.dot %1, %2 : <1x5x75352x128xf16, 48225280x128x640x1>, <1x5x128x75352xf16, 48225280x128x1x640> -> <1x5x75352x75352xf16, 28389619520x5677923904x75352x1> + %5 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 5, 75352, 75352]} : <1xf16, 0> -> <1x5x75352x75352xf16, 0x0x0x0> + %6 = migraphx.convert %4 {target_type = 2 : i64} : <1x5x75352x75352xf16, 28389619520x5677923904x75352x1> to <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> + %7 = migraphx.convert %5 {target_type = 2 : i64} : <1x5x75352x75352xf16, 0x0x0x0> to <1x5x75352x75352xf32, 0x0x0x0> + %8 = migraphx.mul %6, %7 : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1>, <1x5x75352x75352xf32, 0x0x0x0> -> <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> + %9 = migraphx.reshape %8 {dims = [1, 5, 75352, 75352]} : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> -> <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> + %10 = migraphx.reduce_max %9 {axes = [3]} : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> -> <1x5x75352x1xf32, 376760x75352x1x1> + %11 = migraphx.reshape %10 {dims = [1, 5, 75352, 1]} : <1x5x75352x1xf32, 376760x75352x1x1> -> <1x5x75352x1xf32, 376760x75352x1x1> + %12 = migraphx.multibroadcast %11 {out_dyn_dims = [], out_lens = [1, 5, 75352, 75352]} : <1x5x75352x1xf32, 376760x75352x1x1> -> <1x5x75352x75352xf32, 376760x75352x1x0> + %13 = migraphx.sub %8, %12 : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1>, <1x5x75352x75352xf32, 376760x75352x1x0> -> <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> + %14 = migraphx.exp %13 : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> -> <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> + %15 = migraphx.reshape %14 {dims = [1, 5, 75352, 75352]} : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> -> <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> + %16 = migraphx.reduce_sum %15 {axes = [3]} : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> -> <1x5x75352x1xf32, 376760x75352x1x1> + %17 = migraphx.reshape %16 {dims = [1, 5, 75352, 1]} : <1x5x75352x1xf32, 376760x75352x1x1> -> <1x5x75352x1xf32, 376760x75352x1x1> + %18 = migraphx.multibroadcast %17 {out_dyn_dims = [], out_lens = [1, 5, 75352, 75352]} : <1x5x75352x1xf32, 376760x75352x1x1> -> <1x5x75352x75352xf32, 376760x75352x1x0> + %19 = migraphx.div %14, %18 : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1>, <1x5x75352x75352xf32, 376760x75352x1x0> -> <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> + %20 = migraphx.convert %19 {target_type = 1 : i64} : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> to <1x5x75352x75352xf16, 28389619520x5677923904x75352x1> + %21 = migraphx.dot %20, %3 : <1x5x75352x75352xf16, 28389619520x5677923904x75352x1>, <1x5x75352x128xf16, 48225280x128x640x1> -> <1x5x75352x128xf16, 48225280x9645056x128x1> + return %21 : !migraphx.shaped<1x5x75352x128xf16, 48225280x9645056x128x1> + } +} \ No newline at end of file From 1e9671ae19bb66f3f55e430420bcd2e6728f82f8 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Thu, 8 Jan 2026 17:21:25 +0000 Subject: [PATCH 08/12] More LIT tests --- .../LLVMIR/remove-redundant-casts.mlir | 252 ++++++++++++++++++ .../pr-e2e/mixr-remove-redundant-casts.mlir | 52 ++-- 2 files changed, 279 insertions(+), 25 deletions(-) create mode 100644 mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir diff --git a/mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir b/mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir new file mode 100644 index 000000000000..aeb976b0e904 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir @@ -0,0 +1,252 @@ +// RUN: rocmlir-opt --rock-remove-redundant-casts %s | FileCheck %s + +// Parallel wide buffer already exists, so the pass should redirect the load to +// the wide buffer, eliminating fpext +// CHECK-LABEL: llvm.func @test_parallel_buffer_exists +llvm.func @test_parallel_buffer_exists() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(4 : i64) : i64 + %2 = llvm.mlir.constant(0 : i32) : i32 + %3 = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32> + %4 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %5 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %6 = llvm.fptrunc %3 : vector<4xf32> to vector<4xf16> + llvm.store %6, %4 : vector<4xf16>, !llvm.ptr<5> + llvm.store %3, %5 : vector<4xf32>, !llvm.ptr<5> + %7 = llvm.getelementptr %4[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %8 = llvm.getelementptr %5[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %9 = llvm.fptrunc %3 : vector<4xf32> to vector<4xf16> + llvm.store %9, %7 : vector<4xf16>, !llvm.ptr<5> + llvm.store %3, %8 : vector<4xf32>, !llvm.ptr<5> + %10 = llvm.getelementptr %4[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %11 = llvm.getelementptr %5[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %12 = llvm.fptrunc %3 : vector<4xf32> to vector<4xf16> + llvm.store %12, %10 : vector<4xf16>, !llvm.ptr<5> + llvm.store %3, %11 : vector<4xf32>, !llvm.ptr<5> + %13 = llvm.getelementptr %4[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %14 = llvm.getelementptr %5[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %15 = llvm.fptrunc %3 : vector<4xf32> to vector<4xf16> + llvm.store %15, %13 : vector<4xf16>, !llvm.ptr<5> + llvm.store %3, %14 : vector<4xf32>, !llvm.ptr<5> + %16 = llvm.load %4 : !llvm.ptr<5> -> vector<4xf16> + %17 = llvm.fpext %16 : vector<4xf16> to vector<4xf32> + // CHECK-NOT: llvm.fpext + // CHECK: llvm.load {{.*}} -> vector<4xf32> + %18 = llvm.fadd %17, %3 : vector<4xf32> + + llvm.return +} + +// No parallel buffer so the pass should create one +// CHECK-LABEL: llvm.func @test_create_wide_buffer +llvm.func @test_create_wide_buffer() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<2.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + // CHECK-NOT: llvm.alloca {{.*}} x f16 + // CHECK: llvm.alloca {{.*}} x f32 + %3 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %3, %2 : vector<4xf16>, !llvm.ptr<5> + %4 = llvm.getelementptr %2[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %5 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %5, %4 : vector<4xf16>, !llvm.ptr<5> + %6 = llvm.getelementptr %2[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %7 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %7, %6 : vector<4xf16>, !llvm.ptr<5> + %8 = llvm.getelementptr %2[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %9 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %9, %8 : vector<4xf16>, !llvm.ptr<5> + %10 = llvm.load %2 : !llvm.ptr<5> -> vector<4xf16> + // CHECK-NOT: llvm.fpext + %11 = llvm.fpext %10 : vector<4xf16> to vector<4xf32> + %12 = llvm.fadd %11, %1 : vector<4xf32> + llvm.return +} + +// The load reads subset of what was stored +// CHECK-LABEL: llvm.func @test_subset_load +llvm.func @test_subset_load() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<3.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %3 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %4 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %4, %2 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %3 : vector<4xf32>, !llvm.ptr<5> + %5 = llvm.getelementptr %2[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %6 = llvm.getelementptr %3[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %7 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %7, %5 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %6 : vector<4xf32>, !llvm.ptr<5> + %8 = llvm.getelementptr %2[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %9 = llvm.getelementptr %3[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %10 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %10, %8 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %9 : vector<4xf32>, !llvm.ptr<5> + %11 = llvm.getelementptr %2[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %12 = llvm.getelementptr %3[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %13 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %13, %11 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %12 : vector<4xf32>, !llvm.ptr<5> + %14 = llvm.load %2 : !llvm.ptr<5> -> vector<2xf16> + // CHECK: llvm.load {{.*}} -> vector<2xf32> + %15 = llvm.fpext %14 : vector<2xf16> to vector<2xf32> + // CHECK-NOT: llvm.fpext + %16 = llvm.mlir.constant(dense<1.000000e+00> : vector<2xf32>) : vector<2xf32> + %17 = llvm.fadd %15, %16 : vector<2xf32> + llvm.return +} + +// Unsafe case: intervening non-fptrunc store +// CHECK-LABEL: llvm.func @test_unsafe_intervening_store +llvm.func @test_unsafe_intervening_store() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<4.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.mlir.constant(dense<9.000000e+00> : vector<4xf16>) : vector<4xf16> + %3 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %4 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %5 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %5, %3 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %4 : vector<4xf32>, !llvm.ptr<5> + llvm.store %2, %3 : vector<4xf16>, !llvm.ptr<5> + %6 = llvm.load %3 : !llvm.ptr<5> -> vector<4xf16> + %7 = llvm.fpext %6 : vector<4xf16> to vector<4xf32> + // CHECK: llvm.fpext + %8 = llvm.fadd %7, %1 : vector<4xf32> + llvm.return +} + +// Unsafe case: fpext to different type than original +// CHECK-LABEL: llvm.func @test_type_mismatch +llvm.func @test_type_mismatch() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<5.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %3 = llvm.alloca %0 x f64 : (i64) -> !llvm.ptr<5> + %4 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %4, %2 : vector<4xf16>, !llvm.ptr<5> + %5 = llvm.load %2 : !llvm.ptr<5> -> vector<4xf16> + %6 = llvm.fpext %5 : vector<4xf16> to vector<4xf64> + // CHECK: llvm.fpext + %7 = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf64>) : vector<4xf64> + %8 = llvm.fadd %6, %7 : vector<4xf64> + llvm.return +} + +// Unsafe case: narrow buffer is not an alloca (function argument) +// CHECK-LABEL: llvm.func @test_not_alloca +llvm.func @test_not_alloca(%narrow_buf: !llvm.ptr<5>) { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<6.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %3 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %3, %narrow_buf : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %2 : vector<4xf32>, !llvm.ptr<5> + %4 = llvm.load %narrow_buf : !llvm.ptr<5> -> vector<4xf16> + %5 = llvm.fpext %4 : vector<4xf16> to vector<4xf32> + // CHECK: llvm.fpext + %6 = llvm.fadd %5, %1 : vector<4xf32> + llvm.return +} + +// Unsafe case: buffer not fully covered by fptrunc stores (partial coverage) +// CHECK-LABEL: llvm.func @test_partial_coverage +llvm.func @test_partial_coverage() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<7.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %3 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %4 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %4, %2 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %3 : vector<4xf32>, !llvm.ptr<5> + %5 = llvm.load %2 : !llvm.ptr<5> -> vector<4xf16> + %6 = llvm.fpext %5 : vector<4xf16> to vector<4xf32> + // CHECK: llvm.fpext + %7 = llvm.fadd %6, %1 : vector<4xf32> + llvm.return +} + +// Safe case: non-overlapping intervening store (store to different indices) +// CHECK-LABEL: llvm.func @test_non_overlapping_store +llvm.func @test_non_overlapping_store() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<8.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.mlir.constant(dense<9.000000e+00> : vector<4xf16>) : vector<4xf16> + %3 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %4 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %5 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %5, %3 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %4 : vector<4xf32>, !llvm.ptr<5> + %6 = llvm.getelementptr %3[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %7 = llvm.getelementptr %4[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %8 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %8, %6 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %7 : vector<4xf32>, !llvm.ptr<5> + %9 = llvm.getelementptr %3[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %10 = llvm.getelementptr %4[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %11 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %11, %9 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %10 : vector<4xf32>, !llvm.ptr<5> + %12 = llvm.getelementptr %3[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %13 = llvm.getelementptr %4[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %14 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %14, %12 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %13 : vector<4xf32>, !llvm.ptr<5> + llvm.store %2, %9 : vector<4xf16>, !llvm.ptr<5> + %15 = llvm.load %3 : !llvm.ptr<5> -> vector<4xf16> + %16 = llvm.fpext %15 : vector<4xf16> to vector<4xf32> + // CHECK-NOT: llvm.fpext + // CHECK: llvm.load {{.*}} -> vector<4xf32> + %17 = llvm.fadd %16, %1 : vector<4xf32> + llvm.return +} + +// Unsafe case: fptrunc store does NOT dominate load (store after load) +// CHECK-LABEL: llvm.func @test_store_after_load +llvm.func @test_store_after_load() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<10.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %3 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %4 = llvm.load %2 : !llvm.ptr<5> -> vector<4xf16> + %5 = llvm.fpext %4 : vector<4xf16> to vector<4xf32> + // CHECK: llvm.fpext + %6 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %6, %2 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %3 : vector<4xf32>, !llvm.ptr<5> + %7 = llvm.fadd %5, %1 : vector<4xf32> + llvm.return +} + +// Safe case: scalar types (not vectors) +// CHECK-LABEL: llvm.func @test_scalar_types +llvm.func @test_scalar_types() { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(11.000000e+00 : f32) : f32 + %2 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %3 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %4 = llvm.fptrunc %1 : f32 to f16 + llvm.store %4, %2 : f16, !llvm.ptr<5> + llvm.store %1, %3 : f32, !llvm.ptr<5> + %5 = llvm.getelementptr %2[1] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %6 = llvm.getelementptr %3[1] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %7 = llvm.fptrunc %1 : f32 to f16 + llvm.store %7, %5 : f16, !llvm.ptr<5> + llvm.store %1, %6 : f32, !llvm.ptr<5> + %8 = llvm.getelementptr %2[2] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %9 = llvm.getelementptr %3[2] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %10 = llvm.fptrunc %1 : f32 to f16 + llvm.store %10, %8 : f16, !llvm.ptr<5> + llvm.store %1, %9 : f32, !llvm.ptr<5> + %11 = llvm.getelementptr %2[3] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %12 = llvm.getelementptr %3[3] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %13 = llvm.fptrunc %1 : f32 to f16 + llvm.store %13, %11 : f16, !llvm.ptr<5> + llvm.store %1, %12 : f32, !llvm.ptr<5> + %14 = llvm.load %2 : !llvm.ptr<5> -> f16 + %15 = llvm.fpext %14 : f16 to f32 + // CHECK-NOT: llvm.fpext + // CHECK: llvm.load {{.*}} -> f32 + %16 = llvm.fadd %15, %1 : f32 + llvm.return +} diff --git a/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir b/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir index 3e99eaf73994..40bee18d2a5b 100644 --- a/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir @@ -1,30 +1,32 @@ -// RUN: rocmlir-gen -fut mlir_remove_casts --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_remove_casts_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: sed s/##TOKEN_ARCH##/%arch/g %s | rocmlir-gen -fut mlir_remove_casts --arch %arch --clone-harness - | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_remove_casts_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: sed s/##TOKEN_ARCH##/%arch/g %s | rocmlir-driver --arch %arch --kernel-pipeline=migraphx,highlevel - | rocmlir-driver -arch %arch -c -mlir-print-ir-after=rock-remove-redundant-casts 2>&1 | FileCheck %s --check-prefix=NO-FPEXT // CHECK: [1 1 1] +// NO-FPEXT-NOT: llvm.fpext module { -func.func @mlir_remove_casts(%arg0: !migraphx.shaped<1x75352x5x128xf16, 48225280x640x128x1>, %arg1: !migraphx.shaped<1x75352x5x128xf16, 48225280x640x128x1>, %arg2: !migraphx.shaped<1x75352x5x128xf16, 48225280x640x128x1>) -> !migraphx.shaped<1x5x75352x128xf16, 48225280x9645056x128x1> attributes {kernel = "mixr"} { +func.func @mlir_remove_casts(%arg0: !migraphx.shaped<1x64x5x128xf16, 40960x640x128x1>, %arg1: !migraphx.shaped<1x64x5x128xf16, 40960x640x128x1>, %arg2: !migraphx.shaped<1x64x5x128xf16, 40960x640x128x1>) -> !migraphx.shaped<1x5x64x128xf16, 40960x8192x128x1> attributes {arch="gfx950", kernel = "mixr"} { %0 = migraphx.literal(dense<8.837890e-02> : tensor<1xf16>) : <1xf16, 0> - %1 = migraphx.transpose %arg0 {permutation = [0, 2, 1, 3]} : <1x75352x5x128xf16, 48225280x640x128x1> -> <1x5x75352x128xf16, 48225280x128x640x1> - %2 = migraphx.transpose %arg1 {permutation = [0, 2, 3, 1]} : <1x75352x5x128xf16, 48225280x640x128x1> -> <1x5x128x75352xf16, 48225280x128x1x640> - %3 = migraphx.transpose %arg2 {permutation = [0, 2, 1, 3]} : <1x75352x5x128xf16, 48225280x640x128x1> -> <1x5x75352x128xf16, 48225280x128x640x1> - %4 = migraphx.dot %1, %2 : <1x5x75352x128xf16, 48225280x128x640x1>, <1x5x128x75352xf16, 48225280x128x1x640> -> <1x5x75352x75352xf16, 28389619520x5677923904x75352x1> - %5 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 5, 75352, 75352]} : <1xf16, 0> -> <1x5x75352x75352xf16, 0x0x0x0> - %6 = migraphx.convert %4 {target_type = 2 : i64} : <1x5x75352x75352xf16, 28389619520x5677923904x75352x1> to <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> - %7 = migraphx.convert %5 {target_type = 2 : i64} : <1x5x75352x75352xf16, 0x0x0x0> to <1x5x75352x75352xf32, 0x0x0x0> - %8 = migraphx.mul %6, %7 : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1>, <1x5x75352x75352xf32, 0x0x0x0> -> <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> - %9 = migraphx.reshape %8 {dims = [1, 5, 75352, 75352]} : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> -> <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> - %10 = migraphx.reduce_max %9 {axes = [3]} : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> -> <1x5x75352x1xf32, 376760x75352x1x1> - %11 = migraphx.reshape %10 {dims = [1, 5, 75352, 1]} : <1x5x75352x1xf32, 376760x75352x1x1> -> <1x5x75352x1xf32, 376760x75352x1x1> - %12 = migraphx.multibroadcast %11 {out_dyn_dims = [], out_lens = [1, 5, 75352, 75352]} : <1x5x75352x1xf32, 376760x75352x1x1> -> <1x5x75352x75352xf32, 376760x75352x1x0> - %13 = migraphx.sub %8, %12 : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1>, <1x5x75352x75352xf32, 376760x75352x1x0> -> <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> - %14 = migraphx.exp %13 : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> -> <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> - %15 = migraphx.reshape %14 {dims = [1, 5, 75352, 75352]} : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> -> <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> - %16 = migraphx.reduce_sum %15 {axes = [3]} : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> -> <1x5x75352x1xf32, 376760x75352x1x1> - %17 = migraphx.reshape %16 {dims = [1, 5, 75352, 1]} : <1x5x75352x1xf32, 376760x75352x1x1> -> <1x5x75352x1xf32, 376760x75352x1x1> - %18 = migraphx.multibroadcast %17 {out_dyn_dims = [], out_lens = [1, 5, 75352, 75352]} : <1x5x75352x1xf32, 376760x75352x1x1> -> <1x5x75352x75352xf32, 376760x75352x1x0> - %19 = migraphx.div %14, %18 : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1>, <1x5x75352x75352xf32, 376760x75352x1x0> -> <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> - %20 = migraphx.convert %19 {target_type = 1 : i64} : <1x5x75352x75352xf32, 28389619520x5677923904x75352x1> to <1x5x75352x75352xf16, 28389619520x5677923904x75352x1> - %21 = migraphx.dot %20, %3 : <1x5x75352x75352xf16, 28389619520x5677923904x75352x1>, <1x5x75352x128xf16, 48225280x128x640x1> -> <1x5x75352x128xf16, 48225280x9645056x128x1> - return %21 : !migraphx.shaped<1x5x75352x128xf16, 48225280x9645056x128x1> + %1 = migraphx.transpose %arg0 {permutation = [0, 2, 1, 3]} : <1x64x5x128xf16, 40960x640x128x1> -> <1x5x64x128xf16, 40960x128x640x1> + %2 = migraphx.transpose %arg1 {permutation = [0, 2, 3, 1]} : <1x64x5x128xf16, 40960x640x128x1> -> <1x5x128x64xf16, 40960x128x1x640> + %3 = migraphx.transpose %arg2 {permutation = [0, 2, 1, 3]} : <1x64x5x128xf16, 40960x640x128x1> -> <1x5x64x128xf16, 40960x128x640x1> + %4 = migraphx.dot %1, %2 : <1x5x64x128xf16, 40960x128x640x1>, <1x5x128x64xf16, 40960x128x1x640> -> <1x5x64x64xf16, 20480x4096x64x1> + %5 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 5, 64, 64]} : <1xf16, 0> -> <1x5x64x64xf16, 0x0x0x0> + %6 = migraphx.convert %4 {target_type = 2 : i64} : <1x5x64x64xf16, 20480x4096x64x1> to <1x5x64x64xf32, 20480x4096x64x1> + %7 = migraphx.convert %5 {target_type = 2 : i64} : <1x5x64x64xf16, 0x0x0x0> to <1x5x64x64xf32, 0x0x0x0> + %8 = migraphx.mul %6, %7 : <1x5x64x64xf32, 20480x4096x64x1>, <1x5x64x64xf32, 0x0x0x0> -> <1x5x64x64xf32, 20480x4096x64x1> + %9 = migraphx.reshape %8 {dims = [1, 5, 64, 64]} : <1x5x64x64xf32, 20480x4096x64x1> -> <1x5x64x64xf32, 20480x4096x64x1> + %10 = migraphx.reduce_max %9 {axes = [3]} : <1x5x64x64xf32, 20480x4096x64x1> -> <1x5x64x1xf32, 320x64x1x1> + %11 = migraphx.reshape %10 {dims = [1, 5, 64, 1]} : <1x5x64x1xf32, 320x64x1x1> -> <1x5x64x1xf32, 320x64x1x1> + %12 = migraphx.multibroadcast %11 {out_dyn_dims = [], out_lens = [1, 5, 64, 64]} : <1x5x64x1xf32, 320x64x1x1> -> <1x5x64x64xf32, 320x64x1x0> + %13 = migraphx.sub %8, %12 : <1x5x64x64xf32, 20480x4096x64x1>, <1x5x64x64xf32, 320x64x1x0> -> <1x5x64x64xf32, 20480x4096x64x1> + %14 = migraphx.exp %13 : <1x5x64x64xf32, 20480x4096x64x1> -> <1x5x64x64xf32, 20480x4096x64x1> + %15 = migraphx.reshape %14 {dims = [1, 5, 64, 64]} : <1x5x64x64xf32, 20480x4096x64x1> -> <1x5x64x64xf32, 20480x4096x64x1> + %16 = migraphx.reduce_sum %15 {axes = [3]} : <1x5x64x64xf32, 20480x4096x64x1> -> <1x5x64x1xf32, 320x64x1x1> + %17 = migraphx.reshape %16 {dims = [1, 5, 64, 1]} : <1x5x64x1xf32, 320x64x1x1> -> <1x5x64x1xf32, 320x64x1x1> + %18 = migraphx.multibroadcast %17 {out_dyn_dims = [], out_lens = [1, 5, 64, 64]} : <1x5x64x1xf32, 320x64x1x1> -> <1x5x64x64xf32, 320x64x1x0> + %19 = migraphx.div %14, %18 : <1x5x64x64xf32, 20480x4096x64x1>, <1x5x64x64xf32, 320x64x1x0> -> <1x5x64x64xf32, 20480x4096x64x1> + %20 = migraphx.convert %19 {target_type = 1 : i64} : <1x5x64x64xf32, 20480x4096x64x1> to <1x5x64x64xf16, 20480x4096x64x1> + %21 = migraphx.dot %20, %3 : <1x5x64x64xf16, 20480x4096x64x1>, <1x5x64x128xf16, 40960x128x640x1> -> <1x5x64x128xf16, 40960x8192x128x1> + return %21 : !migraphx.shaped<1x5x64x128xf16, 40960x8192x128x1> } -} \ No newline at end of file +} From 800c5dcb1332978f20476575d2d3121f0ce18ea1 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Thu, 8 Jan 2026 17:22:50 +0000 Subject: [PATCH 09/12] Add newline --- mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir | 1 + mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir | 1 + 2 files changed, 2 insertions(+) diff --git a/mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir b/mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir index aeb976b0e904..0a6835ccc681 100644 --- a/mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir +++ b/mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir @@ -250,3 +250,4 @@ llvm.func @test_scalar_types() { %16 = llvm.fadd %15, %1 : f32 llvm.return } + diff --git a/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir b/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir index 40bee18d2a5b..7e28183b0fe8 100644 --- a/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir +++ b/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir @@ -30,3 +30,4 @@ func.func @mlir_remove_casts(%arg0: !migraphx.shaped<1x64x5x128xf16, 40960x640x1 return %21 : !migraphx.shaped<1x5x64x128xf16, 40960x8192x128x1> } } + From 3ee0bf8edae352eb89595883cdd239f628d1f21b Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Thu, 8 Jan 2026 17:24:18 +0000 Subject: [PATCH 10/12] Clang-format --- mlir/include/mlir/Dialect/Rock/Passes.td | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index 3b911d11e28b..1014ea9ea99d 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -174,7 +174,8 @@ def RockLowerReducePass : Pass<"rock-lower-reduce", "::mlir::func::FuncOp"> { def RockRemoveRedundantCastsPass : Pass<"rock-remove-redundant-casts", "::mlir::LLVM::LLVMFuncOp"> { - let summary = "Remove redundant fptrunc/fpext pairs through buffers at LLVM dialect level"; + let summary = "Remove redundant fptrunc/fpext pairs through buffers at LLVM " + "dialect level"; let description = [{ Detects patterns at the LLVM dialect level where wider float values are truncated (llvm.fptrunc) to a narrower type, stored to a buffer, then loaded From 713d5776f6f63f5df08037cc90e0496ebe9e4d87 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Thu, 8 Jan 2026 17:29:20 +0000 Subject: [PATCH 11/12] Remove some extra lines --- .../Rock/Transforms/RemoveRedundantCasts.cpp | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp index 9e9f481e18cb..ec2387284a4f 100644 --- a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp @@ -134,7 +134,6 @@ findFPTruncStorePatterns(LLVMFuncOp funcOp) { Value narrowPtr = narrowStore.getAddr(); Value narrowBuffer = getBasePointer(narrowPtr); GEPOp narrowGep = narrowPtr.getDefiningOp(); - LLVM_DEBUG(llvm::dbgs() << "\tFound narrow store: " << narrowStore << "\n"); @@ -154,13 +153,11 @@ findFPTruncStorePatterns(LLVMFuncOp funcOp) { Value widePtr = wideStore.getAddr(); Value wideBuffer = getBasePointer(widePtr); - if (wideBuffer == narrowBuffer) continue; LLVM_DEBUG(llvm::dbgs() << "\tFound existing parallel wide store: " << wideStore << "\n"); - info.wideBuffer = wideBuffer; info.wideStore = wideStore; break; @@ -188,7 +185,6 @@ static SmallVector findLoadFPExtPatterns(LLVMFuncOp funcOp) { Value loadPtr = loadOp.getAddr(); Value narrowBuffer = getBasePointer(loadPtr); GEPOp gepOp = loadPtr.getDefiningOp(); - LLVM_DEBUG(llvm::dbgs() << "Found load+fpext pattern:\n"); LLVM_DEBUG(llvm::dbgs() << "\tLoad: " << loadOp << "\n"); LLVM_DEBUG(llvm::dbgs() << "\tFPExt: " << fpextOp << "\n"); @@ -242,14 +238,11 @@ isStoreFromFPTruncPattern(StoreOp store, struct IndexRange { int64_t start; int64_t count; - bool isValid() const { return count > 0; } - bool isSubsetOf(const IndexRange &other) const { return start >= other.start && (start + count) <= (other.start + other.count); } - bool overlaps(const IndexRange &other) const { return start < (other.start + other.count) && other.start < (start + count); } @@ -321,7 +314,6 @@ findMatchingFPTruncStores(LoadFPExtPattern &pattern, continue; dominatingStores.push_back(&info); - IndexRange storeRange = getAccessRange(info.narrowGep, info.narrowStore.getValue().getType()); if (!storeRange.isValid()) @@ -351,9 +343,7 @@ hasNoInterveningStores(LoadFPExtPattern &pattern, DominanceInfo &domInfo) { IndexRange loadRange = getAccessRange(pattern.gepOp, pattern.loadOp.getRes().getType()); - SmallVector allStores = collectStoresToBuffer(pattern.narrowBuffer); - for (auto store : allStores) { // Skip fptrunc stores if (isStoreFromFPTruncPattern(store, storeInfos)) @@ -437,7 +427,6 @@ verifySafety(LoadFPExtPattern &pattern, static void createWideStore(FPTruncStoreInfo *info, Value wideBuffer, Type wideElemType, OpBuilder &builder) { builder.setInsertionPointAfter(info->narrowStore); - Value widePtr; if (info->narrowGep) { SmallVector gepArgs; @@ -472,13 +461,11 @@ createWideBuffersAndStores(SmallVector &safePatterns, OpBuilder &builder) { DenseMap narrowToWideBuffer; DenseSet processedStores; - for (auto &pattern : safePatterns) { if (pattern.matchingStores.empty()) continue; Type wideElemType = getScalarType(pattern.fpextOp.getRes().getType()); - for (FPTruncStoreInfo *info : pattern.matchingStores) { if (processedStores.contains(info)) continue; @@ -510,7 +497,6 @@ createWideBuffersAndStores(SmallVector &safePatterns, wideElemType, narrowAlloca.getArraySize()); LLVM_DEBUG(llvm::dbgs() << "Created wide alloca: " << wideAlloca << "\n"); - narrowToWideBuffer[info->narrowBuffer] = wideAlloca.getResult(); createWideStore(info, wideAlloca.getResult(), wideElemType, builder); } @@ -541,11 +527,9 @@ static void applyTransformation(SmallVector &safePatterns, Type wideType = pattern.fpextOp.getRes().getType(); Type wideElemType = getScalarType(wideType); - Value newPtr; if (pattern.gepOp) { builder.setInsertionPoint(pattern.gepOp); - SmallVector gepArgs; for (auto idx : pattern.gepOp.getIndices()) { if (auto constIdx = dyn_cast(idx)) { @@ -565,7 +549,6 @@ static void applyTransformation(SmallVector &safePatterns, } builder.setInsertionPoint(pattern.loadOp); - unsigned wideAlignment = 4; if (auto vecType = dyn_cast(wideType)) { unsigned elemBits = vecType.getElementType().getIntOrFloatBitWidth(); @@ -581,7 +564,6 @@ static void applyTransformation(SmallVector &safePatterns, // Clean up the load -> fpext pattern.fpextOp.getRes().replaceAllUsesWith(newLoad.getRes()); pattern.fpextOp.erase(); - if (pattern.loadOp.getRes().use_empty()) { pattern.loadOp.erase(); } @@ -609,7 +591,6 @@ cleanupUnusedNarrowBufferOps(SmallVector &safePatterns) { // Track what we've already erased to avoid double-erase DenseSet erased; - for (auto &[narrowBuffer, stores] : bufferToStores) { // Check if the narrow buffer still has uses // (other than the stores we're about to erase) @@ -641,7 +622,6 @@ cleanupUnusedNarrowBufferOps(SmallVector &safePatterns) { // No other uses, so we can clean up the fptrunc stores and related ops LLVM_DEBUG(llvm::dbgs() << "Cleaning up unused narrow buffer ops\n"); - for (auto *info : stores) { // Capture the fptrunc op before erasing the store auto fptruncOp = info->narrowStore.getValue().getDefiningOp(); From 107b461ce5ca2e7fed4ca5bb30cab8bba48480e9 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Thu, 8 Jan 2026 22:07:59 +0000 Subject: [PATCH 12/12] Conservative checks --- .../Rock/Transforms/RemoveRedundantCasts.cpp | 82 +++++++++++++++++-- mlir/test/rocmlir-driver/pipelines.mlir | 3 + 2 files changed, 79 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp index ec2387284a4f..fa6fecdade42 100644 --- a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp @@ -226,6 +226,40 @@ static SmallVector collectStoresToBuffer(Value buffer) { return stores; } +// Collect all loads that read from a buffer, tracing through GEPs. +static SmallVector collectLoadsFromBuffer(Value buffer) { + SmallVector loads; + SmallVector worklist; + worklist.push_back(buffer); + + while (!worklist.empty()) { + Value ptr = worklist.pop_back_val(); + for (Operation *user : ptr.getUsers()) { + if (auto load = dyn_cast(user)) { + if (load.getAddr() == ptr) { + loads.push_back(load); + } + } else if (auto gep = dyn_cast(user)) { + // Trace through GEPs that use this as base + if (gep.getBase() == ptr) { + worklist.push_back(gep.getResult()); + } + } + } + } + return loads; +} + +// Check if a load is only used by an fpext operation +static bool isLoadOnlyUsedByFPExt(LoadOp load) { + Value loadResult = load.getRes(); + // The load result should have exactly one use, and it should be fpext + if (!loadResult.hasOneUse()) + return false; + Operation *user = *loadResult.getUsers().begin(); + return isa(user); +} + // Check if a store is from one of our tracked fptrunc patterns. static bool isStoreFromFPTruncPattern(StoreOp store, @@ -383,6 +417,18 @@ verifySafety(LoadFPExtPattern &pattern, return failure(); } + // Check that the load uses static indices. Dynamic indices suggest the buffer + // is accessed in a pattern-dependent way (e.g., in a loop with runtime + // indexing), where the precision conversion may be algorithmically intentional. + IndexRange loadRange = + getAccessRange(pattern.gepOp, pattern.loadOp.getRes().getType()); + if (!loadRange.isValid()) { + LLVM_DEBUG(llvm::dbgs() + << "\tUNSAFE: Load has dynamic index - cannot verify access " + "pattern matches stores\n"); + return failure(); + } + // Find all fptrunc stores that cover this load SmallVector matchingStores = findMatchingFPTruncStores(pattern, storeInfos, domInfo); @@ -392,7 +438,7 @@ verifySafety(LoadFPExtPattern &pattern, return failure(); } - // Check that all matching stores have compatible element types + // Check that all matching stores have compatible element types. Type fpextElemType = getScalarType(pattern.fpextOp.getRes().getType()); for (auto *store : matchingStores) { Type originalWideElemType = getScalarType(store->wideValue.getType()); @@ -413,6 +459,30 @@ verifySafety(LoadFPExtPattern &pattern, << "\tUNSAFE: Wide store does not dominate load\n"); return failure(); } + + // Check that the fptrunc result is only used by the narrow store. + // If the f16 value has other uses, we can't eliminate the truncation. + auto fptruncOp = store->narrowStore.getValue().getDefiningOp(); + if (fptruncOp && !fptruncOp.getRes().hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() + << "\tUNSAFE: fptrunc result has multiple uses - the f16 " + "value is used elsewhere\n"); + return failure(); + } + } + + // Check that all loads from the narrow buffer are used only by fpext. + // If there are loads that use the f16 values directly, + // we can't eliminate the narrow buffer. + SmallVector allLoads = collectLoadsFromBuffer(pattern.narrowBuffer); + for (LoadOp load : allLoads) { + if (!isLoadOnlyUsedByFPExt(load)) { + LLVM_DEBUG(llvm::dbgs() + << "\tUNSAFE: Buffer has loads that don't go through fpext - " + "f16 values are used directly: " + << load << "\n"); + return failure(); + } } // Check that no non-fptrunc stores could intervene @@ -437,8 +507,8 @@ static void createWideStore(FPTruncStoreInfo *info, Value wideBuffer, gepArgs.push_back(cast(idx)); } auto wideGep = GEPOp::create(builder, info->narrowGep.getLoc(), - info->narrowGep.getType(), wideElemType, - wideBuffer, gepArgs); + wideBuffer.getType(), wideElemType, wideBuffer, + gepArgs); wideGep.setNoWrapFlags(info->narrowGep.getNoWrapFlags()); widePtr = wideGep.getResult(); } else { @@ -493,8 +563,8 @@ createWideBuffersAndStores(SmallVector &safePatterns, builder.setInsertionPointAfter(narrowAlloca); auto wideAlloca = AllocaOp::create(builder, narrowAlloca.getLoc(), - LLVM::LLVMPointerType::get(builder.getContext()), - wideElemType, narrowAlloca.getArraySize()); + narrowAlloca.getResult().getType(), wideElemType, + narrowAlloca.getArraySize()); LLVM_DEBUG(llvm::dbgs() << "Created wide alloca: " << wideAlloca << "\n"); narrowToWideBuffer[info->narrowBuffer] = wideAlloca.getResult(); @@ -540,7 +610,7 @@ static void applyTransformation(SmallVector &safePatterns, } auto newGep = GEPOp::create(builder, pattern.gepOp.getLoc(), - pattern.gepOp.getType(), wideElemType, + wideBuffer.getType(), wideElemType, wideBuffer, gepArgs); newGep.setNoWrapFlags(pattern.gepOp.getNoWrapFlags()); newPtr = newGep.getResult(); diff --git a/mlir/test/rocmlir-driver/pipelines.mlir b/mlir/test/rocmlir-driver/pipelines.mlir index 3554d4752b1c..a8aa2042985a 100644 --- a/mlir/test/rocmlir-driver/pipelines.mlir +++ b/mlir/test/rocmlir-driver/pipelines.mlir @@ -74,6 +74,7 @@ // BINARY-NEXT:llvm.func(rock-to-rocdl{chipset=gfx90a}), // BINARY-NEXT:llvm.func(canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, // BINARY-NEXT:cse, +// BINARY-NEXT:rock-remove-redundant-casts, // BINARY-NEXT:rock-prepare-llvm)), // BINARY-NEXT:rocdl-attach-target{O=3 abi=600 chip=gfx90a correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64=true}, // BINARY-NEXT:gpu-module-to-binary{format=fatbin opts= section= toolkit=}, @@ -105,6 +106,7 @@ // BINARY_MI300-NEXT:llvm.func(rock-to-rocdl{chipset=gfx942}), // BINARY_MI300-NEXT:llvm.func(canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, // BINARY_MI300-NEXT:cse, +// BINARY_MI300-NEXT:rock-remove-redundant-casts, // BINARY_MI300-NEXT:rock-prepare-llvm)), // BINARY_MI300-NEXT:rocdl-attach-target{O=3 abi=600 chip=gfx942 correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64=true}, // BINARY_MI300-NEXT:gpu-module-to-binary{format=fatbin opts= section= toolkit=}, @@ -136,6 +138,7 @@ // BINARY_MI350-NEXT:llvm.func(rock-to-rocdl{chipset=gfx950}), // BINARY_MI350-NEXT:llvm.func(canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, // BINARY_MI350-NEXT:cse, +// BINARY_MI350-NEXT:rock-remove-redundant-casts, // BINARY_MI350-NEXT:rock-prepare-llvm)), // BINARY_MI350-NEXT:rocdl-attach-target{O=3 abi=600 chip=gfx950 correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64=true}, // BINARY_MI350-NEXT:gpu-module-to-binary{format=fatbin opts= section= toolkit=},