Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 230 additions & 6 deletions nova/core/ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from datetime import datetime

from nova.core.metrics import ContextAnalysis, get_metrics_collector
from nova.models.config import AIModelConfig

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -39,6 +41,7 @@ class BaseAIClient(ABC):

def __init__(self, config: AIModelConfig):
self.config = config
self.metrics_collector = get_metrics_collector()

@abstractmethod
async def generate_response(self, messages: list[dict[str, str]], **kwargs) -> str:
Expand All @@ -62,6 +65,29 @@ async def list_models(self) -> list[str]:
"""List available models for this provider"""
pass

def _analyze_context(self, messages: list[dict[str, str]]) -> ContextAnalysis:
"""Analyze context window usage for this provider"""
# Get max context tokens for this model (provider-specific)
max_context = self._get_max_context_tokens()

context_analysis = self.metrics_collector.analyze_context(
messages=messages, max_context_tokens=max_context
)

# Print warnings if context is getting full
self.metrics_collector.print_context_warning(context_analysis)

return context_analysis

@abstractmethod
def _get_max_context_tokens(self) -> int:
"""Get maximum context tokens for the current model"""
pass

def _get_provider_name(self) -> str:
"""Get the provider name for metrics"""
return self.__class__.__name__.replace("Client", "").lower()


class OpenAIClient(BaseAIClient):
"""OpenAI API client"""
Expand All @@ -85,8 +111,25 @@ def validate_config(self) -> bool:
return False
return True

async def generate_response(self, messages: list[dict[str, str]], **kwargs) -> str:
async def generate_response(
self, messages: list[dict[str, str]], conversation_id: str = "unknown", **kwargs
) -> str:
"""Generate response using OpenAI API"""
# Analyze context before making request
context_analysis = self._analyze_context(messages)

# Log request details
self.metrics_collector.log_request(
provider=self._get_provider_name(),
model=self.config.model_name,
messages=messages,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
**kwargs,
)

request_start = datetime.now()

try:
response = await self.client.chat.completions.create(
model=self.config.model_name,
Expand All @@ -95,7 +138,41 @@ async def generate_response(self, messages: list[dict[str, str]], **kwargs) -> s
temperature=self.config.temperature,
**kwargs,
)
return response.choices[0].message.content

response_complete = datetime.now()
response_content = response.choices[0].message.content

# Extract token usage if available
usage = response.usage
input_tokens = usage.prompt_tokens if usage else 0
output_tokens = usage.completion_tokens if usage else 0

# Create metrics
metrics = self.metrics_collector.create_metrics(
conversation_id=conversation_id,
provider=self._get_provider_name(),
model=self.config.model_name,
messages=messages,
response=response_content,
request_start=request_start,
response_complete=response_complete,
context_analysis=context_analysis,
input_tokens=input_tokens,
output_tokens=output_tokens,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
**kwargs,
)

# Log response details
self.metrics_collector.log_response(
provider=self._get_provider_name(),
model=self.config.model_name,
response=response_content,
metrics=metrics,
)

return response_content

except Exception as e:
self._handle_api_error(e)
Expand Down Expand Up @@ -129,6 +206,18 @@ async def list_models(self) -> list[str]:
except Exception as e:
self._handle_api_error(e)

def _get_max_context_tokens(self) -> int:
"""Get maximum context tokens for OpenAI models"""
model_limits = {
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-turbo": 128000,
"gpt-4o": 128000,
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16384,
}
return model_limits.get(self.config.model_name, 8192) # Default to GPT-4 limit

def _handle_api_error(self, error: Exception) -> None:
"""Convert OpenAI errors to our standard errors"""
import openai
Expand Down Expand Up @@ -167,8 +256,25 @@ def validate_config(self) -> bool:
return False
return True

async def generate_response(self, messages: list[dict[str, str]], **kwargs) -> str:
async def generate_response(
self, messages: list[dict[str, str]], conversation_id: str = "unknown", **kwargs
) -> str:
"""Generate response using Anthropic API"""
# Analyze context before making request
context_analysis = self._analyze_context(messages)

# Log request details
self.metrics_collector.log_request(
provider=self._get_provider_name(),
model=self.config.model_name,
messages=messages,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
**kwargs,
)

request_start = datetime.now()

try:
# Convert messages to Anthropic format
anthropic_messages = self._convert_messages(messages)
Expand All @@ -180,7 +286,41 @@ async def generate_response(self, messages: list[dict[str, str]], **kwargs) -> s
temperature=self.config.temperature,
**kwargs,
)
return response.content[0].text

