diff --git a/mlir/include/mlir-c/Dialect/RockEnums.h b/mlir/include/mlir-c/Dialect/RockEnums.h index 1ad08814e7da..894be4512fc9 100644 --- a/mlir/include/mlir-c/Dialect/RockEnums.h +++ b/mlir/include/mlir-c/Dialect/RockEnums.h @@ -18,7 +18,8 @@ extern "C" { enum RocmlirTuningParamSetKind { RocmlirTuningParamSetKindQuick = 0, RocmlirTuningParamSetKindFull = 1, - RocmlirTuningParamSetKindExhaustive = 2 + RocmlirTuningParamSetKindGreedy = 2, + RocmlirTuningParamSetKindExhaustive = 3 }; typedef enum RocmlirTuningParamSetKind RocmlirTuningParamSetKind; diff --git a/mlir/include/mlir/Dialect/Rock/Passes.h b/mlir/include/mlir/Dialect/Rock/Passes.h index c0771efaae26..f30dc248c8c8 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.h +++ b/mlir/include/mlir/Dialect/Rock/Passes.h @@ -56,6 +56,7 @@ namespace rock { #define GEN_PASS_DECL_ROCKADDASYNCWAITPASS #define GEN_PASS_DECL_ROCKADDDIRECTTOLDSALIASINFOPASS #define GEN_PASS_DECL_CONVERTROCKOPSTOROCDLOPS +#define GEN_PASS_DECL_ROCKADDSCHEDGROUPBARRIERSPASS #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Rock/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Rock/Passes.td b/mlir/include/mlir/Dialect/Rock/Passes.td index 566d2de04b73..a099802f9a98 100644 --- a/mlir/include/mlir/Dialect/Rock/Passes.td +++ b/mlir/include/mlir/Dialect/Rock/Passes.td @@ -161,6 +161,26 @@ def RockBufferLoadMergePass : Pass<"rock-buffer-load-merge", "::mlir::func::Func let dependentDialects = ["::mlir::amdgpu::AMDGPUDialect"]; } +def RockAddSchedGroupBarriersPass : Pass<"rock-add-sched-group-barriers", "::mlir::func::FuncOp"> { + let summary = "Analyze scf.for loops and insert scheduling group barriers"; + let description = [{ + This pass analyzes scf.for operations, counts memory operations and MFMA + instructions per iteration, and inserts scheduling group barriers: + - Global memory loads (amdgpu.raw_buffer_load, vector.load from global memory) + - LDS/workgroup memory reads (memref.load from workgroup address space) + - LDS/workgroup memory writes (memref.store to workgroup address space) + - MFMA instructions (amdgpu.mfma) + + The counts factor in affine.for loop trip counts to give the total number of + operations per scf.for iteration. Based on these counts, scheduling group + barriers (ROCDL::SchedGroupBarrier) are inserted to optimize instruction + scheduling on AMD GPUs. + }]; + let dependentDialects = ["::mlir::amdgpu::AMDGPUDialect", "::mlir::scf::SCFDialect", + "::mlir::affine::AffineDialect", "::mlir::gpu::GPUDialect", + "::mlir::ROCDL::ROCDLDialect"]; +} + def RockTransformToMemrefPass : Pass<"rock-transform-to-memref", "::mlir::func::FuncOp"> { let summary = "convert remaining rock.transform ops to memref.expand/collapse_shape"; let dependentDialects = ["rock::RockDialect", "affine::AffineDialect", "gpu::GPUDialect", "vector::VectorDialect", "memref::MemRefDialect"]; diff --git a/mlir/lib/CAPI/Dialect/Rock.cpp b/mlir/lib/CAPI/Dialect/Rock.cpp index d7acb119bd67..1b7b280195d3 100644 --- a/mlir/lib/CAPI/Dialect/Rock.cpp +++ b/mlir/lib/CAPI/Dialect/Rock.cpp @@ -42,6 +42,9 @@ mlirRockTuningSpaceCreate(MlirModule module, RocmlirTuningParamSetKind kind) { case RocmlirTuningParamSetKindExhaustive: ourKind = rock::TuningParamSetKind::Exhaustive; break; + case RocmlirTuningParamSetKindGreedy: + ourKind = rock::TuningParamSetKind::Greedy; + break; } auto mod = unwrap(module); rock::TuningParamSpaceSettings settings; diff --git a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp index a945ed55ca8e..002906a4df98 100644 --- a/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp +++ b/mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp @@ -237,6 +237,7 @@ void rock::buildKernelPipeline(OpPassManager &pm, funcPm.addPass( math::createMathExtendToSupportedTypes(extendToLLVMTypesOptions)); funcPm.addPass(rock::createRockBufferLoadMergePass()); + funcPm.addPass(rock::createRockAddSchedGroupBarriersPass()); funcPm.addPass(rock::createRockTransformToMemrefPass()); funcPm.addPass(rock::createRockEmulateNarrowTypePass()); funcPm.addPass(rock::createRockPack4BitGpuOpsTo8BitPass()); diff --git a/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp new file mode 100644 index 000000000000..03290e5e529e --- /dev/null +++ b/mlir/lib/Dialect/Rock/Transforms/AddSchedGroupBarriers.cpp @@ -0,0 +1,393 @@ +//===- AddSchedGroupBarriers.cpp - Add scheduling group barriers ----------===// +// +// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Copyright (c) 2025 Advanced Micro Devices Inc. +//===----------------------------------------------------------------------===// +// +// This pass analyzes scf.for loops to count memory operations and MFMA +// instructions per iteration, then inserts scheduling group barriers: +// - Global memory loads (amdgpu.raw_buffer_load, vector.load from global) +// - LDS/workgroup memory reads (memref.load from workgroup address space) +// - LDS/workgroup memory writes (memref.store to workgroup address space) +// - MFMA instructions (amdgpu.mfma) +// +// The counts factor in affine.for loop trip counts. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Rock/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "rock-add-sched-group-barriers" + +namespace mlir { +namespace rock { +#define GEN_PASS_DEF_ROCKADDSCHEDGROUPBARRIERSPASS +#include "mlir/Dialect/Rock/Passes.h.inc" +} // namespace rock +} // namespace mlir + +using namespace mlir; +using namespace mlir::rock; + +namespace { + +/// Check if a memref type has workgroup (LDS) address space +static bool hasWorkgroupAddressSpace(MemRefType memrefType) { + auto addrSpace = memrefType.getMemorySpace(); + if (!addrSpace) + return false; + + // Check for gpu.address_space + if (auto gpuAddrSpace = dyn_cast(addrSpace)) { + return gpuAddrSpace.getValue() == gpu::AddressSpace::Workgroup; + } + return false; +} + +/// Check if a memref type has global address space +static bool hasGlobalAddressSpace(MemRefType memrefType) { + auto addrSpace = memrefType.getMemorySpace(); + // No address space means global by default + if (!addrSpace) + return true; + + // Check for gpu.address_space + if (auto gpuAddrSpace = dyn_cast(addrSpace)) { + return gpuAddrSpace.getValue() == gpu::AddressSpace::Global; + } + return false; +} + +/// Check if a value is defined by an arith.select operation, which indicates +/// double buffering (selecting between two different LDS buffers) +static bool isDefinedBySelect(Value val) { + return val.getDefiningOp() != nullptr; +} + +/// Get the trip count of an affine.for loop, returns 1 if unknown +static uint64_t getAffineForTripCount(affine::AffineForOp affineFor) { + std::optional tripCount = affine::getConstantTripCount(affineFor); + if (tripCount.has_value()) { + return tripCount.value(); + } + // If we can't determine the trip count, return 1 (conservative estimate) + return 1; +} + +/// Compute the multiplier for an operation based on enclosing affine.for loops +/// within the scf.for boundary +static uint64_t computeAffineLoopMultiplier(Operation *op, + scf::ForOp boundary) { + uint64_t multiplier = 1; + Operation *parent = op->getParentOp(); + + while (parent && parent != boundary.getOperation()) { + if (auto affineFor = dyn_cast(parent)) { + multiplier *= getAffineForTripCount(affineFor); + } + parent = parent->getParentOp(); + } + + return multiplier; +} + +struct ScfForAnalysisResult { + uint64_t globalLoads = 0; + uint64_t ldsReads = 0; + uint64_t ldsWrites = 0; + uint64_t matrixMultiplyOps = 0; + /// Direct loads from global memory to LDS (amdgpu.gather_to_lds) + uint64_t directLoadsToLDS = 0; + /// Indicates if the loop uses double buffering (LDS reads/writes use + /// arith.select to choose between two buffers) + bool isDoubleBuffered = false; +}; + +/// Analyze a single scf.for operation +static ScfForAnalysisResult analyzeScfFor(scf::ForOp forOp) { + ScfForAnalysisResult result; + + forOp.walk([&](Operation *op) { + uint64_t multiplier = computeAffineLoopMultiplier(op, forOp); + + // Count amdgpu.raw_buffer_load (global loads) + if (isa(op)) { + result.globalLoads += multiplier; + return; + } + + // Count vector.load from global memory or workgroup memory (LDS) + if (auto vectorLoad = dyn_cast(op)) { + if (auto memrefType = + dyn_cast(vectorLoad.getBase().getType())) { + if (hasGlobalAddressSpace(memrefType)) { + result.globalLoads += multiplier; + } else if (hasWorkgroupAddressSpace(memrefType)) { + result.ldsReads += multiplier; + // Check for double buffering: if the memref is selected via + // arith.select, it indicates alternating between two LDS buffers + if (isDefinedBySelect(vectorLoad.getBase())) { + result.isDoubleBuffered = true; + } + } + } + return; + } + + // Count vector.transfer_read from global memory or workgroup memory (LDS) + if (auto transferRead = dyn_cast(op)) { + if (auto memrefType = + dyn_cast(transferRead.getBase().getType())) { + if (hasGlobalAddressSpace(memrefType)) { + result.globalLoads += multiplier; + } else if (hasWorkgroupAddressSpace(memrefType)) { + result.ldsReads += multiplier; + // Check for double buffering: if the memref is selected via + // arith.select, it indicates alternating between two LDS buffers + if (isDefinedBySelect(transferRead.getBase())) { + result.isDoubleBuffered = true; + } + } + } + return; + } + + // Count memref.load from workgroup memory (LDS reads) + if (auto memrefLoad = dyn_cast(op)) { + if (auto memrefType = + dyn_cast(memrefLoad.getMemRef().getType())) { + if (hasWorkgroupAddressSpace(memrefType)) { + result.ldsReads += multiplier; + // Check for double buffering: if the memref is selected via + // arith.select, it indicates alternating between two LDS buffers + if (isDefinedBySelect(memrefLoad.getMemRef())) { + result.isDoubleBuffered = true; + } + } + } + return; + } + + // Count memref.store to workgroup memory (LDS writes) + if (auto memrefStore = dyn_cast(op)) { + if (auto memrefType = + dyn_cast(memrefStore.getMemRef().getType())) { + if (hasWorkgroupAddressSpace(memrefType)) { + result.ldsWrites += multiplier; + // Check for double buffering: if the memref is selected via + // arith.select, it indicates alternating between two LDS buffers + if (isDefinedBySelect(memrefStore.getMemRef())) { + result.isDoubleBuffered = true; + } + } + } + return; + } + + // Count vector.transfer_write to workgroup memory (LDS writes) + if (auto transferWrite = dyn_cast(op)) { + if (auto memrefType = + dyn_cast(transferWrite.getBase().getType())) { + if (hasWorkgroupAddressSpace(memrefType)) { + result.ldsWrites += multiplier; + // Check for double buffering: if the memref is selected via + // arith.select, it indicates alternating between two LDS buffers + if (isDefinedBySelect(transferWrite.getBase())) { + result.isDoubleBuffered = true; + } + } + } + return; + } + + // Count Matrix multiply operations + if (isa(op) || isa(op) || + isa(op)) { + result.matrixMultiplyOps += multiplier; + return; + } + + // Count direct loads from global memory to LDS (amdgpu.gather_to_lds) + if (isa(op)) { + result.directLoadsToLDS += multiplier; + return; + } + }); + + return result; +} + +/// Rewrite pattern to insert scheduling group barriers in scf.for loops +struct InsertSchedGroupBarrierPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rw) const override { + mlir::Region ®ion = op.getRegion(); + Block &block = region.front(); + + // Check if SchedBarrierOp already exists (to avoid duplicates) + WalkResult result = block.walk([&](Operation *innerOp) { + if (isa(innerOp)) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) + return failure(); + + // Analyze the scf.for loop to get operation counts + ScfForAnalysisResult analysis = analyzeScfFor(op); + bool isDirectToLDS = analysis.directLoadsToLDS > 0; + if (isDirectToLDS) { + // direct to LDS is not supported yet + LLVM_DEBUG(llvm::dbgs() << "Direct to LDS is not supported yet\n"); + return failure(); + } + + // Skip if no meaningful operations found + if (analysis.globalLoads == 0 && analysis.matrixMultiplyOps == 0) + return failure(); + + // Print analysis results for debugging + LLVM_DEBUG({ + llvm::dbgs() << "=== scf.for Analysis ===\n"; + llvm::dbgs() << "Location: " << op.getLoc() << "\n"; + llvm::dbgs() << "Global memory loads per iteration: " + << analysis.globalLoads << "\n"; + llvm::dbgs() << "Direct loads to LDS per iteration: " + << analysis.directLoadsToLDS << "\n"; + llvm::dbgs() << "LDS reads per iteration: " << analysis.ldsReads << "\n"; + llvm::dbgs() << "LDS writes per iteration: " << analysis.ldsWrites + << "\n"; + llvm::dbgs() << "Matrix multiply operations per iteration: " + << analysis.matrixMultiplyOps << "\n"; + llvm::dbgs() << "Double buffering detected: " + << (analysis.isDoubleBuffered ? "yes" : "no") << "\n"; + llvm::dbgs() << "========================\n\n"; + }); + + uint64_t numBufferLoads = analysis.globalLoads; + uint64_t numDSReads = analysis.ldsReads; + uint64_t numDSWrites = analysis.ldsWrites; + uint64_t numMatrixMultiplyOps = analysis.matrixMultiplyOps; + + // Insert sched_barrier at the start of the block + rw.setInsertionPointToStart(&block); + amdgpu::SchedBarrierOp::create( + rw, op.getLoc(), + amdgpu::sched_barrier_opt_enumAttr::get( + rw.getContext(), amdgpu::sched_barrier_opt_enum::none)); + + // Insert sched group barriers before the terminator + auto *lastOp = block.getTerminator()->getPrevNode(); + rw.setInsertionPointAfter(lastOp); + + // Insert sched group barriers based on the analysis + if (numBufferLoads > 0 && numMatrixMultiplyOps > 0) { + for (uint64_t i = 0; i < numBufferLoads; i++) { + uint64_t dsReadsPerLoad = llvm::divideCeil(numDSReads, numBufferLoads); + uint64_t dsWritesPerLoad = + llvm::divideCeil(numDSWrites, numBufferLoads); + uint64_t matrixMultiplyPerLoad = + llvm::divideCeil(numMatrixMultiplyOps, numBufferLoads); + if (analysis.isDoubleBuffered) { + uint64_t dsWritesPerMFMA = + llvm::divideCeil(dsWritesPerLoad, matrixMultiplyPerLoad); + while (dsWritesPerLoad > 0 && matrixMultiplyPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, + dsWritesPerMFMA, 0); // DS Writes + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, + 0); // MFMA + matrixMultiplyPerLoad--; + dsWritesPerLoad -= dsWritesPerMFMA; + } + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x020, 1, + 0); // VMEM + if (matrixMultiplyPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, + matrixMultiplyPerLoad, + 0); // MFMA + } + if (dsReadsPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x100, + dsReadsPerLoad, + 0); // DS Reads + } + } else { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x020, 1, + 0); // VMEM + if (dsReadsPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x100, + dsReadsPerLoad, + 0); // DS Reads + } + uint64_t dsWritesPerMFMA = + llvm::divideCeil(dsWritesPerLoad, matrixMultiplyPerLoad); + while (dsWritesPerLoad > 0 && matrixMultiplyPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x200, + dsWritesPerMFMA, + 0); // DS Writes + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, 1, + 0); // MFMA + matrixMultiplyPerLoad--; + dsWritesPerLoad -= dsWritesPerMFMA; + } + if (matrixMultiplyPerLoad > 0) { + ROCDL::SchedGroupBarrier::create(rw, op.getLoc(), 0x008, + matrixMultiplyPerLoad, + 0); // MFMA + } + } + } + } + + // Insert sched_barrier at the end + amdgpu::SchedBarrierOp::create( + rw, op.getLoc(), + amdgpu::sched_barrier_opt_enumAttr::get( + rw.getContext(), amdgpu::sched_barrier_opt_enum::none)); + + return success(); + } +}; + +struct RockAddSchedGroupBarriersPass final + : rock::impl::RockAddSchedGroupBarriersPassBase< + RockAddSchedGroupBarriersPass> { + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *ctx = funcOp.getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // end namespace diff --git a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt index 9bdec3f2318d..f4618c8d36af 100644 --- a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_rocmlir_dialect_library(MLIRRockTransforms AffixTuningParameters.cpp AlignTiling.cpp AnalyzeMemoryUse.cpp + AddSchedGroupBarriers.cpp BlockwiseGemmToThreadwise.cpp BufferLoadMerge.cpp BufferizableOpInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 156d35af4842..046d032ad78e 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -287,12 +287,13 @@ static LDSLayoutConfigDim getLDSLayoutConfigDim(Type elementType, int64_t kpack, LDSLayoutConfigDim cfg; int64_t maxVlen = 128 / elementType.getIntOrFloatBitWidth(); int64_t copyDPerThread = vecDimInfo.inDPerThread; + int64_t copyKPerThread = vecDimInfo.inKPerThread; bool isKContiguousDim = vecDimInfo.vectorDim == GemmDimension::K; // If kpack is less than the hardware max vector length, and we are // writing more contiguous kpack elements, there is a possibility to // vectorize that we want to preserve (i.e., we favour vectorization over // bank conflicts resolution) - bool isPossibleToVectorizeD = (kpack < maxVlen && copyDPerThread > 1); + bool isPossibleToVectorizeD = (kpack < maxVlen && copyDPerThread > 1) && (copyKPerThread >= kpack); cfg.doRotateWithK = isKContiguousDim && !isPossibleToVectorizeD; cfg.doSwapThreadIterSubDims = !isKContiguousDim && !isPossibleToVectorizeD; cfg.ldsLayoutDxK = false; diff --git a/mlir/utils/tuna/tuna-script.sh b/mlir/utils/tuna/tuna-script.sh index cd966654a079..eb5111969f05 100755 --- a/mlir/utils/tuna/tuna-script.sh +++ b/mlir/utils/tuna/tuna-script.sh @@ -96,7 +96,7 @@ export TUNA_DIR=/tmp/MITuna export ROCMLIR_DIR=$(pwd)/.. # Assumes we're in the build directory export OUT_FILE=results.tsv export OP=convolution -export TUNING_SPACE=full +export TUNING_SPACE=greedy export LOAD_FACTOR= # -c configs