Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 209 additions & 0 deletions mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/BufferDependencyAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Rock/IR/AmdArchDb.h"
#include "mlir/Dialect/Rock/IR/GetRockInfo.h"
#include "mlir/Dialect/Rock/IR/Rock.h"
Expand All @@ -22,9 +24,12 @@
#include "mlir/Dialect/Rock/Tuning/RockTuning.h"
#include "mlir/Dialect/Rock/utility/fusionUtils.h"
#include "mlir/Dialect/Rock/utility/loweringUtils.h"
#include "mlir/Dialect/Rock/utility/transformMapUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down Expand Up @@ -742,6 +747,197 @@ extractLayouts(Operation *op, llvm::StringMap<unsigned> &fLayoutMap,
return success();
}

// Structure to hold information about a single reduction operation
struct ReductionInfo {
ReduceMethod method;
int64_t axis;
int64_t rank;
int64_t stride; // Stride of the reduction dimension
bool hasPointwiseBefore;

bool operator<(const ReductionInfo &other) const {
// Sort by method first, then rank, then axis, then stride, then
// hasPointwiseBefore
if (method != other.method)
return method < other.method;
if (rank != other.rank)
return rank < other.rank;
if (axis != other.axis)
return axis < other.axis;
if (stride != other.stride)
return stride < other.stride;
return hasPointwiseBefore > other.hasPointwiseBefore;
}

bool operator==(const ReductionInfo &other) const {
return method == other.method && axis == other.axis && rank == other.rank &&
stride == other.stride &&
hasPointwiseBefore == other.hasPointwiseBefore;
}
};

// Structure to hold fusion information for problem key generation
struct FusionInfo {
SmallVector<ReductionInfo> reductions;

bool hasReduction() const { return !reductions.empty(); }
int numReductionOutputs() const { return reductions.size(); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
int numReductionOutputs() const { return reductions.size(); }
int numReductions() const { return reductions.size(); }

numReductionOutputs is ambigous. It could mean how many reduction operations are being returned from the function or how many outputs each reduction op has (if there is such an op with multiple outputs).

};

// Helper to get the base value (allocation or block argument) from a value
static FailureOr<Value> getBaseValue(Value v) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be moved to loweringUtils.h

FailureOr<memref::AllocOp> maybeAlloc = rock::findMemrefAlloc(v);
if (succeeded(maybeAlloc)) {
return maybeAlloc.value().getResult();
}

FailureOr<BlockArgument> maybeBlockArg = rock::findBlockArgument(v);
if (succeeded(maybeBlockArg)) {
return maybeBlockArg.value();
}

return failure();
}

// Helper to trace backwards from a value to see if it reaches the target
// Returns success(hasPointwise) if target is reached, failure otherwise
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Returns success(hasPointwise) if target is reached, failure otherwise
// Returns success if target is reached via traversing through pointwise, failure otherwise

static FailureOr<bool> tracesToTarget(Value start, Value target,
const BufferDependencyAnalysis &deps,
DenseSet<Value> &visited) {
if (!visited.insert(start).second) {
return failure(); // Avoid cycles
}

FailureOr<Value> baseValue = getBaseValue(start);
if (failed(baseValue))
return failure(); // Could not find base value

if (*baseValue == target) {
return false; // Found target, no pointwise
}

// For allocations, use BufferDependencyAnalysis to find writers
if (auto allocOp = baseValue->getDefiningOp<memref::AllocOp>()) {
std::optional<SmallVector<OpOperand *>> writers = deps.getWriters(allocOp);
if (writers) {
for (OpOperand *writerOperand : *writers) {
auto genericOp = dyn_cast<linalg::GenericOp>(writerOperand->getOwner());
if (!genericOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we "continue" it if it's not genericOp?
We might be also to find another genericOp in the next one even if this one is not a genericOp, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, we could have:

out = gemm(...)
b = linalg(out)
a = non_linalg(b)
return a

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding of .getWriters() from BufferDependencyAnalysis is that it will do just that. It already has logic that traces through so that any rock.transforms will be skipped. Are there other non-linalg ops that we care about? Wouldn't that mean that it's a non-fusible operation and should be skipped?

Copy link
Contributor

@dhernandez0 dhernandez0 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does that, yes, but then if the op is not a linalg.generic you just skip it here, why do we want to do that?

continue;
}

// Trace through inputs of the linalg.generic (assumed to be pointwise)
for (Value input : genericOp.getInputs()) {
FailureOr<bool> maybeHasPointwise =
tracesToTarget(input, target, deps, visited);
if (succeeded(maybeHasPointwise)) {
return true; // Found target through pointwise
}
}
}
}
}

return failure();
}

