11#include " common.cuh"
22
3+ struct mma_int_A_I16K4 {
4+ static constexpr int I = 16 ;
5+ static constexpr int K = 4 ;
6+ static constexpr int ne = 2 ;
7+
8+ int x[ne] = {0 };
9+
10+ static __device__ __forceinline__ int get_i (const int l) {
11+ const int ret = (l%2 ) * (I/2 ) + threadIdx .x / K;
12+ GGML_CUDA_ASSUME (ret >= 0 );
13+ GGML_CUDA_ASSUME (ret < I);
14+ return ret;
15+ }
16+
17+ static __device__ __forceinline__ int get_k (const int /* l */ ) {
18+ const int ret = threadIdx .x % K;
19+ GGML_CUDA_ASSUME (ret >= 0 );
20+ GGML_CUDA_ASSUME (ret < K);
21+ return ret;
22+ }
23+ };
24+
325struct mma_int_A_I16K8 {
426 static constexpr int I = 16 ;
527 static constexpr int K = 8 ;
@@ -22,6 +44,28 @@ struct mma_int_A_I16K8 {
2244 }
2345};
2446
47+ struct mma_int_B_J8K4 {
48+ static constexpr int J = 8 ;
49+ static constexpr int K = 4 ;
50+ static constexpr int ne = 1 ;
51+
52+ int x[ne] = {0 };
53+
54+ static __device__ __forceinline__ int get_j (const int /* l */ ) {
55+ const int ret = threadIdx .x / K;
56+ GGML_CUDA_ASSUME (ret >= 0 );
57+ GGML_CUDA_ASSUME (ret < J);
58+ return ret;
59+ }
60+
61+ static __device__ __forceinline__ int get_k (const int /* l */ ) {
62+ const int ret = threadIdx .x % K;
63+ GGML_CUDA_ASSUME (ret >= 0 );
64+ GGML_CUDA_ASSUME (ret < K);
65+ return ret;
66+ }
67+ };
68+
2569struct mma_int_B_J8K8 {
2670 static constexpr int J = 8 ;
2771 static constexpr int K = 8 ;
@@ -65,6 +109,28 @@ struct mma_int_C_I16J8 {
65109 return ret;
66110 }
67111
112+ __device__ __forceinline__ void mma_K4 (const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
113+ #ifdef INT8_MMA_AVAILABLE
114+ #if __CUDA_ARCH__ >= CC_AMPERE
115+ asm (" mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
116+ : " +r" (x[0 ]), " +r" (x[1 ]), " +r" (x[2 ]), " +r" (x[3 ])
117+ : " r" (mma_A.x [0 ]), " r" (mma_A.x [1 ]), " r" (mma_B.x [0 ]));
118+ #else
119+ // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
120+ asm (" mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
121+ : " +r" (x[0 ]), " +r" (x[1 ])
122+ : " r" (mma_A.x [0 ]), " r" (mma_B.x [0 ]));
123+ asm (" mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
124+ : " +r" (x[2 ]), " +r" (x[3 ])
125+ : " r" (mma_A.x [1 ]), " r" (mma_B.x [0 ]));
126+ #endif // __CUDA_ARCH__ >= CC_AMPERE
127+ #else
128+ GGML_UNUSED (mma_A);
129+ GGML_UNUSED (mma_B);
130+ NO_DEVICE_CODE;
131+ #endif // INT8_MMA_AVAILABLE
132+ }
133+
68134 __device__ __forceinline__ void mma_K8 (const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
69135#ifdef INT8_MMA_AVAILABLE
70136#if __CUDA_ARCH__ >= CC_AMPERE
0 commit comments