diff --git a/Fuser/auto_agent.py b/Fuser/auto_agent.py index dadaa8f..35b3fbe 100644 --- a/Fuser/auto_agent.py +++ b/Fuser/auto_agent.py @@ -330,6 +330,7 @@ def __init__( dispatch_jobs: int = 2, allow_fallback: bool = True, target_platform: str | None = None, + use_router_cache: bool = True, ) -> None: self.ka_model = ka_model self.ka_num_workers = ka_num_workers @@ -352,6 +353,7 @@ def __init__( self.dispatch_jobs = dispatch_jobs self.allow_fallback = allow_fallback self.platform_config = get_platform(target_platform) + self.use_router_cache = use_router_cache def _solve_with_kernelagent(self, problem_code: str) -> RouteResult: agent = TritonKernelAgent( @@ -461,20 +463,23 @@ def solve(self, problem_path: Path) -> RouteResult: heuristic_prefers_fuser = cx.route_to_fuser() # Cache lookup by content hash to avoid repeated router calls + cache = {} code_hash = _file_sha256_text(code) - cache = _load_router_cache() - cached = cache.get(code_hash) - strategy: str | None = None route_conf: float | None = None route_cfg: dict[str, Any] = {} - if isinstance(cached, dict): - strategy = ( - str(cached.get("route_strategy") or cached.get("route") or "") or None - ) - route_conf = cached.get("confidence") - route_cfg = cached.get("config") or {} + if self.use_router_cache: + cache = _load_router_cache() + cached = cache.get(code_hash) + + if isinstance(cached, dict): + strategy = ( + str(cached.get("route_strategy") or cached.get("route") or "") + or None + ) + route_conf = cached.get("confidence") + route_cfg = cached.get("config") or {} if strategy is None: # Try LLM-driven decision @@ -483,11 +488,12 @@ def solve(self, problem_path: Path) -> RouteResult: problem_path, code, cx ) # Persist in cache for future runs - cache[code_hash] = info.get("parsed") or { - "route_strategy": strategy, - "confidence": route_conf, - } - _save_router_cache(cache) + if self.use_router_cache: + cache[code_hash] = info.get("parsed") or { + "route_strategy": strategy, + "confidence": route_conf, + } + _save_router_cache(cache) except Exception: # No provider or failure; fall back later pass @@ -704,6 +710,11 @@ def main(argv: list[str] | None = None) -> int: p.add_argument("--verify", action="store_true") p.add_argument("--dispatch-jobs", type=int, default=2) p.add_argument("--no-fallback", action="store_true") + p.add_argument( + "--no-router-cache", + action="store_true", + help="Disable router cache (do not read from or write to cache)", + ) p.add_argument( "--target-platform", default="cuda", @@ -741,6 +752,7 @@ def main(argv: list[str] | None = None) -> int: dispatch_jobs=args.dispatch_jobs, allow_fallback=(not args.no_fallback), target_platform=args.target_platform, + use_router_cache=(not args.no_router_cache), ) try: diff --git a/README.md b/README.md index d39fb4b..26302a5 100644 --- a/README.md +++ b/README.md @@ -91,8 +91,10 @@ More knobs live in `triton_kernel_agent/agent.py` and `Fuser/config.py`. ```bash python -m Fuser.auto_agent \ --problem /abs/path/to/KernelBench/level1/19_ReLU.py \ - --verify # ensure final composition test runs + --no-router-cache \ # avoid caching or using cached results + --verify # ensure final composition test runs ``` + `--no-router-cache` can be enabled to avoid utilizing any cached router results and prevent writing to the cache. - **Manually run the pipeline (extract → dispatch → compose)** when you want explicit control over models or concurrency: ```bash @@ -144,7 +146,7 @@ More knobs live in `triton_kernel_agent/agent.py` and `Fuser/config.py`. ## Component Details -- **AutoRouter (`Fuser/auto_agent.py`)**: parses the problem’s AST, looks for attention blocks, transposed convolutions, control flow, and long op chains. It caches decisions under `.fuse/router_cache.json` and can fall back to the other path if the first attempt fails. +- **AutoRouter (`Fuser/auto_agent.py`)**: parses the problem’s AST, looks for attention blocks, transposed convolutions, control flow, and long op chains. It caches decisions under `.fuse/router_cache.json` and can fall back to the other path if the first attempt fails. Use `--no-router-cache` ignore the existing cache and caching new routes. - **Fuser Orchestrator (`Fuser/orchestrator.py`)**: rewrites the PyTorch module into fusable modules, executes them for validation, and packages a tarball of the fused code. Run IDs and directories are managed via `Fuser/paths.py`.