// Find all reductions and check if they trace back to our GEMM output
static FusionInfo getFusionInfo(Value gemmResult, GemmFeatures features) {
FusionInfo info;

// Find the target (allocation or block argument)
FailureOr<Value> maybeTarget = getBaseValue(gemmResult);
if (failed(maybeTarget))
return info; // None found

Value target = *maybeTarget;

// Get the parent function
auto defOp = gemmResult.getDefiningOp();
auto funcOp = defOp ? rock::getParentFuncOp(defOp) : nullptr;
if (!funcOp) {
return info;
}

// Walk all reduce operations and check if they trace back to our GEMM.
// Note, we are assuming that all reduce operations are returned here.
BufferDependencyAnalysis deps(funcOp);
funcOp->walk([&](rock::ReduceOp reduceOp) {
DenseSet<Value> visited;
FailureOr<bool> maybeHasPointwise =
tracesToTarget(reduceOp.getIn(), target, deps, visited);

if (succeeded(maybeHasPointwise)) {
ReductionInfo redInfo;
redInfo.method = reduceOp.getReduceMethod();
redInfo.axis = reduceOp.getAxis().getSExtValue();
auto memrefType = cast<MemRefType>(reduceOp.getIn().getType());
redInfo.rank = memrefType.getRank();

// Extract stride for the reduction dimension
SmallVector<int64_t> strides;
int64_t offset;
if (succeeded(memrefType.getStridesAndOffset(strides, offset))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Is there a test with stride != 1 ?
I think the way we use memref strides are always 1. Layout information is encoded by series of rock.transforms. So i don't think there is an easy way to get that information out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the stride from memref is the real stride in rocmlir. We can have transforms that do padding etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, as @pfultz2 suggested, we might want to add a list of axes from the input tensor shape that get reduced, instead of the output axis.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my question as well, is if we are overcomplicating things, maybe to know we are doing a reduction is enough, does it help tuning in any way to know the axis?

redInfo.stride = strides[redInfo.axis];
} else {
// If we can't determine stride, use dynamic sentinel
redInfo.stride = ShapedType::kDynamic;
}

redInfo.hasPointwiseBefore = *maybeHasPointwise;
info.reductions.push_back(redInfo);
}
});

// Sort reductions for consistent ordering in problem key
std::sort(info.reductions.begin(), info.reductions.end());

return info;
}

// Append fusion information to the problem key string
static void appendOutputFusionInfo(llvm::raw_svector_ostream &problemOS,
const FusionInfo &fusionInfo) {
constexpr char sep = ' ';

if (!fusionInfo.hasReduction())
return;

problemOS << sep << "-fusion_reduce" << sep
<< "count=" << fusionInfo.numReductionOutputs();

// Encode each reduction in format: method:rank:axis:stride[:hasPointwise]
for (const auto &reduction : fusionInfo.reductions) {
problemOS << sep;

// Add reduction method
switch (reduction.method) {
case ReduceMethod::Sum:
problemOS << "sum";
break;
case ReduceMethod::Max:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can there be some other kind of reduce operations ? Something other than Sum/Max ?

i suggest adding llvm_unreachable as default case.

problemOS << "max";
break;
}

// Add rank, axis, and stride with colon separators
problemOS << ":rank" << reduction.rank;
problemOS << ":axis" << reduction.axis;

// Add stride (use '?' for dynamic/unknown strides)
if (reduction.stride == ShapedType::kDynamic) {
problemOS << ":stride?";
} else {
problemOS << ":stride" << reduction.stride;
}

// Add pointwise flag for this specific reduction
if (reduction.hasPointwiseBefore) {
problemOS << ":hasPointwise";
}
}
}

static LogicalResult
getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp,
SmallVectorImpl<char> &out) {
Expand Down Expand Up @@ -917,6 +1113,13 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp,
problemOS << "-k " << headDimQK << sep;
problemOS << "-gemmO " << headDimV;
}

// Analyze and append fusion information
Value gemmGemmOutput = gemmGemmOp.getOutArgument()->get();
GemmFeatures features = rock::getFeatures(gemmGemmOp);
FusionInfo fusionInfo = getFusionInfo(gemmGemmOutput, features);
appendOutputFusionInfo(problemOS, fusionInfo);

return success();
}

