-
Notifications
You must be signed in to change notification settings - Fork 52
Extend problem key for reduction fusions #2133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
1214ace
540ab2a
eef4b5f
bd3f5c5
85832a2
9c4c9f6
4d92656
f6068c6
33c14bc
0dfaa7c
6325d50
8939734
dcc1231
144a7c8
ac72c15
d3a6a7a
bf18a0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -11,6 +11,8 @@ | |||||
| // | ||||||
| //===----------------------------------------------------------------------===// | ||||||
|
|
||||||
| #include "mlir/Analysis/BufferDependencyAnalysis.h" | ||||||
| #include "mlir/Dialect/Linalg/IR/Linalg.h" | ||||||
| #include "mlir/Dialect/Rock/IR/AmdArchDb.h" | ||||||
| #include "mlir/Dialect/Rock/IR/GetRockInfo.h" | ||||||
| #include "mlir/Dialect/Rock/IR/Rock.h" | ||||||
|
|
@@ -22,9 +24,12 @@ | |||||
| #include "mlir/Dialect/Rock/Tuning/RockTuning.h" | ||||||
| #include "mlir/Dialect/Rock/utility/fusionUtils.h" | ||||||
| #include "mlir/Dialect/Rock/utility/loweringUtils.h" | ||||||
| #include "mlir/Dialect/Rock/utility/transformMapUtils.h" | ||||||
| #include "mlir/IR/BuiltinOps.h" | ||||||
| #include "mlir/IR/BuiltinTypes.h" | ||||||
| #include "mlir/Interfaces/ViewLikeInterface.h" | ||||||
| #include "llvm/ADT/ArrayRef.h" | ||||||
| #include "llvm/ADT/DenseSet.h" | ||||||
| #include "llvm/ADT/STLExtras.h" | ||||||
| #include "llvm/ADT/StringRef.h" | ||||||
| #include "llvm/Support/FormatVariadic.h" | ||||||
|
|
@@ -742,6 +747,197 @@ extractLayouts(Operation *op, llvm::StringMap<unsigned> &fLayoutMap, | |||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| // Structure to hold information about a single reduction operation | ||||||
| struct ReductionInfo { | ||||||
| ReduceMethod method; | ||||||
| int64_t axis; | ||||||
| int64_t rank; | ||||||
| int64_t stride; // Stride of the reduction dimension | ||||||
| bool hasPointwiseBefore; | ||||||
|
|
||||||
| bool operator<(const ReductionInfo &other) const { | ||||||
| // Sort by method first, then rank, then axis, then stride, then | ||||||
| // hasPointwiseBefore | ||||||
| if (method != other.method) | ||||||
| return method < other.method; | ||||||
| if (rank != other.rank) | ||||||
| return rank < other.rank; | ||||||
| if (axis != other.axis) | ||||||
| return axis < other.axis; | ||||||
| if (stride != other.stride) | ||||||
| return stride < other.stride; | ||||||
| return hasPointwiseBefore > other.hasPointwiseBefore; | ||||||
| } | ||||||
|
|
||||||
| bool operator==(const ReductionInfo &other) const { | ||||||
| return method == other.method && axis == other.axis && rank == other.rank && | ||||||
| stride == other.stride && | ||||||
| hasPointwiseBefore == other.hasPointwiseBefore; | ||||||
| } | ||||||
| }; | ||||||
|
|
||||||
| // Structure to hold fusion information for problem key generation | ||||||
| struct FusionInfo { | ||||||
| SmallVector<ReductionInfo> reductions; | ||||||
|
|
||||||
| bool hasReduction() const { return !reductions.empty(); } | ||||||
| int numReductionOutputs() const { return reductions.size(); } | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
numReductionOutputs is ambigous. It could mean how many reduction operations are being returned from the function or how many outputs each reduction op has (if there is such an op with multiple outputs). |
||||||
| }; | ||||||
|
|
||||||
| // Helper to get the base value (allocation or block argument) from a value | ||||||
| static FailureOr<Value> getBaseValue(Value v) { | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be moved to |
||||||
| FailureOr<memref::AllocOp> maybeAlloc = rock::findMemrefAlloc(v); | ||||||
| if (succeeded(maybeAlloc)) { | ||||||
| return maybeAlloc.value().getResult(); | ||||||
| } | ||||||
|
|
||||||
| FailureOr<BlockArgument> maybeBlockArg = rock::findBlockArgument(v); | ||||||
| if (succeeded(maybeBlockArg)) { | ||||||
| return maybeBlockArg.value(); | ||||||
| } | ||||||
|
|
||||||
| return failure(); | ||||||
| } | ||||||
|
|
||||||
| // Helper to trace backwards from a value to see if it reaches the target | ||||||
| // Returns success(hasPointwise) if target is reached, failure otherwise | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| static FailureOr<bool> tracesToTarget(Value start, Value target, | ||||||
| const BufferDependencyAnalysis &deps, | ||||||
| DenseSet<Value> &visited) { | ||||||
| if (!visited.insert(start).second) { | ||||||
| return failure(); // Avoid cycles | ||||||
| } | ||||||
|
|
||||||
| FailureOr<Value> baseValue = getBaseValue(start); | ||||||
| if (failed(baseValue)) | ||||||
| return failure(); // Could not find base value | ||||||
|
|
||||||
| if (*baseValue == target) { | ||||||
| return false; // Found target, no pointwise | ||||||
| } | ||||||
|
|
||||||
| // For allocations, use BufferDependencyAnalysis to find writers | ||||||
| if (auto allocOp = baseValue->getDefiningOp<memref::AllocOp>()) { | ||||||
| std::optional<SmallVector<OpOperand *>> writers = deps.getWriters(allocOp); | ||||||
| if (writers) { | ||||||
| for (OpOperand *writerOperand : *writers) { | ||||||
| auto genericOp = dyn_cast<linalg::GenericOp>(writerOperand->getOwner()); | ||||||
| if (!genericOp) { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we "continue" it if it's not genericOp?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean, we could have: out = gemm(...)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it does that, yes, but then if the op is not a linalg.generic you just skip it here, why do we want to do that? |
||||||
| continue; | ||||||
| } | ||||||
|
|
||||||
| // Trace through inputs of the linalg.generic (assumed to be pointwise) | ||||||
| for (Value input : genericOp.getInputs()) { | ||||||
| FailureOr<bool> maybeHasPointwise = | ||||||
| tracesToTarget(input, target, deps, visited); | ||||||
| if (succeeded(maybeHasPointwise)) { | ||||||
| return true; // Found target through pointwise | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| return failure(); | ||||||
| } | ||||||
|
|
||||||
| // Find all reductions and check if they trace back to our GEMM output | ||||||
| static FusionInfo getFusionInfo(Value gemmResult, GemmFeatures features) { | ||||||
| FusionInfo info; | ||||||
|
|
||||||
| // Find the target (allocation or block argument) | ||||||
| FailureOr<Value> maybeTarget = getBaseValue(gemmResult); | ||||||
| if (failed(maybeTarget)) | ||||||
| return info; // None found | ||||||
|
|
||||||
| Value target = *maybeTarget; | ||||||
|
|
||||||
| // Get the parent function | ||||||
| auto defOp = gemmResult.getDefiningOp(); | ||||||
| auto funcOp = defOp ? rock::getParentFuncOp(defOp) : nullptr; | ||||||
| if (!funcOp) { | ||||||
| return info; | ||||||
| } | ||||||
|
|
||||||
| // Walk all reduce operations and check if they trace back to our GEMM. | ||||||
| // Note, we are assuming that all reduce operations are returned here. | ||||||
| BufferDependencyAnalysis deps(funcOp); | ||||||
| funcOp->walk([&](rock::ReduceOp reduceOp) { | ||||||
justinrosner marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| DenseSet<Value> visited; | ||||||
| FailureOr<bool> maybeHasPointwise = | ||||||
| tracesToTarget(reduceOp.getIn(), target, deps, visited); | ||||||
|
|
||||||
| if (succeeded(maybeHasPointwise)) { | ||||||
| ReductionInfo redInfo; | ||||||
| redInfo.method = reduceOp.getReduceMethod(); | ||||||
| redInfo.axis = reduceOp.getAxis().getSExtValue(); | ||||||
| auto memrefType = cast<MemRefType>(reduceOp.getIn().getType()); | ||||||
| redInfo.rank = memrefType.getRank(); | ||||||
|
|
||||||
| // Extract stride for the reduction dimension | ||||||
| SmallVector<int64_t> strides; | ||||||
| int64_t offset; | ||||||
| if (succeeded(memrefType.getStridesAndOffset(strides, offset))) { | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: Is there a test with stride != 1 ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the stride from memref is the real stride in rocmlir. We can have transforms that do padding etc.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, as @pfultz2 suggested, we might want to add a list of axes from the input tensor shape that get reduced, instead of the output axis.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my question as well, is if we are overcomplicating things, maybe to know we are doing a reduction is enough, does it help tuning in any way to know the axis? |
||||||
| redInfo.stride = strides[redInfo.axis]; | ||||||
| } else { | ||||||
| // If we can't determine stride, use dynamic sentinel | ||||||
| redInfo.stride = ShapedType::kDynamic; | ||||||
| } | ||||||
|
|
||||||
| redInfo.hasPointwiseBefore = *maybeHasPointwise; | ||||||
| info.reductions.push_back(redInfo); | ||||||
| } | ||||||
| }); | ||||||
|
|
||||||
| // Sort reductions for consistent ordering in problem key | ||||||
| std::sort(info.reductions.begin(), info.reductions.end()); | ||||||
justinrosner marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
| return info; | ||||||
| } | ||||||
|
|
||||||
| // Append fusion information to the problem key string | ||||||
| static void appendOutputFusionInfo(llvm::raw_svector_ostream &problemOS, | ||||||
| const FusionInfo &fusionInfo) { | ||||||
| constexpr char sep = ' '; | ||||||
|
|
||||||
| if (!fusionInfo.hasReduction()) | ||||||
| return; | ||||||
|
|
||||||
| problemOS << sep << "-fusion_reduce" << sep | ||||||
dhernandez0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| << "count=" << fusionInfo.numReductionOutputs(); | ||||||
|
|
||||||
| // Encode each reduction in format: method:rank:axis:stride[:hasPointwise] | ||||||
| for (const auto &reduction : fusionInfo.reductions) { | ||||||
| problemOS << sep; | ||||||
|
|
||||||
| // Add reduction method | ||||||
| switch (reduction.method) { | ||||||
| case ReduceMethod::Sum: | ||||||
| problemOS << "sum"; | ||||||
| break; | ||||||
| case ReduceMethod::Max: | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can there be some other kind of reduce operations ? Something other than Sum/Max ? i suggest adding |
||||||
| problemOS << "max"; | ||||||
| break; | ||||||
| } | ||||||
|
|
||||||
| // Add rank, axis, and stride with colon separators | ||||||
| problemOS << ":rank" << reduction.rank; | ||||||
| problemOS << ":axis" << reduction.axis; | ||||||
|
|
||||||
| // Add stride (use '?' for dynamic/unknown strides) | ||||||
| if (reduction.stride == ShapedType::kDynamic) { | ||||||
| problemOS << ":stride?"; | ||||||
| } else { | ||||||
| problemOS << ":stride" << reduction.stride; | ||||||
| } | ||||||
|
|
||||||
| // Add pointwise flag for this specific reduction | ||||||
| if (reduction.hasPointwiseBefore) { | ||||||
| problemOS << ":hasPointwise"; | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| static LogicalResult | ||||||
| getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp, | ||||||
| SmallVectorImpl<char> &out) { | ||||||
|
|
@@ -917,6 +1113,13 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp, | |||||
| problemOS << "-k " << headDimQK << sep; | ||||||
| problemOS << "-gemmO " << headDimV; | ||||||
| } | ||||||
|
|
||||||
| // Analyze and append fusion information | ||||||
| Value gemmGemmOutput = gemmGemmOp.getOutArgument()->get(); | ||||||
| GemmFeatures features = rock::getFeatures(gemmGemmOp); | ||||||
| FusionInfo fusionInfo = getFusionInfo(gemmGemmOutput, features); | ||||||
| appendOutputFusionInfo(problemOS, fusionInfo); | ||||||
|
|
||||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
|
|
@@ -1134,6 +1337,12 @@ static LogicalResult getTuningProblemStr(rock::RockGemmWrapperInterface gemmIF, | |||||
| return failure(); | ||||||
| } | ||||||
|
|
||||||
| // Analyze and append fusion information | ||||||
| Value gemmOutput = gemmIF.getOutArgument()->get(); | ||||||
| GemmFeatures features = rock::getFeatures(gemmIF); | ||||||
| FusionInfo fusionInfo = getFusionInfo(gemmOutput, features); | ||||||
| appendOutputFusionInfo(problemOS, fusionInfo); | ||||||
|
|
||||||
| while (out.back() == sep) { | ||||||
| // remove trailing whitespace | ||||||
| out.pop_back(); | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s | ||
|
|
||
| // CHECK: 256 convfp16 -F 1 -f GNC01 -I NGC01 -O NGC01 -n 1 -c 128 -H 32 -W 32 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -fusion_reduce count=2 sum:rank3:axis2:stride1:hasPointwise sum:rank3:axis2:stride1:hasPointwise | ||
|
|
||
| #map = affine_map<(d0, d1, d2, d3) -> (((d0 * 128 + d1) * 3 + d2) * 3 + d3)> | ||
| #map1 = affine_map<(d0, d1, d2, d3) -> ((d1 * 32 + d2) * 32 + d3)> | ||
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 * 128 + d2, d3, d4)> | ||
| #map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0 * 256 + d1, d2, d3, d4)> | ||
| #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 * 256 + d2, d3, d4)> | ||
| #map5 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d0, d1, d2, d4)> | ||
| #map6 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d0, d1, d4)> | ||
| #map7 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 * 8 + d2, d3, d4)> | ||
| #map8 = affine_map<(d0, d1, d2, d3, d4) -> (d1 * 8 + d2)> | ||
| #map9 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, 0, 0)> | ||
| #map10 = affine_map<(d0, d1, d2, d3) -> (0, d0, d1, d2, d3)> | ||
| #map11 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | ||
| #map12 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)> | ||
| #map13 = affine_map<(d0) -> (0, d0 floordiv 8192, (d0 mod 8192) floordiv 1024, (d0 mod 1024) floordiv 32, d0 mod 32)> | ||
| #map14 = affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 1024, (d2 mod 1024) floordiv 32, d2 mod 32)> | ||
| #map15 = affine_map<(d0) -> (0, d0, 0)> | ||
| #transform_map = #rock.transform_map<#map by [<Unmerge{256, 128, 3, 3} ["exp0", "exp1", "exp2", "exp3"] at [0, 1, 2, 3] -> ["dim0"] at [0]>] bounds = [256, 128, 3, 3] -> [294912]> | ||
| #transform_map1 = #rock.transform_map<#map1 by [<Unmerge{128, 32, 32} ["exp1", "exp2", "exp3"] at [1, 2, 3] -> ["dim0"] at [0]>, <AddDim{1} ["unit0"] at [0] -> [] at []>] bounds = [1, 128, 32, 32] -> [131072]> | ||
| #transform_map2 = #rock.transform_map<#map2 by [<PassThrough ["n", "h", "w"] at [0, 3, 4] -> ["n", "h", "w"] at [0, 2, 3]>, <Unmerge{1, 128} ["g", "c"] at [1, 2] -> ["c"] at [1]>] bounds = [1, 1, 128, 32, 32] -> [1, 128, 32, 32]> | ||
| #transform_map3 = #rock.transform_map<#map3 by [<PassThrough ["c", "y", "x"] at [2, 3, 4] -> ["c", "y", "x"] at [1, 2, 3]>, <Unmerge{1, 256} ["g", "k"] at [0, 1] -> ["k"] at [0]>] bounds = [1, 256, 128, 3, 3] -> [256, 128, 3, 3]> | ||
| #transform_map4 = #rock.transform_map<#map4 by [<PassThrough ["n", "h", "w"] at [0, 3, 4] -> ["n", "h", "w"] at [0, 2, 3]>, <Unmerge{1, 256} ["g", "k"] at [1, 2] -> ["k"] at [1]>] bounds = [1, 1, 256, 32, 32] -> [1, 256, 32, 32]> | ||
| #transform_map5 = #rock.transform_map<#map5 by [<PassThrough ["dim1", "dim2", "dim3", "dim0", "dim4"] at [0, 1, 2, 3, 4] -> ["dim1", "dim2", "dim3", "dim0", "dim4"] at [1, 2, 3, 0, 4]>] bounds = [256, 128, 3, 1, 3] -> [1, 256, 128, 3, 3]> | ||
| #transform_map6 = #rock.transform_map<#map6 by [<PassThrough ["dim2", "dim3", "dim0", "dim1", "dim4"] at [0, 1, 2, 3, 4] -> ["dim2", "dim3", "dim0", "dim1", "dim4"] at [2, 3, 0, 1, 4]>] bounds = [128, 32, 1, 1, 32] -> [1, 1, 128, 32, 32]> | ||
| #transform_map7 = #rock.transform_map<#map7 by [<PassThrough ["dim0"] at [0] -> ["dim0"] at [0]>, <Unmerge{32, 8} ["exp1", "exp2"] at [1, 2] -> ["dim1"] at [1]>, <PassThrough ["dim2"] at [3] -> ["dim2"] at [2]>, <PassThrough ["dim3"] at [4] -> ["dim3"] at [3]>] bounds = [1, 32, 8, 32, 32] -> [1, 256, 32, 32]> | ||
| #transform_map8 = #rock.transform_map<#map8 by [<Unmerge{32, 8} ["exp1", "exp2"] at [1, 2] -> ["dim0"] at [0]>, <AddDim{1} ["unit0"] at [0] -> [] at []>, <AddDim{1} ["unit3"] at [3] -> [] at []>, <AddDim{1} ["unit4"] at [4] -> [] at []>] bounds = [1, 32, 8, 1, 1] -> [256]> | ||
| #transform_map9 = #rock.transform_map<#map9 by [<PassThrough ["dim0"] at [0] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [1] -> ["dim1"] at [1]>, <PassThrough ["dim2"] at [2] -> ["dim2"] at [2]>, <Broadcast{1} ["dim3"] at [3] -> ["dim3"] at [3]>, <Broadcast{1} ["dim4"] at [4] -> ["dim4"] at [4]>] bounds = [1, 32, 8, 32, 32] -> [1, 32, 8, 1, 1]> | ||
| #transform_map10 = #rock.transform_map<#map10 by [<Merge{1, 32} ["dim0"] at [0] -> ["col0", "col1"] at [0, 1]>, <PassThrough ["dim1"] at [1] -> ["dim1"] at [2]>, <PassThrough ["dim2"] at [2] -> ["dim2"] at [3]>, <PassThrough ["dim3"] at [3] -> ["dim3"] at [4]>] bounds = [32, 8, 32, 32] -> [1, 32, 8, 32, 32]> | ||
| #transform_map11 = #rock.transform_map<#map12 by [<Unmerge{32} ["exp1"] at [1] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [2] -> ["dim1"] at [1]>, <PassThrough ["dim2"] at [3] -> ["dim2"] at [2]>, <PassThrough ["dim3"] at [4] -> ["dim3"] at [3]>, <AddDim{1} ["unit0"] at [0] -> [] at []>] bounds = [1, 32, 8, 32, 32] -> [32, 8, 32, 32]> | ||
| #transform_map12 = #rock.transform_map<#map13 by [<Merge{1, 32, 8, 32, 32} ["dim0"] at [0] -> ["col0", "col1", "col2", "col3", "col4"] at [0, 1, 2, 3, 4]>] bounds = [262144] -> [1, 32, 8, 32, 32]> | ||
| #transform_map13 = #rock.transform_map<#map14 by [<PassThrough ["dim0"] at [0] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [1] -> ["dim1"] at [1]>, <Merge{8, 32, 32} ["dim2"] at [2] -> ["col2", "col3", "col4"] at [2, 3, 4]>] bounds = [1, 32, 8192] -> [1, 32, 8, 32, 32]> | ||
| #transform_map14 = #rock.transform_map<#map15 by [<Merge{1, 32, 1} ["dim0"] at [0] -> ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [32] -> [1, 32, 1]> | ||
| module { | ||
| func.func @mlir_reshape_convolution_reshape_broadcast_add_mul_reshape_reduce_sum_reshape_mul_reshape_reduce_sum_reshape(%arg0: memref<131072xf16>, %arg1: memref<294912xf16>, %arg2: memref<256xf16>, %arg3: memref<32xf16> {mhal.read_access, rock.prefill = 0.000000e+00 : f16}, %arg4: memref<32xf16> {mhal.read_access, rock.prefill = 0.000000e+00 : f16}, %arg5: memref<262144xf16>) attributes {arch = "gfx950:sramecc+:xnack-", enable_splitk_for_tuning, kernel = "mixr", num_cu = 256 : i64} { | ||
| %cst = arith.constant 1.220700e-04 : f16 | ||
| %0 = rock.transform %arg1 by #transform_map : memref<294912xf16> to memref<256x128x3x3xf16> | ||
| %1 = rock.transform %arg0 by #transform_map1 : memref<131072xf16> to memref<1x128x32x32xf16> | ||
| %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x256x32x32xf16> | ||
| %2 = rock.transform %1 by #transform_map2 : memref<1x128x32x32xf16> to memref<1x1x128x32x32xf16> | ||
| %3 = rock.transform %0 by #transform_map3 : memref<256x128x3x3xf16> to memref<1x256x128x3x3xf16> | ||
| %4 = rock.transform %alloc by #transform_map4 : memref<1x256x32x32xf16> to memref<1x1x256x32x32xf16> | ||
| %5 = rock.transform %3 by #transform_map5 : memref<1x256x128x3x3xf16> to memref<256x128x3x1x3xf16> | ||
| %6 = rock.transform %2 by #transform_map6 : memref<1x1x128x32x32xf16> to memref<128x32x1x1x32xf16> | ||
| rock.conv(%3, %2, %4) {dilations = [1 : index, 1 : index], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], output_layout = ["no", "go", "ko", "ho", "wo"], padding = [1 : index, 1 : index, 1 : index, 1 : index], strides = [1 : index, 1 : index]} : memref<1x256x128x3x3xf16>, memref<1x1x128x32x32xf16>, memref<1x1x256x32x32xf16> | ||
| %7 = rock.transform %alloc by #transform_map7 : memref<1x256x32x32xf16> to memref<1x32x8x32x32xf16> | ||
| %8 = rock.transform %arg2 by #transform_map8 : memref<256xf16> to memref<1x32x8x1x1xf16> | ||
| %9 = rock.transform %8 by #transform_map9 : memref<1x32x8x1x1xf16> to memref<1x32x8x32x32xf16> | ||
| %10 = rock.transform %7 by #transform_map10 : memref<1x32x8x32x32xf16> to memref<32x8x32x32xf16> | ||
| %11 = rock.transform %9 by #transform_map10 : memref<1x32x8x32x32xf16> to memref<32x8x32x32xf16> | ||
| %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x8x32x32xf16> | ||
| linalg.generic {indexing_maps = [#map11, #map11, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%10, %11 : memref<32x8x32x32xf16>, memref<32x8x32x32xf16>) outs(%alloc_0 : memref<32x8x32x32xf16>) { | ||
| ^bb0(%in: f16, %in_5: f16, %out: f16): | ||
| %20 = arith.addf %in, %in_5 : f16 | ||
| linalg.yield %20 : f16 | ||
| } | ||
| %12 = rock.transform %alloc_0 by #transform_map11 : memref<32x8x32x32xf16> to memref<1x32x8x32x32xf16> | ||
| %13 = rock.transform %12 by #transform_map12 : memref<1x32x8x32x32xf16> to memref<262144xf16> | ||
| %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x8x32x32xf16> | ||
| linalg.generic {indexing_maps = [#map11, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_0 : memref<32x8x32x32xf16>) outs(%alloc_1 : memref<32x8x32x32xf16>) { | ||
| ^bb0(%in: f16, %out: f16): | ||
| %20 = arith.mulf %in, %cst : f16 | ||
| linalg.yield %20 : f16 | ||
| } | ||
| %14 = rock.transform %alloc_1 by #transform_map11 : memref<32x8x32x32xf16> to memref<1x32x8x32x32xf16> | ||
| %15 = rock.transform %14 by #transform_map13 : memref<1x32x8x32x32xf16> to memref<1x32x8192xf16> | ||
| %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x32x1xf16> | ||
| rock.reduce sum %15 into %alloc_2 {axis = 2 : index, blockSize = 256 : i32, gridSize = 1024 : i32} : memref<1x32x8192xf16> into memref<1x32x1xf16> | ||
| %16 = rock.transform %alloc_2 by #transform_map14 : memref<1x32x1xf16> to memref<32xf16> | ||
| %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x8x32x32xf16> | ||
| linalg.generic {indexing_maps = [#map11, #map11, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_1, %alloc_0 : memref<32x8x32x32xf16>, memref<32x8x32x32xf16>) outs(%alloc_3 : memref<32x8x32x32xf16>) { | ||
| ^bb0(%in: f16, %in_5: f16, %out: f16): | ||
| %20 = arith.mulf %in, %in_5 : f16 | ||
| linalg.yield %20 : f16 | ||
| } | ||
| %17 = rock.transform %alloc_3 by #transform_map11 : memref<32x8x32x32xf16> to memref<1x32x8x32x32xf16> | ||
| %18 = rock.transform %17 by #transform_map13 : memref<1x32x8x32x32xf16> to memref<1x32x8192xf16> | ||
| %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<1x32x1xf16> | ||
| rock.reduce sum %18 into %alloc_4 {axis = 2 : index, blockSize = 256 : i32, gridSize = 1024 : i32} : memref<1x32x8192xf16> into memref<1x32x1xf16> | ||
| %19 = rock.transform %alloc_4 by #transform_map14 : memref<1x32x1xf16> to memref<32xf16> | ||
| memref.copy %16, %arg3 : memref<32xf16> to memref<32xf16> | ||
| memref.copy %19, %arg4 : memref<32xf16> to memref<32xf16> | ||
| memref.copy %13, %arg5 : memref<262144xf16> to memref<262144xf16> | ||
| return | ||
| } | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.