Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion Fuser/auto_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
get_platform_choices,
get_platform,
)
from utils.providers.models import get_model_provider
from utils.providers.models import get_model_provider, is_model_available


# ------------------------
Expand Down Expand Up @@ -138,6 +138,20 @@ def _dotted_name(n: ast.AST) -> str:
return ".".join(parts)


def _validate_cfg_models(cfg) -> None:
"""Given a router config, remove all unavailable model choices."""
if (ka_model := cfg.get("ka_model")) and not is_model_available(ka_model):
del cfg["ka_model"]

if models := cfg.get("llm_models"):
remove = [k for k, v in models.items() if not is_model_available(v)]
for k in remove:
del models[k]

if not models:
del cfg["llm_models"]


@dataclass
class Complexity:
has_control_flow: bool
Expand Down Expand Up @@ -475,6 +489,7 @@ def solve(self, problem_path: Path) -> RouteResult:
)
route_conf = cached.get("confidence")
route_cfg = cached.get("config") or {}
_validate_cfg_models(route_cfg)

if strategy is None:
# Try LLM-driven decision
Expand All @@ -487,6 +502,8 @@ def solve(self, problem_path: Path) -> RouteResult:
"route_strategy": strategy,
"confidence": route_conf,
}
route_cfg = cache[code_hash].get("config") or {}
_validate_cfg_models(route_cfg)
_save_router_cache(cache)
except Exception:
# No provider or failure; fall back later
Expand Down
52 changes: 52 additions & 0 deletions tests/fuser/test_auto_agent_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import patch

from Fuser.auto_agent import _validate_cfg_models


class TestValidateCfgModels:
def test_validate_cfg_models_removes_unavailable(self):
"""Test that unavailable models are removed from config."""
cfg = {
"ka_model": "gpt-5",
"llm_models": {
"extract": "gpt-5",
"dispatch": "o4-mini",
},
}

with patch("Fuser.auto_agent.is_model_available", return_value=False):
_validate_cfg_models(cfg)

assert "ka_model" not in cfg
assert "llm_models" not in cfg

def test_validate_cfg_models_keeps_available(self):
"""Test that available models are kept in config."""
cfg = {
"ka_model": "gpt-5",
"llm_models": {
"extract": "gpt-5",
"dispatch": "o4-mini",
},
}

with patch("Fuser.auto_agent.is_model_available", return_value=True):
_validate_cfg_models(cfg)

assert cfg["ka_model"] == "gpt-5"
assert cfg["llm_models"]["extract"] == "gpt-5"
assert cfg["llm_models"]["dispatch"] == "o4-mini"