From a06573e1ffc0109acdbd7f6d7872448a72e6e37f Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Thu, 22 Jan 2026 15:52:05 -0800 Subject: [PATCH 1/2] Disable cuSolver --- triton_kernel_agent/templates/kernel_generation.j2 | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/triton_kernel_agent/templates/kernel_generation.j2 b/triton_kernel_agent/templates/kernel_generation.j2 index 9764af2..932bf89 100644 --- a/triton_kernel_agent/templates/kernel_generation.j2 +++ b/triton_kernel_agent/templates/kernel_generation.j2 @@ -55,11 +55,12 @@ CRITICAL REQUIREMENTS - NO CHEATING ALLOWED: - The actual computation MUST be implemented using Triton operations (tl.load, tl.store, tl.sum, etc.) - The wrapper function can use PyTorch for tensor creation, memory allocation, and result formatting, but the core computation MUST happen in the Triton kernel - DO NOT import or instantiate torch.nn modules, call torch.nn.functional (including aliases like F.*), or use PyTorch activations/pooling helpers to satisfy the requirements +- DO NOT rely on cuSolver/cuBLAS libraries to satisfy the requirements **WHAT IS REQUIRED:** - Write actual Triton kernel code using triton.language operations - Use tl.load() and tl.store() for memory access -- Use tl.sum(), tl.max(), tl.min() etc. for reductions +- Use tl.sum(), tl.max(), tl.min() etc. for reductions - Use proper Triton indexing with tl.program_id(), tl.arange(), etc. - Implement the algorithm step-by-step in Triton operations, not PyTorch shortcuts @@ -86,14 +87,14 @@ def _actual_kernel(ptr_a, ptr_b, ptr_out, n_elements, BLOCK_SIZE: tl.constexpr): block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements - + # Load using Triton a = tl.load(ptr_a + offsets, mask=mask) b = tl.load(ptr_b + offsets, mask=mask) - + # Compute using Triton operations (NOT torch.add!) result = a + b - + # Store using Triton tl.store(ptr_out + offsets, result, mask=mask) @@ -102,7 +103,7 @@ def kernel_function(tensor_a, tensor_b): # PyTorch operations allowed here for setup output = torch.empty_like(tensor_a) n_elements = tensor_a.numel() - + # Calculate grid and launch Triton kernel BLOCK_SIZE = 1024 grid = (triton.cdiv(n_elements, BLOCK_SIZE),) From 402aabfb388037159c5ef849e853656e8f7d9149 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Fri, 23 Jan 2026 15:46:22 -0800 Subject: [PATCH 2/2] Wire the option to disable cuSolver --- Fuser/auto_agent.py | 9 +++++++++ Fuser/dispatch_kernel_agent.py | 8 ++++++++ triton_kernel_agent/agent.py | 8 +++++++- triton_kernel_agent/manager.py | 6 ++++++ triton_kernel_agent/prompt_manager.py | 6 ++++++ triton_kernel_agent/templates/kernel_generation.j2 | 2 ++ triton_kernel_agent/templates/kernel_refinement.j2 | 5 ++++- triton_kernel_agent/worker.py | 4 ++++ 8 files changed, 46 insertions(+), 2 deletions(-) diff --git a/Fuser/auto_agent.py b/Fuser/auto_agent.py index dadaa8f..faee5dc 100644 --- a/Fuser/auto_agent.py +++ b/Fuser/auto_agent.py @@ -330,6 +330,7 @@ def __init__( dispatch_jobs: int = 2, allow_fallback: bool = True, target_platform: str | None = None, + disable_cuda_math: bool = False, ) -> None: self.ka_model = ka_model self.ka_num_workers = ka_num_workers @@ -352,6 +353,7 @@ def __init__( self.dispatch_jobs = dispatch_jobs self.allow_fallback = allow_fallback self.platform_config = get_platform(target_platform) + self.disable_cuda_math = disable_cuda_math def _solve_with_kernelagent(self, problem_code: str) -> RouteResult: agent = TritonKernelAgent( @@ -360,6 +362,7 @@ def _solve_with_kernelagent(self, problem_code: str) -> RouteResult: model_name=self.ka_model, high_reasoning_effort=self.ka_high_reasoning, target_platform=self.platform_config, + disable_cuda_math=self.disable_cuda_math, ) try: # Ensure exceptions in KernelAgent do not abort routing; return a structured failure @@ -710,6 +713,11 @@ def main(argv: list[str] | None = None) -> int: choices=get_platform_choices(), help="Target platform (default: cuda)", ) + p.add_argument( + "--disable-cuda-math", + action="store_true", + help="Disable cuSolver/cuBLAS library usage in generated kernels", + ) args = p.parse_args(argv) # Load environment variables from .env file @@ -741,6 +749,7 @@ def main(argv: list[str] | None = None) -> int: dispatch_jobs=args.dispatch_jobs, allow_fallback=(not args.no_fallback), target_platform=args.target_platform, + disable_cuda_math=args.disable_cuda_math, ) try: diff --git a/Fuser/dispatch_kernel_agent.py b/Fuser/dispatch_kernel_agent.py index 6e233ff..9c7132a 100644 --- a/Fuser/dispatch_kernel_agent.py +++ b/Fuser/dispatch_kernel_agent.py @@ -333,6 +333,7 @@ def run( jobs: int = 1, target_platform: str = "cuda", max_iters: int = 10, + disable_cuda_math: bool = False, ) -> Path: """Dispatch subgraphs to KernelAgent with optional parallelism. @@ -369,6 +370,7 @@ def _handle_one(idx_item: tuple[int, dict[str, Any]]) -> tuple[int, dict[str, An max_rounds=max_iters, model_name=agent_model, target_platform=platform, + disable_cuda_math=disable_cuda_math, ) try: result = local_agent.generate_kernel( @@ -457,6 +459,11 @@ def main(argv: list[str] | None = None) -> int: choices=get_platform_choices(), help="Target platform (default: cuda)", ) + p.add_argument( + "--disable-cuda-math", + action="store_true", + help="Disable cuSolver/cuBLAS library usage in generated kernels", + ) args = p.parse_args(argv) subgraphs_path = Path(args.subgraphs).resolve() @@ -485,6 +492,7 @@ def main(argv: list[str] | None = None) -> int: agent_model=args.agent_model, jobs=jobs_val, target_platform=args.target_platform, + disable_cuda_math=args.disable_cuda_math, ) print(str(summary_path)) return 0 diff --git a/triton_kernel_agent/agent.py b/triton_kernel_agent/agent.py index a85dbe4..5572a36 100644 --- a/triton_kernel_agent/agent.py +++ b/triton_kernel_agent/agent.py @@ -41,6 +41,7 @@ def __init__( high_reasoning_effort: bool = True, preferred_provider: BaseProvider | None = None, target_platform: PlatformConfig | None = None, + disable_cuda_math: bool = False, ): """ Initialize the Triton Kernel Agent. @@ -52,6 +53,7 @@ def __init__( model_name: OpenAI model to use (loaded from .env if None) high_reasoning_effort: Whether to use high reasoning effort for OpenAI models target_platform: Target platform PlatformConfig + disable_cuda_math: If True, disables cuSolver/cuBLAS library usage """ # Load environment variables load_dotenv() @@ -87,6 +89,7 @@ def __init__( self._platform_config = ( target_platform if target_platform else get_platform("cuda") ) + self.disable_cuda_math = disable_cuda_math # Setup main logger self._setup_logging() @@ -103,6 +106,7 @@ def __init__( openai_model=self.model_name, high_reasoning_effort=self.high_reasoning_effort, target_platform=self._platform_config.name, + disable_cuda_math=self.disable_cuda_math, ) def _setup_logging(self): @@ -339,7 +343,9 @@ def _generate_kernel_seeds( # Create prompt with Triton guidelines using template prompt = self.prompt_manager.render_kernel_generation_prompt( - problem_description=problem_description, test_code=test_code + problem_description=problem_description, + test_code=test_code, + disable_cuda_math=self.disable_cuda_math, ) kernels = [] diff --git a/triton_kernel_agent/manager.py b/triton_kernel_agent/manager.py index 2b191d9..940057a 100644 --- a/triton_kernel_agent/manager.py +++ b/triton_kernel_agent/manager.py @@ -38,6 +38,7 @@ def __init__( openai_model: str = "gpt-5", high_reasoning_effort: bool = True, target_platform: str = "cuda", + disable_cuda_math: bool = False, ): """ Initialize the worker manager. @@ -51,6 +52,7 @@ def __init__( openai_model: OpenAI model name high_reasoning_effort: Whether to use high reasoning effort for OpenAI models target_platform: Target platform ('cuda' or 'xpu') + disable_cuda_math: If True, disables cuSolver/cuBLAS library usage """ self.num_workers = num_workers self.max_rounds = max_rounds @@ -59,6 +61,7 @@ def __init__( self.openai_model = openai_model self.high_reasoning_effort = high_reasoning_effort self.target_platform = target_platform + self.disable_cuda_math = disable_cuda_math # Setup logging if log_dir is None: @@ -164,6 +167,7 @@ def run_verification( self.openai_model, self.high_reasoning_effort, self.target_platform, + self.disable_cuda_math, ) process = mp.Process(target=worker_process, args=args) @@ -228,6 +232,7 @@ def worker_process( openai_model: str, high_reasoning_effort: bool, target_platform: str, + disable_cuda_math: bool = False, ): """ Worker process for kernel verification and refinement. @@ -247,6 +252,7 @@ def worker_process( openai_model=openai_model, high_reasoning_effort=high_reasoning_effort, target_platform=target_platform, + disable_cuda_math=disable_cuda_math, ) result = worker.run( diff --git a/triton_kernel_agent/prompt_manager.py b/triton_kernel_agent/prompt_manager.py index 7c22009..c56d1f9 100644 --- a/triton_kernel_agent/prompt_manager.py +++ b/triton_kernel_agent/prompt_manager.py @@ -124,6 +124,7 @@ def render_kernel_generation_prompt( problem_description: str, test_code: str, triton_guidelines: str | None = None, + disable_cuda_math: bool = False, ) -> str: """ Render the kernel generation prompt. @@ -132,6 +133,7 @@ def render_kernel_generation_prompt( problem_description: Description of the kernel to generate test_code: Test code that the kernel must pass triton_guidelines: Optional guidelines (if None, loads from template) + disable_cuda_math: If True, disables cuSolver/cuBLAS library usage Returns: Rendered prompt string @@ -147,6 +149,7 @@ def render_kernel_generation_prompt( test_code=test_code, triton_guidelines=triton_guidelines, kernel_guidance=self.target_platform.kernel_guidance, + disable_cuda_math=disable_cuda_math, ) def render_kernel_refinement_prompt( @@ -157,6 +160,7 @@ def render_kernel_refinement_prompt( error_info: dict[str, str], history_context: str | None = None, triton_guidelines: str | None = None, + disable_cuda_math: bool = False, ) -> str: """ Render the kernel refinement prompt. @@ -168,6 +172,7 @@ def render_kernel_refinement_prompt( error_info: Dictionary with error information (stdout, stderr) history_context: Optional context from previous attempts triton_guidelines: Optional guidelines (if None, loads from template) + disable_cuda_math: If True, disables cuSolver/cuBLAS library usage Returns: Rendered prompt string @@ -186,6 +191,7 @@ def render_kernel_refinement_prompt( history_context=history_context, triton_guidelines=triton_guidelines, kernel_guidance=self.target_platform.kernel_guidance, + disable_cuda_math=disable_cuda_math, ) def render_triton_guidelines(self) -> str: diff --git a/triton_kernel_agent/templates/kernel_generation.j2 b/triton_kernel_agent/templates/kernel_generation.j2 index 932bf89..68a0f09 100644 --- a/triton_kernel_agent/templates/kernel_generation.j2 +++ b/triton_kernel_agent/templates/kernel_generation.j2 @@ -55,7 +55,9 @@ CRITICAL REQUIREMENTS - NO CHEATING ALLOWED: - The actual computation MUST be implemented using Triton operations (tl.load, tl.store, tl.sum, etc.) - The wrapper function can use PyTorch for tensor creation, memory allocation, and result formatting, but the core computation MUST happen in the Triton kernel - DO NOT import or instantiate torch.nn modules, call torch.nn.functional (including aliases like F.*), or use PyTorch activations/pooling helpers to satisfy the requirements +{% if disable_cuda_math %} - DO NOT rely on cuSolver/cuBLAS libraries to satisfy the requirements +{% endif %} **WHAT IS REQUIRED:** - Write actual Triton kernel code using triton.language operations diff --git a/triton_kernel_agent/templates/kernel_refinement.j2 b/triton_kernel_agent/templates/kernel_refinement.j2 index 3cd4f3e..463b15b 100644 --- a/triton_kernel_agent/templates/kernel_refinement.j2 +++ b/triton_kernel_agent/templates/kernel_refinement.j2 @@ -65,11 +65,14 @@ CRITICAL REQUIREMENTS - NO CHEATING ALLOWED: - The actual computation MUST be implemented using Triton operations (tl.load, tl.store, tl.sum, etc.) - The wrapper function can use PyTorch for tensor creation, memory allocation, and result formatting, but the core computation MUST happen in the Triton kernel - DO NOT import or instantiate torch.nn modules, call torch.nn.functional (including aliases like F.*), or use PyTorch activations/pooling helpers to satisfy the requirements +{% if disable_cuda_math %} +- DO NOT rely on cuSolver/cuBLAS libraries to satisfy the requirements +{% endif %} **WHAT IS REQUIRED:** - Write actual Triton kernel code using triton.language operations - Use tl.load() and tl.store() for memory access -- Use tl.sum(), tl.max(), tl.min() etc. for reductions +- Use tl.sum(), tl.max(), tl.min() etc. for reductions - Use proper Triton indexing with tl.program_id(), tl.arange(), etc. - Implement the algorithm step-by-step in Triton operations, not PyTorch shortcuts - The kernel should be decorated with @triton.jit, not the wrapper function diff --git a/triton_kernel_agent/worker.py b/triton_kernel_agent/worker.py index ef3b73c..76fcd49 100644 --- a/triton_kernel_agent/worker.py +++ b/triton_kernel_agent/worker.py @@ -132,6 +132,7 @@ def __init__( openai_model: str = "gpt-5", high_reasoning_effort: bool = True, target_platform: str = "cuda", + disable_cuda_math: bool = False, ): """ Initialize a verification worker. @@ -146,6 +147,7 @@ def __init__( openai_model: Model name for refinement high_reasoning_effort: Whether to use high reasoning effort for OpenAI models target_platform: Target platform default: cuda + disable_cuda_math: If True, disables cuSolver/cuBLAS library usage """ self.worker_id = worker_id self.workdir = Path(workdir) @@ -155,6 +157,7 @@ def __init__( self.openai_model = openai_model self.high_reasoning_effort = high_reasoning_effort self._platform_config = get_platform(target_platform) + self.disable_cuda_math = disable_cuda_math # Setup files self.kernel_file = self.workdir / "kernel.py" @@ -363,6 +366,7 @@ def _refine_kernel( kernel_code=kernel_code, error_info=error_info, history_context=history_context, + disable_cuda_math=self.disable_cuda_math, ) # Call LLM API