From 176d321dac42a5a2d37c495136e43909a37012fb Mon Sep 17 00:00:00 2001 From: Roshan Piyush Date: Mon, 19 Jan 2026 17:53:33 +0530 Subject: [PATCH 1/3] Add other ai provider support --- README.md | 2 + deploy/docker/docker-compose.yml | 25 ++- deploy/helm/templates/chatbot/config.yaml | 25 ++- deploy/helm/values.yaml | 44 ++++- deploy/k8s/base/chatbot/config.yaml | 44 +++++ docs/setup.md | 43 +++++ services/chatbot/.env | 2 + services/chatbot/requirements.txt | 8 +- services/chatbot/src/chatbot/chat_api.py | 151 +++++++++++++---- services/chatbot/src/chatbot/chat_service.py | 7 +- services/chatbot/src/chatbot/config.py | 19 ++- .../chatbot/src/chatbot/langgraph_agent.py | 63 ++++++- .../chatbot/src/chatbot/retriever_utils.py | 159 +++++++++++++++--- .../chatbot/src/chatbot/session_service.py | 39 +++-- .../chatbot/src/mcpserver/tool_helpers.py | 18 +- 15 files changed, 568 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index 94be9f51..ccbb3169 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,8 @@ Visit [http://localhost:8888](http://localhost:8888) [http://localhost:8025](http://localhost:8025) You can change the smtp configuration if required however all emails with domain **example.com** will still go to mailhog. +For chatbot LLM provider configuration, see [setup instructions](docs/setup.md#chatbot-llm-configuration). + ### Vagrant This option allows you to run crAPI within a virtual machine, thus isolated from diff --git a/deploy/docker/docker-compose.yml b/deploy/docker/docker-compose.yml index da84ffbe..abd759fc 100755 --- a/deploy/docker/docker-compose.yml +++ b/deploy/docker/docker-compose.yml @@ -172,11 +172,32 @@ services: - API_USER=admin@example.com - API_PASSWORD=Admin!123 - OPENAPI_SPEC=/app/resources/crapi-openapi-spec.json - - DEFAULT_MODEL=gpt-4o-mini + - CHATBOT_LLM_PROVIDER=${CHATBOT_LLM_PROVIDER:-openai} + - CHATBOT_LLM_MODEL=${CHATBOT_LLM_MODEL:-} + - CHATBOT_EMBEDDINGS_MODEL=${CHATBOT_EMBEDDINGS_MODEL:-} + - CHATBOT_EMBEDDINGS_DIMENSIONS=${CHATBOT_EMBEDDINGS_DIMENSIONS:-1536} + - CHATBOT_OPENAI_API_KEY=${CHATBOT_OPENAI_API_KEY:-} + - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-} + - AZURE_OPENAI_API_KEY=${AZURE_OPENAI_API_KEY:-} + - AZURE_AD_TOKEN=${AZURE_AD_TOKEN:-} + - AZURE_OPENAI_ENDPOINT=${AZURE_OPENAI_ENDPOINT:-} + - AZURE_OPENAI_API_VERSION=${AZURE_OPENAI_API_VERSION:-2024-02-15-preview} + - AZURE_OPENAI_CHAT_DEPLOYMENT=${AZURE_OPENAI_CHAT_DEPLOYMENT:-} + - AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT=${AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT:-} + - GROQ_API_KEY=${GROQ_API_KEY:-} + - MISTRAL_API_KEY=${MISTRAL_API_KEY:-} + - COHERE_API_KEY=${COHERE_API_KEY:-} + - AWS_BEARER_TOKEN_BEDROCK=${AWS_BEARER_TOKEN_BEDROCK:-} + - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-} + - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-} + - AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN:-} + - AWS_REGION=${AWS_REGION:-} + - GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS:-} + - VERTEX_PROJECT=${VERTEX_PROJECT:-} + - VERTEX_LOCATION=${VERTEX_LOCATION:-} - MAX_CONTENT_LENGTH=50000 - CHROMA_HOST=chromadb - CHROMA_PORT=8000 - # - CHATBOT_OPENAI_API_KEY= depends_on: mongodb: condition: service_healthy diff --git a/deploy/helm/templates/chatbot/config.yaml b/deploy/helm/templates/chatbot/config.yaml index 94be0c0e..c72a7402 100644 --- a/deploy/helm/templates/chatbot/config.yaml +++ b/deploy/helm/templates/chatbot/config.yaml @@ -20,8 +20,29 @@ data: MONGO_DB_USER: {{ .Values.mongodb.config.mongoUser }} MONGO_DB_PASSWORD: {{ .Values.mongodb.config.mongoPassword }} MONGO_DB_NAME: {{ .Values.mongodb.config.mongoDbName }} - CHATBOT_OPENAI_API_KEY: {{ .Values.openAIApiKey }} - DEFAULT_MODEL: {{ .Values.chatbot.config.defaultModel | quote }} + CHATBOT_OPENAI_API_KEY: {{ .Values.chatbotOpenaiApiKey | default .Values.openAIApiKey | quote }} + CHATBOT_LLM_PROVIDER: {{ .Values.chatbotLlmProvider | quote }} + CHATBOT_LLM_MODEL: {{ .Values.chatbotLlmModel | quote }} + CHATBOT_EMBEDDINGS_MODEL: {{ .Values.chatbotEmbeddingsModel | quote }} + CHATBOT_EMBEDDINGS_DIMENSIONS: {{ .Values.chatbotEmbeddingsDimensions | quote }} + ANTHROPIC_API_KEY: {{ .Values.anthropicApiKey | quote }} + AZURE_OPENAI_API_KEY: {{ .Values.azureOpenaiApiKey | quote }} + AZURE_AD_TOKEN: {{ .Values.azureAdToken | quote }} + AZURE_OPENAI_ENDPOINT: {{ .Values.azureOpenaiEndpoint | quote }} + AZURE_OPENAI_API_VERSION: {{ .Values.azureOpenaiApiVersion | quote }} + AZURE_OPENAI_CHAT_DEPLOYMENT: {{ .Values.azureOpenaiChatDeployment | quote }} + AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT: {{ .Values.azureOpenaiEmbeddingsDeployment | quote }} + GROQ_API_KEY: {{ .Values.groqApiKey | quote }} + MISTRAL_API_KEY: {{ .Values.mistralApiKey | quote }} + COHERE_API_KEY: {{ .Values.cohereApiKey | quote }} + AWS_BEARER_TOKEN_BEDROCK: {{ .Values.awsBearerTokenBedrock | quote }} + AWS_ACCESS_KEY_ID: {{ .Values.awsAccessKeyId | quote }} + AWS_SECRET_ACCESS_KEY: {{ .Values.awsSecretAccessKey | quote }} + AWS_SESSION_TOKEN: {{ .Values.awsSessionToken | quote }} + AWS_REGION: {{ .Values.awsRegion | quote }} + GOOGLE_APPLICATION_CREDENTIALS: {{ .Values.googleApplicationCredentials | quote }} + VERTEX_PROJECT: {{ .Values.vertexProject | quote }} + VERTEX_LOCATION: {{ .Values.vertexLocation | quote }} MAX_CONTENT_LENGTH: {{ .Values.chatbot.config.maxContentLength | quote }} CHROMA_HOST: {{ .Values.chromadb.service.name }} CHROMA_PORT: {{ .Values.chromadb.port | quote }} diff --git a/deploy/helm/values.yaml b/deploy/helm/values.yaml index ec842e43..9fb01ae1 100644 --- a/deploy/helm/values.yaml +++ b/deploy/helm/values.yaml @@ -11,9 +11,50 @@ apiGatewayServiceInstall: true apiGatewayPassword: tlsEnabled: false jwtExpiration: 604800000 -openAIApiKey: "" logLevel: INFO +# Chatbot LLM Configuration +chatbotLlmProvider: openai +chatbotLlmModel: "" +chatbotEmbeddingsModel: "" +chatbotEmbeddingsDimensions: 1536 + +# OpenAI API Key +chatbotOpenaiApiKey: "" +openAIApiKey: "" # for backward compatibility + +# Anthropic API Key +anthropicApiKey: "" + +# Azure configuration +azureOpenaiApiKey: "" +azureAdToken: "" +azureOpenaiEndpoint: "" +azureOpenaiApiVersion: 2024-02-15-preview +azureOpenaiChatDeployment: "" +azureOpenaiEmbeddingsDeployment: "" + +# Groq API Key +groqApiKey: "" + +# Mistral API Key +mistralApiKey: "" + +# Cohere API Key +cohereApiKey: "" + +# AWS Bedrock configuration +awsBearerTokenBedrock: "" +awsAccessKeyId: "" +awsSecretAccessKey: "" +awsSessionToken: "" +awsRegion: "" + +# Google Vertex AI configuration +googleApplicationCredentials: "" +vertexProject: "" +vertexLocation: "" + waitForK8sResources: enabled: True image: groundnuty/k8s-wait-for:v2.0 @@ -152,7 +193,6 @@ chatbot: postgresDbDriver: postgres mongoDbDriver: mongodb secretKey: crapi - defaultModel: gpt-4o-mini maxContentLength: 50000 chromaPersistDirectory: /app/vectorstore apiUser: admin@example.com diff --git a/deploy/k8s/base/chatbot/config.yaml b/deploy/k8s/base/chatbot/config.yaml index 04357e68..d0a59cdc 100644 --- a/deploy/k8s/base/chatbot/config.yaml +++ b/deploy/k8s/base/chatbot/config.yaml @@ -4,3 +4,47 @@ metadata: name: crapi-chatbot-configmap labels: app: crapi-chatbot +data: + SERVER_PORT: "5002" + IDENTITY_SERVICE: "crapi-identity:8080" + WEB_SERVICE: "crapi-web" + TLS_ENABLED: "false" + DB_HOST: "postgresdb" + DB_USER: "admin" + DB_PASSWORD: "crapisecretpassword" + DB_NAME: "crapi" + DB_PORT: "5432" + MONGO_DB_HOST: "mongodb" + MONGO_DB_PORT: "27017" + MONGO_DB_USER: "admin" + MONGO_DB_PASSWORD: "crapisecretpassword" + MONGO_DB_NAME: "crapi" + CHATBOT_OPENAI_API_KEY: "" + CHATBOT_LLM_PROVIDER: "openai" + CHATBOT_LLM_MODEL: "" + CHATBOT_EMBEDDINGS_MODEL: "" + CHATBOT_EMBEDDINGS_DIMENSIONS: "1536" + ANTHROPIC_API_KEY: "" + AZURE_OPENAI_API_KEY: "" + AZURE_AD_TOKEN: "" + AZURE_OPENAI_ENDPOINT: "" + AZURE_OPENAI_API_VERSION: "2024-02-15-preview" + AZURE_OPENAI_CHAT_DEPLOYMENT: "" + AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT: "" + GROQ_API_KEY: "" + MISTRAL_API_KEY: "" + COHERE_API_KEY: "" + AWS_BEARER_TOKEN_BEDROCK: "" + AWS_ACCESS_KEY_ID: "" + AWS_SECRET_ACCESS_KEY: "" + AWS_SESSION_TOKEN: "" + AWS_REGION: "" + GOOGLE_APPLICATION_CREDENTIALS: "" + VERTEX_PROJECT: "" + VERTEX_LOCATION: "" + MAX_CONTENT_LENGTH: "50000" + CHROMA_HOST: "chromadb" + CHROMA_PORT: "8000" + API_USER: "admin@example.com" + API_PASSWORD: "Admin!123" + OPENAPI_SPEC: "/app/resources/crapi-openapi-spec.json" diff --git a/docs/setup.md b/docs/setup.md index d9fd2ec2..2296c477 100644 --- a/docs/setup.md +++ b/docs/setup.md @@ -117,6 +117,49 @@ You can use prebuilt images generated by our CI workflow by downloading the dock [http://localhost:8025](http://localhost:8025) You can change the smtp configuration if required however all emails with domain **example.com** will still go to mailhog. +### Chatbot LLM configuration + +The chatbot supports multiple LLM providers. Provider selection and models are set via environment variables. OpenAI and Anthropic keys are set via the API (with env fallback); all other providers use env credentials only. + +Core settings: +- `CHATBOT_LLM_PROVIDER`: `openai`, `anthropic`, `azure_openai`, `bedrock`, `vertex`, `groq`, `mistral`, `cohere` +- `CHATBOT_LLM_MODEL`: optional, provider model name. Defaults per provider: + - openai: `gpt-4o-mini` + - anthropic: `claude-sonnet-4-20250514` + - bedrock: `anthropic.claude-3-sonnet-20240229-v1:0` + - vertex: `gemini-1.5-flash` + - groq: `llama-3.3-70b-versatile` + - mistral: `mistral-large-latest` + - cohere: `command-r-plus` + - azure_openai: uses `AZURE_OPENAI_CHAT_DEPLOYMENT` +- `CHATBOT_EMBEDDINGS_MODEL`: optional embeddings model name (used across providers) +- `CHATBOT_EMBEDDINGS_DIMENSIONS`: optional, defaults to `1536` + +OpenAI (API, optional env fallback): +- POST `/genai/init` with `{"openai_api_key":"..."}` (per-session) +- `CHATBOT_OPENAI_API_KEY` (optional env fallback) + +Anthropic (API, optional env fallback): +- POST `/genai/init` with `{"anthropic_api_key":"..."}` (per-session) +- `ANTHROPIC_API_KEY` (optional env fallback) + +Azure OpenAI (env only): +- `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_CHAT_DEPLOYMENT` +- Auth: `AZURE_OPENAI_API_KEY` or `AZURE_AD_TOKEN` (managed identity) +- Optional: `AZURE_OPENAI_API_VERSION`, `AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT` + +AWS Bedrock (env only): +- `AWS_REGION` and either: + - `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_SESSION_TOKEN` (optional), or + - `AWS_BEARER_TOKEN_BEDROCK` + +Google Vertex AI (env only): +- `VERTEX_PROJECT`, `VERTEX_LOCATION` +- `GOOGLE_APPLICATION_CREDENTIALS` (optional if using ADC in GCP environments) + +Groq / Mistral / Cohere (env only): +- `GROQ_API_KEY`, `MISTRAL_API_KEY`, `COHERE_API_KEY` + ### Build it yourself 1. Clone crAPI repository diff --git a/services/chatbot/.env b/services/chatbot/.env index dad7157a..ef3a786a 100644 --- a/services/chatbot/.env +++ b/services/chatbot/.env @@ -19,4 +19,6 @@ export DEFAULT_MODEL=gpt-4o-mini export MAX_CONTENT_LENGTH=50000 export CHROMA_HOST=localhost export CHROMA_PORT=8000 +export CHATBOT_LLM_PROVIDER=openai +export CHATBOT_LLM_MODEL=gpt-4o-mini export CHATBOT_OPENAI_API_KEY= \ No newline at end of file diff --git a/services/chatbot/requirements.txt b/services/chatbot/requirements.txt index 3c85f4c0..ab34e6e2 100644 --- a/services/chatbot/requirements.txt +++ b/services/chatbot/requirements.txt @@ -20,4 +20,10 @@ quart-cors==0.8.0 motor==3.7.1 faiss-cpu==1.12.0 psycopg2-binary==2.9.11 -uvicorn==0.38.0 \ No newline at end of file +uvicorn==0.38.0 +langchain-aws +langchain-google-vertexai +langchain-anthropic +langchain-groq +langchain-mistralai +langchain-cohere \ No newline at end of file diff --git a/services/chatbot/src/chatbot/chat_api.py b/services/chatbot/src/chatbot/chat_api.py index 056e5b6f..b2b5ec00 100644 --- a/services/chatbot/src/chatbot/chat_api.py +++ b/services/chatbot/src/chatbot/chat_api.py @@ -1,4 +1,5 @@ import logging +import os from uuid import uuid4 from quart import Blueprint, jsonify, request @@ -14,33 +15,98 @@ logger = logging.getLogger(__name__) +def _validate_provider_env(provider: str) -> str | None: + if provider == "openai": + return None + if provider == "anthropic": + return None + if provider == "azure_openai": + if not Config.AZURE_OPENAI_API_KEY and not Config.AZURE_AD_TOKEN: + return "Missing AZURE_OPENAI_API_KEY or AZURE_AD_TOKEN" + if not Config.AZURE_OPENAI_ENDPOINT: + return "Missing AZURE_OPENAI_ENDPOINT" + if not Config.AZURE_OPENAI_CHAT_DEPLOYMENT: + return "Missing AZURE_OPENAI_CHAT_DEPLOYMENT" + return None + if provider == "groq": + if not Config.GROQ_API_KEY: + return "Missing GROQ_API_KEY" + return None + if provider == "mistral": + if not Config.MISTRAL_API_KEY: + return "Missing MISTRAL_API_KEY" + return None + if provider == "cohere": + if not Config.COHERE_API_KEY: + return "Missing COHERE_API_KEY" + return None + if provider == "bedrock": + if not os.environ.get("AWS_REGION"): + return "Missing AWS_REGION" + if not Config.AWS_BEARER_TOKEN_BEDROCK: + if not os.environ.get("AWS_ACCESS_KEY_ID") or not os.environ.get( + "AWS_SECRET_ACCESS_KEY" + ): + return "Missing AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY or AWS_BEARER_TOKEN_BEDROCK" + return None + if provider == "vertex": + # GOOGLE_APPLICATION_CREDENTIALS is optional if running in GCP with ADC + if not Config.VERTEX_PROJECT: + return "Missing VERTEX_PROJECT" + if not Config.VERTEX_LOCATION: + return "Missing VERTEX_LOCATION" + return None + return f"Unsupported provider {provider}" + + @chat_bp.route("/init", methods=["POST"]) async def init(): session_id = await get_or_create_session_id() data = await request.get_json() logger.debug("Initializing bot for session %s", session_id) - api_key = await get_api_key(session_id) - if api_key: - logger.info("Model already initialized with OpenAI API Key from environment") - return jsonify({"message": "Model Already Initialized"}), 200 - elif not data: - logger.error("Invalid request") - return jsonify({"message": "Invalid request"}), 400 - elif "openai_api_key" not in data: - logger.error("openai_api_key not provided") - return jsonify({"message": "openai_api_key not provided"}), 400 - openai_api_key: str = data["openai_api_key"] - logger.debug("OpenAI API Key %s", openai_api_key[:5]) - # Save the api key in session - await store_api_key(session_id, openai_api_key) - return jsonify({"message": "Initialized"}), 200 + provider = Config.LLM_PROVIDER + if provider == "openai": + api_key = await get_api_key(session_id) + if api_key: + logger.info("Model already initialized with OpenAI API key") + return jsonify({"message": "Model Already Initialized"}), 200 + if not data: + logger.error("Invalid request") + return jsonify({"message": "Invalid request"}), 400 + if "openai_api_key" not in data: + logger.error("openai_api_key not provided") + return jsonify({"message": "openai_api_key not provided"}), 400 + openai_api_key: str = data["openai_api_key"] + logger.debug("OpenAI API Key %s", openai_api_key[:5]) + await store_api_key(session_id, openai_api_key, provider) + return jsonify({"message": "Initialized"}), 200 + if provider == "anthropic": + api_key = await get_api_key(session_id) + if api_key: + logger.info("Model already initialized with Anthropic API key") + return jsonify({"message": "Model Already Initialized"}), 200 + if not data: + logger.error("Invalid request") + return jsonify({"message": "Invalid request"}), 400 + if "anthropic_api_key" not in data: + logger.error("anthropic_api_key not provided") + return jsonify({"message": "anthropic_api_key not provided"}), 400 + anthropic_api_key: str = data["anthropic_api_key"] + logger.debug("Anthropic API Key %s", anthropic_api_key[:5]) + await store_api_key(session_id, anthropic_api_key, provider) + return jsonify({"message": "Initialized"}), 200 + error = _validate_provider_env(provider) + if error: + logger.error("Provider %s misconfigured: %s", provider, error) + return jsonify({"message": error}), 400 + return jsonify({"message": f"Initialized ({provider})"}), 200 @chat_bp.route("/model", methods=["POST"]) async def model(): session_id = await get_or_create_session_id() data = await request.get_json() - model_name = Config.DEFAULT_MODEL_NAME + model_name = Config.LLM_MODEL_NAME if data and "model_name" in data and data["model_name"]: model_name = data["model_name"] logger.debug("Setting model %s for session %s", model_name, session_id) @@ -51,18 +117,27 @@ async def model(): @chat_bp.route("/ask", methods=["POST"]) async def chat(): session_id = await get_or_create_session_id() - openai_api_key = await get_api_key(session_id) + provider = Config.LLM_PROVIDER + error = _validate_provider_env(provider) + if error: + return jsonify({"message": error}), 400 + provider_api_key = await get_api_key(session_id) model_name = await get_model_name(session_id) user_jwt = await get_user_jwt() - if not openai_api_key: - return jsonify({"message": "Missing OpenAI API key. Please authenticate."}), 400 + if provider in {"openai", "anthropic"} and not provider_api_key: + message = ( + "Missing OpenAI API key. Please authenticate." + if provider == "openai" + else "Missing Anthropic API key. Please authenticate." + ) + return jsonify({"message": message}), 400 data = await request.get_json() message = data.get("message", "").strip() id = data.get("id", uuid4().int & (1 << 63) - 1) if not message: return jsonify({"message": "Message is required", "id": id}), 400 reply, response_id = await process_user_message( - session_id, message, openai_api_key, model_name, user_jwt + session_id, message, provider_api_key, model_name, user_jwt ) return jsonify({"id": response_id, "message": reply}), 200 @@ -71,10 +146,11 @@ async def chat(): async def state(): session_id = await get_or_create_session_id() logger.debug("Checking state for session %s", session_id) - openai_api_key = await get_api_key(session_id) - if openai_api_key: + provider = Config.LLM_PROVIDER + provider_api_key = await get_api_key(session_id) + if provider in {"openai", "anthropic"} and provider_api_key: logger.debug( - "OpenAI API Key for session %s: %s", session_id, openai_api_key[:5] + "Provider API key for session %s: %s", session_id, provider_api_key[:5] ) chat_history = await get_chat_history(session_id) # Limit chat history to last 20 messages @@ -89,26 +165,35 @@ async def state(): ), 200, ) - return ( - jsonify({"initialized": "false", "message": "Model needs to be initialized"}), - 200, - ) + if provider in {"openai", "anthropic"}: + return ( + jsonify( + {"initialized": "false", "message": "Model needs to be initialized"} + ), + 200, + ) + return jsonify({"initialized": "true", "message": "Model initialized"}), 200 @chat_bp.route("/history", methods=["GET"]) async def history(): session_id = await get_or_create_session_id() logger.debug("Checking state for session %s", session_id) - openai_api_key = await get_api_key(session_id) - if openai_api_key: + provider = Config.LLM_PROVIDER + provider_api_key = await get_api_key(session_id) + if provider in {"openai", "anthropic"} and provider_api_key: chat_history = await get_chat_history(session_id) # Limit chat history to last 20 messages chat_history = chat_history[-20:] return jsonify({"chat_history": chat_history}), 200 - return ( - jsonify({"chat_history": []}), - 200, - ) + if provider in {"openai", "anthropic"}: + return ( + jsonify({"chat_history": []}), + 200, + ) + chat_history = await get_chat_history(session_id) + chat_history = chat_history[-20:] if chat_history else [] + return jsonify({"chat_history": chat_history}), 200 @chat_bp.route("/reset", methods=["POST"]) diff --git a/services/chatbot/src/chatbot/chat_service.py b/services/chatbot/src/chatbot/chat_service.py index d6d365f5..807d3536 100644 --- a/services/chatbot/src/chatbot/chat_service.py +++ b/services/chatbot/src/chatbot/chat_service.py @@ -2,6 +2,7 @@ from langgraph.graph.message import Messages +from .config import Config from .extensions import db from .langgraph_agent import execute_langgraph_agent from .retriever_utils import add_to_chroma_collection @@ -41,7 +42,11 @@ async def process_user_message(session_id, user_message, api_key, model_name, us {"id": response_message_id, "role": "assistant", "content": reply.content} ) add_to_chroma_collection( - api_key, session_id, [{"user": user_message}, {"assistant": reply.content}] + api_key, + session_id, + [{"user": user_message}, {"assistant": reply.content}], + Config.LLM_PROVIDER, + model_name, ) # Limit chat history to last 20 messages history = history[-20:] diff --git a/services/chatbot/src/chatbot/config.py b/services/chatbot/src/chatbot/config.py index 79feb359..061792bd 100644 --- a/services/chatbot/src/chatbot/config.py +++ b/services/chatbot/src/chatbot/config.py @@ -10,7 +10,24 @@ class Config: SECRET_KEY = os.getenv("SECRET_KEY", "super-secret") MONGO_URI = MONGO_CONNECTION_URI - DEFAULT_MODEL_NAME = os.getenv("DEFAULT_MODEL", "gpt-4o-mini") + LLM_PROVIDER = os.getenv("CHATBOT_LLM_PROVIDER", "openai").lower() + LLM_MODEL_NAME = os.getenv("CHATBOT_LLM_MODEL", "") + EMBEDDINGS_MODEL = os.getenv("CHATBOT_EMBEDDINGS_MODEL", "") + EMBEDDINGS_DIMENSIONS = int(os.getenv("CHATBOT_EMBEDDINGS_DIMENSIONS", "1536")) + OPENAI_API_KEY = os.getenv("CHATBOT_OPENAI_API_KEY") + ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") + GROQ_API_KEY = os.getenv("GROQ_API_KEY") + MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") + COHERE_API_KEY = os.getenv("COHERE_API_KEY") + AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") + AZURE_AD_TOKEN = os.getenv("AZURE_AD_TOKEN") + AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") + AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-15-preview") + AZURE_OPENAI_CHAT_DEPLOYMENT = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT") + AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT") + AWS_BEARER_TOKEN_BEDROCK = os.getenv("AWS_BEARER_TOKEN_BEDROCK") + VERTEX_PROJECT = os.getenv("VERTEX_PROJECT", "") + VERTEX_LOCATION = os.getenv("VERTEX_LOCATION", "") MAX_CONTENT_LENGTH = int(os.getenv("MAX_CONTENT_LENGTH", 50000)) CHROMA_HOST = CHROMA_HOST CHROMA_PORT = CHROMA_PORT diff --git a/services/chatbot/src/chatbot/langgraph_agent.py b/services/chatbot/src/chatbot/langgraph_agent.py index 4c25e6a5..3e1e4ccf 100644 --- a/services/chatbot/src/chatbot/langgraph_agent.py +++ b/services/chatbot/src/chatbot/langgraph_agent.py @@ -2,13 +2,70 @@ from langchain.agents import create_agent from langchain_community.agent_toolkits import SQLDatabaseToolkit -from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic +from langchain_aws import ChatBedrock +from langchain_cohere import ChatCohere +from langchain_google_vertexai import ChatVertexAI +from langchain_groq import ChatGroq +from langchain_mistralai import ChatMistralAI +from langchain_openai import AzureChatOpenAI, ChatOpenAI from .agent_utils import truncate_tool_messages +from .config import Config from .extensions import postgresdb from .mcp_client import get_mcp_client from .retriever_utils import get_retriever_tool +DEFAULT_MODELS = { + "openai": "gpt-4o-mini", + "anthropic": "claude-sonnet-4-20250514", + "azure_openai": "", # uses deployment name + "bedrock": "anthropic.claude-3-sonnet-20240229-v1:0", + "vertex": "gemini-1.5-flash", + "groq": "llama-3.3-70b-versatile", + "mistral": "mistral-large-latest", + "cohere": "command-r-plus", +} + + +def _get_default_model(provider: str) -> str: + return DEFAULT_MODELS.get(provider, "gpt-4o-mini") + + +def _build_llm(api_key, model_name): + provider = Config.LLM_PROVIDER + model_name = model_name or _get_default_model(provider) + if provider == "openai": + return ChatOpenAI(api_key=api_key, model=model_name) + if provider == "azure_openai": + kwargs = { + "azure_endpoint": Config.AZURE_OPENAI_ENDPOINT, + "api_version": Config.AZURE_OPENAI_API_VERSION, + "azure_deployment": Config.AZURE_OPENAI_CHAT_DEPLOYMENT or model_name, + } + if Config.AZURE_AD_TOKEN: + kwargs["azure_ad_token"] = Config.AZURE_AD_TOKEN + else: + kwargs["api_key"] = Config.AZURE_OPENAI_API_KEY + return AzureChatOpenAI(**kwargs) + if provider == "bedrock": + return ChatBedrock(model_id=model_name) + if provider == "vertex": + return ChatVertexAI( + model_name=model_name, + project=Config.VERTEX_PROJECT or None, + location=Config.VERTEX_LOCATION or None, + ) + if provider == "anthropic": + return ChatAnthropic(api_key=api_key, model=model_name) + if provider == "groq": + return ChatGroq(api_key=Config.GROQ_API_KEY, model=model_name) + if provider == "mistral": + return ChatMistralAI(api_key=Config.MISTRAL_API_KEY, model=model_name) + if provider == "cohere": + return ChatCohere(api_key=Config.COHERE_API_KEY, model=model_name) + raise ValueError(f"Unsupported provider {provider}") + async def build_langgraph_agent(api_key, model_name, user_jwt): system_prompt = textwrap.dedent( @@ -48,13 +105,13 @@ async def build_langgraph_agent(api_key, model_name, user_jwt): Use the tools only if you don't know the answer. """ ) - llm = ChatOpenAI(api_key=api_key, model=model_name) + llm = _build_llm(api_key, model_name) toolkit = SQLDatabaseToolkit(db=postgresdb, llm=llm) mcp_client = get_mcp_client(user_jwt) mcp_tools = await mcp_client.get_tools() db_tools = toolkit.get_tools() tools = mcp_tools + db_tools - retriever_tool = get_retriever_tool(api_key) + retriever_tool = get_retriever_tool(api_key, Config.LLM_PROVIDER, model_name) tools.append(retriever_tool) agent_node = create_agent( model=llm, diff --git a/services/chatbot/src/chatbot/retriever_utils.py b/services/chatbot/src/chatbot/retriever_utils.py index c94657b0..865401ad 100644 --- a/services/chatbot/src/chatbot/retriever_utils.py +++ b/services/chatbot/src/chatbot/retriever_utils.py @@ -1,11 +1,31 @@ +import logging + import chromadb +from langchain_aws import BedrockEmbeddings from langchain_chroma import Chroma as ChromaClient +from langchain_cohere import CohereEmbeddings from langchain_core.documents import Document -from langchain_core.tools import create_retriever_tool -from langchain_openai import OpenAIEmbeddings +from langchain_core.embeddings import Embeddings +from langchain_core.tools import tool +from langchain_google_vertexai import VertexAIEmbeddings +from langchain_mistralai import MistralAIEmbeddings +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from .config import Config +logger = logging.getLogger(__name__) + + +class ZeroEmbeddings(Embeddings): + def __init__(self, size: int) -> None: + self.size = size + + def embed_documents(self, texts): + return [[0.0] * self.size for _ in texts] + + def embed_query(self, text): + return [0.0] * self.size + def get_chroma_client(): chroma_client = chromadb.HttpClient( @@ -17,28 +37,121 @@ def get_chroma_client(): return chroma_client -def get_embedding_function(api_key): - return OpenAIEmbeddings( - openai_api_key=api_key, - model="text-embedding-3-large", - ) +def _resolve_embeddings_provider(provider: str) -> str: + if provider in {"openai", "azure_openai", "bedrock", "vertex", "mistral", "cohere"}: + return provider + return "none" + + +def _zero_embeddings() -> ZeroEmbeddings: + return ZeroEmbeddings(Config.EMBEDDINGS_DIMENSIONS) -def get_chroma_vectorstore(api_key): +def _default_embeddings_model(provider: str, llm_model: str | None) -> str: + if provider == "openai": + return "text-embedding-3-large" + if provider == "bedrock": + return "amazon.titan-embed-text-v2:0" + if provider == "vertex": + return "text-embedding-004" + if provider == "cohere": + return "embed-english-v3.0" + if provider == "mistral": + return "mistral-embed" + if provider == "azure_openai": + return llm_model or "" + return llm_model or "" + + +def get_embedding_function(api_key, provider: str, llm_model: str | None): + embeddings_provider = _resolve_embeddings_provider(provider) + if embeddings_provider == "openai": + if not api_key: + logger.warning("OpenAI embeddings requested without API key.") + return _zero_embeddings() + return OpenAIEmbeddings( + openai_api_key=api_key, + model=Config.EMBEDDINGS_MODEL + or _default_embeddings_model(embeddings_provider, llm_model), + ) + if embeddings_provider == "azure_openai": + if (not Config.AZURE_OPENAI_API_KEY and not Config.AZURE_AD_TOKEN) or not Config.AZURE_OPENAI_ENDPOINT: + logger.warning("Azure OpenAI embeddings misconfigured.") + return _zero_embeddings() + default_deployment = _default_embeddings_model(embeddings_provider, llm_model) + kwargs = { + "azure_endpoint": Config.AZURE_OPENAI_ENDPOINT, + "api_version": Config.AZURE_OPENAI_API_VERSION, + "azure_deployment": Config.AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT + or Config.EMBEDDINGS_MODEL + or Config.AZURE_OPENAI_CHAT_DEPLOYMENT + or default_deployment, + } + if Config.AZURE_AD_TOKEN: + kwargs["azure_ad_token"] = Config.AZURE_AD_TOKEN + else: + kwargs["api_key"] = Config.AZURE_OPENAI_API_KEY + return AzureOpenAIEmbeddings(**kwargs) + if embeddings_provider == "bedrock": + model_id = ( + Config.EMBEDDINGS_MODEL + or _default_embeddings_model(embeddings_provider, llm_model) + ) + if not model_id: + logger.warning("Bedrock embedding model not configured.") + return _zero_embeddings() + return BedrockEmbeddings(model_id=model_id) + if embeddings_provider == "vertex": + vertex_model = ( + Config.EMBEDDINGS_MODEL + or _default_embeddings_model(embeddings_provider, llm_model) + ) + return VertexAIEmbeddings( + model_name=vertex_model, + project=Config.VERTEX_PROJECT or None, + location=Config.VERTEX_LOCATION or None, + ) + if embeddings_provider == "cohere": + if not Config.COHERE_API_KEY: + logger.warning("Cohere embeddings requested without API key.") + return _zero_embeddings() + return CohereEmbeddings( + cohere_api_key=Config.COHERE_API_KEY, + model=Config.EMBEDDINGS_MODEL + or _default_embeddings_model(embeddings_provider, llm_model), + ) + if embeddings_provider == "mistral": + if not Config.MISTRAL_API_KEY: + logger.warning("Mistral embeddings requested without API key.") + return _zero_embeddings() + return MistralAIEmbeddings( + mistral_api_key=Config.MISTRAL_API_KEY, + model=Config.EMBEDDINGS_MODEL + or _default_embeddings_model(embeddings_provider, llm_model), + ) + logger.warning("Embeddings disabled for provider %s.", provider) + return _zero_embeddings() + + +def get_chroma_vectorstore(api_key, provider: str, llm_model: str | None): chroma_client = get_chroma_client() vectorstore = ChromaClient( client=chroma_client, collection_name="chats", - embedding_function=get_embedding_function(api_key), + embedding_function=get_embedding_function(api_key, provider, llm_model), create_collection_if_not_exists=True, ) return vectorstore def add_to_chroma_collection( - api_key, session_id, new_messages: list[dict[str, str]] + api_key, + session_id, + new_messages: list[dict[str, str]], + provider: str, + llm_model: str | None, ) -> list: - vectorstore = get_chroma_vectorstore(api_key) + vectorstore = get_chroma_vectorstore(api_key, provider, llm_model) print("new_messages", new_messages) # new_messages = [{'user': 'hi'}, {'assistant': 'Hello! How can I assist you today?'}] documents = [] @@ -55,19 +168,23 @@ def add_to_chroma_collection( return res -def get_retriever_tool(api_key): - vectorstore = get_chroma_vectorstore(api_key) +def get_retriever_tool(api_key, provider: str, llm_model: str | None): + vectorstore = get_chroma_vectorstore(api_key, provider, llm_model) retriever = vectorstore.as_retriever() - retriever_tool = create_retriever_tool( - retriever, - name="chat_rag", - description=""" - Use this to answer questions based on user chat history (summarized and semantically indexed). + + @tool + def chat_rag(query: str) -> str: + """Use this to answer questions based on user chat history (summarized and semantically indexed). Use this when the user asks about prior chats, what they asked earlier, or wants a summary of past conversations. Use this tool when the user refers to anything mentioned before, asks for a summary of previous messages or sessions, or references phrases like 'what I said earlier', 'things we discussed', 'my earlier question', 'until now', 'till date', 'all my conversations' or 'previously mentioned'. The chat history is semantically indexed and summarized using vector search. - """, - ) - return retriever_tool + + Args: + query: The search query to find relevant chat history. + """ + docs = retriever.invoke(query) + return "\n\n".join(doc.page_content for doc in docs) + + return chat_rag diff --git a/services/chatbot/src/chatbot/session_service.py b/services/chatbot/src/chatbot/session_service.py index 4d84caf3..c611af94 100644 --- a/services/chatbot/src/chatbot/session_service.py +++ b/services/chatbot/src/chatbot/session_service.py @@ -1,4 +1,3 @@ -import os import uuid from quart import after_this_request, request @@ -24,26 +23,46 @@ def after_index(response): return session_id -async def store_api_key(session_id, api_key): +def _get_api_key_field(provider: str) -> str | None: + if provider == "openai": + return "openai_api_key" + if provider == "anthropic": + return "anthropic_api_key" + return None + + +async def store_api_key(session_id, api_key, provider: str): + key_field = _get_api_key_field(provider) + if not key_field: + return await db.sessions.update_one( - {"session_id": session_id}, {"$set": {"openai_api_key": api_key}}, upsert=True + {"session_id": session_id}, {"$set": {key_field: api_key}}, upsert=True ) async def get_api_key(session_id): - if os.environ.get("CHATBOT_OPENAI_API_KEY"): - return os.environ.get("CHATBOT_OPENAI_API_KEY") + provider = Config.LLM_PROVIDER + key_field = _get_api_key_field(provider) + if provider == "openai" and Config.OPENAI_API_KEY: + return Config.OPENAI_API_KEY + if provider == "anthropic" and Config.ANTHROPIC_API_KEY: + return Config.ANTHROPIC_API_KEY + if not key_field: + return None doc = await db.sessions.find_one({"session_id": session_id}) if not doc: return None - if "openai_api_key" not in doc: + if key_field not in doc: return None - return doc["openai_api_key"] + return doc[key_field] async def delete_api_key(session_id): + updates = {} + for key_field in ("openai_api_key", "anthropic_api_key"): + updates[key_field] = "" await db.sessions.update_one( - {"session_id": session_id}, {"$unset": {"openai_api_key": ""}} + {"session_id": session_id}, {"$unset": updates} ) @@ -56,9 +75,9 @@ async def store_model_name(session_id, model_name): async def get_model_name(session_id): doc = await db.sessions.find_one({"session_id": session_id}) if not doc: - return Config.DEFAULT_MODEL_NAME + return Config.LLM_MODEL_NAME if "model_name" not in doc: - return Config.DEFAULT_MODEL_NAME + return Config.LLM_MODEL_NAME return doc["model_name"] diff --git a/services/chatbot/src/mcpserver/tool_helpers.py b/services/chatbot/src/mcpserver/tool_helpers.py index 6ef1610f..d6d8caa0 100644 --- a/services/chatbot/src/mcpserver/tool_helpers.py +++ b/services/chatbot/src/mcpserver/tool_helpers.py @@ -1,15 +1,23 @@ import os + +from chatbot.config import Config from chatbot.extensions import db async def get_any_api_key(): - if os.environ.get("CHATBOT_OPENAI_API_KEY"): - return os.environ.get("CHATBOT_OPENAI_API_KEY") + provider = Config.LLM_PROVIDER + if provider == "openai" and Config.OPENAI_API_KEY: + return Config.OPENAI_API_KEY + if provider == "anthropic" and Config.ANTHROPIC_API_KEY: + return Config.ANTHROPIC_API_KEY + key_field = "openai_api_key" if provider == "openai" else "anthropic_api_key" + if provider not in {"openai", "anthropic"}: + return None doc = await db.sessions.find_one( - {"openai_api_key": {"$exists": True, "$ne": None}}, {"openai_api_key": 1} + {key_field: {"$exists": True, "$ne": None}}, {key_field: 1} ) - if doc and "openai_api_key" in doc: - return doc["openai_api_key"] + if doc and key_field in doc: + return doc[key_field] return None def fix_array_responses_in_spec(spec): From b40f765255d1772fabaa7e15db696ce01ca70330 Mon Sep 17 00:00:00 2001 From: Roshan Piyush Date: Mon, 19 Jan 2026 18:50:44 +0530 Subject: [PATCH 2/3] Add logging --- services/chatbot/src/chatbot/langgraph_agent.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/services/chatbot/src/chatbot/langgraph_agent.py b/services/chatbot/src/chatbot/langgraph_agent.py index 3e1e4ccf..e5563c17 100644 --- a/services/chatbot/src/chatbot/langgraph_agent.py +++ b/services/chatbot/src/chatbot/langgraph_agent.py @@ -1,3 +1,4 @@ +import logging import textwrap from langchain.agents import create_agent @@ -16,6 +17,8 @@ from .mcp_client import get_mcp_client from .retriever_utils import get_retriever_tool +logger = logging.getLogger(__name__) + DEFAULT_MODELS = { "openai": "gpt-4o-mini", "anthropic": "claude-sonnet-4-20250514", @@ -35,6 +38,7 @@ def _get_default_model(provider: str) -> str: def _build_llm(api_key, model_name): provider = Config.LLM_PROVIDER model_name = model_name or _get_default_model(provider) + logger.info("Using LLM provider: %s, model: %s", provider, model_name) if provider == "openai": return ChatOpenAI(api_key=api_key, model=model_name) if provider == "azure_openai": From eab5398c84c029b8cd7f6e6c17ef912e7c70196c Mon Sep 17 00:00:00 2001 From: Roshan Piyush Date: Mon, 19 Jan 2026 19:51:27 +0530 Subject: [PATCH 3/3] Make base endpoint configurable --- deploy/docker/docker-compose.yml | 1 + deploy/helm/templates/chatbot/config.yaml | 1 + deploy/helm/values.yaml | 3 +++ deploy/k8s/base/chatbot/config.yaml | 1 + services/chatbot/src/chatbot/config.py | 1 + services/chatbot/src/chatbot/langgraph_agent.py | 6 +++++- services/chatbot/src/chatbot/retriever_utils.py | 11 +++++++---- 7 files changed, 19 insertions(+), 5 deletions(-) diff --git a/deploy/docker/docker-compose.yml b/deploy/docker/docker-compose.yml index abd759fc..a18ae61a 100755 --- a/deploy/docker/docker-compose.yml +++ b/deploy/docker/docker-compose.yml @@ -177,6 +177,7 @@ services: - CHATBOT_EMBEDDINGS_MODEL=${CHATBOT_EMBEDDINGS_MODEL:-} - CHATBOT_EMBEDDINGS_DIMENSIONS=${CHATBOT_EMBEDDINGS_DIMENSIONS:-1536} - CHATBOT_OPENAI_API_KEY=${CHATBOT_OPENAI_API_KEY:-} + - CHATBOT_OPENAI_BASE_URL=${CHATBOT_OPENAI_BASE_URL:-} - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-} - AZURE_OPENAI_API_KEY=${AZURE_OPENAI_API_KEY:-} - AZURE_AD_TOKEN=${AZURE_AD_TOKEN:-} diff --git a/deploy/helm/templates/chatbot/config.yaml b/deploy/helm/templates/chatbot/config.yaml index c72a7402..2ab2db35 100644 --- a/deploy/helm/templates/chatbot/config.yaml +++ b/deploy/helm/templates/chatbot/config.yaml @@ -21,6 +21,7 @@ data: MONGO_DB_PASSWORD: {{ .Values.mongodb.config.mongoPassword }} MONGO_DB_NAME: {{ .Values.mongodb.config.mongoDbName }} CHATBOT_OPENAI_API_KEY: {{ .Values.chatbotOpenaiApiKey | default .Values.openAIApiKey | quote }} + CHATBOT_OPENAI_BASE_URL: {{ .Values.chatbotOpenaiBaseUrl | quote }} CHATBOT_LLM_PROVIDER: {{ .Values.chatbotLlmProvider | quote }} CHATBOT_LLM_MODEL: {{ .Values.chatbotLlmModel | quote }} CHATBOT_EMBEDDINGS_MODEL: {{ .Values.chatbotEmbeddingsModel | quote }} diff --git a/deploy/helm/values.yaml b/deploy/helm/values.yaml index 9fb01ae1..0067a91f 100644 --- a/deploy/helm/values.yaml +++ b/deploy/helm/values.yaml @@ -23,6 +23,9 @@ chatbotEmbeddingsDimensions: 1536 chatbotOpenaiApiKey: "" openAIApiKey: "" # for backward compatibility +# OpenAI Base URL (for DeepSeek or other OpenAI-compatible APIs) +chatbotOpenaiBaseUrl: "" + # Anthropic API Key anthropicApiKey: "" diff --git a/deploy/k8s/base/chatbot/config.yaml b/deploy/k8s/base/chatbot/config.yaml index d0a59cdc..4465bb28 100644 --- a/deploy/k8s/base/chatbot/config.yaml +++ b/deploy/k8s/base/chatbot/config.yaml @@ -20,6 +20,7 @@ data: MONGO_DB_PASSWORD: "crapisecretpassword" MONGO_DB_NAME: "crapi" CHATBOT_OPENAI_API_KEY: "" + CHATBOT_OPENAI_BASE_URL: "" CHATBOT_LLM_PROVIDER: "openai" CHATBOT_LLM_MODEL: "" CHATBOT_EMBEDDINGS_MODEL: "" diff --git a/services/chatbot/src/chatbot/config.py b/services/chatbot/src/chatbot/config.py index 061792bd..b1180883 100644 --- a/services/chatbot/src/chatbot/config.py +++ b/services/chatbot/src/chatbot/config.py @@ -15,6 +15,7 @@ class Config: EMBEDDINGS_MODEL = os.getenv("CHATBOT_EMBEDDINGS_MODEL", "") EMBEDDINGS_DIMENSIONS = int(os.getenv("CHATBOT_EMBEDDINGS_DIMENSIONS", "1536")) OPENAI_API_KEY = os.getenv("CHATBOT_OPENAI_API_KEY") + OPENAI_BASE_URL = os.getenv("CHATBOT_OPENAI_BASE_URL") ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") GROQ_API_KEY = os.getenv("GROQ_API_KEY") MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") diff --git a/services/chatbot/src/chatbot/langgraph_agent.py b/services/chatbot/src/chatbot/langgraph_agent.py index e5563c17..c97198d7 100644 --- a/services/chatbot/src/chatbot/langgraph_agent.py +++ b/services/chatbot/src/chatbot/langgraph_agent.py @@ -40,7 +40,11 @@ def _build_llm(api_key, model_name): model_name = model_name or _get_default_model(provider) logger.info("Using LLM provider: %s, model: %s", provider, model_name) if provider == "openai": - return ChatOpenAI(api_key=api_key, model=model_name) + kwargs = {"api_key": api_key, "model": model_name} + if Config.OPENAI_BASE_URL: + kwargs["base_url"] = Config.OPENAI_BASE_URL + logger.info("Using custom OpenAI base URL: %s", Config.OPENAI_BASE_URL) + return ChatOpenAI(**kwargs) if provider == "azure_openai": kwargs = { "azure_endpoint": Config.AZURE_OPENAI_ENDPOINT, diff --git a/services/chatbot/src/chatbot/retriever_utils.py b/services/chatbot/src/chatbot/retriever_utils.py index 865401ad..b580c7a6 100644 --- a/services/chatbot/src/chatbot/retriever_utils.py +++ b/services/chatbot/src/chatbot/retriever_utils.py @@ -69,11 +69,14 @@ def get_embedding_function(api_key, provider: str, llm_model: str | None): if not api_key: logger.warning("OpenAI embeddings requested without API key.") return _zero_embeddings() - return OpenAIEmbeddings( - openai_api_key=api_key, - model=Config.EMBEDDINGS_MODEL + kwargs = { + "openai_api_key": api_key, + "model": Config.EMBEDDINGS_MODEL or _default_embeddings_model(embeddings_provider, llm_model), - ) + } + if Config.OPENAI_BASE_URL: + kwargs["base_url"] = Config.OPENAI_BASE_URL + return OpenAIEmbeddings(**kwargs) if embeddings_provider == "azure_openai": if (not Config.AZURE_OPENAI_API_KEY and not Config.AZURE_AD_TOKEN) or not Config.AZURE_OPENAI_ENDPOINT: logger.warning("Azure OpenAI embeddings misconfigured.")