Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions mellea/backends/model_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ModelIdentifier:
ollama_name: str | None = None
watsonx_name: str | None = None
mlx_name: str | None = None
openai_name: str | None = None

hf_tokenizer_name: str | None = None # if None, is the same as hf_model_name

Expand Down Expand Up @@ -134,9 +135,9 @@ class ModelIdentifier:

QWEN3_14B = ModelIdentifier(hf_model_name="Qwen/Qwen3-14B", ollama_name="qwen3:14b")

######################
#### OpenAI models ###
######################
###########################
#### OpenAI open models ###
###########################

OPENAI_GPT_OSS_20B = ModelIdentifier(
hf_model_name="openai/gpt-oss-20b", ollama_name="gpt-oss:20b"
Expand All @@ -145,6 +146,12 @@ class ModelIdentifier:
hf_model_name="openai/gpt-oss-120b", ollama_name="gpt-oss:120b"
)

###########################
#### OpenAI prop models ###
###########################

OPENAI_GPT_5_1 = ModelIdentifier(openai_name="gpt-5.1")

#####################
#### Misc models ####
#####################
Expand Down
88 changes: 64 additions & 24 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class OpenAIBackend(FormatterBackend, AdapterMixin):

def __init__(
self,
model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_4_MICRO_3B,
model_id: str | ModelIdentifier = model_ids.OPENAI_GPT_5_1,
formatter: Formatter | None = None,
base_url: str | None = None,
model_options: dict | None = None,
Expand Down Expand Up @@ -142,26 +142,30 @@ def __init__(

self.default_to_constraint_checking_alora = default_to_constraint_checking_alora

self._model_id = model_id
match model_id:
case str():
self._hf_model_id = model_id
self._model_id = model_id
case ModelIdentifier():
assert model_id.hf_model_name is not None, (
"model_id is None. This can also happen if the ModelIdentifier has no hf_model_id name set."
assert model_id.openai_name is not None, (
"model_id is None. This can also happen if the ModelIdentifier has no `openai_name` name set."
)
self._hf_model_id = model_id.hf_model_name
self._model_id = model_id.openai_name

if base_url is None:
self._base_url = "http://localhost:11434/v1" # ollama
else:
self._base_url = base_url
if api_key is None:
FancyLogger.get_logger().warning(
"You are using an OpenAI backend with no api_key. Because no API key was provided, mellea assumes you intend to use the openai-compatible interface to your local ollama instance. If you intend to use OpenAI's platform you must specify your API key when instantiating your Mellea session/backend object."
)
self._base_url: str | None = "http://localhost:11434/v1" # ollama
self._api_key = "ollama"
Comment on lines 154 to 159
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are close to defaults that make sense here. I think if the user specifies a base_url we should always use that base_url (even if no apikey is set). I also wonder if we should default the apikey to ollama in those situations.

Otherwise, we have no way to target arbitrary localhost ports that don't require an apikey.

For example (and this isn't the best since it uses LiteLLM and we have a separate backend for that), LiteLLM has a proxy that you can run locally. This proxy stores the apikey information itself; so you can target an arbitrary localhost port without an apikey.

My proposed solution would be to just set the parameter default values to work for the ollama version (ie api_key="ollama" and base_url="http://localhost:11434/v1"). Then users can override these values. I think this would also allow users to explicitly set api_key / base_url to None and have the underlying OpenAI SDK still automatically pick up their env vars (without the risk of users accidentally incurring expenses).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consensus: just pass the args through to the openai sdk. Don't do argument handling such as this.

else:
self._base_url = base_url
self._api_key = api_key

self._server_type = _server_type(self._base_url)
self._server_type: _ServerType = (
_server_type(self._base_url)
if self._base_url is not None
else _ServerType.OPENAI
) # type: ignore

self._openai_client_kwargs = self.filter_openai_client_kwargs(**kwargs)

Expand Down Expand Up @@ -598,14 +602,38 @@ async def _generate_from_chat_context_standard(

extra_params: dict[str, Any] = {}
if _format is not None:
extra_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": _format.__name__,
"schema": _format.model_json_schema(),
"strict": True,
},
}
if self._server_type == _ServerType.OPENAI:
# The OpenAI platform requires that additionalProperties=False on all response_format schemas.
# However, not all schemas generates by Mellea include additionalProperties.
# GenerativeSlot, in particular, does not add this property.
# The easiest way to address this disparity between OpenAI and other inference providers is to
# monkey-patch the response format exactly when we are actually using the OpenAI server.
#
# This only addresses the additionalProperties=False constraint.
# Other constraints we should be checking/patching are described here:
# https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat
monkey_patched_response_schema = _format.model_json_schema()
monkey_patched_response_schema["additionalProperties"] = False
extra_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": _format.__name__,
"schema": monkey_patched_response_schema,
"strict": True,
},
}
else:
FancyLogger().get_logger().warning(
"Mellea assumes you are NOT using the OpenAI platform, and that other model providers have less strict requirements on support JSON schemas passed into `format=`. If you encounter a server-side error following this message, then you found an exception to this assumption. Please open an issue at github.com/generative_computing/mellea with this stack trace and your inference engine / model provider."
)
extra_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": _format.__name__,
"schema": _format.model_json_schema(),
"strict": True,
},
}

