diff --git a/mellea/backends/model_ids.py b/mellea/backends/model_ids.py index 96b8afee..3ffdb4b8 100644 --- a/mellea/backends/model_ids.py +++ b/mellea/backends/model_ids.py @@ -17,6 +17,7 @@ class ModelIdentifier: ollama_name: str | None = None watsonx_name: str | None = None mlx_name: str | None = None + openai_name: str | None = None hf_tokenizer_name: str | None = None # if None, is the same as hf_model_name @@ -134,9 +135,9 @@ class ModelIdentifier: QWEN3_14B = ModelIdentifier(hf_model_name="Qwen/Qwen3-14B", ollama_name="qwen3:14b") -###################### -#### OpenAI models ### -###################### +########################### +#### OpenAI open models ### +########################### OPENAI_GPT_OSS_20B = ModelIdentifier( hf_model_name="openai/gpt-oss-20b", ollama_name="gpt-oss:20b" @@ -145,6 +146,12 @@ class ModelIdentifier: hf_model_name="openai/gpt-oss-120b", ollama_name="gpt-oss:120b" ) +########################### +#### OpenAI prop models ### +########################### + +OPENAI_GPT_5_1 = ModelIdentifier(openai_name="gpt-5.1") + ##################### #### Misc models #### ##################### diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index ba825753..65ab544a 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -6,6 +6,7 @@ import functools import inspect import json +import os from collections.abc import Callable, Coroutine from copy import deepcopy from enum import Enum @@ -72,7 +73,7 @@ class OpenAIBackend(FormatterBackend, AdapterMixin): def __init__( self, - model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_4_MICRO_3B, + model_id: str | ModelIdentifier = model_ids.OPENAI_GPT_5_1, formatter: Formatter | None = None, base_url: str | None = None, model_options: dict | None = None, @@ -142,26 +143,38 @@ def __init__( self.default_to_constraint_checking_alora = default_to_constraint_checking_alora - self._model_id = model_id match model_id: case str(): - self._hf_model_id = model_id + self._model_id = model_id case ModelIdentifier(): - assert model_id.hf_model_name is not None, ( - "model_id is None. This can also happen if the ModelIdentifier has no hf_model_id name set." + assert model_id.openai_name is not None, ( + "model_id is None. This can also happen if the ModelIdentifier has no `openai_name` name set." ) - self._hf_model_id = model_id.hf_model_name + self._model_id = model_id.openai_name - if base_url is None: - self._base_url = "http://localhost:11434/v1" # ollama - else: - self._base_url = base_url - if api_key is None: - self._api_key = "ollama" - else: - self._api_key = api_key + # Use provided parameters or fall back to environment variables + self._api_key = api_key + self._base_url = base_url - self._server_type = _server_type(self._base_url) + # Validate that we have the required configuration + if self._api_key is None and os.getenv("OPENAI_API_KEY") is None: + raise ValueError( + "OPENAI_API_KEY or api_key is required but not set. Please either:\n" + " 1. Set the environment variable: export OPENAI_API_KEY='your-key-here'\n" + " 2. Pass it as a parameter: OpenAIBackend(api_key='your-key-here')" + ) + + if self._base_url is None and os.getenv("OPENAI_BASE_URL") is None: + FancyLogger.get_logger().warning( + "OPENAI_BASE_URL or base_url is not set.\n" + "The openai SDK is going to assume that the base_url is `https://api.openai.com/v1`" + ) + + self._server_type: _ServerType = ( + _server_type(self._base_url) + if self._base_url is not None + else _ServerType.OPENAI + ) # type: ignore self._openai_client_kwargs = self.filter_openai_client_kwargs(**kwargs) @@ -598,14 +611,38 @@ async def _generate_from_chat_context_standard( extra_params: dict[str, Any] = {} if _format is not None: - extra_params["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": _format.__name__, - "schema": _format.model_json_schema(), - "strict": True, - }, - } + if self._server_type == _ServerType.OPENAI: + # The OpenAI platform requires that additionalProperties=False on all response_format schemas. + # However, not all schemas generates by Mellea include additionalProperties. + # GenerativeSlot, in particular, does not add this property. + # The easiest way to address this disparity between OpenAI and other inference providers is to + # monkey-patch the response format exactly when we are actually using the OpenAI server. + # + # This only addresses the additionalProperties=False constraint. + # Other constraints we should be checking/patching are described here: + # https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat + monkey_patched_response_schema = _format.model_json_schema() + monkey_patched_response_schema["additionalProperties"] = False + extra_params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": _format.__name__, + "schema": monkey_patched_response_schema, + "strict": True, + }, + } + else: + FancyLogger().get_logger().warning( + "Mellea assumes you are NOT using the OpenAI platform, and that other model providers have less strict requirements on support JSON schemas passed into `format=`. If you encounter a server-side error following this message, then you found an exception to this assumption. Please open an issue at github.com/generative_computing/mellea with this stack trace and your inference engine / model provider." + ) + extra_params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": _format.__name__, + "schema": _format.model_json_schema(), + "strict": True, + }, + } # Append tool call information if applicable. tools: dict[str, Callable] = dict() @@ -631,15 +668,21 @@ async def _generate_from_chat_context_standard( formatted_tools = convert_tools_to_json(tools) use_tools = len(formatted_tools) > 0 + # Build optional reasoning parameters + # NOTE: the openai SDK doesn't like it if you pass `reasoning_effort` param to a non-reasoning model e.g. gpt4o + reasoning_params = {} + if thinking is not None: + reasoning_params["reasoning_effort"] = thinking + chat_response: Coroutine[ Any, Any, ChatCompletion | openai.AsyncStream[ChatCompletionChunk] ] = self._async_client.chat.completions.create( - model=self._hf_model_id, + model=self._model_id, messages=conversation, # type: ignore - reasoning_effort=thinking, # type: ignore tools=formatted_tools if use_tools else None, # type: ignore # parallel_tool_calls=False, # We only support calling one tool per turn. But we do the choosing on our side so we leave this False. **extra_params, + **reasoning_params, # type: ignore **self._make_backend_specific_and_remove( model_opts, is_chat_context=ctx.is_chat_context ), @@ -807,7 +850,7 @@ async def generate_from_raw( try: completion_response: Completion = ( await self._async_client.completions.create( - model=self._hf_model_id, + model=self._model_id, prompt=prompts, extra_body=extra_body, **self._make_backend_specific_and_remove( @@ -860,7 +903,10 @@ async def generate_from_raw( @property def base_model_name(self): """Returns the base_model_id of the model used by the backend. For example, `granite-3.3-8b-instruct` for `ibm-granite/granite-3.3-8b-instruct`.""" - return self._hf_model_id.split("/")[1] + if "/" in self._model_id: + return self._model_id.split("/")[1] + else: + return self._model_id def add_adapter(self, adapter: OpenAIAdapter): """Adds the given adapter to the backend. Must not have been added to a different backend.""" @@ -970,22 +1016,3 @@ def list_adapters(self) -> list[str]: :returns: list of adapter names that are currently registered with this backend """ return list(self._loaded_adapters.keys()) - - def apply_chat_template(self, chat: list[dict[str, str]]): - """Apply the chat template for the model, if such a model is available (e.g., when it can deduce the huggingface model id).""" - from transformers import AutoTokenizer - - if not hasattr(self, "_tokenizer"): - match _server_type(self._base_url): - case _ServerType.LOCALHOST: - self._tokenizer: "PreTrainedTokenizer" = ( # noqa: UP037 - AutoTokenizer.from_pretrained(self._hf_model_id) - ) - case _ServerType.OPENAI: - raise Exception( - "apply_chat_template is called while targeting a server at openai.com. " - "This is not supported --- openai.com does not support Activated Lora. " - "Use a locally served vllm instance. " - ) - - return self._tokenizer.apply_chat_template(chat, tokenize=False) diff --git a/test/backends/test_openai_ollama.py b/test/backends/test_openai_ollama.py index 1848287c..57ca3281 100644 --- a/test/backends/test_openai_ollama.py +++ b/test/backends/test_openai_ollama.py @@ -1,6 +1,7 @@ # test/rits_backend_tests/test_openai_integration.py import asyncio import os +from unittest.mock import patch import openai import pydantic @@ -216,6 +217,80 @@ async def get_client_async(): assert len(backend._client_cache.cache.values()) == 2 +async def test_reasoning_effort_conditional_passing(backend): + """Test that reasoning_effort is only passed to API when not None.""" + from unittest.mock import AsyncMock, MagicMock, patch + + ctx = ChatContext() + ctx = ctx.add(CBlock(value="Test")) + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = "Response" + mock_response.choices[0].message.role = "assistant" + + # Test 1: reasoning_effort should NOT be passed when not specified + with patch.object( + backend._async_client.chat.completions, "create", new_callable=AsyncMock + ) as mock_create: + mock_create.return_value = mock_response + await backend.generate_from_chat_context( + CBlock(value="Hi"), ctx, model_options={} + ) + call_kwargs = mock_create.call_args.kwargs + assert "reasoning_effort" not in call_kwargs, ( + "reasoning_effort should not be passed when not specified" + ) + + # Test 2: reasoning_effort SHOULD be passed when specified + with patch.object( + backend._async_client.chat.completions, "create", new_callable=AsyncMock + ) as mock_create: + mock_create.return_value = mock_response + await backend.generate_from_chat_context( + CBlock(value="Hi"), ctx, model_options={ModelOption.THINKING: "medium"} + ) + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs.get("reasoning_effort") == "medium", ( + "reasoning_effort should be passed with correct value when specified" + ) + + +def test_api_key_and_base_url_from_parameters(): + """Test that API key and base URL can be set via parameters.""" + backend = OpenAIBackend( + model_id="gpt-4", api_key="test-api-key", base_url="https://api.test.com/v1" + ) + assert backend._api_key == "test-api-key" + assert backend._base_url == "https://api.test.com/v1" + + +def test_parameter_overrides_env_variable(): + """Test that explicit parameters override environment variables.""" + with patch.dict( + os.environ, + {"OPENAI_API_KEY": "env-api-key", "OPENAI_BASE_URL": "https://api.env.com/v1"}, + ): + backend = OpenAIBackend( + model_id="gpt-4", + api_key="param-api-key", + base_url="https://api.param.com/v1", + ) + assert backend._api_key == "param-api-key" + assert backend._base_url == "https://api.param.com/v1" + + +def test_missing_api_key_raises_error(): + """Test that missing API key raises ValueError with helpful message.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError) as exc_info: + OpenAIBackend(model_id="gpt-4", base_url="https://api.test.com/v1") + assert "OPENAI_API_KEY or api_key is required but not set" in str( + exc_info.value + ) + + if __name__ == "__main__": import pytest diff --git a/test/stdlib_intrinsics/test_rag/test_rag.py b/test/stdlib_intrinsics/test_rag/test_rag.py index 016b78a0..47b13e02 100644 --- a/test/stdlib_intrinsics/test_rag/test_rag.py +++ b/test/stdlib_intrinsics/test_rag/test_rag.py @@ -184,6 +184,7 @@ def test_answer_relevance(backend): assert result == answer +@pytest.mark.qualitative def test_answer_relevance_classifier(backend): """Verify that the first phase of the answer relevance flow behaves as expectee.""" context, answer, docs = _read_input_json("answer_relevance.json")