From aefeb311e9c4a39d2c2389e4675d126e6af45c77 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 23 Dec 2025 20:57:15 +0000 Subject: [PATCH 01/15] add schedGroup --- mlir/include/mlir/Dialect/Rock/Passes.h | 1 + mlir/include/mlir/Dialect/Rock/Passes.td | 22 +- mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp | 1 + .../Rock/Transforms/AnalyzeScfForOps.cpp | 294 ++++++++++++++++++ .../Dialect/Rock/Transforms/CMakeLists.txt | 1 + 5 files changed, 318 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/Rock/Transforms/AnalyzeScfForOps.cpp diff --git a/mlir/include/mlir/Dialect/Rock/Passes.h b/mlir/include/mlir/Dialect/Rock/Passes.h index c0771efaae26..4591e85aca63 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.h +++ b/mlir/include/mlir/Dialect/Rock/Passes.h @@ -56,6 +56,7 @@ namespace rock { #define GEN_PASS_DECL_ROCKADDASYNCWAITPASS #define GEN_PASS_DECL_ROCKADDDIRECTTOLDSALIASINFOPASS #define GEN_PASS_DECL_CONVERTROCKOPSTOROCDLOPS +#define GEN_PASS_DECL_ROCKANALYZESCFFOROPSPASS #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Rock/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index 6f2965a15373..7095297c44ef 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -143,7 +143,7 @@ def RockPipelinePass : Pass<"rock-pipeline", "::mlir::func::FuncOp"> { Option<"removeStages", "rock-pipeline-remove-stages", "bool", "true", "Remove pipeline stages once the pipeline pass is completed"> ]; - let dependentDialects = ["rock::RockDialect", "affine::AffineDialect"]; + let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "ROCDL::ROCDLDialect", "amdgpu::AMDGPUDialect"]; let summary = "Pipeline loops"; } @@ -157,6 +157,26 @@ def RockBufferLoadMergePass : Pass<"rock-buffer-load-merge", "::mlir::func::Func let dependentDialects = ["::mlir::amdgpu::AMDGPUDialect"]; } +def RockAnalyzeScfForOpsPass : Pass<"rock-analyze-scf-for-ops", "::mlir::func::FuncOp"> { + let summary = "Analyze scf.for loops and insert scheduling group barriers"; + let description = [{ + This pass analyzes scf.for operations, counts memory operations and MFMA + instructions per iteration, and inserts scheduling group barriers: + - Global memory loads (amdgpu.raw_buffer_load, vector.load from global memory) + - LDS/workgroup memory reads (memref.load from workgroup address space) + - LDS/workgroup memory writes (memref.store to workgroup address space) + - MFMA instructions (amdgpu.mfma) + + The counts factor in affine.for loop trip counts to give the total number of + operations per scf.for iteration. Based on these counts, scheduling group + barriers (ROCDL::SchedGroupBarrier) are inserted to optimize instruction + scheduling on AMD GPUs. + }]; + let dependentDialects = ["::mlir::amdgpu::AMDGPUDialect", "::mlir::scf::SCFDialect", + "::mlir::affine::AffineDialect", "::mlir::gpu::GPUDialect", + "::mlir::ROCDL::ROCDLDialect"]; +} + def RockTransformToMemrefPass : Pass<"rock-transform-to-memref", "::mlir::func::FuncOp"> { let summary = "convert remaining rock.transform ops to memref.expand/collapse_shape"; let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "gpu::GPUDialect", "vector::VectorDialect", "memref::MemRefDialect"]; diff --git a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp index a945ed55ca8e..22baadff210b 100644 --- a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp +++ b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp @@ -237,6 +237,7 @@ void rock::buildKernelPipeline(OpPassManager &pm, funcPm.addPass( math::createMathExtendToSupportedTypes(extendToLLVMTypesOptions)); funcPm.addPass(rock::createRockBufferLoadMergePass()); + funcPm.addPass(rock::createRockAnalyzeScfForOpsPass()); funcPm.addPass(rock::createRockTransformToMemrefPass()); funcPm.addPass(rock::createRockEmulateNarrowTypePass()); funcPm.addPass(rock::createRockPack4BitGpuOpsTo8BitPass()); diff --git a/mlir/lib/Dialect/Rock/Transforms/AnalyzeScfForOps.cpp b/mlir/lib/Dialect/Rock/Transforms/AnalyzeScfForOps.cpp new file mode 100644 index 000000000000..bd130e421892 --- /dev/null +++ b/mlir/lib/Dialect/Rock/Transforms/AnalyzeScfForOps.cpp @@ -0,0 +1,294 @@ +//===- AnalyzeScfForOps.cpp - Analyze scf.for memory ops and MFMA ---------===// +// +// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Copyright (c) 2024 Advanced Micro Devices Inc. +//===----------------------------------------------------------------------===// +// +// This pass analyzes scf.for loops to count memory operations and MFMA +// instructions per iteration: +// - Global memory loads (amdgpu.raw_buffer_load, vector.load from global) +// - LDS/workgroup memory reads (memref.load from workgroup address space) +// - LDS/workgroup memory writes (memref.store to workgroup address space) +// - MFMA instructions (amdgpu.mfma) +// +// The counts factor in affine.for loop trip counts. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Rock/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "rock-analyze-scf-for-ops" + +namespace mlir { +namespace rock { +#define GEN_PASS_DEF_ROCKANALYZESCFFOROPSPASS +#include "mlir/Dialect/Rock/Passes.h.inc" +} // namespace rock +} // namespace mlir + +using namespace mlir; +using namespace mlir::rock; + +namespace { + +/// Check if a memref type has workgroup (LDS) address space +static bool hasWorkgroupAddressSpace(MemRefType memrefType) { + auto addrSpace = memrefType.getMemorySpace(); + if (!addrSpace) + return false; + + // Check for gpu.address_space + if (auto gpuAddrSpace = dyn_cast(addrSpace)) { + return gpuAddrSpace.getValue() == gpu::AddressSpace::Workgroup; + } + return false; +} + +/// Check if a memref type has global address space +static bool hasGlobalAddressSpace(MemRefType memrefType) { + auto addrSpace = memrefType.getMemorySpace(); + // No address space means global by default + if (!addrSpace) + return true; + + // Check for gpu.address_space + if (auto gpuAddrSpace = dyn_cast(addrSpace)) { + return gpuAddrSpace.getValue() == gpu::AddressSpace::Global; + } + return false; +} + +/// Get the trip count of an affine.for loop, returns 1 if unknown +static uint64_t getAffineForTripCount(affine::AffineForOp affineFor) { + std::optional tripCount = affine::getConstantTripCount(affineFor); + if (tripCount.has_value()) { + return tripCount.value(); + } + // If we can't determine the trip count, return 1 (conservative estimate) + return 1; +} + +/// Compute the multiplier for an operation based on enclosing affine.for loops +/// within the scf.for boundary +static uint64_t computeAffineLoopMultiplier(Operation *op, scf::ForOp boundary) { + uint64_t multiplier = 1; + Operation *parent = op->getParentOp(); + + while (parent && parent != boundary.getOperation()) { + if (auto affineFor = dyn_cast(parent)) { + multiplier *= getAffineForTripCount(affineFor); + } + parent = parent->getParentOp(); + } + + return multiplier; +} + +struct ScfForAnalysisResult { + uint64_t globalLoads = 0; + uint64_t ldsReads = 0; + uint64_t ldsWrites = 0; + uint64_t mfmaOps = 0; +}; + +/// Analyze a single scf.for operation +static ScfForAnalysisResult analyzeScfFor(scf::ForOp forOp) { + ScfForAnalysisResult result; + + forOp.walk([&](Operation *op) { + uint64_t multiplier = computeAffineLoopMultiplier(op, forOp); + + // Count amdgpu.raw_buffer_load (global loads) + if (isa(op)) { + result.globalLoads += multiplier; + return; + } + + // Count vector.load from global memory + if (auto vectorLoad = dyn_cast(op)) { + if (auto memrefType = dyn_cast(vectorLoad.getBase().getType())) { + if (hasGlobalAddressSpace(memrefType)) { + result.globalLoads += multiplier; + } + } + return; + } + + // Count vector.transfer_read from global memory + if (auto transferRead = dyn_cast(op)) { + if (auto memrefType = dyn_cast(transferRead.getBase().getType())) { + if (hasGlobalAddressSpace(memrefType)) { + result.globalLoads += multiplier; + } + } + return; + } + + // Count memref.load from workgroup memory (LDS reads) + if (auto memrefLoad = dyn_cast(op)) { + if (auto memrefType = dyn_cast(memrefLoad.getMemRef().getType())) { + if (hasWorkgroupAddressSpace(memrefType)) { + result.ldsReads += multiplier; + } + } + return; + } + + // Count memref.store to workgroup memory (LDS writes) + if (auto memrefStore = dyn_cast(op)) { + if (auto memrefType = dyn_cast(memrefStore.getMemRef().getType())) { + if (hasWorkgroupAddressSpace(memrefType)) { + result.ldsWrites += multiplier; + } + } + return; + } + + // Count vector.transfer_write to workgroup memory (LDS writes) + if (auto transferWrite = dyn_cast(op)) { + if (auto memrefType = dyn_cast(transferWrite.getBase().getType())) { + if (hasWorkgroupAddressSpace(memrefType)) { + result.ldsWrites += multiplier; + } + } + return; + } + + // Count amdgpu.mfma operations + if (isa(op)) { + result.mfmaOps += multiplier; + return; + } + }); + + return result; +} + +/// Rewrite pattern to insert scheduling group barriers in scf.for loops +struct InsertSchedGroupBarrierPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rw) const override { + mlir::Region ®ion = op.getRegion(); + Block &block = region.front(); + + // Check if SchedBarrierOp already exists (to avoid duplicates) + WalkResult result = block.walk([&](Operation *innerOp) { + if (isa(innerOp)) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) + return failure(); + + // Analyze the scf.for loop to get operation counts + ScfForAnalysisResult analysis = analyzeScfFor(op); + + // Skip if no meaningful operations found + if (analysis.globalLoads == 0 && analysis.mfmaOps == 0) + return failure(); + + // Print analysis results for debugging + LLVM_DEBUG({ + llvm::dbgs() << "=== scf.for Analysis ===\n"; + llvm::dbgs() << "Location: " << op.getLoc() << "\n"; + llvm::dbgs() << "Global memory loads per iteration: " << analysis.globalLoads << "\n"; + llvm::dbgs() << "LDS reads per iteration: " << analysis.ldsReads << "\n"; + llvm::dbgs() << "LDS writes per iteration: " << analysis.ldsWrites << "\n"; + llvm::dbgs() << "MFMA operations per iteration: " << analysis.mfmaOps << "\n"; + llvm::dbgs() << "========================\n\n"; + }); + + uint64_t numBufferLoads = analysis.globalLoads; + uint64_t numDSReads = analysis.ldsReads; + uint64_t numDSWrites = analysis.ldsWrites; + uint64_t numMFMA = analysis.mfmaOps; + + // Insert sched_barrier at the start of the block + rw.setInsertionPointToStart(&block); + amdgpu::SchedBarrierOp::create( + rw, op.getLoc(), + amdgpu::sched_barrier_opt_enumAttr::get( + rw.getContext(), amdgpu::sched_barrier_opt_enum::none)); + + // Insert sched group barriers before the terminator + auto *lastOp = block.getTerminator()->getPrevNode(); + rw.setInsertionPointAfter(lastOp); + + // Insert sched group barriers based on the analysis + if (numBufferLoads > 0) { + uint64_t dsReadsPerLoad = numDSReads / numBufferLoads; + uint64_t dsWritesPerLoad = numDSWrites / numBufferLoads; + uint64_t mfmaPerLoad = numMFMA / numBufferLoads; + // Ensure we have at least 3 MFMAs to distribute (for the pattern) + uint64_t remainingMfma = mfmaPerLoad > 3 ? mfmaPerLoad - 3 : 0; + + for (uint64_t i = 0; i < numBufferLoads; i++) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, 0); // MFMA + if (dsReadsPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x100, + dsReadsPerLoad, 0); // DS Reads + } + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, 0); // MFMA + if (dsWritesPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, + dsWritesPerLoad, 0); // DS Writes + } + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, 0); // MFMA + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x020, 1, 0); // VMEM + if (remainingMfma > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, + remainingMfma, 0); // MFMA + } + } + } + + // Insert sched_barrier at the end + amdgpu::SchedBarrierOp::create( + rw, op.getLoc(), + amdgpu::sched_barrier_opt_enumAttr::get( + rw.getContext(), amdgpu::sched_barrier_opt_enum::none)); + + return success(); + } +}; + +struct RockAnalyzeScfForOpsPass final + : rock::impl::RockAnalyzeScfForOpsPassBase { + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *ctx = funcOp.getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // end namespace + diff --git a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt index 4147eaf69b75..827f02f22b19 100644 --- a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_rocmlir_dialect_library(MLIRRockTransforms AffixTuningParameters.cpp AlignTiling.cpp AnalyzeMemoryUse.cpp + AnalyzeScfForOps.cpp BlockwiseGemmToThreadwise.cpp BufferLoadMerge.cpp BufferizableOpInterfaceImpl.cpp From d27c187ae24f40f233c7501feb9c969967550c1e Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 23 Dec 2025 23:27:35 +0000 Subject: [PATCH 02/15] change names --- mlir/include/mlir/Dialect/Rock/Passes.h | 2 +- mlir/include/mlir/Dialect/Rock/Passes.td | 4 +-- mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp | 2 +- ...cfForOps.cpp => AddSchedGroupBarriers.cpp} | 32 +++++++++---------- .../Dialect/Rock/Transforms/CMakeLists.txt | 2 +- 5 files changed, 21 insertions(+), 21 deletions(-) rename mlir/lib/Dialect/Rock/Transforms/{AnalyzeScfForOps.cpp => AddSchedGroupBarriers.cpp} (92%) diff --git a/mlir/include/mlir/Dialect/Rock/Passes.h b/mlir/include/mlir/Dialect/Rock/Passes.h index 4591e85aca63..f30dc248c8c8 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.h +++ b/mlir/include/mlir/Dialect/Rock/Passes.h @@ -56,7 +56,7 @@ namespace rock { #define GEN_PASS_DECL_ROCKADDASYNCWAITPASS #define GEN_PASS_DECL_ROCKADDDIRECTTOLDSALIASINFOPASS #define GEN_PASS_DECL_CONVERTROCKOPSTOROCDLOPS -#define GEN_PASS_DECL_ROCKANALYZESCFFOROPSPASS +#define GEN_PASS_DECL_ROCKADDSCHEDGROUPBARRIERSPASS #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Rock/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index 7095297c44ef..45690558d38b 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -143,7 +143,7 @@ def RockPipelinePass : Pass<"rock-pipeline", "::mlir::func::FuncOp"> { Option<"removeStages", "rock-pipeline-remove-stages", "bool", "true", "Remove pipeline stages once the pipeline pass is completed"> ]; - let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "ROCDL::ROCDLDialect", "amdgpu::AMDGPUDialect"]; + let dependentDialects = ["rock::RockDialect", "affine::AffineDialect"]; let summary = "Pipeline loops"; } @@ -157,7 +157,7 @@ def RockBufferLoadMergePass : Pass<"rock-buffer-load-merge", "::mlir::func::Func let dependentDialects = ["::mlir::amdgpu::AMDGPUDialect"]; } -def RockAnalyzeScfForOpsPass : Pass<"rock-analyze-scf-for-ops", "::mlir::func::FuncOp"> { +def RockAddSchedGroupBarriersPass : Pass<"rock-add-sched-group-barriers", "::mlir::func::FuncOp"> { let summary = "Analyze scf.for loops and insert scheduling group barriers"; let description = [{ This pass analyzes scf.for operations, counts memory operations and MFMA diff --git a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp index 22baadff210b..002906a4df98 100644 --- a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp +++ b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp @@ -237,7 +237,7 @@ void rock::buildKernelPipeline(OpPassManager &pm, funcPm.addPass( math::createMathExtendToSupportedTypes(extendToLLVMTypesOptions)); funcPm.addPass(rock::createRockBufferLoadMergePass()); - funcPm.addPass(rock::createRockAnalyzeScfForOpsPass()); + funcPm.addPass(rock::createRockAddSchedGroupBarriersPass()); funcPm.addPass(rock::createRockTransformToMemrefPass()); funcPm.addPass(rock::createRockEmulateNarrowTypePass()); funcPm.addPass(rock::createRockPack4BitGpuOpsTo8BitPass()); diff --git a/mlir/lib/Dialect/Rock/Transforms/AnalyzeScfForOps.cpp b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp similarity index 92% rename from mlir/lib/Dialect/Rock/Transforms/AnalyzeScfForOps.cpp rename to mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp index bd130e421892..25002b1112d4 100644 --- a/mlir/lib/Dialect/Rock/Transforms/AnalyzeScfForOps.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp @@ -1,14 +1,14 @@ -//===- AnalyzeScfForOps.cpp - Analyze scf.for memory ops and MFMA ---------===// +//===- AddSchedGroupBarriers.cpp - Add scheduling group barriers ----------===// // // Part of the rocMLIR Project, under the Apache License v2.0 with LLVM // Exceptions. See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// Copyright (c) 2024 Advanced Micro Devices Inc. +// Copyright (c) 2025 Advanced Micro Devices Inc. //===----------------------------------------------------------------------===// // // This pass analyzes scf.for loops to count memory operations and MFMA -// instructions per iteration: +// instructions per iteration, then inserts scheduling group barriers: // - Global memory loads (amdgpu.raw_buffer_load, vector.load from global) // - LDS/workgroup memory reads (memref.load from workgroup address space) // - LDS/workgroup memory writes (memref.store to workgroup address space) @@ -34,13 +34,14 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#define DEBUG_TYPE "rock-analyze-scf-for-ops" +#define DEBUG_TYPE "rock-add-sched-group-barriers" namespace mlir { namespace rock { -#define GEN_PASS_DEF_ROCKANALYZESCFFOROPSPASS +#define GEN_PASS_DEF_ROCKADDSCHEDGROUPBARRIERSPASS #include "mlir/Dialect/Rock/Passes.h.inc" } // namespace rock } // namespace mlir @@ -238,18 +239,13 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { // Insert sched group barriers based on the analysis if (numBufferLoads > 0) { - uint64_t dsReadsPerLoad = numDSReads / numBufferLoads; - uint64_t dsWritesPerLoad = numDSWrites / numBufferLoads; - uint64_t mfmaPerLoad = numMFMA / numBufferLoads; + uint64_t dsReadsPerLoad = llvm::divideCeil(numDSReads, numBufferLoads); + uint64_t dsWritesPerLoad = llvm::divideCeil(numDSWrites, numBufferLoads); + uint64_t mfmaPerLoad = llvm::divideCeil(numMFMA, numBufferLoads); // Ensure we have at least 3 MFMAs to distribute (for the pattern) - uint64_t remainingMfma = mfmaPerLoad > 3 ? mfmaPerLoad - 3 : 0; + uint64_t remainingMfma = mfmaPerLoad > 2 ? mfmaPerLoad - 2 : 0; for (uint64_t i = 0; i < numBufferLoads; i++) { - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, 0); // MFMA - if (dsReadsPerLoad > 0) { - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x100, - dsReadsPerLoad, 0); // DS Reads - } ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, 0); // MFMA if (dsWritesPerLoad > 0) { ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, @@ -261,6 +257,10 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, remainingMfma, 0); // MFMA } + if (dsReadsPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x100, + dsReadsPerLoad, 0); // DS Reads + } } } @@ -274,8 +274,8 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { } }; -struct RockAnalyzeScfForOpsPass final - : rock::impl::RockAnalyzeScfForOpsPassBase { +struct RockAddSchedGroupBarriersPass final + : rock::impl::RockAddSchedGroupBarriersPassBase { void runOnOperation() override { func::FuncOp funcOp = getOperation(); diff --git a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt index 827f02f22b19..1a0ae7832559 100644 --- a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt @@ -3,7 +3,7 @@ add_rocmlir_dialect_library(MLIRRockTransforms AffixTuningParameters.cpp AlignTiling.cpp AnalyzeMemoryUse.cpp - AnalyzeScfForOps.cpp + AddSchedGroupBarriers.cpp BlockwiseGemmToThreadwise.cpp BufferLoadMerge.cpp BufferizableOpInterfaceImpl.cpp From c9386a1101c12b2f8afcbb02ff027fea21e2b7a7 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Fri, 26 Dec 2025 06:16:53 -0600 Subject: [PATCH 03/15] try exhaustive tune --- mlir/utils/tuna/tuna-script.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/utils/tuna/tuna-script.sh b/mlir/utils/tuna/tuna-script.sh index cd966654a079..9feef2108015 100755 --- a/mlir/utils/tuna/tuna-script.sh +++ b/mlir/utils/tuna/tuna-script.sh @@ -96,7 +96,7 @@ export TUNA_DIR=/tmp/MITuna export ROCMLIR_DIR=$(pwd)/.. # Assumes we're in the build directory export OUT_FILE=results.tsv export OP=convolution -export TUNING_SPACE=full +export TUNING_SPACE=exhaustive export LOAD_FACTOR= # -c configs From 05e0f09ad4153aa38e19d675b443e162bd21cc07 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 29 Dec 2025 19:11:25 +0000 Subject: [PATCH 04/15] This achieves 160 TFLops --- .../Rock/Transforms/AddSchedGroupBarriers.cpp | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp index 25002b1112d4..b2c0a91dfe22 100644 --- a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp @@ -239,27 +239,35 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { // Insert sched group barriers based on the analysis if (numBufferLoads > 0) { - uint64_t dsReadsPerLoad = llvm::divideCeil(numDSReads, numBufferLoads); - uint64_t dsWritesPerLoad = llvm::divideCeil(numDSWrites, numBufferLoads); - uint64_t mfmaPerLoad = llvm::divideCeil(numMFMA, numBufferLoads); // Ensure we have at least 3 MFMAs to distribute (for the pattern) - uint64_t remainingMfma = mfmaPerLoad > 2 ? mfmaPerLoad - 2 : 0; - for (uint64_t i = 0; i < numBufferLoads; i++) { - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, 0); // MFMA - if (dsWritesPerLoad > 0) { - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, - dsWritesPerLoad, 0); // DS Writes - } - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, 0); // MFMA + uint64_t dsReadsPerLoad = llvm::divideCeil(numDSReads, numBufferLoads); + uint64_t dsWritesPerLoad = + llvm::divideCeil(numDSWrites, numBufferLoads); + uint64_t mfmaPerLoad = llvm::divideCeil(numMFMA, numBufferLoads); ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x020, 1, 0); // VMEM - if (remainingMfma > 0) { - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, - remainingMfma, 0); // MFMA - } if (dsReadsPerLoad > 0) { ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x100, - dsReadsPerLoad, 0); // DS Reads + dsReadsPerLoad, + 0); // DS Reads + } + uint64_t mfmaPerDSWrite = + llvm::divideCeil(mfmaPerLoad, dsWritesPerLoad); + if (dsWritesPerLoad > 0 && mfmaPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, 1, + 0); // DS Writes + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, + mfmaPerDSWrite, 0); // MFMA + mfmaPerLoad -= mfmaPerDSWrite; + dsWritesPerLoad--; + } + if (mfmaPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, mfmaPerLoad, + 0); // MFMA + } + if (dsWritesPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, + dsWritesPerLoad, 0); // DS Writes } } } From 42f005b3ee5a1dbb20643cf7c48bd9d7e2416d75 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 29 Dec 2025 19:22:22 +0000 Subject: [PATCH 05/15] This is better achieves upto 166 TFLops --- .../Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp index b2c0a91dfe22..a82e8f4c6388 100644 --- a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp @@ -253,7 +253,7 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { } uint64_t mfmaPerDSWrite = llvm::divideCeil(mfmaPerLoad, dsWritesPerLoad); - if (dsWritesPerLoad > 0 && mfmaPerLoad > 0) { + while (dsWritesPerLoad > 0 && mfmaPerLoad > 0) { ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, 1, 0); // DS Writes ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, @@ -261,14 +261,14 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { mfmaPerLoad -= mfmaPerDSWrite; dsWritesPerLoad--; } - if (mfmaPerLoad > 0) { - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, mfmaPerLoad, - 0); // MFMA - } if (dsWritesPerLoad > 0) { ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, dsWritesPerLoad, 0); // DS Writes } + if (mfmaPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, mfmaPerLoad, + 0); // MFMA + } } } From e50779e25c31e8477859e7cbe678999b1b6e5c4b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 29 Dec 2025 20:13:27 +0000 Subject: [PATCH 06/15] add logic for both single and double buffered pipelines --- .../Rock/Transforms/AddSchedGroupBarriers.cpp | 103 +++++++++++++----- 1 file changed, 78 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp index a82e8f4c6388..dd8995580286 100644 --- a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp @@ -18,9 +18,10 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -78,6 +79,12 @@ static bool hasGlobalAddressSpace(MemRefType memrefType) { return false; } +/// Check if a value is defined by an arith.select operation, which indicates +/// double buffering (selecting between two different LDS buffers) +static bool isDefinedBySelect(Value val) { + return val.getDefiningOp() != nullptr; +} + /// Get the trip count of an affine.for loop, returns 1 if unknown static uint64_t getAffineForTripCount(affine::AffineForOp affineFor) { std::optional tripCount = affine::getConstantTripCount(affineFor); @@ -109,6 +116,9 @@ struct ScfForAnalysisResult { uint64_t ldsReads = 0; uint64_t ldsWrites = 0; uint64_t mfmaOps = 0; + /// Indicates if the loop uses double buffering (LDS reads/writes use + /// arith.select to choose between two buffers) + bool isDoubleBuffered = false; }; /// Analyze a single scf.for operation @@ -149,6 +159,11 @@ static ScfForAnalysisResult analyzeScfFor(scf::ForOp forOp) { if (auto memrefType = dyn_cast(memrefLoad.getMemRef().getType())) { if (hasWorkgroupAddressSpace(memrefType)) { result.ldsReads += multiplier; + // Check for double buffering: if the memref is selected via + // arith.select, it indicates alternating between two LDS buffers + if (isDefinedBySelect(memrefLoad.getMemRef())) { + result.isDoubleBuffered = true; + } } } return; @@ -159,6 +174,11 @@ static ScfForAnalysisResult analyzeScfFor(scf::ForOp forOp) { if (auto memrefType = dyn_cast(memrefStore.getMemRef().getType())) { if (hasWorkgroupAddressSpace(memrefType)) { result.ldsWrites += multiplier; + // Check for double buffering: if the memref is selected via + // arith.select, it indicates alternating between two LDS buffers + if (isDefinedBySelect(memrefStore.getMemRef())) { + result.isDoubleBuffered = true; + } } } return; @@ -169,6 +189,11 @@ static ScfForAnalysisResult analyzeScfFor(scf::ForOp forOp) { if (auto memrefType = dyn_cast(transferWrite.getBase().getType())) { if (hasWorkgroupAddressSpace(memrefType)) { result.ldsWrites += multiplier; + // Check for double buffering: if the memref is selected via + // arith.select, it indicates alternating between two LDS buffers + if (isDefinedBySelect(transferWrite.getBase())) { + result.isDoubleBuffered = true; + } } } return; @@ -218,6 +243,8 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { llvm::dbgs() << "LDS reads per iteration: " << analysis.ldsReads << "\n"; llvm::dbgs() << "LDS writes per iteration: " << analysis.ldsWrites << "\n"; llvm::dbgs() << "MFMA operations per iteration: " << analysis.mfmaOps << "\n"; + llvm::dbgs() << "Double buffering detected: " + << (analysis.isDoubleBuffered ? "yes" : "no") << "\n"; llvm::dbgs() << "========================\n\n"; }); @@ -239,35 +266,61 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { // Insert sched group barriers based on the analysis if (numBufferLoads > 0) { - // Ensure we have at least 3 MFMAs to distribute (for the pattern) for (uint64_t i = 0; i < numBufferLoads; i++) { uint64_t dsReadsPerLoad = llvm::divideCeil(numDSReads, numBufferLoads); uint64_t dsWritesPerLoad = llvm::divideCeil(numDSWrites, numBufferLoads); uint64_t mfmaPerLoad = llvm::divideCeil(numMFMA, numBufferLoads); - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x020, 1, 0); // VMEM - if (dsReadsPerLoad > 0) { - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x100, - dsReadsPerLoad, - 0); // DS Reads - } - uint64_t mfmaPerDSWrite = - llvm::divideCeil(mfmaPerLoad, dsWritesPerLoad); - while (dsWritesPerLoad > 0 && mfmaPerLoad > 0) { - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, 1, - 0); // DS Writes - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, - mfmaPerDSWrite, 0); // MFMA - mfmaPerLoad -= mfmaPerDSWrite; - dsWritesPerLoad--; - } - if (dsWritesPerLoad > 0) { - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, - dsWritesPerLoad, 0); // DS Writes - } - if (mfmaPerLoad > 0) { - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, mfmaPerLoad, - 0); // MFMA + if (analysis.isDoubleBuffered) { + uint64_t dsWritesPerMFMA = + llvm::divideCeil(dsWritesPerLoad, mfmaPerLoad); + if (dsWritesPerLoad > 0 && mfmaPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, + dsWritesPerMFMA, 0); // DS Writes + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, + 0); // MFMA + mfmaPerLoad--; + dsWritesPerLoad -= dsWritesPerMFMA; + } + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x020, 1, + 0); // VMEM + if (mfmaPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, + mfmaPerLoad, + 0); // MFMA + } + if (dsReadsPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x100, + dsReadsPerLoad, + 0); // DS Reads + } + } else { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x020, 1, + 0); // VMEM + if (dsReadsPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x100, + dsReadsPerLoad, + 0); // DS Reads + } + uint64_t mfmaPerDSWrite = + llvm::divideCeil(mfmaPerLoad, dsWritesPerLoad); + while (dsWritesPerLoad > 0 && mfmaPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, 1, + 0); // DS Writes + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, + mfmaPerDSWrite, 0); // MFMA + mfmaPerLoad -= mfmaPerDSWrite; + dsWritesPerLoad--; + } + if (dsWritesPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, + dsWritesPerLoad, 0); // DS Writes + } + if (mfmaPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, + mfmaPerLoad, + 0); // MFMA + } } } } From 20d234233c0c3f414bb98fcca3a4855c2c6a4d1b Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 29 Dec 2025 20:20:31 +0000 Subject: [PATCH 07/15] swap order --- .../Rock/Transforms/AddSchedGroupBarriers.cpp | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp index dd8995580286..e88d38b90ea4 100644 --- a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp @@ -302,19 +302,16 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { dsReadsPerLoad, 0); // DS Reads } - uint64_t mfmaPerDSWrite = - llvm::divideCeil(mfmaPerLoad, dsWritesPerLoad); + uint64_t dsWritesPerMFMA = + llvm::divideCeil(dsWritesPerLoad, mfmaPerLoad); while (dsWritesPerLoad > 0 && mfmaPerLoad > 0) { - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, 1, - 0); // DS Writes - ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, - mfmaPerDSWrite, 0); // MFMA - mfmaPerLoad -= mfmaPerDSWrite; - dsWritesPerLoad--; - } - if (dsWritesPerLoad > 0) { ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, - dsWritesPerLoad, 0); // DS Writes + dsWritesPerMFMA, + 0); // DS Writes + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, + 0); // MFMA + mfmaPerLoad--; + dsWritesPerLoad -= dsWritesPerMFMA; } if (mfmaPerLoad > 0) { ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, From 920a35b8006efad255849286a1cf8d7cbfc6ff5c Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 29 Dec 2025 20:21:35 +0000 Subject: [PATCH 08/15] change to greedy --- mlir/utils/tuna/tuna-script.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/utils/tuna/tuna-script.sh b/mlir/utils/tuna/tuna-script.sh index 9feef2108015..eb5111969f05 100755 --- a/mlir/utils/tuna/tuna-script.sh +++ b/mlir/utils/tuna/tuna-script.sh @@ -96,7 +96,7 @@ export TUNA_DIR=/tmp/MITuna export ROCMLIR_DIR=$(pwd)/.. # Assumes we're in the build directory export OUT_FILE=results.tsv export OP=convolution -export TUNING_SPACE=exhaustive +export TUNING_SPACE=greedy export LOAD_FACTOR= # -c configs From 94317d6f20edce03deff18cc7fcebee17d64758e Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 30 Dec 2025 13:53:10 +0000 Subject: [PATCH 09/15] Fix bug --- .../Rock/Transforms/AddSchedGroupBarriers.cpp | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp index e88d38b90ea4..1eee42ca4b6d 100644 --- a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp @@ -115,7 +115,7 @@ struct ScfForAnalysisResult { uint64_t globalLoads = 0; uint64_t ldsReads = 0; uint64_t ldsWrites = 0; - uint64_t mfmaOps = 0; + uint64_t matrixMultiplyOps = 0; /// Indicates if the loop uses double buffering (LDS reads/writes use /// arith.select to choose between two buffers) bool isDoubleBuffered = false; @@ -199,9 +199,10 @@ static ScfForAnalysisResult analyzeScfFor(scf::ForOp forOp) { return; } - // Count amdgpu.mfma operations - if (isa(op)) { - result.mfmaOps += multiplier; + // Count Matrix multiply operations + if (isa(op) || isa(op) || + isa(op)) { + result.matrixMultiplyOps += multiplier; return; } }); @@ -232,7 +233,7 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { ScfForAnalysisResult analysis = analyzeScfFor(op); // Skip if no meaningful operations found - if (analysis.globalLoads == 0 && analysis.mfmaOps == 0) + if (analysis.globalLoads == 0 && analysis.matrixMultiplyOps == 0) return failure(); // Print analysis results for debugging @@ -242,7 +243,7 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { llvm::dbgs() << "Global memory loads per iteration: " << analysis.globalLoads << "\n"; llvm::dbgs() << "LDS reads per iteration: " << analysis.ldsReads << "\n"; llvm::dbgs() << "LDS writes per iteration: " << analysis.ldsWrites << "\n"; - llvm::dbgs() << "MFMA operations per iteration: " << analysis.mfmaOps << "\n"; + llvm::dbgs() << "Matrix multiply operations per iteration: " << analysis.matrixMultiplyOps << "\n"; llvm::dbgs() << "Double buffering detected: " << (analysis.isDoubleBuffered ? "yes" : "no") << "\n"; llvm::dbgs() << "========================\n\n"; @@ -251,7 +252,7 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { uint64_t numBufferLoads = analysis.globalLoads; uint64_t numDSReads = analysis.ldsReads; uint64_t numDSWrites = analysis.ldsWrites; - uint64_t numMFMA = analysis.mfmaOps; + uint64_t numMatrixMultiplyOps = analysis.matrixMultiplyOps; // Insert sched_barrier at the start of the block rw.setInsertionPointToStart(&block); @@ -265,28 +266,28 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { rw.setInsertionPointAfter(lastOp); // Insert sched group barriers based on the analysis - if (numBufferLoads > 0) { + if (numBufferLoads > 0 && numMatrixMultiplyOps > 0) { for (uint64_t i = 0; i < numBufferLoads; i++) { uint64_t dsReadsPerLoad = llvm::divideCeil(numDSReads, numBufferLoads); uint64_t dsWritesPerLoad = llvm::divideCeil(numDSWrites, numBufferLoads); - uint64_t mfmaPerLoad = llvm::divideCeil(numMFMA, numBufferLoads); + uint64_t matrixMultiplyPerLoad = llvm::divideCeil(numMatrixMultiplyOps, numBufferLoads); if (analysis.isDoubleBuffered) { uint64_t dsWritesPerMFMA = - llvm::divideCeil(dsWritesPerLoad, mfmaPerLoad); - if (dsWritesPerLoad > 0 && mfmaPerLoad > 0) { + llvm::divideCeil(dsWritesPerLoad, matrixMultiplyPerLoad); + if (dsWritesPerLoad > 0 && matrixMultiplyPerLoad > 0) { ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, dsWritesPerMFMA, 0); // DS Writes ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, 0); // MFMA - mfmaPerLoad--; + matrixMultiplyPerLoad--; dsWritesPerLoad -= dsWritesPerMFMA; } ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x020, 1, 0); // VMEM - if (mfmaPerLoad > 0) { + if (matrixMultiplyPerLoad > 0) { ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, - mfmaPerLoad, + matrixMultiplyPerLoad, 0); // MFMA } if (dsReadsPerLoad > 0) { @@ -303,19 +304,19 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { 0); // DS Reads } uint64_t dsWritesPerMFMA = - llvm::divideCeil(dsWritesPerLoad, mfmaPerLoad); - while (dsWritesPerLoad > 0 && mfmaPerLoad > 0) { + llvm::divideCeil(dsWritesPerLoad, matrixMultiplyPerLoad); + while (dsWritesPerLoad > 0 && matrixMultiplyPerLoad > 0) { ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, dsWritesPerMFMA, 0); // DS Writes ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, 0); // MFMA - mfmaPerLoad--; + matrixMultiplyPerLoad--; dsWritesPerLoad -= dsWritesPerMFMA; } - if (mfmaPerLoad > 0) { + if (matrixMultiplyPerLoad > 0) { ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, - mfmaPerLoad, + matrixMultiplyPerLoad, 0); // MFMA } } From 3ad9d0dd0928d19d799e08b864f69ee16fe78c35 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 30 Dec 2025 15:22:16 +0000 Subject: [PATCH 10/15] fix bug --- mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp index 1eee42ca4b6d..ab4280eb97d1 100644 --- a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp @@ -275,7 +275,7 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { if (analysis.isDoubleBuffered) { uint64_t dsWritesPerMFMA = llvm::divideCeil(dsWritesPerLoad, matrixMultiplyPerLoad); - if (dsWritesPerLoad > 0 && matrixMultiplyPerLoad > 0) { + while(dsWritesPerLoad > 0 && matrixMultiplyPerLoad > 0) { ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, dsWritesPerMFMA, 0); // DS Writes ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, From 03dacc16bea472f66e7eb72ad594aed6b36ed2de Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Mon, 5 Jan 2026 21:47:30 +0000 Subject: [PATCH 11/15] fix bug for LDS bank conflicts --- mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 156d35af4842..046d032ad78e 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -287,12 +287,13 @@ static LDSLayoutConfigDim getLDSLayoutConfigDim(Type elementType, int64_t kpack, LDSLayoutConfigDim cfg; int64_t maxVlen = 128 / elementType.getIntOrFloatBitWidth(); int64_t copyDPerThread = vecDimInfo.inDPerThread; + int64_t copyKPerThread = vecDimInfo.inKPerThread; bool isKContiguousDim = vecDimInfo.vectorDim == GemmDimension::K; // If kpack is less than the hardware max vector length, and we are // writing more contiguous kpack elements, there is a possibility to // vectorize that we want to preserve (i.e., we favour vectorization over // bank conflicts resolution) - bool isPossibleToVectorizeD = (kpack < maxVlen && copyDPerThread > 1); + bool isPossibleToVectorizeD = (kpack < maxVlen && copyDPerThread > 1) && (copyKPerThread >= kpack); cfg.doRotateWithK = isKContiguousDim && !isPossibleToVectorizeD; cfg.doSwapThreadIterSubDims = !isKContiguousDim && !isPossibleToVectorizeD; cfg.ldsLayoutDxK = false; From 4bb12667b2634aff46677378a9ced4ece67894cd Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 6 Jan 2026 08:10:13 -0600 Subject: [PATCH 12/15] lower minCU Count for CPX mode --- mlir/lib/Dialect/Rock/IR/AmdArchDb.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Rock/IR/AmdArchDb.cpp b/mlir/lib/Dialect/Rock/IR/AmdArchDb.cpp index 71b3de8d86fd..aeca723203db 100644 --- a/mlir/lib/Dialect/Rock/IR/AmdArchDb.cpp +++ b/mlir/lib/Dialect/Rock/IR/AmdArchDb.cpp @@ -66,7 +66,7 @@ static constexpr AmdArchInfo GemmFeatures::direct_to_lds_32b, /*waveSize=*/64, /*maxWavesPerEU*/ 8, /*totalSGPRPerEU*/ 800, /*totalVGPRPerEU*/ 512, /*totalSharedMemPerCU*/ 65536, - /*maxSharedMemPerWG*/ 65536, /*numEUPerCU=*/4, /*minNumCU=*/80, + /*maxSharedMemPerWG*/ 65536, /*numEUPerCU=*/4, /*minNumCU=*/20, /*hasFp8ConversionInstrs=*/true, /*hasOcpFp8ConversionInstrs=*/false, /*hasScaledGemm=*/false, /*maxNumXCC=*/8, /*hasLdsTransposeLoad=*/false), From f2632b77e3b5fc578f528bb22dcc721d821d7c90 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 6 Jan 2026 16:56:01 +0000 Subject: [PATCH 13/15] Add greedy type in API --- mlir/include/mlir-c/Dialect/RockEnums.h | 3 ++- mlir/lib/CAPI/Dialect/Rock.cpp | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/Dialect/RockEnums.h b/mlir/include/mlir-c/Dialect/RockEnums.h index 1ad08814e7da..894be4512fc9 100644 --- a/mlir/include/mlir-c/Dialect/RockEnums.h +++ b/mlir/include/mlir-c/Dialect/RockEnums.h @@ -18,7 +18,8 @@ extern "C" { enum RocmlirTuningParamSetKind { RocmlirTuningParamSetKindQuick = 0, RocmlirTuningParamSetKindFull = 1, - RocmlirTuningParamSetKindExhaustive = 2 + RocmlirTuningParamSetKindGreedy = 2, + RocmlirTuningParamSetKindExhaustive = 3 }; typedef enum RocmlirTuningParamSetKind RocmlirTuningParamSetKind; diff --git a/mlir/lib/CAPI/Dialect/Rock.cpp b/mlir/lib/CAPI/Dialect/Rock.cpp index d7acb119bd67..1b7b280195d3 100644 --- a/mlir/lib/CAPI/Dialect/Rock.cpp +++ b/mlir/lib/CAPI/Dialect/Rock.cpp @@ -42,6 +42,9 @@ mlirRockTuningSpaceCreate(MlirModule module, RocmlirTuningParamSetKind kind) { case RocmlirTuningParamSetKindExhaustive: ourKind = rock::TuningParamSetKind::Exhaustive; break; + case RocmlirTuningParamSetKindGreedy: + ourKind = rock::TuningParamSetKind::Greedy; + break; } auto mod = unwrap(module); rock::TuningParamSpaceSettings settings; From a6e5534d9e51bc10e0cc4204b2b92c7f979c9871 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 6 Jan 2026 17:17:20 +0000 Subject: [PATCH 14/15] add logic for directToLDS --- .../Rock/Transforms/AddSchedGroupBarriers.cpp | 69 ++++++++++++++----- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp index ab4280eb97d1..5d87bf624ded 100644 --- a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp @@ -97,7 +97,8 @@ static uint64_t getAffineForTripCount(affine::AffineForOp affineFor) { /// Compute the multiplier for an operation based on enclosing affine.for loops /// within the scf.for boundary -static uint64_t computeAffineLoopMultiplier(Operation *op, scf::ForOp boundary) { +static uint64_t computeAffineLoopMultiplier(Operation *op, + scf::ForOp boundary) { uint64_t multiplier = 1; Operation *parent = op->getParentOp(); @@ -116,6 +117,8 @@ struct ScfForAnalysisResult { uint64_t ldsReads = 0; uint64_t ldsWrites = 0; uint64_t matrixMultiplyOps = 0; + /// Direct loads from global memory to LDS (amdgpu.gather_to_lds) + uint64_t directLoadsToLDS = 0; /// Indicates if the loop uses double buffering (LDS reads/writes use /// arith.select to choose between two buffers) bool isDoubleBuffered = false; @@ -134,21 +137,37 @@ static ScfForAnalysisResult analyzeScfFor(scf::ForOp forOp) { return; } - // Count vector.load from global memory + // Count vector.load from global memory or workgroup memory (LDS) if (auto vectorLoad = dyn_cast(op)) { - if (auto memrefType = dyn_cast(vectorLoad.getBase().getType())) { + if (auto memrefType = + dyn_cast(vectorLoad.getBase().getType())) { if (hasGlobalAddressSpace(memrefType)) { result.globalLoads += multiplier; + } else if (hasWorkgroupAddressSpace(memrefType)) { + result.ldsReads += multiplier; + // Check for double buffering: if the memref is selected via + // arith.select, it indicates alternating between two LDS buffers + if (isDefinedBySelect(vectorLoad.getBase())) { + result.isDoubleBuffered = true; + } } } return; } - // Count vector.transfer_read from global memory + // Count vector.transfer_read from global memory or workgroup memory (LDS) if (auto transferRead = dyn_cast(op)) { - if (auto memrefType = dyn_cast(transferRead.getBase().getType())) { + if (auto memrefType = + dyn_cast(transferRead.getBase().getType())) { if (hasGlobalAddressSpace(memrefType)) { result.globalLoads += multiplier; + } else if (hasWorkgroupAddressSpace(memrefType)) { + result.ldsReads += multiplier; + // Check for double buffering: if the memref is selected via + // arith.select, it indicates alternating between two LDS buffers + if (isDefinedBySelect(transferRead.getBase())) { + result.isDoubleBuffered = true; + } } } return; @@ -156,7 +175,8 @@ static ScfForAnalysisResult analyzeScfFor(scf::ForOp forOp) { // Count memref.load from workgroup memory (LDS reads) if (auto memrefLoad = dyn_cast(op)) { - if (auto memrefType = dyn_cast(memrefLoad.getMemRef().getType())) { + if (auto memrefType = + dyn_cast(memrefLoad.getMemRef().getType())) { if (hasWorkgroupAddressSpace(memrefType)) { result.ldsReads += multiplier; // Check for double buffering: if the memref is selected via @@ -171,7 +191,8 @@ static ScfForAnalysisResult analyzeScfFor(scf::ForOp forOp) { // Count memref.store to workgroup memory (LDS writes) if (auto memrefStore = dyn_cast(op)) { - if (auto memrefType = dyn_cast(memrefStore.getMemRef().getType())) { + if (auto memrefType = + dyn_cast(memrefStore.getMemRef().getType())) { if (hasWorkgroupAddressSpace(memrefType)) { result.ldsWrites += multiplier; // Check for double buffering: if the memref is selected via @@ -186,7 +207,8 @@ static ScfForAnalysisResult analyzeScfFor(scf::ForOp forOp) { // Count vector.transfer_write to workgroup memory (LDS writes) if (auto transferWrite = dyn_cast(op)) { - if (auto memrefType = dyn_cast(transferWrite.getBase().getType())) { + if (auto memrefType = + dyn_cast(transferWrite.getBase().getType())) { if (hasWorkgroupAddressSpace(memrefType)) { result.ldsWrites += multiplier; // Check for double buffering: if the memref is selected via @@ -205,6 +227,12 @@ static ScfForAnalysisResult analyzeScfFor(scf::ForOp forOp) { result.matrixMultiplyOps += multiplier; return; } + + // Count direct loads from global memory to LDS (amdgpu.gather_to_lds) + if (isa(op)) { + result.directLoadsToLDS += multiplier; + return; + } }); return result; @@ -233,23 +261,29 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { ScfForAnalysisResult analysis = analyzeScfFor(op); // Skip if no meaningful operations found - if (analysis.globalLoads == 0 && analysis.matrixMultiplyOps == 0) + if (analysis.globalLoads == 0 && analysis.matrixMultiplyOps == 0 && + analysis.directLoadsToLDS == 0) return failure(); // Print analysis results for debugging LLVM_DEBUG({ llvm::dbgs() << "=== scf.for Analysis ===\n"; llvm::dbgs() << "Location: " << op.getLoc() << "\n"; - llvm::dbgs() << "Global memory loads per iteration: " << analysis.globalLoads << "\n"; + llvm::dbgs() << "Global memory loads per iteration: " + << analysis.globalLoads << "\n"; + llvm::dbgs() << "Direct loads to LDS per iteration: " + << analysis.directLoadsToLDS << "\n"; llvm::dbgs() << "LDS reads per iteration: " << analysis.ldsReads << "\n"; - llvm::dbgs() << "LDS writes per iteration: " << analysis.ldsWrites << "\n"; - llvm::dbgs() << "Matrix multiply operations per iteration: " << analysis.matrixMultiplyOps << "\n"; + llvm::dbgs() << "LDS writes per iteration: " << analysis.ldsWrites + << "\n"; + llvm::dbgs() << "Matrix multiply operations per iteration: " + << analysis.matrixMultiplyOps << "\n"; llvm::dbgs() << "Double buffering detected: " << (analysis.isDoubleBuffered ? "yes" : "no") << "\n"; llvm::dbgs() << "========================\n\n"; }); - uint64_t numBufferLoads = analysis.globalLoads; + uint64_t numBufferLoads = analysis.globalLoads + analysis.directLoadsToLDS; uint64_t numDSReads = analysis.ldsReads; uint64_t numDSWrites = analysis.ldsWrites; uint64_t numMatrixMultiplyOps = analysis.matrixMultiplyOps; @@ -271,11 +305,12 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { uint64_t dsReadsPerLoad = llvm::divideCeil(numDSReads, numBufferLoads); uint64_t dsWritesPerLoad = llvm::divideCeil(numDSWrites, numBufferLoads); - uint64_t matrixMultiplyPerLoad = llvm::divideCeil(numMatrixMultiplyOps, numBufferLoads); + uint64_t matrixMultiplyPerLoad = + llvm::divideCeil(numMatrixMultiplyOps, numBufferLoads); if (analysis.isDoubleBuffered) { uint64_t dsWritesPerMFMA = llvm::divideCeil(dsWritesPerLoad, matrixMultiplyPerLoad); - while(dsWritesPerLoad > 0 && matrixMultiplyPerLoad > 0) { + while (dsWritesPerLoad > 0 && matrixMultiplyPerLoad > 0) { ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, dsWritesPerMFMA, 0); // DS Writes ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, @@ -334,7 +369,8 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { }; struct RockAddSchedGroupBarriersPass final - : rock::impl::RockAddSchedGroupBarriersPassBase { + : rock::impl::RockAddSchedGroupBarriersPassBase< + RockAddSchedGroupBarriersPass> { void runOnOperation() override { func::FuncOp funcOp = getOperation(); @@ -350,4 +386,3 @@ struct RockAddSchedGroupBarriersPass final }; } // end namespace - From b0e84e4a4f67429eec78d0021f193d5f919f7af0 Mon Sep 17 00:00:00 2001 From: Umang Yadav Date: Tue, 6 Jan 2026 17:24:57 +0000 Subject: [PATCH 15/15] do not use sched group for direct tolds --- .../Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp index 5d87bf624ded..03290e5e529e 100644 --- a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp @@ -259,10 +259,15 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { // Analyze the scf.for loop to get operation counts ScfForAnalysisResult analysis = analyzeScfFor(op); + bool isDirectToLDS = analysis.directLoadsToLDS > 0; + if (isDirectToLDS) { + // direct to LDS is not supported yet + LLVM_DEBUG(llvm::dbgs() << "Direct to LDS is not supported yet\n"); + return failure(); + } // Skip if no meaningful operations found - if (analysis.globalLoads == 0 && analysis.matrixMultiplyOps == 0 && - analysis.directLoadsToLDS == 0) + if (analysis.globalLoads == 0 && analysis.matrixMultiplyOps == 0) return failure(); // Print analysis results for debugging @@ -283,7 +288,7 @@ struct InsertSchedGroupBarrierPattern : public OpRewritePattern { llvm::dbgs() << "========================\n\n"; }); - uint64_t numBufferLoads = analysis.globalLoads + analysis.directLoadsToLDS; + uint64_t numBufferLoads = analysis.globalLoads; uint64_t numDSReads = analysis.ldsReads; uint64_t numDSWrites = analysis.ldsWrites; uint64_t numMatrixMultiplyOps = analysis.matrixMultiplyOps;