diff --git a/src/ai_migrate/llm_providers/base.py b/src/ai_migrate/llm_providers/base.py index 602e3fa..96fd623 100644 --- a/src/ai_migrate/llm_providers/base.py +++ b/src/ai_migrate/llm_providers/base.py @@ -13,6 +13,7 @@ async def generate_completion( temperature: float = 0.1, max_tokens: int = 8192, model: str | None = None, + response_format: dict[str, Any] | None = None, ) -> tuple[dict[str, Any], list[dict[str, Any]]]: """Generate a completion from the LLM. @@ -22,6 +23,8 @@ async def generate_completion( temperature: The temperature to use for generation max_tokens: The maximum number of tokens to generate model: Optional model override + response_format: Optional response format specification, e.g. for JSON Schema + Example: {"type": "json_schema", "schema": {...}} Returns: A tuple of (response, messages) diff --git a/src/ai_migrate/llm_providers/openai.py b/src/ai_migrate/llm_providers/openai.py index d2b00c5..66d8126 100644 --- a/src/ai_migrate/llm_providers/openai.py +++ b/src/ai_migrate/llm_providers/openai.py @@ -30,6 +30,8 @@ async def generate_completion( tools: list[ToolDefinition] | None = None, temperature: float = 0.1, max_tokens: int = 8192, + model: str | None = None, + response_format: dict[str, Any] | None = None, ) -> tuple[dict[str, Any], list[dict[str, Any]]]: """Generate a completion @@ -38,17 +40,24 @@ async def generate_completion( tools: Optional tools to provide temperature: The temperature to use for generation max_tokens: The maximum number of tokens to generate + model: Optional model override + response_format: Optional response format specification Returns: A tuple of (response, messages) """ - response = await self._openai_client.chat.completions.create( - model="gpt-4o", - messages=messages, - temperature=temperature, - tools=[_tool(t) for t in tools or []], - max_tokens=max_tokens, - ) + kwargs = { + "model": model or "gpt-4o", + "messages": messages, + "temperature": temperature, + "tools": [_tool(t) for t in tools or []], + "max_tokens": max_tokens, + } + + if response_format: + kwargs["response_format"] = response_format + + response = await self._openai_client.chat.completions.create(**kwargs) response = response.model_dump() return response, messages