response_complete = datetime.now()
response_content = response.content[0].text

# Extract token usage if available
usage = response.usage
input_tokens = usage.input_tokens if usage else 0
output_tokens = usage.output_tokens if usage else 0

# Create metrics
metrics = self.metrics_collector.create_metrics(
conversation_id=conversation_id,
provider=self._get_provider_name(),
model=self.config.model_name,
messages=messages,
response=response_content,
request_start=request_start,
response_complete=response_complete,
context_analysis=context_analysis,
input_tokens=input_tokens,
output_tokens=output_tokens,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
**kwargs,
)

# Log response details
self.metrics_collector.log_response(
provider=self._get_provider_name(),
model=self.config.model_name,
response=response_content,
metrics=metrics,
)

return response_content

except Exception as e:
self._handle_api_error(e)
Expand Down Expand Up @@ -231,6 +371,17 @@ def _convert_messages(self, messages: list[dict[str, str]]) -> list[dict[str, st
)
return converted

def _get_max_context_tokens(self) -> int:
"""Get maximum context tokens for Anthropic models"""
model_limits = {
"claude-3-5-sonnet-20241022": 200000,
"claude-3-5-haiku-20241022": 200000,
"claude-3-opus-20240229": 200000,
"claude-3-sonnet-20240229": 200000,
"claude-3-haiku-20240307": 200000,
}
return model_limits.get(self.config.model_name, 200000) # Default to 200k

def _handle_api_error(self, error: Exception) -> None:
"""Convert Anthropic errors to our standard errors"""
import anthropic
Expand Down Expand Up @@ -265,8 +416,25 @@ def validate_config(self) -> bool:
# Ollama doesn't require API key, just check if server is reachable
return True

async def generate_response(self, messages: list[dict[str, str]], **kwargs) -> str:
async def generate_response(
self, messages: list[dict[str, str]], conversation_id: str = "unknown", **kwargs
) -> str:
"""Generate response using Ollama API"""
# Analyze context before making request
context_analysis = self._analyze_context(messages)

# Log request details
self.metrics_collector.log_request(
provider=self._get_provider_name(),
model=self.config.model_name,
messages=messages,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
**kwargs,
)

request_start = datetime.now()

try:
response = await self.client.chat(
model=self.config.model_name,
Expand All @@ -277,7 +445,42 @@ async def generate_response(self, messages: list[dict[str, str]], **kwargs) -> s
},
**kwargs,
)
return response["message"]["content"]

response_complete = datetime.now()
response_content = response["message"]["content"]

# Ollama doesn't provide token usage, so we estimate
input_tokens = self.metrics_collector._estimate_tokens(messages)
output_tokens = self.metrics_collector._estimate_tokens(
[{"content": response_content}]
)

# Create metrics
metrics = self.metrics_collector.create_metrics(
conversation_id=conversation_id,
provider=self._get_provider_name(),
model=self.config.model_name,
messages=messages,
response=response_content,
request_start=request_start,
response_complete=response_complete,
context_analysis=context_analysis,
input_tokens=input_tokens,
output_tokens=output_tokens,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
**kwargs,
)

# Log response details
self.metrics_collector.log_response(
provider=self._get_provider_name(),
model=self.config.model_name,
response=response_content,
metrics=metrics,
)

return response_content

except Exception as e:
self._handle_api_error(e)
Expand Down Expand Up @@ -313,6 +516,27 @@ async def list_models(self) -> list[str]:
except Exception as e:
self._handle_api_error(e)

def _get_max_context_tokens(self) -> int:
"""Get maximum context tokens for Ollama models"""
# Common context limits for popular Ollama models
model_limits = {
"llama2": 4096,
"llama2:13b": 4096,
"llama2:70b": 4096,
"mistral": 8192,
"mistral:7b": 8192,
"codellama": 16384,
"codellama:13b": 16384,
"codellama:34b": 16384,
"vicuna": 2048,
"orca-mini": 2048,
}
# Extract base model name (remove size suffix)
base_model = self.config.model_name.split(":")[0]
return model_limits.get(
self.config.model_name, model_limits.get(base_model, 4096)
)

def _handle_api_error(self, error: Exception) -> None:
"""Convert Ollama errors to our standard errors"""
if "connection" in str(error).lower():
Expand Down
Loading