Skip to content

Commit 9b59641

Browse files
CUDA: quantized KV support for FA vec (ggml-org#7527)
* CUDA: quantized KV support for FA vec * try CI fix * fix commented-out kernel variants * add q8_0 q4_0 tests * fix nwarps > batch size * split fattn compile via extern templates * fix flake8 * fix metal tests * fix cmake * make generate_cu_files.py executable * add autogenerated .cu files * fix AMD * error if type_v != FP16 and not flash_attn * remove obsolete code
1 parent a323ec6 commit 9b59641

File tree

110 files changed

+2649
-1152
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

110 files changed

+2649
-1152
lines changed

CMakeLists.txt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
106106
"llama: max. batch size for using peer access")
107107
option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF)
108108
option(LLAMA_CUDA_NO_VMM "llama: do not try to use CUDA VMM" OFF)
109+
option(LLAMA_CUDA_FA_ALL_QUANTS "llama: compile all quants for FlashAttention" OFF)
109110

110111
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
111112
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
@@ -402,6 +403,8 @@ if (LLAMA_CUDA)
402403

403404
file(GLOB GGML_SOURCES_CUDA "ggml-cuda/*.cu")
404405
list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu")
406+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
407+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
405408

406409
add_compile_definitions(GGML_USE_CUDA)
407410
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
@@ -427,6 +430,18 @@ if (LLAMA_CUDA)
427430
if (LLAMA_CUDA_NO_PEER_COPY)
428431
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
429432
endif()
433+
if (LLAMA_CUDA_FA_ALL_QUANTS)
434+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu")
435+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
436+
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
437+
else()
438+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
439+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
440+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
441+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
442+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
443+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
444+
endif()
430445

431446
if (LLAMA_STATIC)
432447
if (WIN32)
@@ -571,6 +586,8 @@ if (LLAMA_HIPBLAS)
571586

572587
file(GLOB GGML_SOURCES_ROCM "ggml-cuda/*.cu")
573588
list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu")
589+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
590+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
574591

575592
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA)
576593

@@ -590,6 +607,19 @@ if (LLAMA_HIPBLAS)
590607
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
591608
endif()
592609

610+
if (LLAMA_CUDA_FA_ALL_QUANTS)
611+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu")
612+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
613+
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
614+
else()
615+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
616+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
617+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
618+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
619+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
620+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
621+
endif()
622+
593623
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
594624
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
595625
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})

Makefile

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,15 @@ ifdef LLAMA_CUBLAS
421421
LLAMA_CUDA := 1
422422
endif
423423

424+
OBJS_CUDA_TEMP_INST = $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-wmma*.cu))
425+
ifdef LLAMA_CUDA_FA_ALL_QUANTS
426+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*.cu))
427+
else
428+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu))
429+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu))
430+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*f16-f16.cu))
431+
endif # LLAMA_CUDA_FA_ALL_QUANTS
432+
424433
ifdef LLAMA_CUDA
425434
ifneq ('', '$(wildcard /opt/cuda)')
426435
CUDA_PATH ?= /opt/cuda
@@ -431,6 +440,7 @@ ifdef LLAMA_CUDA
431440
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
432441
OBJS += ggml-cuda.o
433442
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
443+
OBJS += $(OBJS_CUDA_TEMP_INST)
434444
MK_NVCCFLAGS += -use_fast_math
435445
ifdef LLAMA_FATAL_WARNINGS
436446
MK_NVCCFLAGS += -Werror all-warnings
@@ -493,7 +503,10 @@ ifdef LLAMA_CUDA_NO_PEER_COPY
493503
endif # LLAMA_CUDA_NO_PEER_COPY
494504
ifdef LLAMA_CUDA_CCBIN
495505
MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
496-
endif
506+
endif # LLAMA_CUDA_CCBIN
507+
ifdef LLAMA_CUDA_FA_ALL_QUANTS
508+
MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
509+
endif # LLAMA_CUDA_FA_ALL_QUANTS
497510

498511
ifdef JETSON_EOL_MODULE_DETECT
499512
define NVCC_COMPILE
@@ -505,7 +518,7 @@ define NVCC_COMPILE
505518
endef # NVCC_COMPILE
506519
endif # JETSON_EOL_MODULE_DETECT
507520

508-
ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
521+
ggml-cuda/%.o: ggml-cuda/%.cu ggml.h ggml-common.h ggml-cuda/common.cuh
509522
$(NVCC_COMPILE)
510523

511524
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
@@ -585,11 +598,12 @@ ifdef LLAMA_CUDA_NO_PEER_COPY
585598
endif # LLAMA_CUDA_NO_PEER_COPY
586599
OBJS += ggml-cuda.o
587600
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
601+
OBJS += $(OBJS_CUDA_TEMP_INST)
588602

589603
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
590604
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
591605

592-
ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
606+
ggml-cuda/%.o: ggml-cuda/%.cu ggml.h ggml-common.h ggml-cuda/common.cuh
593607
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
594608

595609
endif # LLAMA_HIPBLAS
@@ -749,6 +763,7 @@ libllama.a: llama.o ggml.o $(OBJS) $(COMMON_DEPS)
749763
clean:
750764
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
751765
rm -vrf ggml-cuda/*.o
766+
rm -vrf ggml-cuda/template-instances/*.o
752767
find examples pocs -type f -name "*.o" -delete
753768

754769
#

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ Building the program with BLAS support may lead to some performance improvements
501501
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
502502
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
503503
| LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
504+
| LLAMA_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |
504505

505506
- #### hipBLAS
506507

ggml-cuda.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2905,10 +2905,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
29052905
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
29062906
return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
29072907
#else
2908-
if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
2908+
if (op->src[0]->ne[0] == 128) {
29092909
return true;
29102910
}
2911-
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
2911+
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
2912+
return true;
2913+
}
2914+
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
2915+
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
29122916
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
29132917
default:
29142918
return false;

0 commit comments

Comments
 (0)