diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 156d35af4842..f20cb83491d4 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -287,12 +287,14 @@ static LDSLayoutConfigDim getLDSLayoutConfigDim(Type elementType, int64_t kpack, LDSLayoutConfigDim cfg; int64_t maxVlen = 128 / elementType.getIntOrFloatBitWidth(); int64_t copyDPerThread = vecDimInfo.inDPerThread; + int64_t copyKPerThread = vecDimInfo.inKPerThread; bool isKContiguousDim = vecDimInfo.vectorDim == GemmDimension::K; // If kpack is less than the hardware max vector length, and we are // writing more contiguous kpack elements, there is a possibility to // vectorize that we want to preserve (i.e., we favour vectorization over // bank conflicts resolution) - bool isPossibleToVectorizeD = (kpack < maxVlen && copyDPerThread > 1); + bool isPossibleToVectorizeD = + (kpack < maxVlen && copyDPerThread > 1) && (copyKPerThread >= kpack); cfg.doRotateWithK = isKContiguousDim && !isPossibleToVectorizeD; cfg.doSwapThreadIterSubDims = !isKContiguousDim && !isPossibleToVectorizeD; cfg.ldsLayoutDxK = false;