Skip to content

Conversation

@kaiming-cheng
Copy link
Contributor

This PR adds a new Jinja2 template for bottleneck-guided kernel optimization. The bottleneck diagonsis module is introduced in previous PR.

The kernel optimization prompt is integrated into the existing PromptManager class. The current template specify the optimization goal (1.25x than pytorch eager) and format for optimization strategy.

Usage

from triton_kernel_agent.prompt_manager import PromptManager

pm = PromptManager()
prompt = pm.render_kernel_optimization_prompt(
      kernel_code=code,
      problem_description=desc,
      bottleneck_analysis= bottleneck,
      bottleneck_id=1,  # or 2 for secondary bottleneck
      gpu_specs=specs,
      pytorch_baseline_ms= baseline_ms,
  )

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 13, 2026
@kaiming-cheng kaiming-cheng changed the title Add Kernel Optimization Template to PromptManager [Optimization 4/n] Add Kernel Optimization Template to PromptManager Jan 13, 2026
Kaiming Cheng added 23 commits January 15, 2026 11:44
Consolidates previous kernel_benchmark.py and pytorch_benchmark.py into a
streamlined 3-file architecture with clear separation of concerns:

Architecture:
- benchmark.py (299 lines): Main Benchmark class with simplified API
  - benchmark_kernel(): Always uses subprocess for crash protection
  - benchmark_pytorch(): Always uses direct mode for stable code
  - BenchmarkLockManager: GPU lock management for multi-worker scenarios

- timing.py (437 lines): Complete timing infrastructure
  - Timing: time_with_cuda_events(), time_with_triton_do_bench()
  - Loading: prepare_pytorch_model(), load_kernel_function()
  - Stats: compute_timing_stats() with essential metrics (mean/std/min/max)

- kernel_subprocess.py (442 lines): Subprocess runner for kernel isolation
  - Crash protection for potentially buggy kernels
  - Clean CUDA state between runs
  - Timeout handling

Key improvements:
- Eliminated string code generation (was generating Python as strings)
- Removed unnecessary statistics (median, p25/p75/p95/p99)
- Removed confusing use_subprocess parameter (behavior now deterministic)
- Fixed dtype bug causing incorrect speedup measurements
- Reduced from 5 files to 3 files with clearer naming
- Code reduction: ~1,400 lines → 1,178 lines

Simple API:
  bench = Benchmark(logger, temp_dir, lock, worker_id)
  pytorch_result = bench.benchmark_pytorch(problem_file)
  kernel_result = bench.benchmark_kernel(kernel_file, problem_file)
  speedup = pytorch_result['stats']['mean'] / kernel_result['time_ms']
@kaiming-cheng kaiming-cheng force-pushed the kaiming/opt_component_4_clean branch from 2937faa to d5e6edc Compare January 15, 2026 19:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants