Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions Fuser/auto_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def __init__(
dispatch_jobs: int = 2,
allow_fallback: bool = True,
target_platform: str | None = None,
use_router_cache: bool = True,
) -> None:
self.ka_model = ka_model
self.ka_num_workers = ka_num_workers
Expand All @@ -352,6 +353,7 @@ def __init__(
self.dispatch_jobs = dispatch_jobs
self.allow_fallback = allow_fallback
self.platform_config = get_platform(target_platform)
self.use_router_cache = use_router_cache

def _solve_with_kernelagent(self, problem_code: str) -> RouteResult:
agent = TritonKernelAgent(
Expand Down Expand Up @@ -461,20 +463,23 @@ def solve(self, problem_path: Path) -> RouteResult:
heuristic_prefers_fuser = cx.route_to_fuser()

# Cache lookup by content hash to avoid repeated router calls
cache = {}
code_hash = _file_sha256_text(code)
cache = _load_router_cache()
cached = cache.get(code_hash)

strategy: str | None = None
route_conf: float | None = None
route_cfg: dict[str, Any] = {}

if isinstance(cached, dict):
strategy = (
str(cached.get("route_strategy") or cached.get("route") or "") or None
)
route_conf = cached.get("confidence")
route_cfg = cached.get("config") or {}
if self.use_router_cache:
cache = _load_router_cache()
cached = cache.get(code_hash)

if isinstance(cached, dict):
strategy = (
str(cached.get("route_strategy") or cached.get("route") or "")
or None
)
route_conf = cached.get("confidence")
route_cfg = cached.get("config") or {}

if strategy is None:
# Try LLM-driven decision
Expand All @@ -483,11 +488,12 @@ def solve(self, problem_path: Path) -> RouteResult:
problem_path, code, cx
)
# Persist in cache for future runs
cache[code_hash] = info.get("parsed") or {
"route_strategy": strategy,
"confidence": route_conf,
}
_save_router_cache(cache)
if self.use_router_cache:
cache[code_hash] = info.get("parsed") or {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

router-emitted config.llm_models / config.ka_model can be non-configured/inaccessible, gets cached verbatim, and later runs apply it unvalidated. Please consider validating/intersecting against the local registry (utils/providers/available_models.py) + provider availability before (a) applying and/or (b) writing to cache.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I have a local fix I'll push in a separate PR since there's actually some other bugs that get tackled over there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#79

"route_strategy": strategy,
"confidence": route_conf,
}
_save_router_cache(cache)
except Exception:
# No provider or failure; fall back later
pass
Expand Down Expand Up @@ -704,6 +710,11 @@ def main(argv: list[str] | None = None) -> int:
p.add_argument("--verify", action="store_true")
p.add_argument("--dispatch-jobs", type=int, default=2)
p.add_argument("--no-fallback", action="store_true")
p.add_argument(
"--no-router-cache",
action="store_true",
help="Disable router cache (do not read from or write to cache)",
)
p.add_argument(
"--target-platform",
default="cuda",
Expand Down Expand Up @@ -741,6 +752,7 @@ def main(argv: list[str] | None = None) -> int:
dispatch_jobs=args.dispatch_jobs,
allow_fallback=(not args.no_fallback),
target_platform=args.target_platform,
use_router_cache=(not args.no_router_cache),
)

try:
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ More knobs live in `triton_kernel_agent/agent.py` and `Fuser/config.py`.
```bash
python -m Fuser.auto_agent \
--problem /abs/path/to/KernelBench/level1/19_ReLU.py \
--verify # ensure final composition test runs
--no-router-cache \ # avoid caching or using cached results
--verify # ensure final composition test runs
```
`--no-router-cache` can be enabled to avoid utilizing any cached router results and prevent writing to the cache.

- **Manually run the pipeline (extract → dispatch → compose)** when you want explicit control over models or concurrency:
```bash
Expand Down Expand Up @@ -144,7 +146,7 @@ More knobs live in `triton_kernel_agent/agent.py` and `Fuser/config.py`.

## Component Details

- **AutoRouter (`Fuser/auto_agent.py`)**: parses the problem’s AST, looks for attention blocks, transposed convolutions, control flow, and long op chains. It caches decisions under `.fuse/router_cache.json` and can fall back to the other path if the first attempt fails.
- **AutoRouter (`Fuser/auto_agent.py`)**: parses the problem’s AST, looks for attention blocks, transposed convolutions, control flow, and long op chains. It caches decisions under `.fuse/router_cache.json` and can fall back to the other path if the first attempt fails. Use `--no-router-cache` ignore the existing cache and caching new routes.

- **Fuser Orchestrator (`Fuser/orchestrator.py`)**: rewrites the PyTorch module into fusable modules, executes them for validation, and packages a tarball of the fused code. Run IDs and directories are managed via `Fuser/paths.py`.

Expand Down