Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Fuser/auto_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def __init__(
dispatch_jobs: int = 2,
allow_fallback: bool = True,
target_platform: str | None = None,
disable_cuda_math: bool = False,
) -> None:
self.ka_model = ka_model
self.ka_num_workers = ka_num_workers
Expand All @@ -352,6 +353,7 @@ def __init__(
self.dispatch_jobs = dispatch_jobs
self.allow_fallback = allow_fallback
self.platform_config = get_platform(target_platform)
self.disable_cuda_math = disable_cuda_math

def _solve_with_kernelagent(self, problem_code: str) -> RouteResult:
agent = TritonKernelAgent(
Expand All @@ -360,6 +362,7 @@ def _solve_with_kernelagent(self, problem_code: str) -> RouteResult:
model_name=self.ka_model,
high_reasoning_effort=self.ka_high_reasoning,
target_platform=self.platform_config,
disable_cuda_math=self.disable_cuda_math,
)
try:
# Ensure exceptions in KernelAgent do not abort routing; return a structured failure
Expand Down Expand Up @@ -710,6 +713,11 @@ def main(argv: list[str] | None = None) -> int:
choices=get_platform_choices(),
help="Target platform (default: cuda)",
)
p.add_argument(
"--disable-cuda-math",
action="store_true",
help="Disable cuSolver/cuBLAS library usage in generated kernels",
)
args = p.parse_args(argv)

# Load environment variables from .env file
Expand Down Expand Up @@ -741,6 +749,7 @@ def main(argv: list[str] | None = None) -> int:
dispatch_jobs=args.dispatch_jobs,
allow_fallback=(not args.no_fallback),
target_platform=args.target_platform,
disable_cuda_math=args.disable_cuda_math,
)

try:
Expand Down
8 changes: 8 additions & 0 deletions Fuser/dispatch_kernel_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def run(
jobs: int = 1,
target_platform: str = "cuda",
max_iters: int = 10,
disable_cuda_math: bool = False,
) -> Path:
"""Dispatch subgraphs to KernelAgent with optional parallelism.

Expand Down Expand Up @@ -369,6 +370,7 @@ def _handle_one(idx_item: tuple[int, dict[str, Any]]) -> tuple[int, dict[str, An
max_rounds=max_iters,
model_name=agent_model,
target_platform=platform,
disable_cuda_math=disable_cuda_math,
)
try:
result = local_agent.generate_kernel(
Expand Down Expand Up @@ -457,6 +459,11 @@ def main(argv: list[str] | None = None) -> int:
choices=get_platform_choices(),
help="Target platform (default: cuda)",
)
p.add_argument(
"--disable-cuda-math",
action="store_true",
help="Disable cuSolver/cuBLAS library usage in generated kernels",
)
args = p.parse_args(argv)

subgraphs_path = Path(args.subgraphs).resolve()
Expand Down Expand Up @@ -485,6 +492,7 @@ def main(argv: list[str] | None = None) -> int:
agent_model=args.agent_model,
jobs=jobs_val,
target_platform=args.target_platform,
disable_cuda_math=args.disable_cuda_math,
)
print(str(summary_path))
return 0
Expand Down
8 changes: 7 additions & 1 deletion triton_kernel_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
high_reasoning_effort: bool = True,
preferred_provider: BaseProvider | None = None,
target_platform: PlatformConfig | None = None,
disable_cuda_math: bool = False,
):
"""
Initialize the Triton Kernel Agent.
Expand All @@ -52,6 +53,7 @@ def __init__(
model_name: OpenAI model to use (loaded from .env if None)
high_reasoning_effort: Whether to use high reasoning effort for OpenAI models
target_platform: Target platform PlatformConfig
disable_cuda_math: If True, disables cuSolver/cuBLAS library usage
"""
# Load environment variables
load_dotenv()
Expand Down Expand Up @@ -87,6 +89,7 @@ def __init__(
self._platform_config = (
target_platform if target_platform else get_platform("cuda")
)
self.disable_cuda_math = disable_cuda_math

# Setup main logger
self._setup_logging()
Expand All @@ -103,6 +106,7 @@ def __init__(
openai_model=self.model_name,
high_reasoning_effort=self.high_reasoning_effort,
target_platform=self._platform_config.name,
disable_cuda_math=self.disable_cuda_math,
)

def _setup_logging(self):
Expand Down Expand Up @@ -339,7 +343,9 @@ def _generate_kernel_seeds(

# Create prompt with Triton guidelines using template
prompt = self.prompt_manager.render_kernel_generation_prompt(
problem_description=problem_description, test_code=test_code
problem_description=problem_description,
test_code=test_code,
disable_cuda_math=self.disable_cuda_math,
)

kernels = []
Expand Down
6 changes: 6 additions & 0 deletions triton_kernel_agent/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
openai_model: str = "gpt-5",
high_reasoning_effort: bool = True,
target_platform: str = "cuda",
disable_cuda_math: bool = False,
):
"""
Initialize the worker manager.
Expand All @@ -51,6 +52,7 @@ def __init__(
openai_model: OpenAI model name
high_reasoning_effort: Whether to use high reasoning effort for OpenAI models
target_platform: Target platform ('cuda' or 'xpu')
disable_cuda_math: If True, disables cuSolver/cuBLAS library usage
"""
self.num_workers = num_workers
self.max_rounds = max_rounds
Expand All @@ -59,6 +61,7 @@ def __init__(
self.openai_model = openai_model
self.high_reasoning_effort = high_reasoning_effort
self.target_platform = target_platform
self.disable_cuda_math = disable_cuda_math

# Setup logging
if log_dir is None:
Expand Down Expand Up @@ -164,6 +167,7 @@ def run_verification(
self.openai_model,
self.high_reasoning_effort,
self.target_platform,
self.disable_cuda_math,
)

process = mp.Process(target=worker_process, args=args)
Expand Down Expand Up @@ -228,6 +232,7 @@ def worker_process(
openai_model: str,
high_reasoning_effort: bool,
target_platform: str,
disable_cuda_math: bool = False,
):
"""
Worker process for kernel verification and refinement.
Expand All @@ -247,6 +252,7 @@ def worker_process(
openai_model=openai_model,
high_reasoning_effort=high_reasoning_effort,
target_platform=target_platform,
disable_cuda_math=disable_cuda_math,
)

result = worker.run(
Expand Down
6 changes: 6 additions & 0 deletions triton_kernel_agent/prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def render_kernel_generation_prompt(
problem_description: str,
test_code: str,
triton_guidelines: str | None = None,
disable_cuda_math: bool = False,
) -> str:
"""
Render the kernel generation prompt.
Expand All @@ -132,6 +133,7 @@ def render_kernel_generation_prompt(
problem_description: Description of the kernel to generate
test_code: Test code that the kernel must pass
triton_guidelines: Optional guidelines (if None, loads from template)
disable_cuda_math: If True, disables cuSolver/cuBLAS library usage

Returns:
Rendered prompt string
Expand All @@ -147,6 +149,7 @@ def render_kernel_generation_prompt(
test_code=test_code,
triton_guidelines=triton_guidelines,
kernel_guidance=self.target_platform.kernel_guidance,
disable_cuda_math=disable_cuda_math,
)

def render_kernel_refinement_prompt(
Expand All @@ -157,6 +160,7 @@ def render_kernel_refinement_prompt(
error_info: dict[str, str],
history_context: str | None = None,
triton_guidelines: str | None = None,
disable_cuda_math: bool = False,
) -> str:
"""
Render the kernel refinement prompt.
Expand All @@ -168,6 +172,7 @@ def render_kernel_refinement_prompt(
error_info: Dictionary with error information (stdout, stderr)
history_context: Optional context from previous attempts
triton_guidelines: Optional guidelines (if None, loads from template)
disable_cuda_math: If True, disables cuSolver/cuBLAS library usage

Returns:
Rendered prompt string
Expand All @@ -186,6 +191,7 @@ def render_kernel_refinement_prompt(
history_context=history_context,
triton_guidelines=triton_guidelines,
kernel_guidance=self.target_platform.kernel_guidance,
disable_cuda_math=disable_cuda_math,
)

def render_triton_guidelines(self) -> str:
Expand Down
13 changes: 8 additions & 5 deletions triton_kernel_agent/templates/kernel_generation.j2
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@ CRITICAL REQUIREMENTS - NO CHEATING ALLOWED:
- The actual computation MUST be implemented using Triton operations (tl.load, tl.store, tl.sum, etc.)
- The wrapper function can use PyTorch for tensor creation, memory allocation, and result formatting, but the core computation MUST happen in the Triton kernel
- DO NOT import or instantiate torch.nn modules, call torch.nn.functional (including aliases like F.*), or use PyTorch activations/pooling helpers to satisfy the requirements
{% if disable_cuda_math %}
- DO NOT rely on cuSolver/cuBLAS libraries to satisfy the requirements
{% endif %}

**WHAT IS REQUIRED:**
- Write actual Triton kernel code using triton.language operations
- Use tl.load() and tl.store() for memory access
- Use tl.sum(), tl.max(), tl.min() etc. for reductions
- Use tl.sum(), tl.max(), tl.min() etc. for reductions
- Use proper Triton indexing with tl.program_id(), tl.arange(), etc.
- Implement the algorithm step-by-step in Triton operations, not PyTorch shortcuts

Expand All @@ -86,14 +89,14 @@ def _actual_kernel(ptr_a, ptr_b, ptr_out, n_elements, BLOCK_SIZE: tl.constexpr):
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

# Load using Triton
a = tl.load(ptr_a + offsets, mask=mask)
b = tl.load(ptr_b + offsets, mask=mask)

# Compute using Triton operations (NOT torch.add!)
result = a + b

# Store using Triton
tl.store(ptr_out + offsets, result, mask=mask)

Expand All @@ -102,7 +105,7 @@ def kernel_function(tensor_a, tensor_b):
# PyTorch operations allowed here for setup
output = torch.empty_like(tensor_a)
n_elements = tensor_a.numel()

# Calculate grid and launch Triton kernel
BLOCK_SIZE = 1024
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
Expand Down
5 changes: 4 additions & 1 deletion triton_kernel_agent/templates/kernel_refinement.j2
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ CRITICAL REQUIREMENTS - NO CHEATING ALLOWED:
- The actual computation MUST be implemented using Triton operations (tl.load, tl.store, tl.sum, etc.)
- The wrapper function can use PyTorch for tensor creation, memory allocation, and result formatting, but the core computation MUST happen in the Triton kernel
- DO NOT import or instantiate torch.nn modules, call torch.nn.functional (including aliases like F.*), or use PyTorch activations/pooling helpers to satisfy the requirements
{% if disable_cuda_math %}
- DO NOT rely on cuSolver/cuBLAS libraries to satisfy the requirements
{% endif %}

**WHAT IS REQUIRED:**
- Write actual Triton kernel code using triton.language operations
- Use tl.load() and tl.store() for memory access
- Use tl.sum(), tl.max(), tl.min() etc. for reductions
- Use tl.sum(), tl.max(), tl.min() etc. for reductions
- Use proper Triton indexing with tl.program_id(), tl.arange(), etc.
- Implement the algorithm step-by-step in Triton operations, not PyTorch shortcuts
- The kernel should be decorated with @triton.jit, not the wrapper function
Expand Down
4 changes: 4 additions & 0 deletions triton_kernel_agent/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
openai_model: str = "gpt-5",
high_reasoning_effort: bool = True,
target_platform: str = "cuda",
disable_cuda_math: bool = False,
):
"""
Initialize a verification worker.
Expand All @@ -146,6 +147,7 @@ def __init__(
openai_model: Model name for refinement
high_reasoning_effort: Whether to use high reasoning effort for OpenAI models
target_platform: Target platform default: cuda
disable_cuda_math: If True, disables cuSolver/cuBLAS library usage
"""
self.worker_id = worker_id
self.workdir = Path(workdir)
Expand All @@ -155,6 +157,7 @@ def __init__(
self.openai_model = openai_model
self.high_reasoning_effort = high_reasoning_effort
self._platform_config = get_platform(target_platform)
self.disable_cuda_math = disable_cuda_math

# Setup files
self.kernel_file = self.workdir / "kernel.py"
Expand Down Expand Up @@ -363,6 +366,7 @@ def _refine_kernel(
kernel_code=kernel_code,
error_info=error_info,
history_context=history_context,
disable_cuda_math=self.disable_cuda_math,
)

# Call LLM API
Expand Down