From 1214aceb13c555ca582903794b19874c2e90916f Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Fri, 21 Nov 2025 18:41:39 +0000 Subject: [PATCH 01/17] Add gemm tests --- .../gemm-add-reduce-sum.mlir | 15 +++++++++++++ .../gemm-mul-reduce-sum.mlir | 15 +++++++++++++ .../gemm-multi-reduce-layernorm.mlir | 21 +++++++++++++++++++ .../problem-key-tests/gemm-no-fusion.mlir | 12 +++++++++++ .../gemm-passthrough-and-reduce.mlir | 18 ++++++++++++++++ .../gemm-reduce-max-axis2.mlir | 13 ++++++++++++ .../gemm-reduce-sum-axis1.mlir | 13 ++++++++++++ .../gemm-reduce-sum-axis2.mlir | 13 ++++++++++++ 8 files changed, 120 insertions(+) create mode 100644 mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir create mode 100644 mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir create mode 100644 mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir create mode 100644 mlir/test/fusion/problem-key-tests/gemm-no-fusion.mlir create mode 100644 mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir create mode 100644 mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir create mode 100644 mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir create mode 100644 mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir diff --git a/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir new file mode 100644 index 000000000000..08c41964a69b --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir @@ -0,0 +1,15 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_add_reduce_sum(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, + %bias: !migraphx.shaped<1x128x256xf32, 32768x256x1> {mhal.read_access}) + -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + %add = migraphx.add %gemm, %bias : <1x128x256xf32, 32768x256x1>, <1x128x256xf32, 32768x256x1> -> <1x128x256xf32, 32768x256x1> + %result = migraphx.reduce_sum %add {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> + return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir new file mode 100644 index 000000000000..1de2e3abbb62 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir @@ -0,0 +1,15 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_mul_reduce_sum(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, + %scale: !migraphx.shaped<1x128x256xf32, 32768x256x1> {mhal.read_access}) + -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + %mul = migraphx.mul %gemm, %scale : <1x128x256xf32, 32768x256x1>, <1x128x256xf32, 32768x256x1> -> <1x128x256xf32, 32768x256x1> + %result = migraphx.reduce_sum %mul {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> + return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir new file mode 100644 index 000000000000..2d5dd967cdbf --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir @@ -0,0 +1,21 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_multi_reduce(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}) + -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}, + !migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + + // First reduction: reduce_sum(x) + %reduce1 = migraphx.reduce_sum %gemm {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> + + // Second reduction: reduce_sum(x * x) + %square = migraphx.mul %gemm, %gemm : <1x128x256xf32, 32768x256x1>, <1x128x256xf32, 32768x256x1> -> <1x128x256xf32, 32768x256x1> + %reduce2 = migraphx.reduce_sum %square {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> + + return %reduce1, %reduce2 : !migraphx.shaped<1x128x1xf32, 128x1x1>, !migraphx.shaped<1x128x1xf32, 128x1x1> + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-no-fusion.mlir b/mlir/test/fusion/problem-key-tests/gemm-no-fusion.mlir new file mode 100644 index 000000000000..4a78d930b1c7 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-no-fusion.mlir @@ -0,0 +1,12 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_no_fusion(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}) + -> (!migraphx.shaped<1x128x256xf32, 32768x256x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %result = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + return %result : !migraphx.shaped<1x128x256xf32, 32768x256x1> + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir b/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir new file mode 100644 index 000000000000..54e8542313d8 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir @@ -0,0 +1,18 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_passthrough_reduce(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}) + -> (!migraphx.shaped<1x128x256xf32, 32768x256x1> {mhal.write_access}, + !migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + + // Output 1: passthrough the gemm result + // Output 2: reduce_sum + %reduce = migraphx.reduce_sum %gemm {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> + + return %gemm, %reduce : !migraphx.shaped<1x128x256xf32, 32768x256x1>, !migraphx.shaped<1x128x1xf32, 128x1x1> + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir new file mode 100644 index 000000000000..fe0fae19394c --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir @@ -0,0 +1,13 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_reduce_max_axis2(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}) + -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + %result = migraphx.reduce_max %gemm {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> + return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir new file mode 100644 index 000000000000..7a4afb55f0da --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir @@ -0,0 +1,13 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_reduce_sum_axis1(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}) + -> (!migraphx.shaped<1x1x256xf32, 256x256x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + %result = migraphx.reduce_sum %gemm {axes = [1 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x1x256xf32, 256x256x1> + return %result : !migraphx.shaped<1x1x256xf32, 256x256x1> + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir new file mode 100644 index 000000000000..a81f84f27f5f --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir @@ -0,0 +1,13 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_reduce_sum_axis2(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}) + -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + %result = migraphx.reduce_sum %gemm {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> + return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + } +} + From 540ab2ae42cdb758e918d1a7ca31f4102be0460e Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Fri, 21 Nov 2025 18:49:46 +0000 Subject: [PATCH 02/17] Add GEG tests --- .../gemm-gemm-add-reduce-sum.mlir | 17 ++++++++++++++ .../gemm-gemm-multi-reduce.mlir | 23 +++++++++++++++++++ .../gemm-gemm-no-fusion.mlir | 14 +++++++++++ .../gemm-gemm-reduce-max-axis2.mlir | 15 ++++++++++++ .../gemm-gemm-reduce-sum-axis1.mlir | 15 ++++++++++++ .../gemm-gemm-reduce-sum-axis2.mlir | 15 ++++++++++++ 6 files changed, 99 insertions(+) create mode 100644 mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir create mode 100644 mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir create mode 100644 mlir/test/fusion/problem-key-tests/gemm-gemm-no-fusion.mlir create mode 100644 mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir create mode 100644 mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir create mode 100644 mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir new file mode 100644 index 000000000000..2a15414e47f7 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir @@ -0,0 +1,17 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_gemm_add_reduce_sum(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, + %c: !migraphx.shaped<1x256x128xf32, 32768x128x1> {mhal.read_access}, + %bias: !migraphx.shaped<1x128x128xf32, 16384x128x1> {mhal.read_access}) + -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %gemm0 = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + %gemm1 = migraphx.dot %gemm0, %c : <1x128x256xf32, 32768x256x1>, <1x256x128xf32, 32768x128x1> -> <1x128x128xf32, 16384x128x1> + %add = migraphx.add %gemm1, %bias : <1x128x128xf32, 16384x128x1>, <1x128x128xf32, 16384x128x1> -> <1x128x128xf32, 16384x128x1> + %result = migraphx.reduce_sum %add {axes = [2 : i64]} : <1x128x128xf32, 16384x128x1> -> <1x128x1xf32, 128x1x1> + return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir new file mode 100644 index 000000000000..86bd5c5b2957 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir @@ -0,0 +1,23 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_gemm_multi_reduce(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, + %c: !migraphx.shaped<1x256x128xf32, 32768x128x1> {mhal.read_access}) + -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}, + !migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %gemm0 = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + %gemm1 = migraphx.dot %gemm0, %c : <1x128x256xf32, 32768x256x1>, <1x256x128xf32, 32768x128x1> -> <1x128x128xf32, 16384x128x1> + + // First reduction: reduce_sum(x) + %reduce1 = migraphx.reduce_sum %gemm1 {axes = [2 : i64]} : <1x128x128xf32, 16384x128x1> -> <1x128x1xf32, 128x1x1> + + // Second reduction: reduce_sum(x * x) + %square = migraphx.mul %gemm1, %gemm1 : <1x128x128xf32, 16384x128x1>, <1x128x128xf32, 16384x128x1> -> <1x128x128xf32, 16384x128x1> + %reduce2 = migraphx.reduce_sum %square {axes = [2 : i64]} : <1x128x128xf32, 16384x128x1> -> <1x128x1xf32, 128x1x1> + + return %reduce1, %reduce2 : !migraphx.shaped<1x128x1xf32, 128x1x1>, !migraphx.shaped<1x128x1xf32, 128x1x1> + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-no-fusion.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-no-fusion.mlir new file mode 100644 index 000000000000..c74b3f7f7f9c --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-no-fusion.mlir @@ -0,0 +1,14 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_gemm_no_fusion(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, + %c: !migraphx.shaped<1x256x128xf32, 32768x128x1> {mhal.read_access}) + -> (!migraphx.shaped<1x128x128xf32, 16384x128x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %gemm0 = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + %gemm1 = migraphx.dot %gemm0, %c : <1x128x256xf32, 32768x256x1>, <1x256x128xf32, 32768x128x1> -> <1x128x128xf32, 16384x128x1> + return %gemm1 : !migraphx.shaped<1x128x128xf32, 16384x128x1> + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir new file mode 100644 index 000000000000..612ad0ea6dc3 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir @@ -0,0 +1,15 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_gemm_reduce_max_axis2(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, + %c: !migraphx.shaped<1x256x128xf32, 32768x128x1> {mhal.read_access}) + -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %gemm0 = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + %gemm1 = migraphx.dot %gemm0, %c : <1x128x256xf32, 32768x256x1>, <1x256x128xf32, 32768x128x1> -> <1x128x128xf32, 16384x128x1> + %result = migraphx.reduce_max %gemm1 {axes = [2 : i64]} : <1x128x128xf32, 16384x128x1> -> <1x128x1xf32, 128x1x1> + return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir new file mode 100644 index 000000000000..3bdd9821e928 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir @@ -0,0 +1,15 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_gemm_reduce_sum_axis1(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, + %c: !migraphx.shaped<1x256x128xf32, 32768x128x1> {mhal.read_access}) + -> (!migraphx.shaped<1x1x128xf32, 128x128x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %gemm0 = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + %gemm1 = migraphx.dot %gemm0, %c : <1x128x256xf32, 32768x256x1>, <1x256x128xf32, 32768x128x1> -> <1x128x128xf32, 16384x128x1> + %result = migraphx.reduce_sum %gemm1 {axes = [1 : i64]} : <1x128x128xf32, 16384x128x1> -> <1x1x128xf32, 128x128x1> + return %result : !migraphx.shaped<1x1x128xf32, 128x128x1> + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir new file mode 100644 index 000000000000..24dcf45c60ff --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir @@ -0,0 +1,15 @@ +// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s + +module { + func.func private @gemm_gemm_reduce_sum_axis2(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, + %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, + %c: !migraphx.shaped<1x256x128xf32, 32768x128x1> {mhal.read_access}) + -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) + attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { + %gemm0 = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> + %gemm1 = migraphx.dot %gemm0, %c : <1x128x256xf32, 32768x256x1>, <1x256x128xf32, 32768x128x1> -> <1x128x128xf32, 16384x128x1> + %result = migraphx.reduce_sum %gemm1 {axes = [2 : i64]} : <1x128x128xf32, 16384x128x1> -> <1x128x1xf32, 128x1x1> + return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + } +} + From eef4b5f417b9c352c85f2e939959d58857b5a053 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Fri, 21 Nov 2025 19:03:54 +0000 Subject: [PATCH 03/17] Remove rocmlir-driver command from LIT tests --- .../gemm-add-reduce-sum.mlir | 55 ++++++++++--- .../gemm-gemm-add-reduce-sum.mlir | 71 ++++++++++++++--- .../gemm-gemm-multi-reduce.mlir | 78 ++++++++++++++----- .../gemm-gemm-no-fusion.mlir | 51 +++++++++--- .../gemm-gemm-reduce-max-axis2.mlir | 54 ++++++++++--- .../gemm-gemm-reduce-sum-axis1.mlir | 54 ++++++++++--- .../gemm-gemm-reduce-sum-axis2.mlir | 54 ++++++++++--- .../gemm-mul-reduce-sum.mlir | 54 ++++++++++--- .../gemm-multi-reduce-layernorm.mlir | 61 +++++++++++---- .../problem-key-tests/gemm-no-fusion.mlir | 33 ++++++-- .../gemm-passthrough-and-reduce.mlir | 45 +++++++---- .../gemm-reduce-max-axis2.mlir | 36 +++++++-- .../gemm-reduce-sum-axis1.mlir | 36 +++++++-- .../gemm-reduce-sum-axis2.mlir | 36 +++++++-- 14 files changed, 569 insertions(+), 149 deletions(-) diff --git a/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir index 08c41964a69b..2160b45e83ef 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir @@ -1,15 +1,50 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0, d1) -> (0, d0, d1)> +#map4 = affine_map<(d0, d1) -> (d0, d1)> +#map5 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map6 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 256] -> [32768]> +#transform_map1 = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map2 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map6 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>] bounds = [128, 256] -> [1, 128, 256]> +#transform_map8 = #rock.transform_map<#map5 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, [] at []>] bounds = [1, 128, 256] -> [128, 256]> +#transform_map9 = #rock.transform_map<#map6 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> module { - func.func private @gemm_add_reduce_sum(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, - %bias: !migraphx.shaped<1x128x256xf32, 32768x256x1> {mhal.read_access}) - -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - %add = migraphx.add %gemm, %bias : <1x128x256xf32, 32768x256x1>, <1x128x256xf32, 32768x256x1> -> <1x128x256xf32, 32768x256x1> - %result = migraphx.reduce_sum %add {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> - return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + func.func private @gemm_add_reduce_sum(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x128x256xf32> + %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> + %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> + %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> + %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> + %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %2 * %1 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %7 = rock.transform %alloc by #transform_map7 : memref<1x128x256xf32> to memref<128x256xf32> + %8 = rock.transform %0 by #transform_map7 : memref<1x128x256xf32> to memref<128x256xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<128x256xf32> + linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%7, %8 : memref<128x256xf32>, memref<128x256xf32>) outs(%alloc_0 : memref<128x256xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %11 = arith.addf %in, %in_2 : f32 + linalg.yield %11 : f32 + } + %9 = rock.transform %alloc_0 by #transform_map8 : memref<128x256xf32> to memref<1x128x256xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %9 into %alloc_1 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %10 = rock.transform %alloc_1 by #transform_map9 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %10, %arg3 : memref<128xf32> to memref<128xf32> + return } } + diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir index 2a15414e47f7..49fc37fb31de 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir @@ -1,17 +1,64 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 + +#map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0, d1) -> (0, d0, d1)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> +#map6 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map7 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 128] -> [16384]> +#transform_map1 = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 256, 128] -> [32768]> +#transform_map2 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map4 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map5 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map6 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map8 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [256, 1, 128] -> [1, 256, 128]> +#transform_map9 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 256, 128] -> [256, 1, 128]> +#transform_map10 = #rock.transform_map<#map4 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>] bounds = [128, 128] -> [1, 128, 128]> +#transform_map11 = #rock.transform_map<#map6 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, [] at []>] bounds = [1, 128, 128] -> [128, 128]> +#transform_map12 = #rock.transform_map<#map7 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> module { - func.func private @gemm_gemm_add_reduce_sum(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, - %c: !migraphx.shaped<1x256x128xf32, 32768x128x1> {mhal.read_access}, - %bias: !migraphx.shaped<1x128x128xf32, 16384x128x1> {mhal.read_access}) - -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %gemm0 = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - %gemm1 = migraphx.dot %gemm0, %c : <1x128x256xf32, 32768x256x1>, <1x256x128xf32, 32768x128x1> -> <1x128x128xf32, 16384x128x1> - %add = migraphx.add %gemm1, %bias : <1x128x128xf32, 16384x128x1>, <1x128x128xf32, 16384x128x1> -> <1x128x128xf32, 16384x128x1> - %result = migraphx.reduce_sum %add {axes = [2 : i64]} : <1x128x128xf32, 16384x128x1> -> <1x128x1xf32, 128x1x1> - return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + func.func private @gemm_gemm_add_reduce_sum(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<16384xf32> {mhal.read_access}, %arg4: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg3 by #transform_map : memref<16384xf32> to memref<1x128x128xf32> + %1 = rock.transform %arg2 by #transform_map1 : memref<32768xf32> to memref<1x256x128xf32> + %2 = rock.transform %arg1 by #transform_map2 : memref<16384xf32> to memref<1x64x256xf32> + %3 = rock.transform %arg0 by #transform_map3 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x128xf32> + %4 = rock.transform %3 by #transform_map4 : memref<1x128x64xf32> to memref<128x1x64xf32> + %5 = rock.transform %4 by #transform_map5 : memref<128x1x64xf32> to memref<1x128x64xf32> + %6 = rock.transform %2 by #transform_map6 : memref<1x64x256xf32> to memref<64x1x256xf32> + %7 = rock.transform %6 by #transform_map7 : memref<64x1x256xf32> to memref<1x64x256xf32> + %8 = rock.transform %1 by #transform_map8 : memref<1x256x128xf32> to memref<256x1x128xf32> + %9 = rock.transform %8 by #transform_map9 : memref<256x1x128xf32> to memref<1x256x128xf32> + rock.gemm_elementwise_gemm{ + ab = %3 * %2 : memref<1x128x64xf32>, memref<1x64x256xf32> + ab = elementwise { + ^bb0(%arg5: memref<1x128x256xf32>, %arg6: memref<1x128x256xf32>): + memref.copy %arg5, %arg6 : memref<1x128x256xf32> to memref<1x128x256xf32> + rock.yield + } + %alloc = ab * %1 : memref<1x256x128xf32> -> memref<1x128x128xf32> + } {firstGemmIndices = array, storeMethod = #rock} + %10 = rock.transform %alloc by #transform_map10 : memref<1x128x128xf32> to memref<128x128xf32> + %11 = rock.transform %0 by #transform_map10 : memref<1x128x128xf32> to memref<128x128xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32> + linalg.generic {indexing_maps = [#map5, #map5, #map5], iterator_types = ["parallel", "parallel"]} ins(%10, %11 : memref<128x128xf32>, memref<128x128xf32>) outs(%alloc_0 : memref<128x128xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %14 = arith.addf %in, %in_2 : f32 + linalg.yield %14 : f32 + } + %12 = rock.transform %alloc_0 by #transform_map11 : memref<128x128xf32> to memref<1x128x128xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %12 into %alloc_1 {axis = 2 : index, blockSize = 256 : i32, gridSize = 64 : i32} : memref<1x128x128xf32> into memref<1x128x1xf32> + %13 = rock.transform %alloc_1 by #transform_map12 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %13, %arg4 : memref<128xf32> to memref<128xf32> + return } } diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir index 86bd5c5b2957..ec43f1df2efc 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir @@ -1,23 +1,65 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 + +#map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0) -> (0, d0, 0)> +#map5 = affine_map<(d0, d1) -> (0, d0, d1)> +#map6 = affine_map<(d0, d1) -> (d0, d1)> +#map7 = affine_map<(d0, d1, d2) -> (d1, d2)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 256, 128] -> [32768]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map3 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map4 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map5 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [256, 1, 128] -> [1, 256, 128]> +#transform_map8 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 256, 128] -> [256, 1, 128]> +#transform_map9 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> +#transform_map10 = #rock.transform_map<#map5 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>] bounds = [128, 128] -> [1, 128, 128]> +#transform_map11 = #rock.transform_map<#map7 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, [] at []>] bounds = [1, 128, 128] -> [128, 128]> module { - func.func private @gemm_gemm_multi_reduce(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, - %c: !migraphx.shaped<1x256x128xf32, 32768x128x1> {mhal.read_access}) - -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}, - !migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %gemm0 = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - %gemm1 = migraphx.dot %gemm0, %c : <1x128x256xf32, 32768x256x1>, <1x256x128xf32, 32768x128x1> -> <1x128x128xf32, 16384x128x1> - - // First reduction: reduce_sum(x) - %reduce1 = migraphx.reduce_sum %gemm1 {axes = [2 : i64]} : <1x128x128xf32, 16384x128x1> -> <1x128x1xf32, 128x1x1> - - // Second reduction: reduce_sum(x * x) - %square = migraphx.mul %gemm1, %gemm1 : <1x128x128xf32, 16384x128x1>, <1x128x128xf32, 16384x128x1> -> <1x128x128xf32, 16384x128x1> - %reduce2 = migraphx.reduce_sum %square {axes = [2 : i64]} : <1x128x128xf32, 16384x128x1> -> <1x128x1xf32, 128x1x1> - - return %reduce1, %reduce2 : !migraphx.shaped<1x128x1xf32, 128x1x1>, !migraphx.shaped<1x128x1xf32, 128x1x1> + func.func private @gemm_gemm_multi_reduce(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}, %arg4: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x256x128xf32> + %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> + %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x128xf32> + %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> + %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> + %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> + %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> + %7 = rock.transform %0 by #transform_map7 : memref<1x256x128xf32> to memref<256x1x128xf32> + %8 = rock.transform %7 by #transform_map8 : memref<256x1x128xf32> to memref<1x256x128xf32> + rock.gemm_elementwise_gemm{ + ab = %2 * %1 : memref<1x128x64xf32>, memref<1x64x256xf32> + ab = elementwise { + ^bb0(%arg5: memref<1x128x256xf32>, %arg6: memref<1x128x256xf32>): + memref.copy %arg5, %arg6 : memref<1x128x256xf32> to memref<1x128x256xf32> + rock.yield + } + %alloc = ab * %0 : memref<1x256x128xf32> -> memref<1x128x128xf32> + } {firstGemmIndices = array, storeMethod = #rock} + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 64 : i32} : memref<1x128x128xf32> into memref<1x128x1xf32> + %9 = rock.transform %alloc_0 by #transform_map9 : memref<1x128x1xf32> to memref<128xf32> + %10 = rock.transform %alloc by #transform_map10 : memref<1x128x128xf32> to memref<128x128xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32> + linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%10 : memref<128x128xf32>) outs(%alloc_1 : memref<128x128xf32>) { + ^bb0(%in: f32, %out: f32): + %13 = arith.mulf %in, %in : f32 + linalg.yield %13 : f32 + } + %11 = rock.transform %alloc_1 by #transform_map11 : memref<128x128xf32> to memref<1x128x128xf32> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %11 into %alloc_2 {axis = 2 : index, blockSize = 256 : i32, gridSize = 64 : i32} : memref<1x128x128xf32> into memref<1x128x1xf32> + %12 = rock.transform %alloc_2 by #transform_map9 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %9, %arg3 : memref<128xf32> to memref<128xf32> + memref.copy %12, %arg4 : memref<128xf32> to memref<128xf32> + return } } diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-no-fusion.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-no-fusion.mlir index c74b3f7f7f9c..3b7f0115c309 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-no-fusion.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-no-fusion.mlir @@ -1,14 +1,47 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 + + +#map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0) -> (0, d0 floordiv 128, d0 mod 128)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 256, 128] -> [32768]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map3 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map4 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map5 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [256, 1, 128] -> [1, 256, 128]> +#transform_map8 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 256, 128] -> [256, 1, 128]> +#transform_map9 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [16384] -> [1, 128, 128]> module { - func.func private @gemm_gemm_no_fusion(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, - %c: !migraphx.shaped<1x256x128xf32, 32768x128x1> {mhal.read_access}) - -> (!migraphx.shaped<1x128x128xf32, 16384x128x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %gemm0 = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - %gemm1 = migraphx.dot %gemm0, %c : <1x128x256xf32, 32768x256x1>, <1x256x128xf32, 32768x128x1> -> <1x128x128xf32, 16384x128x1> - return %gemm1 : !migraphx.shaped<1x128x128xf32, 16384x128x1> + func.func private @gemm_gemm_no_fusion(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<16384xf32> {mhal.write_access}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x256x128xf32> + %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> + %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x128xf32> + %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> + %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> + %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> + %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> + %7 = rock.transform %0 by #transform_map7 : memref<1x256x128xf32> to memref<256x1x128xf32> + %8 = rock.transform %7 by #transform_map8 : memref<256x1x128xf32> to memref<1x256x128xf32> + rock.gemm_elementwise_gemm{ + ab = %2 * %1 : memref<1x128x64xf32>, memref<1x64x256xf32> + ab = elementwise { + ^bb0(%arg4: memref<1x128x256xf32>, %arg5: memref<1x128x256xf32>): + memref.copy %arg4, %arg5 : memref<1x128x256xf32> to memref<1x128x256xf32> + rock.yield + } + %alloc = ab * %0 : memref<1x256x128xf32> -> memref<1x128x128xf32> + } {firstGemmIndices = array, storeMethod = #rock} + %9 = rock.transform %alloc by #transform_map9 : memref<1x128x128xf32> to memref<16384xf32> + memref.copy %9, %arg3 : memref<16384xf32> to memref<16384xf32> + return } } diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir index 612ad0ea6dc3..4e74aeb6638b 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir @@ -1,15 +1,49 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 + + +#map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 256, 128] -> [32768]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map3 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map4 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map5 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [256, 1, 128] -> [1, 256, 128]> +#transform_map8 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 256, 128] -> [256, 1, 128]> +#transform_map9 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> module { - func.func private @gemm_gemm_reduce_max_axis2(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, - %c: !migraphx.shaped<1x256x128xf32, 32768x128x1> {mhal.read_access}) - -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %gemm0 = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - %gemm1 = migraphx.dot %gemm0, %c : <1x128x256xf32, 32768x256x1>, <1x256x128xf32, 32768x128x1> -> <1x128x128xf32, 16384x128x1> - %result = migraphx.reduce_max %gemm1 {axes = [2 : i64]} : <1x128x128xf32, 16384x128x1> -> <1x128x1xf32, 128x1x1> - return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + func.func private @gemm_gemm_reduce_max_axis2(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0xFF800000 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x256x128xf32> + %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> + %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x128xf32> + %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> + %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> + %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> + %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> + %7 = rock.transform %0 by #transform_map7 : memref<1x256x128xf32> to memref<256x1x128xf32> + %8 = rock.transform %7 by #transform_map8 : memref<256x1x128xf32> to memref<1x256x128xf32> + rock.gemm_elementwise_gemm{ + ab = %2 * %1 : memref<1x128x64xf32>, memref<1x64x256xf32> + ab = elementwise { + ^bb0(%arg4: memref<1x128x256xf32>, %arg5: memref<1x128x256xf32>): + memref.copy %arg4, %arg5 : memref<1x128x256xf32> to memref<1x128x256xf32> + rock.yield + } + %alloc = ab * %0 : memref<1x256x128xf32> -> memref<1x128x128xf32> + } {firstGemmIndices = array, storeMethod = #rock} + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce max %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 64 : i32} : memref<1x128x128xf32> into memref<1x128x1xf32> + %9 = rock.transform %alloc_0 by #transform_map9 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %9, %arg3 : memref<128xf32> to memref<128xf32> + return } } diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir index 3bdd9821e928..ba1fd3d2f7d2 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir @@ -1,15 +1,49 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 + + +#map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0) -> (0, 0, d0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 256, 128] -> [32768]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map3 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map4 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map5 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [256, 1, 128] -> [1, 256, 128]> +#transform_map8 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 256, 128] -> [256, 1, 128]> +#transform_map9 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 1, 128]> module { - func.func private @gemm_gemm_reduce_sum_axis1(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, - %c: !migraphx.shaped<1x256x128xf32, 32768x128x1> {mhal.read_access}) - -> (!migraphx.shaped<1x1x128xf32, 128x128x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %gemm0 = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - %gemm1 = migraphx.dot %gemm0, %c : <1x128x256xf32, 32768x256x1>, <1x256x128xf32, 32768x128x1> -> <1x128x128xf32, 16384x128x1> - %result = migraphx.reduce_sum %gemm1 {axes = [1 : i64]} : <1x128x128xf32, 16384x128x1> -> <1x1x128xf32, 128x128x1> - return %result : !migraphx.shaped<1x1x128xf32, 128x128x1> + func.func private @gemm_gemm_reduce_sum_axis1(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x256x128xf32> + %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> + %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x128xf32> + %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> + %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> + %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> + %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> + %7 = rock.transform %0 by #transform_map7 : memref<1x256x128xf32> to memref<256x1x128xf32> + %8 = rock.transform %7 by #transform_map8 : memref<256x1x128xf32> to memref<1x256x128xf32> + rock.gemm_elementwise_gemm{ + ab = %2 * %1 : memref<1x128x64xf32>, memref<1x64x256xf32> + ab = elementwise { + ^bb0(%arg4: memref<1x128x256xf32>, %arg5: memref<1x128x256xf32>): + memref.copy %arg4, %arg5 : memref<1x128x256xf32> to memref<1x128x256xf32> + rock.yield + } + %alloc = ab * %0 : memref<1x256x128xf32> -> memref<1x128x128xf32> + } {firstGemmIndices = array, storeMethod = #rock} + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x1x128xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 1 : index, blockSize = 256 : i32, gridSize = 64 : i32} : memref<1x128x128xf32> into memref<1x1x128xf32> + %9 = rock.transform %alloc_0 by #transform_map9 : memref<1x1x128xf32> to memref<128xf32> + memref.copy %9, %arg3 : memref<128xf32> to memref<128xf32> + return } } diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir index 24dcf45c60ff..1de4b503cd34 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir @@ -1,15 +1,49 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 + + +#map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 256, 128] -> [32768]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map3 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map4 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map5 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [256, 1, 128] -> [1, 256, 128]> +#transform_map8 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 256, 128] -> [256, 1, 128]> +#transform_map9 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> module { - func.func private @gemm_gemm_reduce_sum_axis2(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, - %c: !migraphx.shaped<1x256x128xf32, 32768x128x1> {mhal.read_access}) - -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %gemm0 = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - %gemm1 = migraphx.dot %gemm0, %c : <1x128x256xf32, 32768x256x1>, <1x256x128xf32, 32768x128x1> -> <1x128x128xf32, 16384x128x1> - %result = migraphx.reduce_sum %gemm1 {axes = [2 : i64]} : <1x128x128xf32, 16384x128x1> -> <1x128x1xf32, 128x1x1> - return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + func.func private @gemm_gemm_reduce_sum_axis2(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x256x128xf32> + %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> + %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x128xf32> + %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> + %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> + %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> + %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> + %7 = rock.transform %0 by #transform_map7 : memref<1x256x128xf32> to memref<256x1x128xf32> + %8 = rock.transform %7 by #transform_map8 : memref<256x1x128xf32> to memref<1x256x128xf32> + rock.gemm_elementwise_gemm{ + ab = %2 * %1 : memref<1x128x64xf32>, memref<1x64x256xf32> + ab = elementwise { + ^bb0(%arg4: memref<1x128x256xf32>, %arg5: memref<1x128x256xf32>): + memref.copy %arg4, %arg5 : memref<1x128x256xf32> to memref<1x128x256xf32> + rock.yield + } + %alloc = ab * %0 : memref<1x256x128xf32> -> memref<1x128x128xf32> + } {firstGemmIndices = array, storeMethod = #rock} + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 64 : i32} : memref<1x128x128xf32> into memref<1x128x1xf32> + %9 = rock.transform %alloc_0 by #transform_map9 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %9, %arg3 : memref<128xf32> to memref<128xf32> + return } } diff --git a/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir index 1de2e3abbb62..c6a00120ee72 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir @@ -1,15 +1,49 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0, d1) -> (0, d0, d1)> +#map4 = affine_map<(d0, d1) -> (d0, d1)> +#map5 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map6 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 256] -> [32768]> +#transform_map1 = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map2 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map6 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map7 = #rock.transform_map<#map3 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>] bounds = [128, 256] -> [1, 128, 256]> +#transform_map8 = #rock.transform_map<#map5 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, [] at []>] bounds = [1, 128, 256] -> [128, 256]> +#transform_map9 = #rock.transform_map<#map6 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> module { - func.func private @gemm_mul_reduce_sum(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}, - %scale: !migraphx.shaped<1x128x256xf32, 32768x256x1> {mhal.read_access}) - -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - %mul = migraphx.mul %gemm, %scale : <1x128x256xf32, 32768x256x1>, <1x128x256xf32, 32768x256x1> -> <1x128x256xf32, 32768x256x1> - %result = migraphx.reduce_sum %mul {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> - return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + func.func private @gemm_mul_reduce_sum(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x128x256xf32> + %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> + %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> + %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> + %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> + %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %2 * %1 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %7 = rock.transform %alloc by #transform_map7 : memref<1x128x256xf32> to memref<128x256xf32> + %8 = rock.transform %0 by #transform_map7 : memref<1x128x256xf32> to memref<128x256xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<128x256xf32> + linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%7, %8 : memref<128x256xf32>, memref<128x256xf32>) outs(%alloc_0 : memref<128x256xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %11 = arith.mulf %in, %in_2 : f32 + linalg.yield %11 : f32 + } + %9 = rock.transform %alloc_0 by #transform_map8 : memref<128x256xf32> to memref<1x128x256xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %9 into %alloc_1 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %10 = rock.transform %alloc_1 by #transform_map9 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %10, %arg3 : memref<128xf32> to memref<128xf32> + return } } diff --git a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir index 2d5dd967cdbf..57263687f3b8 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir @@ -1,21 +1,50 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, d0, 0)> +#map4 = affine_map<(d0, d1) -> (0, d0, d1)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> +#map6 = affine_map<(d0, d1, d2) -> (d1, d2)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> +#transform_map7 = #rock.transform_map<#map4 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>] bounds = [128, 256] -> [1, 128, 256]> +#transform_map8 = #rock.transform_map<#map6 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, [] at []>] bounds = [1, 128, 256] -> [128, 256]> module { - func.func private @gemm_multi_reduce(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}) - -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}, - !migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - - // First reduction: reduce_sum(x) - %reduce1 = migraphx.reduce_sum %gemm {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> - - // Second reduction: reduce_sum(x * x) - %square = migraphx.mul %gemm, %gemm : <1x128x256xf32, 32768x256x1>, <1x128x256xf32, 32768x256x1> -> <1x128x256xf32, 32768x256x1> - %reduce2 = migraphx.reduce_sum %square {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> - - return %reduce1, %reduce2 : !migraphx.shaped<1x128x1xf32, 128x1x1>, !migraphx.shaped<1x128x1xf32, 128x1x1> + func.func private @gemm_multi_reduce(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %6 = rock.transform %alloc_0 by #transform_map6 : memref<1x128x1xf32> to memref<128xf32> + %7 = rock.transform %alloc by #transform_map7 : memref<1x128x256xf32> to memref<128x256xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<128x256xf32> + linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel"]} ins(%7 : memref<128x256xf32>) outs(%alloc_1 : memref<128x256xf32>) { + ^bb0(%in: f32, %out: f32): + %10 = arith.mulf %in, %in : f32 + linalg.yield %10 : f32 + } + %8 = rock.transform %alloc_1 by #transform_map8 : memref<128x256xf32> to memref<1x128x256xf32> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %8 into %alloc_2 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %9 = rock.transform %alloc_2 by #transform_map6 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %6, %arg2 : memref<128xf32> to memref<128xf32> + memref.copy %9, %arg3 : memref<128xf32> to memref<128xf32> + return } } diff --git a/mlir/test/fusion/problem-key-tests/gemm-no-fusion.mlir b/mlir/test/fusion/problem-key-tests/gemm-no-fusion.mlir index 4a78d930b1c7..811c09132800 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-no-fusion.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-no-fusion.mlir @@ -1,12 +1,31 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, d0 floordiv 256, d0 mod 256)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [32768] -> [1, 128, 256]> module { - func.func private @gemm_no_fusion(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}) - -> (!migraphx.shaped<1x128x256xf32, 32768x256x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %result = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - return %result : !migraphx.shaped<1x128x256xf32, 32768x256x1> + func.func private @gemm_no_fusion(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.write_access}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %6 = rock.transform %alloc by #transform_map6 : memref<1x128x256xf32> to memref<32768xf32> + memref.copy %6, %arg2 : memref<32768xf32> to memref<32768xf32> + return } } diff --git a/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir b/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir index 54e8542313d8..4e0e673580bd 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir @@ -1,18 +1,37 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, d0 floordiv 256, d0 mod 256)> +#map4 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [32768] -> [1, 128, 256]> +#transform_map7 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> module { - func.func private @gemm_passthrough_reduce(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}) - -> (!migraphx.shaped<1x128x256xf32, 32768x256x1> {mhal.write_access}, - !migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - - // Output 1: passthrough the gemm result - // Output 2: reduce_sum - %reduce = migraphx.reduce_sum %gemm {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> - - return %gemm, %reduce : !migraphx.shaped<1x128x256xf32, 32768x256x1>, !migraphx.shaped<1x128x1xf32, 128x1x1> + func.func private @gemm_passthrough_reduce(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.write_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %6 = rock.transform %alloc by #transform_map6 : memref<1x128x256xf32> to memref<32768xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %7 = rock.transform %alloc_0 by #transform_map7 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %6, %arg2 : memref<32768xf32> to memref<32768xf32> + memref.copy %7, %arg3 : memref<128xf32> to memref<128xf32> + return } } diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir index fe0fae19394c..6be9c674815a 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir @@ -1,13 +1,33 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> module { - func.func private @gemm_reduce_max_axis2(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}) - -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - %result = migraphx.reduce_max %gemm {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> - return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + func.func private @gemm_reduce_max_axis2(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0xFF800000 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce max %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %6 = rock.transform %alloc_0 by #transform_map6 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %6, %arg2 : memref<128xf32> to memref<128xf32> + return } } diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir index 7a4afb55f0da..726822beff22 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir @@ -1,13 +1,33 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, 0, d0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [256] -> [1, 1, 256]> module { - func.func private @gemm_reduce_sum_axis1(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}) - -> (!migraphx.shaped<1x1x256xf32, 256x256x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - %result = migraphx.reduce_sum %gemm {axes = [1 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x1x256xf32, 256x256x1> - return %result : !migraphx.shaped<1x1x256xf32, 256x256x1> + func.func private @gemm_reduce_sum_axis1(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<256xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x1x256xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 1 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x1x256xf32> + %6 = rock.transform %alloc_0 by #transform_map6 : memref<1x1x256xf32> to memref<256xf32> + memref.copy %6, %arg2 : memref<256xf32> to memref<256xf32> + return } } diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir index a81f84f27f5f..5f20b801455e 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir @@ -1,13 +1,33 @@ -// RUN: rocmlir-driver -kernel-pipeline=migraphx,highlevel %s | rocmlir-gen --emit-tuning-key - | FileCheck %s +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> module { - func.func private @gemm_reduce_sum_axis2(%a: !migraphx.shaped<1x128x64xf32, 8192x64x1> {mhal.read_access}, - %b: !migraphx.shaped<1x64x256xf32, 16384x256x1> {mhal.read_access}) - -> (!migraphx.shaped<1x128x1xf32, 128x1x1> {mhal.write_access}) - attributes {kernel, arch = "gfx942", num_cu = 120 : i64} { - %gemm = migraphx.dot %a, %b : <1x128x64xf32, 8192x64x1>, <1x64x256xf32, 16384x256x1> -> <1x128x256xf32, 32768x256x1> - %result = migraphx.reduce_sum %gemm {axes = [2 : i64]} : <1x128x256xf32, 32768x256x1> -> <1x128x1xf32, 128x1x1> - return %result : !migraphx.shaped<1x128x1xf32, 128x1x1> + func.func private @gemm_reduce_sum_axis2(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %6 = rock.transform %alloc_0 by #transform_map6 : memref<1x128x1xf32> to memref<128xf32> + memref.copy %6, %arg2 : memref<128xf32> to memref<128xf32> + return } } From bd3f5c5d4dc8275946a89ce6d71b3652cc91c7ce Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Fri, 21 Nov 2025 20:23:14 +0000 Subject: [PATCH 04/17] Update RockTuningImpl for reduction fusion detection --- .../Dialect/Rock/Tuning/RockTuningImpl.cpp | 159 ++++++++++++++++++ .../gemm-add-reduce-sum.mlir | 2 +- .../gemm-gemm-add-reduce-sum.mlir | 2 +- .../gemm-gemm-multi-reduce.mlir | 5 +- .../gemm-gemm-reduce-max-axis2.mlir | 2 +- .../gemm-gemm-reduce-sum-axis1.mlir | 2 +- .../gemm-gemm-reduce-sum-axis2.mlir | 2 +- .../gemm-mul-reduce-sum.mlir | 2 +- .../gemm-multi-reduce-different-axes.mlir | 44 +++++ .../gemm-multi-reduce-layernorm.mlir | 5 +- .../gemm-passthrough-and-reduce.mlir | 2 +- .../gemm-reduce-max-axis2.mlir | 2 +- .../gemm-reduce-sum-axis1.mlir | 2 +- .../gemm-reduce-sum-axis2.mlir | 2 +- 14 files changed, 219 insertions(+), 14 deletions(-) create mode 100644 mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index 26fb3b571bdc..cf56a1700f08 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -22,9 +22,11 @@ #include "mlir/Dialect/Rock/Tuning/RockTuning.h" #include "mlir/Dialect/Rock/utility/fusionUtils.h" #include "mlir/Dialect/Rock/utility/loweringUtils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" @@ -742,6 +744,150 @@ extractLayouts(Operation *op, llvm::StringMap &fLayoutMap, return success(); } +// Structure to hold information about a single reduction operation +struct ReductionInfo { + ReduceMethod method; + int64_t axis; + bool hasPointwiseBefore; + + bool operator<(const ReductionInfo &other) const { + // Sort by method first, then axis for consistent ordering + if (method != other.method) + return method < other.method; + return axis < other.axis; + } + + bool operator==(const ReductionInfo &other) const { + return method == other.method && axis == other.axis && + hasPointwiseBefore == other.hasPointwiseBefore; + } +}; + +// Structure to hold fusion information for problem key generation +struct FusionInfo { + SmallVector reductions; + + bool hasReduction() const { return !reductions.empty(); } + bool hasMultipleReductions() const { return reductions.size() > 1; } + int numReductionOutputs() const { return reductions.size(); } +}; + +// Helper to analyze users, following through rock.transform operations +static void analyzeUsers(Value originalGemmResult, GemmFeatures features, + FusionInfo &info) { + // Worklist: + SmallVector> worklist; + worklist.push_back({originalGemmResult, false}); + + // Track visited values to avoid cycles + DenseSet visited; + + while (!worklist.empty()) { + auto [value, hasPointwiseSoFar] = worklist.pop_back_val(); + + if (!visited.insert(value).second) { + continue; // Already visited + } + + for (Operation *user : value.getUsers()) { + // Check for direct reduction + if (auto reduceOp = dyn_cast(user)) { + ReductionInfo redInfo; + redInfo.method = reduceOp.getReduceMethod(); + redInfo.axis = reduceOp.getAxis().getSExtValue(); + redInfo.hasPointwiseBefore = hasPointwiseSoFar; + info.reductions.push_back(redInfo); + continue; + } + + // Follow through rock.transform operations + if (auto transformOp = dyn_cast(user)) { + worklist.push_back({transformOp.getResult(), hasPointwiseSoFar}); + continue; + } + + // Check for linalg.generic (i.e., pointwise operations) + if (auto genericOp = dyn_cast(user)) { + // For memref-based linalg.generic, we need to check if this is reading + // from our value. We only care about cases where our value is an input + // to the pointwise operation. + bool isInput = false; + for (Value input : genericOp.getInputs()) { + if (input == value) { + isInput = true; + break; + } + } + + if (!isInput) { + continue; + } + + // Validate the output fusion + SmallVector> adds; + if (failed(rock::checkValidOutputFusion(genericOp, originalGemmResult, + features, adds))) { + continue; + } + + // We need to follow users of the output memref + auto outputs = genericOp.getOutputs(); + if (!outputs.empty() && outputs[0]) { + worklist.push_back({outputs[0], /*hasPointwiseSoFar=*/true}); + } + continue; + } + } + } +} + +// Analyze fusion patterns for a GEMM operation's output +static FusionInfo analyzeOuputFusionPattern(Value gemmResult, + GemmFeatures features) { + FusionInfo info; + analyzeUsers(gemmResult, features, info); + + // Sort reductions for consistent ordering in problem key + std::sort(info.reductions.begin(), info.reductions.end()); + + return info; +} + +// Append fusion information to the problem key string +static void appendOutputFusionInfo(llvm::raw_svector_ostream &problemOS, + const FusionInfo &fusionInfo) { + constexpr char sep = ' '; + + if (!fusionInfo.hasReduction()) + return; + + problemOS << sep << "-fusion_reduce" << sep << "count=" + << fusionInfo.numReductionOutputs(); + + // Encode each reduction in format: method:axis[:hasPointwise] + for (const auto &reduction : fusionInfo.reductions) { + problemOS << sep; + + // Add reduction method + switch (reduction.method) { + case ReduceMethod::Sum: + problemOS << "sum"; + break; + case ReduceMethod::Max: + problemOS << "max"; + break; + } + + // Add reduction axis with colon separator + problemOS << ":axis" << reduction.axis; + + // Add pointwise flag for this specific reduction + if (reduction.hasPointwiseBefore) { + problemOS << ":hasPointwise"; + } + } +} + static LogicalResult getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp, SmallVectorImpl &out) { @@ -917,6 +1063,13 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp, problemOS << "-k " << headDimQK << sep; problemOS << "-gemmO " << headDimV; } + + // Analyze and append fusion information + Value gemmGemmOutput = gemmGemmOp.getOutArgument()->get(); + GemmFeatures features = rock::getFeatures(gemmGemmOp); + FusionInfo fusionInfo = analyzeOuputFusionPattern(gemmGemmOutput, features); + appendOutputFusionInfo(problemOS, fusionInfo); + return success(); } @@ -1134,6 +1287,12 @@ static LogicalResult getTuningProblemStr(rock::RockGemmWrapperInterface gemmIF, return failure(); } + // Analyze and append fusion information + Value gemmOutput = gemmIF.getOutArgument()->get(); + GemmFeatures features = rock::getFeatures(gemmIF); + FusionInfo fusionInfo = analyzeOuputFusionPattern(gemmOutput, features); + appendOutputFusionInfo(problemOS, fusionInfo); + while (out.back() == sep) { // remove trailing whitespace out.pop_back(); diff --git a/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir index 2160b45e83ef..3fa1abc907b2 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:axis2:hasPointwise #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir index 49fc37fb31de..921d85901df0 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:axis2:hasPointwise #map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir index ec43f1df2efc..4c5dc2342afb 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=2 sum:axis2 sum:axis2:hasPointwise #map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> @@ -50,7 +50,8 @@ module { %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32> linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%10 : memref<128x128xf32>) outs(%alloc_1 : memref<128x128xf32>) { ^bb0(%in: f32, %out: f32): - %13 = arith.mulf %in, %in : f32 + %cst = arith.constant 2.0 : f32 + %13 = arith.mulf %in, %cst : f32 linalg.yield %13 : f32 } %11 = rock.transform %alloc_1 by #transform_map11 : memref<128x128xf32> to memref<1x128x128xf32> diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir index 4e74aeb6638b..0ff7280db28f 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 max:axis #map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir index ba1fd3d2f7d2..dd06f2e9f42d 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:axis1 #map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir index 1de4b503cd34..251622a43eaa 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:axis2 #map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir index c6a00120ee72..b1ccc6df6914 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:axis2:hasPointwise #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir new file mode 100644 index 000000000000..4f3f29cab340 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir @@ -0,0 +1,44 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:axis1 sum:axis2 + +#map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0) -> (0, d0, 0)> +#map4 = affine_map<(d0) -> (0, 0, d0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> +#transform_map2 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> +#transform_map3 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> +#transform_map4 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> +#transform_map5 = #rock.transform_map<#map2 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> +#transform_map6 = #rock.transform_map<#map3 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> +#transform_map7 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [256] -> [1, 1, 256]> +module { + func.func private @gemm_multi_reduce_different_axes(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}, %arg3: memref<256xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { + %0 = rock.transform %arg1 by #transform_map : memref<16384xf32> to memref<1x64x256xf32> + %1 = rock.transform %arg0 by #transform_map1 : memref<8192xf32> to memref<1x128x64xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x256xf32> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x64xf32> to memref<128x1x64xf32> + %3 = rock.transform %2 by #transform_map3 : memref<128x1x64xf32> to memref<1x128x64xf32> + %4 = rock.transform %0 by #transform_map4 : memref<1x64x256xf32> to memref<64x1x256xf32> + %5 = rock.transform %4 by #transform_map5 : memref<64x1x256xf32> to memref<1x64x256xf32> + rock.gemm %alloc = %1 * %0 storeMethod = set : memref<1x128x256xf32> = memref<1x128x64xf32> * memref<1x64x256xf32> + + // First reduction: sum on axis 1 (row reduction: 1x128x256 -> 1x1x256) + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x1x256xf32> + rock.reduce sum %alloc into %alloc_0 {axis = 1 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x1x256xf32> + %6 = rock.transform %alloc_0 by #transform_map7 : memref<1x1x256xf32> to memref<256xf32> + + // Second reduction: sum on axis 2 (column reduction: 1x128x256 -> 1x128x1) + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> + rock.reduce sum %alloc into %alloc_1 {axis = 2 : index, blockSize = 256 : i32, gridSize = 128 : i32} : memref<1x128x256xf32> into memref<1x128x1xf32> + %7 = rock.transform %alloc_1 by #transform_map6 : memref<1x128x1xf32> to memref<128xf32> + + memref.copy %7, %arg2 : memref<128xf32> to memref<128xf32> + memref.copy %6, %arg3 : memref<256xf32> to memref<256xf32> + return + } +} + diff --git a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir index 57263687f3b8..20f02ca55ba2 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:axis2 sum:axis2:hasPointwise #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> @@ -35,7 +35,8 @@ module { %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<128x256xf32> linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel"]} ins(%7 : memref<128x256xf32>) outs(%alloc_1 : memref<128x256xf32>) { ^bb0(%in: f32, %out: f32): - %10 = arith.mulf %in, %in : f32 + %cst = arith.constant 2.0 : f32 + %10 = arith.mulf %in, %cst : f32 linalg.yield %10 : f32 } %8 = rock.transform %alloc_1 by #transform_map8 : memref<128x256xf32> to memref<1x128x256xf32> diff --git a/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir b/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir index 4e0e673580bd..896f59f4fad9 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:axis2 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir index 6be9c674815a..c80bb39dcbb3 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 max:axis2 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir index 726822beff22..7219c3504f48 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:axis1 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir index 5f20b801455e..252e092de222 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:axis2 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> From 85832a2f6702853ae2fefbc5f807eb759e5d3784 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Fri, 21 Nov 2025 20:42:14 +0000 Subject: [PATCH 05/17] Update perfRunner to store entire problem config --- mlir/utils/performance/perfRunner.py | 92 +++++++++++++++++++++++----- 1 file changed, 76 insertions(+), 16 deletions(-) diff --git a/mlir/utils/performance/perfRunner.py b/mlir/utils/performance/perfRunner.py index 42d59b4d4b07..4f7305e31871 100644 --- a/mlir/utils/performance/perfRunner.py +++ b/mlir/utils/performance/perfRunner.py @@ -607,6 +607,10 @@ def to_command_line(self): f"-y {self.y} -x {self.x} -p {self.padding_h} -q {self.padding_w} " + f"-u {self.conv_stride_h} -v {self.conv_stride_w} -l {self.dilation_h} " + f"-j {self.dilation_w} -m conv -g {self.group} -t 1") + + def to_tuning_key(self): + """Returns the full problem key including fusion info for tuning DB lookups.""" + return getattr(self, '_original_command_line', self.to_command_line()) def __init__(self, dtype: str, direction: str, filter_layout: str, input_layout: str, output_layout: str, n: int, c: int, hi: int, wi: int, k: int, y: int, x: int, @@ -975,6 +979,11 @@ def from_command_line(cls, argv, arch, num_cu): scale_b_dtype = None trans_scale_a = False trans_scale_b = False + + # Store the original command line for accurate tuning DB lookups + # (including fusion info which we don't parse but need for cache key) + original_command_line = ' '.join(argv) + i = 0 while i < len(argv): opt = argv[i] @@ -983,6 +992,9 @@ def from_command_line(cls, argv, arch, num_cu): scaled_gemm = True i += 1 continue + # Fusion info is always at the end, so we can stop parsing here + if opt == '-fusion_reduce': + break # Handle flags with values if i + 1 >= len(argv): raise ValueError(f"Missing value for argument {opt}") @@ -1020,8 +1032,11 @@ def from_command_line(cls, argv, arch, num_cu): if v is None: raise ValueError("Incomplete GEMM configuration") - return cls(dtype, out_dtype, g, m, k, n, trans_a, trans_b, scaled_gemm, scale_a_dtype, - scale_b_dtype, trans_scale_a, trans_scale_b, arch, num_cu, perf_config) + config = cls(dtype, out_dtype, g, m, k, n, trans_a, trans_b, scaled_gemm, scale_a_dtype, + scale_b_dtype, trans_scale_a, trans_scale_b, arch, num_cu, perf_config) + # Store the full original command line for tuning DB lookups + config._original_command_line = original_command_line + return config def to_command_line(self): result = (f"-t {self.datatype} -out_datatype {self.out_dtype} " + @@ -1038,6 +1053,10 @@ def to_command_line(self): if self.trans_scale_b: result += f" -transScaleB {str(self.trans_scale_b).lower()}" return result + + def to_tuning_key(self): + """Returns the full problem key including fusion info for tuning DB lookups.""" + return getattr(self, '_original_command_line', self.to_command_line()) def __init__(self, dtype: str, @@ -1225,9 +1244,16 @@ def from_command_line(cls, argv, arch, num_cu): input_layout = None trans_c = False trans_o = False + + # Store the original command line for accurate tuning DB lookups + original_command_line = ' '.join(argv) + # Please keep this in sync with mlir::rock::getTuningProblemStr() for i in range(0, len(argv), 2): opt = argv[i] + # Fusion info is always at the end, so we can stop parsing here + if opt == '-fusion_reduce': + break val = argv[i + 1] if opt.endswith("-t"): dtype = val @@ -1280,9 +1306,11 @@ def from_command_line(cls, argv, arch, num_cu): if v is None: raise ValueError("Incomplete conv+gemm configuration") - return cls(dtype, filter_layout, input_layout, trans_c, trans_o, n, c, hi, wi, k, y, x, o, - conv_stride_h, conv_stride_w, padding_h, padding_w, dilation_h, dilation_w, - group, arch, num_cu, perf_config) + config = cls(dtype, filter_layout, input_layout, trans_c, trans_o, n, c, hi, wi, k, y, x, o, + conv_stride_h, conv_stride_w, padding_h, padding_w, dilation_h, dilation_w, + group, arch, num_cu, perf_config) + config._original_command_line = original_command_line + return config def to_command_line(self): return (f"-t {self.datatype} " + @@ -1292,6 +1320,10 @@ def to_command_line(self): f"-y {self.y} -x {self.x} -p {self.padding_h} -q {self.padding_w} " + f"-u {self.conv_stride_h} -v {self.conv_stride_w} -l {self.dilation_h} " + f"-j {self.dilation_w} -g {self.group}" + f"-gemmO {str(self.o)}") + + def to_tuning_key(self): + """Returns the full problem key including fusion info for tuning DB lookups.""" + return getattr(self, '_original_command_line', self.to_command_line()) class GemmGemmConfiguration(PerfConfiguration): @@ -1385,9 +1417,16 @@ def from_command_line(cls, argv, arch, num_cu): trans_b = False trans_c = False trans_o = False + + # Store the original command line for accurate tuning DB lookups + original_command_line = ' '.join(argv) + # Please keep this in sync with mlir::rock::getTuningProblemStr() for i in range(0, len(argv), 2): opt = argv[i] + # Fusion info is always at the end, so we can stop parsing here + if opt == '-fusion_reduce': + break val = argv[i + 1] if opt.endswith("-t"): dtype = val @@ -1417,8 +1456,10 @@ def from_command_line(cls, argv, arch, num_cu): if v is None: raise ValueError("Incomplete gemm+gemm configuration") - return cls(dtype, g, m, k, n, o, trans_a, trans_b, trans_c, trans_o, arch, num_cu, - perf_config) + config = cls(dtype, g, m, k, n, o, trans_a, trans_b, trans_c, trans_o, arch, num_cu, + perf_config) + config._original_command_line = original_command_line + return config def to_command_line(self): return (f"-t {self.datatype} " + @@ -1426,6 +1467,10 @@ def to_command_line(self): f"-transC {str(self.trans_c).lower()} -transO {str(self.trans_o).lower()} " + f"-g {self.g} " + f"-m {str(self.m)} -k {str(self.k)} -n {str(self.n)} -gemmO {str(self.o)}") + + def to_tuning_key(self): + """Returns the full problem key including fusion info for tuning DB lookups.""" + return getattr(self, '_original_command_line', self.to_command_line()) class AttentionConfiguration(PerfConfiguration): @@ -1565,9 +1610,16 @@ def from_command_line(cls, argv, arch, num_cu): split_kv = 1 with_attn_scale = False with_attn_bias = False + + # Store the original command line for accurate tuning DB lookups + original_command_line = ' '.join(argv) + # Please keep this in sync with mlir::rock::getTuningProblemStr() for i in range(0, len(argv), 2): opt = argv[i] + # Fusion info is always at the end, so we can stop parsing here + if opt == '-fusion_reduce': + break val = argv[i + 1] if opt.endswith("-t"): dtype = val @@ -1615,9 +1667,11 @@ def from_command_line(cls, argv, arch, num_cu): if v is None: raise ValueError("Incomplete Attention configuration") - return cls(dtype, g, seq_len_q, seq_len_k, num_heads_q, num_heads_kv, head_dim_qk, - head_dim_v, with_attn_scale, with_attn_bias, trans_q, trans_k, trans_v, trans_o, - causal, return_lse, split_kv, arch, num_cu, perf_config) + config = cls(dtype, g, seq_len_q, seq_len_k, num_heads_q, num_heads_kv, head_dim_qk, + head_dim_v, with_attn_scale, with_attn_bias, trans_q, trans_k, trans_v, trans_o, + causal, return_lse, split_kv, arch, num_cu, perf_config) + config._original_command_line = original_command_line + return config def to_command_line(self): return ( @@ -1630,6 +1684,10 @@ def to_command_line(self): f"-seq_len_q {str(self.seq_len_q)} -seq_len_k {str(self.seq_len_k)} -num_heads_q {str(self.num_heads_q)} -num_heads_kv {str(self.num_heads_kv)} -head_dim_qk {str(self.head_dim_qk)} -head_dim_v {str(self.head_dim_v)} " + f"-with-attn-scale {str(self.with_attn_scale).lower()} " + f"-with-attn-bias {str(self.with_attn_bias).lower()}") + + def to_tuning_key(self): + """Returns the full problem key including fusion info for tuning DB lookups.""" + return getattr(self, '_original_command_line', self.to_command_line()) class RocBLASGemmConfig(GemmConfiguration): @@ -1748,10 +1806,11 @@ def benchmark_mlir(commandline, rocmlir_gen_flags, use_rocprof=False): config = conf_class.from_command_line(commandline, arch, num_cu) - config_str = config.to_command_line() + # Use to_tuning_key() which includes fusion info for accurate DB lookups + config_key = config.to_tuning_key() if hasattr(config, 'to_tuning_key') else config.to_command_line() if tuning_db: - if (arch, config_str) in tuning_db: - config.set_perfconfig(tuning_db[arch, config_str]) + if (arch, config_key) in tuning_db: + config.set_perfconfig(tuning_db[arch, config_key]) else: # Tuning DB present but doesn't contain config, return N/A return config.table_entry(np.nan) @@ -2071,9 +2130,10 @@ def benchmark_fusion_kernels(test_dir, # Find the best perf_config best_perf = "" if tuning_db: - config_str = config.to_command_line() - if (arch, config_str) in tuning_db: - best_perf = tuning_db[arch, config_str] + # Use to_tuning_key() which includes fusion info for accurate DB lookups + config_key = config.to_tuning_key() if hasattr(config, 'to_tuning_key') else config.to_command_line() + if (arch, config_key) in tuning_db: + best_perf = tuning_db[arch, config_key] config.set_perfconfig(best_perf) else: # Tuning DB present but doesn't contain config, add a NaN entry if test_vector not in perf_results: From 9c4c9f60f9d58ac877f45766c8ef28ac2cea25ef Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Fri, 21 Nov 2025 21:01:53 +0000 Subject: [PATCH 06/17] Clang-format --- .../Dialect/Rock/Tuning/RockTuningImpl.cpp | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index cf56a1700f08..30caada5f0b7 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Rock/IR/AmdArchDb.h" #include "mlir/Dialect/Rock/IR/GetRockInfo.h" #include "mlir/Dialect/Rock/IR/Rock.h" @@ -22,7 +23,6 @@ #include "mlir/Dialect/Rock/Tuning/RockTuning.h" #include "mlir/Dialect/Rock/utility/fusionUtils.h" #include "mlir/Dialect/Rock/utility/loweringUtils.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/ArrayRef.h" @@ -749,16 +749,16 @@ struct ReductionInfo { ReduceMethod method; int64_t axis; bool hasPointwiseBefore; - + bool operator<(const ReductionInfo &other) const { // Sort by method first, then axis for consistent ordering if (method != other.method) return method < other.method; return axis < other.axis; } - + bool operator==(const ReductionInfo &other) const { - return method == other.method && axis == other.axis && + return method == other.method && axis == other.axis && hasPointwiseBefore == other.hasPointwiseBefore; } }; @@ -766,7 +766,7 @@ struct ReductionInfo { // Structure to hold fusion information for problem key generation struct FusionInfo { SmallVector reductions; - + bool hasReduction() const { return !reductions.empty(); } bool hasMultipleReductions() const { return reductions.size() > 1; } int numReductionOutputs() const { return reductions.size(); } @@ -778,17 +778,17 @@ static void analyzeUsers(Value originalGemmResult, GemmFeatures features, // Worklist: SmallVector> worklist; worklist.push_back({originalGemmResult, false}); - + // Track visited values to avoid cycles DenseSet visited; - + while (!worklist.empty()) { auto [value, hasPointwiseSoFar] = worklist.pop_back_val(); - + if (!visited.insert(value).second) { continue; // Already visited } - + for (Operation *user : value.getUsers()) { // Check for direct reduction if (auto reduceOp = dyn_cast(user)) { @@ -799,13 +799,13 @@ static void analyzeUsers(Value originalGemmResult, GemmFeatures features, info.reductions.push_back(redInfo); continue; } - + // Follow through rock.transform operations if (auto transformOp = dyn_cast(user)) { worklist.push_back({transformOp.getResult(), hasPointwiseSoFar}); continue; } - + // Check for linalg.generic (i.e., pointwise operations) if (auto genericOp = dyn_cast(user)) { // For memref-based linalg.generic, we need to check if this is reading @@ -818,14 +818,14 @@ static void analyzeUsers(Value originalGemmResult, GemmFeatures features, break; } } - + if (!isInput) { continue; } - + // Validate the output fusion SmallVector> adds; - if (failed(rock::checkValidOutputFusion(genericOp, originalGemmResult, + if (failed(rock::checkValidOutputFusion(genericOp, originalGemmResult, features, adds))) { continue; } @@ -842,32 +842,32 @@ static void analyzeUsers(Value originalGemmResult, GemmFeatures features, } // Analyze fusion patterns for a GEMM operation's output -static FusionInfo analyzeOuputFusionPattern(Value gemmResult, +static FusionInfo analyzeOuputFusionPattern(Value gemmResult, GemmFeatures features) { FusionInfo info; analyzeUsers(gemmResult, features, info); - + // Sort reductions for consistent ordering in problem key std::sort(info.reductions.begin(), info.reductions.end()); - + return info; } // Append fusion information to the problem key string -static void appendOutputFusionInfo(llvm::raw_svector_ostream &problemOS, - const FusionInfo &fusionInfo) { +static void appendOutputFusionInfo(llvm::raw_svector_ostream &problemOS, + const FusionInfo &fusionInfo) { constexpr char sep = ' '; - + if (!fusionInfo.hasReduction()) return; - - problemOS << sep << "-fusion_reduce" << sep << "count=" - << fusionInfo.numReductionOutputs(); - + + problemOS << sep << "-fusion_reduce" << sep + << "count=" << fusionInfo.numReductionOutputs(); + // Encode each reduction in format: method:axis[:hasPointwise] for (const auto &reduction : fusionInfo.reductions) { problemOS << sep; - + // Add reduction method switch (reduction.method) { case ReduceMethod::Sum: @@ -877,10 +877,10 @@ static void appendOutputFusionInfo(llvm::raw_svector_ostream &problemOS, problemOS << "max"; break; } - + // Add reduction axis with colon separator problemOS << ":axis" << reduction.axis; - + // Add pointwise flag for this specific reduction if (reduction.hasPointwiseBefore) { problemOS << ":hasPointwise"; @@ -1063,13 +1063,13 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp, problemOS << "-k " << headDimQK << sep; problemOS << "-gemmO " << headDimV; } - + // Analyze and append fusion information Value gemmGemmOutput = gemmGemmOp.getOutArgument()->get(); GemmFeatures features = rock::getFeatures(gemmGemmOp); FusionInfo fusionInfo = analyzeOuputFusionPattern(gemmGemmOutput, features); appendOutputFusionInfo(problemOS, fusionInfo); - + return success(); } From 4d9265633a4a33924deec35a2f8fca89bc5b6bab Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 25 Nov 2025 21:18:14 +0000 Subject: [PATCH 07/17] Add better tracing of reduction linalg generics --- .../Dialect/Rock/Tuning/RockTuningImpl.cpp | 78 +++++++++++++++++-- 1 file changed, 70 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index 30caada5f0b7..690abd79a9c5 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -23,8 +23,10 @@ #include "mlir/Dialect/Rock/Tuning/RockTuning.h" #include "mlir/Dialect/Rock/utility/fusionUtils.h" #include "mlir/Dialect/Rock/utility/loweringUtils.h" +#include "mlir/Dialect/Rock/utility/transformMapUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" @@ -772,18 +774,57 @@ struct FusionInfo { int numReductionOutputs() const { return reductions.size(); } }; +// Helper to collect all values that are views of the same underlying allocation +static void collectAllViewsOfAlloc(Value value, SmallVectorImpl &views, + DenseSet &visited) { + if (!visited.insert(value).second) { + return; // Already visited + } + + views.push_back(value); + + // Traverse through all users to find view-like operations + for (Operation *user : value.getUsers()) { + if (isa(user)) { + // For view-like operations, recursively collect their results + for (Value result : user->getResults()) { + collectAllViewsOfAlloc(result, views, visited); + } + } + } +} + // Helper to analyze users, following through rock.transform operations static void analyzeUsers(Value originalGemmResult, GemmFeatures features, FusionInfo &info) { - // Worklist: - SmallVector> worklist; - worklist.push_back({originalGemmResult, false}); + OpBuilder b(originalGemmResult.getContext()); + + // First, untransform to get the underlying allocation + Value underlyingAlloc; + ArrayAttr transforms; + bool needs64Bit; + std::tie(underlyingAlloc, transforms, needs64Bit) = + rock::untransform(b, originalGemmResult); + + // Collect all views (transforms) of this underlying allocation + SmallVector allViews; + DenseSet viewVisited; + collectAllViewsOfAlloc(underlyingAlloc, allViews, viewVisited); + + // Now analyze users of all these views + // Worklist: + SmallVector> worklist; + for (Value view : allViews) { + worklist.push_back({view, false, underlyingAlloc}); + } - // Track visited values to avoid cycles + // Track visited values to avoid processing cycles DenseSet visited; while (!worklist.empty()) { - auto [value, hasPointwiseSoFar] = worklist.pop_back_val(); + auto [value, hasPointwiseSoFar, currentUnderlyingAlloc] = + worklist.pop_back_val(); if (!visited.insert(value).second) { continue; // Already visited @@ -802,7 +843,8 @@ static void analyzeUsers(Value originalGemmResult, GemmFeatures features, // Follow through rock.transform operations if (auto transformOp = dyn_cast(user)) { - worklist.push_back({transformOp.getResult(), hasPointwiseSoFar}); + worklist.push_back({transformOp.getResult(), hasPointwiseSoFar, + currentUnderlyingAlloc}); continue; } @@ -825,7 +867,8 @@ static void analyzeUsers(Value originalGemmResult, GemmFeatures features, // Validate the output fusion SmallVector> adds; - if (failed(rock::checkValidOutputFusion(genericOp, originalGemmResult, + if (failed(rock::checkValidOutputFusion(genericOp, + currentUnderlyingAlloc, features, adds))) { continue; } @@ -833,7 +876,26 @@ static void analyzeUsers(Value originalGemmResult, GemmFeatures features, // We need to follow users of the output memref auto outputs = genericOp.getOutputs(); if (!outputs.empty() && outputs[0]) { - worklist.push_back({outputs[0], /*hasPointwiseSoFar=*/true}); + Value outputValue = outputs[0]; + + // Get the underlying allocation for this output + Value outputUnderlyingAlloc; + ArrayAttr outputTransforms; + bool outputNeeds64Bit; + std::tie(outputUnderlyingAlloc, outputTransforms, outputNeeds64Bit) = + rock::untransform(b, outputValue); + + // Collect all views of the output allocation + SmallVector outputViews; + DenseSet outputViewVisited; + collectAllViewsOfAlloc(outputUnderlyingAlloc, outputViews, + outputViewVisited); + + // Add all views of the output to the worklist + for (Value outputView : outputViews) { + worklist.push_back({outputView, /*hasPointwiseSoFar=*/true, + outputUnderlyingAlloc}); + } } continue; } From f6068c6f14234ed9a8295f5e9d874f2aecded9ca Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 25 Nov 2025 21:24:22 +0000 Subject: [PATCH 08/17] Add linalg tracing LIT test --- .../conv-reduce-trace-linalg.mlir | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir diff --git a/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir b/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir new file mode 100644 index 000000000000..c93a0bcd6081 --- /dev/null +++ b/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir @@ -0,0 +1,88 @@ +// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s + +// CHECK: 256 convfp16 -F 1 -f GNC01 -I NGC01 -O NGC01 -n 1 -c 128 -H 32 -W 32 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -fusion_reduce count=2 sum:axis2:hasPointwise sum:axis2:hasPointwise + +#map = affine_map<(d0, d1, d2, d3) -> (((d0 * 128 + d1) * 3 + d2) * 3 + d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> ((d1 * 32 + d2) * 32 + d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 * 128 + d2, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0 * 256 + d1, d2, d3, d4)> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 * 256 + d2, d3, d4)> +#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d0, d1, d2, d4)> +#map6 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d0, d1, d4)> +#map7 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 * 8 + d2, d3, d4)> +#map8 = affine_map<(d0, d1, d2, d3, d4) -> (d1 * 8 + d2)> +#map9 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, 0, 0)> +#map10 = affine_map<(d0, d1, d2, d3) -> (0, d0, d1, d2, d3)> +#map11 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map12 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)> +#map13 = affine_map<(d0) -> (0, d0 floordiv 8192, (d0 mod 8192) floordiv 1024, (d0 mod 1024) floordiv 32, d0 mod 32)> +#map14 = affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 1024, (d2 mod 1024) floordiv 32, d2 mod 32)> +#map15 = affine_map<(d0) -> (0, d0, 0)> +#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>] bounds = [256, 128, 3, 3] -> [294912]> +#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 32, 32] -> [131072]> +#transform_map2 = #rock.transform_map<#map2 by [ ["n", "h", "w"] at [0, 2, 3]>, ["c"] at [1]>] bounds = [1, 1, 128, 32, 32] -> [1, 128, 32, 32]> +#transform_map3 = #rock.transform_map<#map3 by [ ["c", "y", "x"] at [1, 2, 3]>, ["k"] at [0]>] bounds = [1, 256, 128, 3, 3] -> [256, 128, 3, 3]> +#transform_map4 = #rock.transform_map<#map4 by [ ["n", "h", "w"] at [0, 2, 3]>, ["k"] at [1]>] bounds = [1, 1, 256, 32, 32] -> [1, 256, 32, 32]> +#transform_map5 = #rock.transform_map<#map5 by [ ["dim1", "dim2", "dim3", "dim0", "dim4"] at [1, 2, 3, 0, 4]>] bounds = [256, 128, 3, 1, 3] -> [1, 256, 128, 3, 3]> +#transform_map6 = #rock.transform_map<#map6 by [ ["dim2", "dim3", "dim0", "dim1", "dim4"] at [2, 3, 0, 1, 4]>] bounds = [128, 32, 1, 1, 32] -> [1, 1, 128, 32, 32]> +#transform_map7 = #rock.transform_map<#map7 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>, ["dim3"] at [3]>] bounds = [1, 32, 8, 32, 32] -> [1, 256, 32, 32]> +#transform_map8 = #rock.transform_map<#map8 by [ ["dim0"] at [0]>, [] at []>, [] at []>, [] at []>] bounds = [1, 32, 8, 1, 1] -> [256]> +#transform_map9 = #rock.transform_map<#map9 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>, ["dim3"] at [3]>, ["dim4"] at [4]>] bounds = [1, 32, 8, 32, 32] -> [1, 32, 8, 1, 1]> +#transform_map10 = #rock.transform_map<#map10 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>, ["dim2"] at [3]>, ["dim3"] at [4]>] bounds = [32, 8, 32, 32] -> [1, 32, 8, 32, 32]> +#transform_map11 = #rock.transform_map<#map12 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["dim2"] at [2]>, ["dim3"] at [3]>, [] at []>] bounds = [1, 32, 8, 32, 32] -> [32, 8, 32, 32]> +#transform_map12 = #rock.transform_map<#map13 by [ ["col0", "col1", "col2", "col3", "col4"] at [0, 1, 2, 3, 4]>] bounds = [262144] -> [1, 32, 8, 32, 32]> +#transform_map13 = #rock.transform_map<#map14 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, ["col2", "col3", "col4"] at [2, 3, 4]>] bounds = [1, 32, 8192] -> [1, 32, 8, 32, 32]> +#transform_map14 = #rock.transform_map<#map15 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [32] -> [1, 32, 1]> +module { + func.func @mlir_reshape_convolution_reshape_broadcast_add_mul_reshape_reduce_sum_reshape_mul_reshape_reduce_sum_reshape(%arg0: memref<131072xf16>, %arg1: memref<294912xf16>, %arg2: memref<256xf16>, %arg3: memref<32xf16> {mhal.read_access, rock.prefill = 0.000000e+00 : f16}, %arg4: memref<32xf16> {mhal.read_access, rock.prefill = 0.000000e+00 : f16}, %arg5: memref<262144xf16>) attributes {arch = "gfx950:sramecc+:xnack-", enable_splitk_for_tuning, kernel = "mixr", num_cu = 256 : i64} { + %cst = arith.constant 1.220700e-04 : f16 + %0 = rock.transform %arg1 by #transform_map : memref<294912xf16> to memref<256x128x3x3xf16> + %1 = rock.transform %arg0 by #transform_map1 : memref<131072xf16> to memref<1x128x32x32xf16> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x256x32x32xf16> + %2 = rock.transform %1 by #transform_map2 : memref<1x128x32x32xf16> to memref<1x1x128x32x32xf16> + %3 = rock.transform %0 by #transform_map3 : memref<256x128x3x3xf16> to memref<1x256x128x3x3xf16> + %4 = rock.transform %alloc by #transform_map4 : memref<1x256x32x32xf16> to memref<1x1x256x32x32xf16> + %5 = rock.transform %3 by #transform_map5 : memref<1x256x128x3x3xf16> to memref<256x128x3x1x3xf16> + %6 = rock.transform %2 by #transform_map6 : memref<1x1x128x32x32xf16> to memref<128x32x1x1x32xf16> + rock.conv(%3, %2, %4) {dilations = [1 : index, 1 : index], filter_layout = ["g", "k", "c", "y", "x"], input_layout = ["ni", "gi", "ci", "hi", "wi"], output_layout = ["no", "go", "ko", "ho", "wo"], padding = [1 : index, 1 : index, 1 : index, 1 : index], strides = [1 : index, 1 : index]} : memref<1x256x128x3x3xf16>, memref<1x1x128x32x32xf16>, memref<1x1x256x32x32xf16> + %7 = rock.transform %alloc by #transform_map7 : memref<1x256x32x32xf16> to memref<1x32x8x32x32xf16> + %8 = rock.transform %arg2 by #transform_map8 : memref<256xf16> to memref<1x32x8x1x1xf16> + %9 = rock.transform %8 by #transform_map9 : memref<1x32x8x1x1xf16> to memref<1x32x8x32x32xf16> + %10 = rock.transform %7 by #transform_map10 : memref<1x32x8x32x32xf16> to memref<32x8x32x32xf16> + %11 = rock.transform %9 by #transform_map10 : memref<1x32x8x32x32xf16> to memref<32x8x32x32xf16> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<32x8x32x32xf16> + linalg.generic {indexing_maps = [#map11, #map11, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%10, %11 : memref<32x8x32x32xf16>, memref<32x8x32x32xf16>) outs(%alloc_0 : memref<32x8x32x32xf16>) { + ^bb0(%in: f16, %in_5: f16, %out: f16): + %20 = arith.addf %in, %in_5 : f16 + linalg.yield %20 : f16 + } + %12 = rock.transform %alloc_0 by #transform_map11 : memref<32x8x32x32xf16> to memref<1x32x8x32x32xf16> + %13 = rock.transform %12 by #transform_map12 : memref<1x32x8x32x32xf16> to memref<262144xf16> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x8x32x32xf16> + linalg.generic {indexing_maps = [#map11, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_0 : memref<32x8x32x32xf16>) outs(%alloc_1 : memref<32x8x32x32xf16>) { + ^bb0(%in: f16, %out: f16): + %20 = arith.mulf %in, %cst : f16 + linalg.yield %20 : f16 + } + %14 = rock.transform %alloc_1 by #transform_map11 : memref<32x8x32x32xf16> to memref<1x32x8x32x32xf16> + %15 = rock.transform %14 by #transform_map13 : memref<1x32x8x32x32xf16> to memref<1x32x8192xf16> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x32x1xf16> + rock.reduce sum %15 into %alloc_2 {axis = 2 : index, blockSize = 256 : i32, gridSize = 1024 : i32} : memref<1x32x8192xf16> into memref<1x32x1xf16> + %16 = rock.transform %alloc_2 by #transform_map14 : memref<1x32x1xf16> to memref<32xf16> + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x8x32x32xf16> + linalg.generic {indexing_maps = [#map11, #map11, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc_1, %alloc_0 : memref<32x8x32x32xf16>, memref<32x8x32x32xf16>) outs(%alloc_3 : memref<32x8x32x32xf16>) { + ^bb0(%in: f16, %in_5: f16, %out: f16): + %20 = arith.mulf %in, %in_5 : f16 + linalg.yield %20 : f16 + } + %17 = rock.transform %alloc_3 by #transform_map11 : memref<32x8x32x32xf16> to memref<1x32x8x32x32xf16> + %18 = rock.transform %17 by #transform_map13 : memref<1x32x8x32x32xf16> to memref<1x32x8192xf16> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<1x32x1xf16> + rock.reduce sum %18 into %alloc_4 {axis = 2 : index, blockSize = 256 : i32, gridSize = 1024 : i32} : memref<1x32x8192xf16> into memref<1x32x1xf16> + %19 = rock.transform %alloc_4 by #transform_map14 : memref<1x32x1xf16> to memref<32xf16> + memref.copy %16, %arg3 : memref<32xf16> to memref<32xf16> + memref.copy %19, %arg4 : memref<32xf16> to memref<32xf16> + memref.copy %13, %arg5 : memref<262144xf16> to memref<262144xf16> + return + } +} From 33c14bc395587fa4973d70718330c6f66b03da57 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Tue, 25 Nov 2025 21:55:59 +0000 Subject: [PATCH 09/17] More small fixes --- .../mlir/Dialect/Rock/utility/fusionUtils.h | 4 ++ .../Dialect/Rock/Tuning/RockTuningImpl.cpp | 22 +++++-- mlir/lib/Dialect/Rock/utility/fusionUtils.cpp | 2 +- .../gemm-gemm-multi-reduce.mlir | 66 ------------------- .../gemm-multi-reduce-layernorm.mlir | 2 +- 5 files changed, 21 insertions(+), 75 deletions(-) delete mode 100644 mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir diff --git a/mlir/include/mlir/Dialect/Rock/utility/fusionUtils.h b/mlir/include/mlir/Dialect/Rock/utility/fusionUtils.h index 0e97832855a5..e50f58f22d5a 100644 --- a/mlir/include/mlir/Dialect/Rock/utility/fusionUtils.h +++ b/mlir/include/mlir/Dialect/Rock/utility/fusionUtils.h @@ -54,6 +54,10 @@ checkValidOutputFusion(linalg::GenericOp genericOp, Value gemmResult, GemmFeatures features, SmallVector> &adds); +// Checks whether an operation is a valid elementwise operation for GEMM output +// fusion (used for both split-K and reduction fusion analysis). +bool validOperationGemmOut(Operation &op); + } // end namespace rock } // end namespace mlir diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index 690abd79a9c5..4af8ca466a39 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -865,11 +865,19 @@ static void analyzeUsers(Value originalGemmResult, GemmFeatures features, continue; } - // Validate the output fusion - SmallVector> adds; - if (failed(rock::checkValidOutputFusion(genericOp, - currentUnderlyingAlloc, - features, adds))) { + // For reduction fusion detection, we just need to validate that this is + // an elementwise operation. + bool isValidElementwise = true; + Block &body = genericOp.getRegion().front(); + for (Operation &nestedOp : body.without_terminator()) { + if (!rock::validOperationGemmOut(nestedOp) && + !isa(nestedOp)) { + isValidElementwise = false; + break; + } + } + + if (!isValidElementwise) { continue; } @@ -889,12 +897,12 @@ static void analyzeUsers(Value originalGemmResult, GemmFeatures features, SmallVector outputViews; DenseSet outputViewVisited; collectAllViewsOfAlloc(outputUnderlyingAlloc, outputViews, - outputViewVisited); + outputViewVisited); // Add all views of the output to the worklist for (Value outputView : outputViews) { worklist.push_back({outputView, /*hasPointwiseSoFar=*/true, - outputUnderlyingAlloc}); + outputUnderlyingAlloc}); } } continue; diff --git a/mlir/lib/Dialect/Rock/utility/fusionUtils.cpp b/mlir/lib/Dialect/Rock/utility/fusionUtils.cpp index 68d56776e162..8aa7b144ae90 100644 --- a/mlir/lib/Dialect/Rock/utility/fusionUtils.cpp +++ b/mlir/lib/Dialect/Rock/utility/fusionUtils.cpp @@ -33,7 +33,7 @@ using namespace mlir; using namespace mlir::rock; using namespace arith; -bool validOperationGemmOut(Operation &op) { +bool mlir::rock::validOperationGemmOut(Operation &op) { return isa(op); } diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir deleted file mode 100644 index 4c5dc2342afb..000000000000 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-multi-reduce.mlir +++ /dev/null @@ -1,66 +0,0 @@ -// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s - -// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=2 sum:axis2 sum:axis2:hasPointwise - -#map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> -#map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> -#map2 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> -#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -#map4 = affine_map<(d0) -> (0, d0, 0)> -#map5 = affine_map<(d0, d1) -> (0, d0, d1)> -#map6 = affine_map<(d0, d1) -> (d0, d1)> -#map7 = affine_map<(d0, d1, d2) -> (d1, d2)> -#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 256, 128] -> [32768]> -#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> -#transform_map2 = #rock.transform_map<#map2 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> -#transform_map3 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> -#transform_map4 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> -#transform_map5 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> -#transform_map6 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> -#transform_map7 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [256, 1, 128] -> [1, 256, 128]> -#transform_map8 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 256, 128] -> [256, 1, 128]> -#transform_map9 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> -#transform_map10 = #rock.transform_map<#map5 by [ ["col0", "col1"] at [0, 1]>, ["dim1"] at [2]>] bounds = [128, 128] -> [1, 128, 128]> -#transform_map11 = #rock.transform_map<#map7 by [ ["dim0"] at [0]>, ["dim1"] at [1]>, [] at []>] bounds = [1, 128, 128] -> [128, 128]> -module { - func.func private @gemm_gemm_multi_reduce(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}, %arg4: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0.000000e+00 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { - %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x256x128xf32> - %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> - %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x128xf32> - %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> - %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> - %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> - %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> - %7 = rock.transform %0 by #transform_map7 : memref<1x256x128xf32> to memref<256x1x128xf32> - %8 = rock.transform %7 by #transform_map8 : memref<256x1x128xf32> to memref<1x256x128xf32> - rock.gemm_elementwise_gemm{ - ab = %2 * %1 : memref<1x128x64xf32>, memref<1x64x256xf32> - ab = elementwise { - ^bb0(%arg5: memref<1x128x256xf32>, %arg6: memref<1x128x256xf32>): - memref.copy %arg5, %arg6 : memref<1x128x256xf32> to memref<1x128x256xf32> - rock.yield - } - %alloc = ab * %0 : memref<1x256x128xf32> -> memref<1x128x128xf32> - } {firstGemmIndices = array, storeMethod = #rock} - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> - rock.reduce sum %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 64 : i32} : memref<1x128x128xf32> into memref<1x128x1xf32> - %9 = rock.transform %alloc_0 by #transform_map9 : memref<1x128x1xf32> to memref<128xf32> - %10 = rock.transform %alloc by #transform_map10 : memref<1x128x128xf32> to memref<128x128xf32> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<128x128xf32> - linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%10 : memref<128x128xf32>) outs(%alloc_1 : memref<128x128xf32>) { - ^bb0(%in: f32, %out: f32): - %cst = arith.constant 2.0 : f32 - %13 = arith.mulf %in, %cst : f32 - linalg.yield %13 : f32 - } - %11 = rock.transform %alloc_1 by #transform_map11 : memref<128x128xf32> to memref<1x128x128xf32> - %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> - rock.reduce sum %11 into %alloc_2 {axis = 2 : index, blockSize = 256 : i32, gridSize = 64 : i32} : memref<1x128x128xf32> into memref<1x128x1xf32> - %12 = rock.transform %alloc_2 by #transform_map9 : memref<1x128x1xf32> to memref<128xf32> - memref.copy %9, %arg3 : memref<128xf32> to memref<128xf32> - memref.copy %12, %arg4 : memref<128xf32> to memref<128xf32> - return - } -} - diff --git a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir index 20f02ca55ba2..a68ac95a6c35 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:axis2 sum:axis2:hasPointwise +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:axis2:hasPointwise sum:axis2 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> From 0dfaa7cdd8c5243a0df22256fa846d89edaaf234 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Wed, 26 Nov 2025 19:22:55 +0000 Subject: [PATCH 10/17] Attend to copilot review comments --- .../Dialect/Rock/Tuning/RockTuningImpl.cpp | 14 +++--- .../gemm-gemm-reduce-max-axis2.mlir | 49 ------------------- 2 files changed, 8 insertions(+), 55 deletions(-) delete mode 100644 mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index 4af8ca466a39..f465ec5d8c2c 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -753,10 +753,12 @@ struct ReductionInfo { bool hasPointwiseBefore; bool operator<(const ReductionInfo &other) const { - // Sort by method first, then axis for consistent ordering + // Sort by method first, then axis, then hasPointwiseBefore if (method != other.method) return method < other.method; - return axis < other.axis; + if (axis != other.axis) + return axis < other.axis; + return hasPointwiseBefore < other.hasPointwiseBefore; } bool operator==(const ReductionInfo &other) const { @@ -912,8 +914,8 @@ static void analyzeUsers(Value originalGemmResult, GemmFeatures features, } // Analyze fusion patterns for a GEMM operation's output -static FusionInfo analyzeOuputFusionPattern(Value gemmResult, - GemmFeatures features) { +static FusionInfo analyzeOutputFusionPattern(Value gemmResult, + GemmFeatures features) { FusionInfo info; analyzeUsers(gemmResult, features, info); @@ -1137,7 +1139,7 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp, // Analyze and append fusion information Value gemmGemmOutput = gemmGemmOp.getOutArgument()->get(); GemmFeatures features = rock::getFeatures(gemmGemmOp); - FusionInfo fusionInfo = analyzeOuputFusionPattern(gemmGemmOutput, features); + FusionInfo fusionInfo = analyzeOutputFusionPattern(gemmGemmOutput, features); appendOutputFusionInfo(problemOS, fusionInfo); return success(); @@ -1360,7 +1362,7 @@ static LogicalResult getTuningProblemStr(rock::RockGemmWrapperInterface gemmIF, // Analyze and append fusion information Value gemmOutput = gemmIF.getOutArgument()->get(); GemmFeatures features = rock::getFeatures(gemmIF); - FusionInfo fusionInfo = analyzeOuputFusionPattern(gemmOutput, features); + FusionInfo fusionInfo = analyzeOutputFusionPattern(gemmOutput, features); appendOutputFusionInfo(problemOS, fusionInfo); while (out.back() == sep) { diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir deleted file mode 100644 index 0ff7280db28f..000000000000 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-max-axis2.mlir +++ /dev/null @@ -1,49 +0,0 @@ -// RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s - -// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 max:axis - - -#map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> -#map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> -#map2 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> -#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -#map4 = affine_map<(d0) -> (0, d0, 0)> -#transform_map = #rock.transform_map<#map by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 256, 128] -> [32768]> -#transform_map1 = #rock.transform_map<#map1 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 64, 256] -> [16384]> -#transform_map2 = #rock.transform_map<#map2 by [ ["dim0"] at [0]>, [] at []>] bounds = [1, 128, 64] -> [8192]> -#transform_map3 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [128, 1, 64] -> [1, 128, 64]> -#transform_map4 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 128, 64] -> [128, 1, 64]> -#transform_map5 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [64, 1, 256] -> [1, 64, 256]> -#transform_map6 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 64, 256] -> [64, 1, 256]> -#transform_map7 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [256, 1, 128] -> [1, 256, 128]> -#transform_map8 = #rock.transform_map<#map3 by [ ["dim1", "dim0", "dim2"] at [1, 0, 2]>] bounds = [1, 256, 128] -> [256, 1, 128]> -#transform_map9 = #rock.transform_map<#map4 by [ ["col0", "col1", "col2"] at [0, 1, 2]>] bounds = [128] -> [1, 128, 1]> -module { - func.func private @gemm_gemm_reduce_max_axis2(%arg0: memref<8192xf32> {mhal.read_access}, %arg1: memref<16384xf32> {mhal.read_access}, %arg2: memref<32768xf32> {mhal.read_access}, %arg3: memref<128xf32> {mhal.read_access, mhal.write_access, rock.prefill = 0xFF800000 : f32}) attributes {arch = "gfx942", kernel, num_cu = 120 : i64} { - %0 = rock.transform %arg2 by #transform_map : memref<32768xf32> to memref<1x256x128xf32> - %1 = rock.transform %arg1 by #transform_map1 : memref<16384xf32> to memref<1x64x256xf32> - %2 = rock.transform %arg0 by #transform_map2 : memref<8192xf32> to memref<1x128x64xf32> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x128x128xf32> - %3 = rock.transform %2 by #transform_map3 : memref<1x128x64xf32> to memref<128x1x64xf32> - %4 = rock.transform %3 by #transform_map4 : memref<128x1x64xf32> to memref<1x128x64xf32> - %5 = rock.transform %1 by #transform_map5 : memref<1x64x256xf32> to memref<64x1x256xf32> - %6 = rock.transform %5 by #transform_map6 : memref<64x1x256xf32> to memref<1x64x256xf32> - %7 = rock.transform %0 by #transform_map7 : memref<1x256x128xf32> to memref<256x1x128xf32> - %8 = rock.transform %7 by #transform_map8 : memref<256x1x128xf32> to memref<1x256x128xf32> - rock.gemm_elementwise_gemm{ - ab = %2 * %1 : memref<1x128x64xf32>, memref<1x64x256xf32> - ab = elementwise { - ^bb0(%arg4: memref<1x128x256xf32>, %arg5: memref<1x128x256xf32>): - memref.copy %arg4, %arg5 : memref<1x128x256xf32> to memref<1x128x256xf32> - rock.yield - } - %alloc = ab * %0 : memref<1x256x128xf32> -> memref<1x128x128xf32> - } {firstGemmIndices = array, storeMethod = #rock} - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x128x1xf32> - rock.reduce max %alloc into %alloc_0 {axis = 2 : index, blockSize = 256 : i32, gridSize = 64 : i32} : memref<1x128x128xf32> into memref<1x128x1xf32> - %9 = rock.transform %alloc_0 by #transform_map9 : memref<1x128x1xf32> to memref<128xf32> - memref.copy %9, %arg3 : memref<128xf32> to memref<128xf32> - return - } -} - From 6325d506da7b70c06ab84c90687e4b8f92d35ed5 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Fri, 28 Nov 2025 18:56:14 +0000 Subject: [PATCH 11/17] Partial review comments --- .../Dialect/Rock/Tuning/RockTuningImpl.cpp | 197 +++++++----------- .../gemm-multi-reduce-layernorm.mlir | 2 +- 2 files changed, 77 insertions(+), 122 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index f465ec5d8c2c..0aecbf4d272c 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Analysis/BufferDependencyAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Rock/IR/AmdArchDb.h" #include "mlir/Dialect/Rock/IR/GetRockInfo.h" @@ -772,152 +773,106 @@ struct FusionInfo { SmallVector reductions; bool hasReduction() const { return !reductions.empty(); } - bool hasMultipleReductions() const { return reductions.size() > 1; } int numReductionOutputs() const { return reductions.size(); } }; -// Helper to collect all values that are views of the same underlying allocation -static void collectAllViewsOfAlloc(Value value, SmallVectorImpl &views, - DenseSet &visited) { - if (!visited.insert(value).second) { - return; // Already visited +// Helper to get the base value (allocation or block argument) from a value +static FailureOr getBaseValue(Value v) { + FailureOr maybeAlloc = rock::findMemrefAlloc(v); + if (succeeded(maybeAlloc)) { + return maybeAlloc.value().getResult(); } - views.push_back(value); - - // Traverse through all users to find view-like operations - for (Operation *user : value.getUsers()) { - if (isa(user)) { - // For view-like operations, recursively collect their results - for (Value result : user->getResults()) { - collectAllViewsOfAlloc(result, views, visited); - } - } + FailureOr maybeBlockArg = rock::findBlockArgument(v); + if (succeeded(maybeBlockArg)) { + return maybeBlockArg.value(); } + + return failure(); } -// Helper to analyze users, following through rock.transform operations -static void analyzeUsers(Value originalGemmResult, GemmFeatures features, - FusionInfo &info) { - OpBuilder b(originalGemmResult.getContext()); - - // First, untransform to get the underlying allocation - Value underlyingAlloc; - ArrayAttr transforms; - bool needs64Bit; - std::tie(underlyingAlloc, transforms, needs64Bit) = - rock::untransform(b, originalGemmResult); - - // Collect all views (transforms) of this underlying allocation - SmallVector allViews; - DenseSet viewVisited; - collectAllViewsOfAlloc(underlyingAlloc, allViews, viewVisited); - - // Now analyze users of all these views - // Worklist: - SmallVector> worklist; - for (Value view : allViews) { - worklist.push_back({view, false, underlyingAlloc}); +// Helper to trace backwards from a value to see if it reaches the target +static std::pair +tracesToTarget(Value start, Value target, const BufferDependencyAnalysis &deps, + DenseSet &visited) { + if (!visited.insert(start).second) { + return {false, false}; // Avoid cycles } - // Track visited values to avoid processing cycles - DenseSet visited; - - while (!worklist.empty()) { - auto [value, hasPointwiseSoFar, currentUnderlyingAlloc] = - worklist.pop_back_val(); - - if (!visited.insert(value).second) { - continue; // Already visited - } - - for (Operation *user : value.getUsers()) { - // Check for direct reduction - if (auto reduceOp = dyn_cast(user)) { - ReductionInfo redInfo; - redInfo.method = reduceOp.getReduceMethod(); - redInfo.axis = reduceOp.getAxis().getSExtValue(); - redInfo.hasPointwiseBefore = hasPointwiseSoFar; - info.reductions.push_back(redInfo); - continue; - } - - // Follow through rock.transform operations - if (auto transformOp = dyn_cast(user)) { - worklist.push_back({transformOp.getResult(), hasPointwiseSoFar, - currentUnderlyingAlloc}); - continue; - } + FailureOr baseValue = getBaseValue(start); + if (failed(baseValue)) + return {false, false}; // Could not find base value - // Check for linalg.generic (i.e., pointwise operations) - if (auto genericOp = dyn_cast(user)) { - // For memref-based linalg.generic, we need to check if this is reading - // from our value. We only care about cases where our value is an input - // to the pointwise operation. - bool isInput = false; - for (Value input : genericOp.getInputs()) { - if (input == value) { - isInput = true; - break; - } - } - - if (!isInput) { - continue; - } + if (*baseValue == target) { + return {true, false}; // Found target, with no pointwise inbetween + } - // For reduction fusion detection, we just need to validate that this is - // an elementwise operation. - bool isValidElementwise = true; - Block &body = genericOp.getRegion().front(); - for (Operation &nestedOp : body.without_terminator()) { - if (!rock::validOperationGemmOut(nestedOp) && - !isa(nestedOp)) { - isValidElementwise = false; - break; - } - } - - if (!isValidElementwise) { + // For allocations, use BufferDependencyAnalysis to find writers + if (auto allocOp = baseValue->getDefiningOp()) { + std::optional> writers = deps.getWriters(allocOp); + if (writers) { + for (OpOperand *writerOperand : *writers) { + auto genericOp = dyn_cast(writerOperand->getOwner()); + if (!genericOp) { continue; } - // We need to follow users of the output memref - auto outputs = genericOp.getOutputs(); - if (!outputs.empty() && outputs[0]) { - Value outputValue = outputs[0]; - - // Get the underlying allocation for this output - Value outputUnderlyingAlloc; - ArrayAttr outputTransforms; - bool outputNeeds64Bit; - std::tie(outputUnderlyingAlloc, outputTransforms, outputNeeds64Bit) = - rock::untransform(b, outputValue); - - // Collect all views of the output allocation - SmallVector outputViews; - DenseSet outputViewVisited; - collectAllViewsOfAlloc(outputUnderlyingAlloc, outputViews, - outputViewVisited); - - // Add all views of the output to the worklist - for (Value outputView : outputViews) { - worklist.push_back({outputView, /*hasPointwiseSoFar=*/true, - outputUnderlyingAlloc}); + // Trace through inputs of the linalg.generic (assumed to be pointwise) + for (Value input : genericOp.getInputs()) { + auto [reachesTarget, hasPointwise] = + tracesToTarget(input, target, deps, visited); + if (reachesTarget) { + return {true, true}; // Found target through pointwise } } - continue; } } } + + return {false, false}; +} + +// Find all reductions and check if they trace back to our GEMM output +static FusionInfo getFusionInfo(Value gemmResult, GemmFeatures features) { + FusionInfo info; + + // Find the target (allocation or block argument) + FailureOr maybeTarget = getBaseValue(gemmResult); + if (failed(maybeTarget)) + return info; // None found + + Value target = *maybeTarget; + + // Get the parent function + auto defOp = gemmResult.getDefiningOp(); + auto funcOp = defOp ? rock::getParentFuncOp(defOp) : nullptr; + if (!funcOp) { + return info; + } + + // Walk all reduce operations and check if they trace back to our GEMM + BufferDependencyAnalysis deps(funcOp); + funcOp->walk([&](rock::ReduceOp reduceOp) { + DenseSet visited; + auto [reachesTarget, hasPointwise] = + tracesToTarget(reduceOp.getIn(), target, deps, visited); + + if (reachesTarget) { + ReductionInfo redInfo; + redInfo.method = reduceOp.getReduceMethod(); + redInfo.axis = reduceOp.getAxis().getSExtValue(); + redInfo.hasPointwiseBefore = hasPointwise; + info.reductions.push_back(redInfo); + } + }); + + return info; } // Analyze fusion patterns for a GEMM operation's output static FusionInfo analyzeOutputFusionPattern(Value gemmResult, GemmFeatures features) { - FusionInfo info; - analyzeUsers(gemmResult, features, info); + FusionInfo info = getFusionInfo(gemmResult, features); // Sort reductions for consistent ordering in problem key std::sort(info.reductions.begin(), info.reductions.end()); diff --git a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir index a68ac95a6c35..20f02ca55ba2 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:axis2:hasPointwise sum:axis2 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:axis2 sum:axis2:hasPointwise #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> From 8939734d9150897bd1a9d129a94ceea6a21f5f15 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Fri, 28 Nov 2025 20:44:14 +0000 Subject: [PATCH 12/17] Add reduction rank --- mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp | 15 ++++++++++----- .../conv-reduce-trace-linalg.mlir | 2 +- .../problem-key-tests/gemm-add-reduce-sum.mlir | 2 +- .../gemm-gemm-add-reduce-sum.mlir | 2 +- .../gemm-gemm-reduce-sum-axis1.mlir | 2 +- .../gemm-gemm-reduce-sum-axis2.mlir | 2 +- .../problem-key-tests/gemm-mul-reduce-sum.mlir | 2 +- .../gemm-multi-reduce-different-axes.mlir | 2 +- .../gemm-multi-reduce-layernorm.mlir | 2 +- .../gemm-passthrough-and-reduce.mlir | 2 +- .../problem-key-tests/gemm-reduce-max-axis2.mlir | 2 +- .../problem-key-tests/gemm-reduce-sum-axis1.mlir | 2 +- .../problem-key-tests/gemm-reduce-sum-axis2.mlir | 2 +- 13 files changed, 22 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index 0aecbf4d272c..05bcad2ce699 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -751,20 +751,23 @@ extractLayouts(Operation *op, llvm::StringMap &fLayoutMap, struct ReductionInfo { ReduceMethod method; int64_t axis; + int64_t rank; bool hasPointwiseBefore; bool operator<(const ReductionInfo &other) const { - // Sort by method first, then axis, then hasPointwiseBefore + // Sort by method first, then rank, then axis, then hasPointwiseBefore if (method != other.method) return method < other.method; + if (rank != other.rank) + return rank < other.rank; if (axis != other.axis) return axis < other.axis; - return hasPointwiseBefore < other.hasPointwiseBefore; + return hasPointwiseBefore > other.hasPointwiseBefore; } bool operator==(const ReductionInfo &other) const { return method == other.method && axis == other.axis && - hasPointwiseBefore == other.hasPointwiseBefore; + rank == other.rank && hasPointwiseBefore == other.hasPointwiseBefore; } }; @@ -861,6 +864,7 @@ static FusionInfo getFusionInfo(Value gemmResult, GemmFeatures features) { ReductionInfo redInfo; redInfo.method = reduceOp.getReduceMethod(); redInfo.axis = reduceOp.getAxis().getSExtValue(); + redInfo.rank = cast(reduceOp.getIn().getType()).getRank(); redInfo.hasPointwiseBefore = hasPointwise; info.reductions.push_back(redInfo); } @@ -891,7 +895,7 @@ static void appendOutputFusionInfo(llvm::raw_svector_ostream &problemOS, problemOS << sep << "-fusion_reduce" << sep << "count=" << fusionInfo.numReductionOutputs(); - // Encode each reduction in format: method:axis[:hasPointwise] + // Encode each reduction in format: method:rank:axis[:hasPointwise] for (const auto &reduction : fusionInfo.reductions) { problemOS << sep; @@ -905,7 +909,8 @@ static void appendOutputFusionInfo(llvm::raw_svector_ostream &problemOS, break; } - // Add reduction axis with colon separator + // Add rank and axis with colon separators + problemOS << ":rank" << reduction.rank; problemOS << ":axis" << reduction.axis; // Add pointwise flag for this specific reduction diff --git a/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir b/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir index c93a0bcd6081..33c4fa7c99f4 100644 --- a/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir +++ b/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: 256 convfp16 -F 1 -f GNC01 -I NGC01 -O NGC01 -n 1 -c 128 -H 32 -W 32 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -fusion_reduce count=2 sum:axis2:hasPointwise sum:axis2:hasPointwise +// CHECK: 256 convfp16 -F 1 -f GNC01 -I NGC01 -O NGC01 -n 1 -c 128 -H 32 -W 32 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -fusion_reduce count=2 sum:rank3:axis2:hasPointwise sum:rank3:axis2:hasPointwise #map = affine_map<(d0, d1, d2, d3) -> (((d0 * 128 + d1) * 3 + d2) * 3 + d3)> #map1 = affine_map<(d0, d1, d2, d3) -> ((d1 * 32 + d2) * 32 + d3)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir index 3fa1abc907b2..3acf4caf9ed1 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:axis2:hasPointwise +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2:hasPointwise #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir index 921d85901df0..2dd02542f35c 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:axis2:hasPointwise +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:rank3:axis2:hasPointwise #map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir index dd06f2e9f42d..59ea78aca59b 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:axis1 +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:rank3:axis1 #map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir index 251622a43eaa..afad3ff83b00 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:axis2 +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:rank3:axis2 #map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir index b1ccc6df6914..3e5dd81c25ae 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:axis2:hasPointwise +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2:hasPointwise #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir index 4f3f29cab340..c0498491a248 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:axis1 sum:axis2 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:rank3:axis1 sum:rank3:axis2 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir index 20f02ca55ba2..0a6947514cf3 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:axis2 sum:axis2:hasPointwise +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:rank3:axis2:hasPointwise sum:rank3:axis2 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir b/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir index 896f59f4fad9..401bf6b0620b 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:axis2 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir index c80bb39dcbb3..06c11199e7ce 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 max:axis2 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 max:rank3:axis2 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir index 7219c3504f48..0edd688c44cb 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:axis1 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis1 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir index 252e092de222..2ef110d7ee90 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:axis2 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> From dcc12314f43191a70fd1d1d747c5aaa3da1f7060 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Fri, 28 Nov 2025 20:48:52 +0000 Subject: [PATCH 13/17] Remove fusionUtil changes --- mlir/include/mlir/Dialect/Rock/utility/fusionUtils.h | 4 ---- mlir/lib/Dialect/Rock/utility/fusionUtils.cpp | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Rock/utility/fusionUtils.h b/mlir/include/mlir/Dialect/Rock/utility/fusionUtils.h index e50f58f22d5a..0e97832855a5 100644 --- a/mlir/include/mlir/Dialect/Rock/utility/fusionUtils.h +++ b/mlir/include/mlir/Dialect/Rock/utility/fusionUtils.h @@ -54,10 +54,6 @@ checkValidOutputFusion(linalg::GenericOp genericOp, Value gemmResult, GemmFeatures features, SmallVector> &adds); -// Checks whether an operation is a valid elementwise operation for GEMM output -// fusion (used for both split-K and reduction fusion analysis). -bool validOperationGemmOut(Operation &op); - } // end namespace rock } // end namespace mlir diff --git a/mlir/lib/Dialect/Rock/utility/fusionUtils.cpp b/mlir/lib/Dialect/Rock/utility/fusionUtils.cpp index 8aa7b144ae90..68d56776e162 100644 --- a/mlir/lib/Dialect/Rock/utility/fusionUtils.cpp +++ b/mlir/lib/Dialect/Rock/utility/fusionUtils.cpp @@ -33,7 +33,7 @@ using namespace mlir; using namespace mlir::rock; using namespace arith; -bool mlir::rock::validOperationGemmOut(Operation &op) { +bool validOperationGemmOut(Operation &op) { return isa(op); } From 144a7c843d4a689a14f3cebd8ea0a3592708d0a8 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Mon, 1 Dec 2025 17:22:14 +0000 Subject: [PATCH 14/17] Attend to review comments --- .../Dialect/Rock/Tuning/RockTuningImpl.cpp | 38 ++++++++----------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index 05bcad2ce699..7b3facfa8af0 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -795,19 +795,20 @@ static FailureOr getBaseValue(Value v) { } // Helper to trace backwards from a value to see if it reaches the target -static std::pair +// Returns success(hasPointwise) if target is reached, failure otherwise +static FailureOr tracesToTarget(Value start, Value target, const BufferDependencyAnalysis &deps, DenseSet &visited) { if (!visited.insert(start).second) { - return {false, false}; // Avoid cycles + return failure(); // Avoid cycles } FailureOr baseValue = getBaseValue(start); if (failed(baseValue)) - return {false, false}; // Could not find base value + return failure(); // Could not find base value if (*baseValue == target) { - return {true, false}; // Found target, with no pointwise inbetween + return false; // Found target, no pointwise } // For allocations, use BufferDependencyAnalysis to find writers @@ -822,17 +823,17 @@ tracesToTarget(Value start, Value target, const BufferDependencyAnalysis &deps, // Trace through inputs of the linalg.generic (assumed to be pointwise) for (Value input : genericOp.getInputs()) { - auto [reachesTarget, hasPointwise] = + FailureOr maybeHasPointwise = tracesToTarget(input, target, deps, visited); - if (reachesTarget) { - return {true, true}; // Found target through pointwise + if (succeeded(maybeHasPointwise)) { + return true; // Found target through pointwise } } } } } - return {false, false}; + return failure(); } // Find all reductions and check if they trace back to our GEMM output @@ -853,31 +854,24 @@ static FusionInfo getFusionInfo(Value gemmResult, GemmFeatures features) { return info; } - // Walk all reduce operations and check if they trace back to our GEMM + // Walk all reduce operations and check if they trace back to our GEMM. + // Note, we are assuming that all reduce operations are returned here. BufferDependencyAnalysis deps(funcOp); funcOp->walk([&](rock::ReduceOp reduceOp) { DenseSet visited; - auto [reachesTarget, hasPointwise] = + FailureOr maybeHasPointwise = tracesToTarget(reduceOp.getIn(), target, deps, visited); - if (reachesTarget) { + if (succeeded(maybeHasPointwise)) { ReductionInfo redInfo; redInfo.method = reduceOp.getReduceMethod(); redInfo.axis = reduceOp.getAxis().getSExtValue(); redInfo.rank = cast(reduceOp.getIn().getType()).getRank(); - redInfo.hasPointwiseBefore = hasPointwise; + redInfo.hasPointwiseBefore = *maybeHasPointwise; info.reductions.push_back(redInfo); } }); - return info; -} - -// Analyze fusion patterns for a GEMM operation's output -static FusionInfo analyzeOutputFusionPattern(Value gemmResult, - GemmFeatures features) { - FusionInfo info = getFusionInfo(gemmResult, features); - // Sort reductions for consistent ordering in problem key std::sort(info.reductions.begin(), info.reductions.end()); @@ -1099,7 +1093,7 @@ getTuningProblemStr(RockGemmGemmWrapperInterface gemmGemmOp, // Analyze and append fusion information Value gemmGemmOutput = gemmGemmOp.getOutArgument()->get(); GemmFeatures features = rock::getFeatures(gemmGemmOp); - FusionInfo fusionInfo = analyzeOutputFusionPattern(gemmGemmOutput, features); + FusionInfo fusionInfo = getFusionInfo(gemmGemmOutput, features); appendOutputFusionInfo(problemOS, fusionInfo); return success(); @@ -1322,7 +1316,7 @@ static LogicalResult getTuningProblemStr(rock::RockGemmWrapperInterface gemmIF, // Analyze and append fusion information Value gemmOutput = gemmIF.getOutArgument()->get(); GemmFeatures features = rock::getFeatures(gemmIF); - FusionInfo fusionInfo = analyzeOutputFusionPattern(gemmOutput, features); + FusionInfo fusionInfo = getFusionInfo(gemmOutput, features); appendOutputFusionInfo(problemOS, fusionInfo); while (out.back() == sep) { From ac72c153ab8a091f16fed35d96a80147c9cc80ea Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Mon, 1 Dec 2025 17:22:39 +0000 Subject: [PATCH 15/17] Clang-format --- mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index 7b3facfa8af0..c9a0dcf379c7 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -766,8 +766,8 @@ struct ReductionInfo { } bool operator==(const ReductionInfo &other) const { - return method == other.method && axis == other.axis && - rank == other.rank && hasPointwiseBefore == other.hasPointwiseBefore; + return method == other.method && axis == other.axis && rank == other.rank && + hasPointwiseBefore == other.hasPointwiseBefore; } }; @@ -796,9 +796,9 @@ static FailureOr getBaseValue(Value v) { // Helper to trace backwards from a value to see if it reaches the target // Returns success(hasPointwise) if target is reached, failure otherwise -static FailureOr -tracesToTarget(Value start, Value target, const BufferDependencyAnalysis &deps, - DenseSet &visited) { +static FailureOr tracesToTarget(Value start, Value target, + const BufferDependencyAnalysis &deps, + DenseSet &visited) { if (!visited.insert(start).second) { return failure(); // Avoid cycles } From d3a6a7a8928a118ebc9953db44a4d671c24cd9cd Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Mon, 1 Dec 2025 18:01:13 +0000 Subject: [PATCH 16/17] Add strides --- .../Dialect/Rock/Tuning/RockTuningImpl.cpp | 32 ++++++++++++++++--- .../conv-reduce-trace-linalg.mlir | 2 +- .../gemm-add-reduce-sum.mlir | 2 +- .../gemm-gemm-add-reduce-sum.mlir | 2 +- .../gemm-gemm-reduce-sum-axis1.mlir | 2 +- .../gemm-gemm-reduce-sum-axis2.mlir | 2 +- .../gemm-mul-reduce-sum.mlir | 2 +- .../gemm-multi-reduce-different-axes.mlir | 2 +- .../gemm-multi-reduce-layernorm.mlir | 2 +- .../gemm-passthrough-and-reduce.mlir | 2 +- .../gemm-reduce-max-axis2.mlir | 2 +- .../gemm-reduce-sum-axis1.mlir | 2 +- .../gemm-reduce-sum-axis2.mlir | 2 +- 13 files changed, 39 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index c9a0dcf379c7..854c3d39c6df 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -752,22 +752,25 @@ struct ReductionInfo { ReduceMethod method; int64_t axis; int64_t rank; + int64_t stride; // Stride of the reduction dimension bool hasPointwiseBefore; bool operator<(const ReductionInfo &other) const { - // Sort by method first, then rank, then axis, then hasPointwiseBefore + // Sort by method first, then rank, then axis, then stride, then hasPointwiseBefore if (method != other.method) return method < other.method; if (rank != other.rank) return rank < other.rank; if (axis != other.axis) return axis < other.axis; + if (stride != other.stride) + return stride < other.stride; return hasPointwiseBefore > other.hasPointwiseBefore; } bool operator==(const ReductionInfo &other) const { return method == other.method && axis == other.axis && rank == other.rank && - hasPointwiseBefore == other.hasPointwiseBefore; + stride == other.stride && hasPointwiseBefore == other.hasPointwiseBefore; } }; @@ -866,7 +869,19 @@ static FusionInfo getFusionInfo(Value gemmResult, GemmFeatures features) { ReductionInfo redInfo; redInfo.method = reduceOp.getReduceMethod(); redInfo.axis = reduceOp.getAxis().getSExtValue(); - redInfo.rank = cast(reduceOp.getIn().getType()).getRank(); + auto memrefType = cast(reduceOp.getIn().getType()); + redInfo.rank = memrefType.getRank(); + + // Extract stride for the reduction dimension + SmallVector strides; + int64_t offset; + if (succeeded(memrefType.getStridesAndOffset(strides, offset))) { + redInfo.stride = strides[redInfo.axis]; + } else { + // If we can't determine stride, use dynamic sentinel + redInfo.stride = ShapedType::kDynamic; + } + redInfo.hasPointwiseBefore = *maybeHasPointwise; info.reductions.push_back(redInfo); } @@ -889,7 +904,7 @@ static void appendOutputFusionInfo(llvm::raw_svector_ostream &problemOS, problemOS << sep << "-fusion_reduce" << sep << "count=" << fusionInfo.numReductionOutputs(); - // Encode each reduction in format: method:rank:axis[:hasPointwise] + // Encode each reduction in format: method:rank:axis:stride[:hasPointwise] for (const auto &reduction : fusionInfo.reductions) { problemOS << sep; @@ -903,9 +918,16 @@ static void appendOutputFusionInfo(llvm::raw_svector_ostream &problemOS, break; } - // Add rank and axis with colon separators + // Add rank, axis, and stride with colon separators problemOS << ":rank" << reduction.rank; problemOS << ":axis" << reduction.axis; + + // Add stride (use '?' for dynamic/unknown strides) + if (reduction.stride == ShapedType::kDynamic) { + problemOS << ":stride?"; + } else { + problemOS << ":stride" << reduction.stride; + } // Add pointwise flag for this specific reduction if (reduction.hasPointwiseBefore) { diff --git a/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir b/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir index 33c4fa7c99f4..24fbfde2c223 100644 --- a/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir +++ b/mlir/test/fusion/problem-key-tests/conv-reduce-trace-linalg.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: 256 convfp16 -F 1 -f GNC01 -I NGC01 -O NGC01 -n 1 -c 128 -H 32 -W 32 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -fusion_reduce count=2 sum:rank3:axis2:hasPointwise sum:rank3:axis2:hasPointwise +// CHECK: 256 convfp16 -F 1 -f GNC01 -I NGC01 -O NGC01 -n 1 -c 128 -H 32 -W 32 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -fusion_reduce count=2 sum:rank3:axis2:stride1:hasPointwise sum:rank3:axis2:stride1:hasPointwise #map = affine_map<(d0, d1, d2, d3) -> (((d0 * 128 + d1) * 3 + d2) * 3 + d3)> #map1 = affine_map<(d0, d1, d2, d3) -> ((d1 * 32 + d2) * 32 + d3)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir index 3acf4caf9ed1..42ccf00c9444 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-add-reduce-sum.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2:hasPointwise +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2:stride1:hasPointwise #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir index 2dd02542f35c..cd6dad7d7e45 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-add-reduce-sum.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:rank3:axis2:hasPointwise +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:rank3:axis2:stride1:hasPointwise #map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir index 59ea78aca59b..d9ca8fcd9b0b 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis1.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:rank3:axis1 +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:rank3:axis1:stride1 #map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir index afad3ff83b00..c83e6814fac4 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-gemm-reduce-sum-axis2.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:rank3:axis2 +// CHECK: gfx942 120 -t f32 -transA false -transB false -transC false -transO false -g 1 -m 128 -n 256 -k 64 -gemmO 128 -fusion_reduce count=1 sum:rank3:axis2:stride1 #map = affine_map<(d0, d1, d2) -> (d1 * 128 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir b/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir index 3e5dd81c25ae..56da4965b50b 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-mul-reduce-sum.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2:hasPointwise +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2:stride1:hasPointwise #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir index c0498491a248..93e0d793c8cd 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-different-axes.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:rank3:axis1 sum:rank3:axis2 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:rank3:axis1:stride256 sum:rank3:axis2:stride1 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir index 0a6947514cf3..224a73582433 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-multi-reduce-layernorm.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:rank3:axis2:hasPointwise sum:rank3:axis2 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=2 sum:rank3:axis2:stride1:hasPointwise sum:rank3:axis2:stride1 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir b/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir index 401bf6b0620b..939e4c4bb1c0 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-passthrough-and-reduce.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2:stride1 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir index 06c11199e7ce..29e633090481 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-max-axis2.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 max:rank3:axis2 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 max:rank3:axis2:stride1 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir index 0edd688c44cb..b6cc84d4c4e8 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis1.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis1 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis1:stride256 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> diff --git a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir index 2ef110d7ee90..7a3d54595a59 100644 --- a/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir +++ b/mlir/test/fusion/problem-key-tests/gemm-reduce-sum-axis2.mlir @@ -1,6 +1,6 @@ // RUN: rocmlir-gen --emit-tuning-key %s | FileCheck %s -// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2 +// CHECK: gfx942 120 -t f32 -out_datatype f32 -transA false -transB false -g 1 -m 128 -n 256 -k 64 -fusion_reduce count=1 sum:rank3:axis2:stride1 #map = affine_map<(d0, d1, d2) -> (d1 * 256 + d2)> #map1 = affine_map<(d0, d1, d2) -> (d1 * 64 + d2)> From bf18a0bd9e6ec0e1efdfba0be06dadc73126d387 Mon Sep 17 00:00:00 2001 From: Justin Rosner Date: Mon, 1 Dec 2025 19:00:33 +0000 Subject: [PATCH 17/17] Git clang-format again --- mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp index 854c3d39c6df..45dbb20127c0 100644 --- a/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp +++ b/mlir/lib/Dialect/Rock/Tuning/RockTuningImpl.cpp @@ -756,7 +756,8 @@ struct ReductionInfo { bool hasPointwiseBefore; bool operator<(const ReductionInfo &other) const { - // Sort by method first, then rank, then axis, then stride, then hasPointwiseBefore + // Sort by method first, then rank, then axis, then stride, then + // hasPointwiseBefore if (method != other.method) return method < other.method; if (rank != other.rank) @@ -770,7 +771,8 @@ struct ReductionInfo { bool operator==(const ReductionInfo &other) const { return method == other.method && axis == other.axis && rank == other.rank && - stride == other.stride && hasPointwiseBefore == other.hasPointwiseBefore; + stride == other.stride && + hasPointwiseBefore == other.hasPointwiseBefore; } }; @@ -871,7 +873,7 @@ static FusionInfo getFusionInfo(Value gemmResult, GemmFeatures features) { redInfo.axis = reduceOp.getAxis().getSExtValue(); auto memrefType = cast(reduceOp.getIn().getType()); redInfo.rank = memrefType.getRank(); - + // Extract stride for the reduction dimension SmallVector strides; int64_t offset; @@ -881,7 +883,7 @@ static FusionInfo getFusionInfo(Value gemmResult, GemmFeatures features) { // If we can't determine stride, use dynamic sentinel redInfo.stride = ShapedType::kDynamic; } - + redInfo.hasPointwiseBefore = *maybeHasPointwise; info.reductions.push_back(redInfo); } @@ -921,7 +923,7 @@ static void appendOutputFusionInfo(llvm::raw_svector_ostream &problemOS, // Add rank, axis, and stride with colon separators problemOS << ":rank" << reduction.rank; problemOS << ":axis" << reduction.axis; - + // Add stride (use '?' for dynamic/unknown strides) if (reduction.stride == ShapedType::kDynamic) { problemOS << ":stride?";