diff --git a/mlir/include/mlir/Dialect/Rock/Passes.h b/mlir/include/mlir/Dialect/Rock/Passes.h index c0771efaae26..305669ac548a 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.h +++ b/mlir/include/mlir/Dialect/Rock/Passes.h @@ -41,6 +41,7 @@ namespace rock { #define GEN_PASS_DECL_ROCKVIEWTOTRANSFORMPASS #define GEN_PASS_DECL_ROCKDETECTFLASHDECODINGPASS #define GEN_PASS_DECL_ROCKLOWERREDUCEPASS +#define GEN_PASS_DECL_ROCKREMOVEREDUNDANTCASTSPASS #define GEN_PASS_DECL_ROCKPREPARELLVMPASS #define GEN_PASS_DECL_ROCKCHECKRESIDENCYPASS #define GEN_PASS_DECL_ROCKVECTORIZEFUSIONSPASS diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index 566d2de04b73..d1d4c8c4098e 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -176,6 +176,20 @@ def RockLowerReducePass : Pass<"rock-lower-reduce", "::mlir::func::FuncOp"> { let dependentDialects = ["rock::RockDialect", "func::FuncDialect", "gpu::GPUDialect"]; } +def RockRemoveRedundantCastsPass + : Pass<"rock-remove-redundant-casts", "::mlir::LLVM::LLVMFuncOp"> { + let summary = "Remove redundant fptrunc/fpext pairs through buffers at LLVM " + "dialect level"; + let description = [{ + Detects patterns at the LLVM dialect level where wider float values are + truncated (llvm.fptrunc) to a narrower type, stored to a buffer, then loaded + and extended (llvm.fpext) back to the original wider type. This pass + redirects the loads to read the wide values directly, eliminating the + fpext and preserving precision. + }]; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + def RockPrepareLLVMPass : Pass<"rock-prepare-llvm", "::mlir::LLVM::LLVMFuncOp"> { let summary = "prepare the generated code for llvm"; let dependentDialects = ["ROCDL::ROCDLDialect"]; diff --git a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp index a945ed55ca8e..5984fcb1a6c3 100644 --- a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp +++ b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp @@ -305,6 +305,7 @@ void rock::buildBackendPipeline(OpPassManager &pm, // descriptors. (Mainly we want the `extractvalue` fold). llvmFuncPm.addPass(createCanonicalizerPass()); llvmFuncPm.addPass(createCSEPass()); + llvmFuncPm.addPass(rock::createRockRemoveRedundantCastsPass()); llvmFuncPm.addPass(rock::createRockPrepareLLVMPass()); if (options.compile) { GpuROCDLAttachTargetOptions opts; diff --git a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt index 9bdec3f2318d..d164f5a4b19e 100644 --- a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt @@ -37,6 +37,7 @@ add_rocmlir_dialect_library(MLIRRockTransforms BlockwiseLoadTileToThreadwise.cpp AnnotateLiveness.cpp AddAsyncWait.cpp + RemoveRedundantCasts.cpp LowerRockOpsToROCDLOps.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp new file mode 100644 index 000000000000..fa6fecdade42 --- /dev/null +++ b/mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp @@ -0,0 +1,802 @@ +//===-------------------- RemoveRedundantCasts.cpp ------------------------===// +// +// Copyright 2026 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// +// +// This pass detects patterns at the LLVM dialect level where wider float values +// are truncated (llvm.fptrunc) to a narrower type, stored to a buffer, then +// loaded and extended (llvm.fpext) back to the original wider type. This pass +// creates a parallel wide buffer (if one doesn't exist) and redirects the loads +// to read the wide values directly, eliminating the fpext and preserving +// precision. +// +// Algorithm: +// 1. Find all fptrunc -> store patterns in the function. For each pattern, +// record whether there's already a parallel store of the wide value to +// a separate buffer. +// 2. Find all load -> fpext patterns where the load is from a buffer that +// has fptrunc stores. +// 3. Verify safety for each load+fpext pattern: +// - All stores to the narrow buffer must be from tracked fptrunc patterns +// (i.e., no untracked stores that could write different values) +// - All tracked stores must dominate the load +// - The narrow buffer must be an alloca +// 4. For safe patterns, create a wide buffer and the corresponding stores if +// they don't exist. If a parallel store already exists, reuse it: +// - Create a wide alloca right after the narrow alloca +// - For each fptrunc store, insert a store of the wide value to the +// wide buffer (right after the narrow store, using the same indices) +// 5. Apply the transformation: +// - Redirect the load to read from the wide buffer instead +// - Replace uses of the fpext result with the wide load result +// - Delete the fpext (and the old load/GEP if unused) +// 6. Clean up unused narrow buffer operations: +// - If the narrow buffer has no remaining uses, erase the fptrunc stores +// - These can only be erased if they are not used by any other +// operations +// - Erase the narrow alloca if it has no remaining uses +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Rock/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace rock { +#define GEN_PASS_DEF_ROCKREMOVEREDUNDANTCASTSPASS +#include "mlir/Dialect/Rock/Passes.h.inc" +} // namespace rock +} // namespace mlir + +#define DEBUG_TYPE "rock-remove-redundant-casts" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { + +// Information about a fptrunc -> store pattern +struct FPTruncStoreInfo { + Value wideValue; // The input to fptrunc (original wide value) + StoreOp narrowStore; // The store of the narrow value + Value narrowBuffer; // Base alloca of narrow store + GEPOp narrowGep; // GEP for narrow store (null if storing directly) + + // These may be null if no parallel store exists yet + Value wideBuffer; + StoreOp wideStore; + + bool hasParallelStore() const { return wideStore != nullptr; } +}; + +// Information about a load + fpext pattern that can potentially be optimized. +struct LoadFPExtPattern { + LoadOp loadOp; // The load from narrow buffer + FPExtOp fpextOp; // The fpext that extends the loaded value + Value narrowBuffer; // Base pointer of the narrow buffer being loaded + GEPOp gepOp; // The GEP operation (if any) used for indexing + + // All fptrunc stores that contribute to covering the buffer. + SmallVector matchingStores; +}; + +// Get the base pointer from a value, tracing through GEP operations. +static Value getBasePointer(Value ptr) { + while (auto gep = ptr.getDefiningOp()) { + ptr = gep.getBase(); + } + return ptr; +} + +// Get the scalar element type, unwrapping vectors if needed. +static Type getScalarType(Type type) { + if (auto vecType = dyn_cast(type)) + return vecType.getElementType(); + return type; +} + +// Find all fptrunc -> store patterns in the function. +static SmallVector +findFPTruncStorePatterns(LLVMFuncOp funcOp) { + SmallVector results; + + funcOp.walk([&](FPTruncOp fptruncOp) -> WalkResult { + LLVM_DEBUG(llvm::dbgs() << "Found fptrunc: " << fptruncOp << "\n"); + Value wideValue = fptruncOp.getArg(); + Value narrowValue = fptruncOp.getRes(); + + // Find direct stores of the narrow value + for (Operation *user : narrowValue.getUsers()) { + auto narrowStore = dyn_cast(user); + if (!narrowStore) + continue; + + Value narrowPtr = narrowStore.getAddr(); + Value narrowBuffer = getBasePointer(narrowPtr); + GEPOp narrowGep = narrowPtr.getDefiningOp(); + LLVM_DEBUG(llvm::dbgs() + << "\tFound narrow store: " << narrowStore << "\n"); + + FPTruncStoreInfo info; + info.wideValue = wideValue; + info.narrowStore = narrowStore; + info.narrowBuffer = narrowBuffer; + info.narrowGep = narrowGep; + info.wideBuffer = nullptr; + info.wideStore = nullptr; + + // Look for existing parallel wide store + for (Operation *wideUser : wideValue.getUsers()) { + auto wideStore = dyn_cast(wideUser); + if (!wideStore) + continue; + + Value widePtr = wideStore.getAddr(); + Value wideBuffer = getBasePointer(widePtr); + if (wideBuffer == narrowBuffer) + continue; + + LLVM_DEBUG(llvm::dbgs() << "\tFound existing parallel wide store: " + << wideStore << "\n"); + info.wideBuffer = wideBuffer; + info.wideStore = wideStore; + break; + } + + results.push_back(info); + } + + return WalkResult::advance(); + }); + + return results; +} + +// Find all load + fpext patterns. +static SmallVector findLoadFPExtPatterns(LLVMFuncOp funcOp) { + SmallVector results; + + funcOp.walk([&](FPExtOp fpextOp) -> WalkResult { + Value input = fpextOp.getArg(); + auto loadOp = input.getDefiningOp(); + if (!loadOp) + return WalkResult::advance(); + + Value loadPtr = loadOp.getAddr(); + Value narrowBuffer = getBasePointer(loadPtr); + GEPOp gepOp = loadPtr.getDefiningOp(); + LLVM_DEBUG(llvm::dbgs() << "Found load+fpext pattern:\n"); + LLVM_DEBUG(llvm::dbgs() << "\tLoad: " << loadOp << "\n"); + LLVM_DEBUG(llvm::dbgs() << "\tFPExt: " << fpextOp << "\n"); + + LoadFPExtPattern pattern; + pattern.loadOp = loadOp; + pattern.fpextOp = fpextOp; + pattern.narrowBuffer = narrowBuffer; + pattern.gepOp = gepOp; + results.push_back(pattern); + return WalkResult::advance(); + }); + + return results; +} + +// Collect all stores that write to a buffer, tracing through GEPs. +static SmallVector collectStoresToBuffer(Value buffer) { + SmallVector stores; + SmallVector worklist; + worklist.push_back(buffer); + + while (!worklist.empty()) { + Value ptr = worklist.pop_back_val(); + for (Operation *user : ptr.getUsers()) { + if (auto store = dyn_cast(user)) { + // Only count stores TO this address, not stores OF this value + if (store.getAddr() == ptr) { + stores.push_back(store); + } + } else if (auto gep = dyn_cast(user)) { + // Trace through GEPs that use this as base + if (gep.getBase() == ptr) { + worklist.push_back(gep.getResult()); + } + } + } + } + return stores; +} + +// Collect all loads that read from a buffer, tracing through GEPs. +static SmallVector collectLoadsFromBuffer(Value buffer) { + SmallVector loads; + SmallVector worklist; + worklist.push_back(buffer); + + while (!worklist.empty()) { + Value ptr = worklist.pop_back_val(); + for (Operation *user : ptr.getUsers()) { + if (auto load = dyn_cast(user)) { + if (load.getAddr() == ptr) { + loads.push_back(load); + } + } else if (auto gep = dyn_cast(user)) { + // Trace through GEPs that use this as base + if (gep.getBase() == ptr) { + worklist.push_back(gep.getResult()); + } + } + } + } + return loads; +} + +// Check if a load is only used by an fpext operation +static bool isLoadOnlyUsedByFPExt(LoadOp load) { + Value loadResult = load.getRes(); + // The load result should have exactly one use, and it should be fpext + if (!loadResult.hasOneUse()) + return false; + Operation *user = *loadResult.getUsers().begin(); + return isa(user); +} + +// Check if a store is from one of our tracked fptrunc patterns. +static bool +isStoreFromFPTruncPattern(StoreOp store, + const SmallVector &storeInfos) { + return llvm::any_of( + storeInfos, [&](const auto &info) { return info.narrowStore == store; }); +} + +// Represents a range of element indices [start, start + count). +struct IndexRange { + int64_t start; + int64_t count; + bool isValid() const { return count > 0; } + bool isSubsetOf(const IndexRange &other) const { + return start >= other.start && + (start + count) <= (other.start + other.count); + } + bool overlaps(const IndexRange &other) const { + return start < (other.start + other.count) && other.start < (start + count); + } +}; + +// Get the index range for a memory access. Returns an invalid range if we can't +// determine it (e.g., dynamic indices). +static IndexRange getAccessRange(GEPOp gep, Type accessType) { + int64_t elementCount = 1; + if (auto vecType = dyn_cast(accessType)) + elementCount = vecType.getNumElements(); + + // No GEP means accessing at base (index 0) + if (!gep) + return {0, elementCount}; + + // Check if all indices are constant + auto indices = gep.getIndices(); + if (indices.empty()) + return {0, elementCount}; + + // We only handle single-index GEPs with constant index + if (indices.size() != 1) + return {-1, 0}; // Invalid + + auto constIdx = dyn_cast(indices[0]); + if (!constIdx) + return {-1, 0}; // Invalid + + return {constIdx.getInt(), elementCount}; +} + +// Get the total size (in elements) of a buffer from its alloca. +static int64_t getBufferSize(Value buffer) { + auto alloca = buffer.getDefiningOp(); + if (!alloca) + return -1; + + // Get array size (number of elements allocated) + Value arraySizeVal = alloca.getArraySize(); + if (auto constOp = arraySizeVal.getDefiningOp()) { + if (auto intAttr = dyn_cast(constOp.getValue())) + return intAttr.getInt(); + } + return -1; // Dynamic or unknown size +} + +// Find all fptrunc stores that cover the load's location. Returns all +// dominating stores if they collectively cover the entire buffer, otherwise +// returns an empty list. +static SmallVector +findMatchingFPTruncStores(LoadFPExtPattern &pattern, + SmallVector &storeInfos, + DominanceInfo &domInfo) { + SmallVector dominatingStores; + int64_t bufferSize = getBufferSize(pattern.narrowBuffer); + + if (bufferSize <= 0) + return dominatingStores; + + // Track which elements are covered + std::vector covered(bufferSize, false); + + for (auto &info : storeInfos) { + if (info.narrowBuffer != pattern.narrowBuffer) + continue; + if (!domInfo.dominates(info.narrowStore.getOperation(), + pattern.loadOp.getOperation())) + continue; + + dominatingStores.push_back(&info); + IndexRange storeRange = + getAccessRange(info.narrowGep, info.narrowStore.getValue().getType()); + if (!storeRange.isValid()) + continue; + + // Mark covered elements + for (int64_t i = storeRange.start; + i < storeRange.start + storeRange.count && i < bufferSize; ++i) { + if (i >= 0) + covered[i] = true; + } + } + + // Check if all elements are covered + for (bool c : covered) { + if (!c) + return {}; // Not fully covered, return empty + } + return dominatingStores; +} + +// Check that no non-fptrunc stores could intervene between the fptrunc stores +// and the load that would overwrite the fptrunc'd value. +static bool +hasNoInterveningStores(LoadFPExtPattern &pattern, + const SmallVector &storeInfos, + DominanceInfo &domInfo) { + IndexRange loadRange = + getAccessRange(pattern.gepOp, pattern.loadOp.getRes().getType()); + SmallVector allStores = collectStoresToBuffer(pattern.narrowBuffer); + for (auto store : allStores) { + // Skip fptrunc stores + if (isStoreFromFPTruncPattern(store, storeInfos)) + continue; + + // If load dominates store, the store happens after the load on all paths + if (domInfo.dominates(pattern.loadOp.getOperation(), store.getOperation())) + continue; + + // This non-fptrunc store could execute before the load on some path. + // Check if it could overwrite what the load is reading. + GEPOp storeGep = store.getAddr().getDefiningOp(); + IndexRange storeRange = + getAccessRange(storeGep, store.getValue().getType()); + + // If we can determine ranges and they don't overlap, it's safe + if (storeRange.isValid() && loadRange.isValid() && + !storeRange.overlaps(loadRange)) + continue; + + LLVM_DEBUG(llvm::dbgs() + << "\tUNSAFE: Non-fptrunc store could overwrite value: " << store + << "\n"); + return false; + } + return true; +} + +// Verify that a load -> fpext pattern is safe to optimize. +static FailureOr> +verifySafety(LoadFPExtPattern &pattern, + SmallVector &storeInfos, + DominanceInfo &domInfo) { + // Check that the narrow buffer is an alloca + if (!pattern.narrowBuffer.getDefiningOp()) { + LLVM_DEBUG(llvm::dbgs() << "\tUNSAFE: Narrow buffer is not an alloca\n"); + return failure(); + } + + // Check that the load uses static indices. Dynamic indices suggest the buffer + // is accessed in a pattern-dependent way (e.g., in a loop with runtime + // indexing), where the precision conversion may be algorithmically intentional. + IndexRange loadRange = + getAccessRange(pattern.gepOp, pattern.loadOp.getRes().getType()); + if (!loadRange.isValid()) { + LLVM_DEBUG(llvm::dbgs() + << "\tUNSAFE: Load has dynamic index - cannot verify access " + "pattern matches stores\n"); + return failure(); + } + + // Find all fptrunc stores that cover this load + SmallVector matchingStores = + findMatchingFPTruncStores(pattern, storeInfos, domInfo); + if (matchingStores.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "\tUNSAFE: No matching fptrunc stores found for load\n"); + return failure(); + } + + // Check that all matching stores have compatible element types. + Type fpextElemType = getScalarType(pattern.fpextOp.getRes().getType()); + for (auto *store : matchingStores) { + Type originalWideElemType = getScalarType(store->wideValue.getType()); + if (fpextElemType != originalWideElemType) { + LLVM_DEBUG( + llvm::dbgs() + << "\tUNSAFE: Element type mismatch - fpext result element type " + << fpextElemType << " != original wide element type " + << originalWideElemType << "\n"); + return failure(); + } + + // For existing parallel stores, check wide store dominance + if (store->hasParallelStore() && + !domInfo.dominates(store->wideStore.getOperation(), + pattern.loadOp.getOperation())) { + LLVM_DEBUG(llvm::dbgs() + << "\tUNSAFE: Wide store does not dominate load\n"); + return failure(); + } + + // Check that the fptrunc result is only used by the narrow store. + // If the f16 value has other uses, we can't eliminate the truncation. + auto fptruncOp = store->narrowStore.getValue().getDefiningOp(); + if (fptruncOp && !fptruncOp.getRes().hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() + << "\tUNSAFE: fptrunc result has multiple uses - the f16 " + "value is used elsewhere\n"); + return failure(); + } + } + + // Check that all loads from the narrow buffer are used only by fpext. + // If there are loads that use the f16 values directly, + // we can't eliminate the narrow buffer. + SmallVector allLoads = collectLoadsFromBuffer(pattern.narrowBuffer); + for (LoadOp load : allLoads) { + if (!isLoadOnlyUsedByFPExt(load)) { + LLVM_DEBUG(llvm::dbgs() + << "\tUNSAFE: Buffer has loads that don't go through fpext - " + "f16 values are used directly: " + << load << "\n"); + return failure(); + } + } + + // Check that no non-fptrunc stores could intervene + if (!hasNoInterveningStores(pattern, storeInfos, domInfo)) + return failure(); + + LLVM_DEBUG(llvm::dbgs() << "\tSAFE: All checks passed\n"); + return matchingStores; +} + +// Create a wide store for an fptrunc store info, using the given wide buffer. +static void createWideStore(FPTruncStoreInfo *info, Value wideBuffer, + Type wideElemType, OpBuilder &builder) { + builder.setInsertionPointAfter(info->narrowStore); + Value widePtr; + if (info->narrowGep) { + SmallVector gepArgs; + for (auto idx : info->narrowGep.getIndices()) { + if (auto constIdx = dyn_cast(idx)) + gepArgs.push_back(static_cast(constIdx.getInt())); + else + gepArgs.push_back(cast(idx)); + } + auto wideGep = GEPOp::create(builder, info->narrowGep.getLoc(), + wideBuffer.getType(), wideElemType, wideBuffer, + gepArgs); + wideGep.setNoWrapFlags(info->narrowGep.getNoWrapFlags()); + widePtr = wideGep.getResult(); + } else { + widePtr = wideBuffer; + } + + auto wideStore = StoreOp::create(builder, info->narrowStore.getLoc(), + info->wideValue, widePtr); + info->wideBuffer = wideBuffer; + info->wideStore = wideStore; + LLVM_DEBUG(llvm::dbgs() << "Created wide store: " << wideStore << "\n"); +} + +// Create wide buffers and stores for safe patterns that don't already have +// them. For patterns with existing parallel wide stores, do nothing. For +// patterns without parallel stores, create a new wide alloca and insert a wide +// store right after each narrow store. +static void +createWideBuffersAndStores(SmallVector &safePatterns, + OpBuilder &builder) { + DenseMap narrowToWideBuffer; + DenseSet processedStores; + for (auto &pattern : safePatterns) { + if (pattern.matchingStores.empty()) + continue; + + Type wideElemType = getScalarType(pattern.fpextOp.getRes().getType()); + for (FPTruncStoreInfo *info : pattern.matchingStores) { + if (processedStores.contains(info)) + continue; + processedStores.insert(info); + + if (info->hasParallelStore()) + continue; + + // Check if we already created a wide buffer for this narrow buffer + auto it = narrowToWideBuffer.find(info->narrowBuffer); + if (it != narrowToWideBuffer.end()) { + createWideStore(info, it->second, wideElemType, builder); + continue; + } + + // Need to create new wide buffer + auto narrowAlloca = info->narrowBuffer.getDefiningOp(); + if (!narrowAlloca) { + LLVM_DEBUG( + llvm::dbgs() + << "Cannot create wide buffer: narrow buffer is not alloca\n"); + continue; + } + + builder.setInsertionPointAfter(narrowAlloca); + auto wideAlloca = + AllocaOp::create(builder, narrowAlloca.getLoc(), + narrowAlloca.getResult().getType(), wideElemType, + narrowAlloca.getArraySize()); + + LLVM_DEBUG(llvm::dbgs() << "Created wide alloca: " << wideAlloca << "\n"); + narrowToWideBuffer[info->narrowBuffer] = wideAlloca.getResult(); + createWideStore(info, wideAlloca.getResult(), wideElemType, builder); + } + } +} + +// Apply the transformation: redirect loads from narrow buffer to wide buffer, +// eliminating the fpext operations. +static void applyTransformation(SmallVector &safePatterns, + OpBuilder &builder) { + for (auto &pattern : safePatterns) { + // Find the first matching store with a wide buffer + Value wideBuffer; + for (auto *store : pattern.matchingStores) { + if (store->wideBuffer) { + wideBuffer = store->wideBuffer; + break; + } + } + if (!wideBuffer) { + LLVM_DEBUG(llvm::dbgs() << "No wide buffer for pattern, skipping\n"); + continue; + } + + LLVM_DEBUG(llvm::dbgs() << "Transforming pattern:\n"); + LLVM_DEBUG(llvm::dbgs() << " Load: " << pattern.loadOp << "\n"); + LLVM_DEBUG(llvm::dbgs() << " FPExt: " << pattern.fpextOp << "\n"); + + Type wideType = pattern.fpextOp.getRes().getType(); + Type wideElemType = getScalarType(wideType); + Value newPtr; + if (pattern.gepOp) { + builder.setInsertionPoint(pattern.gepOp); + SmallVector gepArgs; + for (auto idx : pattern.gepOp.getIndices()) { + if (auto constIdx = dyn_cast(idx)) { + gepArgs.push_back(static_cast(constIdx.getInt())); + } else { + gepArgs.push_back(cast(idx)); + } + } + + auto newGep = GEPOp::create(builder, pattern.gepOp.getLoc(), + wideBuffer.getType(), wideElemType, + wideBuffer, gepArgs); + newGep.setNoWrapFlags(pattern.gepOp.getNoWrapFlags()); + newPtr = newGep.getResult(); + } else { + newPtr = wideBuffer; + } + + builder.setInsertionPoint(pattern.loadOp); + unsigned wideAlignment = 4; + if (auto vecType = dyn_cast(wideType)) { + unsigned elemBits = vecType.getElementType().getIntOrFloatBitWidth(); + wideAlignment = (elemBits / 8) * vecType.getNumElements(); + wideAlignment = std::min(wideAlignment, 16u); + } else { + wideAlignment = wideType.getIntOrFloatBitWidth() / 8; + } + + auto newLoad = LoadOp::create(builder, pattern.loadOp.getLoc(), wideType, + newPtr, wideAlignment); + + // Clean up the load -> fpext + pattern.fpextOp.getRes().replaceAllUsesWith(newLoad.getRes()); + pattern.fpextOp.erase(); + if (pattern.loadOp.getRes().use_empty()) { + pattern.loadOp.erase(); + } + + if (pattern.gepOp && pattern.gepOp.getRes().use_empty()) { + pattern.gepOp.erase(); + } + + LLVM_DEBUG(llvm::dbgs() << " Transformation complete.\n"); + } +} + +// Clean up unused narrow buffer operations after transformation. +// If the narrow buffer has no remaining uses, we can remove the fptrunc stores, +// the fptrunc ops (if only used by the store), and the narrow alloca. +static void +cleanupUnusedNarrowBufferOps(SmallVector &safePatterns) { + // Collect all narrow buffers and their associated fptrunc stores + DenseMap> bufferToStores; + for (auto &pattern : safePatterns) { + for (auto *info : pattern.matchingStores) { + bufferToStores[info->narrowBuffer].push_back(info); + } + } + + // Track what we've already erased to avoid double-erase + DenseSet erased; + for (auto &[narrowBuffer, stores] : bufferToStores) { + // Check if the narrow buffer still has uses + // (other than the stores we're about to erase) + bool hasOtherUses = false; + for (Operation *user : narrowBuffer.getUsers()) { + // Check if this user is one of the stores we might erase + bool isTrackedStore = false; + for (auto *info : stores) { + if (info->narrowStore.getOperation() == user) { + isTrackedStore = true; + break; + } + if (info->narrowGep && info->narrowGep.getOperation() == user) { + isTrackedStore = true; + break; + } + } + if (!isTrackedStore) { + hasOtherUses = true; + break; + } + } + + if (hasOtherUses) { + LLVM_DEBUG(llvm::dbgs() << "Narrow buffer still has other uses, " + << "keeping fptrunc stores\n"); + continue; + } + + // No other uses, so we can clean up the fptrunc stores and related ops + LLVM_DEBUG(llvm::dbgs() << "Cleaning up unused narrow buffer ops\n"); + for (auto *info : stores) { + // Capture the fptrunc op before erasing the store + auto fptruncOp = info->narrowStore.getValue().getDefiningOp(); + + // Erase the narrow store + if (!erased.contains(info->narrowStore.getOperation())) { + erased.insert(info->narrowStore.getOperation()); + info->narrowStore.erase(); + LLVM_DEBUG(llvm::dbgs() << "\tErased narrow store\n"); + } + + // Erase the GEP if it has no uses + if (info->narrowGep && info->narrowGep.getRes().use_empty() && + !erased.contains(info->narrowGep.getOperation())) { + erased.insert(info->narrowGep.getOperation()); + info->narrowGep.erase(); + LLVM_DEBUG(llvm::dbgs() << "\tErased narrow GEP\n"); + } + + // Erase the fptrunc if it has no uses + if (fptruncOp && fptruncOp.getRes().use_empty() && + !erased.contains(fptruncOp.getOperation())) { + erased.insert(fptruncOp.getOperation()); + fptruncOp.erase(); + LLVM_DEBUG(llvm::dbgs() << "\tErased fptrunc\n"); + } + } + + // Erase the narrow alloca if it has no uses + if (auto narrowAlloca = narrowBuffer.getDefiningOp()) { + if (narrowAlloca.getResult().use_empty() && + !erased.contains(narrowAlloca.getOperation())) { + erased.insert(narrowAlloca.getOperation()); + narrowAlloca.erase(); + LLVM_DEBUG(llvm::dbgs() << "\tErased narrow alloca\n"); + } + } + } +} + +struct RockRemoveRedundantCastsPass + : public rock::impl::RockRemoveRedundantCastsPassBase< + RockRemoveRedundantCastsPass> { + void runOnOperation() override; +}; + +} // end namespace + +void RockRemoveRedundantCastsPass::runOnOperation() { + LLVMFuncOp funcOp = getOperation(); + OpBuilder builder(funcOp.getContext()); + + LLVM_DEBUG(llvm::dbgs() << "Running RockRemoveRedundantCastsPass on " + << funcOp.getName() << "\n"); + + // Step 1: Find all fptrunc -> store patterns + SmallVector storeInfo = findFPTruncStorePatterns(funcOp); + if (storeInfo.empty()) { + LLVM_DEBUG(llvm::dbgs() << "No fptrunc -> store patterns found.\n"); + return; + } + LLVM_DEBUG(llvm::dbgs() << "Found " << storeInfo.size() + << " fptrunc -> store patterns.\n"); + + // Step 2: Find all load -> fpext patterns + SmallVector loadFPExtPatterns = + findLoadFPExtPatterns(funcOp); + if (loadFPExtPatterns.empty()) { + LLVM_DEBUG(llvm::dbgs() << "No load+fpext patterns found.\n"); + return; + } + LLVM_DEBUG(llvm::dbgs() << "Found " << loadFPExtPatterns.size() + << " load+fpext patterns.\n"); + + // Step 3: Verify safety (applicability) for each pattern + DominanceInfo domInfo(funcOp); + SmallVector safePatterns; + for (auto &pattern : loadFPExtPatterns) { + LLVM_DEBUG(llvm::dbgs() + << "Verifying pattern: load=" << pattern.loadOp << "\n"); + FailureOr> result = + verifySafety(pattern, storeInfo, domInfo); + if (succeeded(result)) { + pattern.matchingStores = *result; + safePatterns.push_back(pattern); + } + } + + if (safePatterns.empty()) { + LLVM_DEBUG(llvm::dbgs() << "No safe patterns to optimize.\n"); + return; + } + + LLVM_DEBUG(llvm::dbgs() << "Found " << safePatterns.size() + << " safe pattern combination(s) to optimize.\n"); + + // Step 4: Create wide buffers and stores for patterns that need them + createWideBuffersAndStores(safePatterns, builder); + + // Step 5: Apply transformation (redirect loads to wide buffer) + applyTransformation(safePatterns, builder); + + // Step 6: Clean up unused narrow buffer operations + cleanupUnusedNarrowBufferOps(safePatterns); + + LLVM_DEBUG(llvm::dbgs() << "Optimized " << safePatterns.size() + << " patterns.\n"); +} diff --git a/mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir b/mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir new file mode 100644 index 000000000000..0a6835ccc681 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir @@ -0,0 +1,253 @@ +// RUN: rocmlir-opt --rock-remove-redundant-casts %s | FileCheck %s + +// Parallel wide buffer already exists, so the pass should redirect the load to +// the wide buffer, eliminating fpext +// CHECK-LABEL: llvm.func @test_parallel_buffer_exists +llvm.func @test_parallel_buffer_exists() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(4 : i64) : i64 + %2 = llvm.mlir.constant(0 : i32) : i32 + %3 = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32> + %4 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %5 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %6 = llvm.fptrunc %3 : vector<4xf32> to vector<4xf16> + llvm.store %6, %4 : vector<4xf16>, !llvm.ptr<5> + llvm.store %3, %5 : vector<4xf32>, !llvm.ptr<5> + %7 = llvm.getelementptr %4[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %8 = llvm.getelementptr %5[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %9 = llvm.fptrunc %3 : vector<4xf32> to vector<4xf16> + llvm.store %9, %7 : vector<4xf16>, !llvm.ptr<5> + llvm.store %3, %8 : vector<4xf32>, !llvm.ptr<5> + %10 = llvm.getelementptr %4[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %11 = llvm.getelementptr %5[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %12 = llvm.fptrunc %3 : vector<4xf32> to vector<4xf16> + llvm.store %12, %10 : vector<4xf16>, !llvm.ptr<5> + llvm.store %3, %11 : vector<4xf32>, !llvm.ptr<5> + %13 = llvm.getelementptr %4[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %14 = llvm.getelementptr %5[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %15 = llvm.fptrunc %3 : vector<4xf32> to vector<4xf16> + llvm.store %15, %13 : vector<4xf16>, !llvm.ptr<5> + llvm.store %3, %14 : vector<4xf32>, !llvm.ptr<5> + %16 = llvm.load %4 : !llvm.ptr<5> -> vector<4xf16> + %17 = llvm.fpext %16 : vector<4xf16> to vector<4xf32> + // CHECK-NOT: llvm.fpext + // CHECK: llvm.load {{.*}} -> vector<4xf32> + %18 = llvm.fadd %17, %3 : vector<4xf32> + + llvm.return +} + +// No parallel buffer so the pass should create one +// CHECK-LABEL: llvm.func @test_create_wide_buffer +llvm.func @test_create_wide_buffer() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<2.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + // CHECK-NOT: llvm.alloca {{.*}} x f16 + // CHECK: llvm.alloca {{.*}} x f32 + %3 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %3, %2 : vector<4xf16>, !llvm.ptr<5> + %4 = llvm.getelementptr %2[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %5 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %5, %4 : vector<4xf16>, !llvm.ptr<5> + %6 = llvm.getelementptr %2[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %7 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %7, %6 : vector<4xf16>, !llvm.ptr<5> + %8 = llvm.getelementptr %2[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %9 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %9, %8 : vector<4xf16>, !llvm.ptr<5> + %10 = llvm.load %2 : !llvm.ptr<5> -> vector<4xf16> + // CHECK-NOT: llvm.fpext + %11 = llvm.fpext %10 : vector<4xf16> to vector<4xf32> + %12 = llvm.fadd %11, %1 : vector<4xf32> + llvm.return +} + +// The load reads subset of what was stored +// CHECK-LABEL: llvm.func @test_subset_load +llvm.func @test_subset_load() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<3.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %3 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %4 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %4, %2 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %3 : vector<4xf32>, !llvm.ptr<5> + %5 = llvm.getelementptr %2[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %6 = llvm.getelementptr %3[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %7 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %7, %5 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %6 : vector<4xf32>, !llvm.ptr<5> + %8 = llvm.getelementptr %2[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %9 = llvm.getelementptr %3[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %10 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %10, %8 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %9 : vector<4xf32>, !llvm.ptr<5> + %11 = llvm.getelementptr %2[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %12 = llvm.getelementptr %3[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %13 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %13, %11 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %12 : vector<4xf32>, !llvm.ptr<5> + %14 = llvm.load %2 : !llvm.ptr<5> -> vector<2xf16> + // CHECK: llvm.load {{.*}} -> vector<2xf32> + %15 = llvm.fpext %14 : vector<2xf16> to vector<2xf32> + // CHECK-NOT: llvm.fpext + %16 = llvm.mlir.constant(dense<1.000000e+00> : vector<2xf32>) : vector<2xf32> + %17 = llvm.fadd %15, %16 : vector<2xf32> + llvm.return +} + +// Unsafe case: intervening non-fptrunc store +// CHECK-LABEL: llvm.func @test_unsafe_intervening_store +llvm.func @test_unsafe_intervening_store() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<4.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.mlir.constant(dense<9.000000e+00> : vector<4xf16>) : vector<4xf16> + %3 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %4 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %5 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %5, %3 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %4 : vector<4xf32>, !llvm.ptr<5> + llvm.store %2, %3 : vector<4xf16>, !llvm.ptr<5> + %6 = llvm.load %3 : !llvm.ptr<5> -> vector<4xf16> + %7 = llvm.fpext %6 : vector<4xf16> to vector<4xf32> + // CHECK: llvm.fpext + %8 = llvm.fadd %7, %1 : vector<4xf32> + llvm.return +} + +// Unsafe case: fpext to different type than original +// CHECK-LABEL: llvm.func @test_type_mismatch +llvm.func @test_type_mismatch() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<5.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %3 = llvm.alloca %0 x f64 : (i64) -> !llvm.ptr<5> + %4 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %4, %2 : vector<4xf16>, !llvm.ptr<5> + %5 = llvm.load %2 : !llvm.ptr<5> -> vector<4xf16> + %6 = llvm.fpext %5 : vector<4xf16> to vector<4xf64> + // CHECK: llvm.fpext + %7 = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf64>) : vector<4xf64> + %8 = llvm.fadd %6, %7 : vector<4xf64> + llvm.return +} + +// Unsafe case: narrow buffer is not an alloca (function argument) +// CHECK-LABEL: llvm.func @test_not_alloca +llvm.func @test_not_alloca(%narrow_buf: !llvm.ptr<5>) { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<6.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %3 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %3, %narrow_buf : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %2 : vector<4xf32>, !llvm.ptr<5> + %4 = llvm.load %narrow_buf : !llvm.ptr<5> -> vector<4xf16> + %5 = llvm.fpext %4 : vector<4xf16> to vector<4xf32> + // CHECK: llvm.fpext + %6 = llvm.fadd %5, %1 : vector<4xf32> + llvm.return +} + +// Unsafe case: buffer not fully covered by fptrunc stores (partial coverage) +// CHECK-LABEL: llvm.func @test_partial_coverage +llvm.func @test_partial_coverage() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<7.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %3 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %4 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %4, %2 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %3 : vector<4xf32>, !llvm.ptr<5> + %5 = llvm.load %2 : !llvm.ptr<5> -> vector<4xf16> + %6 = llvm.fpext %5 : vector<4xf16> to vector<4xf32> + // CHECK: llvm.fpext + %7 = llvm.fadd %6, %1 : vector<4xf32> + llvm.return +} + +// Safe case: non-overlapping intervening store (store to different indices) +// CHECK-LABEL: llvm.func @test_non_overlapping_store +llvm.func @test_non_overlapping_store() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<8.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.mlir.constant(dense<9.000000e+00> : vector<4xf16>) : vector<4xf16> + %3 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %4 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %5 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %5, %3 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %4 : vector<4xf32>, !llvm.ptr<5> + %6 = llvm.getelementptr %3[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %7 = llvm.getelementptr %4[4] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %8 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %8, %6 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %7 : vector<4xf32>, !llvm.ptr<5> + %9 = llvm.getelementptr %3[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %10 = llvm.getelementptr %4[8] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %11 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %11, %9 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %10 : vector<4xf32>, !llvm.ptr<5> + %12 = llvm.getelementptr %3[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %13 = llvm.getelementptr %4[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %14 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %14, %12 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %13 : vector<4xf32>, !llvm.ptr<5> + llvm.store %2, %9 : vector<4xf16>, !llvm.ptr<5> + %15 = llvm.load %3 : !llvm.ptr<5> -> vector<4xf16> + %16 = llvm.fpext %15 : vector<4xf16> to vector<4xf32> + // CHECK-NOT: llvm.fpext + // CHECK: llvm.load {{.*}} -> vector<4xf32> + %17 = llvm.fadd %16, %1 : vector<4xf32> + llvm.return +} + +// Unsafe case: fptrunc store does NOT dominate load (store after load) +// CHECK-LABEL: llvm.func @test_store_after_load +llvm.func @test_store_after_load() { + %0 = llvm.mlir.constant(16 : i64) : i64 + %1 = llvm.mlir.constant(dense<10.000000e+00> : vector<4xf32>) : vector<4xf32> + %2 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %3 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %4 = llvm.load %2 : !llvm.ptr<5> -> vector<4xf16> + %5 = llvm.fpext %4 : vector<4xf16> to vector<4xf32> + // CHECK: llvm.fpext + %6 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> + llvm.store %6, %2 : vector<4xf16>, !llvm.ptr<5> + llvm.store %1, %3 : vector<4xf32>, !llvm.ptr<5> + %7 = llvm.fadd %5, %1 : vector<4xf32> + llvm.return +} + +// Safe case: scalar types (not vectors) +// CHECK-LABEL: llvm.func @test_scalar_types +llvm.func @test_scalar_types() { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(11.000000e+00 : f32) : f32 + %2 = llvm.alloca %0 x f16 : (i64) -> !llvm.ptr<5> + %3 = llvm.alloca %0 x f32 : (i64) -> !llvm.ptr<5> + %4 = llvm.fptrunc %1 : f32 to f16 + llvm.store %4, %2 : f16, !llvm.ptr<5> + llvm.store %1, %3 : f32, !llvm.ptr<5> + %5 = llvm.getelementptr %2[1] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %6 = llvm.getelementptr %3[1] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %7 = llvm.fptrunc %1 : f32 to f16 + llvm.store %7, %5 : f16, !llvm.ptr<5> + llvm.store %1, %6 : f32, !llvm.ptr<5> + %8 = llvm.getelementptr %2[2] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %9 = llvm.getelementptr %3[2] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %10 = llvm.fptrunc %1 : f32 to f16 + llvm.store %10, %8 : f16, !llvm.ptr<5> + llvm.store %1, %9 : f32, !llvm.ptr<5> + %11 = llvm.getelementptr %2[3] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 + %12 = llvm.getelementptr %3[3] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f32 + %13 = llvm.fptrunc %1 : f32 to f16 + llvm.store %13, %11 : f16, !llvm.ptr<5> + llvm.store %1, %12 : f32, !llvm.ptr<5> + %14 = llvm.load %2 : !llvm.ptr<5> -> f16 + %15 = llvm.fpext %14 : f16 to f32 + // CHECK-NOT: llvm.fpext + // CHECK: llvm.load {{.*}} -> f32 + %16 = llvm.fadd %15, %1 : f32 + llvm.return +} + diff --git a/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir b/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir new file mode 100644 index 000000000000..7e28183b0fe8 --- /dev/null +++ b/mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir @@ -0,0 +1,33 @@ +// RUN: sed s/##TOKEN_ARCH##/%arch/g %s | rocmlir-gen -fut mlir_remove_casts --arch %arch --clone-harness - | rocmlir-driver -kernel-pipeline=migraphx,highlevel -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_remove_casts_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s +// RUN: sed s/##TOKEN_ARCH##/%arch/g %s | rocmlir-driver --arch %arch --kernel-pipeline=migraphx,highlevel - | rocmlir-driver -arch %arch -c -mlir-print-ir-after=rock-remove-redundant-casts 2>&1 | FileCheck %s --check-prefix=NO-FPEXT + +// CHECK: [1 1 1] +// NO-FPEXT-NOT: llvm.fpext +module { +func.func @mlir_remove_casts(%arg0: !migraphx.shaped<1x64x5x128xf16, 40960x640x128x1>, %arg1: !migraphx.shaped<1x64x5x128xf16, 40960x640x128x1>, %arg2: !migraphx.shaped<1x64x5x128xf16, 40960x640x128x1>) -> !migraphx.shaped<1x5x64x128xf16, 40960x8192x128x1> attributes {arch="gfx950", kernel = "mixr"} { + %0 = migraphx.literal(dense<8.837890e-02> : tensor<1xf16>) : <1xf16, 0> + %1 = migraphx.transpose %arg0 {permutation = [0, 2, 1, 3]} : <1x64x5x128xf16, 40960x640x128x1> -> <1x5x64x128xf16, 40960x128x640x1> + %2 = migraphx.transpose %arg1 {permutation = [0, 2, 3, 1]} : <1x64x5x128xf16, 40960x640x128x1> -> <1x5x128x64xf16, 40960x128x1x640> + %3 = migraphx.transpose %arg2 {permutation = [0, 2, 1, 3]} : <1x64x5x128xf16, 40960x640x128x1> -> <1x5x64x128xf16, 40960x128x640x1> + %4 = migraphx.dot %1, %2 : <1x5x64x128xf16, 40960x128x640x1>, <1x5x128x64xf16, 40960x128x1x640> -> <1x5x64x64xf16, 20480x4096x64x1> + %5 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 5, 64, 64]} : <1xf16, 0> -> <1x5x64x64xf16, 0x0x0x0> + %6 = migraphx.convert %4 {target_type = 2 : i64} : <1x5x64x64xf16, 20480x4096x64x1> to <1x5x64x64xf32, 20480x4096x64x1> + %7 = migraphx.convert %5 {target_type = 2 : i64} : <1x5x64x64xf16, 0x0x0x0> to <1x5x64x64xf32, 0x0x0x0> + %8 = migraphx.mul %6, %7 : <1x5x64x64xf32, 20480x4096x64x1>, <1x5x64x64xf32, 0x0x0x0> -> <1x5x64x64xf32, 20480x4096x64x1> + %9 = migraphx.reshape %8 {dims = [1, 5, 64, 64]} : <1x5x64x64xf32, 20480x4096x64x1> -> <1x5x64x64xf32, 20480x4096x64x1> + %10 = migraphx.reduce_max %9 {axes = [3]} : <1x5x64x64xf32, 20480x4096x64x1> -> <1x5x64x1xf32, 320x64x1x1> + %11 = migraphx.reshape %10 {dims = [1, 5, 64, 1]} : <1x5x64x1xf32, 320x64x1x1> -> <1x5x64x1xf32, 320x64x1x1> + %12 = migraphx.multibroadcast %11 {out_dyn_dims = [], out_lens = [1, 5, 64, 64]} : <1x5x64x1xf32, 320x64x1x1> -> <1x5x64x64xf32, 320x64x1x0> + %13 = migraphx.sub %8, %12 : <1x5x64x64xf32, 20480x4096x64x1>, <1x5x64x64xf32, 320x64x1x0> -> <1x5x64x64xf32, 20480x4096x64x1> + %14 = migraphx.exp %13 : <1x5x64x64xf32, 20480x4096x64x1> -> <1x5x64x64xf32, 20480x4096x64x1> + %15 = migraphx.reshape %14 {dims = [1, 5, 64, 64]} : <1x5x64x64xf32, 20480x4096x64x1> -> <1x5x64x64xf32, 20480x4096x64x1> + %16 = migraphx.reduce_sum %15 {axes = [3]} : <1x5x64x64xf32, 20480x4096x64x1> -> <1x5x64x1xf32, 320x64x1x1> + %17 = migraphx.reshape %16 {dims = [1, 5, 64, 1]} : <1x5x64x1xf32, 320x64x1x1> -> <1x5x64x1xf32, 320x64x1x1> + %18 = migraphx.multibroadcast %17 {out_dyn_dims = [], out_lens = [1, 5, 64, 64]} : <1x5x64x1xf32, 320x64x1x1> -> <1x5x64x64xf32, 320x64x1x0> + %19 = migraphx.div %14, %18 : <1x5x64x64xf32, 20480x4096x64x1>, <1x5x64x64xf32, 320x64x1x0> -> <1x5x64x64xf32, 20480x4096x64x1> + %20 = migraphx.convert %19 {target_type = 1 : i64} : <1x5x64x64xf32, 20480x4096x64x1> to <1x5x64x64xf16, 20480x4096x64x1> + %21 = migraphx.dot %20, %3 : <1x5x64x64xf16, 20480x4096x64x1>, <1x5x64x128xf16, 40960x128x640x1> -> <1x5x64x128xf16, 40960x8192x128x1> + return %21 : !migraphx.shaped<1x5x64x128xf16, 40960x8192x128x1> + } +} + diff --git a/mlir/test/rocmlir-driver/pipelines.mlir b/mlir/test/rocmlir-driver/pipelines.mlir index 3554d4752b1c..a8aa2042985a 100644 --- a/mlir/test/rocmlir-driver/pipelines.mlir +++ b/mlir/test/rocmlir-driver/pipelines.mlir @@ -74,6 +74,7 @@ // BINARY-NEXT:llvm.func(rock-to-rocdl{chipset=gfx90a}), // BINARY-NEXT:llvm.func(canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, // BINARY-NEXT:cse, +// BINARY-NEXT:rock-remove-redundant-casts, // BINARY-NEXT:rock-prepare-llvm)), // BINARY-NEXT:rocdl-attach-target{O=3 abi=600 chip=gfx90a correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64=true}, // BINARY-NEXT:gpu-module-to-binary{format=fatbin opts= section= toolkit=}, @@ -105,6 +106,7 @@ // BINARY_MI300-NEXT:llvm.func(rock-to-rocdl{chipset=gfx942}), // BINARY_MI300-NEXT:llvm.func(canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, // BINARY_MI300-NEXT:cse, +// BINARY_MI300-NEXT:rock-remove-redundant-casts, // BINARY_MI300-NEXT:rock-prepare-llvm)), // BINARY_MI300-NEXT:rocdl-attach-target{O=3 abi=600 chip=gfx942 correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64=true}, // BINARY_MI300-NEXT:gpu-module-to-binary{format=fatbin opts= section= toolkit=}, @@ -136,6 +138,7 @@ // BINARY_MI350-NEXT:llvm.func(rock-to-rocdl{chipset=gfx950}), // BINARY_MI350-NEXT:llvm.func(canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, // BINARY_MI350-NEXT:cse, +// BINARY_MI350-NEXT:rock-remove-redundant-casts, // BINARY_MI350-NEXT:rock-prepare-llvm)), // BINARY_MI350-NEXT:rocdl-attach-target{O=3 abi=600 chip=gfx950 correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64=true}, // BINARY_MI350-NEXT:gpu-module-to-binary{format=fatbin opts= section= toolkit=},