diff --git a/Fuser/auto_agent.py b/Fuser/auto_agent.py index dadaa8f..563e174 100644 --- a/Fuser/auto_agent.py +++ b/Fuser/auto_agent.py @@ -61,7 +61,7 @@ get_platform_choices, get_platform, ) -from utils.providers.models import get_model_provider +from utils.providers.models import get_model_provider, is_model_available # ------------------------ @@ -138,6 +138,20 @@ def _dotted_name(n: ast.AST) -> str: return ".".join(parts) +def _validate_cfg_models(cfg) -> None: + """Given a router config, remove all unavailable model choices.""" + if (ka_model := cfg.get("ka_model")) and not is_model_available(ka_model): + del cfg["ka_model"] + + if models := cfg.get("llm_models"): + remove = [k for k, v in models.items() if not is_model_available(v)] + for k in remove: + del models[k] + + if not models: + del cfg["llm_models"] + + @dataclass class Complexity: has_control_flow: bool @@ -475,6 +489,7 @@ def solve(self, problem_path: Path) -> RouteResult: ) route_conf = cached.get("confidence") route_cfg = cached.get("config") or {} + _validate_cfg_models(route_cfg) if strategy is None: # Try LLM-driven decision @@ -487,6 +502,8 @@ def solve(self, problem_path: Path) -> RouteResult: "route_strategy": strategy, "confidence": route_conf, } + route_cfg = cache[code_hash].get("config") or {} + _validate_cfg_models(route_cfg) _save_router_cache(cache) except Exception: # No provider or failure; fall back later diff --git a/tests/fuser/test_auto_agent_cache.py b/tests/fuser/test_auto_agent_cache.py new file mode 100644 index 0000000..0b99079 --- /dev/null +++ b/tests/fuser/test_auto_agent_cache.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +from Fuser.auto_agent import _validate_cfg_models + + +class TestValidateCfgModels: + def test_validate_cfg_models_removes_unavailable(self): + """Test that unavailable models are removed from config.""" + cfg = { + "ka_model": "gpt-5", + "llm_models": { + "extract": "gpt-5", + "dispatch": "o4-mini", + }, + } + + with patch("Fuser.auto_agent.is_model_available", return_value=False): + _validate_cfg_models(cfg) + + assert "ka_model" not in cfg + assert "llm_models" not in cfg + + def test_validate_cfg_models_keeps_available(self): + """Test that available models are kept in config.""" + cfg = { + "ka_model": "gpt-5", + "llm_models": { + "extract": "gpt-5", + "dispatch": "o4-mini", + }, + } + + with patch("Fuser.auto_agent.is_model_available", return_value=True): + _validate_cfg_models(cfg) + + assert cfg["ka_model"] == "gpt-5" + assert cfg["llm_models"]["extract"] == "gpt-5" + assert cfg["llm_models"]["dispatch"] == "o4-mini"