From 2c8630dac4b8eec62959828c1679f8e2887b3153 Mon Sep 17 00:00:00 2001 From: firelightning13 Date: Thu, 21 Sep 2023 16:01:33 +0800 Subject: [PATCH 1/8] Prevent the bot from replying multiple replies - Implement asynchronous call in textgen.py - Use asyncio cancel the current task and start a new one each time a text message is received --- cogs/pygbot.py | 29 +++++++++++++++++------ helpers/textgen.py | 59 ++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/cogs/pygbot.py b/cogs/pygbot.py index 03a69eb..b9cc728 100644 --- a/cogs/pygbot.py +++ b/cogs/pygbot.py @@ -9,6 +9,7 @@ from discord import app_commands from discord.ext import commands import os +import asyncio # load environment STOP_SEQUENCES variables and split them into a list by comma @@ -127,7 +128,7 @@ async def get_memory_for_channel(self, channel_id): name = message[0] channel_ids = str(message[1]) message = message[2] - print(f"{name}: {message}") + #print(f"{name}: {message}") await self.add_history(name, channel_ids, message) # self.memory = self.histories[channel_id] @@ -160,7 +161,7 @@ async def generate_response(self, message, message_content) -> None: name = message.author.display_name memory = await self.get_memory_for_channel(str(channel_id)) stop_sequence = await self.get_stop_sequence_for_channel(channel_id, name) - print(f"stop sequences: {stop_sequence}") + #print(f"stop sequences: {stop_sequence}") formatted_message = f"{name}: {message_content}" MAIN_TEMPLATE = f""" {self.top_character_info} @@ -176,11 +177,11 @@ async def generate_response(self, message, message_content) -> None: conversation = ConversationChain( prompt=PROMPT, llm=self.llm, - verbose=True, + #verbose=True, memory=memory, ) input_dict = {"input": formatted_message, "stop": stop_sequence} - response_text = conversation(input_dict) + response_text = await conversation.acall(input_dict) response = await self.detect_and_replace_out(response_text["response"]) with open(self.convo_filename, "a", encoding="utf-8") as f: f.write(f"{message.author.display_name}: {message_content}\n") @@ -199,7 +200,7 @@ async def add_history(self, name, channel_id, message_content) -> None: formatted_message = f"{name}: {message_content}" # add the message to the memory - print(f"adding message to memory: {formatted_message}") + #print(f"adding message to memory: {formatted_message}") memory.add_input_only(formatted_message) return None @@ -209,6 +210,7 @@ def __init__(self, bot): self.bot = bot self.chatlog_dir = bot.chatlog_dir self.chatbot = Chatbot(bot) + self.current_task = None # create chatlog directory if it doesn't exist if not os.path.exists(self.chatlog_dir): @@ -233,8 +235,21 @@ async def chat_command(self, name, channel_id, message_content, message) -> None and self.chatbot.convo_filename != chatlog_filename ): await self.chatbot.set_convo_filename(chatlog_filename) - response = await self.chatbot.generate_response(message, message_content) - return response + + # Check if the task is still running + print(f"The current task is: {self.current_task}") + if self.current_task is not None and not self.current_task.done(): + print("Cancelling previous task") + self.current_task.cancel() + + # Create new task and store in current_task + self.current_task = asyncio.create_task(self.chatbot.generate_response(message, message_content)) + try: + response = await self.current_task + return response + except asyncio.CancelledError: + print("Request cancelled") + return None # No Response Handler @commands.command(name="chatnr") diff --git a/helpers/textgen.py b/helpers/textgen.py index 1365b49..175a407 100644 --- a/helpers/textgen.py +++ b/helpers/textgen.py @@ -3,8 +3,13 @@ from typing import Any, Dict, Iterator, List, Optional import requests +import asyncio +import aiohttp -from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import LLM from langchain.pydantic_v1 import Field from langchain.schema.output import GenerationChunk @@ -217,7 +222,7 @@ def _call( print(params["stopping_strings"]) # TODO: Remove this line request = params.copy() request["prompt"] = prompt - print(request) # TODO: Remove this line + #print(request) # TODO: Remove this line response = requests.post(url, json=request) if response.status_code == 200: @@ -229,6 +234,56 @@ def _call( return result + # Implement _acall function from LangChain example github + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call the textgen web API and return the output. + + Args: + prompt: The prompt to use for generation. + stop: A list of strings to stop generation when encountered. + + Returns: + The generated text. + + Example: + .. code-block:: python + + from langchain.llms import TextGen + llm = TextGen(model_url="http://localhost:5000") + llm("Write a story about llamas.") + """ + if self.streaming: + combined_text_output = "" + async for chunk in self._astream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + combined_text_output += chunk.text + result = combined_text_output + + else: + # Use aiohttp to call textgen API asynchronously to prevent blocking + async with aiohttp.ClientSession() as session: + url = f"{self.model_url}/api/v1/generate" + params = self._get_parameters(stop) + request = params.copy() + request["prompt"] = prompt + #response = requests.post(url, json=request) + + async with session.post(url, json=request) as response: + if response.status == 200: + result = (await response.json())["results"][0]["text"] + else: + print(f"ERROR: response: {response}") + result = "" + + return result + def _stream( self, prompt: str, From daae24b55b60ec3b698bb21315b3780896a2ccbe Mon Sep 17 00:00:00 2001 From: firelightning13 Date: Tue, 26 Sep 2023 14:28:47 +0800 Subject: [PATCH 2/8] Fix bot memory 2nd last user reply issue - Bot can now see 2nd last user reply by storing last message before send another reply - Avoid `NoneType` error in messagehandler.py, I forgor --- cogs/messagehandler.py | 23 ++++++++++++----------- cogs/pygbot.py | 13 +++++++++---- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/cogs/messagehandler.py b/cogs/messagehandler.py index 39280fc..40e4ad5 100644 --- a/cogs/messagehandler.py +++ b/cogs/messagehandler.py @@ -249,17 +249,18 @@ async def handle_text_message(self, message, mode=""): message, ) await self.add_message_to_dict(message, message.clean_content) - async with message.channel.typing(): - # If the response is more than 2000 characters, split it - chunks = [response[i : i + 1998] for i in range(0, len(response), 1998)] - for chunk in chunks: - print(chunk) - response_obj = await message.channel.send(chunk) - await self.add_message_to_dict( - response_obj, response_obj.clean_content - ) - # self.bot.sent_last_message[str(message.channel.id)] = True - # await log_message(response_obj) + if response: + async with message.channel.typing(): + # If the response is more than 2000 characters, split it + chunks = [response[i : i + 1998] for i in range(0, len(response), 1998)] + for chunk in chunks: + print(chunk) + response_obj = await message.channel.send(chunk) + await self.add_message_to_dict( + response_obj, response_obj.clean_content + ) + # self.bot.sent_last_message[str(message.channel.id)] = True + # await log_message(response_obj) async def set_listen_only_mode_timer(self, channel_id): # Start the timer diff --git a/cogs/pygbot.py b/cogs/pygbot.py index b9cc728..6621a47 100644 --- a/cogs/pygbot.py +++ b/cogs/pygbot.py @@ -177,7 +177,7 @@ async def generate_response(self, message, message_content) -> None: conversation = ConversationChain( prompt=PROMPT, llm=self.llm, - #verbose=True, + verbose=True, memory=memory, ) input_dict = {"input": formatted_message, "stop": stop_sequence} @@ -210,7 +210,10 @@ def __init__(self, bot): self.bot = bot self.chatlog_dir = bot.chatlog_dir self.chatbot = Chatbot(bot) + + # Store current task and last message here self.current_task = None + self.last_message = None # create chatlog directory if it doesn't exist if not os.path.exists(self.chatlog_dir): @@ -237,18 +240,20 @@ async def chat_command(self, name, channel_id, message_content, message) -> None await self.chatbot.set_convo_filename(chatlog_filename) # Check if the task is still running - print(f"The current task is: {self.current_task}") + #print(f"The current task is: {self.current_task}") # for debugging purposes if self.current_task is not None and not self.current_task.done(): - print("Cancelling previous task") + # Cancelling previous task, add last message to the history + await self.chatbot.add_history(name, str(channel_id), self.last_message) self.current_task.cancel() # Create new task and store in current_task + self.last_message = message_content self.current_task = asyncio.create_task(self.chatbot.generate_response(message, message_content)) try: response = await self.current_task return response except asyncio.CancelledError: - print("Request cancelled") + print(f"Cancelled {self.chatbot.char_name}'s current response, regenerate another reply...") return None # No Response Handler From 56065b63510cd88e5b7e11fa744236ad7a2bf9cf Mon Sep 17 00:00:00 2001 From: firelightning13 Date: Wed, 27 Sep 2023 14:20:03 +0000 Subject: [PATCH 3/8] Fix stop sequence for TextGen in `_acall` function - I literally forgot --- helpers/textgen.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/helpers/textgen.py b/helpers/textgen.py index 175a407..4212a00 100644 --- a/helpers/textgen.py +++ b/helpers/textgen.py @@ -271,6 +271,9 @@ async def _acall( async with aiohttp.ClientSession() as session: url = f"{self.model_url}/api/v1/generate" params = self._get_parameters(stop) + params["stopping_strings"] = params.pop( + "stop" + ) # Rename 'stop' to 'stopping_strings' request = params.copy() request["prompt"] = prompt #response = requests.post(url, json=request) From 66da2291548542c64583cc638420a842438ee75d Mon Sep 17 00:00:00 2001 From: firelightning13 Date: Wed, 4 Oct 2023 21:09:50 +0800 Subject: [PATCH 4/8] Made KoboldAI API compatible - implementation carried over to koboldAI API - so far tested only with koboldcpp --- discordbot.py | 3 +- helpers/koboldai.py | 264 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 266 insertions(+), 1 deletion(-) create mode 100644 helpers/koboldai.py diff --git a/discordbot.py b/discordbot.py index 837f35d..0b8d44e 100644 --- a/discordbot.py +++ b/discordbot.py @@ -6,7 +6,8 @@ from pathlib import Path import base64 from helpers.textgen import TextGen -from langchain.llms import KoboldApiLLM, OpenAI +from helpers.koboldai import KoboldApiLLM +from langchain.llms import OpenAI from discord import app_commands from discord.ext import commands from discord.ext.commands import Bot diff --git a/helpers/koboldai.py b/helpers/koboldai.py new file mode 100644 index 0000000..1ca828b --- /dev/null +++ b/helpers/koboldai.py @@ -0,0 +1,264 @@ +import logging +from typing import Any, Dict, List, Optional + +import requests +import asyncio +import aiohttp + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.llms.base import LLM + +logger = logging.getLogger(__name__) + + +def clean_url(url: str) -> str: + """Remove trailing slash and /api from url if present.""" + if url.endswith("/api"): + return url[:-4] + elif url.endswith("/"): + return url[:-1] + else: + return url + + +class KoboldApiLLM(LLM): + """Kobold API language model. + + It includes several fields that can be used to control the text generation process. + + To use this class, instantiate it with the required parameters and call it with a + prompt to generate text. For example: + + kobold = KoboldApiLLM(endpoint="http://localhost:5000") + result = kobold("Write a story about a dragon.") + + This will send a POST request to the Kobold API with the provided prompt and + generate text. + """ + + endpoint: str + """The API endpoint to use for generating text.""" + + use_story: Optional[bool] = False + """ Whether or not to use the story from the KoboldAI GUI when generating text. """ + + use_authors_note: Optional[bool] = False + """Whether to use the author's note from the KoboldAI GUI when generating text. + + This has no effect unless use_story is also enabled. + """ + + use_world_info: Optional[bool] = False + """Whether to use the world info from the KoboldAI GUI when generating text.""" + + use_memory: Optional[bool] = False + """Whether to use the memory from the KoboldAI GUI when generating text.""" + + max_context_length: Optional[int] = 1600 + """Maximum number of tokens to send to the model. + + minimum: 1 + """ + + max_length: Optional[int] = 80 + """Number of tokens to generate. + + maximum: 512 + minimum: 1 + """ + + rep_pen: Optional[float] = 1.12 + """Base repetition penalty value. + + minimum: 1 + """ + + rep_pen_range: Optional[int] = 1024 + """Repetition penalty range. + + minimum: 0 + """ + + rep_pen_slope: Optional[float] = 0.9 + """Repetition penalty slope. + + minimum: 0 + """ + + temperature: Optional[float] = 0.6 + """Temperature value. + + exclusiveMinimum: 0 + """ + + tfs: Optional[float] = 0.9 + """Tail free sampling value. + + maximum: 1 + minimum: 0 + """ + + top_a: Optional[float] = 0.9 + """Top-a sampling value. + + minimum: 0 + """ + + top_p: Optional[float] = 0.95 + """Top-p sampling value. + + maximum: 1 + minimum: 0 + """ + + top_k: Optional[int] = 0 + """Top-k sampling value. + + minimum: 0 + """ + + typical: Optional[float] = 0.5 + """Typical sampling value. + + maximum: 1 + minimum: 0 + """ + + @property + def _llm_type(self) -> str: + return "koboldai" + + # Define a helper method to generate the data dict + def _get_parameters( + self, + prompt: str, + stop: Optional[List[str]] = None) -> Dict[str, Any]: + """Get the parameters to send to the API.""" + data: Dict[str, Any] = { + "prompt": prompt, + "use_story": self.use_story, + "use_authors_note": self.use_authors_note, + "use_world_info": self.use_world_info, + "use_memory": self.use_memory, + "max_context_length": self.max_context_length, + "max_length": self.max_length, + "rep_pen": self.rep_pen, + "rep_pen_range": self.rep_pen_range, + "rep_pen_slope": self.rep_pen_slope, + "temperature": self.temperature, + "tfs": self.tfs, + "top_a": self.top_a, + "top_p": self.top_p, + "top_k": self.top_k, + "typical": self.typical, + } + + if stop: + data["stop_sequence"] = stop + + return data + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call the API and return the output. + + Args: + prompt: The prompt to use for generation. + stop: A list of strings to stop generation when encountered. + + Returns: + The generated text. + + Example: + .. code-block:: python + + from langchain.llms import KoboldApiLLM + + llm = KoboldApiLLM(endpoint="http://localhost:5000") + llm("Write a story about dragons.") + """ + data = self._get_parameters(prompt, stop) + + response = requests.post( + f"{clean_url(self.endpoint)}/api/v1/generate", json=data + ) + + response.raise_for_status() + json_response = response.json() + + if ( + "results" in json_response + and len(json_response["results"]) > 0 + and "text" in json_response["results"][0] + ): + text = json_response["results"][0]["text"].strip() + + if stop is not None: + for sequence in stop: + if text.endswith(sequence): + text = text[: -len(sequence)].rstrip() + + return text + else: + raise ValueError( + f"Unexpected response format from Kobold API: {json_response}" + ) + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call the API and return the output. + + Args: + prompt: The prompt to use for generation. + stop: A list of strings to stop generation when encountered. + + Returns: + The generated text. + + Example: + .. code-block:: python + + from langchain.llms import KoboldApiLLM + + llm = KoboldApiLLM(endpoint="http://localhost:5000") + llm("Write a story about dragons.") + """ + data = self._get_parameters(prompt, stop) + + # Use aiohttp to call KoboldAI API asynchronously to prevent blocking + async with aiohttp.ClientSession() as session: + async with session.post(f"{clean_url(self.endpoint)}/api/v1/generate", json=data) as response: + + response.raise_for_status() + json_response = await response.json() + + if ( + "results" in json_response + and len(json_response["results"]) > 0 + and "text" in json_response["results"][0] + ): + text = json_response["results"][0]["text"].strip() + + if stop is not None: + for sequence in stop: + if text.endswith(sequence): + text = text[: -len(sequence)].rstrip() + + return text + else: + raise ValueError( + f"Unexpected response format from Kobold API: {json_response}" + ) \ No newline at end of file From a180526bbbd98e8c6015101cbff82ecfdd839887 Mon Sep 17 00:00:00 2001 From: firelightning13 Date: Thu, 5 Oct 2023 21:29:52 +0800 Subject: [PATCH 5/8] Stop generate text via API (koboldcpp only) - Made a `_stop` function in koboldai.py - By invoking `/api/extra/abort` - Official KoboldAI doesn't work --- cogs/pygbot.py | 5 +++++ helpers/koboldai.py | 17 ++++++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/cogs/pygbot.py b/cogs/pygbot.py index 6621a47..4e7cef1 100644 --- a/cogs/pygbot.py +++ b/cogs/pygbot.py @@ -244,6 +244,11 @@ async def chat_command(self, name, channel_id, message_content, message) -> None if self.current_task is not None and not self.current_task.done(): # Cancelling previous task, add last message to the history await self.chatbot.add_history(name, str(channel_id), self.last_message) + + # If the llm type is "koboldai", stop the generation from the API + if self.bot.llm._llm_type == "koboldai": + await self.bot.llm._stop() + self.current_task.cancel() # Create new task and store in current_task diff --git a/helpers/koboldai.py b/helpers/koboldai.py index 1ca828b..cc7d26b 100644 --- a/helpers/koboldai.py +++ b/helpers/koboldai.py @@ -212,6 +212,7 @@ def _call( f"Unexpected response format from Kobold API: {json_response}" ) + # New function to call KoboldAI API asynchronously async def _acall( self, prompt: str, @@ -261,4 +262,18 @@ async def _acall( else: raise ValueError( f"Unexpected response format from Kobold API: {json_response}" - ) \ No newline at end of file + ) + + async def _stop(self): + """Send abort request to stop ongoing AI generation. + This only applies to koboldcpp. Official KoboldAI API does not support this. + """ + + try: + async with aiohttp.ClientSession() as session: + async with session.post(f"{clean_url(self.endpoint)}/api/extra/abort") as response: + if response.status == 200: + print("Successfully aborted AI generation.") + + except Exception as e: + print(f"Error aborting AI generation: {e}") \ No newline at end of file From 454f121356b1a9c582ab446bfb5b215cd7fe4f1e Mon Sep 17 00:00:00 2001 From: firelightning13 Date: Mon, 9 Oct 2023 22:57:31 +0800 Subject: [PATCH 6/8] Support abrupt replies from multiple channel IDs - Using dict for current_tasks, with channel IDs as a key - Perform a check if the endpoint is using koboldcpp or KoboldAI - TODO: support koboldcpp with multi-user requests --- cogs/pygbot.py | 34 +++++++++++++++++++--------------- discordbot.py | 6 ++++++ helpers/koboldai.py | 14 +++++++++++++- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/cogs/pygbot.py b/cogs/pygbot.py index 4e7cef1..91a75e5 100644 --- a/cogs/pygbot.py +++ b/cogs/pygbot.py @@ -212,8 +212,8 @@ def __init__(self, bot): self.chatbot = Chatbot(bot) # Store current task and last message here - self.current_task = None - self.last_message = None + self.current_tasks = {} + self.last_messages = {} # create chatlog directory if it doesn't exist if not os.path.exists(self.chatlog_dir): @@ -239,23 +239,27 @@ async def chat_command(self, name, channel_id, message_content, message) -> None ): await self.chatbot.set_convo_filename(chatlog_filename) - # Check if the task is still running - #print(f"The current task is: {self.current_task}") # for debugging purposes - if self.current_task is not None and not self.current_task.done(): - # Cancelling previous task, add last message to the history - await self.chatbot.add_history(name, str(channel_id), self.last_message) + # Check if the task is still running by channel ID + #print(f"The current task is: {self.current_tasks[channel_id]}") # for debugging purposes + if channel_id in self.current_tasks: + task = self.current_tasks[channel_id] + + if task is not None and not task.done(): + # Cancelling previous task, add last message to the history + await self.chatbot.add_history(name, str(channel_id), self.last_messages[channel_id]) - # If the llm type is "koboldai", stop the generation from the API - if self.bot.llm._llm_type == "koboldai": - await self.bot.llm._stop() + # If the endpoint is koboldcpp, stop the generation + if self.bot.koboldcpp_version >= 1.29: + await self.bot.llm._stop() - self.current_task.cancel() + self.current_task.cancel() + + # Create a new task and last message bounded to the channel ID + self.last_messages[channel_id] = message_content + self.current_tasks[channel_id] = asyncio.create_task(self.chatbot.generate_response(message, message_content)) - # Create new task and store in current_task - self.last_message = message_content - self.current_task = asyncio.create_task(self.chatbot.generate_response(message, message_content)) try: - response = await self.current_task + response = await self.current_tasks[channel_id] return response except asyncio.CancelledError: print(f"Cancelled {self.chatbot.char_name}'s current response, regenerate another reply...") diff --git a/discordbot.py b/discordbot.py index 0b8d44e..803e626 100644 --- a/discordbot.py +++ b/discordbot.py @@ -240,6 +240,12 @@ async def on_ready(): "\n\n\n\nERROR: Unable to retrieve channel from .env \nPlease make sure you're using a valid channel ID, not a server ID." ) + # Check if the endpoint is connected to koboldcpp + if bot.llm._llm_type == "koboldai": + bot.koboldcpp_version = bot.llm.check_version() + else: + bot.koboldcpp_version = 0.0 + # COG LOADER async def load_cogs() -> None: diff --git a/helpers/koboldai.py b/helpers/koboldai.py index cc7d26b..8988f90 100644 --- a/helpers/koboldai.py +++ b/helpers/koboldai.py @@ -264,6 +264,18 @@ async def _acall( f"Unexpected response format from Kobold API: {json_response}" ) + def check_version(self) -> float: + """Check the version of the koboldcpp API. To distinguish between KoboldAI and koboldcpp""" + try: + response = requests.get(f"{clean_url(self.endpoint)}/api/extra/version") + response.raise_for_status() + json_response = response.json() + print("The endpoint is running koboldcpp instead of KoboldAI. Stop generation is supported.") + return float(json_response["version"]) + except Exception as e: + print("The endpoint is running KoboldAI instead of koboldcpp. Stop generation is not supported.") + return 0.0 + async def _stop(self): """Send abort request to stop ongoing AI generation. This only applies to koboldcpp. Official KoboldAI API does not support this. @@ -272,7 +284,7 @@ async def _stop(self): try: async with aiohttp.ClientSession() as session: async with session.post(f"{clean_url(self.endpoint)}/api/extra/abort") as response: - if response.status == 200: + if response.status == 200 and response.json()["success"] == True: print("Successfully aborted AI generation.") except Exception as e: From 0170dd34ccd5990f49b9cbf27013cd1ad5c8533c Mon Sep 17 00:00:00 2001 From: firelightning13 Date: Fri, 20 Oct 2023 20:46:55 +0800 Subject: [PATCH 7/8] Implement `genkey` to support multiple channel IDs (NOT TESTED!) - using dict to generate unique keys for each of channel IDs - `_stop` function rolled back to normal request POST instead of aiohttp session --- cogs/pygbot.py | 12 +++++++--- discordbot.py | 2 ++ helpers/koboldai.py | 58 ++++++++++++++++++++++++++++++++++++--------- 3 files changed, 58 insertions(+), 14 deletions(-) diff --git a/cogs/pygbot.py b/cogs/pygbot.py index 91a75e5..df96151 100644 --- a/cogs/pygbot.py +++ b/cogs/pygbot.py @@ -181,7 +181,13 @@ async def generate_response(self, message, message_content) -> None: memory=memory, ) input_dict = {"input": formatted_message, "stop": stop_sequence} - response_text = await conversation.acall(input_dict) + + # Run the conversation chain + if self.bot.koboldcpp_version >= 1.29: + response_text = await conversation.acall(input_dict,channel_id) + else: + response_text = await conversation.acall(input_dict) + response = await self.detect_and_replace_out(response_text["response"]) with open(self.convo_filename, "a", encoding="utf-8") as f: f.write(f"{message.author.display_name}: {message_content}\n") @@ -248,9 +254,9 @@ async def chat_command(self, name, channel_id, message_content, message) -> None # Cancelling previous task, add last message to the history await self.chatbot.add_history(name, str(channel_id), self.last_messages[channel_id]) - # If the endpoint is koboldcpp, stop the generation + # If the endpoint is koboldcpp, stop the generation by channel ID if self.bot.koboldcpp_version >= 1.29: - await self.bot.llm._stop() + await self.bot.llm._stop(channel_id) self.current_task.cancel() diff --git a/discordbot.py b/discordbot.py index 803e626..4f3c010 100644 --- a/discordbot.py +++ b/discordbot.py @@ -246,6 +246,8 @@ async def on_ready(): else: bot.koboldcpp_version = 0.0 + print(f"KoboldCPP Version: {bot.koboldcpp_version}") + # COG LOADER async def load_cogs() -> None: diff --git a/helpers/koboldai.py b/helpers/koboldai.py index 8988f90..9248279 100644 --- a/helpers/koboldai.py +++ b/helpers/koboldai.py @@ -5,6 +5,9 @@ import asyncio import aiohttp +import random +import string + from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -13,7 +16,6 @@ logger = logging.getLogger(__name__) - def clean_url(url: str) -> str: """Remove trailing slash and /api from url if present.""" if url.endswith("/api"): @@ -127,6 +129,10 @@ class KoboldApiLLM(LLM): minimum: 0 """ + # To store genkeys for each generation + genkeys = {} + is_koboldcpp = False + @property def _llm_type(self) -> str: return "koboldai" @@ -218,6 +224,7 @@ async def _acall( prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + channel_id: Optional[str] = None, **kwargs: Any, ) -> str: """Call the API and return the output. @@ -237,7 +244,19 @@ async def _acall( llm = KoboldApiLLM(endpoint="http://localhost:5000") llm("Write a story about dragons.") """ - data = self._get_parameters(prompt, stop) + if self.is_koboldcpp: + # Generate a random 10 character genkey + genkey = "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) + print(f"genkey: {genkey}") + + # Store genkeys to dict mapped to channel ID + self.genkeys[channel_id] = genkey + data = self._get_parameters(prompt, stop) + data["genkey"] = genkey + + else: + # Normal for KoboldAI, genkey is not required + data = self._get_parameters(prompt, stop) # Use aiohttp to call KoboldAI API asynchronously to prevent blocking async with aiohttp.ClientSession() as session: @@ -270,22 +289,39 @@ def check_version(self) -> float: response = requests.get(f"{clean_url(self.endpoint)}/api/extra/version") response.raise_for_status() json_response = response.json() - print("The endpoint is running koboldcpp instead of KoboldAI. Stop generation is supported.") + self.is_koboldcpp = True + print("The endpoint is running koboldcpp instead of KoboldAI. If you use multiple channel IDs, please pass '--multiuser' to koboldcpp.") return float(json_response["version"]) - except Exception as e: - print("The endpoint is running KoboldAI instead of koboldcpp. Stop generation is not supported.") - return 0.0 + except: + # Try fetching KoboldAI version + try: + response = requests.get(f"{clean_url(self.endpoint)}/api/v1/version") + response.raise_for_status() + json_response = response.json() + self.is_koboldcpp = False + print("The endpoint is running KoboldAI instead of koboldcpp.") + return 0.0 + except: + raise ValueError("The endpoint is not running KoboldAI or koboldcpp.") - async def _stop(self): + + async def _stop(self, channel_id): """Send abort request to stop ongoing AI generation. This only applies to koboldcpp. Official KoboldAI API does not support this. """ + + # Check genkey before cancelling + if channel_id in self.genkeys: + genkey = self.genkeys[channel_id] + + json = {"genkey": genkey} try: - async with aiohttp.ClientSession() as session: - async with session.post(f"{clean_url(self.endpoint)}/api/extra/abort") as response: - if response.status == 200 and response.json()["success"] == True: - print("Successfully aborted AI generation.") + response = requests.post(f"{clean_url(self.endpoint)}/api/extra/abort", json=json) + if response.status_code == 200 and response.json()["success"] == True: + print(f"Successfully aborted AI generation for channel ID of {channel_id}, with genkey: {genkey}") + else: + print("Error aborting AI generation.") except Exception as e: print(f"Error aborting AI generation: {e}") \ No newline at end of file From 2c838958d0e2afb2a5b34ba1cc9f144eab412dbe Mon Sep 17 00:00:00 2001 From: firelightning13 Date: Fri, 20 Oct 2023 20:48:14 +0800 Subject: [PATCH 8/8] Oops --- discordbot.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/discordbot.py b/discordbot.py index 4f3c010..30f6ad6 100644 --- a/discordbot.py +++ b/discordbot.py @@ -243,11 +243,10 @@ async def on_ready(): # Check if the endpoint is connected to koboldcpp if bot.llm._llm_type == "koboldai": bot.koboldcpp_version = bot.llm.check_version() + print(f"KoboldCPP Version: {bot.koboldcpp_version}") else: bot.koboldcpp_version = 0.0 - print(f"KoboldCPP Version: {bot.koboldcpp_version}") - # COG LOADER async def load_cogs() -> None: