-
Notifications
You must be signed in to change notification settings - Fork 27
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Problem
The Fuser outputs kernel_function(...) with arguments that don't have a 1:1 mapping to the original Model's parameters. Static/heuristic mapping fails for:
- Fused operations (Conv+BN → fused_weight, fused_bias)
- Reordered arguments
- Computed intermediates
- Non-obvious naming conventions
Solution: LLM Mapping Step
Add a new pipeline step that uses an LLM to generate the argument mapping by inspecting:
- Original KernelBench problem (Model class, get_inputs, get_init_inputs)
- Generated composed kernel (kernel_function signature)
- Subgraphs metadata (operation fusion info)
Pipeline
extract → dispatch → compose → map_kernelbench (LLM) → format_kernelbench
LLM Mapping Step
Input to LLM:
- Original Model code
- kernel_function signature + implementation
- Subgraph fusion information
Output from LLM:
{
"args": [
{"kernel_arg": "x", "source": "input", "expr": "inputs[0]"},
{"kernel_arg": "weight", "source": "param", "expr": "model.conv.weight"},
{"kernel_arg": "fused_bias", "source": "computed", "expr": "fuse_conv_bn_bias(model)"}
],
"helpers_needed": ["fuse_conv_bn_weights"],
"notes": "BN folded into conv, need to fuse weights at load time"
}Format Step
Takes the LLM-generated mapping and produces final kernelbench_model.py:
- Generates
ModelNew.__init__with correct parameter registration - Generates
forward()that calls kernel_function with mapped args - Includes any helper functions (weight fusion, etc.)
Why This Works or it should
- LLM understands the transformation - it can read both sides and reason about the correspondence, it will be extremely hard for static method
- Handles edge cases - fused weights, reordered args, computed values
- Produces explicit mapping - debuggable, auditable
- Reuses existing infrastructure - same LLM dispatch pattern as kernel generation
Verification
After generation, run:
original_out = Model(*init_inputs).forward(*inputs)
new_out = ModelNew(*init_inputs).forward(*inputs)
assert torch.allclose(original_out, new_out, rtol=1e-3)If verification fails, can re-run mapping step with error feedback.
Implementation Scope
kernelbench_mapping.j2- prompt template for LLM mappingFuser/kernelbench_mapper.py- calls LLM, parses mapping JSONFuser/kernelbench_formatter.py- applies mapping to generate final code- Pipeline integration via
--output-format kernelbenchflag
Prompt Template
You are analyzing a Triton kernel generated from a PyTorch model to create an argument mapping.
## Original PyTorch Model
```python
{{ original_model_code }}Generated Triton Kernel
{{ kernel_function_code }}Subgraph Info
{{ subgraph_info }}
Task
Analyze how kernel_function arguments map to the original Model's parameters and inputs.
Output a JSON mapping:
{
"args": [
{
"kernel_arg": "<argument name in kernel_function>",
"source": "input|param|buffer|computed",
"expr": "<Python expression to get value from Model instance or inputs>"
}
],
"weight_fusion": {
"needed": true|false,
"description": "<what fusion is needed, e.g., Conv+BN weight folding>"
},
"forward_inputs": ["<list of forward() parameter names>"]
}Rules:
source: "input"→ comes from forward() arguments, expr likexorinputs[0]source: "param"→ comes from model parameter, expr likeself.conv.weightsource: "buffer"→ comes from model buffer, expr likeself.bn.running_meansource: "computed"→ requires computation, expr likeself.fused_weight(will be precomputed)
For fused operations (e.g., Conv+BN), identify which weights need to be fused at load time.
Output only valid JSON, no explanation.
---
## Example Usage
```bash
# Full pipeline with KernelBench output
python -m Fuser.pipeline \
--problem /path/to/kernelbench/level2/problem.py \
--output-format kernelbench \
--verify
# Standalone mapping step
python -m Fuser.kernelbench_mapper \
--problem /path/to/problem.py \
--composed-kernel /path/to/composed.py \
--subgraphs /path/to/subgraphs.json \
--output /path/to/mapping.json
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request