Expand Down Expand Up @@ -1134,6 +1337,12 @@ static LogicalResult getTuningProblemStr(rock::RockGemmWrapperInterface gemmIF,
return failure();
}

// Analyze and append fusion information
Value gemmOutput = gemmIF.getOutArgument()->get();
GemmFeatures features = rock::getFeatures(gemmIF);
FusionInfo fusionInfo = getFusionInfo(gemmOutput, features);
appendOutputFusionInfo(problemOS, fusionInfo);

while (out.back() == sep) {
// remove trailing whitespace
out.pop_back();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s

// CHECK: 256 convfp16 -F 1 -f GNC01 -I NGC01 -O NGC01 -n 1 -c 128 -H 32 -W 32 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -fusion_reduce count=2 sum:rank3:axis2:stride1:hasPointwise sum:rank3:axis2:stride1:hasPointwise

#map = affine_map<(d0, d1, d2, d3) -> (((d0 * 128 + d1) * 3 + d2) * 3 + d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> ((d1 * 32 + d2) * 32 + d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 * 128 + d2, d3, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0 * 256 + d1, d2, d3, d4)>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 * 256 + d2, d3, d4)>
#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d0, d1, d2, d4)>
#map6 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d0, d1, d4)>
#map7 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 * 8 + d2, d3, d4)>
#map8 = affine_map<(d0, d1, d2, d3, d4) -> (d1 * 8 + d2)>
#map9 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, 0, 0)>
#map10 = affine_map<(d0, d1, d2, d3) -> (0, d0, d1, d2, d3)>
#map11 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map12 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>
#map13 = affine_map<(d0) -> (0, d0 floordiv 8192, (d0 mod 8192) floordiv 1024, (d0 mod 1024) floordiv 32, d0 mod 32)>
#map14 = affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 1024, (d2 mod 1024) floordiv 32, d2 mod 32)>
#map15 = affine_map<(d0) -> (0, d0, 0)>
#transform_map = #rock.transform_map<#map by [<Unmerge{256, 128, 3, 3} ["exp0", "exp1", "exp2", "exp3"] at [0, 1, 2, 3] -> ["dim0"] at [0]>] bounds = [256, 128, 3, 3] -> [294912]>
#transform_map1 = #rock.transform_map<#map1 by [<Unmerge{128, 32, 32} ["exp1", "exp2", "exp3"] at [1, 2, 3] -> ["dim0"] at [0]>, <AddDim{1} ["unit0"] at [0] -> [] at []>] bounds = [1, 128, 32, 32] -> [131072]>
#transform_map2 = #rock.transform_map<#map2 by [<PassThrough ["n", "h", "w"] at [0, 3, 4] -> ["n", "h", "w"] at [0, 2, 3]>, <Unmerge{1, 128} ["g", "c"] at [1, 2] -> ["c"] at [1]>] bounds = [1, 1, 128, 32, 32] -> [1, 128, 32, 32]>
#transform_map3 = #rock.transform_map<#map3 by [<PassThrough ["c", "y", "x"] at [2, 3, 4] -> ["c", "y", "x"] at [1, 2, 3]>, <Unmerge{1, 256} ["g", "k"] at [0, 1] -> ["k"] at [0]>] bounds = [1, 256, 128, 3, 3] -> [256, 128, 3, 3]>
#transform_map4 = #rock.transform_map<#map4 by [<PassThrough ["n", "h", "w"] at [0, 3, 4] -> ["n", "h", "w"] at [0, 2, 3]>, <Unmerge{1, 256} ["g", "k"] at [1, 2] -> ["k"] at [1]>] bounds = [1, 1, 256, 32, 32] -> [1, 256, 32, 32]>
#transform_map5 = #rock.transform_map<#map5 by [<PassThrough ["dim1", "dim2", "dim3", "dim0", "dim4"] at [0, 1, 2, 3, 4] -> ["dim1", "dim2", "dim3", "dim0", "dim4"] at [1, 2, 3, 0, 4]>] bounds = [256, 128, 3, 1, 3] -> [1, 256, 128, 3, 3]>
#transform_map6 = #rock.transform_map<#map6 by [<PassThrough ["dim2", "dim3", "dim0", "dim1", "dim4"] at [0, 1, 2, 3, 4] -> ["dim2", "dim3", "dim0", "dim1", "dim4"] at [2, 3, 0, 1, 4]>] bounds = [128, 32, 1, 1, 32] -> [1, 1, 128, 32, 32]>
#transform_map7 = #rock.transform_map<#map7 by [<PassThrough ["dim0"] at [0] -> ["dim0"] at [0]>, <Unmerge{32, 8} ["exp1", "exp2"] at [1, 2] -> ["dim1"] at [1]>, <PassThrough ["dim2"] at [3] -> ["dim2"] at [2]>, <PassThrough ["dim3"] at [4] -> ["dim3"] at [3]>] bounds = [1, 32, 8, 32, 32] -> [1, 256, 32, 32]>
#transform_map8 = #rock.transform_map<#map8 by [<Unmerge{32, 8} ["exp1", "exp2"] at [1, 2] -> ["dim0"] at [0]>, <AddDim{1} ["unit0"] at [0] -> [] at []>, <AddDim{1} ["unit3"] at [3] -> [] at []>, <AddDim{1} ["unit4"] at [4] -> [] at []>] bounds = [1, 32, 8, 1, 1] -> [256]>
#transform_map9 = #rock.transform_map<#map9 by [<PassThrough ["dim0"] at [0] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [1] -> ["dim1"] at [1]>, <PassThrough ["dim2"] at [2] -> ["dim2"] at [2]>, <Broadcast{1} ["dim3"] at [3] -> ["dim3"] at [3]>, <Broadcast{1} ["dim4"] at [4] -> ["dim4"] at [4]>] bounds = [1, 32, 8, 32, 32] -> [1, 32, 8, 1, 1]>
#transform_map10 = #rock.transform_map<#map10 by [<Merge{1, 32} ["dim0"] at [0] -> ["col0", "col1"] at [0, 1]>, <PassThrough ["dim1"] at [1] -> ["dim1"] at [2]>, <PassThrough ["dim2"] at [2] -> ["dim2"] at [3]>, <PassThrough ["dim3"] at [3] -> ["dim3"] at [4]>] bounds = [32, 8, 32, 32] -> [1, 32, 8, 32, 32]>
#transform_map11 = #rock.transform_map<#map12 by [<Unmerge{32} ["exp1"] at [1] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [2] -> ["dim1"] at [1]>, <PassThrough ["dim2"] at [3] -> ["dim2"] at [2]>, <PassThrough ["dim3"] at [4] -> ["dim3"] at [3]>, <AddDim{1} ["unit0"] at [0] -> [] at []>] bounds = [1, 32, 8, 32, 32] -> [32, 8, 32, 32]>
#transform_map12 = #rock.transform_map<#map13 by [<Merge{1, 32, 8, 32, 32} ["dim0"] at [0] -> ["col0", "col1", "col2", "col3", "col4"] at [0, 1, 2, 3, 4]>] bounds = [262144] -> [1, 32, 8, 32, 32]>
#transform_map13 = #rock.transform_map<#map14 by [<PassThrough ["dim0"] at [0] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [1] -> ["dim1"] at [1]>, <Merge{8, 32, 32} ["dim2"] at [2] -> ["col2", "col3", "col4"] at [2, 3, 4]>] bounds = [1, 32, 8192] -> [1, 32, 8, 32, 32]>
#transform_map14 = #rock.transform_map<#map15 by [<Merge{1, 32, 1} ["dim0"] at [0] -> ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [32] -> [1, 32, 1]>
module {
func.func @mlir_reshape_convolution_reshape_broadcast_add_mul_reshape_reduce_sum_reshape_mul_reshape_reduce_sum_reshape(%arg0: memref<131072xf16>, %arg1: memref<294912xf16>, %arg2: memref<256xf16>, %arg3: memref<32xf16> {mhal.read_access, rock.prefill = 0.000000e+00 : f16}, %arg4: memref<32xf16> {mhal.read_access, rock.prefill = 0.000000e+00 : f16}, %arg5: memref<262144xf16>) attributes {arch = "gfx950:sramecc+:xnack-", enable_splitk_for_tuning, kernel = "mixr", num_cu = 256 : i64} {
%cst = arith.constant 1.220700e-04 : f16
%0 = rock.transform %arg1 by #transform_map : memref<294912xf16> to memref<256x128x3x3xf16>
%1 = rock.transform %arg0 by #transform_map1 : memref<131072xf16> to memref<1x128x32x32xf16>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x256x32x32xf16>
%2 = rock.transform %1 by #transform_map2 : memref<1x128x32x32xf16> to memref<1x1x128x32x32xf16>
%3 = rock.transform %0 by #transform_map3 : memref<256x128x3x3xf16> to memref<1x256x128x3x3xf16>
%4 = rock.transform %alloc by #transform_map4 : memref<1x256x32x32xf16> to memref<1x1x256x32x32xf16>
%5 = rock.transform %3 by #transform_map5 : memref<1x256x128x3x3xf16> to memref<256x128x3x1x3xf16>
%6 = rock.transform %2 by #transform_map6 : memref<1x1x128x32x32xf16> to memref<128x32x1x1x32xf16>
rock.conv(%3, %2, %4) {dilations = [1 : index, 1 : index], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], output_layout = ["no", "go", "ko", "ho", "wo"], padding = [1 : index, 1 : index, 1 : index, 1 : index], strides = [1 : index, 1 : index]} : memref<1x256x128x3x3xf16>, memref<1x1x128x32x32xf16>, memref<1x1x256x32x32xf16>
%7 = rock.transform %alloc by #transform_map7 : memref<1x256x32x32xf16> to memref<1x32x8x32x32xf16>
%8 = rock.transform %arg2 by #transform_map8 : memref<256xf16> to memref<1x32x8x1x1xf16>
%9 = rock.transform %8 by #transform_map9 : memref<1x32x8x1x1xf16> to memref<1x32x8x32x32xf16>
%10 = rock.transform %7 by #transform_map10 : memref<1x32x8x32x32xf16> to memref<32x8x32x32xf16>
%11 = rock.transform %9 by #transform_map10 : memref<1x32x8x32x32xf16> to memref<32x8x32x32xf16>
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x8x32x32xf16>
linalg.generic {indexing_maps = [#map11, #map11, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%10, %11 : memref<32x8x32x32xf16>, memref<32x8x32x32xf16>) outs(%alloc_0 : memref<32x8x32x32xf16>) {
^bb0(%in: f16, %in_5: f16, %out: f16):
%20 = arith.addf %in, %in_5 : f16
linalg.yield %20 : f16
}
%12 = rock.transform %alloc_0 by #transform_map11 : memref<32x8x32x32xf16> to memref<1x32x8x32x32xf16>
%13 = rock.transform %12 by #transform_map12 : memref<1x32x8x32x32xf16> to memref<262144xf16>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x8x32x32xf16>
linalg.generic {indexing_maps = [#map11, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_0 : memref<32x8x32x32xf16>) outs(%alloc_1 : memref<32x8x32x32xf16>) {
^bb0(%in: f16, %out: f16):
%20 = arith.mulf %in, %cst : f16
linalg.yield %20 : f16
}
%14 = rock.transform %alloc_1 by #transform_map11 : memref<32x8x32x32xf16> to memref<1x32x8x32x32xf16>
%15 = rock.transform %14 by #transform_map13 : memref<1x32x8x32x32xf16> to memref<1x32x8192xf16>
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x32x1xf16>
rock.reduce sum %15 into %alloc_2 {axis = 2 : index, blockSize = 256 : i32, gridSize = 1024 : i32} : memref<1x32x8192xf16> into memref<1x32x1xf16>
%16 = rock.transform %alloc_2 by #transform_map14 : memref<1x32x1xf16> to memref<32xf16>
%alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x8x32x32xf16>
linalg.generic {indexing_maps = [#map11, #map11, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_1, %alloc_0 : memref<32x8x32x32xf16>, memref<32x8x32x32xf16>) outs(%alloc_3 : memref<32x8x32x32xf16>) {
^bb0(%in: f16, %in_5: f16, %out: f16):
%20 = arith.mulf %in, %in_5 : f16
linalg.yield %20 : f16
}
%17 = rock.transform %alloc_3 by #transform_map11 : memref<32x8x32x32xf16> to memref<1x32x8x32x32xf16>
%18 = rock.transform %17 by #transform_map13 : memref<1x32x8x32x32xf16> to memref<1x32x8192xf16>
%alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<1x32x1xf16>
rock.reduce sum %18 into %alloc_4 {axis = 2 : index, blockSize = 256 : i32, gridSize = 1024 : i32} : memref<1x32x8192xf16> into memref<1x32x1xf16>
%19 = rock.transform %alloc_4 by #transform_map14 : memref<1x32x1xf16> to memref<32xf16>
memref.copy %16, %arg3 : memref<32xf16> to memref<32xf16>
memref.copy %19, %arg4 : memref<32xf16> to memref<32xf16>
memref.copy %13, %arg5 : memref<262144xf16> to memref<262144xf16>
return
}
}
Loading
Loading