diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index 26fb3b571bdc..45dbb20127c0 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -11,6 +11,8 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Analysis/BufferDependencyAnalysis.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Rock/IR/AmdArchDb.h" #include "mlir/Dialect/Rock/IR/GetRockInfo.h" #include "mlir/Dialect/Rock/IR/Rock.h" @@ -22,9 +24,12 @@ #include "mlir/Dialect/Rock/Tuning/RockTuning.h" #include "mlir/Dialect/Rock/utility/fusionUtils.h" #include "mlir/Dialect/Rock/utility/loweringUtils.h" +#include "mlir/Dialect/Rock/utility/transformMapUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" @@ -742,6 +747,197 @@ extractLayouts(Operation *op, llvm::StringMap &fLayoutMap, return success(); } +// Structure to hold information about a single reduction operation +struct ReductionInfo { + ReduceMethod method; + int64_t axis; + int64_t rank; + int64_t stride; // Stride of the reduction dimension + bool hasPointwiseBefore; + + bool operator<(const ReductionInfo &other) const { + // Sort by method first, then rank, then axis, then stride, then + // hasPointwiseBefore + if (method != other.method) + return method < other.method; + if (rank != other.rank) + return rank < other.rank; + if (axis != other.axis) + return axis < other.axis; + if (stride != other.stride) + return stride < other.stride; + return hasPointwiseBefore > other.hasPointwiseBefore; + } + + bool operator==(const ReductionInfo &other) const { + return method == other.method && axis == other.axis && rank == other.rank && + stride == other.stride && + hasPointwiseBefore == other.hasPointwiseBefore; + } +}; + +// Structure to hold fusion information for problem key generation +struct FusionInfo { + SmallVector reductions; + + bool hasReduction() const { return !reductions.empty(); } + int numReductionOutputs() const { return reductions.size(); } +}; + +// Helper to get the base value (allocation or block argument) from a value +static FailureOr getBaseValue(Value v) { + FailureOr maybeAlloc = rock::findMemrefAlloc(v); + if (succeeded(maybeAlloc)) { + return maybeAlloc.value().getResult(); + } + + FailureOr maybeBlockArg = rock::findBlockArgument(v); + if (succeeded(maybeBlockArg)) { + return maybeBlockArg.value(); + } + + return failure(); +} + +// Helper to trace backwards from a value to see if it reaches the target +// Returns success(hasPointwise) if target is reached, failure otherwise +static FailureOr tracesToTarget(Value start, Value target, + const BufferDependencyAnalysis &deps, + DenseSet &visited) { + if (!visited.insert(start).second) { + return failure(); // Avoid cycles + } + + FailureOr baseValue = getBaseValue(start); + if (failed(baseValue)) + return failure(); // Could not find base value + + if (*baseValue == target) { + return false; // Found target, no pointwise + } + + // For allocations, use BufferDependencyAnalysis to find writers + if (auto allocOp = baseValue->getDefiningOp()) { + std::optional> writers = deps.getWriters(allocOp); + if (writers) { + for (OpOperand *writerOperand : *writers) { + auto genericOp = dyn_cast(writerOperand->getOwner()); + if (!genericOp) { + continue; + } + + // Trace through inputs of the linalg.generic (assumed to be pointwise) + for (Value input : genericOp.getInputs()) { + FailureOr maybeHasPointwise = + tracesToTarget(input, target, deps, visited); + if (succeeded(maybeHasPointwise)) { + return true; // Found target through pointwise + } + } + } + } + } + + return failure(); +} + +// Find all reductions and check if they trace back to our GEMM output +static FusionInfo getFusionInfo(Value gemmResult, GemmFeatures features) { + FusionInfo info; + + // Find the target (allocation or block argument) + FailureOr maybeTarget = getBaseValue(gemmResult); + if (failed(maybeTarget)) + return info; // None found + + Value target = *maybeTarget; + + // Get the parent function + auto defOp = gemmResult.getDefiningOp(); + auto funcOp = defOp ? rock::getParentFuncOp(defOp) : nullptr; + if (!funcOp) { + return info; + } + + // Walk all reduce operations and check if they trace back to our GEMM. + // Note, we are assuming that all reduce operations are returned here. + BufferDependencyAnalysis deps(funcOp); + funcOp->walk([&](rock::ReduceOp reduceOp) { + DenseSet visited; + FailureOr maybeHasPointwise = + tracesToTarget(reduceOp.getIn(), target, deps, visited); + + if (succeeded(maybeHasPointwise)) { + ReductionInfo redInfo; + redInfo.method = reduceOp.getReduceMethod(); + redInfo.axis = reduceOp.getAxis().getSExtValue(); + auto memrefType = cast(reduceOp.getIn().getType()); + redInfo.rank = memrefType.getRank(); + + // Extract stride for the reduction dimension + SmallVector strides; + int64_t offset; + if (succeeded(memrefType.getStridesAndOffset(strides, offset))) { + redInfo.stride = strides[redInfo.axis]; + } else { + // If we can't determine stride, use dynamic sentinel + redInfo.stride = ShapedType::kDynamic; + } + + redInfo.hasPointwiseBefore = *maybeHasPointwise; + info.reductions.push_back(redInfo); + } + }); + + // Sort reductions for consistent ordering in problem key + std::sort(info.reductions.begin(), info.reductions.end()); + + return info; +} + +// Append fusion information to the problem key string +static void appendOutputFusionInfo(llvm::raw_svector_ostream &problemOS, + const FusionInfo &fusionInfo) { + constexpr char sep = ' '; + + if (!fusionInfo.hasReduction()) + return; + + problemOS << sep << "-fusion_reduce" << sep + << "count=" << fusionInfo.numReductionOutputs(); + + // Encode each reduction in format: method:rank:axis:stride[:hasPointwise] + for (const auto &reduction : fusionInfo.reductions) { + problemOS << sep; + + // Add reduction method + switch (reduction.method) { + case ReduceMethod::Sum: + problemOS << "sum"; + break; + case ReduceMethod::Max: + problemOS << "max"; + break; + } + + // Add rank, axis, and stride with colon separators + problemOS << ":rank" << reduction.rank; + problemOS << ":axis" << reduction.axis; + + // Add stride (use '?' for dynamic/unknown strides) + if (reduction.stride == ShapedType::kDynamic) { + problemOS << ":stride?"; + } else { + problemOS << ":stride" << reduction.stride; + } + + // Add pointwise flag for this specific reduction + if (reduction.hasPointwiseBefore) { + problemOS << ":hasPointwise"; + } + } +} + static LogicalResult getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp, SmallVectorImpl &out) { @@ -917,6 +1113,13 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp, problemOS << "-k " << headDimQK << sep; problemOS << "-gemmO " << headDimV; } + + // Analyze and append fusion information + Value gemmGemmOutput = gemmGemmOp.getOutArgument()->get(); + GemmFeatures features = rock::getFeatures(gemmGemmOp); + FusionInfo fusionInfo = getFusionInfo(gemmGemmOutput, features); + appendOutputFusionInfo(problemOS, fusionInfo); + return success(); } @@ -1134,6 +1337,12 @@ static LogicalResult getTuningProblemStr(rock::RockGemmWrapperInterface gemmIF, return failure(); } + // Analyze and append fusion information + Value gemmOutput = gemmIF.getOutArgument()->get(); + GemmFeatures features = rock::getFeatures(gemmIF); + FusionInfo fusionInfo = getFusionInfo(gemmOutput, features); + appendOutputFusionInfo(problemOS, fusionInfo); + while (out.back() == sep) { // remove trailing whitespace out.pop_back(); diff --git a/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir b/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir new file mode 100644 index 000000000000..24fbfde2c223 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir @@ -0,0 +1,88 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: 256 convfp16 -F 1 -f GNC01 -I NGC01 -O NGC01 -n 1 -c 128 -H 32 -W 32 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -fusion_reduce count=2 sum:rank3:axis2:stride1:hasPointwise sum:rank3:axis2:stride1:hasPointwise + +#map = affine_map<(d0, d1, d2, d3) -> (((d0 * 128 + d1) * 3 + d2) * 3 + d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> ((d1 * 32 + d2) * 32 + d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 * 128 + d2, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0 * 256 + d1, d2, d3, d4)> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 * 256 + d2, d3, d4)> +#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d0, d1, d2, d4)> +#map6 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d0, d1, d4)> +#map7 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 * 8 + d2, d3, d4)> +#map8 = affine_map<(d0, d1, d2, d3, d4) -> (d1 * 8 + d2)> +#map9 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, 0, 0)> +#map10 = affine_map<(d0, d1, d2, d3) -> (0, d0, d1, d2, d3)> +#map11 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map12 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)> +#map13 = affine_map<(d0) -> (0, d0 floordiv 8192, (d0 mod 8192) floordiv 1024, (d0 mod 1024) floordiv 32, d0 mod 32)> +#map14 = affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 1024, (d2 mod 1024) floordiv 32, d2 mod 32)> +#map15 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>] bounds = [256, 128, 3, 3] -> [294912]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 32, 32] -> [131072]> +#transform_map2 = #rock.transform_map<#map2 by [ ["n", "h", "w"] at [0, 2, 3]>, ["c"] at [1]>] bounds = [1, 1, 128, 32, 32] -> [1, 128, 32, 32]> +#transform_map3 = #rock.transform_map<#map3 by [ ["c", "y", "x"] at [1, 2, 3]>, ["k"] at [0]>] bounds = [1, 256, 128, 3, 3] -> [256, 128, 3, 3]> +#transform_map4 = #rock.transform_map<#map4 by [ ["n", "h", "w"] at [0, 2, 3]>, ["k"] at [1]>] bounds = [1, 1, 256, 32, 32] -> [1, 256, 32, 32]> +#transform_map5 = #rock.transform_map<#map5 by [ ["dim1", "dim2", "dim3", "dim0", "dim4"] at [1, 2, 3, 0, 4]>] bounds = [256, 128, 3, 1, 3] -> [1, 256, 128, 3, 3]> +#transform_map6 = #rock.transform_map<#map6 by [ ["dim2", "dim3", "dim0", "dim1", "dim4"] at [2, 3, 0, 1, 4]>] bounds = [128, 32, 1, 1, 32] -> [1, 1, 128, 32, 32]> +#transform_map7 = #rock.transform_map<#map7 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>, ["dim3"] at [3]>] bounds = [1, 32, 8, 32, 32] -> [1, 256, 32, 32]> +#transform_map8 = #rock.transform_map<#map8 by [ ["dim0"] at [0]>, [] at []>, [] at []>, [] at []>] bounds = [1, 32, 8, 1, 1] -> [256]> +#transform_map9 = #rock.transform_map<#map9 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>, ["dim3"] at [3]>, ["dim4"] at [4]>] bounds = [1, 32, 8, 32, 32] -> [1, 32, 8, 1, 1]> +#transform_map10 = #rock.transform_map<#map10 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>, ["dim2"] at [3]>, ["dim3"] at [4]>] bounds = [32, 8, 32, 32] -> [1, 32, 8, 32, 32]> +#transform_map11 = #rock.transform_map<#map12 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>, ["dim3"] at [3]>, [] at []>] bounds = [1, 32, 8, 32, 32] -> [32, 8, 32, 32]> +#transform_map12 = #rock.transform_map<#map13 by [ ["col0", "col1", "col2", "col3", "col4"] at [0, 1, 2, 3, 4]>] bounds = [262144] -> [1, 32, 8, 32, 32]> +#transform_map13 = #rock.transform_map<#map14 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["col2", "col3", "col4"] at [2, 3, 4]>] bounds = [1, 32, 8192] -> [1, 32, 8, 32, 32]> +#transform_map14 = #rock.transform_map<#map15 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [32] -> [1, 32, 1]> +module { + func.func @mlir_reshape_convolution_reshape_broadcast_add_mul_reshape_reduce_sum_reshape_mul_reshape_reduce_sum_reshape(%arg0: memref<131072xf16>, %arg1: memref<294912xf16>, %arg2: memref<256xf16>, %arg3: memref<32xf16> {mhal.read_access, rock.prefill = 0.000000e+00 : f16}, %arg4: memref<32xf16> {mhal.read_access, rock.prefill = 0.000000e+00 : f16}, %arg5: memref<262144xf16>) attributes {arch = "gfx950:sramecc+:xnack-", enable_splitk_for_tuning, kernel = "mixr", num_cu = 256 : i64} { + %cst = arith.constant 1.220700e-04 : f16 + %0 = rock.transform %arg1 by #transform_map : memref<294912xf16> to memref<256x128x3x3xf16> + %1 = rock.transform %arg0 by #transform_map1 : memref<131072xf16> to memref<1x128x32x32xf16> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x256x32x32xf16> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x32x32xf16> to memref<1x1x128x32x32xf16> + %3 = rock.transform %0 by #transform_map3 : memref<256x128x3x3xf16> to memref<1x256x128x3x3xf16> + %4 = rock.transform %alloc by #transform_map4 : memref<1x256x32x32xf16> to memref<1x1x256x32x32xf16> + %5 = rock.transform %3 by #transform_map5 : memref<1x256x128x3x3xf16> to memref<256x128x3x1x3xf16> + %6 = rock.transform %2 by #transform_map6 : memref<1x1x128x32x32xf16> to memref<128x32x1x1x32xf16> + rock.conv(%3, %2, %4) {dilations = [1 : index, 1 : index], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], output_layout = ["no", "go", "ko", "ho", "wo"], padding = [1 : index, 1 : index, 1 : index, 1 : index], strides = [1 : index, 1 : index]} : memref<1x256x128x3x3xf16>, memref<1x1x128x32x32xf16>, memref<1x1x256x32x32xf16> + %7 = rock.transform %alloc by #transform_map7 : memref<1x256x32x32xf16> to memref<1x32x8x32x32xf16> + %8 = rock.transform %arg2 by #transform_map8 : memref<256xf16> to memref<1x32x8x1x1xf16> + %9 = rock.transform %8 by #transform_map9 : memref<1x32x8x1x1xf16> to memref<1x32x8x32x32xf16> + %10 = rock.transform %7 by #transform_map10 : memref<1x32x8x32x32xf16> to memref<32x8x32x32xf16> + %11 = rock.transform %9 by #transform_map10 : memref<1x32x8x32x32xf16> to memref<32x8x32x32xf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x8x32x32xf16> + linalg.generic {indexing_maps = [#map11, #map11, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%10, %11 : memref<32x8x32x32xf16>, memref<32x8x32x32xf16>) outs(%alloc_0 : memref<32x8x32x32xf16>) { + ^bb0(%in: f16, %in_5: f16, %out: f16): + %20 = arith.addf %in, %in_5 : f16 + linalg.yield %20 : f16 + } + %12 = rock.transform %alloc_0 by #transform_map11 : memref<32x8x32x32xf16> to memref<1x32x8x32x32xf16> + %13 = rock.transform %12 by #transform_map12 : memref<1x32x8x32x32xf16> to memref<262144xf16> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x8x32x32xf16> + linalg.generic {indexing_maps = [#map11, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_0 : memref<32x8x32x32xf16>) outs(%alloc_1 : memref<32x8x32x32xf16>) { + ^bb0(%in: f16, %out: f16): + %20 = arith.mulf %in, %cst : f16 + linalg.yield %20 : f16 + } + %14 = rock.transform %alloc_1 by #transform_map11 : memref<32x8x32x32xf16> to memref<1x32x8x32x32xf16> + %15 = rock.transform %14 by #transform_map13 : memref<1x32x8x32x32xf16> to memref<1x32x8192xf16> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x32x1xf16> + rock.reduce sum %15 into %alloc_2 {axis = 2 : index, blockSize = 256 : i32, gridSize = 1024 : i32} : memref<1x32x8192xf16> into memref<1x32x1xf16> + %16 = rock.transform %alloc_2 by #transform_map14 : memref<1x32x1xf16> to memref<32xf16> + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x8x32x32xf16> + linalg.generic {indexing_maps = [#map11, #map11, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_1, %alloc_0 : memref<32x8x32x32xf16>, memref<32x8x32x32xf16>) outs(%alloc_3 : memref<32x8x32x32xf16>) { + ^bb0(%in: f16, %in_5: f16, %out: f16): + %20 = arith.mulf %in, %in_5 : f16 + linalg.yield %20 : f16 + } + %17 = rock.transform %alloc_3 by #transform_map11 : memref<32x8x32x32xf16> to memref<1x32x8x32x32xf16> + %18 = rock.transform %17 by #transform_map13 : memref<1x32x8x32x32xf16> to memref<1x32x8192xf16> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<1x32x1xf16> + rock.reduce sum %18 into %alloc_4 {axis = 2 : index, blockSize = 256 : i32, gridSize = 1024 : i32} : memref<1x32x8192xf16> into memref<1x32x1xf16> + %19 = rock.transform %alloc_4 by #transform_map14 : memref<1x32x1xf16> to memref<32xf16> + memref.copy %16, %arg3 : memref<32xf16> to memref<32xf16> + memref.copy %19, %arg4 : memref<32xf16> to memref<32xf16> + memref.copy %13, %arg5 : memref<262144xf16> to memref<262144xf16> + return + } +} diff --git a/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir new file mode 100644 index 000000000000..42ccf00c9444 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir @@ -0,0 +1,50 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2:stride1:hasPointwise + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0, d1) -> (0, d0, d1)> +#map4 = affine_map<(d0, d1) -> (d0, d1)> +#map5 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map6 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 256] -> [32768]> +#transform_map1 = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map2 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map6 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>] bounds = [128, 256] -> [1, 128, 256]> +#transform_map8 = #rock.transform_map<#map5 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, [] at []>] bounds = [1, 128, 256] -> [128, 256]> +#transform_map9 = #rock.transform_map<#map6 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> +module { + func.func private @gemm_add_reduce_sum(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x128x256xf32> + %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> + %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> + %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> + %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> + %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %2 * %1 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %7 = rock.transform %alloc by #transform_map7 : memref<1x128x256xf32> to memref<128x256xf32> + %8 = rock.transform %0 by #transform_map7 : memref<1x128x256xf32> to memref<128x256xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<128x256xf32> + linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%7, %8 : memref<128x256xf32>, memref<128x256xf32>) outs(%alloc_0 : memref<128x256xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %11 = arith.addf %in, %in_2 : f32 + linalg.yield %11 : f32 + } + %9 = rock.transform %alloc_0 by #transform_map8 : memref<128x256xf32> to memref<1x128x256xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %9 into %alloc_1 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %10 = rock.transform %alloc_1 by #transform_map9 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %10, %arg3 : memref<128xf32> to memref<128xf32> + return + } +} + + diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir new file mode 100644 index 000000000000..cd6dad7d7e45 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir @@ -0,0 +1,64 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:rank3:axis2:stride1:hasPointwise + +#map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0, d1) -> (0, d0, d1)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> +#map6 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map7 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 128] -> [16384]> +#transform_map1 = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 256, 128] -> [32768]> +#transform_map2 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map4 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map5 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map6 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map8 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [256, 1, 128] -> [1, 256, 128]> +#transform_map9 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 256, 128] -> [256, 1, 128]> +#transform_map10 = #rock.transform_map<#map4 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>] bounds = [128, 128] -> [1, 128, 128]> +#transform_map11 = #rock.transform_map<#map6 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, [] at []>] bounds = [1, 128, 128] -> [128, 128]> +#transform_map12 = #rock.transform_map<#map7 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> +module { + func.func private @gemm_gemm_add_reduce_sum(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<16384xf32> {mhal.read_access}, %arg4: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg3 by #transform_map : memref<16384xf32> to memref<1x128x128xf32> + %1 = rock.transform %arg2 by #transform_map1 : memref<32768xf32> to memref<1x256x128xf32> + %2 = rock.transform %arg1 by #transform_map2 : memref<16384xf32> to memref<1x64x256xf32> + %3 = rock.transform %arg0 by #transform_map3 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x128xf32> + %4 = rock.transform %3 by #transform_map4 : memref<1x128x64xf32> to memref<128x1x64xf32> + %5 = rock.transform %4 by #transform_map5 : memref<128x1x64xf32> to memref<1x128x64xf32> + %6 = rock.transform %2 by #transform_map6 : memref<1x64x256xf32> to memref<64x1x256xf32> + %7 = rock.transform %6 by #transform_map7 : memref<64x1x256xf32> to memref<1x64x256xf32> + %8 = rock.transform %1 by #transform_map8 : memref<1x256x128xf32> to memref<256x1x128xf32> + %9 = rock.transform %8 by #transform_map9 : memref<256x1x128xf32> to memref<1x256x128xf32> + rock.gemm_elementwise_gemm{ + ab = %3 * %2 : memref<1x128x64xf32>, memref<1x64x256xf32> + ab = elementwise { + ^bb0(%arg5: memref<1x128x256xf32>, %arg6: memref<1x128x256xf32>): + memref.copy %arg5, %arg6 : memref<1x128x256xf32> to memref<1x128x256xf32> + rock.yield + } + %alloc = ab * %1 : memref<1x256x128xf32> -> memref<1x128x128xf32> + } {firstGemmIndices = array, storeMethod = #rock} + %10 = rock.transform %alloc by #transform_map10 : memref<1x128x128xf32> to memref<128x128xf32> + %11 = rock.transform %0 by #transform_map10 : memref<1x128x128xf32> to memref<128x128xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32> + linalg.generic {indexing_maps = [#map5, #map5, #map5], iterator_types = ["parallel", "parallel"]} ins(%10, %11 : memref<128x128xf32>, memref<128x128xf32>) outs(%alloc_0 : memref<128x128xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %14 = arith.addf %in, %in_2 : f32 + linalg.yield %14 : f32 + } + %12 = rock.transform %alloc_0 by #transform_map11 : memref<128x128xf32> to memref<1x128x128xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %12 into %alloc_1 {axis = 2 : index, blockSize = 256 : i32, gridSize = 64 : i32} : memref<1x128x128xf32> into memref<1x128x1xf32> + %13 = rock.transform %alloc_1 by #transform_map12 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %13, %arg4 : memref<128xf32> to memref<128xf32> + return + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-no-fusion.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-no-fusion.mlir new file mode 100644 index 000000000000..3b7f0115c309 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-no-fusion.mlir @@ -0,0 +1,47 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 + + +#map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0) -> (0, d0 floordiv 128, d0 mod 128)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 256, 128] -> [32768]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map3 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map4 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map5 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [256, 1, 128] -> [1, 256, 128]> +#transform_map8 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 256, 128] -> [256, 1, 128]> +#transform_map9 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [16384] -> [1, 128, 128]> +module { + func.func private @gemm_gemm_no_fusion(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<16384xf32> {mhal.write_access}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x256x128xf32> + %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> + %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x128xf32> + %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> + %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> + %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> + %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> + %7 = rock.transform %0 by #transform_map7 : memref<1x256x128xf32> to memref<256x1x128xf32> + %8 = rock.transform %7 by #transform_map8 : memref<256x1x128xf32> to memref<1x256x128xf32> + rock.gemm_elementwise_gemm{ + ab = %2 * %1 : memref<1x128x64xf32>, memref<1x64x256xf32> + ab = elementwise { + ^bb0(%arg4: memref<1x128x256xf32>, %arg5: memref<1x128x256xf32>): + memref.copy %arg4, %arg5 : memref<1x128x256xf32> to memref<1x128x256xf32> + rock.yield + } + %alloc = ab * %0 : memref<1x256x128xf32> -> memref<1x128x128xf32> + } {firstGemmIndices = array, storeMethod = #rock} + %9 = rock.transform %alloc by #transform_map9 : memref<1x128x128xf32> to memref<16384xf32> + memref.copy %9, %arg3 : memref<16384xf32> to memref<16384xf32> + return + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir new file mode 100644 index 000000000000..d9ca8fcd9b0b --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir @@ -0,0 +1,49 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:rank3:axis1:stride1 + + +#map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0) -> (0, 0, d0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 256, 128] -> [32768]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map3 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map4 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map5 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [256, 1, 128] -> [1, 256, 128]> +#transform_map8 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 256, 128] -> [256, 1, 128]> +#transform_map9 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 1, 128]> +module { + func.func private @gemm_gemm_reduce_sum_axis1(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x256x128xf32> + %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> + %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x128xf32> + %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> + %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> + %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> + %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> + %7 = rock.transform %0 by #transform_map7 : memref<1x256x128xf32> to memref<256x1x128xf32> + %8 = rock.transform %7 by #transform_map8 : memref<256x1x128xf32> to memref<1x256x128xf32> + rock.gemm_elementwise_gemm{ + ab = %2 * %1 : memref<1x128x64xf32>, memref<1x64x256xf32> + ab = elementwise { + ^bb0(%arg4: memref<1x128x256xf32>, %arg5: memref<1x128x256xf32>): + memref.copy %arg4, %arg5 : memref<1x128x256xf32> to memref<1x128x256xf32> + rock.yield + } + %alloc = ab * %0 : memref<1x256x128xf32> -> memref<1x128x128xf32> + } {firstGemmIndices = array, storeMethod = #rock} + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x1x128xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 1 : index, blockSize = 256 : i32, gridSize = 64 : i32} : memref<1x128x128xf32> into memref<1x1x128xf32> + %9 = rock.transform %alloc_0 by #transform_map9 : memref<1x1x128xf32> to memref<128xf32> + memref.copy %9, %arg3 : memref<128xf32> to memref<128xf32> + return + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir new file mode 100644 index 000000000000..c83e6814fac4 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir @@ -0,0 +1,49 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:rank3:axis2:stride1 + + +#map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 256, 128] -> [32768]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map3 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map4 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map5 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [256, 1, 128] -> [1, 256, 128]> +#transform_map8 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 256, 128] -> [256, 1, 128]> +#transform_map9 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> +module { + func.func private @gemm_gemm_reduce_sum_axis2(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x256x128xf32> + %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> + %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x128xf32> + %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> + %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> + %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> + %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> + %7 = rock.transform %0 by #transform_map7 : memref<1x256x128xf32> to memref<256x1x128xf32> + %8 = rock.transform %7 by #transform_map8 : memref<256x1x128xf32> to memref<1x256x128xf32> + rock.gemm_elementwise_gemm{ + ab = %2 * %1 : memref<1x128x64xf32>, memref<1x64x256xf32> + ab = elementwise { + ^bb0(%arg4: memref<1x128x256xf32>, %arg5: memref<1x128x256xf32>): + memref.copy %arg4, %arg5 : memref<1x128x256xf32> to memref<1x128x256xf32> + rock.yield + } + %alloc = ab * %0 : memref<1x256x128xf32> -> memref<1x128x128xf32> + } {firstGemmIndices = array, storeMethod = #rock} + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 64 : i32} : memref<1x128x128xf32> into memref<1x128x1xf32> + %9 = rock.transform %alloc_0 by #transform_map9 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %9, %arg3 : memref<128xf32> to memref<128xf32> + return + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir new file mode 100644 index 000000000000..56da4965b50b --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir @@ -0,0 +1,49 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2:stride1:hasPointwise + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0, d1) -> (0, d0, d1)> +#map4 = affine_map<(d0, d1) -> (d0, d1)> +#map5 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map6 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 256] -> [32768]> +#transform_map1 = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map2 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map6 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>] bounds = [128, 256] -> [1, 128, 256]> +#transform_map8 = #rock.transform_map<#map5 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, [] at []>] bounds = [1, 128, 256] -> [128, 256]> +#transform_map9 = #rock.transform_map<#map6 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> +module { + func.func private @gemm_mul_reduce_sum(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x128x256xf32> + %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> + %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> + %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> + %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> + %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %2 * %1 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %7 = rock.transform %alloc by #transform_map7 : memref<1x128x256xf32> to memref<128x256xf32> + %8 = rock.transform %0 by #transform_map7 : memref<1x128x256xf32> to memref<128x256xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<128x256xf32> + linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%7, %8 : memref<128x256xf32>, memref<128x256xf32>) outs(%alloc_0 : memref<128x256xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %11 = arith.mulf %in, %in_2 : f32 + linalg.yield %11 : f32 + } + %9 = rock.transform %alloc_0 by #transform_map8 : memref<128x256xf32> to memref<1x128x256xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %9 into %alloc_1 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %10 = rock.transform %alloc_1 by #transform_map9 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %10, %arg3 : memref<128xf32> to memref<128xf32> + return + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir new file mode 100644 index 000000000000..93e0d793c8cd --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir @@ -0,0 +1,44 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:rank3:axis1:stride256 sum:rank3:axis2:stride1 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, d0, 0)> +#map4 = affine_map<(d0) -> (0, 0, d0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> +#transform_map7 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [256] -> [1, 1, 256]> +module { + func.func private @gemm_multi_reduce_different_axes(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}, %arg3: memref<256xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + + // First reduction: sum on axis 1 (row reduction: 1x128x256 -> 1x1x256) + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x1x256xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 1 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x1x256xf32> + %6 = rock.transform %alloc_0 by #transform_map7 : memref<1x1x256xf32> to memref<256xf32> + + // Second reduction: sum on axis 2 (column reduction: 1x128x256 -> 1x128x1) + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %alloc into %alloc_1 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %7 = rock.transform %alloc_1 by #transform_map6 : memref<1x128x1xf32> to memref<128xf32> + + memref.copy %7, %arg2 : memref<128xf32> to memref<128xf32> + memref.copy %6, %arg3 : memref<256xf32> to memref<256xf32> + return + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir new file mode 100644 index 000000000000..224a73582433 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir @@ -0,0 +1,51 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:rank3:axis2:stride1:hasPointwise sum:rank3:axis2:stride1 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, d0, 0)> +#map4 = affine_map<(d0, d1) -> (0, d0, d1)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> +#map6 = affine_map<(d0, d1, d2) -> (d1, d2)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> +#transform_map7 = #rock.transform_map<#map4 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>] bounds = [128, 256] -> [1, 128, 256]> +#transform_map8 = #rock.transform_map<#map6 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, [] at []>] bounds = [1, 128, 256] -> [128, 256]> +module { + func.func private @gemm_multi_reduce(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %6 = rock.transform %alloc_0 by #transform_map6 : memref<1x128x1xf32> to memref<128xf32> + %7 = rock.transform %alloc by #transform_map7 : memref<1x128x256xf32> to memref<128x256xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<128x256xf32> + linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel"]} ins(%7 : memref<128x256xf32>) outs(%alloc_1 : memref<128x256xf32>) { + ^bb0(%in: f32, %out: f32): + %cst = arith.constant 2.0 : f32 + %10 = arith.mulf %in, %cst : f32 + linalg.yield %10 : f32 + } + %8 = rock.transform %alloc_1 by #transform_map8 : memref<128x256xf32> to memref<1x128x256xf32> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %8 into %alloc_2 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %9 = rock.transform %alloc_2 by #transform_map6 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %6, %arg2 : memref<128xf32> to memref<128xf32> + memref.copy %9, %arg3 : memref<128xf32> to memref<128xf32> + return + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-no-fusion.mlir b/mlir/test/fusion/problem-key-tests/gemm-no-fusion.mlir new file mode 100644 index 000000000000..811c09132800 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-no-fusion.mlir @@ -0,0 +1,31 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, d0 floordiv 256, d0 mod 256)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [32768] -> [1, 128, 256]> +module { + func.func private @gemm_no_fusion(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.write_access}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %6 = rock.transform %alloc by #transform_map6 : memref<1x128x256xf32> to memref<32768xf32> + memref.copy %6, %arg2 : memref<32768xf32> to memref<32768xf32> + return + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir b/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir new file mode 100644 index 000000000000..939e4c4bb1c0 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir @@ -0,0 +1,37 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2:stride1 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, d0 floordiv 256, d0 mod 256)> +#map4 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [32768] -> [1, 128, 256]> +#transform_map7 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> +module { + func.func private @gemm_passthrough_reduce(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.write_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %6 = rock.transform %alloc by #transform_map6 : memref<1x128x256xf32> to memref<32768xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %7 = rock.transform %alloc_0 by #transform_map7 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %6, %arg2 : memref<32768xf32> to memref<32768xf32> + memref.copy %7, %arg3 : memref<128xf32> to memref<128xf32> + return + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir new file mode 100644 index 000000000000..29e633090481 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir @@ -0,0 +1,33 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 max:rank3:axis2:stride1 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> +module { + func.func private @gemm_reduce_max_axis2(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0xFF800000 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce max %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %6 = rock.transform %alloc_0 by #transform_map6 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %6, %arg2 : memref<128xf32> to memref<128xf32> + return + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir new file mode 100644 index 000000000000..b6cc84d4c4e8 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir @@ -0,0 +1,33 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis1:stride256 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, 0, d0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [256] -> [1, 1, 256]> +module { + func.func private @gemm_reduce_sum_axis1(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<256xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x1x256xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 1 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x1x256xf32> + %6 = rock.transform %alloc_0 by #transform_map6 : memref<1x1x256xf32> to memref<256xf32> + memref.copy %6, %arg2 : memref<256xf32> to memref<256xf32> + return + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir new file mode 100644 index 000000000000..7a3d54595a59 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir @@ -0,0 +1,33 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2:stride1 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> +module { + func.func private @gemm_reduce_sum_axis2(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %6 = rock.transform %alloc_0 by #transform_map6 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %6, %arg2 : memref<128xf32> to memref<128xf32> + return + } +} + diff --git a/mlir/utils/performance/perfRunner.py b/mlir/utils/performance/perfRunner.py index 42d59b4d4b07..4f7305e31871 100644 --- a/mlir/utils/performance/perfRunner.py +++ b/mlir/utils/performance/perfRunner.py @@ -607,6 +607,10 @@ def to_command_line(self): f"-y {self.y} -x {self.x} -p {self.padding_h} -q {self.padding_w} " + f"-u {self.conv_stride_h} -v {self.conv_stride_w} -l {self.dilation_h} " + f"-j {self.dilation_w} -m conv -g {self.group} -t 1") + + def to_tuning_key(self): + """Returns the full problem key including fusion info for tuning DB lookups.""" + return getattr(self, '_original_command_line', self.to_command_line()) def __init__(self, dtype: str, direction: str, filter_layout: str, input_layout: str, output_layout: str, n: int, c: int, hi: int, wi: int, k: int, y: int, x: int, @@ -975,6 +979,11 @@ def from_command_line(cls, argv, arch, num_cu): scale_b_dtype = None trans_scale_a = False trans_scale_b = False + + # Store the original command line for accurate tuning DB lookups + # (including fusion info which we don't parse but need for cache key) + original_command_line = ' '.join(argv) + i = 0 while i < len(argv): opt = argv[i] @@ -983,6 +992,9 @@ def from_command_line(cls, argv, arch, num_cu): scaled_gemm = True i += 1 continue + # Fusion info is always at the end, so we can stop parsing here + if opt == '-fusion_reduce': + break # Handle flags with values if i + 1 >= len(argv): raise ValueError(f"Missing value for argument {opt}") @@ -1020,8 +1032,11 @@ def from_command_line(cls, argv, arch, num_cu): if v is None: raise ValueError("Incomplete GEMM configuration") - return cls(dtype, out_dtype, g, m, k, n, trans_a, trans_b, scaled_gemm, scale_a_dtype, - scale_b_dtype, trans_scale_a, trans_scale_b, arch, num_cu, perf_config) + config = cls(dtype, out_dtype, g, m, k, n, trans_a, trans_b, scaled_gemm, scale_a_dtype, + scale_b_dtype, trans_scale_a, trans_scale_b, arch, num_cu, perf_config) + # Store the full original command line for tuning DB lookups + config._original_command_line = original_command_line + return config def to_command_line(self): result = (f"-t {self.datatype} -out_datatype {self.out_dtype} " + @@ -1038,6 +1053,10 @@ def to_command_line(self): if self.trans_scale_b: result += f" -transScaleB {str(self.trans_scale_b).lower()}" return result + + def to_tuning_key(self): + """Returns the full problem key including fusion info for tuning DB lookups.""" + return getattr(self, '_original_command_line', self.to_command_line()) def __init__(self, dtype: str, @@ -1225,9 +1244,16 @@ def from_command_line(cls, argv, arch, num_cu): input_layout = None trans_c = False trans_o = False + + # Store the original command line for accurate tuning DB lookups + original_command_line = ' '.join(argv) + # Please keep this in sync with mlir::rock::getTuningProblemStr() for i in range(0, len(argv), 2): opt = argv[i] + # Fusion info is always at the end, so we can stop parsing here + if opt == '-fusion_reduce': + break val = argv[i + 1] if opt.endswith("-t"): dtype = val @@ -1280,9 +1306,11 @@ def from_command_line(cls, argv, arch, num_cu): if v is None: raise ValueError("Incomplete conv+gemm configuration") - return cls(dtype, filter_layout, input_layout, trans_c, trans_o, n, c, hi, wi, k, y, x, o, - conv_stride_h, conv_stride_w, padding_h, padding_w, dilation_h, dilation_w, - group, arch, num_cu, perf_config) + config = cls(dtype, filter_layout, input_layout, trans_c, trans_o, n, c, hi, wi, k, y, x, o, + conv_stride_h, conv_stride_w, padding_h, padding_w, dilation_h, dilation_w, + group, arch, num_cu, perf_config) + config._original_command_line = original_command_line + return config def to_command_line(self): return (f"-t {self.datatype} " + @@ -1292,6 +1320,10 @@ def to_command_line(self): f"-y {self.y} -x {self.x} -p {self.padding_h} -q {self.padding_w} " + f"-u {self.conv_stride_h} -v {self.conv_stride_w} -l {self.dilation_h} " + f"-j {self.dilation_w} -g {self.group}" + f"-gemmO {str(self.o)}") + + def to_tuning_key(self): + """Returns the full problem key including fusion info for tuning DB lookups.""" + return getattr(self, '_original_command_line', self.to_command_line()) class GemmGemmConfiguration(PerfConfiguration): @@ -1385,9 +1417,16 @@ def from_command_line(cls, argv, arch, num_cu): trans_b = False trans_c = False trans_o = False + + # Store the original command line for accurate tuning DB lookups + original_command_line = ' '.join(argv) + # Please keep this in sync with mlir::rock::getTuningProblemStr() for i in range(0, len(argv), 2): opt = argv[i] + # Fusion info is always at the end, so we can stop parsing here + if opt == '-fusion_reduce': + break val = argv[i + 1] if opt.endswith("-t"): dtype = val @@ -1417,8 +1456,10 @@ def from_command_line(cls, argv, arch, num_cu): if v is None: raise ValueError("Incomplete gemm+gemm configuration") - return cls(dtype, g, m, k, n, o, trans_a, trans_b, trans_c, trans_o, arch, num_cu, - perf_config) + config = cls(dtype, g, m, k, n, o, trans_a, trans_b, trans_c, trans_o, arch, num_cu, + perf_config) + config._original_command_line = original_command_line + return config def to_command_line(self): return (f"-t {self.datatype} " + @@ -1426,6 +1467,10 @@ def to_command_line(self): f"-transC {str(self.trans_c).lower()} -transO {str(self.trans_o).lower()} " + f"-g {self.g} " + f"-m {str(self.m)} -k {str(self.k)} -n {str(self.n)} -gemmO {str(self.o)}") + + def to_tuning_key(self): + """Returns the full problem key including fusion info for tuning DB lookups.""" + return getattr(self, '_original_command_line', self.to_command_line()) class AttentionConfiguration(PerfConfiguration): @@ -1565,9 +1610,16 @@ def from_command_line(cls, argv, arch, num_cu): split_kv = 1 with_attn_scale = False with_attn_bias = False + + # Store the original command line for accurate tuning DB lookups + original_command_line = ' '.join(argv) + # Please keep this in sync with mlir::rock::getTuningProblemStr() for i in range(0, len(argv), 2): opt = argv[i] + # Fusion info is always at the end, so we can stop parsing here + if opt == '-fusion_reduce': + break val = argv[i + 1] if opt.endswith("-t"): dtype = val @@ -1615,9 +1667,11 @@ def from_command_line(cls, argv, arch, num_cu): if v is None: raise ValueError("Incomplete Attention configuration") - return cls(dtype, g, seq_len_q, seq_len_k, num_heads_q, num_heads_kv, head_dim_qk, - head_dim_v, with_attn_scale, with_attn_bias, trans_q, trans_k, trans_v, trans_o, - causal, return_lse, split_kv, arch, num_cu, perf_config) + config = cls(dtype, g, seq_len_q, seq_len_k, num_heads_q, num_heads_kv, head_dim_qk, + head_dim_v, with_attn_scale, with_attn_bias, trans_q, trans_k, trans_v, trans_o, + causal, return_lse, split_kv, arch, num_cu, perf_config) + config._original_command_line = original_command_line + return config def to_command_line(self): return ( @@ -1630,6 +1684,10 @@ def to_command_line(self): f"-seq_len_q {str(self.seq_len_q)} -seq_len_k {str(self.seq_len_k)} -num_heads_q {str(self.num_heads_q)} -num_heads_kv {str(self.num_heads_kv)} -head_dim_qk {str(self.head_dim_qk)} -head_dim_v {str(self.head_dim_v)} " + f"-with-attn-scale {str(self.with_attn_scale).lower()} " + f"-with-attn-bias {str(self.with_attn_bias).lower()}") + + def to_tuning_key(self): + """Returns the full problem key including fusion info for tuning DB lookups.""" + return getattr(self, '_original_command_line', self.to_command_line()) class RocBLASGemmConfig(GemmConfiguration): @@ -1748,10 +1806,11 @@ def benchmark_mlir(commandline, rocmlir_gen_flags, use_rocprof=False): config = conf_class.from_command_line(commandline, arch, num_cu) - config_str = config.to_command_line() + # Use to_tuning_key() which includes fusion info for accurate DB lookups + config_key = config.to_tuning_key() if hasattr(config, 'to_tuning_key') else config.to_command_line() if tuning_db: - if (arch, config_str) in tuning_db: - config.set_perfconfig(tuning_db[arch, config_str]) + if (arch, config_key) in tuning_db: + config.set_perfconfig(tuning_db[arch, config_key]) else: # Tuning DB present but doesn't contain config, return N/A return config.table_entry(np.nan) @@ -2071,9 +2130,10 @@ def benchmark_fusion_kernels(test_dir, # Find the best perf_config best_perf = "" if tuning_db: - config_str = config.to_command_line() - if (arch, config_str) in tuning_db: - best_perf = tuning_db[arch, config_str] + # Use to_tuning_key() which includes fusion info for accurate DB lookups + config_key = config.to_tuning_key() if hasattr(config, 'to_tuning_key') else config.to_command_line() + if (arch, config_key) in tuning_db: + best_perf = tuning_db[arch, config_key] config.set_perfconfig(best_perf) else: # Tuning DB present but doesn't contain config, add a NaN entry if test_vector not in perf_results: