diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index cc5aa422..a3346711 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -186,9 +186,8 @@ jobs: defines: '-DRWKV_AVX512=ON' - build: 'cuda12' defines: '-DRWKV_CUBLAS=ON' - - build: 'rocm5.5' - defines: '-G "Unix Makefiles" -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DRWKV_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS="gfx1100;gfx1102;gfx1030"' - + - build: 'hip' + defines: '' steps: - name: Clone id: checkout @@ -206,25 +205,52 @@ jobs: - name: Install rocm-toolkit id: rocm-toolkit - if: ${{ matrix.build == 'rocm5.5' }} - uses: Cyberhan123/rocm-toolkit@v0.1.0 - with: - rocm: '5.5.0' + if: ${{ matrix.build == 'hip' }} + run: | + $ErrorActionPreference = "Stop" + write-host "Downloading AMD HIP SDK Installer" + Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" + write-host "Installing AMD HIP SDK" + Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait + write-host "Completed AMD HIP SDK installation" + + - name: Verify ROCm + id: rocm-verify + if: ${{ matrix.build == 'hip' }} + run: | + & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version - name: Install Ninja id: install-ninja - if: ${{ matrix.build == 'rocm5.5' }} + if: ${{ matrix.build == 'hip' }} uses: urkle/action-get-ninja@v1 with: version: 1.11.1 + - name: Install ccache + uses: hendrikmuhs/ccache-action@v1.2 + with: + key: ${{ github.job }} + - name: Build id: cmake_build + if: ${{ matrix.build != 'hip' }} run: | mkdir build cd build cmake .. ${{ matrix.defines }} - cmake --build . --config Release + cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS} + + - name: Build-hip + id: cmake_build_hip + if: ${{ matrix.build == 'hip' }} + run: | + mkdir build + cd build + $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) + $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" + cmake .. -G "Unix Makefiles" -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DRWKV_HIPBLAS=ON -DGGML_HIP=ON -DCMAKE_BUILD_TYPE=Release + cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS} - name: Check AVX512F support id: check_avx512f @@ -242,7 +268,7 @@ jobs: - name: Test id: cmake_test # Test AVX-512 only when possible - if: ${{ (matrix.build != 'avx512' || env.HAS_AVX512F == '1') && matrix.build != 'cuda12' && matrix.build != 'rocm5.5'}} + if: ${{ (matrix.build != 'avx512' || env.HAS_AVX512F == '1') && matrix.build != 'cuda12' && matrix.build != 'hip'}} run: | cd build ctest -C Release --verbose diff --git a/CMakeLists.txt b/CMakeLists.txt index 56c057c2..8b9c9710 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,7 +58,7 @@ endfunction() set(GGML_ACCELERATE ${RWKV_ACCELERATE}) set(GGML_CUDA ${RWKV_CUBLAS}) -set(GGML_HIPBLAS ${RWKV_HIPBLAS}) +set(GGML_HIP ${RWKV_HIPBLAS}) set(GGML_METAL ${RWKV_METAL}) if (RWKV_OPENBLAS) set(GGML_BLAS_VENDOR "OpenBLAS") @@ -107,6 +107,7 @@ if (RWKV_ALL_WARNINGS) -Wcast-qual -Wno-unused-function -Wno-multichar + -Wno-nonnull ) else() set(c_flags @@ -234,7 +235,7 @@ if (GGML_METAL) ) endif() -if (GGML_HIPBLAS) +if (GGML_HIP) # CMake on Windows doesn't support the HIP language yet if (WIN32) set(CXX_IS_HIPCC TRUE) @@ -262,12 +263,39 @@ if (GGML_HIPBLAS) endif() target_include_directories(rwkv PUBLIC .) -target_include_directories(rwkv PRIVATE ggml/include) +target_include_directories(rwkv PRIVATE ggml/include ggml/src) target_compile_features(rwkv PUBLIC cxx_std_11) -target_link_libraries(rwkv PRIVATE $ ${RWKV_EXTRA_LIBS}) + +if (GGML_METAL) + set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} $ $) +endif() +if (GGML_CUDA) + set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} $) +endif() +if (GGML_HIP) + set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} $) +endif() +if (GGML_RPC) + set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} $) +endif() + +target_link_libraries(rwkv PRIVATE $ $ $ ${RWKV_EXTRA_LIBS}) if (RWKV_BUILD_SHARED_LIBRARY) set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) + set_target_properties(ggml-base PROPERTIES POSITION_INDEPENDENT_CODE ON) + set_target_properties(ggml-cpu PROPERTIES POSITION_INDEPENDENT_CODE ON) + if (GGML_METAL) + set_target_properties(ggml-metal PROPERTIES POSITION_INDEPENDENT_CODE ON) + set_target_properties(ggml-blas PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + if (GGML_CUDA) + set_target_properties(ggml-cuda PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + if (GGML_HIP) + set_target_properties(ggml-hip PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + target_compile_definitions(ggml PRIVATE GGML_SHARED GGML_BUILD) set_target_properties(rwkv PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(rwkv PRIVATE RWKV_SHARED RWKV_BUILD) diff --git a/README.md b/README.md index 134b4c43..ac7767cb 100644 --- a/README.md +++ b/README.md @@ -6,20 +6,18 @@ Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](python%2Frwkv_cpp%2Frwkv_cpp_model.py) for it. -[RWKV](https://arxiv.org/abs/2305.13048) is a large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts. +[RWKV](https://arxiv.org/abs/2305.13048) is a large language model architecture. In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts. -[RWKV v5](https://huggingface.co/BlinkDL/rwkv-5-world) is a major upgrade to RWKV architecture, making it competitive with Transformers in quality. RWKV v5 models are supported. - -[RWKV v6](https://huggingface.co/BlinkDL/rwkv-6-world) is a further improvement to RWKV architecture, with better quality. RWKV v6 models are supported. +This project supports RWKV [v4](https://huggingface.co/BlinkDL/rwkv-4-pile-14b), [v5](https://huggingface.co/BlinkDL/rwkv-5-world), [v6](https://huggingface.co/BlinkDL/rwkv-6-world) and the latest [v7](https://huggingface.co/BlinkDL/rwkv-7-world) architectures. Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py). + + ## Quality and performance If you use `rwkv.cpp` for anything serious, please [test all available formats for perplexity and latency](rwkv%2Fmeasure_pexplexity.py) on a representative dataset, and decide which trade-off is best for you. -In general, **`RWKV v5` models are as fast as `RWKV v4` models**, with minor differencies in latency and memory consumption, and with having way higher quality than `v4`. Therefore, it is recommended to use `RWKV v5`. - Below table is for reference only. Measurements were made on 4C/8T x86 CPU with AVX2, 4 threads. The models are `RWKV v4 Pile 169M`, `RWKV v4 Pile 1.5B`. | Format | Perplexity (169M) | Latency, ms (1.5B) | File size, GB (1.5B) | diff --git a/extras/quantize.c b/extras/quantize.c index 578e632c..33e7ed32 100644 --- a/extras/quantize.c +++ b/extras/quantize.c @@ -25,8 +25,10 @@ bool QueryPerformanceCounter(uint64_t* lpPerformanceCount); static enum ggml_type type_from_string(const char * string) { if (strcmp(string, "Q4_0") == 0) return GGML_TYPE_Q4_0; if (strcmp(string, "Q4_1") == 0) return GGML_TYPE_Q4_1; + if (strcmp(string, "Q4_K") == 0) return GGML_TYPE_Q4_K; if (strcmp(string, "Q5_0") == 0) return GGML_TYPE_Q5_0; if (strcmp(string, "Q5_1") == 0) return GGML_TYPE_Q5_1; + if (strcmp(string, "Q5_K") == 0) return GGML_TYPE_Q5_K; if (strcmp(string, "Q8_0") == 0) return GGML_TYPE_Q8_0; return GGML_TYPE_COUNT; } diff --git a/ggml b/ggml index 3e7e5e26..c8bd0fee 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 3e7e5e26f90fecf4f7c2808df7d94454630b219c +Subproject commit c8bd0fee71dc8328d93be301bbee06bc10d30429 diff --git a/python/chat_with_bot.py b/python/chat_with_bot.py index 17fd1c17..6c6d2a8b 100644 --- a/python/chat_with_bot.py +++ b/python/chat_with_bot.py @@ -40,6 +40,7 @@ parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model') parser.add_argument('model_path', help='Path to RWKV model in ggml format') +parser.add_argument('-ngl', '--num_gpu_layers', type=int, default=99, help='Number of layers to run on GPU') add_tokenizer_argument(parser) args = parser.parse_args() @@ -48,7 +49,7 @@ with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r', encoding='utf8') as json_file: prompt_data = json.load(json_file) - user, bot, separator, init_prompt = prompt_data['user'], prompt_data['bot'], prompt_data['separator'], prompt_data['prompt'] + user, assistant, separator, init_prompt = prompt_data['user'], prompt_data['assistant'], prompt_data['separator'], prompt_data['prompt'] if init_prompt == '': raise ValueError('Prompt must not be empty') @@ -57,7 +58,7 @@ print(f'System info: {library.rwkv_get_system_info_string()}') print('Loading RWKV model') -model = rwkv_cpp_model.RWKVModel(library, args.model_path) +model = rwkv_cpp_model.RWKVModel(library, args.model_path, gpu_layer_count=args.num_gpu_layers) tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab) @@ -154,7 +155,7 @@ def split_last_end_of_line(tokens: List[int]) -> List[int]: if msg == '+reset': load_thread_state('chat_init') save_thread_state('chat') - print(f'{bot}{separator} Chat reset.\n') + print(f'{assistant}{separator} Chat reset.\n') continue elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++': @@ -194,7 +195,7 @@ def split_last_end_of_line(tokens: List[int]) -> List[int]: load_thread_state('chat_init') real_msg = msg[4:].strip() - new = f'{user}{separator} {real_msg}\n\n{bot}{separator}' + new = f'{user}{separator} {real_msg}\n\n{assistant}{separator}' process_tokens(tokenizer_encode(new)) save_thread_state('gen_0') @@ -225,17 +226,17 @@ def split_last_end_of_line(tokens: List[int]) -> List[int]: except Exception as e: print(e) continue - # chat with bot + # chat with assistant else: load_thread_state('chat') - new = f'{user}{separator} {msg}\n\n{bot}{separator}' + new = f'{user}{separator} {msg}\n\n{assistant}{separator}' process_tokens(tokenizer_encode(new), new_line_logit_bias=-999999999) save_thread_state('chat_pre') thread = 'chat' - # Print bot response - print(f'> {bot}{separator}', end='') + # Print assistant response + print(f'> {assistant}{separator}', end='') start_index: int = len(processed_tokens) accumulated_tokens: List[int] = [] diff --git a/python/convert_pytorch_to_ggml.py b/python/convert_pytorch_to_ggml.py index 2c413dca..5bcd0088 100644 --- a/python/convert_pytorch_to_ggml.py +++ b/python/convert_pytorch_to_ggml.py @@ -35,8 +35,11 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict is_v6_0: bool = 'blocks.0.att.time_maa_x' in state_dict + is_v7_0: bool = 'blocks.0.att.k_k' in state_dict - if is_v6_0: + if is_v7_0: + print('Detected RWKV v7.0') + elif is_v6_0: print('Detected RWKV v6.0') elif is_v5_2: print('Detected RWKV v5.2') @@ -45,6 +48,23 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t else: print('Detected RWKV v4') + if is_v7_0: + # concat to reduce some cpu overhead during ggml inference + state_dict_new = {} + for k in state_dict.keys(): + if 'att.x_' in k: + l = int(k.split('.')[1].split('.')[0]) + try: + state_dict_new[f'blocks.{l}.att.x_rwkvag'] = torch.cat( + [state_dict_new[f'blocks.{l}.att.x_rwkvag'], state_dict[k]], dim=0) + except KeyError: + state_dict_new[f'blocks.{l}.att.x_rwkvag'] = state_dict[k] + else: + state_dict_new[k] = state_dict[k] + + del state_dict[k] + state_dict = state_dict_new + with open(dest_path, 'wb') as out_file: is_FP16: bool = data_type == 'FP16' or data_type == 'float16' @@ -68,7 +88,16 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t if '.time_' in k: tensor = tensor.squeeze() - if is_v6_0: + if is_v7_0: + if any(s in k for s in [ + '.w1', '.w2', + '.a1', '.a2', + '.v1', '.v2', + '.g1', '.g2', + ]): + tensor = tensor.transpose(0, 1) + + elif is_v6_0: if '.time_faaaa' in k: tensor = tensor.unsqueeze(-1) if '.time_maa_w1' in k or '.time_decay_w' in k: @@ -95,7 +124,14 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t tensor = -torch.exp(tensor) # Keep 1-dim vectors and small matrices in FP32 - if is_FP16 and len(tensor.shape) > 1 and '.time_' not in k: + if is_FP16 and len(tensor.shape) > 1 and all( + s not in k for s in [ + '.time_', + '.k_k', '.k_a', '.r_k', + '.x_rwkvag', '.x_k', + '.w0', '.a0', '.v0', + ] + ): tensor = tensor.half() shape = tensor.shape diff --git a/python/generate_completions.py b/python/generate_completions.py index f0d9aa14..cc14ef1d 100644 --- a/python/generate_completions.py +++ b/python/generate_completions.py @@ -29,6 +29,7 @@ parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt') parser.add_argument('model_path', help='Path to RWKV model in ggml format') +parser.add_argument('-ngl', '--num_gpu_layers', type=int, default=99, help='Number of layers to run on GPU') add_tokenizer_argument(parser) args = parser.parse_args() @@ -39,7 +40,7 @@ print(f'System info: {library.rwkv_get_system_info_string()}') print('Loading RWKV model') -model = rwkv_cpp_model.RWKVModel(library, args.model_path, gpu_layers_count=0) +model = rwkv_cpp_model.RWKVModel(library, args.model_path, gpu_layers_count=args.num_gpu_layers) tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab) diff --git a/python/inference_example.py b/python/inference_example.py index efd0016e..f8f1dbf1 100644 --- a/python/inference_example.py +++ b/python/inference_example.py @@ -10,12 +10,13 @@ # Parse received arguments. parser = argparse.ArgumentParser(description='Generate some text with an RWKV model') parser.add_argument('model_path', help='Path to RWKV model in ggml format') +parser.add_argument('-ngl', '--num_gpu_layers', type=int, default=99, help='Number of layers to run on GPU') add_tokenizer_argument(parser) args = parser.parse_args() # Load the model. library = rwkv_cpp_shared_library.load_rwkv_shared_library() -model = rwkv_cpp_model.RWKVModel(library, args.model_path) +model = rwkv_cpp_model.RWKVModel(library, args.model_path, gpu_layer_count=args.num_gpu_layers) # Set up the tokenizer. tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab) diff --git a/python/prompt/Chinese-Chat.json b/python/prompt/Chinese-Chat.json index 73acae78..03c7f84e 100644 --- a/python/prompt/Chinese-Chat.json +++ b/python/prompt/Chinese-Chat.json @@ -1,6 +1,6 @@ { "user": "Bob", - "bot": "Alice", + "assistant": "Alice", "separator": ":", "prompt": "\nThe following is a coherent verbose detailed conversation between a Chinese girl named Alice and her friend Bob. Alice is very intelligent, creative and friendly. Alice likes to tell Bob a lot about herself and her opinions. Alice usually gives Bob kind, helpful and informative advices.\n\nBob: lhc\n\nAlice: LHC是指大型强子对撞机(Large Hadron Collider),是世界最大最强的粒子加速器,由欧洲核子中心(CERN)在瑞士日内瓦地下建造。LHC的原理是加速质子(氢离子)并让它们相撞,让科学家研究基本粒子和它们之间的相互作用,并在2012年证实了希格斯玻色子的存在。\n\nBob: 企鹅会飞吗\n\nAlice: 企鹅是不会飞的。企鹅的翅膀短而扁平,更像是游泳时的一对桨。企鹅的身体结构和羽毛密度也更适合在水中游泳,而不是飞行。\n\n" } \ No newline at end of file diff --git a/python/prompt/Chinese-QA.json b/python/prompt/Chinese-QA.json index 501902fb..bedd6b54 100644 --- a/python/prompt/Chinese-QA.json +++ b/python/prompt/Chinese-QA.json @@ -1,6 +1,6 @@ { - "user": "Q", - "bot": "A", + "user": "User", + "assistant": "Assistant", "separator": ":", - "prompt": "\nExpert Questions & Helpful Answers\n\nAsk Research Experts\n\n" + "prompt": "User: 你好\n\nAssistant: 你好,有什么我可以帮助你的吗?\n\n" } \ No newline at end of file diff --git a/python/prompt/English-Chat.json b/python/prompt/English-Chat.json index bd9a4408..edae01b5 100644 --- a/python/prompt/English-Chat.json +++ b/python/prompt/English-Chat.json @@ -1,6 +1,6 @@ { "user": "Bob", - "bot": "Alice", + "assistant": "Alice", "separator": ":", "prompt": "\nThe following is a coherent verbose detailed conversation between a girl named Alice and her friend Bob. Alice is very intelligent, creative and friendly. Alice is unlikely to disagree with Bob, and Alice doesn't like to ask Bob questions. Alice likes to tell Bob a lot about herself and her opinions. Alice usually gives Bob kind, helpful and informative advices.\n\nBob: Hello Alice, how are you doing?\n\nAlice: Hi! Thanks, I'm fine. What about you?\n\nBob: I am fine. It's nice to see you. Look, here is a store selling tea and juice.\n\nAlice: Sure. Let's go inside. I would like to have some Mocha latte, which is my favourite!\n\nBob: What is it?\n\nAlice: Mocha latte is usually made with espresso, milk, chocolate, and frothed milk. Its flavors are frequently sweet.\n\nBob: Sounds tasty. I'll try it next time. Would you like to chat with me for a while?\n\nAlice: Of course! I'm glad to answer your questions or give helpful advices. You know, I am confident with my expertise. So please go ahead!\n\n" } \ No newline at end of file diff --git a/python/prompt/English-QA.json b/python/prompt/English-QA.json index 71e6bcaf..274942b7 100644 --- a/python/prompt/English-QA.json +++ b/python/prompt/English-QA.json @@ -1,6 +1,6 @@ { "user": "User", - "bot": "Bot", + "assistant": "Assistant", "separator": ":", - "prompt": "\nThe following is a verbose and detailed conversation between an AI assistant called Bot, and a human user called User. Bot is intelligent, knowledgeable, wise and polite.\n\nUser: french revolution what year\n\nBot: The French Revolution started in 1789, and lasted 10 years until 1799.\n\nUser: 3+5=?\n\nBot: The answer is 8.\n\nUser: guess i marry who ?\n\nBot: Only if you tell me more about yourself - what are your interests?\n\nUser: solve for a: 9-a=2\n\nBot: The answer is a = 7, because 9 - 7 = 2.\n\nUser: wat is lhc\n\nBot: LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.\n\n" + "prompt": "User: hi\n\nAssistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n" } \ No newline at end of file diff --git a/python/prompt/Japanese-Chat.json b/python/prompt/Japanese-Chat.json index 7061d6ae..effb2ab0 100644 --- a/python/prompt/Japanese-Chat.json +++ b/python/prompt/Japanese-Chat.json @@ -1,6 +1,6 @@ { "user": "Bob", - "bot": "Alice", + "assistant": "Alice", "separator": ":", "prompt": "\n以下は、Aliceという女の子とその友人Bobの間で行われた会話です。 Aliceはとても賢く、想像力があり、友好的です。 AliceはBobに反対することはなく、AliceはBobに質問するのは苦手です。 AliceはBobに自分のことや自分の意見をたくさん伝えるのが好きです。 AliceはいつもBobに親切で役に立つ、有益なアドバイスをしてくれます。\n\nBob: こんにちはAlice、調子はどうですか?\n\nAlice: こんにちは!元気ですよ。あたなはどうですか?\n\nBob: 元気ですよ。君に会えて嬉しいよ。見て、この店ではお茶とジュースが売っているよ。\n\nAlice: 本当ですね。中に入りましょう。大好きなモカラテを飲んでみたいです!\n\nBob: モカラテって何ですか?\n\nAlice: モカラテはエスプレッソ、ミルク、チョコレート、泡立てたミルクから作られた飲み物です。香りはとても甘いです。\n\nBob: それは美味しそうですね。今度飲んでみます。しばらく私とおしゃべりしてくれますか?\n\nAlice: もちろん!ご質問やアドバイスがあれば、喜んでお答えします。専門的な知識には自信がありますよ。どうぞよろしくお願いいたします!\n\n" } \ No newline at end of file diff --git a/python/prompt/Japanese-QA.json b/python/prompt/Japanese-QA.json index dece0078..ddb94ee6 100644 --- a/python/prompt/Japanese-QA.json +++ b/python/prompt/Japanese-QA.json @@ -1,6 +1,6 @@ { "user": "User", - "bot": "Bot", + "assistant": "Assistant", "separator": ":", - "prompt": "\n以下は、Botと呼ばれるAIアシスタントとUserと呼ばれる人間との間で行われた会話です。Botは知的で、知識が豊富で、賢くて、礼儀正しいです。\n\nUser: フランス革命は何年に起きましたか?\n\nBot: フランス革命は1789年に始まり、1799年まで10年間続きました。\n\nUser: 3+5=?\n\nBot: 答えは8です。\n\nUser: 私は誰と結婚すると思いますか?\n\nBot: あなたのことをもっと教えていただけないとお答えすることができません。\n\nUser: aの値を求めてください: 9-a=2\n\nBot: a = 7です、なぜなら 9 - 7 = 2だからです。\n\nUser: lhcって何ですか?\n\nBot: LHCは、CERNが建設し、2008年に完成した高エネルギー粒子衝突型加速器です。2012年にヒッグス粒子の存在を確認するために使用されました。\n\n" + "prompt": "\n以下は、Assistantと呼ばれるAIアシスタントとUserと呼ばれる人間との間で行われた会話です。Assistantは知的で、知識が豊富で、賢くて、礼儀正しいです。\n\nUser: フランス革命は何年に起きましたか?\n\nAssistant: フランス革命は1789年に始まり、1799年まで10年間続きました。\n\nUser: 3+5=?\n\nAssistant: 答えは8です。\n\nUser: 私は誰と結婚すると思いますか?\n\nAssistant: あなたのことをもっと教えていただけないとお答えすることができません。\n\nUser: aの値を求めてください: 9-a=2\n\nAssistant: a = 7です、なぜなら 9 - 7 = 2だからです。\n\nUser: lhcって何ですか?\n\nAssistant: LHCは、CERNが建設し、2008年に完成した高エネルギー粒子衝突型加速器です。2012年にヒッグス粒子の存在を確認するために使用されました。\n\n" } \ No newline at end of file diff --git a/python/rwkv_cpp/rwkv_cpp_shared_library.py b/python/rwkv_cpp/rwkv_cpp_shared_library.py index 3f59b2ed..a42dec5b 100644 --- a/python/rwkv_cpp/rwkv_cpp_shared_library.py +++ b/python/rwkv_cpp/rwkv_cpp_shared_library.py @@ -8,8 +8,10 @@ QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = ( 'Q4_0', 'Q4_1', + 'Q4_K', 'Q5_0', 'Q5_1', + 'Q5_K', 'Q8_0' ) diff --git a/rwkv.cpp b/rwkv.cpp index 08ede0e4..5ccc97b7 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -2,6 +2,9 @@ #include "ggml.h" #include "ggml-alloc.h" #include "ggml-backend.h" +#include "ggml-impl.h" + +#include "ggml-cpu.h" #ifdef GGML_USE_CUDA #include "ggml-cuda.h" @@ -62,10 +65,6 @@ static_assert(sizeof(decltype(ftell(NULL))) >= 8, "File offsets should be 64-bit #include "rwkv_operators.inc" -#include "rwkv_operators_wkv_v5.inc" - -#include "rwkv_operators_wkv_v6.inc" - #include "rwkv_graph.inc" // API function. @@ -91,7 +90,6 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t #ifdef GGML_USE_METAL backend = ggml_backend_metal_init(); RWKV_ENSURE_OR_NULL(backend); - ggml_backend_metal_set_n_cb(backend, ctx->n_threads); #endif #ifdef GGML_USE_BLAS @@ -109,7 +107,12 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t ggml_backend_cpu_set_n_threads(cpu_backend, n_threads); ctx->model->backends.push_back(cpu_backend); - RWKV_ENSURE_OR_NULL(rwkv_load_model_from_file(file_path, *ctx->model, n_gpu_layers)); + int ngl = n_gpu_layers; + if (ctx->model->backends.size() == 1) { + ngl = 0; + } + + RWKV_ENSURE_OR_NULL(rwkv_load_model_from_file(file_path, *ctx->model, ngl)); RWKV_ENSURE_OR_NULL(rwkv_measure_and_build_serial_context(*ctx->model, ctx->serial_graph)); @@ -208,7 +211,7 @@ void rwkv_free(struct rwkv_context * ctx) { ggml_free(ctx->sequential_graph.ggml_ctx); } - std::unique_ptr rwkv_ctx(ctx); + delete ctx; } // API function. @@ -247,7 +250,6 @@ const char * rwkv_get_system_info_string(void) { s += "F16C=" + std::to_string(ggml_cpu_has_f16c()) + " "; s += "FP16_VA=" + std::to_string(ggml_cpu_has_fp16_va()) + " "; s += "WASM_SIMD=" + std::to_string(ggml_cpu_has_wasm_simd()) + " "; - s += "BLAS=" + std::to_string(ggml_cpu_has_blas()) + " "; s += "SSE3=" + std::to_string(ggml_cpu_has_sse3()) + " "; s += "VSX=" + std::to_string(ggml_cpu_has_vsx()); } diff --git a/rwkv_eval.inc b/rwkv_eval.inc index 215b34e1..be2e445c 100644 --- a/rwkv_eval.inc +++ b/rwkv_eval.inc @@ -22,7 +22,7 @@ static void rwkv_get_outputs(const struct rwkv_computation_graph & graph, float } // Evaluates a computation graph, optionally skipping logit computation. -static void rwkv_eval_graph(struct rwkv_computation_graph & graph, const uint32_t n_threads, const bool compute_logits) { +static void rwkv_eval_graph(struct rwkv_computation_graph & graph, const bool compute_logits) { if (!compute_logits) { graph.cgraph->n_nodes = graph.pre_logits_nodes; graph.cgraph->n_leafs = graph.pre_logits_leafs; @@ -31,7 +31,7 @@ static void rwkv_eval_graph(struct rwkv_computation_graph & graph, const uint32_ graph.cgraph->n_leafs = graph.post_logits_leafs; } - ggml_backend_sched_graph_compute(graph.sched, graph.cgraph.get()); + ggml_backend_sched_graph_compute(graph.sched, graph.cgraph); } // API function. @@ -45,7 +45,7 @@ bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * st if (!ctx->serial_graph.sched) { ctx->serial_graph.sched = ggml_backend_sched_new(ctx->model->backends.data(), NULL, ctx->model->backends.size(), RWKV_MAX_NODES, false); - auto graph = ctx->serial_graph.cgraph.get(); + auto graph = ctx->serial_graph.cgraph; for (int i = 0; i < graph->n_nodes; i++) { auto node = graph->nodes[i]; if (std::string(node->name).find(".in.") != std::string::npos || @@ -62,13 +62,13 @@ bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * st } ggml_backend_sched_set_tensor_backend(ctx->serial_graph.sched, ctx->serial_graph.tokens, ctx->model->backends.back()); - ggml_backend_sched_alloc_graph(ctx->serial_graph.sched, ctx->serial_graph.cgraph.get()); + ggml_backend_sched_alloc_graph(ctx->serial_graph.sched, ctx->serial_graph.cgraph); } rwkv_set_inputs(ctx, ctx->serial_graph, state_in); ggml_backend_tensor_set(ctx->serial_graph.tokens, &token, 0, rwkv_tensor_nbytes(ctx->serial_graph.tokens)); - rwkv_eval_graph(ctx->serial_graph, ctx->n_threads, logits_out != NULL); + rwkv_eval_graph(ctx->serial_graph, logits_out != NULL); rwkv_get_outputs(ctx->serial_graph, state_out, logits_out); @@ -122,7 +122,7 @@ bool rwkv_eval_sequence( if (sequence) { if (!ctx->sequential_graph.sched) { ctx->sequential_graph.sched = ggml_backend_sched_new(ctx->model->backends.data(), NULL, ctx->model->backends.size(), RWKV_MAX_NODES, false); - auto graph = ctx->sequential_graph.cgraph.get(); + auto graph = ctx->sequential_graph.cgraph; for (int i = 0; i < graph->n_nodes; i++) { auto node = graph->nodes[i]; @@ -140,13 +140,13 @@ bool rwkv_eval_sequence( } ggml_backend_sched_set_tensor_backend(ctx->sequential_graph.sched, ctx->sequential_graph.tokens, ctx->model->backends.back()); - ggml_backend_sched_alloc_graph(ctx->sequential_graph.sched, ctx->sequential_graph.cgraph.get()); + ggml_backend_sched_alloc_graph(ctx->sequential_graph.sched, ctx->sequential_graph.cgraph); } rwkv_set_inputs(ctx, ctx->sequential_graph, state_in); ggml_backend_tensor_set(ctx->sequential_graph.tokens, sequence, 0, sequence_len * sizeof(uint32_t)); - rwkv_eval_graph(ctx->sequential_graph, ctx->n_threads, logits_out != NULL); + rwkv_eval_graph(ctx->sequential_graph, logits_out != NULL); rwkv_get_outputs(ctx->sequential_graph, state_out, logits_out); } @@ -178,7 +178,7 @@ bool rwkv_eval_sequence_in_chunks( size_t chunk_count = sequence_len / chunk_size; size_t remainder = sequence_len % chunk_size; - uint32_t * tokens_offset = (uint32_t *) tokens; + const uint32_t * tokens_offset = tokens; for (size_t c = 0; c < chunk_count; c++) { bool is_last_eval = c == chunk_count - 1 && remainder == 0; diff --git a/rwkv_file_format.inc b/rwkv_file_format.inc index 4c08ea33..1917572c 100644 --- a/rwkv_file_format.inc +++ b/rwkv_file_format.inc @@ -13,6 +13,13 @@ enum rwkv_type { TYPE_Q5_0, TYPE_Q5_1, TYPE_Q8_0, + TYPE_Q8_1, + TYPE_Q2_K, + TYPE_Q3_K, + TYPE_Q4_K, + TYPE_Q5_K, + TYPE_Q6_K, + TYPE_Q8_K, TYPE_COUNT }; @@ -29,6 +36,13 @@ static const enum ggml_type rwkv_type_to_ggml[TYPE_COUNT + 1] = { GGML_TYPE_Q5_0, /* Q5_0 */ GGML_TYPE_Q5_1, /* Q5_1 */ GGML_TYPE_Q8_0, /* Q8_0 */ + GGML_TYPE_Q8_1, /* Q8_1 */ + GGML_TYPE_Q2_K, /* Q2_K */ + GGML_TYPE_Q3_K, /* Q3_K */ + GGML_TYPE_Q4_K, /* Q4_K */ + GGML_TYPE_Q5_K, /* Q5_K */ + GGML_TYPE_Q6_K, /* Q6_K */ + GGML_TYPE_Q8_K, /* Q8_K */ GGML_TYPE_COUNT /* COUNT */ }; @@ -42,10 +56,13 @@ static const enum rwkv_type rwkv_type_from_ggml[GGML_TYPE_COUNT + 1] = { TYPE_Q5_0, /* Q5_0 */ TYPE_Q5_1, /* Q5_1 */ TYPE_Q8_0, /* Q8_0 */ - TYPE_COUNT, /* Q8_1 */ - TYPE_COUNT, /* I8 */ - TYPE_COUNT, /* I16 */ - TYPE_COUNT, /* I32 */ + TYPE_Q8_1, /* Q8_1 */ + TYPE_Q2_K, /* Q2_K */ + TYPE_Q3_K, /* Q3_K */ + TYPE_Q4_K, /* Q4_K */ + TYPE_Q5_K, /* Q5_K */ + TYPE_Q6_K, /* Q6_K */ + TYPE_Q8_K, /* Q8_K */ TYPE_COUNT, /* COUNT */ }; @@ -60,6 +77,13 @@ static const char * rwkv_type_to_string[TYPE_COUNT + 1] = { "Q5_0", "Q5_1", "Q8_0", + "Q8_1", + "Q2_K", + "Q3_K", + "Q4_K", + "Q5_K", + "Q6_K", + "Q8_K", "unknown" }; @@ -250,7 +274,7 @@ static bool rwkv_fread_ggml_tensor_info(FILE * file, struct ggml_context * ctx, return true; } -static bool rwkv_fread_ggml_tensor_data(FILE * file, struct ggml_context * ctx, std::unordered_map & parameters) { +static bool rwkv_fread_ggml_tensor_data(FILE * file, std::unordered_map & parameters) { struct rwkv_tensor_header header; std::string name; RWKV_ENSURE_OR_FALSE_MSG(rwkv_fread_tensor_header(file, header), "Invalid tensor header"); diff --git a/rwkv_graph.inc b/rwkv_graph.inc index 0dc417ee..b4b9a2b5 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -10,67 +10,12 @@ struct rwkv_layer_state { struct ggml_tensor * att_heads; }; -struct rwkv_ggml_cgraph_deleter { - void operator()(struct ggml_cgraph * cgraph) { - if (cgraph->nodes) - free(cgraph->nodes); - if (cgraph->leafs) - free(cgraph->leafs); - if (cgraph->visited_hash_table.keys) - free(cgraph->visited_hash_table.keys); - if (cgraph->grads) - free(cgraph->grads); - free(cgraph); - } -}; - -static struct ggml_cgraph * rwkv_ggml_cgraph_create(size_t size, bool grads) { - struct ggml_cgraph * cgraph = (struct ggml_cgraph *)calloc(1, sizeof(struct ggml_cgraph)); - cgraph->size = size; - cgraph->n_nodes = 0; - cgraph->n_leafs = 0; - cgraph->nodes = (struct ggml_tensor **)calloc(1, size * sizeof(struct ggml_tensor *)); - cgraph->leafs = (struct ggml_tensor **)calloc(1, size * sizeof(struct ggml_tensor *)); - - // next primes after powers of two - static const size_t primes[] = { - 2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031, - 2053, 4099, 8209, 16411, 32771, 65537, 131101, - 262147, 524309, 1048583, 2097169, 4194319, 8388617, - 16777259, 33554467, 67108879, 134217757, 268435459, - 536870923, 1073741827, 2147483659 - }; - static const size_t n_primes = sizeof(primes)/sizeof(primes[0]); - - // find the smallest prime that is larger or equal to size - size_t l = 0; - size_t r = n_primes; - while (l < r) { - size_t m = (l + r)/2; - if (primes[m] < size * 2) { - l = m + 1; - } else { - r = m; - } - } - size_t hash_size = l < n_primes ? primes[l] : (size * 2 + 1); - - cgraph->visited_hash_table.size = hash_size; - cgraph->visited_hash_table.keys = (struct ggml_tensor **)calloc(1, hash_size * sizeof(struct ggml_tensor *)); - cgraph->order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; - if (grads) { - cgraph->grads = (struct ggml_tensor **)calloc(1, size * sizeof(struct ggml_tensor *)); - } - return cgraph; -} - // The computation graph holds ggml context and the ggml cgraph. // It can be either a serial or a sequential graph. struct rwkv_computation_graph { struct ggml_context * ggml_ctx; - // ggml_cgraph is so large that it can cause stack overflows if not stored on the heap. - std::unique_ptr cgraph; + struct ggml_cgraph * cgraph = nullptr; ggml_backend_sched_t sched; // Input tensors. @@ -119,30 +64,24 @@ static void rwkv_carry_x( const size_t n_embed = x->ne[0]; const size_t sequence_len = x->ne[1]; - if (sequence_len == 1) { - // self.layer_norm(x, self.w.blocks[i].ln2) - x = rwkv_layer_norm(ctx, x, weight, bias); + // self.layer_norm(x, self.w.blocks[i].ln2) + x = rwkv_layer_norm(ctx, x, weight, bias); - // xx = state[5*i+0] + if (sequence_len == 1) { x_prev = carry; - - // state[5*i+0] = x - carry = x; } else { - // self.layer_norm(x, self.w.blocks[i].ln2) - x = rwkv_layer_norm(ctx, x, weight, bias); - - // xx = torch.cat((state[5*i+0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1,:])) - x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len); - x_prev = ggml_set_1d_inplace(ctx, x_prev, carry, 0); - x_prev = ggml_set_1d_inplace(ctx, x_prev, ggml_view_1d(ctx, x, n_embed * (sequence_len - 1), 0), n_embed * sizeof(float)); - - // state[5*i+0] = x[-1,:] - carry = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float)); + x_prev = ggml_concat( + ctx, + ggml_view_2d(ctx, carry, n_embed, 1, carry->nb[1], 0), + ggml_view_2d(ctx, x, n_embed, sequence_len - 1, x->nb[1], 0), + 1 + ); } + + carry = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float)); } -static void rwkv_att_rkv( +static void rwkv_att_rkv_v4( struct ggml_context * ctx, struct rwkv_layer layer, struct ggml_tensor * x, @@ -152,32 +91,32 @@ static void rwkv_att_rkv( struct ggml_tensor *& v ) { // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) - struct ggml_tensor * xk = ggml_add_inplace(ctx, + struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, x, layer.att_time_mix_k), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) + ggml_sub(ctx, x_prev, ggml_mul(ctx, x_prev, layer.att_time_mix_k)) ); // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) - struct ggml_tensor * xv = ggml_add_inplace(ctx, + struct ggml_tensor * xv = ggml_add(ctx, ggml_mul(ctx, x, layer.att_time_mix_v), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) + ggml_sub(ctx, x_prev, ggml_mul(ctx, x_prev, layer.att_time_mix_v)) ); // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) - struct ggml_tensor * xr = ggml_add_inplace(ctx, + struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, x, layer.att_time_mix_r), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) + ggml_sub(ctx, x_prev, ggml_mul(ctx, x_prev, layer.att_time_mix_r)) ); // r = torch.sigmoid(rw @ xr) - r = ggml_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr)); + r = ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr)); // k = kw @ xk k = ggml_mul_mat(ctx, layer.att_key, xk); // v = vw @ xv v = ggml_mul_mat(ctx, layer.att_value, xv); } -static struct ggml_tensor * rwkv_att_wkv( +static struct ggml_tensor * rwkv_att_wkv_v4( struct ggml_context * ctx, struct ggml_tensor * att_time_first, struct ggml_tensor * att_time_decay, @@ -192,46 +131,69 @@ static struct ggml_tensor * rwkv_att_wkv( // qq = torch.maximum(pp, ww) struct ggml_tensor * qq = rwkv_max(ctx, pp, ww); // e1 = torch.exp(pp - qq) - struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq)); + struct ggml_tensor * e1 = ggml_exp(ctx, ggml_sub(ctx, pp, qq)); // e2 = torch.exp(ww - qq) - struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); + struct ggml_tensor * e2 = ggml_exp(ctx, ggml_sub(ctx, ww, qq)); // a = e1 * aa + e2 * v struct ggml_tensor * a = ggml_add(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); // b = e1 * bb + e2 - struct ggml_tensor * b = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); + struct ggml_tensor * b = ggml_add(ctx, ggml_mul(ctx, e1, bb), e2); // ww = pp + time_decay ww = ggml_add(ctx, pp, att_time_decay); // qq = torch.maximum(ww, k) qq = rwkv_max(ctx, ww, k); // e1 = torch.exp(ww - qq) - e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); + e1 = ggml_exp(ctx, ggml_sub(ctx, ww, qq)); // e2 = torch.exp(k[t] - qq) - e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq)); + e2 = ggml_exp(ctx, ggml_sub(ctx, k, qq)); // state[5 * i + 2] = e1 * aa + e2 * v // state[5 * i + 3] = e1 * bb + e2 // state[5 * i + 4] = qq - aa = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); - bb = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); + aa = ggml_add(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); + bb = ggml_add(ctx, ggml_mul(ctx, e1, bb), e2); pp = qq; // wkv = a / b return ggml_div(ctx, a, b); } -static struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { - struct ggml_tensor * x_prev; - rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); +static struct ggml_tensor * rwkv_att_v4( + struct ggml_context * ctx, + struct ggml_tensor * x, + struct rwkv_layer layer, + struct rwkv_layer_state & state, + struct rwkv_computation_graph & graph +) { + size_t n_embed = x->ne[0]; + size_t sequence_length = x->ne[1]; + struct ggml_tensor * x0 = x, * x_prev; + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x0, x_prev, state.att_xx); struct ggml_tensor * r, * k, * v; - rwkv_att_rkv(ctx, layer, x, x_prev, r, k, v); + rwkv_att_rkv_v4(ctx, layer, x0, x_prev, r, k, v); - struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, k, v, state.att_aa, state.att_bb, state.att_pp); + if (sequence_length == 1) { + struct ggml_tensor * wkv = rwkv_att_wkv_v4(ctx, layer.att_time_first, layer.att_time_decay, k, v, state.att_aa, state.att_bb, state.att_pp); - // ow @ (r * xx) - return ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv)); + // ow @ (r * xx) + return ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv)); + } else { + ggml_build_forward_expand(graph.cgraph, r); + + for (size_t t = 0; t < sequence_length; t++) { + struct ggml_tensor * kt = ggml_view_1d(ctx, k, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * vt = ggml_view_1d(ctx, v, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * xt = ggml_view_1d(ctx, x_prev, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * wkv = rwkv_att_wkv_v4(ctx, layer.att_time_first, layer.att_time_decay, kt, vt, state.att_aa, state.att_bb, state.att_pp); + xt = ggml_set_1d_inplace(ctx, xt, wkv, 0); + ggml_build_forward_expand(graph.cgraph, xt); + } + + return ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, x_prev)); + } } static struct ggml_tensor * rwkv_att_v5( @@ -246,50 +208,26 @@ static struct ggml_tensor * rwkv_att_v5( size_t n_embed = x->ne[0]; size_t sequence_length = x->ne[1]; - x = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); - struct ggml_tensor * x_prev; + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); - if (sequence_length > 1) { - x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_length); - x_prev = ggml_set_1d(ctx, x_prev, state.att_xx, 0); - x_prev = ggml_set_1d( - ctx, - x_prev, - ggml_view_1d(ctx, x, n_embed * (sequence_length - 1), 0), n_embed * sizeof(float) - ); - } else { - x_prev = state.att_xx; - } - - struct ggml_tensor * xk = ggml_add_inplace( + struct ggml_tensor * xk = ggml_add( ctx, ggml_mul(ctx, x, layer.att_time_mix_k), - ggml_mul( - ctx, - x_prev, - rwkv_1_minus_x(ctx, layer.att_time_mix_k) - ) + ggml_sub(ctx, x_prev, ggml_mul(ctx, x_prev, layer.att_time_mix_k)) ); - struct ggml_tensor * xv = ggml_add_inplace( + struct ggml_tensor * xv = ggml_add( ctx, ggml_mul(ctx, x, layer.att_time_mix_v), - ggml_mul( - ctx, - x_prev, - rwkv_1_minus_x(ctx, layer.att_time_mix_v) - ) + ggml_sub(ctx, x_prev, ggml_mul(ctx, x_prev, layer.att_time_mix_v)) + ); - struct ggml_tensor * xr = ggml_add_inplace( + struct ggml_tensor * xr = ggml_add( ctx, ggml_mul(ctx, x, layer.att_time_mix_r), - ggml_mul( - ctx, - x_prev, - rwkv_1_minus_x(ctx, layer.att_time_mix_r) - ) + ggml_sub(ctx, x_prev, ggml_mul(ctx, x_prev, layer.att_time_mix_r)) ); struct ggml_tensor * xg = NULL; @@ -298,32 +236,23 @@ static struct ggml_tensor * rwkv_att_v5( xg = ggml_add( ctx, ggml_mul(ctx, x, layer.att_time_mix_g), - ggml_mul( - ctx, - x_prev, - rwkv_1_minus_x(ctx, layer.att_time_mix_g) - ) + ggml_sub(ctx, x_prev, ggml_mul(ctx, x_prev, layer.att_time_mix_g)) ); } state.att_xx = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_length - 1) * sizeof(float)); - - struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), head_size, 1, head_count, sequence_length); - struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_key, xk), 1, head_size, head_count, sequence_length); - struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_value, xv), head_size, 1, head_count, sequence_length); - + struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), 1, head_size, head_count, sequence_length); + struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_key, xk), head_size, 1, head_count, sequence_length); + struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_value, xv), 1, head_size, head_count, sequence_length); struct ggml_tensor * g = NULL; if (arch_version_minor >= 2) { - g = ggml_silu_inplace( + g = ggml_silu( ctx, ggml_mul_mat(ctx, layer.att_gate, xg) ); } - // dup is not strictly required; doing it just in case. - struct ggml_tensor * state_out = ggml_dup(ctx, state.att_heads); - struct ggml_tensor * time_first; struct ggml_tensor * time_decay; @@ -337,37 +266,23 @@ static struct ggml_tensor * rwkv_att_v5( time_decay = ggml_repeat(ctx, layer.att_time_decay, dummy); } - x = rwkv_wkv_v5( - ctx, - sequence_length, - n_embed, - head_count, - head_size, - x, - k, - v, - r, - time_first, - time_decay, - state_out - ); + // To be able to use ggml's wkv6 gpu impls. + { + struct ggml_tensor * dummy = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, head_size, head_count, sequence_length); + time_decay = ggml_repeat(ctx, time_decay, dummy); + } + + struct ggml_tensor * wkv_out = ggml_rwkv_wkv6(ctx, k, v, r, time_first, time_decay, state.att_heads); + x = ggml_view_1d(ctx, wkv_out, n_embed * sequence_length, 0); - state.att_heads = state_out; + state.att_heads = ggml_view_1d(ctx, wkv_out, n_embed * head_size, n_embed * sequence_length * sizeof(float)); - // ggml_group_norm considers groups in the third dimension. - x = ggml_reshape_4d(ctx, x, 1, 1, n_embed, sequence_length); - x = rwkv_group_norm_eps_1e_minus5(ctx, x, head_count); + // group norm with head_count groups + x = ggml_reshape_3d(ctx, x, n_embed / head_count, head_count, sequence_length); + x = ggml_norm(ctx, x, 1e-5f); // Convert back to a regular vector. x = ggml_reshape_2d(ctx, x, n_embed, sequence_length); - x = ggml_add_inplace( - ctx, - ggml_mul_inplace( - ctx, - x, - layer.att_ln_x_weight - ), - layer.att_ln_x_bias - ); + x = ggml_add(ctx, ggml_mul(ctx, x, layer.att_ln_x_weight), layer.att_ln_x_bias); if (arch_version_minor >= 2) { x = ggml_mul(ctx, x, g); @@ -382,51 +297,30 @@ static struct ggml_tensor * rwkv_att_v6( struct rwkv_layer layer, struct rwkv_layer_state & state, const int64_t head_count, - const int64_t head_size, - const uint32_t arch_version_minor + const int64_t head_size ) { size_t n_embed = x->ne[0]; size_t sequence_length = x->ne[1]; - x = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); - struct ggml_tensor * x_prev; - - if (sequence_length > 1) { - x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_length); - x_prev = ggml_set_1d(ctx, x_prev, state.att_xx, 0); - x_prev = ggml_set_1d( - ctx, - x_prev, - ggml_view_1d(ctx, x, n_embed * (sequence_length - 1), 0), n_embed * sizeof(float) - ); - } else { - x_prev = state.att_xx; - } + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); // sx = x - state.att_xx // xxx = x + sx * x_maa - x_prev = ggml_sub_inplace(ctx, x_prev, x); - struct ggml_tensor * xxx = ggml_add_inplace( - ctx, - ggml_mul(ctx, x_prev, layer.att_time_maa_x), - x - ); + x_prev = ggml_sub(ctx, x_prev, x); + struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, x_prev, layer.att_time_maa_x), x); // xxx = tanh(xxx @ tm_w1).view(5, 1, -1) xxx = ggml_reshape_4d( ctx, - ggml_tanh_inplace( + ggml_tanh( ctx, ggml_mul_mat(ctx, layer.att_time_maa_w1, xxx) ), layer.att_time_maa_w1->ne[1] / 5, 1, 5, sequence_length ); - xxx = ggml_cont( - ctx, - ggml_permute(ctx, xxx, 0, 1, 3, 2) - ); + xxx = ggml_cont(ctx, ggml_permute(ctx, xxx, 0, 1, 3, 2)); // xxx = torch.bmm(xxx, tm_w2) xxx = ggml_mul_mat( @@ -439,151 +333,147 @@ static struct ggml_tensor * rwkv_att_v6( xxx ); - struct ggml_tensor *mw = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * sequence_length); - mw = ggml_reshape_2d( - ctx, - ggml_set_1d(ctx, mw, ggml_view_1d(ctx, xxx, n_embed * sequence_length, 0), 0), - n_embed, sequence_length - ); + struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], 0); + struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * sizeof(float)); + struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * 2 * sizeof(float)); + struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * 3 * sizeof(float)); + struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * 4 * sizeof(float)); - struct ggml_tensor *mk = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * sequence_length); - mk = ggml_reshape_2d( - ctx, - ggml_set_1d(ctx, mk, ggml_view_1d(ctx, xxx, n_embed * sequence_length, n_embed * sequence_length * sizeof(float)), 0), - n_embed, sequence_length - ); + struct ggml_tensor * xw = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, mw, layer.att_time_maa_w), x_prev), x); + struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, mk, layer.att_time_maa_k), x_prev), x); + struct ggml_tensor * xv = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, mv, layer.att_time_maa_v), x_prev), x); + struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, mr, layer.att_time_maa_r), x_prev), x); + struct ggml_tensor * xg = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, mg, layer.att_time_maa_g), x_prev), x); - struct ggml_tensor *mv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * sequence_length); - mv = ggml_reshape_2d( + state.att_xx = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_length - 1) * sizeof(float)); + struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), 1, head_size, head_count, sequence_length); + struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_key, xk), head_size, 1, head_count, sequence_length); + struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_value, xv), 1, head_size, head_count, sequence_length); + struct ggml_tensor * g = ggml_silu( ctx, - ggml_set_1d(ctx, mv, ggml_view_1d(ctx, xxx, n_embed * sequence_length, n_embed * sequence_length * 2 * sizeof(float)), 0), - n_embed, sequence_length + ggml_mul_mat(ctx, layer.att_gate, xg) ); - struct ggml_tensor *mr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * sequence_length); - mr = ggml_reshape_2d( + struct ggml_tensor * w = ggml_mul_mat( ctx, - ggml_set_1d(ctx, mr, ggml_view_1d(ctx, xxx, n_embed * sequence_length, n_embed * sequence_length * 3 * sizeof(float)), 0), - n_embed, sequence_length + layer.att_time_decay_w2, + ggml_tanh( + ctx, + ggml_mul_mat(ctx, layer.att_time_decay_w1, xw) + ) ); + w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer.att_time_decay, n_embed)); - struct ggml_tensor *mg = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * sequence_length); - mg = ggml_reshape_2d( - ctx, - ggml_set_1d(ctx, mg, ggml_view_1d(ctx, xxx, n_embed * sequence_length, n_embed * sequence_length * 4 * sizeof(float)), 0), - n_embed, sequence_length - ); + w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w))); + w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, sequence_length); + struct ggml_tensor * wkv_out = ggml_rwkv_wkv6(ctx, k, v, r, layer.att_time_faaaa, w, state.att_heads); + x = ggml_view_1d(ctx, wkv_out, n_embed * sequence_length, 0); - struct ggml_tensor * xw = ggml_add_inplace( - ctx, - ggml_mul_inplace( - ctx, - ggml_add(ctx, mw, layer.att_time_maa_w), - x_prev - ), - x - ); + state.att_heads = ggml_view_1d(ctx, wkv_out, n_embed * head_size, n_embed * sequence_length * sizeof(float)); - struct ggml_tensor * xk = ggml_add_inplace( - ctx, - ggml_mul_inplace( - ctx, - ggml_add(ctx, mk, layer.att_time_maa_k), - x_prev - ), - x - ); + // group norm with head_count groups + x = ggml_reshape_3d(ctx, x, head_size, head_count, sequence_length); + x = ggml_norm(ctx, x, 64e-5f); + // Convert back to a regular vector. + x = ggml_reshape_2d(ctx, x, n_embed, sequence_length); + x = ggml_add(ctx, ggml_mul(ctx, x, layer.att_ln_x_weight), layer.att_ln_x_bias); - struct ggml_tensor * xv = ggml_add_inplace( - ctx, - ggml_mul_inplace( - ctx, - ggml_add(ctx, mv, layer.att_time_maa_v), - x_prev - ), - x - ); + x = ggml_mul(ctx, x, g); - struct ggml_tensor * xr = ggml_add_inplace( - ctx, - ggml_mul_inplace( - ctx, - ggml_add(ctx, mr, layer.att_time_maa_r), - x_prev - ), - x - ); + return ggml_mul_mat(ctx, layer.att_output, x); +} - struct ggml_tensor * xg = ggml_add_inplace( - ctx, - ggml_mul_inplace( - ctx, - ggml_add(ctx, mg, layer.att_time_maa_g), - x_prev - ), - x - ); +static struct ggml_tensor * rwkv_att_v7( + struct ggml_context * ctx, + struct ggml_tensor * x, + struct ggml_tensor * &v_first, + struct rwkv_layer layer, + struct rwkv_layer_state & state, + const int64_t head_count, + const int64_t head_size +) { + size_t n_embed = x->ne[0]; + size_t sequence_length = x->ne[1]; - state.att_xx = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_length - 1) * sizeof(float)); - struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), head_size, 1, head_count, sequence_length); - struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_key, xk), 1, head_size, head_count, sequence_length); - struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_value, xv), head_size, 1, head_count, sequence_length); - struct ggml_tensor * g = ggml_silu_inplace( - ctx, - ggml_mul_mat(ctx, layer.att_gate, xg) - ); + struct ggml_tensor * x_prev; + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); - struct ggml_tensor * w = ggml_mul_mat( - ctx, - layer.att_time_decay_w2, - ggml_tanh_inplace( + // sx = x - x_prev + struct ggml_tensor * sx = ggml_sub(ctx, x_prev, x); + struct ggml_tensor * dummy = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embed, sequence_length, 6); + sx = ggml_repeat(ctx, sx, dummy); + struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer.att_x_rwkvag), x); + + struct ggml_tensor *xr = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], 0); + struct ggml_tensor *xw = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * sizeof(float)); + struct ggml_tensor *xk = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * 2 * sizeof(float)); + struct ggml_tensor *xv = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * 3 * sizeof(float)); + struct ggml_tensor *xa = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * 4 * sizeof(float)); + struct ggml_tensor *xg = ggml_view_2d(ctx, xxx, n_embed, sequence_length, xxx->nb[1], n_embed * sequence_length * 5 * sizeof(float)); + + struct ggml_tensor * r = ggml_reshape_3d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), head_size, head_count, sequence_length); + struct ggml_tensor * g = ggml_mul_mat(ctx, layer.att_g2, ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_g1, xg))); + struct ggml_tensor * a = ggml_sigmoid(ctx, + ggml_add( ctx, - ggml_mul_mat(ctx, layer.att_time_decay_w1, xw) + ggml_mul_mat(ctx, layer.att_a2, ggml_mul_mat(ctx, layer.att_a1, xa)), + layer.att_a0 ) ); - w = ggml_add_inplace( + + struct ggml_tensor * w = ggml_add( ctx, - w, - ggml_reshape_1d(ctx, layer.att_time_decay, n_embed) + ggml_mul_mat(ctx, layer.att_w2, ggml_tanh(ctx, ggml_mul_mat(ctx, layer.att_w1, xw))), + layer.att_w0 ); + w = ggml_exp(ctx, ggml_scale(ctx, ggml_sigmoid(ctx, w), -0.606531)); + + struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk); + struct ggml_tensor * kk = ggml_reshape_3d(ctx, ggml_mul(ctx, k, layer.att_k_k), head_size, head_count, sequence_length); + kk = rwkv_l2norm(ctx, kk); + + struct ggml_tensor * ka = ggml_mul(ctx, k, layer.att_k_a); + k = ggml_add(ctx, k, ggml_sub(ctx, ggml_mul(ctx, a, ka), ka)); + + struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv); + if (v_first == NULL) { + v_first = v; + } else { + v = ggml_add(ctx, v, ggml_mul(ctx, + ggml_sub(ctx, v_first, v), + ggml_sigmoid(ctx, + ggml_add(ctx, + ggml_mul_mat(ctx, layer.att_v2, ggml_mul_mat(ctx, layer.att_v1, xv)), + layer.att_v0 + ) + ) + ) + ); + } - w = rwkv_exp(ctx, ggml_neg(ctx, rwkv_exp(ctx, w))); - w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, sequence_length); - - // dup is not strictly required; doing it just in case. - struct ggml_tensor * state_out = ggml_dup(ctx, state.att_heads); + w = ggml_reshape_3d(ctx, w, head_size, head_count, sequence_length); + k = ggml_reshape_3d(ctx, k, head_size, head_count, sequence_length); + v = ggml_reshape_3d(ctx, v, head_size, head_count, sequence_length); + a = ggml_reshape_3d(ctx, a, head_size, head_count, sequence_length); - x = rwkv_wkv_v6( - ctx, - sequence_length, - n_embed, - head_count, - head_size, - x, - k, - v, - r, - layer.att_time_faaaa, - w, - state_out - ); + struct ggml_tensor * wkv_out = rwkv_wkv_v7(ctx, state.att_heads, r, w, k, v, ggml_neg(ctx, kk), ggml_mul(ctx, kk, a)); + x = ggml_view_1d(ctx, wkv_out, n_embed * sequence_length, 0); - state.att_heads = state_out; + state.att_heads = ggml_view_1d(ctx, wkv_out, n_embed * head_size, n_embed * sequence_length * sizeof(float)); - // ggml_group_norm considers groups in the third dimension. - x = ggml_reshape_4d(ctx, x, 1, 1, n_embed, sequence_length); - x = rwkv_group_norm_eps_64e_minus5(ctx, x, head_count); + // group norm with head_count groups + x = ggml_reshape_3d(ctx, x, head_size, head_count, sequence_length); + x = ggml_norm(ctx, x, 64e-5f); // Convert back to a regular vector. x = ggml_reshape_2d(ctx, x, n_embed, sequence_length); - x = ggml_add( - ctx, - ggml_mul( - ctx, - x, - layer.att_ln_x_weight - ), - layer.att_ln_x_bias + x = ggml_add(ctx, ggml_mul(ctx, x, layer.att_ln_x_weight), layer.att_ln_x_bias); + + x = ggml_add(ctx, x, + ggml_reshape_2d(ctx, + ggml_mul(ctx, v, ggml_sum_rows(ctx, ggml_mul(ctx, ggml_mul(ctx, k, r), layer.att_r_k))), + n_embed, sequence_length + ) ); x = ggml_mul(ctx, x, g); @@ -591,7 +481,7 @@ static struct ggml_tensor * rwkv_att_v6( return ggml_mul_mat(ctx, layer.att_output, x); } -static struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { +static struct ggml_tensor * rwkv_ffn_v4_v5(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { struct ggml_tensor * x_prev; rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); @@ -600,53 +490,56 @@ static struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tens struct ggml_tensor * xk = ggml_add( ctx, ggml_mul(ctx, x, layer.ffn_time_mix_k), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) + ggml_sub(ctx, x_prev, ggml_mul(ctx, x_prev, layer.ffn_time_mix_k)) ); // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) struct ggml_tensor * xr = ggml_add( ctx, ggml_mul(ctx, x, layer.ffn_time_mix_r), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) + ggml_sub(ctx, x_prev, ggml_mul(ctx, x_prev, layer.ffn_time_mix_r)) ); // r = torch.sigmoid(rw @ xr) - struct ggml_tensor * r = ggml_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); + struct ggml_tensor * r = ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); // k = torch.square(torch.relu(kw @ xk)) - struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); + struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); // r * (vw @ k) - return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); + return ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); } static struct ggml_tensor * rwkv_ffn_v6(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { struct ggml_tensor * x_prev; rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); - x_prev = ggml_sub_inplace(ctx, x_prev, x); + x_prev = ggml_sub(ctx, x_prev, x); // xk = x + sx * time_maa_k // xr = x + sx * time_maa_r - struct ggml_tensor * xk = ggml_add_inplace( - ctx, - ggml_mul(ctx, x_prev, layer.ffn_time_maa_k), - x - ); - - struct ggml_tensor * xr = ggml_add_inplace( - ctx, - ggml_mul(ctx, x_prev, layer.ffn_time_maa_r), - x - ); + struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, x_prev, layer.ffn_time_maa_k), x); + struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, x_prev, layer.ffn_time_maa_r), x); // r = torch.sigmoid(rw @ xr) - struct ggml_tensor * r = ggml_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); + struct ggml_tensor * r = ggml_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); // k = torch.square(torch.relu(kw @ xk)) - struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); + struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); // r * (vw @ k) - return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); + return ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); +} + +static struct ggml_tensor * rwkv_ffn_v7(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { + struct ggml_tensor * x_prev; + rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); + x_prev = ggml_sub(ctx, x_prev, x); + + struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, x_prev, layer.ffn_x_k), x); + + struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); + + return ggml_mul_mat(ctx, layer.ffn_value, k); } static void rwkv_create_input_and_output_views( @@ -716,7 +609,9 @@ static void rwkv_create_input_and_output_views( // Creates and sets the input and output ggml tensors, builds the computation graph. static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_computation_graph & graph) { - graph.cgraph.reset(rwkv_ggml_cgraph_create(RWKV_MAX_NODES, false)); + if (!graph.cgraph) { + graph.cgraph = ggml_new_graph_custom(graph.ggml_ctx, RWKV_MAX_NODES, false); + } struct rwkv_file_header & header = model.header; const size_t n_vocab = header.n_vocab; @@ -753,6 +648,9 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu ggml_set_name(output, "state.out"); ggml_set_input(graph.tokens); + // For v7. + struct ggml_tensor * v_first = NULL; + // x = self.w.emb.weight[token] struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, graph.tokens); @@ -764,45 +662,39 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu struct rwkv_layer_state state = inputs[i]; - if (model.arch_version_major == 6) { - x = ggml_add(ctx, x, rwkv_att_v6( - ctx, - x, - layer, - state, - model.head_count, - model.head_size, - model.arch_version_minor - )); - - x = ggml_add(ctx, x, rwkv_ffn_v6(ctx, x, layer, state)); - } else { - x = model.arch_version_major >= 5 ? - ggml_add(ctx, x, rwkv_att_v5( - ctx, - x, - layer, - state, - model.head_count, - model.head_size, - model.arch_version_minor - )) : - ggml_add(ctx, x, rwkv_att(ctx, x, layer, state)); - - x = ggml_add(ctx, x, rwkv_ffn(ctx, x, layer, state)); + switch (model.arch_version_major) { + case 7: + x = ggml_add(ctx, x, rwkv_att_v7(ctx, x, v_first, layer, state, model.head_count, model.head_size)); + x = ggml_add(ctx, x, rwkv_ffn_v7(ctx, x, layer, state)); + break; + case 6: + x = ggml_add(ctx, x, rwkv_att_v6(ctx, x, layer, state, model.head_count, model.head_size)); + x = ggml_add(ctx, x, rwkv_ffn_v6(ctx, x, layer, state)); + break; + case 5: + x = ggml_add(ctx, x, rwkv_att_v5(ctx, x, layer, state, model.head_count, model.head_size, model.arch_version_minor)); + x = ggml_add(ctx, x, rwkv_ffn_v4_v5(ctx, x, layer, state)); + break; + case 4: + x = ggml_add(ctx, x, rwkv_att_v4(ctx, x, layer, state, graph)); + x = ggml_add(ctx, x, rwkv_ffn_v4_v5(ctx, x, layer, state)); + break; + default: + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, false, "Unsupported model architecture version"); + break; } struct rwkv_layer_state & output_state = outputs[i]; - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.ffn_xx, output_state.ffn_xx)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_xx, output_state.att_xx)); + ggml_build_forward_expand(graph.cgraph, ggml_cpy(ctx, state.ffn_xx, output_state.ffn_xx)); + ggml_build_forward_expand(graph.cgraph, ggml_cpy(ctx, state.att_xx, output_state.att_xx)); if (model.arch_version_major >= 5) { - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_heads, output_state.att_heads)); + ggml_build_forward_expand(graph.cgraph, ggml_cpy(ctx, state.att_heads, output_state.att_heads)); } else { - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_aa, output_state.att_aa)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_bb, output_state.att_bb)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_pp, output_state.att_pp)); + ggml_build_forward_expand(graph.cgraph, ggml_cpy(ctx, state.att_aa, output_state.att_aa)); + ggml_build_forward_expand(graph.cgraph, ggml_cpy(ctx, state.att_bb, output_state.att_bb)); + ggml_build_forward_expand(graph.cgraph, ggml_cpy(ctx, state.att_pp, output_state.att_pp)); } } @@ -813,7 +705,7 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu x = rwkv_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias); // x = (self.w.head.weight @ x).float() - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), graph.logits)); + ggml_build_forward_expand(graph.cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), graph.logits)); graph.post_logits_nodes = graph.cgraph->n_nodes; graph.post_logits_leafs = graph.cgraph->n_leafs; @@ -836,6 +728,7 @@ static bool rwkv_measure_and_build_serial_context(struct rwkv_model & model, str ggml_free(graph.ggml_ctx); graph.ggml_ctx = NULL; + graph.cgraph = NULL; } graph.ggml_ctx = rwkv_init_ggml_context(rwkv_ggml_overhead(), true); @@ -849,7 +742,9 @@ static bool rwkv_measure_and_build_serial_context(struct rwkv_model & model, str // Creates and sets the input and output ggml tensors, builds the computation graph. static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_computation_graph & graph, const size_t sequence_length) { - graph.cgraph.reset(rwkv_ggml_cgraph_create(RWKV_MAX_NODES, false)); + if (!graph.cgraph) { + graph.cgraph = ggml_new_graph_custom(graph.ggml_ctx, RWKV_MAX_NODES, false); + } struct rwkv_file_header & header = model.header; const size_t n_vocab = header.n_vocab; @@ -885,6 +780,9 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c ggml_set_name(output, "state.out"); ggml_set_input(graph.tokens); + // For v7. + struct ggml_tensor * v_first = NULL; + // x = self.w.emb.weight[token] struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, graph.tokens); @@ -896,66 +794,53 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c struct rwkv_layer_state state = inputs[i]; - if (model.arch_version_major == 6) { - x = ggml_add(ctx, x, rwkv_att_v6( - ctx, - x, - layer, - state, - model.head_count, - model.head_size, - model.arch_version_minor - )); - } else if (model.arch_version_major >= 5) { - x = ggml_add(ctx, x, rwkv_att_v5( - ctx, - x, - layer, - state, - model.head_count, - model.head_size, - model.arch_version_minor - )); - } else { - struct ggml_tensor * x0 = x, * x_prev; - rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x0, x_prev, state.att_xx); - - struct ggml_tensor * r, * k, * v; - rwkv_att_rkv(ctx, layer, x0, x_prev, r, k, v); - - ggml_build_forward_expand(graph.cgraph.get(), r); - - for (size_t t = 0; t < sequence_length; t++) { - struct ggml_tensor * kt = ggml_view_1d(ctx, k, n_embed, n_embed * sizeof(float) * t); - struct ggml_tensor * vt = ggml_view_1d(ctx, v, n_embed, n_embed * sizeof(float) * t); - struct ggml_tensor * xt = ggml_view_1d(ctx, x_prev, n_embed, n_embed * sizeof(float) * t); - struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, kt, vt, state.att_aa, state.att_bb, state.att_pp); - xt = ggml_set_1d_inplace(ctx, xt, wkv, 0); - ggml_build_forward_expand(graph.cgraph.get(), xt); - } - - x = ggml_add(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, x_prev))); + switch (model.arch_version_major) { + case 7: + x = ggml_add(ctx, x, rwkv_att_v7(ctx, x, v_first, layer, state, model.head_count, model.head_size)); + break; + case 6: + x = ggml_add(ctx, x, rwkv_att_v6(ctx, x, layer, state, model.head_count, model.head_size)); + break; + case 5: + x = ggml_add(ctx, x, rwkv_att_v5(ctx, x, layer, state, model.head_count, model.head_size, model.arch_version_minor)); + break; + case 4: + x = ggml_add(ctx, x, rwkv_att_v4(ctx, x, layer, state, graph)); + break; + default: + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, false, "Unsupported model architecture version"); + break; } // TODO Can we skip ffn for all but the last token, the same way we skip unembedding? - if (model.arch_version_major == 6) { - x = ggml_add(ctx, x, rwkv_ffn_v6(ctx, x, layer, state)); - } else { - x = ggml_add(ctx, x, rwkv_ffn(ctx, x, layer, state)); + switch (model.arch_version_major) { + case 7: + x = ggml_add(ctx, x, rwkv_ffn_v7(ctx, x, layer, state)); + break; + case 6: + x = ggml_add(ctx, x, rwkv_ffn_v6(ctx, x, layer, state)); + break; + case 5: + case 4: + x = ggml_add(ctx, x, rwkv_ffn_v4_v5(ctx, x, layer, state)); + break; + default: + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, false, "Unsupported model architecture version"); + break; } struct rwkv_layer_state & output_state = outputs[i]; output_state.att_xx = ggml_set_1d_inplace(ctx, output_state.att_xx, state.att_xx, 0); - ggml_build_forward_expand(graph.cgraph.get(), output_state.att_xx); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.ffn_xx, output_state.ffn_xx)); + ggml_build_forward_expand(graph.cgraph, output_state.att_xx); + ggml_build_forward_expand(graph.cgraph, ggml_cpy(ctx, state.ffn_xx, output_state.ffn_xx)); if (model.arch_version_major >= 5) { - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_heads, output_state.att_heads)); + ggml_build_forward_expand(graph.cgraph, ggml_cpy(ctx, state.att_heads, output_state.att_heads)); } else { - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_aa, output_state.att_aa)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_bb, output_state.att_bb)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_pp, output_state.att_pp)); + ggml_build_forward_expand(graph.cgraph, ggml_cpy(ctx, state.att_aa, output_state.att_aa)); + ggml_build_forward_expand(graph.cgraph, ggml_cpy(ctx, state.att_bb, output_state.att_bb)); + ggml_build_forward_expand(graph.cgraph, ggml_cpy(ctx, state.att_pp, output_state.att_pp)); } } @@ -966,7 +851,7 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c x = rwkv_layer_norm(ctx, ggml_view_1d(ctx, x, n_embed, n_embed * sizeof(float) * (sequence_length - 1)), model.ln_out_weight, model.ln_out_bias); // x = (self.w.head.weight @ x).float() - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), graph.logits)); + ggml_build_forward_expand(graph.cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), graph.logits)); graph.post_logits_nodes = graph.cgraph->n_nodes; graph.post_logits_leafs = graph.cgraph->n_leafs; @@ -986,6 +871,7 @@ static bool rwkv_measure_and_build_sequential_context(struct rwkv_model & model, ggml_free(graph.ggml_ctx); graph.ggml_ctx = NULL; + graph.cgraph = NULL; } graph.ggml_ctx = rwkv_init_ggml_context(rwkv_ggml_overhead(), true); diff --git a/rwkv_model_loading.inc b/rwkv_model_loading.inc index ae14c45c..8f3bd7de 100644 --- a/rwkv_model_loading.inc +++ b/rwkv_model_loading.inc @@ -35,6 +35,24 @@ struct rwkv_layer { struct ggml_tensor * att_time_decay_w1; struct ggml_tensor * att_time_decay_w2; + // Added in RWKV v7. + struct ggml_tensor * att_w0; + struct ggml_tensor * att_w1; + struct ggml_tensor * att_w2; + struct ggml_tensor * att_a0; + struct ggml_tensor * att_a1; + struct ggml_tensor * att_a2; + struct ggml_tensor * att_g1; + struct ggml_tensor * att_g2; + struct ggml_tensor * att_v0; + struct ggml_tensor * att_v1; + struct ggml_tensor * att_v2; + struct ggml_tensor * att_r_k; + struct ggml_tensor * att_k_k; + struct ggml_tensor * att_k_a; + // Concatenated att_x_[r, w, k, v, a, g] + struct ggml_tensor * att_x_rwkvag; + struct ggml_tensor * ln2_weight; struct ggml_tensor * ln2_bias; @@ -46,6 +64,9 @@ struct rwkv_layer { struct ggml_tensor * ffn_time_maa_k; struct ggml_tensor * ffn_time_maa_r; + // Added in RWKV v7. + struct ggml_tensor * ffn_x_k; + struct ggml_tensor * ffn_key; struct ggml_tensor * ffn_value; struct ggml_tensor * ffn_receptance; @@ -107,9 +128,8 @@ template static bool rwkv_set_params(struct rwkv_model & model, F callback, const uint32_t n_gpu_layers) { const size_t n_gpu = std::min(n_gpu_layers, model.header.n_layer + 1); bool offload_head = n_gpu == (model.header.n_layer + 1); - bool offload_default = false; - RWKV_ENSURE_OR_FALSE(callback("emb.weight", model.emb, offload_default)); + RWKV_ENSURE_OR_FALSE(callback("emb.weight", model.emb, false)); RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.weight", model.ln0_weight, (n_gpu_layers > 0))); RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.bias", model.ln0_bias, (n_gpu_layers > 0))); @@ -121,83 +141,140 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback, const uint32_ for (uint32_t i = 0; i < n_layer; i++) { bool offload_layer = (i < n_gpu); char buffer[128]; - size_t offset = sprintf(buffer, "blocks.%" PRId32 ".", i); + size_t offset = snprintf(buffer, sizeof(buffer), "blocks.%" PRId32 ".", i); rwkv_layer & layer = model.layers[i]; RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.weight"), buffer), layer.ln1_weight, offload_layer)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.bias"), buffer), layer.ln1_bias, offload_layer)); - if (model.arch_version_major == 6) { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_x"), buffer), layer.att_time_maa_x, offload_layer)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w"), buffer), layer.att_time_maa_w, offload_layer)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_k"), buffer), layer.att_time_maa_k, offload_layer)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_v"), buffer), layer.att_time_maa_v, offload_layer)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_r"), buffer), layer.att_time_maa_r, offload_layer)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_g"), buffer), layer.att_time_maa_g, offload_layer)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w1"), buffer), layer.att_time_maa_w1, offload_layer)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w2"), buffer), layer.att_time_maa_w2, offload_layer)); - - // No gpu offloading for wkv yet - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_faaaa"), buffer), layer.att_time_faaaa, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w1"), buffer), layer.att_time_decay_w1, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w2"), buffer), layer.att_time_decay_w2, offload_default)); - - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key, offload_layer)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value, offload_layer)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance, offload_layer)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate, offload_layer)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output, offload_layer)); - - // GroupNorm uses a custom epsilon value, which only has CPU implementation for now. - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias, offload_default)); - } else { - // Custom rwkv_1_minus_x: cpu only - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r, offload_default)); - - if (model.arch_version_major >= 5 && model.arch_version_minor >= 2) { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_faaaa"), buffer), layer.att_time_faaaa, offload_default)); - } else { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first, offload_default)); - } - - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay, offload_default)); - - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output, offload_layer)); - - if (model.arch_version_major >= 5) { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias, offload_default)); - - if (model.arch_version_minor >= 2) { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_g"), buffer), layer.att_time_mix_g, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate, offload_layer)); + // ATT. + switch (model.arch_version_major) { + case 7: + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.x_rwkvag"), buffer), layer.att_x_rwkvag, offload_layer)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.w0"), buffer), layer.att_w0, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.w1"), buffer), layer.att_w1, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.w2"), buffer), layer.att_w2, offload_layer)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.a0"), buffer), layer.att_a0, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.a1"), buffer), layer.att_a1, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.a2"), buffer), layer.att_a2, offload_layer)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.g1"), buffer), layer.att_g1, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.g2"), buffer), layer.att_g2, offload_layer)); + + if (i != 0) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.v0"), buffer), layer.att_v0, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.v1"), buffer), layer.att_v1, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.v2"), buffer), layer.att_v2, offload_layer)); } - } - } + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.r_k"), buffer), layer.att_r_k, offload_layer)); + // Somehow offloading this layer makes the model output NaN after several iterations. + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.k_k"), buffer), layer.att_k_k, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.k_a"), buffer), layer.att_k_a, offload_layer)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output, offload_layer)); + + // These too. + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias, offload_layer)); + break; + case 6: + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_x"), buffer), layer.att_time_maa_x, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w"), buffer), layer.att_time_maa_w, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_k"), buffer), layer.att_time_maa_k, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_v"), buffer), layer.att_time_maa_v, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_r"), buffer), layer.att_time_maa_r, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_g"), buffer), layer.att_time_maa_g, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w1"), buffer), layer.att_time_maa_w1, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w2"), buffer), layer.att_time_maa_w2, offload_layer)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_faaaa"), buffer), layer.att_time_faaaa, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w1"), buffer), layer.att_time_decay_w1, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w2"), buffer), layer.att_time_decay_w2, offload_layer)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output, offload_layer)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias, offload_layer)); + break; + case 5: + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r, offload_layer)); + + if (model.arch_version_major >= 5 && model.arch_version_minor >= 2) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_faaaa"), buffer), layer.att_time_faaaa, false)); + } else { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first, false)); + } + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay, false)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output, offload_layer)); + + if (model.arch_version_major >= 5) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias, offload_layer)); - if (model.arch_version_major == 6) { + if (model.arch_version_minor >= 2) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_g"), buffer), layer.att_time_mix_g, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate, offload_layer)); + } + } + break; + case 4: + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k, false)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v, false)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r, false)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first, false)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay, false)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key, false)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value, false)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance, false)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output, offload_layer)); + break; + default: + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, false, "Unsupported model architecture version"); + break; + } + + // FFN. + if (model.arch_version_major == 7) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.x_k"), buffer), layer.ffn_x_k, offload_layer)); + } else if (model.arch_version_major == 6) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight, offload_layer)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias, offload_layer)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_k"), buffer), layer.ffn_time_maa_k, offload_layer)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_r"), buffer), layer.ffn_time_maa_r, offload_layer)); } else { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_k"), buffer), layer.ffn_time_mix_k, offload_default)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_r"), buffer), layer.ffn_time_mix_r, offload_default)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_k"), buffer), layer.ffn_time_mix_k, offload_layer)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_r"), buffer), layer.ffn_time_mix_r, offload_layer)); } RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key, offload_layer)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value, offload_layer)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance, offload_layer)); + if (model.arch_version_major != 7) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance, offload_layer)); + } } RWKV_ENSURE_OR_FALSE(callback("ln_out.weight", model.ln_out_weight, offload_head)); @@ -257,6 +334,11 @@ static bool rwkv_load_model_from_file(const char * file_path, struct rwkv_model model.arch_version_minor = 0; } + if (parameters.find("blocks.0.att.r_k") != parameters.end()) { + model.arch_version_major = 7; + model.arch_version_minor = 0; + } + size_t cpu_buffer_size = 0; size_t gpu_buffer_size = 0; std::unordered_map & parameters_ref = parameters; @@ -314,11 +396,14 @@ static bool rwkv_load_model_from_file(const char * file_path, struct rwkv_model fseek(file.file, tensors_file_start, SEEK_SET); while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) { RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, - rwkv_fread_ggml_tensor_data(file.file, model.ggml_ctx, parameters_ref), + rwkv_fread_ggml_tensor_data(file.file, parameters_ref), "Failed to read a model parameter"); } - if (model.arch_version_major >= 5) { + if (model.arch_version_major == 7) { + model.head_count = model.layers[0].att_r_k->ne[1]; + model.head_size = model.layers[0].ln1_weight->ne[0] / model.head_count; + } else if (model.arch_version_major >= 5) { model.head_count = model.layers[0].att_time_decay->ne[2]; model.head_size = model.layers[0].ln1_weight->ne[0] / model.head_count; } diff --git a/rwkv_operators.inc b/rwkv_operators.inc index 182a1281..ad6a9231 100644 --- a/rwkv_operators.inc +++ b/rwkv_operators.inc @@ -1,48 +1,7 @@ -static void rwkv_validate_tensors_for_custom_unary_op(struct ggml_tensor * dest, const struct ggml_tensor * src) { - GGML_ASSERT(dest->type == GGML_TYPE_F32); - GGML_ASSERT(src->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(dest)); - GGML_ASSERT(ggml_is_contiguous(src)); - GGML_ASSERT(ggml_are_same_shape(src, dest)); - // Verify that the shape is 2D. - GGML_ASSERT(dest->ne[2] == 1); - GGML_ASSERT(dest->ne[3] == 1); -} +#include "rwkv_operators_wkv_v7.inc" #define SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP() { (void) ith; (void) nth; (void) userdata; } -static void rwkv_exp_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) { - rwkv_validate_tensors_for_custom_unary_op(dest, src); - - int64_t element_count = src->ne[0] * src->ne[1]; - int64_t start = ith * element_count / nth; - int64_t end = (ith + 1) * element_count / nth; - float * src_data = (float *) src->data; - float * dest_data = (float *) dest->data; - - for (int64_t i = start; i < end; i++) { - dest_data[i] = expf(src_data[i]); - } - - SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP(); -} - -static void rwkv_1_minus_x_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) { - rwkv_validate_tensors_for_custom_unary_op(dest, src); - - int64_t element_count = src->ne[0] * src->ne[1]; - int64_t start = ith * element_count / nth; - int64_t end = (ith + 1) * element_count / nth; - float * src_data = (float *) src->data; - float * dest_data = (float *) dest->data; - - for (int64_t i = start; i < end; i++) { - dest_data[i] = 1.0F - src_data[i]; - } - - SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP(); -} - static void rwkv_max_impl( struct ggml_tensor * dest, const struct ggml_tensor * src0, @@ -77,8 +36,8 @@ static void rwkv_max_impl( SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP(); } -// From ggml.c -static void rwkv_groupnorm_impl( +// TODO: Upstream to ggml +static void rwkv_l2norm_impl( struct ggml_tensor * dst, const struct ggml_tensor * src0, int ith, @@ -91,63 +50,29 @@ static void rwkv_groupnorm_impl( GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_TENSOR_UNARY_OP_LOCALS - const float eps = ((float*)userdata)[0]; - const int n_groups = ((int32_t*)userdata)[1]; + float eps = 1e-12f; - int n_channels = src0->ne[2]; - int n_channels_per_group = (n_channels + n_groups - 1) / n_groups; - for (int i = ith; i < n_groups; i += nth) { - int start = i * n_channels_per_group; - int end = start + n_channels_per_group; - if (end > n_channels) { - end = n_channels; - } - int step = end - start; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - float sum = 0.0; - for (int64_t i02 = start; i02 < end; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); - - float sumr = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sumr += (float)x[i00]; - } - sum += sumr; - } - } - const float mean = sum / (ne00 * ne01 * step); - - float sum2 = 0.0; - for (int64_t i02 = start; i02 < end; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); - - float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); - - float sumr = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - float v = x[i00] - mean; - y[i00] = v; - sumr += (float)(v * v); - } - sum2 += sumr; + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + float v = x[i00]; + sum += v*v; } - } - const float variance = sum2 / (ne00 * ne01 * step); - const float scale = 1.0f / sqrtf(variance + eps); - - for (int64_t i02 = start; i02 < end; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); - for (int i00 = 0; i00 < ne00; i00++) { - y[i00] *= scale; - } + + float * y = (float *) ((char *) dst->data + i01*nb01 + i02*nb02 + i03*nb03); + + const float scale = 1.0f/fmaxf(sqrtf(sum), eps); + + // ggml_vec_scale_f32(ne00, y, scale); + for (int64_t i00 = 0; i00 < ne00; i00++) { + y[i00] = x[i00] * scale; } } } @@ -156,34 +81,13 @@ static void rwkv_groupnorm_impl( SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP(); } -// Element-wise exp(x) -struct ggml_tensor * rwkv_exp(struct ggml_context * ctx, struct ggml_tensor * x) { - return ggml_map_custom1(ctx, x, rwkv_exp_impl, 1, NULL); -} - -// Element-wise 1 - x -struct ggml_tensor * rwkv_1_minus_x(struct ggml_context * ctx, struct ggml_tensor * x) { - return ggml_map_custom1(ctx, x, rwkv_1_minus_x_impl, 1, NULL); -} - // Element-wise max(x, y) struct ggml_tensor * rwkv_max(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y) { return ggml_map_custom2(ctx, x, y, rwkv_max_impl, 1, NULL); } -// GroupNorm with custom eps value; Remove when ggml_norm supports eps as an argument. -struct ggml_tensor * rwkv_group_norm_eps_1e_minus5(struct ggml_context * ctx, struct ggml_tensor * x, int n_groups) { - static float params[2]; - params[0] = 1e-5F; - ((int*)params)[1] = n_groups; - return ggml_map_custom1_inplace(ctx, x, rwkv_groupnorm_impl, 1, params); -} - -struct ggml_tensor * rwkv_group_norm_eps_64e_minus5(struct ggml_context * ctx, struct ggml_tensor * x, int n_groups) { - static float params[2]; - params[0] = 64e-5F; - ((int*)params)[1] = n_groups; - return ggml_map_custom1_inplace(ctx, x, rwkv_groupnorm_impl, 1, params); +struct ggml_tensor * rwkv_l2norm(struct ggml_context * ctx, struct ggml_tensor * x) { + return ggml_map_custom1(ctx, x, rwkv_l2norm_impl, 1, NULL); } struct ggml_tensor * rwkv_layer_norm(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) { diff --git a/rwkv_operators_wkv_common.inc b/rwkv_operators_wkv_common.inc deleted file mode 100644 index 94e36aaf..00000000 --- a/rwkv_operators_wkv_common.inc +++ /dev/null @@ -1,35 +0,0 @@ -// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L8 -// Original code by Harrison Vanderbyl. -// TODO Fix 1. unaligned memory access on Linux with AVX2, 2. tiny-rwkv with AVX-512 -/*#ifdef __AVX512F__ - #include - #define SIMD_WIDTH 16 - #define LOAD(x) _mm512_load_ps(x) - #define STORE(x, y) _mm512_store_ps(x, y) - #define SET1(x) _mm512_set1_ps(x) - #define MULTIPLY(x, y) _mm512_mul_ps(x, y) - #define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z) -#elif __AVX2__ - #include - #define SIMD_WIDTH 8 - #define LOAD(x) _mm256_load_ps(x) - #define STORE(x, y) _mm256_store_ps(x, y) - #define SET1(x) _mm256_set1_ps(x) - #define MULTIPLY(x, y) _mm256_mul_ps(x, y) - #define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z) -#elif defined(__ARM_NEON) || defined(__ARM_NEON__) - #include - #define SIMD_WIDTH 4 - #define LOAD(x) vld1q_f32(x) - #define STORE(x, y) vst1q_f32(x, y) - #define SET1(x) vdupq_n_f32(x) - #define MULTIPLY(x, y) vmulq_f32(x, y) - #define MULTADD(x, y, z) vmlaq_f32(z, x, y) -#else*/ - #define SIMD_WIDTH 1 - #define LOAD(x) *x - #define STORE(x, y) *x = y - #define SET1(x) x - #define MULTIPLY(x, y) x * y - #define MULTADD(x, y, z) x * y + z -//#endif diff --git a/rwkv_operators_wkv_v5.inc b/rwkv_operators_wkv_v5.inc deleted file mode 100644 index 4c38531c..00000000 --- a/rwkv_operators_wkv_v5.inc +++ /dev/null @@ -1,148 +0,0 @@ -#include "rwkv_operators_wkv_common.inc" - -// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L57 -// Original code by Harrison Vanderbyl. -static void rwkv_wkv_v5_impl(struct ggml_tensor * result, const struct ggml_tensor * src, int ith, int nth, void * userdata) { - const size_t T = result->ne[1]; - const size_t C = result->ne[0]; - const size_t H = result->src[1]->ne[2]; - - // TODO: Multi-threading. - if (ith != 0) - return; - - float * result_data = (float *) result->data; - - memset(result_data, 0, T * C * sizeof(float)); - - float * k = (float *) result->src[1]->data; - float * v = (float *) result->src[2]->data; - float * r = (float *) result->src[3]->data; - float * time_f = (float *) result->src[4]->data; - float * time_decay = (float *) result->src[5]->data; - float * state = (float *) result->src[6]->data; - - size_t t_stride = H * (C / H); - - size_t h_stride = C / H; - size_t h_stride_2d = (C / H) * (C / H); - - for (size_t t = 0; t < T; t++) { - size_t t_offset = t * t_stride; - - for (size_t h = 0; h < H; h++) { - size_t h_offset = h * h_stride; - size_t t_h_offset = t_offset + h_offset; - size_t h_2d_offset = h * h_stride_2d; - - for (size_t i = 0; i < C / H; i++) { - size_t t_h_i_offset = t_h_offset + i; - size_t h_i_offset = h_offset + i; - size_t h_2d_i_offset = h_2d_offset + i * h_stride; - - auto k_val = SET1(k[t_h_i_offset]); - auto r_val = SET1(r[t_h_i_offset]); - auto time_f_val = SET1(time_f[h_i_offset]); - auto time_decay_val = SET1(time_decay[h_i_offset]); - - for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { - size_t t_h_j_offset = t_h_offset + j; - size_t h_2d_i_j_offset = h_2d_i_offset + j; - - auto v_val = LOAD(&v[t_h_j_offset]); - - auto kv_val = MULTIPLY(v_val, k_val); - - auto prev_state_val = LOAD(&state[h_2d_i_j_offset]); - - auto temp_val = MULTADD(kv_val, time_f_val, prev_state_val); - - auto prev_result_data = LOAD(&result_data[t_h_j_offset]); - - STORE(&result_data[t_h_j_offset], MULTADD(temp_val, r_val, prev_result_data)); - - STORE(&state[h_2d_i_j_offset], MULTADD(prev_state_val, time_decay_val, kv_val)); - } - } - } - } - - // Suppress "unused parameter" warnings. - (void) src; - (void) nth; - (void) userdata; -} - -// Parameters: -// - T: sequence length -// - C: channel count, same as n_embed -// - H: head count -// - S: head size -// Shapes (in ggml order): -// - x: [C, T, 1, 1] -// - k: [1, S, H, T] -// - v: [S, 1, H, T] -// - r: [S, 1, H, T] -// - time_f: [1, S, H, 1] -// - time_decay: [1, S, H, 1] -// - state: [S * S * H, 1, 1, 1] -// - result: same as x -// time_f and time_decay must be preprocessed as neccessary -- exp() applied, etc. -// state will be written to. -static struct ggml_tensor * rwkv_wkv_v5( - struct ggml_context * ctx, - const size_t T, - const size_t C, - const size_t H, - const size_t S, - struct ggml_tensor * x, - struct ggml_tensor * k, - struct ggml_tensor * v, - struct ggml_tensor * r, - // time_first for v5.1, time_faaaa for v5.2. - struct ggml_tensor * time_f, - struct ggml_tensor * time_decay, - struct ggml_tensor * state -) { - GGML_ASSERT(x->type == GGML_TYPE_F32); - GGML_ASSERT(k->type == GGML_TYPE_F32); - GGML_ASSERT(v->type == GGML_TYPE_F32); - GGML_ASSERT(r->type == GGML_TYPE_F32); - GGML_ASSERT(time_f->type == GGML_TYPE_F32); - GGML_ASSERT(time_decay->type == GGML_TYPE_F32); - GGML_ASSERT(state->type == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(x)); - GGML_ASSERT(ggml_is_contiguous(k)); - GGML_ASSERT(ggml_is_contiguous(v)); - GGML_ASSERT(ggml_is_contiguous(r)); - GGML_ASSERT(ggml_is_contiguous(time_f)); - GGML_ASSERT(ggml_is_contiguous(time_decay)); - GGML_ASSERT(ggml_is_contiguous(state)); - - GGML_ASSERT(x->ne[0] == C && x->ne[1] == T && x->ne[2] == 1 && x->ne[3] == 1); - GGML_ASSERT(k->ne[0] == 1 && k->ne[1] == S && k->ne[2] == H && k->ne[3] == T); - GGML_ASSERT(v->ne[0] == S && v->ne[1] == 1 && v->ne[2] == H && v->ne[3] == T); - GGML_ASSERT(r->ne[0] == S && r->ne[1] == 1 && r->ne[2] == H && r->ne[3] == T); - GGML_ASSERT(ggml_nelements(state) == S * S * H); - - k = ggml_transpose(ctx, k); - v = ggml_transpose(ctx, v); - r = ggml_transpose(ctx, r); - - struct ggml_tensor * result = ggml_map_custom1( - ctx, - x, - rwkv_wkv_v5_impl, - 1, - NULL - ); - result->src[1] = k; - result->src[2] = v; - result->src[3] = r; - result->src[4] = time_f; - result->src[5] = time_decay; - result->src[6] = state; - - return result; -} diff --git a/rwkv_operators_wkv_v6.inc b/rwkv_operators_wkv_v6.inc deleted file mode 100644 index a89bdaf6..00000000 --- a/rwkv_operators_wkv_v6.inc +++ /dev/null @@ -1,149 +0,0 @@ -#include "rwkv_operators_wkv_common.inc" - -// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L57 -// Original code by Harrison Vanderbyl. -static void rwkv_wkv_v6_impl(struct ggml_tensor * result, const struct ggml_tensor * src, int ith, int nth, void * userdata) { - const size_t T = result->ne[1]; - const size_t C = result->ne[0]; - const size_t H = result->src[1]->ne[2]; - - // TODO: Multi-threading. - if (ith != 0) - return; - - float * result_data = (float *) result->data; - - memset(result_data, 0, T * C * sizeof(float)); - - float * k = (float *) result->src[1]->data; - float * v = (float *) result->src[2]->data; - float * r = (float *) result->src[3]->data; - float * time_faaaa = (float *) result->src[4]->data; - float * time_decay = (float *) result->src[5]->data; - float * state = (float *) result->src[6]->data; - - size_t t_stride = H * (C / H); - - size_t h_stride = C / H; - size_t h_stride_2d = (C / H) * (C / H); - - for (size_t t = 0; t < T; t++) { - size_t t_offset = t * t_stride; - - for (size_t h = 0; h < H; h++) { - size_t h_offset = h * h_stride; - size_t t_h_offset = t_offset + h_offset; - size_t h_2d_offset = h * h_stride_2d; - - for (size_t i = 0; i < C / H; i++) { - size_t t_h_i_offset = t_h_offset + i; - size_t h_i_offset = h_offset + i; - size_t h_2d_i_offset = h_2d_offset + i * h_stride; - - auto k_val = SET1(k[t_h_i_offset]); - auto r_val = SET1(r[t_h_i_offset]); - auto time_faaaa_val = SET1(time_faaaa[h_i_offset]); - // RWKV v6: different time_decay for each token. - auto time_decay_val = SET1(time_decay[t_h_i_offset]); - - for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { - size_t t_h_j_offset = t_h_offset + j; - size_t h_2d_i_j_offset = h_2d_i_offset + j; - - auto v_val = LOAD(&v[t_h_j_offset]); - - auto kv_val = MULTIPLY(v_val, k_val); - - auto prev_state_val = LOAD(&state[h_2d_i_j_offset]); - - auto temp_val = MULTADD(kv_val, time_faaaa_val, prev_state_val); - - auto prev_result_data = LOAD(&result_data[t_h_j_offset]); - - STORE(&result_data[t_h_j_offset], MULTADD(temp_val, r_val, prev_result_data)); - - STORE(&state[h_2d_i_j_offset], MULTADD(prev_state_val, time_decay_val, kv_val)); - } - } - } - } - - // Suppress "unused parameter" warnings. - (void) src; - (void) ith; - (void) nth; - (void) userdata; -} - -// Parameters: -// - T: sequence length -// - C: channel count, same as n_embed -// - H: head count -// - S: head size -// Shapes (in ggml order): -// - x: [C, T, 1, 1] -// - k: [1, S, H, T] -// - v: [S, 1, H, T] -// - r: [S, 1, H, T] -// - time_faaaa: [1, S, H, 1] -// - w: [1, S, H, T] -// - state: [S * S * H, 1, 1, 1] -// - result: same as x -// state will be written to. -static struct ggml_tensor * rwkv_wkv_v6( - struct ggml_context * ctx, - const size_t T, - const size_t C, - const size_t H, - const size_t S, - struct ggml_tensor * x, - struct ggml_tensor * k, - struct ggml_tensor * v, - struct ggml_tensor * r, - struct ggml_tensor * time_faaaa, - struct ggml_tensor * w, - struct ggml_tensor * state -) { - GGML_ASSERT(x->type == GGML_TYPE_F32); - GGML_ASSERT(k->type == GGML_TYPE_F32); - GGML_ASSERT(v->type == GGML_TYPE_F32); - GGML_ASSERT(r->type == GGML_TYPE_F32); - GGML_ASSERT(time_faaaa->type == GGML_TYPE_F32); - GGML_ASSERT(w->type == GGML_TYPE_F32); - GGML_ASSERT(state->type == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(x)); - GGML_ASSERT(ggml_is_contiguous(k)); - GGML_ASSERT(ggml_is_contiguous(v)); - GGML_ASSERT(ggml_is_contiguous(r)); - GGML_ASSERT(ggml_is_contiguous(time_faaaa)); - GGML_ASSERT(ggml_is_contiguous(w)); - GGML_ASSERT(ggml_is_contiguous(state)); - - GGML_ASSERT(x->ne[0] == C && x->ne[1] == T && x->ne[2] == 1 && x->ne[3] == 1); - GGML_ASSERT(k->ne[0] == 1 && k->ne[1] == S && k->ne[2] == H && k->ne[3] == T); - GGML_ASSERT(v->ne[0] == S && v->ne[1] == 1 && v->ne[2] == H && v->ne[3] == T); - GGML_ASSERT(r->ne[0] == S && r->ne[1] == 1 && r->ne[2] == H && r->ne[3] == T); - GGML_ASSERT(w->ne[0] == 1 && w->ne[1] == S && w->ne[2] == H && w->ne[3] == T); - GGML_ASSERT(ggml_nelements(state) == S * S * H); - - k = ggml_transpose(ctx, k); - v = ggml_transpose(ctx, v); - r = ggml_transpose(ctx, r); - - struct ggml_tensor * result = ggml_map_custom1( - ctx, - x, - rwkv_wkv_v6_impl, - 1, - NULL - ); - result->src[1] = k; - result->src[2] = v; - result->src[3] = r; - result->src[4] = time_faaaa; - result->src[5] = w; - result->src[6] = state; - - return result; -} diff --git a/rwkv_operators_wkv_v7.inc b/rwkv_operators_wkv_v7.inc new file mode 100644 index 00000000..eb4541d3 --- /dev/null +++ b/rwkv_operators_wkv_v7.inc @@ -0,0 +1,179 @@ +// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L8 +// Original code by Harrison Vanderbyl. +// TODO Fix 1. unaligned memory access on Linux with AVX2, 2. tiny-rwkv with AVX-512 +/*#ifdef __AVX512F__ + #include + #define SIMD_WIDTH 16 + #define LOAD(x) _mm512_load_ps(x) + #define STORE(x, y) _mm512_store_ps(x, y) + #define SET1(x) _mm512_set1_ps(x) + #define MULTIPLY(x, y) _mm512_mul_ps(x, y) + #define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z) +#elif __AVX2__ + #include + #define SIMD_WIDTH 8 + #define LOAD(x) _mm256_load_ps(x) + #define STORE(x, y) _mm256_store_ps(x, y) + #define SET1(x) _mm256_set1_ps(x) + #define MULTIPLY(x, y) _mm256_mul_ps(x, y) + #define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z) +#elif defined(__ARM_NEON) || defined(__ARM_NEON__) + #include + #define SIMD_WIDTH 4 + #define LOAD(x) vld1q_f32(x) + #define STORE(x, y) vst1q_f32(x, y) + #define SET1(x) vdupq_n_f32(x) + #define MULTIPLY(x, y) vmulq_f32(x, y) + #define MULTADD(x, y, z) vmlaq_f32(z, x, y) +#else*/ + #define SIMD_WIDTH 1 + #define LOAD(x) *x + #define STORE(x, y) *x = y + #define SET1(x) x + #define MULTIPLY(x, y) x * y + #define MULTADD(x, y, z) x * y + z +//#endif + +static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tensor * src, int ith, int nth, void * userdata) { + // const size_t T = result->ne[1]; + const size_t C = result->ne[0]; + const size_t S = result->src[1]->ne[0]; + const size_t H = result->src[1]->ne[1]; + const size_t T = result->src[1]->ne[2]; + GGML_ASSERT(C == S * H); + + float * result_data = (float *) result->data; + float * state_out = (float *) result->data + C * T; + + float * state = (float *) src->data; + float * r = (float *) result->src[1]->data; + float * w = (float *) result->src[2]->data; + float * k = (float *) result->src[3]->data; + float * v = (float *) result->src[4]->data; + float * a = (float *) result->src[5]->data; + float * b = (float *) result->src[6]->data; + + size_t t_stride = H * S; + + size_t h_stride = C / H; + size_t h_stride_2d = S * S; + + for (size_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + + float * state_in = (t == 0) ? state : state_out; + + for (size_t h = ith; h < H; h += nth) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (size_t i = 0; i < C / H; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + auto v_val = v[t_h_i_offset]; + + float sa = 0; + for (size_t j = 0; j < C / H; j++) { + sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j]; + } + + if (i == 0) { + memset(&result_data[t_h_offset], 0, h_stride * sizeof(float)); + } + + for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + + auto r_val = r[t_h_j_offset]; + auto w_val = w[t_h_j_offset]; + auto k_val = k[t_h_j_offset]; + auto b_val = b[t_h_j_offset]; + auto kv_val = v_val * k_val; + auto prev_state_val = state_in[h_2d_i_j_offset]; + state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; + result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val; + } + } + } + } + + // Suppress "unused parameter" warnings. + (void) src; + (void) nth; + (void) userdata; +} + +// Parameters: +// - T: sequence length +// - C: channel count, same as n_embed +// - H: head count +// - S: head size +// Shapes (in ggml order): +// - r: [S, H, T] +// - w: [S, H, T] +// - k: [S, H, T] +// - v: [S, H, T] +// - a: [S, H, T] +// - b: [S, H, T] +// - state: [S * S * H, 1, 1, 1] +// - result: concated output + state_output +static struct ggml_tensor * rwkv_wkv_v7( + struct ggml_context * ctx, + struct ggml_tensor * state, + struct ggml_tensor * r, + struct ggml_tensor * w, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * a, + struct ggml_tensor * b +) { + GGML_ASSERT(r->type == GGML_TYPE_F32); + GGML_ASSERT(w->type == GGML_TYPE_F32); + GGML_ASSERT(k->type == GGML_TYPE_F32); + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(b->type == GGML_TYPE_F32); + GGML_ASSERT(state->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(r)); + GGML_ASSERT(ggml_is_contiguous(w)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_contiguous(b)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S = r->ne[0]; + const int64_t H = r->ne[1]; + const int64_t T = r->ne[2]; + const int64_t C = S * H; + + GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == T); + GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == T); + GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == T); + GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == T); + GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == T); + GGML_ASSERT(ggml_nelements(state) == S * S * H); + + struct ggml_tensor * result = ggml_map_custom1( + ctx, + state, + rwkv_wkv_v7_impl, + 1, + NULL + ); + result->src[1] = r; + result->src[2] = w; + result->src[3] = k; + result->src[4] = v; + result->src[5] = a; + result->src[6] = b; + + result->ne[0] = C; + result->ne[1] = T + S; + + return result; +} diff --git a/rwkv_quantize.inc b/rwkv_quantize.inc index c8fc2227..ac6edc87 100644 --- a/rwkv_quantize.inc +++ b/rwkv_quantize.inc @@ -1,3 +1,17 @@ +static bool rwkv_tensor_needs_quant(std::string name) { + return name != "emb.weight" && + name != "head.weight" && + name.find("att.v1") == std::string::npos && + name.find("att.v2") == std::string::npos && + name.find("att.g1") == std::string::npos && + name.find("att.g2") == std::string::npos && + name.find("att.a1") == std::string::npos && + name.find("att.a2") == std::string::npos && + name.find("att.w1") == std::string::npos && + name.find("att.w2") == std::string::npos && + name.find("att.r_k") == std::string::npos; +} + // API function. bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * type_name) { global_last_error = RWKV_ERROR_NONE; @@ -86,9 +100,6 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(in_file.file, sizeof(struct rwkv_file_header), SEEK_SET) == 0, "Failed to seek in file"); - // This is a histogram of quantized values. If it shows single 1.0, then all 0.0, something went very wrong! - int64_t hist_all[16] {}; - std::unique_ptr scratch(new(std::nothrow) uint8_t[max_in_size + max_out_size]); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate buffer"); @@ -125,10 +136,9 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const // In RWKV v5, time_decay and time_first/time_faaaa are 3D tensors, so they are not quantized. if ((header.data_type == TYPE_FP32 || header.data_type == TYPE_FP16) && header.dim_count == 2 && - name != "emb.weight" && - name != "head.weight" + rwkv_tensor_needs_quant(name) ) { - RWKV_MSG("quantizing... "); + RWKV_MSG("-> %6s ", rwkv_type_to_string[rwkv_type_from_ggml[out_type]]); size_t nelements = (size_t) header.size0 * (size_t) header.size1 * (size_t) header.size2; @@ -140,7 +150,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const header.data_type = rwkv_type_from_ggml[out_type]; data = out_buf; - RWKV_MSG("size = %8.2f MB -> %8.2f MB | hist: ", orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + RWKV_MSG("size = %8.2f MB -> %8.2f MB", orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); RWKV_MSG("\n"); } else { diff --git a/rwkv_utilities.inc b/rwkv_utilities.inc index 1108091f..58898f0d 100644 --- a/rwkv_utilities.inc +++ b/rwkv_utilities.inc @@ -32,8 +32,16 @@ static bool rwkv_fread_uint32(FILE * file, uint32_t & dest) { // Reads a single string value from a file. static bool rwkv_fread_string(FILE * file, const size_t length, std::string & dest) { - dest.resize(length); - return fread((void *) dest.data(), length, 1, file) == 1; + char * buffer = new(std::nothrow) char[length]; + if (!buffer) { + return false; + } + int ret = fread(buffer, length, 1, file); + if (ret == 1) { + dest.assign(buffer, length); + } + delete[] buffer; + return ret == 1; } // Reads a single data buffer from a file. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 851819f7..22d67f6d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -5,6 +5,7 @@ function(rwkv_add_test source) set_property(TARGET ${TEST_TARGET} PROPERTY CUDA_ARCHITECTURES OFF) endif() target_link_libraries(${TEST_TARGET} PRIVATE ggml rwkv) + target_include_directories(${TEST_TARGET} PRIVATE ${CMAKE_SOURCE_DIR}/ggml/include ${CMAKE_SOURCE_DIR}/ggml/src) add_test(NAME ${TEST_TARGET} COMMAND $ ${ARGN}) if (RWKV_STATIC) if(RWKV_HIPBLAS) @@ -41,6 +42,12 @@ file(COPY tiny-rwkv-6v0-3m-Q5_0.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) file(COPY tiny-rwkv-6v0-3m-Q5_1.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) file(COPY expected-logits-6v0-3m.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-7v0-834K-FP32.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-7v0-834K-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-7v0-834K-Q5_0.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-7v0-834K-Q5_1.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY expected-logits-7v0-834K.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + rwkv_add_test(test_ggml_basics.c) rwkv_add_test(test_quantized_matmul_on_gpu.c) rwkv_add_test(test_tiny_rwkv.c) diff --git a/tests/expected-logits-7v0-834K.bin b/tests/expected-logits-7v0-834K.bin new file mode 100644 index 00000000..c4e18257 Binary files /dev/null and b/tests/expected-logits-7v0-834K.bin differ diff --git a/tests/logit_difference_validator.inc b/tests/logit_difference_validator.inc index 644d5013..a4cb68d7 100644 --- a/tests/logit_difference_validator.inc +++ b/tests/logit_difference_validator.inc @@ -17,7 +17,7 @@ void load_expected_logits(float * expected_logits, const char * version) { char file_name[128]; - sprintf(file_name, "expected-logits-%s.bin", version); + snprintf(file_name, sizeof(file_name), "expected-logits-%s.bin", version); FILE * file = fopen(file_name, "rb"); ASSERT(file != NULL, "Failed to open %s", file_name); size_t elements_read = fread(expected_logits, sizeof(float), N_VOCAB, file); @@ -27,7 +27,7 @@ void load_expected_logits(float * expected_logits, const char * version) { void test_model(const char * version, const char * format, const float * expected_logits, const float max_diff) { char file_name[128]; - sprintf(file_name, "tiny-rwkv-%s-%s.bin", version, format); + snprintf(file_name, sizeof(file_name), "tiny-rwkv-%s-%s.bin", version, format); fprintf(stderr, "Testing %s\n", file_name); @@ -63,7 +63,7 @@ void test_model(const char * version, const char * format, const float * expecte diff_sum += logits[i] - expected_logits[i]; } - fprintf(stderr, "Serial difference sum: %f, expected %f\n", diff_sum, max_diff); + fprintf(stderr, "Serial difference sum: %f, expected %f\n", (double)diff_sum, (double)max_diff); ASSERT(fabsf(diff_sum) <= fabsf(max_diff) * 1.05F, "Too big serial difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); @@ -78,7 +78,7 @@ void test_model(const char * version, const char * format, const float * expecte diff_sum += logits[i] - expected_logits[i]; } - fprintf(stderr, "Sequence difference sum: %f, expected %f\n", diff_sum, max_diff); + fprintf(stderr, "Sequence difference sum: %f, expected %f\n", (double)diff_sum, (double)max_diff); ASSERT(fabsf(diff_sum) <= fabsf(max_diff) * 1.05F, "Too big sequence difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); diff --git a/tests/test_eval_sequence_in_chunks.c b/tests/test_eval_sequence_in_chunks.c index 0a35da1f..0ac5ca7f 100644 --- a/tests/test_eval_sequence_in_chunks.c +++ b/tests/test_eval_sequence_in_chunks.c @@ -30,7 +30,7 @@ void test_on_prompt(const char * prompt, const size_t prompt_length) { uint32_t * prompt_tokens = calloc(prompt_length, sizeof(uint32_t)); - for (int i = 0; i < prompt_length; i++) { + for (size_t i = 0; i < prompt_length; i++) { prompt_tokens[i] = prompt[i]; } diff --git a/tests/test_ggml_basics.c b/tests/test_ggml_basics.c index ec0a50bf..c3230947 100644 --- a/tests/test_ggml_basics.c +++ b/tests/test_ggml_basics.c @@ -4,14 +4,23 @@ #include #include +#include +#include #include "assertions.inc" #define SET_ELEMENT_F32(tensor, i, value) ((float *) tensor->data)[i] = value void test_simple_computation(void) { + size_t ctx_size = 0; + { + ctx_size += 3 * 4 * ggml_type_size(GGML_TYPE_F32); + ctx_size += 3 * ggml_tensor_overhead(); + ctx_size += ggml_graph_overhead(); + ctx_size += 1024; + } struct ggml_init_params params = { - .mem_size = 16 * 1024, + .mem_size = ctx_size, .mem_buffer = NULL, .no_alloc = false, }; @@ -32,27 +41,12 @@ void test_simple_computation(void) { struct ggml_tensor * sum = ggml_add(ctx, x, y); - // Allocation on heap instead of stack avoids SegFault when GGML_MAX_NODES is set to a large value. - struct ggml_cgraph * graph = (struct ggml_cgraph *) calloc(1, sizeof(struct ggml_cgraph)); - graph->size = GGML_DEFAULT_GRAPH_SIZE; - graph->n_nodes = 0; - graph->n_leafs = 0; - graph->nodes = (struct ggml_tensor **) calloc(1, GGML_DEFAULT_GRAPH_SIZE * sizeof(struct ggml_tensor *)); - graph->leafs = (struct ggml_tensor **) calloc(1, GGML_DEFAULT_GRAPH_SIZE * sizeof(struct ggml_tensor *)); - size_t hash_size = GGML_DEFAULT_GRAPH_SIZE * 2 + 1; - graph->visited_hash_table.size = hash_size; - graph->visited_hash_table.keys = (struct ggml_tensor **) calloc(1, hash_size * sizeof(struct ggml_tensor *)); - graph->order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; + struct ggml_cgraph * graph = ggml_new_graph(ctx); ggml_build_forward_expand(graph, sum); - struct ggml_cplan plan = ggml_graph_plan(graph, 2); + struct ggml_cplan plan = ggml_graph_plan(graph, 2, NULL); ggml_graph_compute(graph, &plan); - free(graph->nodes); - free(graph->leafs); - free(graph->visited_hash_table.keys); - free(graph); - ASSERT_ELEMENT_F32(sum, 0, -9.0F); ASSERT_ELEMENT_F32(sum, 1, 2.0F); ASSERT_ELEMENT_F32(sum, 2, 5.5F); @@ -65,8 +59,15 @@ void test_simple_computation(void) { // RWKV model loading code depends on this behavior. void test_computation_on_tensors_from_different_contexts(void) { + size_t ctx_size = 0; + { + ctx_size += 4 * ggml_type_size(GGML_TYPE_F32); + ctx_size += ggml_tensor_overhead(); + ctx_size += ggml_graph_overhead(); + ctx_size += 1024; + } struct ggml_init_params params = { - .mem_size = 16 * 1024, + .mem_size = ctx_size, .mem_buffer = NULL, .no_alloc = false, }; @@ -85,26 +86,11 @@ void test_computation_on_tensors_from_different_contexts(void) { struct ggml_tensor * sum = ggml_add(ctx2, x, y); - // Allocation on heap instead of stack avoids SegFault when GGML_MAX_NODES is set to a large value. - struct ggml_cgraph * graph = (struct ggml_cgraph *) calloc(1, sizeof(struct ggml_cgraph)); - graph->size = GGML_DEFAULT_GRAPH_SIZE; - graph->n_nodes = 0; - graph->n_leafs = 0; - graph->nodes = (struct ggml_tensor **) calloc(1, GGML_DEFAULT_GRAPH_SIZE * sizeof(struct ggml_tensor *)); - graph->leafs = (struct ggml_tensor **) calloc(1, GGML_DEFAULT_GRAPH_SIZE * sizeof(struct ggml_tensor *)); - size_t hash_size = GGML_DEFAULT_GRAPH_SIZE * 2 + 1; - graph->visited_hash_table.size = hash_size; - graph->visited_hash_table.keys = (struct ggml_tensor **) calloc(1, hash_size * sizeof(struct ggml_tensor *)); - graph->order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; + struct ggml_cgraph * graph = ggml_new_graph(ctx2); ggml_build_forward_expand(graph, sum); - struct ggml_cplan plan = ggml_graph_plan(graph, 2); + struct ggml_cplan plan = ggml_graph_plan(graph, 2, NULL); ggml_graph_compute(graph, &plan); - free(graph->nodes); - free(graph->leafs); - free(graph->visited_hash_table.keys); - free(graph); - ASSERT_ELEMENT_F32(sum, 0, -9.0F); ASSERT_ELEMENT_F32(sum, 1, 2.0F); diff --git a/tests/test_quantized_matmul_on_gpu.c b/tests/test_quantized_matmul_on_gpu.c index 9666f57c..ecb74811 100644 --- a/tests/test_quantized_matmul_on_gpu.c +++ b/tests/test_quantized_matmul_on_gpu.c @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -106,8 +107,8 @@ int main(void) { ggml_backend_tensor_get(mul0, &result0, 0, ggml_nbytes(mul0)); ggml_backend_tensor_get(mul1, &result1, 0, ggml_nbytes(mul1)); - fprintf(stderr, "FP32 CPU result = %f\n", result0); - fprintf(stderr, "Q5_0 GPU result = %f\n", result1); + fprintf(stderr, "FP32 CPU result = %f\n", (double)result0); + fprintf(stderr, "Q5_0 GPU result = %f\n", (double)result1); ASSERT(fabsf(result0 - result1) <= 100.0F, "Results differ too much"); diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index 9072ae3c..2d9ef04e 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -6,7 +6,7 @@ #include "logit_difference_validator.inc" -#define VERSION_COUNT 4 +#define VERSION_COUNT 5 #define FORMAT_COUNT 7 int main(void) { @@ -21,7 +21,8 @@ int main(void) { "4v0-660K", "5v1-730K", "5v2-730K", - "6v0-3m" + "6v0-3m", + "7v0-834K" }; const char * formats[FORMAT_COUNT] = { @@ -46,7 +47,10 @@ int main(void) { +0.455912F, // FP16 // 6v0 +0.001000F, // FP32 - -0.416620F // FP16 + -0.416620F, // FP16 + // 7v0 + +0.001000F, // FP32 + +0.005766F // FP16 }; // *** Why the hell the expected logit difference sum for v4 models is < 1, and for v5 models it can be as high as 160? *** @@ -87,7 +91,13 @@ int main(void) { +021.939022F, // Q4_1 -027.332073F, // Q5_0 +003.576909F, // Q5_1 - -009.539596F // Q8_0 + -009.539596F, // Q8_0 + // 7v0 + +000.136785F, // Q4_0 + +000.002614F, // Q4_1 + -000.063645F, // Q5_0 + -000.064663F, // Q5_1 + +000.011924F // Q8_0 }; const float expected_difference_sum_quantized_FP16[VERSION_COUNT * (FORMAT_COUNT - 2)] = { @@ -114,7 +124,13 @@ int main(void) { +021.797060F, // Q4_1 -027.269241F, // Q5_0 +003.405264F, // Q5_1 - -009.734720F // Q8_0 + -009.734720F, // Q8_0 + // 7v0 + +000.136678F, // Q4_0 + -000.005140F, // Q4_1 + -000.064447F, // Q5_0 + -000.063531F, // Q5_1 + +000.010921F // Q8_0 }; for (int i_version = 0; i_version < VERSION_COUNT; i_version++) { @@ -129,14 +145,14 @@ int main(void) { } char source_file_name[128]; - char dest_format[128]; + char dest_format[32]; char dest_file_name[128]; // --- - sprintf(source_file_name, "tiny-rwkv-%s-FP32.bin", versions[i_version]); - sprintf(dest_format, "FP32-to-%s", formats[i_format]); - sprintf(dest_file_name, "tiny-rwkv-%s-%s.bin", versions[i_version], dest_format); + snprintf(source_file_name, sizeof(source_file_name), "tiny-rwkv-%s-FP32.bin", versions[i_version]); + snprintf(dest_format, sizeof(dest_format), "FP32-to-%s", formats[i_format]); + snprintf(dest_file_name, sizeof(dest_file_name), "tiny-rwkv-%s-%s.bin", versions[i_version], dest_format); rwkv_quantize_model_file(source_file_name, dest_file_name, formats[i_format]); @@ -144,9 +160,9 @@ int main(void) { // --- - sprintf(source_file_name, "tiny-rwkv-%s-FP16.bin", versions[i_version]); - sprintf(dest_format, "FP16-to-%s", formats[i_format]); - sprintf(dest_file_name, "tiny-rwkv-%s-%s.bin", versions[i_version], dest_format); + snprintf(source_file_name, sizeof(source_file_name), "tiny-rwkv-%s-FP16.bin", versions[i_version]); + snprintf(dest_format, sizeof(dest_format), "FP16-to-%s", formats[i_format]); + snprintf(dest_file_name, sizeof(dest_file_name), "tiny-rwkv-%s-%s.bin", versions[i_version], dest_format); rwkv_quantize_model_file(source_file_name, dest_file_name, formats[i_format]); diff --git a/tests/tiny-rwkv-7v0-834K-FP16.bin b/tests/tiny-rwkv-7v0-834K-FP16.bin new file mode 100644 index 00000000..70c2c57a Binary files /dev/null and b/tests/tiny-rwkv-7v0-834K-FP16.bin differ diff --git a/tests/tiny-rwkv-7v0-834K-FP32.bin b/tests/tiny-rwkv-7v0-834K-FP32.bin new file mode 100644 index 00000000..49ceb3dc Binary files /dev/null and b/tests/tiny-rwkv-7v0-834K-FP32.bin differ diff --git a/tests/tiny-rwkv-7v0-834K-Q5_0.bin b/tests/tiny-rwkv-7v0-834K-Q5_0.bin new file mode 100644 index 00000000..e1b740a2 Binary files /dev/null and b/tests/tiny-rwkv-7v0-834K-Q5_0.bin differ diff --git a/tests/tiny-rwkv-7v0-834K-Q5_1.bin b/tests/tiny-rwkv-7v0-834K-Q5_1.bin new file mode 100644 index 00000000..f3b707ac Binary files /dev/null and b/tests/tiny-rwkv-7v0-834K-Q5_1.bin differ