# Append tool call information if applicable.
tools: dict[str, Callable] = dict()
Expand All @@ -631,15 +659,21 @@ async def _generate_from_chat_context_standard(
formatted_tools = convert_tools_to_json(tools)
use_tools = len(formatted_tools) > 0

# Build optional reasoning parameters
# NOTE: the openai SDK doesn't like it if you pass `reasoning_effort` param to a non-reasoning model e.g. gpt4o
reasoning_params = {}
if thinking is not None:
reasoning_params["reasoning_effort"] = thinking

chat_response: Coroutine[
Any, Any, ChatCompletion | openai.AsyncStream[ChatCompletionChunk]
] = self._async_client.chat.completions.create(
model=self._hf_model_id,
model=self._model_id,
messages=conversation, # type: ignore
reasoning_effort=thinking, # type: ignore
tools=formatted_tools if use_tools else None, # type: ignore
# parallel_tool_calls=False, # We only support calling one tool per turn. But we do the choosing on our side so we leave this False.
**extra_params,
**reasoning_params, # type: ignore
**self._make_backend_specific_and_remove(
model_opts, is_chat_context=ctx.is_chat_context
),
Expand Down Expand Up @@ -807,7 +841,7 @@ async def generate_from_raw(
try:
completion_response: Completion = (
await self._async_client.completions.create(
model=self._hf_model_id,
model=self._model_id,
prompt=prompts,
extra_body=extra_body,
**self._make_backend_specific_and_remove(
Expand Down Expand Up @@ -860,7 +894,10 @@ async def generate_from_raw(
@property
def base_model_name(self):
"""Returns the base_model_id of the model used by the backend. For example, `granite-3.3-8b-instruct` for `ibm-granite/granite-3.3-8b-instruct`."""
return self._hf_model_id.split("/")[1]
if "/" in self._model_id:
return self._model_id.split("/")[1]
else:
return self._model_id

def add_adapter(self, adapter: OpenAIAdapter):
"""Adds the given adapter to the backend. Must not have been added to a different backend."""
Expand Down Expand Up @@ -976,10 +1013,13 @@ def apply_chat_template(self, chat: list[dict[str, str]]):
from transformers import AutoTokenizer

if not hasattr(self, "_tokenizer"):
assert self._base_url, (
"The OpenAI Platform does not support adapters. You must specify a _base_url when using adapters."
)
match _server_type(self._base_url):
case _ServerType.LOCALHOST:
self._tokenizer: "PreTrainedTokenizer" = ( # noqa: UP037
AutoTokenizer.from_pretrained(self._hf_model_id)
AutoTokenizer.from_pretrained(self._model_id)
)
case _ServerType.OPENAI:
raise Exception(
Expand Down
40 changes: 40 additions & 0 deletions test/backends/test_openai_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,46 @@ async def get_client_async():
assert len(backend._client_cache.cache.values()) == 2


async def test_reasoning_effort_conditional_passing(backend):
"""Test that reasoning_effort is only passed to API when not None."""
from unittest.mock import AsyncMock, MagicMock, patch

ctx = ChatContext()
ctx = ctx.add(CBlock(value="Test"))

mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = "Response"
mock_response.choices[0].message.role = "assistant"

# Test 1: reasoning_effort should NOT be passed when not specified
with patch.object(
backend._async_client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.return_value = mock_response
await backend.generate_from_chat_context(
CBlock(value="Hi"), ctx, model_options={}
)
call_kwargs = mock_create.call_args.kwargs
assert "reasoning_effort" not in call_kwargs, (
"reasoning_effort should not be passed when not specified"
)

# Test 2: reasoning_effort SHOULD be passed when specified
with patch.object(
backend._async_client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.return_value = mock_response
await backend.generate_from_chat_context(
CBlock(value="Hi"), ctx, model_options={ModelOption.THINKING: "medium"}
)
call_kwargs = mock_create.call_args.kwargs
assert call_kwargs.get("reasoning_effort") == "medium", (
"reasoning_effort should be passed with correct value when specified"
)


if __name__ == "__main__":
import pytest

Expand Down
Loading