From 07a3268dd08dac21f70620bdc07ddbd5ccf9328a Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 11:46:28 -0800 Subject: [PATCH 01/22] NCU profiling wrapper generation and execution --- .../kernel_opt/profiler/__init__.py | 0 .../kernel_opt/profiler/gpu_specs.py | 227 +++++++++ .../kernel_opt/profiler/ncu_profiler.py | 449 ++++++++++++++++++ .../profiling/__init__.py | 6 + .../profiling/kernel_profiler.py | 89 ++++ .../profiling/ncu_wrapper_generator.py | 154 ++++++ 6 files changed, 925 insertions(+) create mode 100644 kernel_perf_agent/kernel_opt/profiler/__init__.py create mode 100644 kernel_perf_agent/kernel_opt/profiler/gpu_specs.py create mode 100644 kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py create mode 100644 triton_kernel_agent/opt_worker_component/profiling/__init__.py create mode 100644 triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py create mode 100644 triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py diff --git a/kernel_perf_agent/kernel_opt/profiler/__init__.py b/kernel_perf_agent/kernel_opt/profiler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py b/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py new file mode 100644 index 0000000..ed8b50f --- /dev/null +++ b/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +""" +GPU Specifications Database for Bottleneck Analysis + +This module provides GPU hardware specifications needed for performance analysis +and bottleneck identification. It includes peak compute performance, memory bandwidth, +cache sizes, and SM counts for common NVIDIA GPUs. + +""" + +import re +import subprocess +from typing import Any, Dict, Optional + + +# GPU specifications database +# Sources: NVIDIA official specifications, manufacturer datasheets +GPU_SPECS_DATABASE = { + "NVIDIA A100": { + "name": "NVIDIA A100", + "architecture": "Ampere", + "peak_fp32_tflops": 19.5, + "peak_fp16_tflops": 312.0, + "peak_bf16_tflops": 312.0, + "peak_memory_bw_gbps": 1555, + "sm_count": 108, + "max_threads_per_sm": 2048, + "l1_cache_kb": 192, + "l2_cache_mb": 40, + "memory_gb": 40, + "memory_type": "HBM2e", + }, + "NVIDIA H100": { + "name": "NVIDIA H100", + "architecture": "Hopper", + "peak_fp32_tflops": 51.0, + "peak_fp16_tflops": 989.0, + "peak_bf16_tflops": 989.0, + "peak_memory_bw_gbps": 3352, + "sm_count": 132, + "max_threads_per_sm": 2048, + "l1_cache_kb": 256, + "l2_cache_mb": 50, + "memory_gb": 80, + "memory_type": "HBM3", + }, + "NVIDIA RTX 4090": { + "name": "NVIDIA RTX 4090", + "architecture": "Ada Lovelace", + "peak_fp32_tflops": 82.6, + "peak_fp16_tflops": 165.0, + "peak_bf16_tflops": 165.0, + "peak_memory_bw_gbps": 1008, + "sm_count": 128, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 72, + "memory_gb": 24, + "memory_type": "GDDR6X", + }, + "NVIDIA RTX 5080": { + "name": "NVIDIA RTX 5080", + "architecture": "Blackwell", + "peak_fp32_tflops": 57.0, + "peak_fp16_tflops": 114.0, + "peak_bf16_tflops": 114.0, + "peak_memory_bw_gbps": 960, + "sm_count": 84, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 64, + "memory_gb": 16, + "memory_type": "GDDR7", + }, +} + + +def query_gpu_name() -> Optional[str]: + """ + Query GPU name using nvidia-smi. + + Returns: + GPU name string, or None if query fails + """ + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + gpu_name = result.stdout.strip() + # Normalize whitespace + gpu_name = re.sub(r"\s+", " ", gpu_name) + return gpu_name + except (subprocess.TimeoutExpired, FileNotFoundError, Exception): + pass + return None + + +def get_gpu_specs(gpu_name: Optional[str] = None) -> Dict[str, Any]: + """ + Get GPU specifications for bottleneck analysis. + + This function returns hardware specifications needed for performance analysis, + including peak compute performance, memory bandwidth, cache sizes, and SM counts. + + Args: + gpu_name: GPU name (if None, auto-detect with nvidia-smi) + + Returns: + Dictionary with GPU specifications containing: + - name: GPU name + - architecture: GPU architecture (e.g., "Ampere", "Hopper") + - peak_fp32_tflops: Peak FP32 compute performance in TFLOPS + - peak_fp16_tflops: Peak FP16 compute performance in TFLOPS + - peak_bf16_tflops: Peak BF16 compute performance in TFLOPS (0 if not supported) + - peak_memory_bw_gbps: Peak memory bandwidth in GB/s + - sm_count: Number of streaming multiprocessors + - max_threads_per_sm: Maximum threads per SM + - l1_cache_kb: L1 cache size in KB per SM + - l2_cache_mb: Total L2 cache size in MB + - memory_gb: Total GPU memory in GB + - memory_type: Memory type (e.g., "HBM2e", "GDDR6X") + + Examples: + >>> specs = get_gpu_specs() # Auto-detect + >>> print(f"Peak BW: {specs['peak_memory_bw_gbps']} GB/s") + + >>> specs = get_gpu_specs("NVIDIA A100") + >>> print(f"SM Count: {specs['sm_count']}") + """ + # Auto-detect if not provided + if gpu_name is None: + gpu_name = query_gpu_name() + + # Return default if detection failed + if gpu_name is None: + print("⚠️ GPU auto-detection failed, using A100 specs as fallback") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + # Try exact match + if gpu_name in GPU_SPECS_DATABASE: + return GPU_SPECS_DATABASE[gpu_name].copy() + + # Try fuzzy match (contains or partial match) + gpu_name_lower = gpu_name.lower() + for key, specs in GPU_SPECS_DATABASE.items(): + key_lower = key.lower() + # Check if either name contains the other + if gpu_name_lower in key_lower or key_lower in gpu_name_lower: + print(f"ℹ️ Matched '{gpu_name}' to '{key}' (fuzzy match)") + return specs.copy() + + # Fallback to A100 specs with warning + print(f"⚠️ Unknown GPU: '{gpu_name}', using A100 specs as fallback") + print(f" Available GPUs: {', '.join(GPU_SPECS_DATABASE.keys())}") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + +def format_gpu_specs_for_prompt(specs: Dict[str, Any]) -> str: + """ + Format GPU specifications as a readable string for LLM prompts. + + Args: + specs: GPU specifications dictionary from get_gpu_specs() + + Returns: + Formatted string with GPU specifications + + Example: + >>> specs = get_gpu_specs("NVIDIA A100") + >>> print(format_gpu_specs_for_prompt(specs)) + GPU: NVIDIA A100 (Ampere) + - Peak Memory Bandwidth: 1555 GB/s + - Peak FP32 Performance: 19.5 TFLOPS + - Peak FP16 Performance: 312.0 TFLOPS + - SM Count: 108 + - L1 Cache: 192 KB per SM + - L2 Cache: 40 MB + """ + lines = [ + f"GPU: {specs['name']} ({specs.get('architecture', 'Unknown')})", + f"- Peak Memory Bandwidth: {specs['peak_memory_bw_gbps']} GB/s", + f"- Peak FP32 Performance: {specs['peak_fp32_tflops']} TFLOPS", + ] + + if specs.get("peak_fp16_tflops", 0) > 0: + lines.append(f"- Peak FP16 Performance: {specs['peak_fp16_tflops']} TFLOPS") + + if specs.get("peak_bf16_tflops", 0) > 0: + lines.append(f"- Peak BF16 Performance: {specs['peak_bf16_tflops']} TFLOPS") + + lines.extend( + [ + f"- SM Count: {specs['sm_count']}", + f"- Max Threads per SM: {specs['max_threads_per_sm']}", + f"- L1 Cache: {specs['l1_cache_kb']} KB per SM", + f"- L2 Cache: {specs['l2_cache_mb']} MB", + f"- Memory: {specs['memory_gb']} GB {specs.get('memory_type', '')}", + ] + ) + + return "\n".join(lines) + + +if __name__ == "__main__": + print("GPU Specifications Module") + print("=" * 60) + + # Auto-detect GPU + detected_name = query_gpu_name() + if detected_name: + print(f"\nDetected GPU: {detected_name}") + else: + print("\nNo GPU detected (nvidia-smi not available)") + + # Get specs + specs = get_gpu_specs() + print(f"\n{format_gpu_specs_for_prompt(specs)}") + + # Show all available GPUs + print(f"\n{'=' * 60}") + print("Available GPU specifications in database:") + for gpu_name in sorted(GPU_SPECS_DATABASE.keys()): + print(f" - {gpu_name}") diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py new file mode 100644 index 0000000..26e59d2 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -0,0 +1,449 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +NCU Profiling Module for Triton Kernels + +This module wraps three tasks: +1) Collect core metrics for Triton CUDA kernels with Nsight Compute into CSV (`profile_triton_kernel`). +2) Extract and clean those metrics into a DataFrame from the CSV (`load_ncu_metrics`). +3) Convert the metrics table into a string suitable for inclusion in an LLM prompt (`metrics_to_prompt`). + +""" + +import json +import math +import os +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Union + +import numpy as np +import pandas as pd + + +__all__ = [ + "METRICS", + "METRIC_COLUMNS", + "profile_triton_kernel", + "load_ncu_metrics", + "metrics_to_prompt", +] + +METRICS = ",".join( + [ + "sm__cycles_active.avg", + "sm__warps_active.avg.pct_of_peak_sustained_active", + "launch__occupancy_limit_blocks", + "launch__occupancy_limit_registers", + "launch__occupancy_limit_shared_mem", + "launch__registers_per_thread", + "sm__inst_executed.sum", + "sm__inst_executed_pipe_fp32.avg.pct_of_peak_sustained_active", + "sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", + "dram__bytes_read.sum", + "dram__bytes_write.sum", + "dram__throughput.avg.pct_of_peak_sustained_elapsed", + "dram__bytes.sum.per_second", + "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed", + "launch__shared_mem_per_block_allocated", + "l1tex__t_sector_hit_rate.pct", + "l1tex__throughput.avg.pct_of_peak_sustained_active", + "lts__t_sector_hit_rate.pct", + "lts__throughput.avg.pct_of_peak_sustained_active", + "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", + "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", + "smsp__warp_issue_stalled_memory_dependency_per_warp_active.pct", + "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", + "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", + "smsp__warp_issue_stalled_barrier_per_warp_active.pct", + "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", + "smsp__sass_average_branch_targets_threads_uniform.pct", + ] +) + +# METRICS = ",".join( +# [ +# # Tensor Core (MOST CRITICAL for GEMM) +# "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", +# # Memory Bandwidth & Cache +# "dram__throughput.avg.pct_of_peak_sustained_elapsed", +# "l1tex__t_sector_hit_rate.pct", +# "lts__t_sector_hit_rate.pct", +# # Memory Access Efficiency +# "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", +# # Stalls & Bottlenecks (H100-compatible stall metrics) +# "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", +# "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", +# # Occupancy & Resources +# "launch__shared_mem_per_block_allocated", +# "sm__warps_active.avg.pct_of_peak_sustained_elapsed", +# # SM Utilization +# "sm__cycles_active.avg.pct_of_peak_sustained_elapsed", +# ] +# ) + +# # CRITICAL GEMM Optimization Metrics (Optimized for Matrix Multiplication) +# METRICS = ",".join( +# [ +# # === Memory Efficiency Metrics (CRITICAL) === +# "dram__throughput.avg.pct_of_peak_sustained_elapsed", # DRAM bandwidth utilization +# "dram__bytes_read.sum", # Total DRAM reads +# "dram__bytes_write.sum", # Total DRAM writes +# "l1tex__t_sector_hit_rate.pct", # L1 cache hit rate (CRITICAL for data reuse) +# "l1tex__throughput.avg.pct_of_peak_sustained_elapsed", # L1 throughput (fixed from _active) +# "lts__t_sector_hit_rate.pct", # L2 cache hit rate +# "lts__throughput.avg.pct_of_peak_sustained_elapsed", # L2 throughput (fixed from _active) +# "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", # Memory coalescing (NEW - CRITICAL) +# "l1tex__data_pipe_lsu_wavefronts_mem_shared_op_ld.sum", # Shared memory loads (NEW) +# # === Compute Utilization Metrics (CRITICAL for GEMM) === +# "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", # Tensor core pipeline (NEW - CRITICAL) +# "smsp__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", # Tensor core inst % +# "smsp__inst_executed_pipe_tensor.sum", # Total tensor core instructions (NEW) +# "sm__inst_executed_pipe_fp32.avg.pct_of_peak_sustained_active", # FP32 pipeline +# "sm__cycles_active.avg.pct_of_peak_sustained_elapsed", # SM active cycles (fixed from .avg only) +# "sm__inst_executed.sum", # Total instructions +# "smsp__inst_executed.avg.per_cycle_active", # IPC (NEW) +# # === Occupancy & Resource Metrics === +# "sm__warps_active.avg.pct_of_peak_sustained_elapsed", # Warp occupancy (fixed from _active) +# "sm__maximum_warps_per_active_cycle_pct", # Max achievable occupancy (NEW) +# "launch__occupancy_limit_blocks", # Block limit +# "launch__occupancy_limit_registers", # Register limit +# "launch__occupancy_limit_shared_mem", # Shared memory limit +# "launch__registers_per_thread", # Registers per thread +# "launch__shared_mem_per_block_allocated", # Shared memory usage (NEW - CRITICAL) +# # === Stall & Latency Metrics === +# "smsp__warp_issue_stalled_memory_dependency_per_warp_active.pct", # Memory dependency stalls +# "smsp__warp_issue_stalled_mem_throttle_per_warp_active.pct", # Memory throttle stalls (NEW - CRITICAL) +# "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", # Short scoreboard stalls +# "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", # Long scoreboard stalls +# "smsp__warp_issue_stalled_barrier_per_warp_active.pct", # Barrier stalls +# "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", # Branch stalls +# ] +# ) + + +# List version for convenient header selection +METRIC_COLUMNS: List[str] = [s.strip() for s in METRICS.split(",")] + + +def profile_triton_kernel( + benchmark_script: Path, + workdir: Path, + out_csv: str = "ncu_output.csv", + python_executable: Optional[str] = None, + ncu_bin: Optional[str] = None, + launch_count: int = 20, + timeout: int = 120, +) -> Path: + """ + Profile a Triton kernel using NCU. + + Args: + benchmark_script: Path to benchmark script that calls the kernel + workdir: Working directory for execution + out_csv: Output CSV filename + python_executable: Python executable to use (default: sys.executable) + ncu_bin: Path to NCU binary (default: auto-detect) + launch_count: Number of kernel launches to profile + timeout: Timeout in seconds for NCU execution + + Returns: + Path to output CSV file + + Raises: + RuntimeError: If NCU profiling fails + FileNotFoundError: If NCU binary or output CSV not found + """ + # Resolve paths + if python_executable is None: + python_executable = sys.executable + + if ncu_bin is None: + ncu_bin = shutil.which("ncu") or "/usr/local/cuda/bin/ncu" + + if not Path(ncu_bin).exists(): + raise FileNotFoundError(f"NCU binary not found: {ncu_bin}") + + csv_path = (workdir / out_csv).resolve() + benchmark_script = benchmark_script.resolve() + + if not benchmark_script.exists(): + raise FileNotFoundError(f"Benchmark script not found: {benchmark_script}") + + # Preserve important environment variables + env = os.environ.copy() + + # Add Triton-specific environment variables + env["TRITON_CACHE_DIR"] = str(workdir / ".triton_cache") + + preserve = ",".join( + [ + "PATH", + "LD_LIBRARY_PATH", + "CUDA_VISIBLE_DEVICES", + "PYTHONPATH", + "TRITON_CACHE_DIR", + "TORCH_EXTENSIONS_DIR", + "CONDA_PREFIX", + "CONDA_DEFAULT_ENV", + ] + ) + + # Build NCU command + cmd = [ + "sudo", + "-E", + f"--preserve-env={preserve}", + ncu_bin, + "--csv", + "--page=raw", + "--kernel-name-base=demangled", + "--target-processes=all", + "--replay-mode=kernel", + "--profile-from-start=on", + f"--log-file={str(csv_path)}", + f"--metrics={METRICS}", + "--launch-skip=0", + f"--launch-count={launch_count}", + python_executable, + str(benchmark_script), + ] + + print(f"[NCU] Running profiling...") + print(f"[NCU] Benchmark: {benchmark_script.name}") + print(f"[NCU] Output: {csv_path}") + print(f"[NCU] Command: {' '.join(cmd[:10])}... (truncated)") + + try: + result = subprocess.run( + cmd, + cwd=str(workdir), + env=env, + capture_output=True, + text=True, + timeout=timeout, + ) + + if result.returncode != 0: + error_msg = result.stderr or result.stdout + raise RuntimeError( + f"NCU profiling failed with return code {result.returncode}:\n{error_msg[:500]}" + ) + + if not csv_path.exists(): + raise FileNotFoundError(f"NCU did not create output CSV: {csv_path}") + + # Check if CSV has content + csv_size = csv_path.stat().st_size + if csv_size < 100: + raise RuntimeError( + f"NCU CSV file is too small ({csv_size} bytes), likely empty" + ) + + print(f"[NCU] ✓ Profiling completed successfully") + print(f"[NCU] ✓ CSV written: {csv_path} ({csv_size} bytes)") + return csv_path + + except subprocess.TimeoutExpired: + raise RuntimeError(f"NCU profiling timed out after {timeout} seconds") + except Exception as e: + raise RuntimeError(f"NCU profiling failed: {e}") + + +def load_ncu_metrics( + csv_path: Union[str, Path], + columns: Optional[Sequence[str]] = None, + extra_keep: Optional[Sequence[str]] = ("Kernel Name",), + coerce_numeric: bool = True, + name_list: Optional[Sequence[str]] = None, + select: str = "last", +) -> pd.DataFrame: + """ + Load and parse NCU metrics from CSV file. + + Args: + csv_path: Path to NCU CSV output + columns: Specific metric columns to load (default: all METRIC_COLUMNS) + extra_keep: Additional columns to keep (e.g., "Kernel Name") + coerce_numeric: Convert metric values to numeric + name_list: Filter by kernel names (substring match) + select: Selection policy when multiple rows per name: + "first", "last", "max_cycles" + + Returns: + DataFrame with parsed metrics + + Raises: + FileNotFoundError: If CSV file not found + ValueError: If no requested columns found in CSV + """ + csv_path = Path(csv_path) + if not csv_path.exists(): + raise FileNotFoundError(f"CSV not found: {csv_path}") + + df = pd.read_csv(csv_path, comment="=", low_memory=False) + + metric_cols = list(columns) if columns is not None else METRIC_COLUMNS + keep_cols: List[str] = [] + if extra_keep: + keep_cols.extend([c for c in extra_keep if c in df.columns]) + keep_cols.extend([c for c in metric_cols if c in df.columns]) + + if not keep_cols: + raise ValueError("No requested columns found in the CSV header.") + + sub = df[keep_cols].copy() + + # Drop the units row (first row often contains units like "%", "inst", etc.) + if len(sub) > 0: + first_row_str = sub.iloc[0].astype(str).str.lower() + unit_tokens = ("%", "inst", "cycle", "block", "register", "register/thread") + if first_row_str.apply(lambda x: any(tok in x for tok in unit_tokens)).any(): + sub = sub.iloc[1:].reset_index(drop=True) + + # Coerce metrics to numeric + if coerce_numeric: + metric_in_sub = [c for c in metric_cols if c in sub.columns] + sub[metric_in_sub] = ( + sub[metric_in_sub] + .replace({",": "", "%": ""}, regex=True) + .apply(pd.to_numeric, errors="coerce") + ) + + # Filter by kernel name list if provided + if name_list: + results = [] + for name in name_list: + # Use contains match instead of exact equality (for Triton's long kernel names) + matched = sub[ + sub["Kernel Name"].astype(str).str.contains(name, regex=False, na=False) + ] + if matched.empty: + continue + if len(matched) > 1: + if select == "first": + row = matched.iloc[[0]] + elif select == "last": + row = matched.iloc[[-1]] + elif ( + select == "max_cycles" + and "sm__cycles_active.avg" in matched.columns + ): + row = matched.sort_values( + "sm__cycles_active.avg", ascending=False + ).head(1) + else: + row = matched.iloc[[-1]] # fallback + else: + row = matched + results.append(row) + + if results: + sub = pd.concat(results, ignore_index=True) + else: + sub = pd.DataFrame(columns=keep_cols) + elif select in ("first", "last", "max_cycles"): + # Apply selection to all rows if no name filter + if len(sub) > 0: + if select == "first": + sub = sub.iloc[[0]] + elif select == "last": + sub = sub.iloc[[-1]] + elif select == "max_cycles" and "sm__cycles_active.avg" in sub.columns: + sub = sub.sort_values("sm__cycles_active.avg", ascending=False).head(1) + + return sub + + +def metrics_to_prompt( + df: pd.DataFrame, + title: str = "GPU Profiling Metrics:", + key_by: str = "Kernel Name", + round_digits: Optional[int] = 3, + compact: bool = False, + keep_cols: Optional[List[str]] = None, +) -> str: + """ + Convert NCU metrics DataFrame to JSON string for LLM prompts. + + Returns JSON in format: + { + "": { "": , ... } + } + Args: + df: DataFrame with NCU metrics + title: Title for the metrics section (not included in output) + key_by: Column to use as key (usually "Kernel Name") + round_digits: Number of decimal places for rounding + compact: If True, use compact JSON (no indentation) + keep_cols: Specific columns to include in output + + Returns: + JSON string with metrics + """ + + def _safe(v: Any) -> Any: + """Convert values to JSON-safe format.""" + if v is None or (isinstance(v, float) and math.isnan(v)): + return None + if isinstance(v, (pd.Timestamp, pd.Timedelta, pd.Interval)): + return str(v) + if isinstance(v, np.generic): + v = v.item() + if isinstance(v, float) and math.isinf(v): + return "inf" if v > 0 else "-inf" + if isinstance(v, float) and round_digits is not None: + return round(v, round_digits) + return v + + # Empty table + if df is None or df.empty: + return "{}" + + cols = list(df.columns) + + # Round numeric columns + if round_digits is not None: + num_cols = df.select_dtypes(include="number").columns + if len(num_cols) > 0: + df = df.copy() + df[num_cols] = df[num_cols].round(round_digits) + + # If key column is missing, return a list of rows + if key_by not in cols: + rows = [ + {k: _safe(v) for k, v in rec.items()} + for rec in df.to_dict(orient="records") + ] + return json.dumps(rows, ensure_ascii=False, indent=None if compact else 2) + + # Determine value columns + value_cols = [c for c in cols if c != key_by] + if keep_cols is not None: + value_cols = [c for c in value_cols if c in keep_cols] + + data: Dict[str, Any] = {} + for rec in df[[key_by] + value_cols].to_dict(orient="records"): + k = str(rec.pop(key_by)) + val_obj = {ck: _safe(cv) for ck, cv in rec.items()} + if k in data: + # Multiple rows for same key - convert to list + if isinstance(data[k], list): + data[k].append(val_obj) + else: + data[k] = [data[k], val_obj] + else: + data[k] = val_obj + + return json.dumps(data, ensure_ascii=False, indent=None if compact else 2) + + +if __name__ == "__main__": + print("ncu_profiler module loaded.") + print("Import its functions in your scripts:") + print( + " from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import profile_triton_kernel, load_ncu_metrics, metrics_to_prompt" + ) diff --git a/triton_kernel_agent/opt_worker_component/profiling/__init__.py b/triton_kernel_agent/opt_worker_component/profiling/__init__.py new file mode 100644 index 0000000..cb28cc6 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/__init__.py @@ -0,0 +1,6 @@ +"""Profiling infrastructure for NCU-based kernel analysis.""" + +from .ncu_wrapper_generator import NCUWrapperGenerator +from .kernel_profiler import KernelProfiler + +__all__ = ["NCUWrapperGenerator", "KernelProfiler"] diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py new file mode 100644 index 0000000..7ba98ef --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -0,0 +1,89 @@ +"""Kernel profiling with NCU.""" + +import json +import logging +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import pandas as pd + +from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( + load_ncu_metrics, + metrics_to_prompt, + profile_triton_kernel, +) + +from .ncu_wrapper_generator import NCUWrapperGenerator + + +class KernelProfiler: + """Profiles Triton kernels using NVIDIA Nsight Compute (NCU).""" + + def __init__( + self, + logger: logging.Logger, + temp_dir: Path, + logs_dir: Path, + ncu_bin_path: Optional[str] = None, + ): + """ + Initialize the kernel profiler. + + Args: + logger: Logger instance + temp_dir: Temporary directory for profiling artifacts + logs_dir: Directory for saving profiling logs + ncu_bin_path: Path to NCU binary (auto-detect if None) + """ + self.logger = logger + self.temp_dir = temp_dir + self.logs_dir = logs_dir + self.ncu_bin_path = ncu_bin_path + self.wrapper_generator = NCUWrapperGenerator(logger) + + def profile_kernel( + self, kernel_file: Path, problem_file: Path, round_num: int + ) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]: + """ + Profile kernel with NCU. + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + round_num: Current optimization round number + + Returns: + Tuple of (metrics_df, metrics_json) or (None, None) on failure + """ + try: + # Create NCU wrapper script + wrapper_file = self.wrapper_generator.create_ncu_wrapper( + kernel_file, problem_file, self.temp_dir + ) + + # Profile with NCU + csv_file = f"ncu_round_{round_num}.csv" + csv_path = profile_triton_kernel( + benchmark_script=wrapper_file, + workdir=self.temp_dir, + out_csv=csv_file, + ncu_bin=self.ncu_bin_path, + launch_count=20, + timeout=120, + ) + + # Load and parse metrics + metrics_df = load_ncu_metrics(csv_path, select="last") + metrics_json = json.loads(metrics_to_prompt(metrics_df)) + + # Save metrics JSON + metrics_file = self.logs_dir / f"round{round_num:03d}_ncu_metrics.json" + with open(metrics_file, "w") as f: + json.dump(metrics_json, f, indent=2) + + self.logger.info(f"✅ NCU profiling completed for round {round_num}") + return metrics_df, metrics_json + + except Exception as e: + self.logger.error(f"❌ NCU profiling failed: {e}") + return None, None diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py new file mode 100644 index 0000000..96f391f --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py @@ -0,0 +1,154 @@ +"""NCU wrapper script generation for kernel profiling.""" + +import logging +from pathlib import Path + + +class NCUWrapperGenerator: + """Generates NCU wrapper scripts for profiling Triton kernels.""" + + def __init__(self, logger: logging.Logger): + """ + Initialize the NCU wrapper generator. + + Args: + logger: Logger instance + """ + self.logger = logger + + def create_ncu_wrapper(self, kernel_file: Path, problem_file: Path, output_dir: Path) -> Path: + """ + Create NCU wrapper script for profiling. + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + output_dir: Directory to write wrapper script + + Returns: + Path to created wrapper script + """ + wrapper_file = output_dir / "ncu_wrapper.py" + + wrapper_content = f'''"""NCU profiling wrapper.""" +import sys +import torch +import inspect +sys.path.insert(0, str({str(kernel_file.parent)!r})) +sys.path.insert(0, str({str(problem_file.parent)!r})) + +from {kernel_file.stem} import kernel_function +from {problem_file.stem} import get_inputs, get_init_inputs + +# Try to import Model if it exists (for Conv, Linear, etc.) +try: + from {problem_file.stem} import Model + has_model = True +except ImportError: + has_model = False + +# Get inputs +inputs = get_inputs() + +# Get additional initialization inputs (e.g., features, eps for RMSNorm) +init_inputs = get_init_inputs() + +# Infer required dtype from kernel function signature/docstring +required_dtype = None +try: + # Try to get dtype from kernel function docstring or source + kernel_source = inspect.getsource(kernel_function) + if 'bfloat16' in kernel_source.lower(): + required_dtype = torch.bfloat16 + elif 'float16' in kernel_source.lower() or 'half' in kernel_source.lower(): + required_dtype = torch.float16 + elif 'float32' in kernel_source.lower(): + required_dtype = torch.float32 +except: + pass + +# Prepare inputs: move to CUDA and convert dtype if needed +# IMPORTANT: Only convert floating-point tensors; preserve integer tensors (e.g., class labels) +cuda_inputs = [] +for inp in inputs: + if isinstance(inp, torch.Tensor): + # Move to CUDA if not already + if not inp.is_cuda: + inp = inp.cuda() + # Convert dtype if required, but ONLY for floating-point tensors + # Preserve integer/bool tensors (e.g., targets for classification) + if required_dtype is not None and inp.is_floating_point() and inp.dtype != required_dtype: + inp = inp.to(required_dtype) + cuda_inputs.append(inp) + else: + cuda_inputs.append(inp) + +# Check if this is a conv-like kernel that needs a Model to extract weights +needs_model = False +try: + sig = inspect.signature(kernel_function) + params = list(sig.parameters.keys()) + # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) + if 'weight' in params: + needs_model = True +except: + pass + +if needs_model and has_model and init_inputs: + # Initialize model to extract weight and bias + model = Model(*init_inputs) if init_inputs else Model() + + # Move model to CUDA and convert dtype + model = model.cuda() + if required_dtype is not None: + model = model.to(required_dtype) + + # Extract weight and bias from model + # Check various possible attribute names + weight = None + bias = None + layer = None + for attr_name in ['conv1', 'conv2', 'conv3', 'conv1d', 'conv2d', 'conv', 'conv3d', 'linear', 'fc']: + if hasattr(model, attr_name): + layer = getattr(model, attr_name) + if hasattr(layer, 'weight'): + weight = layer.weight + bias = layer.bias if hasattr(layer, 'bias') else None + break + + if weight is not None and layer is not None: + # Build arguments for kernel_function using keyword arguments + # to avoid positional argument misalignment issues + kernel_kwargs = {{}} + + # Add conv/linear-specific parameters if they exist + if hasattr(layer, 'stride'): + stride = layer.stride[0] if isinstance(layer.stride, (tuple, list)) else layer.stride + kernel_kwargs['stride'] = stride + if hasattr(layer, 'padding'): + padding = layer.padding[0] if isinstance(layer.padding, (tuple, list)) else layer.padding + kernel_kwargs['padding'] = padding + if hasattr(layer, 'dilation'): + dilation = layer.dilation[0] if isinstance(layer.dilation, (tuple, list)) else layer.dilation + kernel_kwargs['dilation'] = dilation + if hasattr(layer, 'groups'): + kernel_kwargs['groups'] = layer.groups + + # Call kernel with extracted parameters + output = kernel_function(cuda_inputs[0], weight, bias, **kernel_kwargs) + else: + # Fallback to original behavior + output = kernel_function(*cuda_inputs, *init_inputs) +else: + # Run kernel with both tensor inputs and initialization inputs + # For example: RMSNorm needs kernel_function(x, features, eps) + # For cross-entropy: kernel_function(predictions, targets) + # where inputs come from get_inputs() and init_inputs from get_init_inputs() + output = kernel_function(*cuda_inputs, *init_inputs) + +print("Kernel executed successfully, output shape: " + str(output.shape if hasattr(output, 'shape') else type(output))) +''' + + wrapper_file.write_text(wrapper_content) + self.logger.info(f"Created NCU wrapper: {wrapper_file}") + return wrapper_file From 3c4b124ed826e9b9538f07d7fbe42089de882aa4 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 12:49:50 -0800 Subject: [PATCH 02/22] Refactor profiling components and add kernel_perf_util --- kernel_perf_util/README.md | 8 + .../kernel_opt/profiler/__init__.py | 0 .../kernel_opt/profiler/gpu_specs.py | 0 .../kernel_opt/profiler/ncu_profiler.py | 0 .../profiling/kernel_profiler.py | 190 +++++++++++++--- .../profiling/ncu_wrapper_generator.py | 210 ++++++++---------- .../profiling/ncu_wrapper_template.j2 | 128 +++++++++++ .../opt_worker_components/README.md | 12 + .../profiling/__init__.py | 6 + .../profiling/kernel_profiler.py | 89 ++++++++ .../profiling/ncu_wrapper_generator.py | 154 +++++++++++++ 11 files changed, 645 insertions(+), 152 deletions(-) create mode 100644 kernel_perf_util/README.md rename {kernel_perf_agent => kernel_perf_util}/kernel_opt/profiler/__init__.py (100%) rename {kernel_perf_agent => kernel_perf_util}/kernel_opt/profiler/gpu_specs.py (100%) rename {kernel_perf_agent => kernel_perf_util}/kernel_opt/profiler/ncu_profiler.py (100%) create mode 100644 triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 create mode 100644 triton_kernel_agent/opt_worker_components/README.md create mode 100644 triton_kernel_agent/opt_worker_components/profiling/__init__.py create mode 100644 triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py create mode 100644 triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py diff --git a/kernel_perf_util/README.md b/kernel_perf_util/README.md new file mode 100644 index 0000000..5114402 --- /dev/null +++ b/kernel_perf_util/README.md @@ -0,0 +1,8 @@ +# Kernel Performance Utilities + +Low-level, reusable utilities for kernel optimization: +- NCU profiling +- RAG-based optimization database +- Performance analysis tools + +These are **library components** that can be used by any worker. diff --git a/kernel_perf_agent/kernel_opt/profiler/__init__.py b/kernel_perf_util/kernel_opt/profiler/__init__.py similarity index 100% rename from kernel_perf_agent/kernel_opt/profiler/__init__.py rename to kernel_perf_util/kernel_opt/profiler/__init__.py diff --git a/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py b/kernel_perf_util/kernel_opt/profiler/gpu_specs.py similarity index 100% rename from kernel_perf_agent/kernel_opt/profiler/gpu_specs.py rename to kernel_perf_util/kernel_opt/profiler/gpu_specs.py diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py similarity index 100% rename from kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py rename to kernel_perf_util/kernel_opt/profiler/ncu_profiler.py diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index 7ba98ef..3fc991a 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -2,6 +2,9 @@ import json import logging +import subprocess +import time +from datetime import datetime from pathlib import Path from typing import Any, Dict, Optional, Tuple @@ -41,49 +44,170 @@ def __init__( self.ncu_bin_path = ncu_bin_path self.wrapper_generator = NCUWrapperGenerator(logger) + def _get_ncu_version(self) -> Optional[str]: + """ + Get NCU version string. + + Returns: + NCU version string or None if failed + """ + try: + ncu_cmd = self.ncu_bin_path or "ncu" + result = subprocess.run( + [ncu_cmd, "--version"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + # Extract version from output (e.g., "NVIDIA Nsight Compute 2024.3.1") + version_line = result.stdout.strip().split('\n')[0] + return version_line + return None + except Exception as e: + self.logger.debug(f"Failed to get NCU version: {e}") + return None + def profile_kernel( - self, kernel_file: Path, problem_file: Path, round_num: int + self, + kernel_file: Path, + problem_file: Path, + round_num: int, + max_retries: int = 3, ) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]: """ - Profile kernel with NCU. + Profile kernel with NCU (with retry logic). + + NCU profiling can fail due to GPU contention or transient issues. + This method automatically retries with exponential backoff. Args: kernel_file: Path to kernel file problem_file: Path to problem file round_num: Current optimization round number + max_retries: Maximum number of retry attempts (default: 3) Returns: Tuple of (metrics_df, metrics_json) or (None, None) on failure - """ - try: - # Create NCU wrapper script - wrapper_file = self.wrapper_generator.create_ncu_wrapper( - kernel_file, problem_file, self.temp_dir - ) - # Profile with NCU - csv_file = f"ncu_round_{round_num}.csv" - csv_path = profile_triton_kernel( - benchmark_script=wrapper_file, - workdir=self.temp_dir, - out_csv=csv_file, - ncu_bin=self.ncu_bin_path, - launch_count=20, - timeout=120, - ) - - # Load and parse metrics - metrics_df = load_ncu_metrics(csv_path, select="last") - metrics_json = json.loads(metrics_to_prompt(metrics_df)) - - # Save metrics JSON - metrics_file = self.logs_dir / f"round{round_num:03d}_ncu_metrics.json" - with open(metrics_file, "w") as f: - json.dump(metrics_json, f, indent=2) - - self.logger.info(f"✅ NCU profiling completed for round {round_num}") - return metrics_df, metrics_json + Example: + >>> profiler = KernelProfiler(logger, temp_dir, logs_dir) + >>> metrics_df, metrics_json = profiler.profile_kernel( + ... Path("kernel.py"), Path("problem.py"), round_num=1 + ... ) + >>> if metrics_json: + ... print(f"DRAM throughput: {metrics_json['dram__throughput']}") + """ + wrapper_file = None + + for attempt in range(1, max_retries + 1): + try: + # Create NCU wrapper script (cached if unchanged) + if wrapper_file is None: + wrapper_file = self.wrapper_generator.create_ncu_wrapper( + kernel_file, problem_file, self.temp_dir + ) + + self.logger.info( + f"[Round {round_num}] NCU profiling attempt {attempt}/{max_retries}..." + ) + + # Profile with NCU + csv_file = f"ncu_round_{round_num}.csv" + csv_path = profile_triton_kernel( + benchmark_script=wrapper_file, + workdir=self.temp_dir, + out_csv=csv_file, + ncu_bin=self.ncu_bin_path, + launch_count=20, + timeout=120, + ) + + # Load and parse metrics + metrics_df = load_ncu_metrics(csv_path, select="last") + metrics_json = json.loads(metrics_to_prompt(metrics_df)) + + # Save metrics with metadata + self._save_metrics_with_metadata( + metrics_json, kernel_file, problem_file, round_num + ) + + self.logger.info(f"✅ NCU profiling completed for round {round_num}") + return metrics_df, metrics_json + + except FileNotFoundError as e: + self.logger.error(f"❌ File not found during profiling: {e}") + return None, None + + except subprocess.TimeoutExpired: + self.logger.error( + f"❌ NCU profiling timed out after 120s (attempt {attempt}/{max_retries})" + ) + if attempt < max_retries: + wait_time = 2 ** attempt # Exponential backoff: 2, 4, 8 seconds + self.logger.warning(f"Retrying in {wait_time}s...") + time.sleep(wait_time) + continue + return None, None + + except json.JSONDecodeError as e: + self.logger.error(f"❌ Failed to parse NCU metrics: {e}") + if attempt < max_retries: + wait_time = 2 ** attempt + self.logger.warning(f"Retrying in {wait_time}s...") + time.sleep(wait_time) + continue + return None, None + + except Exception as e: + self.logger.error( + f"❌ Unexpected error during profiling (attempt {attempt}/{max_retries}): {e}", + exc_info=True + ) + if attempt < max_retries: + wait_time = 2 ** attempt + self.logger.warning(f"Retrying in {wait_time}s...") + time.sleep(wait_time) + continue + return None, None + + # All retries exhausted + self.logger.error( + f"❌ NCU profiling failed after {max_retries} attempts for round {round_num}" + ) + return None, None + + def _save_metrics_with_metadata( + self, + metrics_json: Dict[str, Any], + kernel_file: Path, + problem_file: Path, + round_num: int, + ) -> None: + """ + Save profiling metrics with additional metadata. - except Exception as e: - self.logger.error(f"❌ NCU profiling failed: {e}") - return None, None + Args: + metrics_json: NCU metrics as JSON dict + kernel_file: Path to kernel file + problem_file: Path to problem file + round_num: Current optimization round number + """ + metrics_file = self.logs_dir / f"round{round_num:03d}_ncu_metrics.json" + + # Build metadata + metadata = { + "metrics": metrics_json, + "metadata": { + "kernel_file": str(kernel_file), + "problem_file": str(problem_file), + "round_num": round_num, + "timestamp": datetime.utcnow().isoformat() + "Z", + "ncu_version": self._get_ncu_version(), + }, + } + + with open(metrics_file, "w") as f: + json.dump(metadata, f, indent=2) + + self.logger.debug(f"Saved metrics with metadata: {metrics_file}") diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py index 96f391f..9f0e2e5 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py @@ -2,11 +2,21 @@ import logging from pathlib import Path +from typing import Optional + +try: + from jinja2 import Template + HAS_JINJA2 = True +except ImportError: + HAS_JINJA2 = False class NCUWrapperGenerator: """Generates NCU wrapper scripts for profiling Triton kernels.""" + # Template file path (relative to this file) + WRAPPER_TEMPLATE = Path(__file__).parent / "ncu_wrapper_template.j2" + def __init__(self, logger: logging.Logger): """ Initialize the NCU wrapper generator. @@ -15,140 +25,102 @@ def __init__(self, logger: logging.Logger): logger: Logger instance """ self.logger = logger + self._template_cache: Optional[Template] = None - def create_ncu_wrapper(self, kernel_file: Path, problem_file: Path, output_dir: Path) -> Path: + def _load_template(self) -> Template: + """ + Load the Jinja2 template (cached). + + Returns: + Jinja2 Template object + + Raises: + ImportError: If Jinja2 is not installed + FileNotFoundError: If template file doesn't exist + """ + if self._template_cache is not None: + return self._template_cache + + if not HAS_JINJA2: + raise ImportError( + "Jinja2 is required for wrapper generation. " + "Install it with: pip install jinja2" + ) + + if not self.WRAPPER_TEMPLATE.exists(): + raise FileNotFoundError(f"Template not found: {self.WRAPPER_TEMPLATE}") + + self._template_cache = Template(self.WRAPPER_TEMPLATE.read_text()) + return self._template_cache + + def create_ncu_wrapper( + self, + kernel_file: Path, + problem_file: Path, + output_dir: Path, + dtype_inference: bool = True, + model_extraction: bool = True, + use_cache: bool = True, + ) -> Path: """ Create NCU wrapper script for profiling. + The wrapper handles multiple kernel types: + - Standard kernels: kernel_function(*inputs) + - Conv/Linear kernels: Extracts weights from Model + - RMSNorm kernels: Passes init_inputs (features, eps) + Args: kernel_file: Path to kernel file problem_file: Path to problem file output_dir: Directory to write wrapper script + dtype_inference: Enable automatic dtype inference from kernel source (default: True) + model_extraction: Enable model weight extraction for Conv/Linear kernels (default: True) + use_cache: Reuse existing wrapper if files haven't changed (default: True) Returns: Path to created wrapper script - """ - wrapper_file = output_dir / "ncu_wrapper.py" - - wrapper_content = f'''"""NCU profiling wrapper.""" -import sys -import torch -import inspect -sys.path.insert(0, str({str(kernel_file.parent)!r})) -sys.path.insert(0, str({str(problem_file.parent)!r})) - -from {kernel_file.stem} import kernel_function -from {problem_file.stem} import get_inputs, get_init_inputs - -# Try to import Model if it exists (for Conv, Linear, etc.) -try: - from {problem_file.stem} import Model - has_model = True -except ImportError: - has_model = False -# Get inputs -inputs = get_inputs() + Raises: + FileNotFoundError: If kernel_file or problem_file doesn't exist + OSError: If output_dir is not writable + """ + # Validate inputs + if not kernel_file.exists(): + raise FileNotFoundError(f"Kernel file not found: {kernel_file}") + if not problem_file.exists(): + raise FileNotFoundError(f"Problem file not found: {problem_file}") -# Get additional initialization inputs (e.g., features, eps for RMSNorm) -init_inputs = get_init_inputs() + # Ensure output directory exists + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) -# Infer required dtype from kernel function signature/docstring -required_dtype = None -try: - # Try to get dtype from kernel function docstring or source - kernel_source = inspect.getsource(kernel_function) - if 'bfloat16' in kernel_source.lower(): - required_dtype = torch.bfloat16 - elif 'float16' in kernel_source.lower() or 'half' in kernel_source.lower(): - required_dtype = torch.float16 - elif 'float32' in kernel_source.lower(): - required_dtype = torch.float32 -except: - pass - -# Prepare inputs: move to CUDA and convert dtype if needed -# IMPORTANT: Only convert floating-point tensors; preserve integer tensors (e.g., class labels) -cuda_inputs = [] -for inp in inputs: - if isinstance(inp, torch.Tensor): - # Move to CUDA if not already - if not inp.is_cuda: - inp = inp.cuda() - # Convert dtype if required, but ONLY for floating-point tensors - # Preserve integer/bool tensors (e.g., targets for classification) - if required_dtype is not None and inp.is_floating_point() and inp.dtype != required_dtype: - inp = inp.to(required_dtype) - cuda_inputs.append(inp) - else: - cuda_inputs.append(inp) - -# Check if this is a conv-like kernel that needs a Model to extract weights -needs_model = False -try: - sig = inspect.signature(kernel_function) - params = list(sig.parameters.keys()) - # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) - if 'weight' in params: - needs_model = True -except: - pass - -if needs_model and has_model and init_inputs: - # Initialize model to extract weight and bias - model = Model(*init_inputs) if init_inputs else Model() - - # Move model to CUDA and convert dtype - model = model.cuda() - if required_dtype is not None: - model = model.to(required_dtype) - - # Extract weight and bias from model - # Check various possible attribute names - weight = None - bias = None - layer = None - for attr_name in ['conv1', 'conv2', 'conv3', 'conv1d', 'conv2d', 'conv', 'conv3d', 'linear', 'fc']: - if hasattr(model, attr_name): - layer = getattr(model, attr_name) - if hasattr(layer, 'weight'): - weight = layer.weight - bias = layer.bias if hasattr(layer, 'bias') else None - break - - if weight is not None and layer is not None: - # Build arguments for kernel_function using keyword arguments - # to avoid positional argument misalignment issues - kernel_kwargs = {{}} - - # Add conv/linear-specific parameters if they exist - if hasattr(layer, 'stride'): - stride = layer.stride[0] if isinstance(layer.stride, (tuple, list)) else layer.stride - kernel_kwargs['stride'] = stride - if hasattr(layer, 'padding'): - padding = layer.padding[0] if isinstance(layer.padding, (tuple, list)) else layer.padding - kernel_kwargs['padding'] = padding - if hasattr(layer, 'dilation'): - dilation = layer.dilation[0] if isinstance(layer.dilation, (tuple, list)) else layer.dilation - kernel_kwargs['dilation'] = dilation - if hasattr(layer, 'groups'): - kernel_kwargs['groups'] = layer.groups - - # Call kernel with extracted parameters - output = kernel_function(cuda_inputs[0], weight, bias, **kernel_kwargs) - else: - # Fallback to original behavior - output = kernel_function(*cuda_inputs, *init_inputs) -else: - # Run kernel with both tensor inputs and initialization inputs - # For example: RMSNorm needs kernel_function(x, features, eps) - # For cross-entropy: kernel_function(predictions, targets) - # where inputs come from get_inputs() and init_inputs from get_init_inputs() - output = kernel_function(*cuda_inputs, *init_inputs) - -print("Kernel executed successfully, output shape: " + str(output.shape if hasattr(output, 'shape') else type(output))) -''' + wrapper_file = output_dir / "ncu_wrapper.py" + # Check cache: reuse wrapper if it's fresh + if use_cache and wrapper_file.exists(): + wrapper_mtime = wrapper_file.stat().st_mtime + kernel_mtime = kernel_file.stat().st_mtime + problem_mtime = problem_file.stat().st_mtime + + if wrapper_mtime > kernel_mtime and wrapper_mtime > problem_mtime: + self.logger.info( + f"Reusing cached NCU wrapper (fresher than source files): {wrapper_file}" + ) + return wrapper_file + + # Load template and render + template = self._load_template() + wrapper_content = template.render( + kernel_file_parent=repr(str(kernel_file.parent)), + problem_file_parent=repr(str(problem_file.parent)), + kernel_module=kernel_file.stem, + problem_module=problem_file.stem, + dtype_inference=dtype_inference, + model_extraction=model_extraction, + ) + + # Write wrapper file wrapper_file.write_text(wrapper_content) self.logger.info(f"Created NCU wrapper: {wrapper_file}") return wrapper_file diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 new file mode 100644 index 0000000..310d8bb --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 @@ -0,0 +1,128 @@ +"""NCU profiling wrapper.""" +import sys +import torch +import inspect +sys.path.insert(0, str({{ kernel_file_parent }})) +sys.path.insert(0, str({{ problem_file_parent }})) + +from {{ kernel_module }} import kernel_function +from {{ problem_module }} import get_inputs, get_init_inputs + +# Try to import Model if it exists (for Conv, Linear, etc.) +try: + from {{ problem_module }} import Model + has_model = True +except ImportError: + has_model = False + +# Get inputs +inputs = get_inputs() + +# Get additional initialization inputs (e.g., features, eps for RMSNorm) +init_inputs = get_init_inputs() + +{% if dtype_inference %} +# Infer required dtype from kernel function signature/docstring +required_dtype = None +try: + # Try to get dtype from kernel function docstring or source + kernel_source = inspect.getsource(kernel_function) + if 'bfloat16' in kernel_source.lower(): + required_dtype = torch.bfloat16 + elif 'float16' in kernel_source.lower() or 'half' in kernel_source.lower(): + required_dtype = torch.float16 + elif 'float32' in kernel_source.lower(): + required_dtype = torch.float32 +except Exception: + pass +{% else %} +# Dtype inference disabled +required_dtype = None +{% endif %} + +# Prepare inputs: move to CUDA and convert dtype if needed +# IMPORTANT: Only convert floating-point tensors; preserve integer tensors (e.g., class labels) +cuda_inputs = [] +for inp in inputs: + if isinstance(inp, torch.Tensor): + # Move to CUDA if not already + if not inp.is_cuda: + inp = inp.cuda() + # Convert dtype if required, but ONLY for floating-point tensors + # Preserve integer/bool tensors (e.g., targets for classification) + if required_dtype is not None and inp.is_floating_point() and inp.dtype != required_dtype: + inp = inp.to(required_dtype) + cuda_inputs.append(inp) + else: + cuda_inputs.append(inp) + +{% if model_extraction %} +# Check if this is a conv-like kernel that needs a Model to extract weights +needs_model = False +try: + sig = inspect.signature(kernel_function) + params = list(sig.parameters.keys()) + # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) + if 'weight' in params: + needs_model = True +except Exception: + pass + +if needs_model and has_model and init_inputs: + # Initialize model to extract weight and bias + try: + model = Model(*init_inputs) if init_inputs else Model() + except TypeError: + # Model initialization failed, use fallback + model = Model() + + # Move model to CUDA and convert dtype + model = model.cuda() + if required_dtype is not None: + model = model.to(required_dtype) + + # Extract weight and bias from model + # Check various possible attribute names + weight = None + bias = None + layer = None + for attr_name in ['conv1', 'conv2', 'conv3', 'conv1d', 'conv2d', 'conv', 'conv3d', 'linear', 'fc']: + if hasattr(model, attr_name): + layer = getattr(model, attr_name) + if hasattr(layer, 'weight'): + weight = layer.weight + bias = layer.bias if hasattr(layer, 'bias') else None + break + + if weight is not None and layer is not None: + # Build arguments for kernel_function using keyword arguments + # to avoid positional argument misalignment issues + kernel_kwargs = {} # Empty dict + + # Add conv/linear-specific parameters if they exist + if hasattr(layer, 'stride'): + stride = layer.stride[0] if isinstance(layer.stride, (tuple, list)) else layer.stride + kernel_kwargs['stride'] = stride + if hasattr(layer, 'padding'): + padding = layer.padding[0] if isinstance(layer.padding, (tuple, list)) else layer.padding + kernel_kwargs['padding'] = padding + if hasattr(layer, 'dilation'): + dilation = layer.dilation[0] if isinstance(layer.dilation, (tuple, list)) else layer.dilation + kernel_kwargs['dilation'] = dilation + if hasattr(layer, 'groups'): + kernel_kwargs['groups'] = layer.groups + + # Call kernel with extracted parameters + output = kernel_function(cuda_inputs[0], weight, bias, **kernel_kwargs) + else: + # Fallback to original behavior + output = kernel_function(*cuda_inputs, *init_inputs) +else: +{% endif %} + # Run kernel with both tensor inputs and initialization inputs + # For example: RMSNorm needs kernel_function(x, features, eps) + # For cross-entropy: kernel_function(predictions, targets) + # where inputs come from get_inputs() and init_inputs from get_init_inputs() + output = kernel_function(*cuda_inputs, *init_inputs) + +print("Kernel executed successfully, output shape: " + str(output.shape if hasattr(output, 'shape') else type(output))) diff --git a/triton_kernel_agent/opt_worker_components/README.md b/triton_kernel_agent/opt_worker_components/README.md new file mode 100644 index 0000000..76026b2 --- /dev/null +++ b/triton_kernel_agent/opt_worker_components/README.md @@ -0,0 +1,12 @@ +# Opt Worker Components + +High-level components used by `OptimizationWorker`. + +These components are **thin wrappers** around low-level utilities +from `kernel_perf_agent` that provide: +- Logging integration +- Error handling +- Worker-specific configuration + +## Dependency Flow +worker_components → kernel_perf_agent (implementation) diff --git a/triton_kernel_agent/opt_worker_components/profiling/__init__.py b/triton_kernel_agent/opt_worker_components/profiling/__init__.py new file mode 100644 index 0000000..cb28cc6 --- /dev/null +++ b/triton_kernel_agent/opt_worker_components/profiling/__init__.py @@ -0,0 +1,6 @@ +"""Profiling infrastructure for NCU-based kernel analysis.""" + +from .ncu_wrapper_generator import NCUWrapperGenerator +from .kernel_profiler import KernelProfiler + +__all__ = ["NCUWrapperGenerator", "KernelProfiler"] diff --git a/triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py new file mode 100644 index 0000000..7ba98ef --- /dev/null +++ b/triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py @@ -0,0 +1,89 @@ +"""Kernel profiling with NCU.""" + +import json +import logging +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import pandas as pd + +from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( + load_ncu_metrics, + metrics_to_prompt, + profile_triton_kernel, +) + +from .ncu_wrapper_generator import NCUWrapperGenerator + + +class KernelProfiler: + """Profiles Triton kernels using NVIDIA Nsight Compute (NCU).""" + + def __init__( + self, + logger: logging.Logger, + temp_dir: Path, + logs_dir: Path, + ncu_bin_path: Optional[str] = None, + ): + """ + Initialize the kernel profiler. + + Args: + logger: Logger instance + temp_dir: Temporary directory for profiling artifacts + logs_dir: Directory for saving profiling logs + ncu_bin_path: Path to NCU binary (auto-detect if None) + """ + self.logger = logger + self.temp_dir = temp_dir + self.logs_dir = logs_dir + self.ncu_bin_path = ncu_bin_path + self.wrapper_generator = NCUWrapperGenerator(logger) + + def profile_kernel( + self, kernel_file: Path, problem_file: Path, round_num: int + ) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]: + """ + Profile kernel with NCU. + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + round_num: Current optimization round number + + Returns: + Tuple of (metrics_df, metrics_json) or (None, None) on failure + """ + try: + # Create NCU wrapper script + wrapper_file = self.wrapper_generator.create_ncu_wrapper( + kernel_file, problem_file, self.temp_dir + ) + + # Profile with NCU + csv_file = f"ncu_round_{round_num}.csv" + csv_path = profile_triton_kernel( + benchmark_script=wrapper_file, + workdir=self.temp_dir, + out_csv=csv_file, + ncu_bin=self.ncu_bin_path, + launch_count=20, + timeout=120, + ) + + # Load and parse metrics + metrics_df = load_ncu_metrics(csv_path, select="last") + metrics_json = json.loads(metrics_to_prompt(metrics_df)) + + # Save metrics JSON + metrics_file = self.logs_dir / f"round{round_num:03d}_ncu_metrics.json" + with open(metrics_file, "w") as f: + json.dump(metrics_json, f, indent=2) + + self.logger.info(f"✅ NCU profiling completed for round {round_num}") + return metrics_df, metrics_json + + except Exception as e: + self.logger.error(f"❌ NCU profiling failed: {e}") + return None, None diff --git a/triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py b/triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py new file mode 100644 index 0000000..96f391f --- /dev/null +++ b/triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py @@ -0,0 +1,154 @@ +"""NCU wrapper script generation for kernel profiling.""" + +import logging +from pathlib import Path + + +class NCUWrapperGenerator: + """Generates NCU wrapper scripts for profiling Triton kernels.""" + + def __init__(self, logger: logging.Logger): + """ + Initialize the NCU wrapper generator. + + Args: + logger: Logger instance + """ + self.logger = logger + + def create_ncu_wrapper(self, kernel_file: Path, problem_file: Path, output_dir: Path) -> Path: + """ + Create NCU wrapper script for profiling. + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + output_dir: Directory to write wrapper script + + Returns: + Path to created wrapper script + """ + wrapper_file = output_dir / "ncu_wrapper.py" + + wrapper_content = f'''"""NCU profiling wrapper.""" +import sys +import torch +import inspect +sys.path.insert(0, str({str(kernel_file.parent)!r})) +sys.path.insert(0, str({str(problem_file.parent)!r})) + +from {kernel_file.stem} import kernel_function +from {problem_file.stem} import get_inputs, get_init_inputs + +# Try to import Model if it exists (for Conv, Linear, etc.) +try: + from {problem_file.stem} import Model + has_model = True +except ImportError: + has_model = False + +# Get inputs +inputs = get_inputs() + +# Get additional initialization inputs (e.g., features, eps for RMSNorm) +init_inputs = get_init_inputs() + +# Infer required dtype from kernel function signature/docstring +required_dtype = None +try: + # Try to get dtype from kernel function docstring or source + kernel_source = inspect.getsource(kernel_function) + if 'bfloat16' in kernel_source.lower(): + required_dtype = torch.bfloat16 + elif 'float16' in kernel_source.lower() or 'half' in kernel_source.lower(): + required_dtype = torch.float16 + elif 'float32' in kernel_source.lower(): + required_dtype = torch.float32 +except: + pass + +# Prepare inputs: move to CUDA and convert dtype if needed +# IMPORTANT: Only convert floating-point tensors; preserve integer tensors (e.g., class labels) +cuda_inputs = [] +for inp in inputs: + if isinstance(inp, torch.Tensor): + # Move to CUDA if not already + if not inp.is_cuda: + inp = inp.cuda() + # Convert dtype if required, but ONLY for floating-point tensors + # Preserve integer/bool tensors (e.g., targets for classification) + if required_dtype is not None and inp.is_floating_point() and inp.dtype != required_dtype: + inp = inp.to(required_dtype) + cuda_inputs.append(inp) + else: + cuda_inputs.append(inp) + +# Check if this is a conv-like kernel that needs a Model to extract weights +needs_model = False +try: + sig = inspect.signature(kernel_function) + params = list(sig.parameters.keys()) + # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) + if 'weight' in params: + needs_model = True +except: + pass + +if needs_model and has_model and init_inputs: + # Initialize model to extract weight and bias + model = Model(*init_inputs) if init_inputs else Model() + + # Move model to CUDA and convert dtype + model = model.cuda() + if required_dtype is not None: + model = model.to(required_dtype) + + # Extract weight and bias from model + # Check various possible attribute names + weight = None + bias = None + layer = None + for attr_name in ['conv1', 'conv2', 'conv3', 'conv1d', 'conv2d', 'conv', 'conv3d', 'linear', 'fc']: + if hasattr(model, attr_name): + layer = getattr(model, attr_name) + if hasattr(layer, 'weight'): + weight = layer.weight + bias = layer.bias if hasattr(layer, 'bias') else None + break + + if weight is not None and layer is not None: + # Build arguments for kernel_function using keyword arguments + # to avoid positional argument misalignment issues + kernel_kwargs = {{}} + + # Add conv/linear-specific parameters if they exist + if hasattr(layer, 'stride'): + stride = layer.stride[0] if isinstance(layer.stride, (tuple, list)) else layer.stride + kernel_kwargs['stride'] = stride + if hasattr(layer, 'padding'): + padding = layer.padding[0] if isinstance(layer.padding, (tuple, list)) else layer.padding + kernel_kwargs['padding'] = padding + if hasattr(layer, 'dilation'): + dilation = layer.dilation[0] if isinstance(layer.dilation, (tuple, list)) else layer.dilation + kernel_kwargs['dilation'] = dilation + if hasattr(layer, 'groups'): + kernel_kwargs['groups'] = layer.groups + + # Call kernel with extracted parameters + output = kernel_function(cuda_inputs[0], weight, bias, **kernel_kwargs) + else: + # Fallback to original behavior + output = kernel_function(*cuda_inputs, *init_inputs) +else: + # Run kernel with both tensor inputs and initialization inputs + # For example: RMSNorm needs kernel_function(x, features, eps) + # For cross-entropy: kernel_function(predictions, targets) + # where inputs come from get_inputs() and init_inputs from get_init_inputs() + output = kernel_function(*cuda_inputs, *init_inputs) + +print("Kernel executed successfully, output shape: " + str(output.shape if hasattr(output, 'shape') else type(output))) +''' + + wrapper_file.write_text(wrapper_content) + self.logger.info(f"Created NCU wrapper: {wrapper_file}") + return wrapper_file From 11f4e7978ff619b0f6052d875b15a1502e1691b8 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 12:54:11 -0800 Subject: [PATCH 03/22] Refactor profiling components and add kernel_perf_util --- .../README.md | 0 .../profiling/kernel_profiler.py | 12 +- .../profiling/__init__.py | 6 - .../profiling/kernel_profiler.py | 89 ---------- .../profiling/ncu_wrapper_generator.py | 154 ------------------ 5 files changed, 6 insertions(+), 255 deletions(-) rename triton_kernel_agent/{opt_worker_components => opt_worker_component}/README.md (100%) delete mode 100644 triton_kernel_agent/opt_worker_components/profiling/__init__.py delete mode 100644 triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py delete mode 100644 triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py diff --git a/triton_kernel_agent/opt_worker_components/README.md b/triton_kernel_agent/opt_worker_component/README.md similarity index 100% rename from triton_kernel_agent/opt_worker_components/README.md rename to triton_kernel_agent/opt_worker_component/README.md diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index 3fc991a..61f41b3 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -10,7 +10,7 @@ import pandas as pd -from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( +from kernel_perf_util.kernel_opt.profiler.ncu_profiler import ( load_ncu_metrics, metrics_to_prompt, profile_triton_kernel, @@ -61,7 +61,7 @@ def _get_ncu_version(self) -> Optional[str]: ) if result.returncode == 0: # Extract version from output (e.g., "NVIDIA Nsight Compute 2024.3.1") - version_line = result.stdout.strip().split('\n')[0] + version_line = result.stdout.strip().split("\n")[0] return version_line return None except Exception as e: @@ -144,7 +144,7 @@ def profile_kernel( f"❌ NCU profiling timed out after 120s (attempt {attempt}/{max_retries})" ) if attempt < max_retries: - wait_time = 2 ** attempt # Exponential backoff: 2, 4, 8 seconds + wait_time = 2**attempt # Exponential backoff: 2, 4, 8 seconds self.logger.warning(f"Retrying in {wait_time}s...") time.sleep(wait_time) continue @@ -153,7 +153,7 @@ def profile_kernel( except json.JSONDecodeError as e: self.logger.error(f"❌ Failed to parse NCU metrics: {e}") if attempt < max_retries: - wait_time = 2 ** attempt + wait_time = 2**attempt self.logger.warning(f"Retrying in {wait_time}s...") time.sleep(wait_time) continue @@ -162,10 +162,10 @@ def profile_kernel( except Exception as e: self.logger.error( f"❌ Unexpected error during profiling (attempt {attempt}/{max_retries}): {e}", - exc_info=True + exc_info=True, ) if attempt < max_retries: - wait_time = 2 ** attempt + wait_time = 2**attempt self.logger.warning(f"Retrying in {wait_time}s...") time.sleep(wait_time) continue diff --git a/triton_kernel_agent/opt_worker_components/profiling/__init__.py b/triton_kernel_agent/opt_worker_components/profiling/__init__.py deleted file mode 100644 index cb28cc6..0000000 --- a/triton_kernel_agent/opt_worker_components/profiling/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Profiling infrastructure for NCU-based kernel analysis.""" - -from .ncu_wrapper_generator import NCUWrapperGenerator -from .kernel_profiler import KernelProfiler - -__all__ = ["NCUWrapperGenerator", "KernelProfiler"] diff --git a/triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py deleted file mode 100644 index 7ba98ef..0000000 --- a/triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Kernel profiling with NCU.""" - -import json -import logging -from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -import pandas as pd - -from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( - load_ncu_metrics, - metrics_to_prompt, - profile_triton_kernel, -) - -from .ncu_wrapper_generator import NCUWrapperGenerator - - -class KernelProfiler: - """Profiles Triton kernels using NVIDIA Nsight Compute (NCU).""" - - def __init__( - self, - logger: logging.Logger, - temp_dir: Path, - logs_dir: Path, - ncu_bin_path: Optional[str] = None, - ): - """ - Initialize the kernel profiler. - - Args: - logger: Logger instance - temp_dir: Temporary directory for profiling artifacts - logs_dir: Directory for saving profiling logs - ncu_bin_path: Path to NCU binary (auto-detect if None) - """ - self.logger = logger - self.temp_dir = temp_dir - self.logs_dir = logs_dir - self.ncu_bin_path = ncu_bin_path - self.wrapper_generator = NCUWrapperGenerator(logger) - - def profile_kernel( - self, kernel_file: Path, problem_file: Path, round_num: int - ) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]: - """ - Profile kernel with NCU. - - Args: - kernel_file: Path to kernel file - problem_file: Path to problem file - round_num: Current optimization round number - - Returns: - Tuple of (metrics_df, metrics_json) or (None, None) on failure - """ - try: - # Create NCU wrapper script - wrapper_file = self.wrapper_generator.create_ncu_wrapper( - kernel_file, problem_file, self.temp_dir - ) - - # Profile with NCU - csv_file = f"ncu_round_{round_num}.csv" - csv_path = profile_triton_kernel( - benchmark_script=wrapper_file, - workdir=self.temp_dir, - out_csv=csv_file, - ncu_bin=self.ncu_bin_path, - launch_count=20, - timeout=120, - ) - - # Load and parse metrics - metrics_df = load_ncu_metrics(csv_path, select="last") - metrics_json = json.loads(metrics_to_prompt(metrics_df)) - - # Save metrics JSON - metrics_file = self.logs_dir / f"round{round_num:03d}_ncu_metrics.json" - with open(metrics_file, "w") as f: - json.dump(metrics_json, f, indent=2) - - self.logger.info(f"✅ NCU profiling completed for round {round_num}") - return metrics_df, metrics_json - - except Exception as e: - self.logger.error(f"❌ NCU profiling failed: {e}") - return None, None diff --git a/triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py b/triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py deleted file mode 100644 index 96f391f..0000000 --- a/triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py +++ /dev/null @@ -1,154 +0,0 @@ -"""NCU wrapper script generation for kernel profiling.""" - -import logging -from pathlib import Path - - -class NCUWrapperGenerator: - """Generates NCU wrapper scripts for profiling Triton kernels.""" - - def __init__(self, logger: logging.Logger): - """ - Initialize the NCU wrapper generator. - - Args: - logger: Logger instance - """ - self.logger = logger - - def create_ncu_wrapper(self, kernel_file: Path, problem_file: Path, output_dir: Path) -> Path: - """ - Create NCU wrapper script for profiling. - - Args: - kernel_file: Path to kernel file - problem_file: Path to problem file - output_dir: Directory to write wrapper script - - Returns: - Path to created wrapper script - """ - wrapper_file = output_dir / "ncu_wrapper.py" - - wrapper_content = f'''"""NCU profiling wrapper.""" -import sys -import torch -import inspect -sys.path.insert(0, str({str(kernel_file.parent)!r})) -sys.path.insert(0, str({str(problem_file.parent)!r})) - -from {kernel_file.stem} import kernel_function -from {problem_file.stem} import get_inputs, get_init_inputs - -# Try to import Model if it exists (for Conv, Linear, etc.) -try: - from {problem_file.stem} import Model - has_model = True -except ImportError: - has_model = False - -# Get inputs -inputs = get_inputs() - -# Get additional initialization inputs (e.g., features, eps for RMSNorm) -init_inputs = get_init_inputs() - -# Infer required dtype from kernel function signature/docstring -required_dtype = None -try: - # Try to get dtype from kernel function docstring or source - kernel_source = inspect.getsource(kernel_function) - if 'bfloat16' in kernel_source.lower(): - required_dtype = torch.bfloat16 - elif 'float16' in kernel_source.lower() or 'half' in kernel_source.lower(): - required_dtype = torch.float16 - elif 'float32' in kernel_source.lower(): - required_dtype = torch.float32 -except: - pass - -# Prepare inputs: move to CUDA and convert dtype if needed -# IMPORTANT: Only convert floating-point tensors; preserve integer tensors (e.g., class labels) -cuda_inputs = [] -for inp in inputs: - if isinstance(inp, torch.Tensor): - # Move to CUDA if not already - if not inp.is_cuda: - inp = inp.cuda() - # Convert dtype if required, but ONLY for floating-point tensors - # Preserve integer/bool tensors (e.g., targets for classification) - if required_dtype is not None and inp.is_floating_point() and inp.dtype != required_dtype: - inp = inp.to(required_dtype) - cuda_inputs.append(inp) - else: - cuda_inputs.append(inp) - -# Check if this is a conv-like kernel that needs a Model to extract weights -needs_model = False -try: - sig = inspect.signature(kernel_function) - params = list(sig.parameters.keys()) - # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) - if 'weight' in params: - needs_model = True -except: - pass - -if needs_model and has_model and init_inputs: - # Initialize model to extract weight and bias - model = Model(*init_inputs) if init_inputs else Model() - - # Move model to CUDA and convert dtype - model = model.cuda() - if required_dtype is not None: - model = model.to(required_dtype) - - # Extract weight and bias from model - # Check various possible attribute names - weight = None - bias = None - layer = None - for attr_name in ['conv1', 'conv2', 'conv3', 'conv1d', 'conv2d', 'conv', 'conv3d', 'linear', 'fc']: - if hasattr(model, attr_name): - layer = getattr(model, attr_name) - if hasattr(layer, 'weight'): - weight = layer.weight - bias = layer.bias if hasattr(layer, 'bias') else None - break - - if weight is not None and layer is not None: - # Build arguments for kernel_function using keyword arguments - # to avoid positional argument misalignment issues - kernel_kwargs = {{}} - - # Add conv/linear-specific parameters if they exist - if hasattr(layer, 'stride'): - stride = layer.stride[0] if isinstance(layer.stride, (tuple, list)) else layer.stride - kernel_kwargs['stride'] = stride - if hasattr(layer, 'padding'): - padding = layer.padding[0] if isinstance(layer.padding, (tuple, list)) else layer.padding - kernel_kwargs['padding'] = padding - if hasattr(layer, 'dilation'): - dilation = layer.dilation[0] if isinstance(layer.dilation, (tuple, list)) else layer.dilation - kernel_kwargs['dilation'] = dilation - if hasattr(layer, 'groups'): - kernel_kwargs['groups'] = layer.groups - - # Call kernel with extracted parameters - output = kernel_function(cuda_inputs[0], weight, bias, **kernel_kwargs) - else: - # Fallback to original behavior - output = kernel_function(*cuda_inputs, *init_inputs) -else: - # Run kernel with both tensor inputs and initialization inputs - # For example: RMSNorm needs kernel_function(x, features, eps) - # For cross-entropy: kernel_function(predictions, targets) - # where inputs come from get_inputs() and init_inputs from get_init_inputs() - output = kernel_function(*cuda_inputs, *init_inputs) - -print("Kernel executed successfully, output shape: " + str(output.shape if hasattr(output, 'shape') else type(output))) -''' - - wrapper_file.write_text(wrapper_content) - self.logger.info(f"Created NCU wrapper: {wrapper_file}") - return wrapper_file From 251f419741dd3a50233b947b3dab26417e922e2e Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 12:55:10 -0800 Subject: [PATCH 04/22] Refactor profiling components and add kernel_perf_util --- .../kernel_opt/profiler/ncu_profiler.py | 61 ------------------- 1 file changed, 61 deletions(-) diff --git a/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py index 26e59d2..0b9168b 100644 --- a/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py +++ b/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py @@ -64,67 +64,6 @@ ] ) -# METRICS = ",".join( -# [ -# # Tensor Core (MOST CRITICAL for GEMM) -# "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", -# # Memory Bandwidth & Cache -# "dram__throughput.avg.pct_of_peak_sustained_elapsed", -# "l1tex__t_sector_hit_rate.pct", -# "lts__t_sector_hit_rate.pct", -# # Memory Access Efficiency -# "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", -# # Stalls & Bottlenecks (H100-compatible stall metrics) -# "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", -# "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", -# # Occupancy & Resources -# "launch__shared_mem_per_block_allocated", -# "sm__warps_active.avg.pct_of_peak_sustained_elapsed", -# # SM Utilization -# "sm__cycles_active.avg.pct_of_peak_sustained_elapsed", -# ] -# ) - -# # CRITICAL GEMM Optimization Metrics (Optimized for Matrix Multiplication) -# METRICS = ",".join( -# [ -# # === Memory Efficiency Metrics (CRITICAL) === -# "dram__throughput.avg.pct_of_peak_sustained_elapsed", # DRAM bandwidth utilization -# "dram__bytes_read.sum", # Total DRAM reads -# "dram__bytes_write.sum", # Total DRAM writes -# "l1tex__t_sector_hit_rate.pct", # L1 cache hit rate (CRITICAL for data reuse) -# "l1tex__throughput.avg.pct_of_peak_sustained_elapsed", # L1 throughput (fixed from _active) -# "lts__t_sector_hit_rate.pct", # L2 cache hit rate -# "lts__throughput.avg.pct_of_peak_sustained_elapsed", # L2 throughput (fixed from _active) -# "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", # Memory coalescing (NEW - CRITICAL) -# "l1tex__data_pipe_lsu_wavefronts_mem_shared_op_ld.sum", # Shared memory loads (NEW) -# # === Compute Utilization Metrics (CRITICAL for GEMM) === -# "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", # Tensor core pipeline (NEW - CRITICAL) -# "smsp__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", # Tensor core inst % -# "smsp__inst_executed_pipe_tensor.sum", # Total tensor core instructions (NEW) -# "sm__inst_executed_pipe_fp32.avg.pct_of_peak_sustained_active", # FP32 pipeline -# "sm__cycles_active.avg.pct_of_peak_sustained_elapsed", # SM active cycles (fixed from .avg only) -# "sm__inst_executed.sum", # Total instructions -# "smsp__inst_executed.avg.per_cycle_active", # IPC (NEW) -# # === Occupancy & Resource Metrics === -# "sm__warps_active.avg.pct_of_peak_sustained_elapsed", # Warp occupancy (fixed from _active) -# "sm__maximum_warps_per_active_cycle_pct", # Max achievable occupancy (NEW) -# "launch__occupancy_limit_blocks", # Block limit -# "launch__occupancy_limit_registers", # Register limit -# "launch__occupancy_limit_shared_mem", # Shared memory limit -# "launch__registers_per_thread", # Registers per thread -# "launch__shared_mem_per_block_allocated", # Shared memory usage (NEW - CRITICAL) -# # === Stall & Latency Metrics === -# "smsp__warp_issue_stalled_memory_dependency_per_warp_active.pct", # Memory dependency stalls -# "smsp__warp_issue_stalled_mem_throttle_per_warp_active.pct", # Memory throttle stalls (NEW - CRITICAL) -# "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", # Short scoreboard stalls -# "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", # Long scoreboard stalls -# "smsp__warp_issue_stalled_barrier_per_warp_active.pct", # Barrier stalls -# "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", # Branch stalls -# ] -# ) - - # List version for convenient header selection METRIC_COLUMNS: List[str] = [s.strip() for s in METRICS.split(",")] From b789660d6897fee27c8793eab9984d25c031a894 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 14:14:19 -0800 Subject: [PATCH 05/22] update directory name and add package in pyproject --- kernel_perf_agent/README.md | 8 + .../kernel_opt/profiler/__init__.py | 0 .../kernel_opt/profiler/gpu_specs.py | 227 ++++++++++ .../kernel_opt/profiler/ncu_profiler.py | 388 ++++++++++++++++++ pyproject.toml | 2 +- 5 files changed, 624 insertions(+), 1 deletion(-) create mode 100644 kernel_perf_agent/README.md create mode 100644 kernel_perf_agent/kernel_opt/profiler/__init__.py create mode 100644 kernel_perf_agent/kernel_opt/profiler/gpu_specs.py create mode 100644 kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py diff --git a/kernel_perf_agent/README.md b/kernel_perf_agent/README.md new file mode 100644 index 0000000..5114402 --- /dev/null +++ b/kernel_perf_agent/README.md @@ -0,0 +1,8 @@ +# Kernel Performance Utilities + +Low-level, reusable utilities for kernel optimization: +- NCU profiling +- RAG-based optimization database +- Performance analysis tools + +These are **library components** that can be used by any worker. diff --git a/kernel_perf_agent/kernel_opt/profiler/__init__.py b/kernel_perf_agent/kernel_opt/profiler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py b/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py new file mode 100644 index 0000000..ed8b50f --- /dev/null +++ b/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +""" +GPU Specifications Database for Bottleneck Analysis + +This module provides GPU hardware specifications needed for performance analysis +and bottleneck identification. It includes peak compute performance, memory bandwidth, +cache sizes, and SM counts for common NVIDIA GPUs. + +""" + +import re +import subprocess +from typing import Any, Dict, Optional + + +# GPU specifications database +# Sources: NVIDIA official specifications, manufacturer datasheets +GPU_SPECS_DATABASE = { + "NVIDIA A100": { + "name": "NVIDIA A100", + "architecture": "Ampere", + "peak_fp32_tflops": 19.5, + "peak_fp16_tflops": 312.0, + "peak_bf16_tflops": 312.0, + "peak_memory_bw_gbps": 1555, + "sm_count": 108, + "max_threads_per_sm": 2048, + "l1_cache_kb": 192, + "l2_cache_mb": 40, + "memory_gb": 40, + "memory_type": "HBM2e", + }, + "NVIDIA H100": { + "name": "NVIDIA H100", + "architecture": "Hopper", + "peak_fp32_tflops": 51.0, + "peak_fp16_tflops": 989.0, + "peak_bf16_tflops": 989.0, + "peak_memory_bw_gbps": 3352, + "sm_count": 132, + "max_threads_per_sm": 2048, + "l1_cache_kb": 256, + "l2_cache_mb": 50, + "memory_gb": 80, + "memory_type": "HBM3", + }, + "NVIDIA RTX 4090": { + "name": "NVIDIA RTX 4090", + "architecture": "Ada Lovelace", + "peak_fp32_tflops": 82.6, + "peak_fp16_tflops": 165.0, + "peak_bf16_tflops": 165.0, + "peak_memory_bw_gbps": 1008, + "sm_count": 128, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 72, + "memory_gb": 24, + "memory_type": "GDDR6X", + }, + "NVIDIA RTX 5080": { + "name": "NVIDIA RTX 5080", + "architecture": "Blackwell", + "peak_fp32_tflops": 57.0, + "peak_fp16_tflops": 114.0, + "peak_bf16_tflops": 114.0, + "peak_memory_bw_gbps": 960, + "sm_count": 84, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 64, + "memory_gb": 16, + "memory_type": "GDDR7", + }, +} + + +def query_gpu_name() -> Optional[str]: + """ + Query GPU name using nvidia-smi. + + Returns: + GPU name string, or None if query fails + """ + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + gpu_name = result.stdout.strip() + # Normalize whitespace + gpu_name = re.sub(r"\s+", " ", gpu_name) + return gpu_name + except (subprocess.TimeoutExpired, FileNotFoundError, Exception): + pass + return None + + +def get_gpu_specs(gpu_name: Optional[str] = None) -> Dict[str, Any]: + """ + Get GPU specifications for bottleneck analysis. + + This function returns hardware specifications needed for performance analysis, + including peak compute performance, memory bandwidth, cache sizes, and SM counts. + + Args: + gpu_name: GPU name (if None, auto-detect with nvidia-smi) + + Returns: + Dictionary with GPU specifications containing: + - name: GPU name + - architecture: GPU architecture (e.g., "Ampere", "Hopper") + - peak_fp32_tflops: Peak FP32 compute performance in TFLOPS + - peak_fp16_tflops: Peak FP16 compute performance in TFLOPS + - peak_bf16_tflops: Peak BF16 compute performance in TFLOPS (0 if not supported) + - peak_memory_bw_gbps: Peak memory bandwidth in GB/s + - sm_count: Number of streaming multiprocessors + - max_threads_per_sm: Maximum threads per SM + - l1_cache_kb: L1 cache size in KB per SM + - l2_cache_mb: Total L2 cache size in MB + - memory_gb: Total GPU memory in GB + - memory_type: Memory type (e.g., "HBM2e", "GDDR6X") + + Examples: + >>> specs = get_gpu_specs() # Auto-detect + >>> print(f"Peak BW: {specs['peak_memory_bw_gbps']} GB/s") + + >>> specs = get_gpu_specs("NVIDIA A100") + >>> print(f"SM Count: {specs['sm_count']}") + """ + # Auto-detect if not provided + if gpu_name is None: + gpu_name = query_gpu_name() + + # Return default if detection failed + if gpu_name is None: + print("⚠️ GPU auto-detection failed, using A100 specs as fallback") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + # Try exact match + if gpu_name in GPU_SPECS_DATABASE: + return GPU_SPECS_DATABASE[gpu_name].copy() + + # Try fuzzy match (contains or partial match) + gpu_name_lower = gpu_name.lower() + for key, specs in GPU_SPECS_DATABASE.items(): + key_lower = key.lower() + # Check if either name contains the other + if gpu_name_lower in key_lower or key_lower in gpu_name_lower: + print(f"ℹ️ Matched '{gpu_name}' to '{key}' (fuzzy match)") + return specs.copy() + + # Fallback to A100 specs with warning + print(f"⚠️ Unknown GPU: '{gpu_name}', using A100 specs as fallback") + print(f" Available GPUs: {', '.join(GPU_SPECS_DATABASE.keys())}") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + +def format_gpu_specs_for_prompt(specs: Dict[str, Any]) -> str: + """ + Format GPU specifications as a readable string for LLM prompts. + + Args: + specs: GPU specifications dictionary from get_gpu_specs() + + Returns: + Formatted string with GPU specifications + + Example: + >>> specs = get_gpu_specs("NVIDIA A100") + >>> print(format_gpu_specs_for_prompt(specs)) + GPU: NVIDIA A100 (Ampere) + - Peak Memory Bandwidth: 1555 GB/s + - Peak FP32 Performance: 19.5 TFLOPS + - Peak FP16 Performance: 312.0 TFLOPS + - SM Count: 108 + - L1 Cache: 192 KB per SM + - L2 Cache: 40 MB + """ + lines = [ + f"GPU: {specs['name']} ({specs.get('architecture', 'Unknown')})", + f"- Peak Memory Bandwidth: {specs['peak_memory_bw_gbps']} GB/s", + f"- Peak FP32 Performance: {specs['peak_fp32_tflops']} TFLOPS", + ] + + if specs.get("peak_fp16_tflops", 0) > 0: + lines.append(f"- Peak FP16 Performance: {specs['peak_fp16_tflops']} TFLOPS") + + if specs.get("peak_bf16_tflops", 0) > 0: + lines.append(f"- Peak BF16 Performance: {specs['peak_bf16_tflops']} TFLOPS") + + lines.extend( + [ + f"- SM Count: {specs['sm_count']}", + f"- Max Threads per SM: {specs['max_threads_per_sm']}", + f"- L1 Cache: {specs['l1_cache_kb']} KB per SM", + f"- L2 Cache: {specs['l2_cache_mb']} MB", + f"- Memory: {specs['memory_gb']} GB {specs.get('memory_type', '')}", + ] + ) + + return "\n".join(lines) + + +if __name__ == "__main__": + print("GPU Specifications Module") + print("=" * 60) + + # Auto-detect GPU + detected_name = query_gpu_name() + if detected_name: + print(f"\nDetected GPU: {detected_name}") + else: + print("\nNo GPU detected (nvidia-smi not available)") + + # Get specs + specs = get_gpu_specs() + print(f"\n{format_gpu_specs_for_prompt(specs)}") + + # Show all available GPUs + print(f"\n{'=' * 60}") + print("Available GPU specifications in database:") + for gpu_name in sorted(GPU_SPECS_DATABASE.keys()): + print(f" - {gpu_name}") diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py new file mode 100644 index 0000000..0b9168b --- /dev/null +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +NCU Profiling Module for Triton Kernels + +This module wraps three tasks: +1) Collect core metrics for Triton CUDA kernels with Nsight Compute into CSV (`profile_triton_kernel`). +2) Extract and clean those metrics into a DataFrame from the CSV (`load_ncu_metrics`). +3) Convert the metrics table into a string suitable for inclusion in an LLM prompt (`metrics_to_prompt`). + +""" + +import json +import math +import os +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Union + +import numpy as np +import pandas as pd + + +__all__ = [ + "METRICS", + "METRIC_COLUMNS", + "profile_triton_kernel", + "load_ncu_metrics", + "metrics_to_prompt", +] + +METRICS = ",".join( + [ + "sm__cycles_active.avg", + "sm__warps_active.avg.pct_of_peak_sustained_active", + "launch__occupancy_limit_blocks", + "launch__occupancy_limit_registers", + "launch__occupancy_limit_shared_mem", + "launch__registers_per_thread", + "sm__inst_executed.sum", + "sm__inst_executed_pipe_fp32.avg.pct_of_peak_sustained_active", + "sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", + "dram__bytes_read.sum", + "dram__bytes_write.sum", + "dram__throughput.avg.pct_of_peak_sustained_elapsed", + "dram__bytes.sum.per_second", + "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed", + "launch__shared_mem_per_block_allocated", + "l1tex__t_sector_hit_rate.pct", + "l1tex__throughput.avg.pct_of_peak_sustained_active", + "lts__t_sector_hit_rate.pct", + "lts__throughput.avg.pct_of_peak_sustained_active", + "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", + "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", + "smsp__warp_issue_stalled_memory_dependency_per_warp_active.pct", + "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", + "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", + "smsp__warp_issue_stalled_barrier_per_warp_active.pct", + "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", + "smsp__sass_average_branch_targets_threads_uniform.pct", + ] +) + +# List version for convenient header selection +METRIC_COLUMNS: List[str] = [s.strip() for s in METRICS.split(",")] + + +def profile_triton_kernel( + benchmark_script: Path, + workdir: Path, + out_csv: str = "ncu_output.csv", + python_executable: Optional[str] = None, + ncu_bin: Optional[str] = None, + launch_count: int = 20, + timeout: int = 120, +) -> Path: + """ + Profile a Triton kernel using NCU. + + Args: + benchmark_script: Path to benchmark script that calls the kernel + workdir: Working directory for execution + out_csv: Output CSV filename + python_executable: Python executable to use (default: sys.executable) + ncu_bin: Path to NCU binary (default: auto-detect) + launch_count: Number of kernel launches to profile + timeout: Timeout in seconds for NCU execution + + Returns: + Path to output CSV file + + Raises: + RuntimeError: If NCU profiling fails + FileNotFoundError: If NCU binary or output CSV not found + """ + # Resolve paths + if python_executable is None: + python_executable = sys.executable + + if ncu_bin is None: + ncu_bin = shutil.which("ncu") or "/usr/local/cuda/bin/ncu" + + if not Path(ncu_bin).exists(): + raise FileNotFoundError(f"NCU binary not found: {ncu_bin}") + + csv_path = (workdir / out_csv).resolve() + benchmark_script = benchmark_script.resolve() + + if not benchmark_script.exists(): + raise FileNotFoundError(f"Benchmark script not found: {benchmark_script}") + + # Preserve important environment variables + env = os.environ.copy() + + # Add Triton-specific environment variables + env["TRITON_CACHE_DIR"] = str(workdir / ".triton_cache") + + preserve = ",".join( + [ + "PATH", + "LD_LIBRARY_PATH", + "CUDA_VISIBLE_DEVICES", + "PYTHONPATH", + "TRITON_CACHE_DIR", + "TORCH_EXTENSIONS_DIR", + "CONDA_PREFIX", + "CONDA_DEFAULT_ENV", + ] + ) + + # Build NCU command + cmd = [ + "sudo", + "-E", + f"--preserve-env={preserve}", + ncu_bin, + "--csv", + "--page=raw", + "--kernel-name-base=demangled", + "--target-processes=all", + "--replay-mode=kernel", + "--profile-from-start=on", + f"--log-file={str(csv_path)}", + f"--metrics={METRICS}", + "--launch-skip=0", + f"--launch-count={launch_count}", + python_executable, + str(benchmark_script), + ] + + print(f"[NCU] Running profiling...") + print(f"[NCU] Benchmark: {benchmark_script.name}") + print(f"[NCU] Output: {csv_path}") + print(f"[NCU] Command: {' '.join(cmd[:10])}... (truncated)") + + try: + result = subprocess.run( + cmd, + cwd=str(workdir), + env=env, + capture_output=True, + text=True, + timeout=timeout, + ) + + if result.returncode != 0: + error_msg = result.stderr or result.stdout + raise RuntimeError( + f"NCU profiling failed with return code {result.returncode}:\n{error_msg[:500]}" + ) + + if not csv_path.exists(): + raise FileNotFoundError(f"NCU did not create output CSV: {csv_path}") + + # Check if CSV has content + csv_size = csv_path.stat().st_size + if csv_size < 100: + raise RuntimeError( + f"NCU CSV file is too small ({csv_size} bytes), likely empty" + ) + + print(f"[NCU] ✓ Profiling completed successfully") + print(f"[NCU] ✓ CSV written: {csv_path} ({csv_size} bytes)") + return csv_path + + except subprocess.TimeoutExpired: + raise RuntimeError(f"NCU profiling timed out after {timeout} seconds") + except Exception as e: + raise RuntimeError(f"NCU profiling failed: {e}") + + +def load_ncu_metrics( + csv_path: Union[str, Path], + columns: Optional[Sequence[str]] = None, + extra_keep: Optional[Sequence[str]] = ("Kernel Name",), + coerce_numeric: bool = True, + name_list: Optional[Sequence[str]] = None, + select: str = "last", +) -> pd.DataFrame: + """ + Load and parse NCU metrics from CSV file. + + Args: + csv_path: Path to NCU CSV output + columns: Specific metric columns to load (default: all METRIC_COLUMNS) + extra_keep: Additional columns to keep (e.g., "Kernel Name") + coerce_numeric: Convert metric values to numeric + name_list: Filter by kernel names (substring match) + select: Selection policy when multiple rows per name: + "first", "last", "max_cycles" + + Returns: + DataFrame with parsed metrics + + Raises: + FileNotFoundError: If CSV file not found + ValueError: If no requested columns found in CSV + """ + csv_path = Path(csv_path) + if not csv_path.exists(): + raise FileNotFoundError(f"CSV not found: {csv_path}") + + df = pd.read_csv(csv_path, comment="=", low_memory=False) + + metric_cols = list(columns) if columns is not None else METRIC_COLUMNS + keep_cols: List[str] = [] + if extra_keep: + keep_cols.extend([c for c in extra_keep if c in df.columns]) + keep_cols.extend([c for c in metric_cols if c in df.columns]) + + if not keep_cols: + raise ValueError("No requested columns found in the CSV header.") + + sub = df[keep_cols].copy() + + # Drop the units row (first row often contains units like "%", "inst", etc.) + if len(sub) > 0: + first_row_str = sub.iloc[0].astype(str).str.lower() + unit_tokens = ("%", "inst", "cycle", "block", "register", "register/thread") + if first_row_str.apply(lambda x: any(tok in x for tok in unit_tokens)).any(): + sub = sub.iloc[1:].reset_index(drop=True) + + # Coerce metrics to numeric + if coerce_numeric: + metric_in_sub = [c for c in metric_cols if c in sub.columns] + sub[metric_in_sub] = ( + sub[metric_in_sub] + .replace({",": "", "%": ""}, regex=True) + .apply(pd.to_numeric, errors="coerce") + ) + + # Filter by kernel name list if provided + if name_list: + results = [] + for name in name_list: + # Use contains match instead of exact equality (for Triton's long kernel names) + matched = sub[ + sub["Kernel Name"].astype(str).str.contains(name, regex=False, na=False) + ] + if matched.empty: + continue + if len(matched) > 1: + if select == "first": + row = matched.iloc[[0]] + elif select == "last": + row = matched.iloc[[-1]] + elif ( + select == "max_cycles" + and "sm__cycles_active.avg" in matched.columns + ): + row = matched.sort_values( + "sm__cycles_active.avg", ascending=False + ).head(1) + else: + row = matched.iloc[[-1]] # fallback + else: + row = matched + results.append(row) + + if results: + sub = pd.concat(results, ignore_index=True) + else: + sub = pd.DataFrame(columns=keep_cols) + elif select in ("first", "last", "max_cycles"): + # Apply selection to all rows if no name filter + if len(sub) > 0: + if select == "first": + sub = sub.iloc[[0]] + elif select == "last": + sub = sub.iloc[[-1]] + elif select == "max_cycles" and "sm__cycles_active.avg" in sub.columns: + sub = sub.sort_values("sm__cycles_active.avg", ascending=False).head(1) + + return sub + + +def metrics_to_prompt( + df: pd.DataFrame, + title: str = "GPU Profiling Metrics:", + key_by: str = "Kernel Name", + round_digits: Optional[int] = 3, + compact: bool = False, + keep_cols: Optional[List[str]] = None, +) -> str: + """ + Convert NCU metrics DataFrame to JSON string for LLM prompts. + + Returns JSON in format: + { + "": { "": , ... } + } + Args: + df: DataFrame with NCU metrics + title: Title for the metrics section (not included in output) + key_by: Column to use as key (usually "Kernel Name") + round_digits: Number of decimal places for rounding + compact: If True, use compact JSON (no indentation) + keep_cols: Specific columns to include in output + + Returns: + JSON string with metrics + """ + + def _safe(v: Any) -> Any: + """Convert values to JSON-safe format.""" + if v is None or (isinstance(v, float) and math.isnan(v)): + return None + if isinstance(v, (pd.Timestamp, pd.Timedelta, pd.Interval)): + return str(v) + if isinstance(v, np.generic): + v = v.item() + if isinstance(v, float) and math.isinf(v): + return "inf" if v > 0 else "-inf" + if isinstance(v, float) and round_digits is not None: + return round(v, round_digits) + return v + + # Empty table + if df is None or df.empty: + return "{}" + + cols = list(df.columns) + + # Round numeric columns + if round_digits is not None: + num_cols = df.select_dtypes(include="number").columns + if len(num_cols) > 0: + df = df.copy() + df[num_cols] = df[num_cols].round(round_digits) + + # If key column is missing, return a list of rows + if key_by not in cols: + rows = [ + {k: _safe(v) for k, v in rec.items()} + for rec in df.to_dict(orient="records") + ] + return json.dumps(rows, ensure_ascii=False, indent=None if compact else 2) + + # Determine value columns + value_cols = [c for c in cols if c != key_by] + if keep_cols is not None: + value_cols = [c for c in value_cols if c in keep_cols] + + data: Dict[str, Any] = {} + for rec in df[[key_by] + value_cols].to_dict(orient="records"): + k = str(rec.pop(key_by)) + val_obj = {ck: _safe(cv) for ck, cv in rec.items()} + if k in data: + # Multiple rows for same key - convert to list + if isinstance(data[k], list): + data[k].append(val_obj) + else: + data[k] = [data[k], val_obj] + else: + data[k] = val_obj + + return json.dumps(data, ensure_ascii=False, indent=None if compact else 2) + + +if __name__ == "__main__": + print("ncu_profiler module loaded.") + print("Import its functions in your scripts:") + print( + " from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import profile_triton_kernel, load_ncu_metrics, metrics_to_prompt" + ) diff --git a/pyproject.toml b/pyproject.toml index 2d5382f..a31dc3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ pipeline-ui = "scripts.pipeline_ui:main" "Repository" = "https://github.com/pytorch-labs/KernelAgent" [tool.setuptools.packages.find] -include = ["triton_kernel_agent*", "Fuser*", "scripts", "utils"] +include = ["triton_kernel_agent*", "kernel_perf_agent*", "Fuser*", "scripts", "utils"] [tool.setuptools] include-package-data = true From 4d35d570e3053a5b22c30c1b7c2cc73cd7bb1360 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 14:18:37 -0800 Subject: [PATCH 06/22] Remove kernel_perf_util directory --- kernel_perf_util/README.md | 8 - .../kernel_opt/profiler/__init__.py | 0 .../kernel_opt/profiler/gpu_specs.py | 227 ---------- .../kernel_opt/profiler/ncu_profiler.py | 388 ------------------ .../profiling/kernel_profiler.py | 2 +- 5 files changed, 1 insertion(+), 624 deletions(-) delete mode 100644 kernel_perf_util/README.md delete mode 100644 kernel_perf_util/kernel_opt/profiler/__init__.py delete mode 100644 kernel_perf_util/kernel_opt/profiler/gpu_specs.py delete mode 100644 kernel_perf_util/kernel_opt/profiler/ncu_profiler.py diff --git a/kernel_perf_util/README.md b/kernel_perf_util/README.md deleted file mode 100644 index 5114402..0000000 --- a/kernel_perf_util/README.md +++ /dev/null @@ -1,8 +0,0 @@ -# Kernel Performance Utilities - -Low-level, reusable utilities for kernel optimization: -- NCU profiling -- RAG-based optimization database -- Performance analysis tools - -These are **library components** that can be used by any worker. diff --git a/kernel_perf_util/kernel_opt/profiler/__init__.py b/kernel_perf_util/kernel_opt/profiler/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/kernel_perf_util/kernel_opt/profiler/gpu_specs.py b/kernel_perf_util/kernel_opt/profiler/gpu_specs.py deleted file mode 100644 index ed8b50f..0000000 --- a/kernel_perf_util/kernel_opt/profiler/gpu_specs.py +++ /dev/null @@ -1,227 +0,0 @@ -#!/usr/bin/env python3 -""" -GPU Specifications Database for Bottleneck Analysis - -This module provides GPU hardware specifications needed for performance analysis -and bottleneck identification. It includes peak compute performance, memory bandwidth, -cache sizes, and SM counts for common NVIDIA GPUs. - -""" - -import re -import subprocess -from typing import Any, Dict, Optional - - -# GPU specifications database -# Sources: NVIDIA official specifications, manufacturer datasheets -GPU_SPECS_DATABASE = { - "NVIDIA A100": { - "name": "NVIDIA A100", - "architecture": "Ampere", - "peak_fp32_tflops": 19.5, - "peak_fp16_tflops": 312.0, - "peak_bf16_tflops": 312.0, - "peak_memory_bw_gbps": 1555, - "sm_count": 108, - "max_threads_per_sm": 2048, - "l1_cache_kb": 192, - "l2_cache_mb": 40, - "memory_gb": 40, - "memory_type": "HBM2e", - }, - "NVIDIA H100": { - "name": "NVIDIA H100", - "architecture": "Hopper", - "peak_fp32_tflops": 51.0, - "peak_fp16_tflops": 989.0, - "peak_bf16_tflops": 989.0, - "peak_memory_bw_gbps": 3352, - "sm_count": 132, - "max_threads_per_sm": 2048, - "l1_cache_kb": 256, - "l2_cache_mb": 50, - "memory_gb": 80, - "memory_type": "HBM3", - }, - "NVIDIA RTX 4090": { - "name": "NVIDIA RTX 4090", - "architecture": "Ada Lovelace", - "peak_fp32_tflops": 82.6, - "peak_fp16_tflops": 165.0, - "peak_bf16_tflops": 165.0, - "peak_memory_bw_gbps": 1008, - "sm_count": 128, - "max_threads_per_sm": 1536, - "l1_cache_kb": 128, - "l2_cache_mb": 72, - "memory_gb": 24, - "memory_type": "GDDR6X", - }, - "NVIDIA RTX 5080": { - "name": "NVIDIA RTX 5080", - "architecture": "Blackwell", - "peak_fp32_tflops": 57.0, - "peak_fp16_tflops": 114.0, - "peak_bf16_tflops": 114.0, - "peak_memory_bw_gbps": 960, - "sm_count": 84, - "max_threads_per_sm": 1536, - "l1_cache_kb": 128, - "l2_cache_mb": 64, - "memory_gb": 16, - "memory_type": "GDDR7", - }, -} - - -def query_gpu_name() -> Optional[str]: - """ - Query GPU name using nvidia-smi. - - Returns: - GPU name string, or None if query fails - """ - try: - result = subprocess.run( - ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], - capture_output=True, - text=True, - timeout=5, - ) - if result.returncode == 0: - gpu_name = result.stdout.strip() - # Normalize whitespace - gpu_name = re.sub(r"\s+", " ", gpu_name) - return gpu_name - except (subprocess.TimeoutExpired, FileNotFoundError, Exception): - pass - return None - - -def get_gpu_specs(gpu_name: Optional[str] = None) -> Dict[str, Any]: - """ - Get GPU specifications for bottleneck analysis. - - This function returns hardware specifications needed for performance analysis, - including peak compute performance, memory bandwidth, cache sizes, and SM counts. - - Args: - gpu_name: GPU name (if None, auto-detect with nvidia-smi) - - Returns: - Dictionary with GPU specifications containing: - - name: GPU name - - architecture: GPU architecture (e.g., "Ampere", "Hopper") - - peak_fp32_tflops: Peak FP32 compute performance in TFLOPS - - peak_fp16_tflops: Peak FP16 compute performance in TFLOPS - - peak_bf16_tflops: Peak BF16 compute performance in TFLOPS (0 if not supported) - - peak_memory_bw_gbps: Peak memory bandwidth in GB/s - - sm_count: Number of streaming multiprocessors - - max_threads_per_sm: Maximum threads per SM - - l1_cache_kb: L1 cache size in KB per SM - - l2_cache_mb: Total L2 cache size in MB - - memory_gb: Total GPU memory in GB - - memory_type: Memory type (e.g., "HBM2e", "GDDR6X") - - Examples: - >>> specs = get_gpu_specs() # Auto-detect - >>> print(f"Peak BW: {specs['peak_memory_bw_gbps']} GB/s") - - >>> specs = get_gpu_specs("NVIDIA A100") - >>> print(f"SM Count: {specs['sm_count']}") - """ - # Auto-detect if not provided - if gpu_name is None: - gpu_name = query_gpu_name() - - # Return default if detection failed - if gpu_name is None: - print("⚠️ GPU auto-detection failed, using A100 specs as fallback") - return GPU_SPECS_DATABASE["NVIDIA A100"].copy() - - # Try exact match - if gpu_name in GPU_SPECS_DATABASE: - return GPU_SPECS_DATABASE[gpu_name].copy() - - # Try fuzzy match (contains or partial match) - gpu_name_lower = gpu_name.lower() - for key, specs in GPU_SPECS_DATABASE.items(): - key_lower = key.lower() - # Check if either name contains the other - if gpu_name_lower in key_lower or key_lower in gpu_name_lower: - print(f"ℹ️ Matched '{gpu_name}' to '{key}' (fuzzy match)") - return specs.copy() - - # Fallback to A100 specs with warning - print(f"⚠️ Unknown GPU: '{gpu_name}', using A100 specs as fallback") - print(f" Available GPUs: {', '.join(GPU_SPECS_DATABASE.keys())}") - return GPU_SPECS_DATABASE["NVIDIA A100"].copy() - - -def format_gpu_specs_for_prompt(specs: Dict[str, Any]) -> str: - """ - Format GPU specifications as a readable string for LLM prompts. - - Args: - specs: GPU specifications dictionary from get_gpu_specs() - - Returns: - Formatted string with GPU specifications - - Example: - >>> specs = get_gpu_specs("NVIDIA A100") - >>> print(format_gpu_specs_for_prompt(specs)) - GPU: NVIDIA A100 (Ampere) - - Peak Memory Bandwidth: 1555 GB/s - - Peak FP32 Performance: 19.5 TFLOPS - - Peak FP16 Performance: 312.0 TFLOPS - - SM Count: 108 - - L1 Cache: 192 KB per SM - - L2 Cache: 40 MB - """ - lines = [ - f"GPU: {specs['name']} ({specs.get('architecture', 'Unknown')})", - f"- Peak Memory Bandwidth: {specs['peak_memory_bw_gbps']} GB/s", - f"- Peak FP32 Performance: {specs['peak_fp32_tflops']} TFLOPS", - ] - - if specs.get("peak_fp16_tflops", 0) > 0: - lines.append(f"- Peak FP16 Performance: {specs['peak_fp16_tflops']} TFLOPS") - - if specs.get("peak_bf16_tflops", 0) > 0: - lines.append(f"- Peak BF16 Performance: {specs['peak_bf16_tflops']} TFLOPS") - - lines.extend( - [ - f"- SM Count: {specs['sm_count']}", - f"- Max Threads per SM: {specs['max_threads_per_sm']}", - f"- L1 Cache: {specs['l1_cache_kb']} KB per SM", - f"- L2 Cache: {specs['l2_cache_mb']} MB", - f"- Memory: {specs['memory_gb']} GB {specs.get('memory_type', '')}", - ] - ) - - return "\n".join(lines) - - -if __name__ == "__main__": - print("GPU Specifications Module") - print("=" * 60) - - # Auto-detect GPU - detected_name = query_gpu_name() - if detected_name: - print(f"\nDetected GPU: {detected_name}") - else: - print("\nNo GPU detected (nvidia-smi not available)") - - # Get specs - specs = get_gpu_specs() - print(f"\n{format_gpu_specs_for_prompt(specs)}") - - # Show all available GPUs - print(f"\n{'=' * 60}") - print("Available GPU specifications in database:") - for gpu_name in sorted(GPU_SPECS_DATABASE.keys()): - print(f" - {gpu_name}") diff --git a/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py deleted file mode 100644 index 0b9168b..0000000 --- a/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py +++ /dev/null @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -NCU Profiling Module for Triton Kernels - -This module wraps three tasks: -1) Collect core metrics for Triton CUDA kernels with Nsight Compute into CSV (`profile_triton_kernel`). -2) Extract and clean those metrics into a DataFrame from the CSV (`load_ncu_metrics`). -3) Convert the metrics table into a string suitable for inclusion in an LLM prompt (`metrics_to_prompt`). - -""" - -import json -import math -import os -import shutil -import subprocess -import sys -from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Union - -import numpy as np -import pandas as pd - - -__all__ = [ - "METRICS", - "METRIC_COLUMNS", - "profile_triton_kernel", - "load_ncu_metrics", - "metrics_to_prompt", -] - -METRICS = ",".join( - [ - "sm__cycles_active.avg", - "sm__warps_active.avg.pct_of_peak_sustained_active", - "launch__occupancy_limit_blocks", - "launch__occupancy_limit_registers", - "launch__occupancy_limit_shared_mem", - "launch__registers_per_thread", - "sm__inst_executed.sum", - "sm__inst_executed_pipe_fp32.avg.pct_of_peak_sustained_active", - "sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", - "dram__bytes_read.sum", - "dram__bytes_write.sum", - "dram__throughput.avg.pct_of_peak_sustained_elapsed", - "dram__bytes.sum.per_second", - "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed", - "launch__shared_mem_per_block_allocated", - "l1tex__t_sector_hit_rate.pct", - "l1tex__throughput.avg.pct_of_peak_sustained_active", - "lts__t_sector_hit_rate.pct", - "lts__throughput.avg.pct_of_peak_sustained_active", - "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", - "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", - "smsp__warp_issue_stalled_memory_dependency_per_warp_active.pct", - "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", - "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", - "smsp__warp_issue_stalled_barrier_per_warp_active.pct", - "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", - "smsp__sass_average_branch_targets_threads_uniform.pct", - ] -) - -# List version for convenient header selection -METRIC_COLUMNS: List[str] = [s.strip() for s in METRICS.split(",")] - - -def profile_triton_kernel( - benchmark_script: Path, - workdir: Path, - out_csv: str = "ncu_output.csv", - python_executable: Optional[str] = None, - ncu_bin: Optional[str] = None, - launch_count: int = 20, - timeout: int = 120, -) -> Path: - """ - Profile a Triton kernel using NCU. - - Args: - benchmark_script: Path to benchmark script that calls the kernel - workdir: Working directory for execution - out_csv: Output CSV filename - python_executable: Python executable to use (default: sys.executable) - ncu_bin: Path to NCU binary (default: auto-detect) - launch_count: Number of kernel launches to profile - timeout: Timeout in seconds for NCU execution - - Returns: - Path to output CSV file - - Raises: - RuntimeError: If NCU profiling fails - FileNotFoundError: If NCU binary or output CSV not found - """ - # Resolve paths - if python_executable is None: - python_executable = sys.executable - - if ncu_bin is None: - ncu_bin = shutil.which("ncu") or "/usr/local/cuda/bin/ncu" - - if not Path(ncu_bin).exists(): - raise FileNotFoundError(f"NCU binary not found: {ncu_bin}") - - csv_path = (workdir / out_csv).resolve() - benchmark_script = benchmark_script.resolve() - - if not benchmark_script.exists(): - raise FileNotFoundError(f"Benchmark script not found: {benchmark_script}") - - # Preserve important environment variables - env = os.environ.copy() - - # Add Triton-specific environment variables - env["TRITON_CACHE_DIR"] = str(workdir / ".triton_cache") - - preserve = ",".join( - [ - "PATH", - "LD_LIBRARY_PATH", - "CUDA_VISIBLE_DEVICES", - "PYTHONPATH", - "TRITON_CACHE_DIR", - "TORCH_EXTENSIONS_DIR", - "CONDA_PREFIX", - "CONDA_DEFAULT_ENV", - ] - ) - - # Build NCU command - cmd = [ - "sudo", - "-E", - f"--preserve-env={preserve}", - ncu_bin, - "--csv", - "--page=raw", - "--kernel-name-base=demangled", - "--target-processes=all", - "--replay-mode=kernel", - "--profile-from-start=on", - f"--log-file={str(csv_path)}", - f"--metrics={METRICS}", - "--launch-skip=0", - f"--launch-count={launch_count}", - python_executable, - str(benchmark_script), - ] - - print(f"[NCU] Running profiling...") - print(f"[NCU] Benchmark: {benchmark_script.name}") - print(f"[NCU] Output: {csv_path}") - print(f"[NCU] Command: {' '.join(cmd[:10])}... (truncated)") - - try: - result = subprocess.run( - cmd, - cwd=str(workdir), - env=env, - capture_output=True, - text=True, - timeout=timeout, - ) - - if result.returncode != 0: - error_msg = result.stderr or result.stdout - raise RuntimeError( - f"NCU profiling failed with return code {result.returncode}:\n{error_msg[:500]}" - ) - - if not csv_path.exists(): - raise FileNotFoundError(f"NCU did not create output CSV: {csv_path}") - - # Check if CSV has content - csv_size = csv_path.stat().st_size - if csv_size < 100: - raise RuntimeError( - f"NCU CSV file is too small ({csv_size} bytes), likely empty" - ) - - print(f"[NCU] ✓ Profiling completed successfully") - print(f"[NCU] ✓ CSV written: {csv_path} ({csv_size} bytes)") - return csv_path - - except subprocess.TimeoutExpired: - raise RuntimeError(f"NCU profiling timed out after {timeout} seconds") - except Exception as e: - raise RuntimeError(f"NCU profiling failed: {e}") - - -def load_ncu_metrics( - csv_path: Union[str, Path], - columns: Optional[Sequence[str]] = None, - extra_keep: Optional[Sequence[str]] = ("Kernel Name",), - coerce_numeric: bool = True, - name_list: Optional[Sequence[str]] = None, - select: str = "last", -) -> pd.DataFrame: - """ - Load and parse NCU metrics from CSV file. - - Args: - csv_path: Path to NCU CSV output - columns: Specific metric columns to load (default: all METRIC_COLUMNS) - extra_keep: Additional columns to keep (e.g., "Kernel Name") - coerce_numeric: Convert metric values to numeric - name_list: Filter by kernel names (substring match) - select: Selection policy when multiple rows per name: - "first", "last", "max_cycles" - - Returns: - DataFrame with parsed metrics - - Raises: - FileNotFoundError: If CSV file not found - ValueError: If no requested columns found in CSV - """ - csv_path = Path(csv_path) - if not csv_path.exists(): - raise FileNotFoundError(f"CSV not found: {csv_path}") - - df = pd.read_csv(csv_path, comment="=", low_memory=False) - - metric_cols = list(columns) if columns is not None else METRIC_COLUMNS - keep_cols: List[str] = [] - if extra_keep: - keep_cols.extend([c for c in extra_keep if c in df.columns]) - keep_cols.extend([c for c in metric_cols if c in df.columns]) - - if not keep_cols: - raise ValueError("No requested columns found in the CSV header.") - - sub = df[keep_cols].copy() - - # Drop the units row (first row often contains units like "%", "inst", etc.) - if len(sub) > 0: - first_row_str = sub.iloc[0].astype(str).str.lower() - unit_tokens = ("%", "inst", "cycle", "block", "register", "register/thread") - if first_row_str.apply(lambda x: any(tok in x for tok in unit_tokens)).any(): - sub = sub.iloc[1:].reset_index(drop=True) - - # Coerce metrics to numeric - if coerce_numeric: - metric_in_sub = [c for c in metric_cols if c in sub.columns] - sub[metric_in_sub] = ( - sub[metric_in_sub] - .replace({",": "", "%": ""}, regex=True) - .apply(pd.to_numeric, errors="coerce") - ) - - # Filter by kernel name list if provided - if name_list: - results = [] - for name in name_list: - # Use contains match instead of exact equality (for Triton's long kernel names) - matched = sub[ - sub["Kernel Name"].astype(str).str.contains(name, regex=False, na=False) - ] - if matched.empty: - continue - if len(matched) > 1: - if select == "first": - row = matched.iloc[[0]] - elif select == "last": - row = matched.iloc[[-1]] - elif ( - select == "max_cycles" - and "sm__cycles_active.avg" in matched.columns - ): - row = matched.sort_values( - "sm__cycles_active.avg", ascending=False - ).head(1) - else: - row = matched.iloc[[-1]] # fallback - else: - row = matched - results.append(row) - - if results: - sub = pd.concat(results, ignore_index=True) - else: - sub = pd.DataFrame(columns=keep_cols) - elif select in ("first", "last", "max_cycles"): - # Apply selection to all rows if no name filter - if len(sub) > 0: - if select == "first": - sub = sub.iloc[[0]] - elif select == "last": - sub = sub.iloc[[-1]] - elif select == "max_cycles" and "sm__cycles_active.avg" in sub.columns: - sub = sub.sort_values("sm__cycles_active.avg", ascending=False).head(1) - - return sub - - -def metrics_to_prompt( - df: pd.DataFrame, - title: str = "GPU Profiling Metrics:", - key_by: str = "Kernel Name", - round_digits: Optional[int] = 3, - compact: bool = False, - keep_cols: Optional[List[str]] = None, -) -> str: - """ - Convert NCU metrics DataFrame to JSON string for LLM prompts. - - Returns JSON in format: - { - "": { "": , ... } - } - Args: - df: DataFrame with NCU metrics - title: Title for the metrics section (not included in output) - key_by: Column to use as key (usually "Kernel Name") - round_digits: Number of decimal places for rounding - compact: If True, use compact JSON (no indentation) - keep_cols: Specific columns to include in output - - Returns: - JSON string with metrics - """ - - def _safe(v: Any) -> Any: - """Convert values to JSON-safe format.""" - if v is None or (isinstance(v, float) and math.isnan(v)): - return None - if isinstance(v, (pd.Timestamp, pd.Timedelta, pd.Interval)): - return str(v) - if isinstance(v, np.generic): - v = v.item() - if isinstance(v, float) and math.isinf(v): - return "inf" if v > 0 else "-inf" - if isinstance(v, float) and round_digits is not None: - return round(v, round_digits) - return v - - # Empty table - if df is None or df.empty: - return "{}" - - cols = list(df.columns) - - # Round numeric columns - if round_digits is not None: - num_cols = df.select_dtypes(include="number").columns - if len(num_cols) > 0: - df = df.copy() - df[num_cols] = df[num_cols].round(round_digits) - - # If key column is missing, return a list of rows - if key_by not in cols: - rows = [ - {k: _safe(v) for k, v in rec.items()} - for rec in df.to_dict(orient="records") - ] - return json.dumps(rows, ensure_ascii=False, indent=None if compact else 2) - - # Determine value columns - value_cols = [c for c in cols if c != key_by] - if keep_cols is not None: - value_cols = [c for c in value_cols if c in keep_cols] - - data: Dict[str, Any] = {} - for rec in df[[key_by] + value_cols].to_dict(orient="records"): - k = str(rec.pop(key_by)) - val_obj = {ck: _safe(cv) for ck, cv in rec.items()} - if k in data: - # Multiple rows for same key - convert to list - if isinstance(data[k], list): - data[k].append(val_obj) - else: - data[k] = [data[k], val_obj] - else: - data[k] = val_obj - - return json.dumps(data, ensure_ascii=False, indent=None if compact else 2) - - -if __name__ == "__main__": - print("ncu_profiler module loaded.") - print("Import its functions in your scripts:") - print( - " from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import profile_triton_kernel, load_ncu_metrics, metrics_to_prompt" - ) diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index 61f41b3..693fa8a 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -10,7 +10,7 @@ import pandas as pd -from kernel_perf_util.kernel_opt.profiler.ncu_profiler import ( +from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( load_ncu_metrics, metrics_to_prompt, profile_triton_kernel, From d87167857f31c5ee87a0a121a36312b18f82d9ba Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 14:37:08 -0800 Subject: [PATCH 07/22] move gpu spec.py to future PR and fix import --- .../kernel_opt/profiler/gpu_specs.py | 227 ------------------ .../profiling/kernel_profiler.py | 4 +- 2 files changed, 3 insertions(+), 228 deletions(-) delete mode 100644 kernel_perf_agent/kernel_opt/profiler/gpu_specs.py diff --git a/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py b/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py deleted file mode 100644 index ed8b50f..0000000 --- a/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py +++ /dev/null @@ -1,227 +0,0 @@ -#!/usr/bin/env python3 -""" -GPU Specifications Database for Bottleneck Analysis - -This module provides GPU hardware specifications needed for performance analysis -and bottleneck identification. It includes peak compute performance, memory bandwidth, -cache sizes, and SM counts for common NVIDIA GPUs. - -""" - -import re -import subprocess -from typing import Any, Dict, Optional - - -# GPU specifications database -# Sources: NVIDIA official specifications, manufacturer datasheets -GPU_SPECS_DATABASE = { - "NVIDIA A100": { - "name": "NVIDIA A100", - "architecture": "Ampere", - "peak_fp32_tflops": 19.5, - "peak_fp16_tflops": 312.0, - "peak_bf16_tflops": 312.0, - "peak_memory_bw_gbps": 1555, - "sm_count": 108, - "max_threads_per_sm": 2048, - "l1_cache_kb": 192, - "l2_cache_mb": 40, - "memory_gb": 40, - "memory_type": "HBM2e", - }, - "NVIDIA H100": { - "name": "NVIDIA H100", - "architecture": "Hopper", - "peak_fp32_tflops": 51.0, - "peak_fp16_tflops": 989.0, - "peak_bf16_tflops": 989.0, - "peak_memory_bw_gbps": 3352, - "sm_count": 132, - "max_threads_per_sm": 2048, - "l1_cache_kb": 256, - "l2_cache_mb": 50, - "memory_gb": 80, - "memory_type": "HBM3", - }, - "NVIDIA RTX 4090": { - "name": "NVIDIA RTX 4090", - "architecture": "Ada Lovelace", - "peak_fp32_tflops": 82.6, - "peak_fp16_tflops": 165.0, - "peak_bf16_tflops": 165.0, - "peak_memory_bw_gbps": 1008, - "sm_count": 128, - "max_threads_per_sm": 1536, - "l1_cache_kb": 128, - "l2_cache_mb": 72, - "memory_gb": 24, - "memory_type": "GDDR6X", - }, - "NVIDIA RTX 5080": { - "name": "NVIDIA RTX 5080", - "architecture": "Blackwell", - "peak_fp32_tflops": 57.0, - "peak_fp16_tflops": 114.0, - "peak_bf16_tflops": 114.0, - "peak_memory_bw_gbps": 960, - "sm_count": 84, - "max_threads_per_sm": 1536, - "l1_cache_kb": 128, - "l2_cache_mb": 64, - "memory_gb": 16, - "memory_type": "GDDR7", - }, -} - - -def query_gpu_name() -> Optional[str]: - """ - Query GPU name using nvidia-smi. - - Returns: - GPU name string, or None if query fails - """ - try: - result = subprocess.run( - ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], - capture_output=True, - text=True, - timeout=5, - ) - if result.returncode == 0: - gpu_name = result.stdout.strip() - # Normalize whitespace - gpu_name = re.sub(r"\s+", " ", gpu_name) - return gpu_name - except (subprocess.TimeoutExpired, FileNotFoundError, Exception): - pass - return None - - -def get_gpu_specs(gpu_name: Optional[str] = None) -> Dict[str, Any]: - """ - Get GPU specifications for bottleneck analysis. - - This function returns hardware specifications needed for performance analysis, - including peak compute performance, memory bandwidth, cache sizes, and SM counts. - - Args: - gpu_name: GPU name (if None, auto-detect with nvidia-smi) - - Returns: - Dictionary with GPU specifications containing: - - name: GPU name - - architecture: GPU architecture (e.g., "Ampere", "Hopper") - - peak_fp32_tflops: Peak FP32 compute performance in TFLOPS - - peak_fp16_tflops: Peak FP16 compute performance in TFLOPS - - peak_bf16_tflops: Peak BF16 compute performance in TFLOPS (0 if not supported) - - peak_memory_bw_gbps: Peak memory bandwidth in GB/s - - sm_count: Number of streaming multiprocessors - - max_threads_per_sm: Maximum threads per SM - - l1_cache_kb: L1 cache size in KB per SM - - l2_cache_mb: Total L2 cache size in MB - - memory_gb: Total GPU memory in GB - - memory_type: Memory type (e.g., "HBM2e", "GDDR6X") - - Examples: - >>> specs = get_gpu_specs() # Auto-detect - >>> print(f"Peak BW: {specs['peak_memory_bw_gbps']} GB/s") - - >>> specs = get_gpu_specs("NVIDIA A100") - >>> print(f"SM Count: {specs['sm_count']}") - """ - # Auto-detect if not provided - if gpu_name is None: - gpu_name = query_gpu_name() - - # Return default if detection failed - if gpu_name is None: - print("⚠️ GPU auto-detection failed, using A100 specs as fallback") - return GPU_SPECS_DATABASE["NVIDIA A100"].copy() - - # Try exact match - if gpu_name in GPU_SPECS_DATABASE: - return GPU_SPECS_DATABASE[gpu_name].copy() - - # Try fuzzy match (contains or partial match) - gpu_name_lower = gpu_name.lower() - for key, specs in GPU_SPECS_DATABASE.items(): - key_lower = key.lower() - # Check if either name contains the other - if gpu_name_lower in key_lower or key_lower in gpu_name_lower: - print(f"ℹ️ Matched '{gpu_name}' to '{key}' (fuzzy match)") - return specs.copy() - - # Fallback to A100 specs with warning - print(f"⚠️ Unknown GPU: '{gpu_name}', using A100 specs as fallback") - print(f" Available GPUs: {', '.join(GPU_SPECS_DATABASE.keys())}") - return GPU_SPECS_DATABASE["NVIDIA A100"].copy() - - -def format_gpu_specs_for_prompt(specs: Dict[str, Any]) -> str: - """ - Format GPU specifications as a readable string for LLM prompts. - - Args: - specs: GPU specifications dictionary from get_gpu_specs() - - Returns: - Formatted string with GPU specifications - - Example: - >>> specs = get_gpu_specs("NVIDIA A100") - >>> print(format_gpu_specs_for_prompt(specs)) - GPU: NVIDIA A100 (Ampere) - - Peak Memory Bandwidth: 1555 GB/s - - Peak FP32 Performance: 19.5 TFLOPS - - Peak FP16 Performance: 312.0 TFLOPS - - SM Count: 108 - - L1 Cache: 192 KB per SM - - L2 Cache: 40 MB - """ - lines = [ - f"GPU: {specs['name']} ({specs.get('architecture', 'Unknown')})", - f"- Peak Memory Bandwidth: {specs['peak_memory_bw_gbps']} GB/s", - f"- Peak FP32 Performance: {specs['peak_fp32_tflops']} TFLOPS", - ] - - if specs.get("peak_fp16_tflops", 0) > 0: - lines.append(f"- Peak FP16 Performance: {specs['peak_fp16_tflops']} TFLOPS") - - if specs.get("peak_bf16_tflops", 0) > 0: - lines.append(f"- Peak BF16 Performance: {specs['peak_bf16_tflops']} TFLOPS") - - lines.extend( - [ - f"- SM Count: {specs['sm_count']}", - f"- Max Threads per SM: {specs['max_threads_per_sm']}", - f"- L1 Cache: {specs['l1_cache_kb']} KB per SM", - f"- L2 Cache: {specs['l2_cache_mb']} MB", - f"- Memory: {specs['memory_gb']} GB {specs.get('memory_type', '')}", - ] - ) - - return "\n".join(lines) - - -if __name__ == "__main__": - print("GPU Specifications Module") - print("=" * 60) - - # Auto-detect GPU - detected_name = query_gpu_name() - if detected_name: - print(f"\nDetected GPU: {detected_name}") - else: - print("\nNo GPU detected (nvidia-smi not available)") - - # Get specs - specs = get_gpu_specs() - print(f"\n{format_gpu_specs_for_prompt(specs)}") - - # Show all available GPUs - print(f"\n{'=' * 60}") - print("Available GPU specifications in database:") - for gpu_name in sorted(GPU_SPECS_DATABASE.keys()): - print(f" - {gpu_name}") diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index 693fa8a..27825ea 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -16,7 +16,9 @@ profile_triton_kernel, ) -from .ncu_wrapper_generator import NCUWrapperGenerator +from triton_kernel_agent.opt_worker_component.profiling.ncu_wrapper_generator import ( + NCUWrapperGenerator, +) class KernelProfiler: From db0c75474f8b964610de73b15c4a40a578125df6 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 14:44:11 -0800 Subject: [PATCH 08/22] Add copyright header --- kernel_perf_agent/README.md | 7 +------ .../kernel_opt/profiler/__init__.py | 18 ++++++++++++++++++ .../kernel_opt/profiler/ncu_profiler.py | 15 +++++++++++++-- .../opt_worker_component/profiling/__init__.py | 16 +++++++++++++++- .../profiling/kernel_profiler.py | 14 ++++++++++++++ .../profiling/ncu_wrapper_generator.py | 15 +++++++++++++++ 6 files changed, 76 insertions(+), 9 deletions(-) diff --git a/kernel_perf_agent/README.md b/kernel_perf_agent/README.md index 5114402..e27dfec 100644 --- a/kernel_perf_agent/README.md +++ b/kernel_perf_agent/README.md @@ -1,8 +1,3 @@ # Kernel Performance Utilities -Low-level, reusable utilities for kernel optimization: -- NCU profiling -- RAG-based optimization database -- Performance analysis tools - -These are **library components** that can be used by any worker. +Low-level, reusable utilities for kernel optimization. diff --git a/kernel_perf_agent/kernel_opt/profiler/__init__.py b/kernel_perf_agent/kernel_opt/profiler/__init__.py index e69de29..d177194 100644 --- a/kernel_perf_agent/kernel_opt/profiler/__init__.py +++ b/kernel_perf_agent/kernel_opt/profiler/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""kernel_perf_agent package.""" + +# Kernel Perf Agent package +__all__ = [] diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py index 0b9168b..2a09353 100644 --- a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -1,5 +1,16 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ NCU Profiling Module for Triton Kernels diff --git a/triton_kernel_agent/opt_worker_component/profiling/__init__.py b/triton_kernel_agent/opt_worker_component/profiling/__init__.py index cb28cc6..218ca6a 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/__init__.py +++ b/triton_kernel_agent/opt_worker_component/profiling/__init__.py @@ -1,6 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Profiling infrastructure for NCU-based kernel analysis.""" -from .ncu_wrapper_generator import NCUWrapperGenerator from .kernel_profiler import KernelProfiler +from .ncu_wrapper_generator import NCUWrapperGenerator __all__ = ["NCUWrapperGenerator", "KernelProfiler"] diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index 27825ea..e1de43e 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Kernel profiling with NCU.""" import json diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py index 9f0e2e5..64fd1c1 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """NCU wrapper script generation for kernel profiling.""" import logging @@ -6,6 +20,7 @@ try: from jinja2 import Template + HAS_JINJA2 = True except ImportError: HAS_JINJA2 = False From cd297596bbdf9d4a99f5e610abcec89ae48a979c Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 14:52:31 -0800 Subject: [PATCH 09/22] fix ruff --- kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py index 2a09353..5362073 100644 --- a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -162,7 +162,7 @@ def profile_triton_kernel( str(benchmark_script), ] - print(f"[NCU] Running profiling...") + print("[NCU] Running profiling...") print(f"[NCU] Benchmark: {benchmark_script.name}") print(f"[NCU] Output: {csv_path}") print(f"[NCU] Command: {' '.join(cmd[:10])}... (truncated)") @@ -193,7 +193,7 @@ def profile_triton_kernel( f"NCU CSV file is too small ({csv_size} bytes), likely empty" ) - print(f"[NCU] ✓ Profiling completed successfully") + print("[NCU] ✓ Profiling completed successfully") print(f"[NCU] ✓ CSV written: {csv_path} ({csv_size} bytes)") return csv_path From bbfa6cd02ab9e496b7330d745b47621ed62ed92a Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Mon, 12 Jan 2026 16:58:32 -0800 Subject: [PATCH 10/22] address previous comments --- .../kernel_opt/profiler/ncu_profiler.py | 152 +++++++---- .../profiling/__init__.py | 4 +- .../profiling/kernel_profiler.py | 241 +++++++++++------- ...er_generator.py => ncu_wrapper_factory.py} | 57 +---- 4 files changed, 265 insertions(+), 189 deletions(-) rename triton_kernel_agent/opt_worker_component/profiling/{ncu_wrapper_generator.py => ncu_wrapper_factory.py} (65%) diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py index 5362073..710d5c6 100644 --- a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -16,10 +16,11 @@ NCU Profiling Module for Triton Kernels This module wraps three tasks: -1) Collect core metrics for Triton CUDA kernels with Nsight Compute into CSV (`profile_triton_kernel`). +1) Collect core metrics for Triton kernels with Nsight Compute into CSV (`profile_triton_kernel`). 2) Extract and clean those metrics into a DataFrame from the CSV (`load_ncu_metrics`). 3) Convert the metrics table into a string suitable for inclusion in an LLM prompt (`metrics_to_prompt`). + """ import json @@ -28,20 +29,14 @@ import shutil import subprocess import sys +from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Union import numpy as np import pandas as pd - -__all__ = [ - "METRICS", - "METRIC_COLUMNS", - "profile_triton_kernel", - "load_ncu_metrics", - "metrics_to_prompt", -] +# This selection of metrics is inspried by the CudaForge team (https://github.com/OptimAI-Lab/CudaForge/tree/main) METRICS = ",".join( [ @@ -79,6 +74,14 @@ METRIC_COLUMNS: List[str] = [s.strip() for s in METRICS.split(",")] +class MetricSelectionPolicy(Enum): + """Policy for selecting rows when multiple rows exist for a kernel.""" + + FIRST = "first" + LAST = "last" + MAX_CYCLES = "max_cycles" + + def profile_triton_kernel( benchmark_script: Path, workdir: Path, @@ -203,13 +206,83 @@ def profile_triton_kernel( raise RuntimeError(f"NCU profiling failed: {e}") +def _apply_selection_policy( + df: pd.DataFrame, + policy: MetricSelectionPolicy, +) -> pd.DataFrame: + """ + Apply selection policy to choose a single row from a DataFrame. + + Args: + df: DataFrame with one or more rows + policy: Selection policy to apply + + Returns: + DataFrame with a single row based on the policy + """ + if df.empty: + return df + + if len(df) == 1: + return df + + if policy == MetricSelectionPolicy.FIRST: + return df.iloc[[0]] + elif policy == MetricSelectionPolicy.LAST: + return df.iloc[[-1]] + elif policy == MetricSelectionPolicy.MAX_CYCLES: + if "sm__cycles_active.avg" in df.columns: + return df.sort_values("sm__cycles_active.avg", ascending=False).head(1) + # Fallback to last if cycles column not available + return df.iloc[[-1]] + else: + # Fallback to last for unknown policies + return df.iloc[[-1]] + + +def _filter_by_kernel_names( + df: pd.DataFrame, + name_list: Sequence[str], + policy: MetricSelectionPolicy, + keep_cols: List[str], +) -> pd.DataFrame: + """ + Filter DataFrame by kernel names with substring matching. + + Args: + df: DataFrame with NCU metrics + name_list: List of kernel name substrings to match + policy: Selection policy when multiple rows match + keep_cols: Columns to preserve in empty result + + Returns: + Filtered DataFrame with one row per matched kernel name + """ + results = [] + for name in name_list: + # Use contains match instead of exact equality (for Triton's long kernel names) + matched = df[ + df["Kernel Name"].astype(str).str.contains(name, regex=False, na=False) + ] + if matched.empty: + continue + + row = _apply_selection_policy(matched, policy) + results.append(row) + + if results: + return pd.concat(results, ignore_index=True) + else: + return pd.DataFrame(columns=keep_cols) + + def load_ncu_metrics( csv_path: Union[str, Path], columns: Optional[Sequence[str]] = None, extra_keep: Optional[Sequence[str]] = ("Kernel Name",), coerce_numeric: bool = True, name_list: Optional[Sequence[str]] = None, - select: str = "last", + select: Union[str, MetricSelectionPolicy] = MetricSelectionPolicy.LAST, ) -> pd.DataFrame: """ Load and parse NCU metrics from CSV file. @@ -220,20 +293,32 @@ def load_ncu_metrics( extra_keep: Additional columns to keep (e.g., "Kernel Name") coerce_numeric: Convert metric values to numeric name_list: Filter by kernel names (substring match) - select: Selection policy when multiple rows per name: - "first", "last", "max_cycles" + select: Selection policy when multiple rows per name. + Can be MetricSelectionPolicy enum or string ("first", "last", "max_cycles") Returns: DataFrame with parsed metrics Raises: FileNotFoundError: If CSV file not found - ValueError: If no requested columns found in CSV + ValueError: If no requested columns found in CSV or invalid select value """ csv_path = Path(csv_path) if not csv_path.exists(): raise FileNotFoundError(f"CSV not found: {csv_path}") + # Convert string to enum if needed + if isinstance(select, str): + try: + policy = MetricSelectionPolicy(select) + except ValueError: + raise ValueError( + f"Invalid select value: {select}. " + f"Must be one of: {[p.value for p in MetricSelectionPolicy]}" + ) + else: + policy = select + df = pd.read_csv(csv_path, comment="=", low_memory=False) metric_cols = list(columns) if columns is not None else METRIC_COLUMNS @@ -265,45 +350,10 @@ def load_ncu_metrics( # Filter by kernel name list if provided if name_list: - results = [] - for name in name_list: - # Use contains match instead of exact equality (for Triton's long kernel names) - matched = sub[ - sub["Kernel Name"].astype(str).str.contains(name, regex=False, na=False) - ] - if matched.empty: - continue - if len(matched) > 1: - if select == "first": - row = matched.iloc[[0]] - elif select == "last": - row = matched.iloc[[-1]] - elif ( - select == "max_cycles" - and "sm__cycles_active.avg" in matched.columns - ): - row = matched.sort_values( - "sm__cycles_active.avg", ascending=False - ).head(1) - else: - row = matched.iloc[[-1]] # fallback - else: - row = matched - results.append(row) - - if results: - sub = pd.concat(results, ignore_index=True) - else: - sub = pd.DataFrame(columns=keep_cols) - elif select in ("first", "last", "max_cycles"): + sub = _filter_by_kernel_names(sub, name_list, policy, keep_cols) + else: # Apply selection to all rows if no name filter - if len(sub) > 0: - if select == "first": - sub = sub.iloc[[0]] - elif select == "last": - sub = sub.iloc[[-1]] - elif select == "max_cycles" and "sm__cycles_active.avg" in sub.columns: - sub = sub.sort_values("sm__cycles_active.avg", ascending=False).head(1) + sub = _apply_selection_policy(sub, policy) return sub diff --git a/triton_kernel_agent/opt_worker_component/profiling/__init__.py b/triton_kernel_agent/opt_worker_component/profiling/__init__.py index 218ca6a..8f8d18b 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/__init__.py +++ b/triton_kernel_agent/opt_worker_component/profiling/__init__.py @@ -15,6 +15,6 @@ """Profiling infrastructure for NCU-based kernel analysis.""" from .kernel_profiler import KernelProfiler -from .ncu_wrapper_generator import NCUWrapperGenerator +from .ncu_wrapper_factory import NCUWrapperFactory -__all__ = ["NCUWrapperGenerator", "KernelProfiler"] +__all__ = ["NCUWrapperFactory", "KernelProfiler"] diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index e1de43e..bcc5a70 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Kernel profiling with NCU.""" +"""Profiles Triton kernels using NVIDIA Nsight Compute (NCU).""" import json import logging import subprocess import time +from dataclasses import asdict, dataclass from datetime import datetime +from functools import cached_property from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict import pandas as pd @@ -30,42 +32,84 @@ profile_triton_kernel, ) -from triton_kernel_agent.opt_worker_component.profiling.ncu_wrapper_generator import ( - NCUWrapperGenerator, +from triton_kernel_agent.opt_worker_component.profiling.ncu_wrapper_factory import ( + NCUWrapperFactory, ) +# Default timeout for NCU profiling in seconds +DEFAULT_NCU_TIMEOUT_SECONDS = 120 + + +@dataclass +class ProfilerMetadata: + """Metadata about a profiling run.""" + + kernel_file: str + problem_file: str + round_num: int + timestamp: str + ncu_version: str | None + + +@dataclass +class ProfilerResults: + """ + Results from a kernel profiling run. + + This dataclass encapsulates both the metrics DataFrame and the parsed + metrics dictionary, along with metadata about the profiling run. + """ + + metrics_df: pd.DataFrame + metrics: Dict[str, Any] + metadata: ProfilerMetadata + + def to_dict(self) -> Dict[str, Any]: + """Convert results to a dictionary for JSON serialization.""" + return { + "metrics": self.metrics, + "metadata": asdict(self.metadata), + } + + def to_json(self) -> str: + """Convert results to a JSON string.""" + return json.dumps(self.to_dict(), indent=2) + class KernelProfiler: - """Profiles Triton kernels using NVIDIA Nsight Compute (NCU).""" def __init__( self, logger: logging.Logger, - temp_dir: Path, + artifacts_dir: Path, logs_dir: Path, - ncu_bin_path: Optional[str] = None, + ncu_bin_path: str | None = None, + ncu_timeout_seconds: int = DEFAULT_NCU_TIMEOUT_SECONDS, ): """ Initialize the kernel profiler. Args: logger: Logger instance - temp_dir: Temporary directory for profiling artifacts + artifacts_dir: Directory for optimization artifacts logs_dir: Directory for saving profiling logs ncu_bin_path: Path to NCU binary (auto-detect if None) + ncu_timeout_seconds: Timeout for NCU profiling in seconds """ self.logger = logger - self.temp_dir = temp_dir + self.artifacts_dir = artifacts_dir self.logs_dir = logs_dir self.ncu_bin_path = ncu_bin_path - self.wrapper_generator = NCUWrapperGenerator(logger) + self.ncu_timeout_seconds = ncu_timeout_seconds + self.wrapper_factory = NCUWrapperFactory(logger) - def _get_ncu_version(self) -> Optional[str]: + @cached_property + def ncu_version(self) -> str | None: """ - Get NCU version string. + NCU version string (cached). Returns: - NCU version string or None if failed + Version string like "2025.2.1.0" or None if unavailable """ try: ncu_cmd = self.ncu_bin_path or "ncu" @@ -76,21 +120,35 @@ def _get_ncu_version(self) -> Optional[str]: timeout=5, ) if result.returncode == 0: - # Extract version from output (e.g., "NVIDIA Nsight Compute 2024.3.1") - version_line = result.stdout.strip().split("\n")[0] - return version_line + # Extract version from output + # Example: "Version 2025.2.1.0 (build 35987062) (public-release)" + for line in result.stdout.strip().split("\n"): + if line.startswith("Version "): + return line.split()[1] + return None return None except Exception as e: self.logger.debug(f"Failed to get NCU version: {e}") return None + def _wait_with_backoff(self, attempt: int) -> None: + """ + Wait with exponential backoff before retrying. + + Args: + attempt: Current attempt number (1-indexed) + """ + wait_time = 2**attempt # Exponential backoff: 2, 4, 8 seconds + self.logger.warning(f"Retrying in {wait_time}s...") + time.sleep(wait_time) + def profile_kernel( self, kernel_file: Path, problem_file: Path, round_num: int, - max_retries: int = 3, - ) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]: + max_retries: int = 2, + ) -> ProfilerResults | None: """ Profile kernel with NCU (with retry logic). @@ -100,30 +158,28 @@ def profile_kernel( Args: kernel_file: Path to kernel file problem_file: Path to problem file - round_num: Current optimization round number - max_retries: Maximum number of retry attempts (default: 3) + round_num: Current optimization round number (used for file naming + and tracking which optimization iteration this profiling belongs to) + max_retries: Maximum number of retry attempts (default: 2) Returns: - Tuple of (metrics_df, metrics_json) or (None, None) on failure + ProfilerResults containing metrics and metadata, or None on failure Example: - >>> profiler = KernelProfiler(logger, temp_dir, logs_dir) - >>> metrics_df, metrics_json = profiler.profile_kernel( + >>> profiler = KernelProfiler(logger, artifacts_dir, logs_dir) + >>> results = profiler.profile_kernel( ... Path("kernel.py"), Path("problem.py"), round_num=1 ... ) - >>> if metrics_json: - ... print(f"DRAM throughput: {metrics_json['dram__throughput']}") + >>> if results: + ... print(f"DRAM throughput: {results.metrics['dram__throughput']}") """ - wrapper_file = None + # Create NCU wrapper script + wrapper_file = self.wrapper_factory.create_ncu_wrapper( + kernel_file, problem_file, self.artifacts_dir + ) for attempt in range(1, max_retries + 1): try: - # Create NCU wrapper script (cached if unchanged) - if wrapper_file is None: - wrapper_file = self.wrapper_generator.create_ncu_wrapper( - kernel_file, problem_file, self.temp_dir - ) - self.logger.info( f"[Round {round_num}] NCU profiling attempt {attempt}/{max_retries}..." ) @@ -132,98 +188,103 @@ def profile_kernel( csv_file = f"ncu_round_{round_num}.csv" csv_path = profile_triton_kernel( benchmark_script=wrapper_file, - workdir=self.temp_dir, + workdir=self.artifacts_dir, out_csv=csv_file, ncu_bin=self.ncu_bin_path, launch_count=20, - timeout=120, + timeout=self.ncu_timeout_seconds, ) # Load and parse metrics metrics_df = load_ncu_metrics(csv_path, select="last") - metrics_json = json.loads(metrics_to_prompt(metrics_df)) + metrics = json.loads(metrics_to_prompt(metrics_df)) + + # Build ProfilerResults + results = ProfilerResults( + metrics_df=metrics_df, + metrics=metrics, + metadata=ProfilerMetadata( + kernel_file=str(kernel_file), + problem_file=str(problem_file), + round_num=round_num, + timestamp=datetime.utcnow().isoformat() + "Z", + ncu_version=self.ncu_version, + ), + ) # Save metrics with metadata - self._save_metrics_with_metadata( - metrics_json, kernel_file, problem_file, round_num - ) + self._save_profiler_results(results) self.logger.info(f"✅ NCU profiling completed for round {round_num}") - return metrics_df, metrics_json + return results except FileNotFoundError as e: self.logger.error(f"❌ File not found during profiling: {e}") - return None, None + return None except subprocess.TimeoutExpired: - self.logger.error( - f"❌ NCU profiling timed out after 120s (attempt {attempt}/{max_retries})" - ) - if attempt < max_retries: - wait_time = 2**attempt # Exponential backoff: 2, 4, 8 seconds - self.logger.warning(f"Retrying in {wait_time}s...") - time.sleep(wait_time) + is_final_attempt = attempt >= max_retries + if is_final_attempt: + self.logger.error( + f"❌ NCU profiling timed out after {self.ncu_timeout_seconds}s " + f"(final attempt {attempt}/{max_retries})" + ) + return None + else: + self.logger.debug( + f"NCU profiling timed out after {self.ncu_timeout_seconds}s " + f"(attempt {attempt}/{max_retries})" + ) + self._wait_with_backoff(attempt) continue - return None, None except json.JSONDecodeError as e: - self.logger.error(f"❌ Failed to parse NCU metrics: {e}") - if attempt < max_retries: - wait_time = 2**attempt - self.logger.warning(f"Retrying in {wait_time}s...") - time.sleep(wait_time) + is_final_attempt = attempt >= max_retries + if is_final_attempt: + self.logger.error( + f"❌ Failed to parse NCU metrics (final attempt): {e}" + ) + return None + else: + self.logger.debug( + f"Failed to parse NCU metrics (attempt {attempt}/{max_retries}): {e}" + ) + self._wait_with_backoff(attempt) continue - return None, None except Exception as e: - self.logger.error( - f"❌ Unexpected error during profiling (attempt {attempt}/{max_retries}): {e}", - exc_info=True, - ) - if attempt < max_retries: - wait_time = 2**attempt - self.logger.warning(f"Retrying in {wait_time}s...") - time.sleep(wait_time) + is_final_attempt = attempt >= max_retries + if is_final_attempt: + self.logger.error( + f"❌ Unexpected error during profiling (final attempt): {e}", + exc_info=True, + ) + return None + else: + self.logger.debug( + f"Unexpected error during profiling (attempt {attempt}/{max_retries}): {e}" + ) + self._wait_with_backoff(attempt) continue - return None, None # All retries exhausted self.logger.error( f"❌ NCU profiling failed after {max_retries} attempts for round {round_num}" ) - return None, None + return None - def _save_metrics_with_metadata( - self, - metrics_json: Dict[str, Any], - kernel_file: Path, - problem_file: Path, - round_num: int, - ) -> None: + def _save_profiler_results(self, results: ProfilerResults) -> None: """ - Save profiling metrics with additional metadata. + Save profiling results with metadata to a JSON file. Args: - metrics_json: NCU metrics as JSON dict - kernel_file: Path to kernel file - problem_file: Path to problem file - round_num: Current optimization round number + results: ProfilerResults to save """ - metrics_file = self.logs_dir / f"round{round_num:03d}_ncu_metrics.json" - - # Build metadata - metadata = { - "metrics": metrics_json, - "metadata": { - "kernel_file": str(kernel_file), - "problem_file": str(problem_file), - "round_num": round_num, - "timestamp": datetime.utcnow().isoformat() + "Z", - "ncu_version": self._get_ncu_version(), - }, - } + metrics_file = ( + self.logs_dir / f"round{results.metadata.round_num:03d}_ncu_metrics.json" + ) with open(metrics_file, "w") as f: - json.dump(metadata, f, indent=2) + f.write(results.to_json()) self.logger.debug(f"Saved metrics with metadata: {metrics_file}") diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_factory.py similarity index 65% rename from triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py rename to triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_factory.py index 64fd1c1..e3c5a3c 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_factory.py @@ -15,58 +15,38 @@ """NCU wrapper script generation for kernel profiling.""" import logging +from functools import cached_property from pathlib import Path -from typing import Optional -try: - from jinja2 import Template +from jinja2 import Template - HAS_JINJA2 = True -except ImportError: - HAS_JINJA2 = False - -class NCUWrapperGenerator: - """Generates NCU wrapper scripts for profiling Triton kernels.""" +class NCUWrapperFactory: + """Factory for creating NCU wrapper scripts for profiling Triton kernels.""" # Template file path (relative to this file) WRAPPER_TEMPLATE = Path(__file__).parent / "ncu_wrapper_template.j2" def __init__(self, logger: logging.Logger): """ - Initialize the NCU wrapper generator. + Initialize the NCU wrapper factory. Args: logger: Logger instance """ self.logger = logger - self._template_cache: Optional[Template] = None - def _load_template(self) -> Template: + @cached_property + def template(self) -> Template: """ - Load the Jinja2 template (cached). - - Returns: - Jinja2 Template object + Jinja2 template for wrapper script generation. Raises: - ImportError: If Jinja2 is not installed FileNotFoundError: If template file doesn't exist """ - if self._template_cache is not None: - return self._template_cache - - if not HAS_JINJA2: - raise ImportError( - "Jinja2 is required for wrapper generation. " - "Install it with: pip install jinja2" - ) - if not self.WRAPPER_TEMPLATE.exists(): raise FileNotFoundError(f"Template not found: {self.WRAPPER_TEMPLATE}") - - self._template_cache = Template(self.WRAPPER_TEMPLATE.read_text()) - return self._template_cache + return Template(self.WRAPPER_TEMPLATE.read_text()) def create_ncu_wrapper( self, @@ -75,7 +55,6 @@ def create_ncu_wrapper( output_dir: Path, dtype_inference: bool = True, model_extraction: bool = True, - use_cache: bool = True, ) -> Path: """ Create NCU wrapper script for profiling. @@ -91,7 +70,6 @@ def create_ncu_wrapper( output_dir: Directory to write wrapper script dtype_inference: Enable automatic dtype inference from kernel source (default: True) model_extraction: Enable model weight extraction for Conv/Linear kernels (default: True) - use_cache: Reuse existing wrapper if files haven't changed (default: True) Returns: Path to created wrapper script @@ -112,21 +90,8 @@ def create_ncu_wrapper( wrapper_file = output_dir / "ncu_wrapper.py" - # Check cache: reuse wrapper if it's fresh - if use_cache and wrapper_file.exists(): - wrapper_mtime = wrapper_file.stat().st_mtime - kernel_mtime = kernel_file.stat().st_mtime - problem_mtime = problem_file.stat().st_mtime - - if wrapper_mtime > kernel_mtime and wrapper_mtime > problem_mtime: - self.logger.info( - f"Reusing cached NCU wrapper (fresher than source files): {wrapper_file}" - ) - return wrapper_file - - # Load template and render - template = self._load_template() - wrapper_content = template.render( + # Render template + wrapper_content = self.template.render( kernel_file_parent=repr(str(kernel_file.parent)), problem_file_parent=repr(str(problem_file.parent)), kernel_module=kernel_file.stem, From 543453a6803ce56a276b4a78540f0e0884b09e45 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Mon, 12 Jan 2026 17:06:54 -0800 Subject: [PATCH 11/22] fix ruff --- .../opt_worker_component/profiling/kernel_profiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index bcc5a70..682c4e7 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -77,7 +77,6 @@ def to_json(self) -> str: class KernelProfiler: - def __init__( self, logger: logging.Logger, From 706c9cc4f14b6fc230250a5e7200c52cc9d4445d Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Thu, 8 Jan 2026 08:20:42 -0800 Subject: [PATCH 12/22] Add unified benchmarking module for kernel performance measurement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consolidates previous kernel_benchmark.py and pytorch_benchmark.py into a streamlined 3-file architecture with clear separation of concerns: Architecture: - benchmark.py (299 lines): Main Benchmark class with simplified API - benchmark_kernel(): Always uses subprocess for crash protection - benchmark_pytorch(): Always uses direct mode for stable code - BenchmarkLockManager: GPU lock management for multi-worker scenarios - timing.py (437 lines): Complete timing infrastructure - Timing: time_with_cuda_events(), time_with_triton_do_bench() - Loading: prepare_pytorch_model(), load_kernel_function() - Stats: compute_timing_stats() with essential metrics (mean/std/min/max) - kernel_subprocess.py (442 lines): Subprocess runner for kernel isolation - Crash protection for potentially buggy kernels - Clean CUDA state between runs - Timeout handling Key improvements: - Eliminated string code generation (was generating Python as strings) - Removed unnecessary statistics (median, p25/p75/p95/p99) - Removed confusing use_subprocess parameter (behavior now deterministic) - Fixed dtype bug causing incorrect speedup measurements - Reduced from 5 files to 3 files with clearer naming - Code reduction: ~1,400 lines → 1,178 lines Simple API: bench = Benchmark(logger, temp_dir, lock, worker_id) pytorch_result = bench.benchmark_pytorch(problem_file) kernel_result = bench.benchmark_kernel(kernel_file, problem_file) speedup = pytorch_result['stats']['mean'] / kernel_result['time_ms'] --- .../benchmarking/__init__.py | 44 ++ .../benchmarking/benchmark.py | 299 ++++++++++++ .../benchmarking/kernel_subprocess.py | 442 ++++++++++++++++++ .../benchmarking/timing.py | 437 +++++++++++++++++ 4 files changed, 1222 insertions(+) create mode 100644 triton_kernel_agent/opt_worker_component/benchmarking/__init__.py create mode 100644 triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py create mode 100644 triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py create mode 100644 triton_kernel_agent/opt_worker_component/benchmarking/timing.py diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py b/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py new file mode 100644 index 0000000..aa90643 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py @@ -0,0 +1,44 @@ +"""Benchmarking system for kernel performance measurement. + +Simplified structure with just 3 files: +- benchmark.py: Main Benchmark class and BenchmarkLockManager +- timing.py: All utilities (timing + model/kernel loading) +- kernel_subprocess.py: Subprocess runner for kernel isolation +""" + +# Core benchmarking +from .benchmark import Benchmark, BenchmarkLockManager + +# All utilities from timing module +from .timing import ( + # Model/kernel loading + CompilationError, + import_module, + load_kernel_function, + load_problem_interface, + prepare_inputs, + prepare_pytorch_model, + # Timing functions + clear_l2_cache, + compute_timing_stats, + time_with_cuda_events, + time_with_triton_do_bench, +) + +__all__ = [ + # Core benchmarking + "Benchmark", + "BenchmarkLockManager", + # Model/kernel loading + "CompilationError", + "import_module", + "load_kernel_function", + "load_problem_interface", + "prepare_inputs", + "prepare_pytorch_model", + # Timing utilities + "clear_l2_cache", + "compute_timing_stats", + "time_with_cuda_events", + "time_with_triton_do_bench", +] diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py new file mode 100644 index 0000000..4af02e4 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py @@ -0,0 +1,299 @@ +"""Unified benchmarking for Triton kernels and PyTorch baselines. + +This module consolidates kernel and PyTorch benchmarking with improved timing +utilities, L2 cache clearing, and comprehensive statistics. +""" + +import json +import logging +import subprocess +import sys +import traceback +from pathlib import Path +from typing import Any, Callable, Optional + +import torch + +from triton_kernel_agent.opt_worker_component.benchmarking.timing import ( + compute_timing_stats, + load_kernel_function, + prepare_pytorch_model, + time_with_cuda_events, + time_with_triton_do_bench, +) + + +class BenchmarkLockManager: + """Manages GPU benchmarking locks to prevent resource contention.""" + + def __init__(self, lock: Optional[Any], worker_id: int, logger: logging.Logger): + """Initialize the lock manager. + + Args: + lock: Shared multiprocessing lock (or None if no locking needed) + worker_id: Worker ID for logging + logger: Logger instance + """ + self.lock = lock + self.worker_id = worker_id + self.logger = logger + + def __enter__(self): + """Acquire the benchmarking lock.""" + if self.lock: + self.logger.info( + f"⏳ Waiting for benchmark lock (worker {self.worker_id})..." + ) + self.lock.acquire() + self.logger.info(f"🔓 Acquired benchmark lock (worker {self.worker_id})") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Release the benchmarking lock.""" + if self.lock: + try: + self.lock.release() + self.logger.info( + f"🔒 Released benchmark lock (worker {self.worker_id})" + ) + except Exception as e: + self.logger.warning(f"Failed to release benchmark lock: {e}") + return False + + +class Benchmark: + """Unified benchmark for Triton kernels and PyTorch baselines. + + Supports two modes: + 1. Subprocess mode: Runs benchmarks in isolated processes (for compatibility) + 2. Direct mode: Uses in-process timing utilities (faster, more flexible) + """ + + def __init__( + self, + logger: logging.Logger, + temp_dir: Path, + benchmark_lock: Optional[Any] = None, + worker_id: int = 0, + warmup: int = 25, + repeat: int = 100, + timing_method: str = "cuda_event", + ): + """Initialize the benchmark. + + Args: + logger: Logger instance + temp_dir: Temporary directory for benchmark artifacts + benchmark_lock: Shared lock to serialize GPU benchmarking + worker_id: Worker ID + warmup: Number of warmup iterations (or warmup time in ms for do_bench) + repeat: Number of repeat iterations (or rep time in ms for do_bench) + timing_method: Timing method ("cuda_event", "do_bench", "host_time") + """ + self.logger = logger + self.temp_dir = temp_dir + self.lock_manager = BenchmarkLockManager(benchmark_lock, worker_id, logger) + self.warmup = warmup + self.repeat = repeat + self.timing_method = timing_method + + def benchmark_kernel( + self, + kernel_file: Path, + problem_file: Path, + baseline_file: Optional[Path] = None, + ) -> dict[str, Any]: + """Benchmark Triton kernel performance using subprocess isolation. + + Always uses subprocess for crash protection of potentially buggy kernels. + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + baseline_file: Path to baseline kernel (optional) + + Returns: + Dictionary with benchmark results: + - time_ms: Mean time (for backward compatibility) + - speedup: Speedup vs baseline + """ + return self._benchmark_kernel_subprocess( + kernel_file, problem_file, baseline_file + ) + + def _benchmark_kernel_subprocess( + self, + kernel_file: Path, + problem_file: Path, + baseline_file: Optional[Path] = None, + ) -> dict[str, Any]: + """Benchmark kernel using subprocess (existing workflow).""" + try: + with self.lock_manager: + results_json = self.temp_dir / "benchmark_results.json" + + # Use kernel_subprocess.py for subprocess isolation + benchmark_script = Path(__file__).parent / "kernel_subprocess.py" + + cmd = [ + sys.executable, + str(benchmark_script), + "--problem", + str(problem_file), + "--kernel", + str(kernel_file), + "--warmup", + str(self.warmup), + "--repeat", + str(self.repeat), + "--json", + str(results_json), + "--quiet", + ] + + if baseline_file: + cmd.extend(["--baseline"]) + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=300, + ) + + if result.returncode != 0: + error_msg = ( + result.stderr.strip() + or result.stdout.strip() + or "Unknown error" + ) + self.logger.error(f"Kernel benchmark failed: {error_msg}") + return {"time_ms": float("inf"), "speedup": 0.0} + + with open(results_json, "r") as f: + results = json.load(f) + + kernel_name = kernel_file.stem + kernel_results = results.get("kernels", {}).get(kernel_name, {}) + + return { + "time_ms": kernel_results.get("time_ms", float("inf")), + "speedup": kernel_results.get("speedup", 1.0), + } + + except Exception as e: + self.logger.error(f"Kernel benchmark failed: {e}") + return {"time_ms": float("inf"), "speedup": 0.0} + + def benchmark_pytorch( + self, + problem_file: Path, + dtype: Optional[torch.dtype] = None, + ) -> dict[str, Any]: + """Benchmark PyTorch baseline using direct in-process timing. + + Always uses direct mode (PyTorch is stable, doesn't need subprocess isolation). + + Args: + problem_file: Path to problem file (must define Model class and get_inputs()) + dtype: Data type to use (default: auto-detect based on model parameters) + + Returns: + Dictionary with benchmark results: + - time_ms: Mean time in ms + - stats: Full timing statistics (mean, std, min, max, all_times, etc.) + """ + try: + with self.lock_manager: + # Load model and inputs + model, inputs = prepare_pytorch_model( + problem_file=problem_file, + device="cuda", + dtype=dtype, + ) + + # Time using configured method + if self.timing_method == "do_bench": + times = time_with_triton_do_bench( + lambda: model(*inputs), + [], + warmup=self.warmup, + rep=self.repeat, + verbose=False, + ) + else: # cuda_event + times = time_with_cuda_events( + lambda: model(*inputs), + [], + num_warmup=self.warmup, + num_trials=self.repeat, + clear_cache=True, + verbose=False, + ) + + stats = compute_timing_stats(times) + + return { + "time_ms": stats["mean"], + "stats": stats, + } + + except Exception as e: + self.logger.error(f"PyTorch baseline benchmark failed: {e}") + self.logger.error(traceback.format_exc()) + return {"time_ms": float("inf")} + + def benchmark_function( + self, + fn: Callable, + args: list[Any], + name: str = "function", + ) -> dict[str, Any]: + """Benchmark an arbitrary function with lock protection. + + Args: + fn: Function to benchmark + args: Arguments to pass to function + name: Name for logging + + Returns: + Dictionary with timing statistics + """ + try: + with self.lock_manager: + self.logger.info(f"Benchmarking {name}...") + + # Time using configured method + if self.timing_method == "do_bench": + times = time_with_triton_do_bench( + fn, + args, + warmup=self.warmup, + rep=self.repeat, + verbose=False, + ) + else: # cuda_event + times = time_with_cuda_events( + fn, + args, + num_warmup=self.warmup, + num_trials=self.repeat, + clear_cache=True, + verbose=False, + ) + + stats = compute_timing_stats(times) + + self.logger.info( + f"{name}: {stats['mean']:.3f} ± {stats['std']:.3f} ms " + f"(min={stats['min']:.3f}, max={stats['max']:.3f})" + ) + + return { + "time_ms": stats["mean"], + "stats": stats, + } + + except Exception as e: + self.logger.error(f"Benchmark failed for {name}: {e}") + return {"time_ms": float("inf")} diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py new file mode 100644 index 0000000..9cfd947 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py @@ -0,0 +1,442 @@ +""" +kernel_subprocess.py + +A task-agnostic, profiling-only benchmark script for Triton kernels. +This script ONLY benchmarks candidate kernels without correctness checks. +Assumes correctness has been verified upstream. + +Design: +- Skips correctness verification (assumes already verified) +- Only runs candidate kernels (not reference) +- Fast profiling for iterative optimization loops +- Uses shared utilities from timing.py (no code duplication) + +Usage: + python kernel_subprocess.py \\ + --problem problem.py \\ + --kernel candidate_kernel.py \\ + --warmup 25 --repeat 100 \\ + --device cuda --dtype bfloat16 +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +# Import shared utilities from timing module (avoid duplication) +from timing import ( + CompilationError, + import_module, + load_kernel_function, + load_problem_interface, + prepare_inputs, +) +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +import triton.testing as tt + + +def _run_once( + fn: Callable, inputs: Tuple[torch.Tensor, ...], init_inputs: list, name: str +) -> torch.Tensor: + """Run kernel once to get output shape/dtype info. + + Args: + fn: Kernel function + inputs: Input tensors + init_inputs: Initialization inputs (e.g., features, eps) + name: Name for logging + + Returns: + Output tensor + + Raises: + Exception if kernel fails to run + """ + try: + with torch.inference_mode(): + out = fn(*inputs, *init_inputs) + return out + except Exception as exc: + raise RuntimeError(f"{name} failed to execute: {exc}") from exc + + +def benchmark( + fn: Callable, + inputs: Tuple[torch.Tensor, ...], + init_inputs: list, + name: str, + warmup: int = 25, + rep: int = 100, +) -> float: + """Benchmark a kernel function using triton.testing.do_bench. + + Args: + fn: Kernel function to benchmark + inputs: Input tensors + init_inputs: Initialization inputs (e.g., features, eps) + name: Name for logging + warmup: Number of warmup iterations + rep: Number of measurement iterations + + Returns: + Mean latency in milliseconds + """ + try: + ms = tt.do_bench( + lambda: fn(*inputs, *init_inputs), + warmup=warmup, + rep=rep, + return_mode="mean", + ) + print(f"{name}: {ms:.4f} ms (mean over {rep} runs)") + return ms + except Exception as exc: + print(f"❌ {name}: Benchmark failed: {exc}") + return float("inf") + + +def main(): + parser = argparse.ArgumentParser( + description="Task-agnostic Triton kernel benchmark", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # File paths + parser.add_argument( + "--problem", + type=Path, + required=True, + help="Path to problem file (must define Model and get_inputs)", + ) + parser.add_argument( + "--kernel", + type=Path, + required=True, + help="Path to kernel file (must define kernel_function)", + ) + parser.add_argument( + "--baseline", + action="store_true", + help="Include PyTorch reference model in benchmark", + ) + + # Benchmark configuration + parser.add_argument("--warmup", type=int, default=25) + parser.add_argument("--repeat", type=int, default=100) + + # Problem configuration + parser.add_argument("--size", type=int, default=4096, help="Problem size N") + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda", "cpu"], + help="Device to use", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["float32", "float16", "bfloat16"], + help="Data type", + ) + + # Output options + parser.add_argument("--json", type=Path, help="Save results to JSON file") + parser.add_argument( + "--quiet", + action="store_true", + help="Suppress non-essential output", + ) + + args = parser.parse_args() + + # Resolve paths + args.problem = args.problem.resolve() + args.kernel = args.kernel.resolve() + + # Setup device and dtype + device = torch.device(args.device) + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + dtype = dtype_map[args.dtype] + + if not args.quiet: + print("=" * 80) + print("TRITON KERNEL PROFILING") + print("=" * 80) + print(f"Problem: {args.problem.name}") + print(f"Size: {args.size}") + print(f"Device: {device}") + print(f"Dtype: {dtype}") + print(f"Warmup: {args.warmup}, Repeat: {args.repeat}") + print() + + # Import problem module using shared utility + try: + Model, get_inputs, get_init_inputs = load_problem_interface(args.problem) + except Exception as exc: + print(f"❌ Failed to import problem file: {exc}") + sys.exit(1) + + # Check for optional benchmark config + try: + problem_mod = import_module(args.problem, "problem") + get_benchmark_config = getattr(problem_mod, "get_benchmark_config", None) + except: + get_benchmark_config = None + + # Override benchmark config if provided by problem + if get_benchmark_config is not None: + config = get_benchmark_config() + args.warmup = config.get("warmup", args.warmup) + args.repeat = config.get("repeat", args.repeat) + if not args.quiet: + cfg_msg = ( + f"Using problem-specific config: " + f"warmup={args.warmup}, repeat={args.repeat}" + ) + print(cfg_msg) + + # Generate inputs using shared utility + try: + inputs = prepare_inputs(get_inputs, device=device, dtype=dtype) + + # Get initialization inputs (e.g., features, eps for RMSNorm) + init_inputs = [] + if get_init_inputs is not None: + init_inputs = get_init_inputs() + if not isinstance(init_inputs, (tuple, list)): + init_inputs = [init_inputs] + except Exception as exc: + print(f"❌ Failed to generate inputs: {exc}") + sys.exit(1) + + # Create reference model only if baseline is requested + model = None + if args.baseline: + try: + # Initialize model with init_inputs if provided + if init_inputs: + model = Model(*init_inputs).to(device=device, dtype=dtype) + else: + model = Model().to(device=device, dtype=dtype) + model.eval() + # Run once to get output shape + out = _run_once(model, inputs, [], "Reference") + if not args.quiet: + print(f"Reference output shape: {out.shape}, dtype: {out.dtype}") + print() + except Exception as exc: + print(f"❌ Failed to create reference model: {exc}") + sys.exit(1) + + # Results tracking + results: Dict[str, Any] = { + "problem": str(args.problem), + "size": args.size, + "device": str(device), + "dtype": str(dtype), + "warmup": args.warmup, + "repeat": args.repeat, + "kernels": {}, + } + + baseline_time = None + + # Benchmark PyTorch baseline if requested + if args.baseline and model is not None: + if not args.quiet: + print("1. PyTorch Reference") + baseline_time = benchmark( + model, inputs, [], "PyTorch", args.warmup, args.repeat + ) + results["kernels"]["pytorch_reference"] = { + "time_ms": baseline_time, + "speedup": 1.0, + } + if not args.quiet: + print() + + # Benchmark candidate kernel + kernel_name = args.kernel.stem + + if not args.quiet: + idx = 2 if args.baseline else 1 + print(f"{idx}. Candidate: {kernel_name}") + + # Import kernel using shared utility + try: + kernel_function = load_kernel_function(args.kernel) + except Exception as exc: + print(f"❌ Failed to import {kernel_name}: {exc}") + results["kernels"][kernel_name] = { + "time_ms": float("inf"), + "error": str(exc), + } + if args.json: + args.json.parent.mkdir(parents=True, exist_ok=True) + with open(args.json, "w") as f: + json.dump(results, f, indent=2) + sys.exit(1) + + # Check if kernel expects weight/bias parameters (e.g., Conv, Linear) + # If so, extract them from a Model instance + import inspect + + needs_model = False + try: + sig = inspect.signature(kernel_function) + params = list(sig.parameters.keys()) + # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) + if "weight" in params: + needs_model = True + except: + pass + + # Prepare kernel arguments + kernel_args = inputs + kernel_init_args = init_inputs + + if needs_model and Model is not None: + try: + # Initialize model to extract weight and bias + if init_inputs: + extract_model = Model(*init_inputs).to(device=device, dtype=dtype) + else: + extract_model = Model().to(device=device, dtype=dtype) + + # Extract weight and bias from model layer + # Check various possible attribute names + weight = None + bias = None + layer = None + for attr_name in [ + "conv1", + "conv2", + "conv3", + "conv1d", + "conv2d", + "conv", + "conv3d", + "linear", + "fc", + ]: + if hasattr(extract_model, attr_name): + layer = getattr(extract_model, attr_name) + if hasattr(layer, "weight"): + weight = layer.weight + bias = layer.bias if hasattr(layer, "bias") else None + break + + if weight is not None and layer is not None: + # Build kwargs for kernel_function + kernel_kwargs = {} + + # Add conv/linear-specific parameters if they exist + if hasattr(layer, "stride"): + stride = ( + layer.stride[0] + if isinstance(layer.stride, (tuple, list)) + else layer.stride + ) + kernel_kwargs["stride"] = stride + if hasattr(layer, "padding"): + padding = ( + layer.padding[0] + if isinstance(layer.padding, (tuple, list)) + else layer.padding + ) + kernel_kwargs["padding"] = padding + if hasattr(layer, "dilation"): + dilation = ( + layer.dilation[0] + if isinstance(layer.dilation, (tuple, list)) + else layer.dilation + ) + kernel_kwargs["dilation"] = dilation + if hasattr(layer, "groups"): + kernel_kwargs["groups"] = layer.groups + + # Capture original kernel function to avoid recursion + original_kernel_function = kernel_function + + # Prepare wrapper function that passes weight/bias + def kernel_with_model(*args, **kwargs): + return original_kernel_function( + args[0], weight, bias, **kernel_kwargs + ) + + # Update kernel function and clear init_inputs (already handled) + kernel_function = kernel_with_model + kernel_init_args = [] + except Exception as exc: + if not args.quiet: + print(f"⚠️ Warning: Failed to extract model parameters: {exc}") + print(f" Falling back to direct kernel invocation") + + # Run once to verify it executes + try: + out = _run_once(kernel_function, kernel_args, kernel_init_args, kernel_name) + if not args.quiet: + print(f"✓ {kernel_name} executes successfully") + print(f" Output shape: {out.shape}, dtype: {out.dtype}") + except Exception as exc: + print(f"❌ {kernel_name} failed: {exc}") + results["kernels"][kernel_name] = { + "time_ms": float("inf"), + "error": str(exc), + } + if args.json: + args.json.parent.mkdir(parents=True, exist_ok=True) + with open(args.json, "w") as f: + json.dump(results, f, indent=2) + sys.exit(1) + + # Benchmark + kernel_time = benchmark( + kernel_function, + kernel_args, + kernel_init_args, + kernel_name, + args.warmup, + args.repeat, + ) + + results["kernels"][kernel_name] = { + "time_ms": kernel_time, + "path": str(args.kernel), + } + + if baseline_time is not None and kernel_time != float("inf"): + speedup = baseline_time / kernel_time + results["kernels"][kernel_name]["speedup"] = speedup + if not args.quiet: + print(f"Speedup vs PyTorch: {speedup:.2f}x") + + # Save JSON results + if args.json: + args.json.parent.mkdir(parents=True, exist_ok=True) + with open(args.json, "w") as f: + json.dump(results, f, indent=2) + print(f"Results saved to: {args.json}") + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nBenchmark interrupted") + sys.exit(130) + except Exception as exc: + print(f"❌ Unexpected error: {exc}") + sys.exit(1) diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/timing.py b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py new file mode 100644 index 0000000..88fa7cc --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py @@ -0,0 +1,437 @@ +"""Core timing and model loading utilities for kernel benchmarking. + +This module consolidates: +- Timing functions (CUDA events, do_bench, host timing) +- Model/kernel loading utilities +- Statistics computation + +Inspired by KernelBench's timing.py with improvements for production use. +""" + +import hashlib +import importlib.util +import sys +import time +from pathlib import Path +from typing import Any, Callable, Optional, Tuple + +import numpy as np +import torch + + +# ============================================================================= +# Model and Kernel Loading Utilities +# ============================================================================= + + +class CompilationError(RuntimeError): + """Raised when a kernel or problem file fails to compile/import.""" + + pass + + +def import_module(path: Path, module_name: Optional[str] = None): + """Dynamically import a Python file. + + Args: + path: Path to the Python file + module_name: Optional name for the module (auto-generated if None) + + Returns: + The imported module + + Raises: + FileNotFoundError: If path doesn't exist + CompilationError: If import fails + """ + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + if module_name is None: + module_name = f"mod_{hashlib.md5(str(path).encode()).hexdigest()}" + + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None or spec.loader is None: + raise CompilationError(f"Failed to create spec for {path}") + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + + try: + spec.loader.exec_module(module) + except Exception as exc: + raise CompilationError(f"Failed to import {path}: {exc}") from exc + + return module + + +def load_problem_interface( + problem_file: Path, +) -> Tuple[type, Callable, Optional[Callable]]: + """Load the standard problem interface from a problem file. + + Args: + problem_file: Path to problem file + + Returns: + Tuple of (Model class, get_inputs function, get_init_inputs function) + + Raises: + CompilationError: If problem file doesn't define required interface + """ + module = import_module(problem_file, "problem") + + Model = getattr(module, "Model", None) + get_inputs = getattr(module, "get_inputs", None) + get_init_inputs = getattr(module, "get_init_inputs", None) + + if Model is None: + raise CompilationError("Problem file must define 'Model' class") + if get_inputs is None: + raise CompilationError("Problem file must define 'get_inputs()' function") + + return Model, get_inputs, get_init_inputs + + +def prepare_inputs( + get_inputs: Callable, + device: torch.device | str = "cuda", + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[torch.Tensor, ...]: + """Prepare inputs by converting to target device and dtype. + + Args: + get_inputs: Function that returns inputs + device: Target device + dtype: Target dtype for floating-point tensors + + Returns: + Tuple of prepared inputs + """ + inputs = get_inputs() + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + + # Convert inputs to target device and dtype + # IMPORTANT: Only convert floating-point tensors; preserve integer/bool tensors + converted_inputs = [] + for inp in inputs: + if isinstance(inp, torch.Tensor): + # Move to device + inp = inp.to(device=device) + # Convert dtype ONLY for floating-point tensors + # Preserve integer/bool tensors (e.g., targets for classification) + if inp.is_floating_point(): + inp = inp.to(dtype=dtype) + converted_inputs.append(inp) + + return tuple(converted_inputs) + + +def prepare_pytorch_model( + problem_file: Path, + device: torch.device | str = "cuda", + dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.nn.Module, Tuple[torch.Tensor, ...]]: + """Prepare PyTorch model and inputs for benchmarking. + + This handles the full workflow: + 1. Load problem interface (Model, get_inputs, get_init_inputs) + 2. Initialize model with init inputs + 3. Move model to device + 4. Handle dtype conversion based on whether model has parameters + + Args: + problem_file: Path to problem file + device: Target device + dtype: Target dtype (auto-detected if None) + + Returns: + Tuple of (model, inputs) ready for benchmarking + """ + # Load problem interface + Model, get_inputs, get_init_inputs = load_problem_interface(problem_file) + + # Get initialization inputs (e.g., features, eps for RMSNorm) + init_inputs = [] + if get_init_inputs is not None: + init_inputs = get_init_inputs() + if not isinstance(init_inputs, (tuple, list)): + init_inputs = [init_inputs] + + # Initialize model + if init_inputs: + model = Model(*init_inputs) + else: + model = Model() + + # Move model to CUDA + model = model.cuda() + + # Check if model has trainable parameters + has_parameters = any(p.numel() > 0 for p in model.parameters()) + + # Get inputs + inputs = get_inputs() + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + + # Determine target dtype + # Default to bfloat16 unless explicitly specified or model is a loss function + target_dtype = dtype or torch.bfloat16 + + # Check if this is actually a loss function (not just a model without parameters) + is_loss_function = isinstance( + model, + ( + torch.nn.CrossEntropyLoss, + torch.nn.MSELoss, + torch.nn.BCELoss, + torch.nn.BCEWithLogitsLoss, + torch.nn.NLLLoss, + torch.nn.L1Loss, + torch.nn.SmoothL1Loss, + ), + ) + + # Handle dtype conversion based on model type + if has_parameters or not is_loss_function: + # Models with parameters (Conv, Linear, etc.) OR compute operations (matmul, etc.) + # → use bfloat16 (or user-specified dtype) + if has_parameters: + model = model.to(target_dtype) + inputs = [ + ( + inp.cuda().to(target_dtype) + if isinstance(inp, torch.Tensor) and inp.is_floating_point() + else inp.cuda() if isinstance(inp, torch.Tensor) else inp + ) + for inp in inputs + ] + else: + # Loss functions (no parameters) → use float32 for compatibility + # PyTorch cross_entropy doesn't support bf16 on CUDA + processed_inputs = [] + for i, inp in enumerate(inputs): + if isinstance(inp, torch.Tensor): + if i == 0 and inp.is_floating_point(): + # First input (predictions) - convert to float32 for compatibility + processed_inputs.append(inp.cuda().to(torch.float32)) + else: + # Other inputs (like targets) - just move to CUDA, preserve dtype + processed_inputs.append(inp.cuda()) + else: + processed_inputs.append(inp) + inputs = processed_inputs + + return model, tuple(inputs) + + +def load_kernel_function(kernel_file: Path) -> Callable: + """Load kernel_function from a kernel file. + + Args: + kernel_file: Path to kernel file + + Returns: + The kernel_function callable + + Raises: + CompilationError: If kernel file doesn't define kernel_function + """ + module = import_module(kernel_file, "kernel") + + kernel_function = getattr(module, "kernel_function", None) + if kernel_function is None: + raise CompilationError( + f"Kernel file {kernel_file.name} must define 'kernel_function'" + ) + + return kernel_function + + +# ============================================================================= +# Timing Utilities +# ============================================================================= + + +def clear_l2_cache(device: torch.device | str = "cuda") -> None: + """Clear L2 cache by thrashing with a large tensor. + + This ensures we measure cold cache performance, which is more representative + of real-world scenarios where data isn't already cached. + + Reference: KernelBench timing.py + L2 cache sizes: A100=40MB, H100=50MB, H200=90MB, RTX4090=72MB, L40S=48MB + We overwrite >256MB to fully thrash L2 cache. + + Args: + device: CUDA device to use + """ + # 32 * 1024 * 1024 * 8B = 256MB - enough to thrash most GPU L2 caches + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device=device) + dummy.fill_(42) # Write to tensor to ensure cache thrashing + del dummy + + +def time_with_cuda_events( + kernel_fn: Callable, + args: list[Any], + num_warmup: int = 3, + num_trials: int = 10, + clear_cache: bool = True, + discard_first: int = 0, + verbose: bool = False, + device: Optional[torch.device | str] = None, +) -> list[float]: + """Time a CUDA kernel using CUDA events for accurate device-side timing. + + This measures actual GPU execution time without host-side overhead. + Each trial clears L2 cache to measure cold-cache performance. + + Args: + kernel_fn: Function to time + args: Arguments to pass to kernel_fn + num_warmup: Number of warmup iterations + num_trials: Number of timing trials + clear_cache: Whether to clear L2 cache between trials + discard_first: Number of initial trials to discard + verbose: Print per-trial timing info + device: CUDA device to use (None = current device) + + Returns: + List of elapsed times in milliseconds (length = num_trials) + """ + if device is None: + device = torch.cuda.current_device() + + with torch.cuda.device(device): + # Warmup + for _ in range(num_warmup): + kernel_fn(*args) + torch.cuda.synchronize(device=device) + + torch.cuda.empty_cache() + + if verbose: + print( + f"[Timing] Device: {torch.cuda.get_device_name(device)}, " + f"warmup={num_warmup}, trials={num_trials}" + ) + + elapsed_times: list[float] = [] + + # Timing trials + for trial in range(num_trials + discard_first): + torch.cuda.synchronize(device=device) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + if clear_cache: + clear_l2_cache(device=device) + + start_event.record() + kernel_fn(*args) + end_event.record() + + torch.cuda.synchronize(device=device) + elapsed_time_ms = start_event.elapsed_time(end_event) + + if trial >= discard_first: + if verbose: + print( + f" Trial {trial - discard_first + 1}: {elapsed_time_ms:.3f} ms" + ) + elapsed_times.append(elapsed_time_ms) + + return elapsed_times + + +def time_with_triton_do_bench( + kernel_fn: Callable, + args: list[Any], + warmup: int = 25, + rep: int = 100, + verbose: bool = False, + device: Optional[torch.device | str] = None, +) -> list[float]: + """Time using Triton's do_bench with adaptive trial count. + + Triton's do_bench automatically determines the number of trials based on + warmup/rep time budgets. This is convenient but gives less control. + + Args: + kernel_fn: Function to time + args: Arguments to pass to kernel_fn + warmup: Warmup time budget in milliseconds + rep: Repetition time budget in milliseconds + verbose: Print timing info + device: CUDA device to use + + Returns: + List of all trial times in milliseconds + """ + if device is None: + device = torch.cuda.current_device() + + try: + from triton import testing as triton_testing + except ImportError: + raise ImportError("Triton is required for time_with_triton_do_bench") + + with torch.cuda.device(device): + if verbose: + print( + f"[Timing] Using triton.do_bench on {torch.cuda.get_device_name(device)}" + ) + + wrapped_fn = lambda: kernel_fn(*args) + times = triton_testing.do_bench( + fn=wrapped_fn, + warmup=warmup, + rep=rep, + grad_to_none=None, + quantiles=None, + return_mode="all", + ) + + return times + + +def compute_timing_stats( + elapsed_times: list[float], + device: Optional[torch.device | str] = None, +) -> dict[str, Any]: + """Compute essential timing statistics. + + Args: + elapsed_times: List of elapsed times in milliseconds + device: CUDA device (for recording hardware info) + + Returns: + Dictionary with timing statistics: + - mean: Mean time in ms + - std: Standard deviation in ms + - min: Minimum time in ms + - max: Maximum time in ms + - num_trials: Number of trials + - all_times: All trial times + - hardware: GPU name (if device provided) + """ + times_array = np.array(elapsed_times) + + stats = { + "mean": float(np.mean(times_array)), + "std": float(np.std(times_array)), + "min": float(np.min(times_array)), + "max": float(np.max(times_array)), + "num_trials": len(elapsed_times), + "all_times": [float(t) for t in elapsed_times], + } + + if device is not None: + stats["hardware"] = torch.cuda.get_device_name(device=device) + stats["device"] = str(device) + + return stats From 4febdd65c413869050e824147dba83bb6e5feee1 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Thu, 8 Jan 2026 08:29:15 -0800 Subject: [PATCH 13/22] Introducing benchmarking infra for kernel performance --- .../benchmarking/__init__.py | 26 ++++++--- .../benchmarking/benchmark.py | 14 +++++ .../benchmarking/kernel_subprocess.py | 57 ++++++++++--------- .../benchmarking/timing.py | 31 +++++----- 4 files changed, 79 insertions(+), 49 deletions(-) diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py b/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py index aa90643..b6a761c 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py @@ -1,9 +1,19 @@ -"""Benchmarking system for kernel performance measurement. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarking infrastructure for kernel performance measurement. -Simplified structure with just 3 files: -- benchmark.py: Main Benchmark class and BenchmarkLockManager -- timing.py: All utilities (timing + model/kernel loading) -- kernel_subprocess.py: Subprocess runner for kernel isolation """ # Core benchmarking @@ -11,16 +21,16 @@ # All utilities from timing module from .timing import ( + # Timing functions + clear_l2_cache, # Model/kernel loading CompilationError, + compute_timing_stats, import_module, load_kernel_function, load_problem_interface, prepare_inputs, prepare_pytorch_model, - # Timing functions - clear_l2_cache, - compute_timing_stats, time_with_cuda_events, time_with_triton_do_bench, ) diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py index 4af02e4..adcf0ae 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unified benchmarking for Triton kernels and PyTorch baselines. This module consolidates kernel and PyTorch benchmarking with improved timing diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py index 9cfd947..5c8c1e9 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py @@ -1,22 +1,27 @@ -""" -kernel_subprocess.py +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" A task-agnostic, profiling-only benchmark script for Triton kernels. This script ONLY benchmarks candidate kernels without correctness checks. Assumes correctness has been verified upstream. Design: - Skips correctness verification (assumes already verified) -- Only runs candidate kernels (not reference) +- Only runs candidate kernels - Fast profiling for iterative optimization loops -- Uses shared utilities from timing.py (no code duplication) - -Usage: - python kernel_subprocess.py \\ - --problem problem.py \\ - --kernel candidate_kernel.py \\ - --warmup 25 --repeat 100 \\ - --device cuda --dtype bfloat16 +- Uses shared utilities from timing.py """ from __future__ import annotations @@ -320,22 +325,20 @@ def main(): weight = None bias = None layer = None - for attr_name in [ - "conv1", - "conv2", - "conv3", - "conv1d", - "conv2d", - "conv", - "conv3d", - "linear", - "fc", - ]: - if hasattr(extract_model, attr_name): - layer = getattr(extract_model, attr_name) - if hasattr(layer, "weight"): - weight = layer.weight - bias = layer.bias if hasattr(layer, "bias") else None + for name, module in extract_model.named_modules(): + if isinstance( + module, + ( + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.Linear, + ), + ): + if hasattr(module, "weight") and module.weight is not None: + layer = module + weight = module.weight + bias = getattr(module, "bias", None) break if weight is not None and layer is not None: diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/timing.py b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py index 88fa7cc..f096e3c 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/timing.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Core timing and model loading utilities for kernel benchmarking. This module consolidates: @@ -5,7 +19,7 @@ - Model/kernel loading utilities - Statistics computation -Inspired by KernelBench's timing.py with improvements for production use. +Inspired by KernelBench's timing.py """ import hashlib @@ -180,19 +194,8 @@ def prepare_pytorch_model( # Default to bfloat16 unless explicitly specified or model is a loss function target_dtype = dtype or torch.bfloat16 - # Check if this is actually a loss function (not just a model without parameters) - is_loss_function = isinstance( - model, - ( - torch.nn.CrossEntropyLoss, - torch.nn.MSELoss, - torch.nn.BCELoss, - torch.nn.BCEWithLogitsLoss, - torch.nn.NLLLoss, - torch.nn.L1Loss, - torch.nn.SmoothL1Loss, - ), - ) + # Check if this is actually a loss function + is_loss_function = isinstance(model, torch.nn.modules.loss._Loss) # Handle dtype conversion based on model type if has_parameters or not is_loss_function: From d92a7b7740601bef9f532cfd02d195c3216e1df5 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Fri, 9 Jan 2026 11:41:51 -0800 Subject: [PATCH 14/22] fix ruff --- .../benchmarking/benchmark.py | 1 - .../benchmarking/kernel_subprocess.py | 9 ++-- .../benchmarking/timing.py | 42 ++++++++++++++++++- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py index adcf0ae..707632b 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py @@ -30,7 +30,6 @@ from triton_kernel_agent.opt_worker_component.benchmarking.timing import ( compute_timing_stats, - load_kernel_function, prepare_pytorch_model, time_with_cuda_events, time_with_triton_do_bench, diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py index 5c8c1e9..015db24 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py @@ -33,13 +33,12 @@ # Import shared utilities from timing module (avoid duplication) from timing import ( - CompilationError, import_module, load_kernel_function, load_problem_interface, prepare_inputs, ) -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Tuple import torch import triton.testing as tt @@ -197,7 +196,7 @@ def main(): try: problem_mod = import_module(args.problem, "problem") get_benchmark_config = getattr(problem_mod, "get_benchmark_config", None) - except: + except Exception: get_benchmark_config = None # Override benchmark config if provided by problem @@ -305,7 +304,7 @@ def main(): # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) if "weight" in params: needs_model = True - except: + except Exception: pass # Prepare kernel arguments @@ -385,7 +384,7 @@ def kernel_with_model(*args, **kwargs): except Exception as exc: if not args.quiet: print(f"⚠️ Warning: Failed to extract model parameters: {exc}") - print(f" Falling back to direct kernel invocation") + print(" Falling back to direct kernel invocation") # Run once to verify it executes try: diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/timing.py b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py index f096e3c..223e60d 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/timing.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py @@ -25,7 +25,6 @@ import hashlib import importlib.util import sys -import time from pathlib import Path from typing import Any, Callable, Optional, Tuple @@ -351,6 +350,43 @@ def time_with_cuda_events( return elapsed_times +def time_with_inductor_benchmarker( + kernel_fn: Callable, + args: list[Any], + num_warmup: int = 25, + verbose: bool = False, +) -> float: + """Time using PyTorch Inductor's benchmarker (simplest approach). + + This is a thin wrapper around torch._inductor.runtime.benchmarking.benchmarker, + which handles CUDA synchronization and timing internally. + + Args: + kernel_fn: Function to time + args: Arguments to pass to kernel_fn + num_warmup: Number of warmup iterations + verbose: Print timing info + + Returns: + Elapsed time in milliseconds (single value, not a list) + + Note: + This uses a private PyTorch API (_inductor) which may change without notice. + """ + from torch._inductor.runtime.benchmarking import benchmarker + + # Warmup + for _ in range(num_warmup): + kernel_fn(*args) + + ms = benchmarker.benchmark_gpu(lambda: kernel_fn(*args)) + + if verbose: + print(f"[Timing] Inductor benchmarker: {ms:.4f} ms") + + return ms + + def time_with_triton_do_bench( kernel_fn: Callable, args: list[Any], @@ -389,7 +425,9 @@ def time_with_triton_do_bench( f"[Timing] Using triton.do_bench on {torch.cuda.get_device_name(device)}" ) - wrapped_fn = lambda: kernel_fn(*args) + def wrapped_fn(): + return kernel_fn(*args) + times = triton_testing.do_bench( fn=wrapped_fn, warmup=warmup, From 2994315671d6f88af9b4be739826945292d9656f Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Fri, 9 Jan 2026 11:49:12 -0800 Subject: [PATCH 15/22] fix ruff --- .../opt_worker_component/benchmarking/__init__.py | 4 +--- .../opt_worker_component/benchmarking/timing.py | 4 +++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py b/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py index b6a761c..cfd8ea8 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Benchmarking infrastructure for kernel performance measurement. - -""" +"""Benchmarking infrastructure for kernel performance measurement.""" # Core benchmarking from .benchmark import Benchmark, BenchmarkLockManager diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/timing.py b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py index 223e60d..6701a60 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/timing.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py @@ -206,7 +206,9 @@ def prepare_pytorch_model( ( inp.cuda().to(target_dtype) if isinstance(inp, torch.Tensor) and inp.is_floating_point() - else inp.cuda() if isinstance(inp, torch.Tensor) else inp + else inp.cuda() + if isinstance(inp, torch.Tensor) + else inp ) for inp in inputs ] From 1378fc33b7575c39b767a9cb386fa42a2b1a3b9d Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 14 Jan 2026 10:28:23 -0800 Subject: [PATCH 16/22] address comments --- .../benchmarking/__init__.py | 37 +- .../benchmarking/benchmark.py | 112 +---- .../benchmarking/kernel_subprocess.py | 451 ++++++++---------- .../benchmarking/timing.py | 31 +- 4 files changed, 226 insertions(+), 405 deletions(-) diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py b/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py index cfd8ea8..61e1d08 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py @@ -14,39 +14,6 @@ """Benchmarking infrastructure for kernel performance measurement.""" -# Core benchmarking -from .benchmark import Benchmark, BenchmarkLockManager +from .benchmark import Benchmark -# All utilities from timing module -from .timing import ( - # Timing functions - clear_l2_cache, - # Model/kernel loading - CompilationError, - compute_timing_stats, - import_module, - load_kernel_function, - load_problem_interface, - prepare_inputs, - prepare_pytorch_model, - time_with_cuda_events, - time_with_triton_do_bench, -) - -__all__ = [ - # Core benchmarking - "Benchmark", - "BenchmarkLockManager", - # Model/kernel loading - "CompilationError", - "import_module", - "load_kernel_function", - "load_problem_interface", - "prepare_inputs", - "prepare_pytorch_model", - # Timing utilities - "clear_l2_cache", - "compute_timing_stats", - "time_with_cuda_events", - "time_with_triton_do_bench", -] +__all__ = ["Benchmark"] diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py index 707632b..8ee39e4 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py @@ -24,7 +24,7 @@ import sys import traceback from pathlib import Path -from typing import Any, Callable, Optional +from typing import Any, Optional import torch @@ -39,11 +39,11 @@ class BenchmarkLockManager: """Manages GPU benchmarking locks to prevent resource contention.""" - def __init__(self, lock: Optional[Any], worker_id: int, logger: logging.Logger): + def __init__(self, lock: Any, worker_id: int, logger: logging.Logger): """Initialize the lock manager. Args: - lock: Shared multiprocessing lock (or None if no locking needed) + lock: Shared multiprocessing lock for serializing GPU access worker_id: Worker ID for logging logger: Logger instance """ @@ -53,24 +53,18 @@ def __init__(self, lock: Optional[Any], worker_id: int, logger: logging.Logger): def __enter__(self): """Acquire the benchmarking lock.""" - if self.lock: - self.logger.info( - f"⏳ Waiting for benchmark lock (worker {self.worker_id})..." - ) - self.lock.acquire() - self.logger.info(f"🔓 Acquired benchmark lock (worker {self.worker_id})") + self.logger.info(f"⏳ Waiting for benchmark lock (worker {self.worker_id})...") + self.lock.acquire() + self.logger.info(f"🔓 Acquired benchmark lock (worker {self.worker_id})") return self def __exit__(self, exc_type, exc_val, exc_tb): """Release the benchmarking lock.""" - if self.lock: - try: - self.lock.release() - self.logger.info( - f"🔒 Released benchmark lock (worker {self.worker_id})" - ) - except Exception as e: - self.logger.warning(f"Failed to release benchmark lock: {e}") + try: + self.lock.release() + self.logger.info(f"🔒 Released benchmark lock (worker {self.worker_id})") + except Exception as e: + self.logger.warning(f"Failed to release benchmark lock: {e}") return False @@ -85,8 +79,8 @@ class Benchmark: def __init__( self, logger: logging.Logger, - temp_dir: Path, - benchmark_lock: Optional[Any] = None, + artifacts_dir: Path, + benchmark_lock: Any, worker_id: int = 0, warmup: int = 25, repeat: int = 100, @@ -96,7 +90,7 @@ def __init__( Args: logger: Logger instance - temp_dir: Temporary directory for benchmark artifacts + artifacts_dir: Directory for benchmark artifacts benchmark_lock: Shared lock to serialize GPU benchmarking worker_id: Worker ID warmup: Number of warmup iterations (or warmup time in ms for do_bench) @@ -104,7 +98,7 @@ def __init__( timing_method: Timing method ("cuda_event", "do_bench", "host_time") """ self.logger = logger - self.temp_dir = temp_dir + self.artifacts_dir = artifacts_dir self.lock_manager = BenchmarkLockManager(benchmark_lock, worker_id, logger) self.warmup = warmup self.repeat = repeat @@ -118,7 +112,7 @@ def benchmark_kernel( ) -> dict[str, Any]: """Benchmark Triton kernel performance using subprocess isolation. - Always uses subprocess for crash protection of potentially buggy kernels. + Uses subprocess for crash protection of potentially buggy kernels. Args: kernel_file: Path to kernel file @@ -127,25 +121,12 @@ def benchmark_kernel( Returns: Dictionary with benchmark results: - - time_ms: Mean time (for backward compatibility) + - time_ms: Mean time in ms - speedup: Speedup vs baseline """ - return self._benchmark_kernel_subprocess( - kernel_file, problem_file, baseline_file - ) - - def _benchmark_kernel_subprocess( - self, - kernel_file: Path, - problem_file: Path, - baseline_file: Optional[Path] = None, - ) -> dict[str, Any]: - """Benchmark kernel using subprocess (existing workflow).""" try: with self.lock_manager: - results_json = self.temp_dir / "benchmark_results.json" - - # Use kernel_subprocess.py for subprocess isolation + results_json = self.artifacts_dir / "benchmark_results.json" benchmark_script = Path(__file__).parent / "kernel_subprocess.py" cmd = [ @@ -218,14 +199,12 @@ def benchmark_pytorch( """ try: with self.lock_manager: - # Load model and inputs model, inputs = prepare_pytorch_model( problem_file=problem_file, device="cuda", dtype=dtype, ) - # Time using configured method if self.timing_method == "do_bench": times = time_with_triton_do_bench( lambda: model(*inputs), @@ -255,58 +234,3 @@ def benchmark_pytorch( self.logger.error(f"PyTorch baseline benchmark failed: {e}") self.logger.error(traceback.format_exc()) return {"time_ms": float("inf")} - - def benchmark_function( - self, - fn: Callable, - args: list[Any], - name: str = "function", - ) -> dict[str, Any]: - """Benchmark an arbitrary function with lock protection. - - Args: - fn: Function to benchmark - args: Arguments to pass to function - name: Name for logging - - Returns: - Dictionary with timing statistics - """ - try: - with self.lock_manager: - self.logger.info(f"Benchmarking {name}...") - - # Time using configured method - if self.timing_method == "do_bench": - times = time_with_triton_do_bench( - fn, - args, - warmup=self.warmup, - rep=self.repeat, - verbose=False, - ) - else: # cuda_event - times = time_with_cuda_events( - fn, - args, - num_warmup=self.warmup, - num_trials=self.repeat, - clear_cache=True, - verbose=False, - ) - - stats = compute_timing_stats(times) - - self.logger.info( - f"{name}: {stats['mean']:.3f} ± {stats['std']:.3f} ms " - f"(min={stats['min']:.3f}, max={stats['max']:.3f})" - ) - - return { - "time_ms": stats["mean"], - "stats": stats, - } - - except Exception as e: - self.logger.error(f"Benchmark failed for {name}: {e}") - return {"time_ms": float("inf")} diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py index 015db24..733216c 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py @@ -27,49 +27,82 @@ from __future__ import annotations import argparse +import inspect import json import sys from pathlib import Path -# Import shared utilities from timing module (avoid duplication) from timing import ( import_module, load_kernel_function, load_problem_interface, prepare_inputs, ) -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Tuple import torch import triton.testing as tt -def _run_once( - fn: Callable, inputs: Tuple[torch.Tensor, ...], init_inputs: list, name: str -) -> torch.Tensor: - """Run kernel once to get output shape/dtype info. +def _extract_model_params( + model: torch.nn.Module, +) -> tuple[torch.Tensor | None, torch.Tensor | None, dict[str, Any]]: + """Extract weight, bias, and layer parameters from a PyTorch model. + + Searches for Conv or Linear layers and extracts their parameters. Args: - fn: Kernel function - inputs: Input tensors - init_inputs: Initialization inputs (e.g., features, eps) - name: Name for logging + model: PyTorch model to extract parameters from Returns: - Output tensor - - Raises: - Exception if kernel fails to run + Tuple of (weight, bias, layer_kwargs) where layer_kwargs contains + stride, padding, dilation, groups if applicable """ + for _, module in model.named_modules(): + if isinstance( + module, + (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.Linear), + ): + if hasattr(module, "weight") and module.weight is not None: + weight = module.weight + bias = getattr(module, "bias", None) + + layer_kwargs: dict[str, Any] = {} + if hasattr(module, "stride"): + stride = module.stride + layer_kwargs["stride"] = ( + stride[0] if isinstance(stride, (tuple, list)) else stride + ) + if hasattr(module, "padding"): + padding = module.padding + layer_kwargs["padding"] = ( + padding[0] if isinstance(padding, (tuple, list)) else padding + ) + if hasattr(module, "dilation"): + dilation = module.dilation + layer_kwargs["dilation"] = ( + dilation[0] if isinstance(dilation, (tuple, list)) else dilation + ) + if hasattr(module, "groups"): + layer_kwargs["groups"] = module.groups + + return weight, bias, layer_kwargs + + return None, None, {} + + +def _run_once( + fn: Callable, inputs: Tuple[torch.Tensor, ...], init_inputs: list, name: str +) -> torch.Tensor: + """Run kernel once to verify execution and get output shape/dtype.""" try: with torch.inference_mode(): - out = fn(*inputs, *init_inputs) - return out + return fn(*inputs, *init_inputs) except Exception as exc: raise RuntimeError(f"{name} failed to execute: {exc}") from exc -def benchmark( +def _benchmark( fn: Callable, inputs: Tuple[torch.Tensor, ...], init_inputs: list, @@ -77,19 +110,7 @@ def benchmark( warmup: int = 25, rep: int = 100, ) -> float: - """Benchmark a kernel function using triton.testing.do_bench. - - Args: - fn: Kernel function to benchmark - inputs: Input tensors - init_inputs: Initialization inputs (e.g., features, eps) - name: Name for logging - warmup: Number of warmup iterations - rep: Number of measurement iterations - - Returns: - Mean latency in milliseconds - """ + """Benchmark a kernel function using triton.testing.do_bench.""" try: ms = tt.do_bench( lambda: fn(*inputs, *init_inputs), @@ -104,14 +125,14 @@ def benchmark( return float("inf") -def main(): +def _parse_args() -> argparse.Namespace: + """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Task-agnostic Triton kernel benchmark", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) - # File paths parser.add_argument( "--problem", type=Path, @@ -129,43 +150,143 @@ def main(): action="store_true", help="Include PyTorch reference model in benchmark", ) - - # Benchmark configuration parser.add_argument("--warmup", type=int, default=25) parser.add_argument("--repeat", type=int, default=100) - - # Problem configuration parser.add_argument("--size", type=int, default=4096, help="Problem size N") - parser.add_argument( - "--device", - type=str, - default="cuda", - choices=["cuda", "cpu"], - help="Device to use", - ) + parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) parser.add_argument( "--dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"], - help="Data type", ) - - # Output options parser.add_argument("--json", type=Path, help="Save results to JSON file") - parser.add_argument( - "--quiet", - action="store_true", - help="Suppress non-essential output", - ) + parser.add_argument("--quiet", action="store_true") args = parser.parse_args() - - # Resolve paths args.problem = args.problem.resolve() args.kernel = args.kernel.resolve() + return args + + +def _load_problem( + args: argparse.Namespace, + device: torch.device, + dtype: torch.dtype, +) -> tuple[type, tuple, list, torch.nn.Module | None]: + """Load problem interface, prepare inputs, and optionally create baseline model. + + Returns: + Tuple of (Model class, inputs, init_inputs, baseline_model or None) + """ + Model, get_inputs, get_init_inputs = load_problem_interface(args.problem) + + # Check for optional benchmark config override + try: + problem_mod = import_module(args.problem, "problem") + get_benchmark_config = getattr(problem_mod, "get_benchmark_config", None) + if get_benchmark_config is not None: + config = get_benchmark_config() + args.warmup = config.get("warmup", args.warmup) + args.repeat = config.get("repeat", args.repeat) + if not args.quiet: + print( + f"Using problem-specific config: " + f"warmup={args.warmup}, repeat={args.repeat}" + ) + except Exception: + pass + + inputs = prepare_inputs(get_inputs, device=device, dtype=dtype) + + init_inputs = get_init_inputs() if get_init_inputs is not None else [] + if not isinstance(init_inputs, (tuple, list)): + init_inputs = [init_inputs] + + # Create baseline model if requested + baseline_model = None + if args.baseline: + baseline_model = ( + Model(*init_inputs).to(device=device, dtype=dtype) + if init_inputs + else Model().to(device=device, dtype=dtype) + ) + baseline_model.eval() + out = _run_once(baseline_model, inputs, [], "Reference") + if not args.quiet: + print(f"Reference output shape: {out.shape}, dtype: {out.dtype}") + print() + + return Model, inputs, init_inputs, baseline_model + + +def _prepare_kernel( + kernel_file: Path, + Model: type, + baseline_model: torch.nn.Module | None, + init_inputs: list, + device: torch.device, + dtype: torch.dtype, + quiet: bool = False, +) -> tuple[Callable, tuple, list]: + """Load kernel and wrap it with model parameters if needed. + + Returns: + Tuple of (kernel_function, kernel_args, kernel_init_args) + """ + kernel_function = load_kernel_function(kernel_file) + + # Check if kernel expects weight/bias parameters + needs_model = False + try: + sig = inspect.signature(kernel_function) + if "weight" in sig.parameters: + needs_model = True + except Exception: + pass + + kernel_init_args = init_inputs + + if needs_model and Model is not None: + try: + # Reuse baseline model if available + extract_model = baseline_model + if extract_model is None: + extract_model = ( + Model(*init_inputs).to(device=device, dtype=dtype) + if init_inputs + else Model().to(device=device, dtype=dtype) + ) + + weight, bias, kernel_kwargs = _extract_model_params(extract_model) + + if weight is not None: + original_fn = kernel_function + + def kernel_with_model(*args, **kwargs): + return original_fn(args[0], weight, bias, **kernel_kwargs) + + kernel_function = kernel_with_model + kernel_init_args = [] + except Exception as exc: + if not quiet: + print(f"⚠️ Warning: Failed to extract model parameters: {exc}") + print(" Falling back to direct kernel invocation") + + return kernel_function, kernel_init_args + + +def _save_results(results: dict[str, Any], path: Path) -> None: + """Save benchmark results to JSON file.""" + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(results, f, indent=2) + print(f"Results saved to: {path}") + + +def main(): + args = _parse_args() - # Setup device and dtype device = torch.device(args.device) dtype_map = { "float32": torch.float32, @@ -180,72 +301,18 @@ def main(): print("=" * 80) print(f"Problem: {args.problem.name}") print(f"Size: {args.size}") - print(f"Device: {device}") - print(f"Dtype: {dtype}") + print(f"Device: {device}, Dtype: {dtype}") print(f"Warmup: {args.warmup}, Repeat: {args.repeat}") print() - # Import problem module using shared utility + # Load problem and prepare inputs try: - Model, get_inputs, get_init_inputs = load_problem_interface(args.problem) + Model, inputs, init_inputs, baseline_model = _load_problem(args, device, dtype) except Exception as exc: - print(f"❌ Failed to import problem file: {exc}") + print(f"❌ Failed to load problem: {exc}") sys.exit(1) - # Check for optional benchmark config - try: - problem_mod = import_module(args.problem, "problem") - get_benchmark_config = getattr(problem_mod, "get_benchmark_config", None) - except Exception: - get_benchmark_config = None - - # Override benchmark config if provided by problem - if get_benchmark_config is not None: - config = get_benchmark_config() - args.warmup = config.get("warmup", args.warmup) - args.repeat = config.get("repeat", args.repeat) - if not args.quiet: - cfg_msg = ( - f"Using problem-specific config: " - f"warmup={args.warmup}, repeat={args.repeat}" - ) - print(cfg_msg) - - # Generate inputs using shared utility - try: - inputs = prepare_inputs(get_inputs, device=device, dtype=dtype) - - # Get initialization inputs (e.g., features, eps for RMSNorm) - init_inputs = [] - if get_init_inputs is not None: - init_inputs = get_init_inputs() - if not isinstance(init_inputs, (tuple, list)): - init_inputs = [init_inputs] - except Exception as exc: - print(f"❌ Failed to generate inputs: {exc}") - sys.exit(1) - - # Create reference model only if baseline is requested - model = None - if args.baseline: - try: - # Initialize model with init_inputs if provided - if init_inputs: - model = Model(*init_inputs).to(device=device, dtype=dtype) - else: - model = Model().to(device=device, dtype=dtype) - model.eval() - # Run once to get output shape - out = _run_once(model, inputs, [], "Reference") - if not args.quiet: - print(f"Reference output shape: {out.shape}, dtype: {out.dtype}") - print() - except Exception as exc: - print(f"❌ Failed to create reference model: {exc}") - sys.exit(1) - - # Results tracking - results: Dict[str, Any] = { + results: dict[str, Any] = { "problem": str(args.problem), "size": args.size, "device": str(device), @@ -255,14 +322,13 @@ def main(): "kernels": {}, } + # Benchmark baseline (if requested) baseline_time = None - - # Benchmark PyTorch baseline if requested - if args.baseline and model is not None: + if baseline_model is not None: if not args.quiet: print("1. PyTorch Reference") - baseline_time = benchmark( - model, inputs, [], "PyTorch", args.warmup, args.repeat + baseline_time = _benchmark( + baseline_model, inputs, [], "PyTorch", args.warmup, args.repeat ) results["kernels"]["pytorch_reference"] = { "time_ms": baseline_time, @@ -271,166 +337,51 @@ def main(): if not args.quiet: print() - # Benchmark candidate kernel + # Load and prepare kernel kernel_name = args.kernel.stem - if not args.quiet: idx = 2 if args.baseline else 1 print(f"{idx}. Candidate: {kernel_name}") - # Import kernel using shared utility try: - kernel_function = load_kernel_function(args.kernel) + kernel_fn, kernel_init_args = _prepare_kernel( + args.kernel, Model, baseline_model, init_inputs, device, dtype, args.quiet + ) except Exception as exc: - print(f"❌ Failed to import {kernel_name}: {exc}") - results["kernels"][kernel_name] = { - "time_ms": float("inf"), - "error": str(exc), - } + print(f"❌ Failed to load kernel: {exc}") + results["kernels"][kernel_name] = {"time_ms": float("inf"), "error": str(exc)} if args.json: - args.json.parent.mkdir(parents=True, exist_ok=True) - with open(args.json, "w") as f: - json.dump(results, f, indent=2) + _save_results(results, args.json) sys.exit(1) - # Check if kernel expects weight/bias parameters (e.g., Conv, Linear) - # If so, extract them from a Model instance - import inspect - - needs_model = False + # Verify kernel executes try: - sig = inspect.signature(kernel_function) - params = list(sig.parameters.keys()) - # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) - if "weight" in params: - needs_model = True - except Exception: - pass - - # Prepare kernel arguments - kernel_args = inputs - kernel_init_args = init_inputs - - if needs_model and Model is not None: - try: - # Initialize model to extract weight and bias - if init_inputs: - extract_model = Model(*init_inputs).to(device=device, dtype=dtype) - else: - extract_model = Model().to(device=device, dtype=dtype) - - # Extract weight and bias from model layer - # Check various possible attribute names - weight = None - bias = None - layer = None - for name, module in extract_model.named_modules(): - if isinstance( - module, - ( - torch.nn.Conv1d, - torch.nn.Conv2d, - torch.nn.Conv3d, - torch.nn.Linear, - ), - ): - if hasattr(module, "weight") and module.weight is not None: - layer = module - weight = module.weight - bias = getattr(module, "bias", None) - break - - if weight is not None and layer is not None: - # Build kwargs for kernel_function - kernel_kwargs = {} - - # Add conv/linear-specific parameters if they exist - if hasattr(layer, "stride"): - stride = ( - layer.stride[0] - if isinstance(layer.stride, (tuple, list)) - else layer.stride - ) - kernel_kwargs["stride"] = stride - if hasattr(layer, "padding"): - padding = ( - layer.padding[0] - if isinstance(layer.padding, (tuple, list)) - else layer.padding - ) - kernel_kwargs["padding"] = padding - if hasattr(layer, "dilation"): - dilation = ( - layer.dilation[0] - if isinstance(layer.dilation, (tuple, list)) - else layer.dilation - ) - kernel_kwargs["dilation"] = dilation - if hasattr(layer, "groups"): - kernel_kwargs["groups"] = layer.groups - - # Capture original kernel function to avoid recursion - original_kernel_function = kernel_function - - # Prepare wrapper function that passes weight/bias - def kernel_with_model(*args, **kwargs): - return original_kernel_function( - args[0], weight, bias, **kernel_kwargs - ) - - # Update kernel function and clear init_inputs (already handled) - kernel_function = kernel_with_model - kernel_init_args = [] - except Exception as exc: - if not args.quiet: - print(f"⚠️ Warning: Failed to extract model parameters: {exc}") - print(" Falling back to direct kernel invocation") - - # Run once to verify it executes - try: - out = _run_once(kernel_function, kernel_args, kernel_init_args, kernel_name) + out = _run_once(kernel_fn, inputs, kernel_init_args, kernel_name) if not args.quiet: print(f"✓ {kernel_name} executes successfully") print(f" Output shape: {out.shape}, dtype: {out.dtype}") except Exception as exc: print(f"❌ {kernel_name} failed: {exc}") - results["kernels"][kernel_name] = { - "time_ms": float("inf"), - "error": str(exc), - } + results["kernels"][kernel_name] = {"time_ms": float("inf"), "error": str(exc)} if args.json: - args.json.parent.mkdir(parents=True, exist_ok=True) - with open(args.json, "w") as f: - json.dump(results, f, indent=2) + _save_results(results, args.json) sys.exit(1) - # Benchmark - kernel_time = benchmark( - kernel_function, - kernel_args, - kernel_init_args, - kernel_name, - args.warmup, - args.repeat, + # Benchmark kernel + kernel_time = _benchmark( + kernel_fn, inputs, kernel_init_args, kernel_name, args.warmup, args.repeat ) + results["kernels"][kernel_name] = {"time_ms": kernel_time, "path": str(args.kernel)} - results["kernels"][kernel_name] = { - "time_ms": kernel_time, - "path": str(args.kernel), - } - + # Calculate speedup if baseline_time is not None and kernel_time != float("inf"): speedup = baseline_time / kernel_time results["kernels"][kernel_name]["speedup"] = speedup if not args.quiet: print(f"Speedup vs PyTorch: {speedup:.2f}x") - # Save JSON results if args.json: - args.json.parent.mkdir(parents=True, exist_ok=True) - with open(args.json, "w") as f: - json.dump(results, f, indent=2) - print(f"Results saved to: {args.json}") + _save_results(results, args.json) if __name__ == "__main__": diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/timing.py b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py index 6701a60..0e0730c 100644 --- a/triton_kernel_agent/opt_worker_component/benchmarking/timing.py +++ b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py @@ -130,9 +130,7 @@ def prepare_inputs( converted_inputs = [] for inp in inputs: if isinstance(inp, torch.Tensor): - # Move to device inp = inp.to(device=device) - # Convert dtype ONLY for floating-point tensors # Preserve integer/bool tensors (e.g., targets for classification) if inp.is_floating_point(): inp = inp.to(dtype=dtype) @@ -162,41 +160,25 @@ def prepare_pytorch_model( Returns: Tuple of (model, inputs) ready for benchmarking """ - # Load problem interface Model, get_inputs, get_init_inputs = load_problem_interface(problem_file) # Get initialization inputs (e.g., features, eps for RMSNorm) - init_inputs = [] - if get_init_inputs is not None: - init_inputs = get_init_inputs() - if not isinstance(init_inputs, (tuple, list)): - init_inputs = [init_inputs] - - # Initialize model - if init_inputs: - model = Model(*init_inputs) - else: - model = Model() + init_inputs = get_init_inputs() if get_init_inputs is not None else [] + if not isinstance(init_inputs, (tuple, list)): + init_inputs = [init_inputs] - # Move model to CUDA + model = Model(*init_inputs) if init_inputs else Model() model = model.cuda() - - # Check if model has trainable parameters has_parameters = any(p.numel() > 0 for p in model.parameters()) - # Get inputs inputs = get_inputs() if not isinstance(inputs, (tuple, list)): inputs = (inputs,) - # Determine target dtype # Default to bfloat16 unless explicitly specified or model is a loss function target_dtype = dtype or torch.bfloat16 - - # Check if this is actually a loss function is_loss_function = isinstance(model, torch.nn.modules.loss._Loss) - # Handle dtype conversion based on model type if has_parameters or not is_loss_function: # Models with parameters (Conv, Linear, etc.) OR compute operations (matmul, etc.) # → use bfloat16 (or user-specified dtype) @@ -416,10 +398,7 @@ def time_with_triton_do_bench( if device is None: device = torch.cuda.current_device() - try: - from triton import testing as triton_testing - except ImportError: - raise ImportError("Triton is required for time_with_triton_do_bench") + import triton.testing as triton_testing with torch.cuda.device(device): if verbose: From 45fec80e7d0532fb64144f3dc588d03899ebc17f Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Sun, 11 Jan 2026 08:47:10 -0800 Subject: [PATCH 17/22] Diagnose module - prompt constructor --- .../diagnose_prompt/judger_prompts.py | 552 ++++++++++++++++++ .../diagnose_prompt/prompt_manager.py | 114 ++++ .../rewrite_prompt_template.py | 22 + 3 files changed, 688 insertions(+) create mode 100644 kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py create mode 100644 kernel_perf_agent/kernel_opt/diagnose_prompt/prompt_manager.py create mode 100644 kernel_perf_agent/kernel_opt/diagnose_prompt/rewrite_prompt_template.py diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py new file mode 100644 index 0000000..27af71e --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py @@ -0,0 +1,552 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Prompt Builder for Hardware Bottleneck Diagnosis + +This module provides prompt templates and builder functions for the Judge LLM +that analyzes NCU profiling metrics to identify performance bottlenecks and +provide specific optimization recommendations. + +The Judge uses a dual-bottleneck framework based on NCU hardware profiling: +- bottleneck_1 (Primary): Highest-impact performance issue +- bottleneck_2 (Secondary): Different category issue that also limits performance + +Both bottlenecks are selected from NCU hardware profiling categories: +- memory-bound +- compute-bound +- occupancy-limited +- latency-bound +""" + +from typing import Any, Dict, Optional, Tuple + + +# System prompt for the Judge LLM (Dual-Bottleneck NCU Analysis) +JUDGE_SYSTEM_PROMPT = """You are a senior GPU performance engineer. Analyze the target GPU spec, the current kernel, and the Nsight Compute (NCU) profiling metrics. Identify EXACTLY TWO DISTINCT bottlenecks from the hardware profiling data, and propose specific optimization methods for each. Be surgical and metrics-driven. + +## Bottleneck Categories (NCU Hardware Profiling) + +Analyze fundamental resource utilization using NCU profiling data: + +- **memory-bound**: DRAM throughput >50% of peak, L1 hit rate <60%, L2 hit rate <70%, memory coalescing <80%, long scoreboard stalls >25% +- **compute-bound**: DRAM throughput <40%, compute/pipe utilization >50%, memory stalls <15%, eligible warps >4/cycle +- **occupancy-limited**: Achieved occupancy <50%, registers/thread >64, shared memory >48KB/block, check launch__occupancy_limit_* for limiter +- **latency-bound**: Total stalls >35%, long scoreboard >20%, short scoreboard >15%, eligible warps <2/cycle, BUT DRAM throughput <50% (latency, not bandwidth) + +## Rules (STRICT) + +- Return EXACTLY TWO DISTINCT bottlenecks with DIFFERENT categories +- Both bottlenecks must be from: {memory-bound, compute-bound, occupancy-limited, latency-bound} +- For each bottleneck, cite 3-4 specific NCU metric values that reveal the issue +- Propose ONE actionable optimization method per bottleneck +- Keep fields brief; avoid lists of alternatives, disclaimers, or generic advice + +## Output Format (JSON - STRICT) + +```json +{ + "bottleneck_1": { + "category": "", + "root_cause": "", + "suggestion": "", + "priority_metrics": ["", "", ""], + "expected_improvement": "" + }, + "bottleneck_2": { + "category": "", + "root_cause": "", + "suggestion": "", + "priority_metrics": ["", "", ""], + "expected_improvement": "" + } +} +``` + +## Important Notes + +- bottleneck_1 is the PRIMARY (highest-impact) issue +- bottleneck_2 is the SECONDARY issue (different category from bottleneck_1) +- They should be independently addressable (fixing one doesn't automatically fix the other) + +Follow the Rules exactly. Return JSON in the specified format. +""" + + +def build_judge_optimization_prompt( + kernel_code: str, + problem_description: str, + ncu_metrics: Dict[str, Any], + gpu_specs: Dict[str, Any], +) -> Tuple[str, str]: + """ + Build system and user prompts for Judge to analyze bottleneck. + + This function constructs detailed prompts for the Judge LLM that include: + - The kernel code being analyzed + - The original problem description + - Complete NCU profiling metrics + - GPU hardware specifications + + Args: + kernel_code: Current Triton kernel code + problem_description: Original problem description + ncu_metrics: NCU profiling metrics as a dictionary (from metrics_to_prompt) + gpu_specs: GPU specifications (from get_gpu_specs) + + Returns: + Tuple of (system_prompt, user_prompt) + + Example: + >>> sys_prompt, user_prompt = build_judge_optimization_prompt( + ... kernel_code=kernel_code, + ... problem_description=problem_desc, + ... ncu_metrics=ncu_metrics, + ... gpu_specs=gpu_specs, + ... ) + >>> response = llm.call([ + ... {"role": "system", "content": sys_prompt}, + ... {"role": "user", "content": user_prompt} + ... ]) + """ + if not ncu_metrics: + raise ValueError("NCU metrics are empty - cannot build judge prompt") + + first_kernel = list(ncu_metrics.values())[0] if ncu_metrics else {} + + def get_metric(key: str, default: str = "N/A") -> str: + val = first_kernel.get(key, default) + if isinstance(val, (int, float)): + return f"{val:.2f}" + return str(val) + + # Build user prompt using list-join pattern (similar to Fuser/prompting.py) + parts: list[str] = [] + + # Problem Description + parts.append("## Problem Description") + parts.append("") + parts.append(problem_description) + + # Current Kernel Code + parts.append("") + parts.append("## Current Kernel Code") + parts.append("") + parts.append("```python") + parts.append(kernel_code) + parts.append("```") + + # GPU Hardware Specifications + parts.append("") + parts.append("## GPU Hardware Specifications") + parts.append("") + parts.append(f"- **Name:** {gpu_specs.get('name', 'Unknown')}") + parts.append(f"- **Architecture:** {gpu_specs.get('architecture', 'Unknown')}") + parts.append( + f"- **Peak Memory Bandwidth:** {gpu_specs.get('peak_memory_bw_gbps', 'N/A')} GB/s" + ) + parts.append( + f"- **Peak FP32 Performance:** {gpu_specs.get('peak_fp32_tflops', 'N/A')} TFLOPS" + ) + parts.append( + f"- **Peak FP16 Performance:** {gpu_specs.get('peak_fp16_tflops', 'N/A')} TFLOPS" + ) + parts.append(f"- **SM Count:** {gpu_specs.get('sm_count', 'N/A')}") + parts.append( + f"- **Max Threads per SM:** {gpu_specs.get('max_threads_per_sm', 'N/A')}" + ) + parts.append(f"- **L1 Cache per SM:** {gpu_specs.get('l1_cache_kb', 'N/A')} KB") + parts.append(f"- **L2 Cache (Total):** {gpu_specs.get('l2_cache_mb', 'N/A')} MB") + parts.append( + f"- **Memory Size:** {gpu_specs.get('memory_gb', 'N/A')} GB {gpu_specs.get('memory_type', '')}" + ) + + # NCU Profiling Metrics + parts.append("") + parts.append("## NCU Profiling Metrics") + + # SM & Compute Utilization + parts.append("") + parts.append("### SM & Compute Utilization") + parts.append(f"- **SM Cycles Active:** {get_metric('sm__cycles_active.avg')}") + parts.append( + f"- **Warp Active:** {get_metric('sm__warps_active.avg.pct_of_peak_sustained_active')}%" + ) + parts.append( + f"- **Total Instructions Executed:** {get_metric('sm__inst_executed.sum')}" + ) + parts.append( + f"- **Tensor Core Utilization:** {get_metric('sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active')}%" + ) + parts.append( + f"- **Tensor Core Pipeline Active:** {get_metric('sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed')}%" + ) + + # Memory Bandwidth & Cache + parts.append("") + parts.append("### Memory Bandwidth & Cache") + parts.append( + f"- **DRAM Throughput:** {get_metric('dram__throughput.avg.pct_of_peak_sustained_elapsed')}%" + ) + parts.append( + f"- **DRAM Bandwidth:** {get_metric('dram__bytes.sum.per_second')} bytes/sec" + ) + parts.append( + f"- **GPU DRAM Throughput:** {get_metric('gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed')}%" + ) + parts.append(f"- **DRAM Bytes Read:** {get_metric('dram__bytes_read.sum')} bytes") + parts.append(f"- **DRAM Bytes Write:** {get_metric('dram__bytes_write.sum')} bytes") + parts.append( + f"- **L1 Cache Hit Rate:** {get_metric('l1tex__t_sector_hit_rate.pct')}%" + ) + parts.append( + f"- **L1 Throughput:** {get_metric('l1tex__throughput.avg.pct_of_peak_sustained_active')}%" + ) + parts.append( + f"- **L2 Cache Hit Rate:** {get_metric('lts__t_sector_hit_rate.pct')}%" + ) + parts.append( + f"- **L2 Throughput:** {get_metric('lts__throughput.avg.pct_of_peak_sustained_active')}%" + ) + + # Memory Access Patterns + parts.append("") + parts.append("### Memory Access Patterns") + parts.append( + f"- **Memory Coalescing:** {get_metric('smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct')}%" + ) + parts.append( + f"- **Branch Uniformity:** {get_metric('smsp__sass_average_branch_targets_threads_uniform.pct')}%" + ) + + # Occupancy & Resources + parts.append("") + parts.append("### Occupancy & Resources") + parts.append( + f"- **Occupancy Limited By Blocks:** {get_metric('launch__occupancy_limit_blocks')}" + ) + parts.append( + f"- **Occupancy Limited By Registers:** {get_metric('launch__occupancy_limit_registers')}" + ) + parts.append( + f"- **Occupancy Limited By Shared Memory:** {get_metric('launch__occupancy_limit_shared_mem')}" + ) + parts.append( + f"- **Registers per Thread:** {get_metric('launch__registers_per_thread')}" + ) + parts.append( + f"- **Shared Memory per Block:** {get_metric('launch__shared_mem_per_block_allocated')} bytes" + ) + + # Stall Metrics + parts.append("") + parts.append("### Stall Metrics (Warp Issue Stalls)") + parts.append( + f"- **Short Scoreboard Stalls:** {get_metric('smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct')}%" + ) + parts.append( + f"- **Long Scoreboard Stalls:** {get_metric('smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct')}%" + ) + parts.append( + f"- **Barrier Stalls:** {get_metric('smsp__warp_issue_stalled_barrier_per_warp_active.pct')}%" + ) + parts.append( + f"- **Branch Resolving Stalls:** {get_metric('smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct')}%" + ) + + # Task instructions + parts.append("") + parts.append("## Your Task") + parts.append("") + parts.append( + "Identify exactly TWO distinct bottlenecks from the NCU profiling metrics above:" + ) + parts.append("1. **Bottleneck 1 (Primary)**: The highest-impact performance issue") + parts.append( + "2. **Bottleneck 2 (Secondary)**: A different category issue that also limits performance" + ) + parts.append("") + parts.append( + "For each bottleneck, cite 3-4 specific metrics that reveal the issue, " + "and recommend ONE actionable optimization." + ) + parts.append("") + parts.append( + "**Be surgical and metrics-driven.** Return JSON in the format specified in the system prompt." + ) + + user_prompt = "\n".join(parts) + return JUDGE_SYSTEM_PROMPT, user_prompt + + +def extract_judge_response(response_text: str) -> Optional[Dict[str, Any]]: + """ + Extract and parse JSON from Judge LLM response. + + This function handles various response formats and provides fallback strategies + for robust JSON extraction. Expects dual-bottleneck format with bottleneck_1 + and bottleneck_2 fields. + + Args: + response_text: Raw text response from Judge LLM + + Returns: + Parsed JSON dictionary with bottleneck_1 and bottleneck_2, + or None if extraction fails + + Example: + >>> response = llm.call(judge_prompts) + >>> analysis = extract_judge_response(response) + >>> if analysis: + ... print(f"Bottleneck 1: {analysis['bottleneck_1']['category']}") + ... print(f"Bottleneck 2: {analysis['bottleneck_2']['category']}") + """ + import json + import re + + # Strategy 1: Find JSON in code block + match = re.search(r"```json\s*(\{.*?\})\s*```", response_text, re.DOTALL) + if match: + try: + data = json.loads(match.group(1)) + if "bottleneck_1" in data and "bottleneck_2" in data: + return data + except json.JSONDecodeError: + pass + + # Strategy 2: Find first { ... } block with "bottleneck_1" field + match = re.search(r'\{[^}]*"bottleneck_1"[^}]*\}', response_text, re.DOTALL) + if match: + try: + # Extract the full JSON object (may be nested) + start_pos = response_text.find("{", match.start()) + brace_count = 0 + end_pos = start_pos + + for i in range(start_pos, len(response_text)): + if response_text[i] == "{": + brace_count += 1 + elif response_text[i] == "}": + brace_count -= 1 + if brace_count == 0: + end_pos = i + 1 + break + + json_str = response_text[start_pos:end_pos] + data = json.loads(json_str) + if "bottleneck_1" in data and "bottleneck_2" in data: + return data + except (json.JSONDecodeError, ValueError): + pass + + # Strategy 3: Find any JSON object with dual-bottleneck structure + match = re.search( + r'\{\s*"bottleneck_1"\s*:\s*\{.*?\}\s*,\s*"bottleneck_2"\s*:\s*\{.*?\}\s*\}', + response_text, + re.DOTALL, + ) + if match: + try: + return json.loads(match.group(0)) + except json.JSONDecodeError: + pass + + # Strategy 4: Backward compatibility - single-bottleneck format + match = re.search(r'\{[^}]*"bottleneck"[^}]*\}', response_text, re.DOTALL) + if match: + try: + old_format = json.loads(match.group(0)) + if "bottleneck" in old_format: + # Convert old format to dual-bottleneck format + return { + "bottleneck_1": { + "category": old_format.get("bottleneck", "unknown"), + "root_cause": old_format.get("root_cause", ""), + "suggestion": old_format.get("suggestion", ""), + "priority_metrics": old_format.get("priority_metrics", []), + "expected_improvement": old_format.get( + "expected_improvement", "" + ), + }, + "bottleneck_2": { + "category": "latency-bound", + "root_cause": "Secondary bottleneck inferred from single-bottleneck response", + "suggestion": "Review stall metrics for additional optimization opportunities", + "priority_metrics": [], + "expected_improvement": "Requires further profiling analysis", + }, + } + except json.JSONDecodeError: + pass + + # Strategy 5: Return None if all strategies fail + return None + + +def validate_judge_response(analysis: Dict[str, Any]) -> bool: + """ + Validate that Judge response contains required fields for dual-bottleneck format. + + This function validates the dual-bottleneck format with bottleneck_1 and + bottleneck_2 fields. Both bottlenecks use NCU hardware profiling categories. + + Args: + analysis: Parsed JSON from Judge response + + Returns: + True if response is valid, False otherwise + + Example: + >>> if validate_judge_response(analysis): + ... print("Valid dual-bottleneck response!") + ... else: + ... print("Invalid response - missing required fields") + """ + # Check for dual-bottleneck format + if "bottleneck_1" in analysis and "bottleneck_2" in analysis: + return _validate_bottleneck_entry( + analysis["bottleneck_1"] + ) and _validate_bottleneck_entry(analysis["bottleneck_2"]) + + # Backward compatibility: Check for old single-bottleneck format + if "bottleneck" in analysis: + required_fields = [ + "bottleneck", + "root_cause", + "suggestion", + "priority_metrics", + "expected_improvement", + ] + + for field in required_fields: + if field not in analysis: + return False + + valid_bottlenecks = [ + "memory-bound", + "compute-bound", + "occupancy-limited", + "latency-bound", + ] + if analysis["bottleneck"] not in valid_bottlenecks: + return False + + if not isinstance(analysis["priority_metrics"], list): + return False + + for field in ["root_cause", "suggestion", "expected_improvement"]: + if ( + not isinstance(analysis[field], str) + or len(analysis[field].strip()) < 10 + ): + return False + + return True + + return False + + +def _validate_bottleneck_entry(bottleneck: Dict[str, Any]) -> bool: + """ + Validate a single bottleneck entry (bottleneck_1 or bottleneck_2). + + Both bottlenecks use NCU hardware profiling categories: + memory-bound, compute-bound, occupancy-limited, latency-bound + + Args: + bottleneck: Bottleneck dictionary to validate + + Returns: + True if valid, False otherwise + """ + required_fields = [ + "category", + "root_cause", + "suggestion", + "priority_metrics", + "expected_improvement", + ] + + for field in required_fields: + if field not in bottleneck: + return False + + # NCU hardware profiling categories only + valid_categories = [ + "memory-bound", + "compute-bound", + "occupancy-limited", + "latency-bound", + ] + + if bottleneck["category"] not in valid_categories: + return False + + if not isinstance(bottleneck["priority_metrics"], list): + return False + + for field in ["root_cause", "suggestion", "expected_improvement"]: + if not isinstance(bottleneck[field], str) or len(bottleneck[field].strip()) < 5: + return False + + return True + + +if __name__ == "__main__": + print("Judge Prompts Module") + print("=" * 60) + print("\nThis module provides prompt templates for hardware bottleneck analysis.") + print("\nExample usage:") + print( + """ + from kernel_perf_agent.kernel_opt.diagnose_prompt.judger_prompts import ( + build_judge_optimization_prompt, + extract_judge_response, + validate_judge_response, + ) + from kernel_perf_agent.kernel_opt.profiler.gpu_specs import get_gpu_specs + from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( + load_ncu_metrics, + metrics_to_prompt, + ) + import json + + # Get GPU specs + gpu_specs = get_gpu_specs() + + # Load NCU metrics + metrics_df = load_ncu_metrics("ncu_baseline.csv") + ncu_metrics = json.loads(metrics_to_prompt(metrics_df)) + + # Build prompts + sys_prompt, user_prompt = build_judge_optimization_prompt( + kernel_code=kernel_code, + problem_description=problem_description, + ncu_metrics=ncu_metrics, + gpu_specs=gpu_specs, + ) + + # Call LLM + response = llm.call([ + {"role": "system", "content": sys_prompt}, + {"role": "user", "content": user_prompt} + ]) + + # Extract and validate + analysis = extract_judge_response(response) + if analysis and validate_judge_response(analysis): + print(f"Bottleneck 1: {analysis['bottleneck_1']['category']}") + print(f"Bottleneck 2: {analysis['bottleneck_2']['category']}") + """ + ) diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/prompt_manager.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/prompt_manager.py new file mode 100644 index 0000000..b118cca --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/prompt_manager.py @@ -0,0 +1,114 @@ +"""Prompt management.""" + +import ast +import inspect +from pathlib import Path +from typing import Callable, Tuple + +from kernel_perf_agent.kernel_opt.database.base import OptHierarchy, OptNode +from kernel_perf_agent.kernel_opt.prompts.rewrite_prompt_template import ( + REWRITE_PROMPT_TEMPLATE, +) +from kernel_perf_agent.kernel_opt.utils.parser_util import ( + get_module_path, + remove_decorators_from_file, +) + + +class PromptManager: + """Manages prompt construction.""" + + def __init__( + self, + func_source_code: str, + func_prompt: str, + opt_prompt: str, + model: str, + dsl: str, + kernel_name: str, + database: OptHierarchy, + opt_node: OptNode, + module_path: Path, + debug: bool, + ): + """Initialize prompt manager. + :param func: Function to optimize + :param func_prompt: Function prompt + :param opt_prompt: Optimization prompt + :param model: LLM model to use + :param dsl: Target DSL (e.g., "triton") + :param kernel_name: Name of the kernel (defaults to function name) + :param database: Knowledge database of kernel optimizations + :param opt_node: The most relevant optimization node in database + :param module_path: Path to the module containing the function + :param debug: Whether to print debug information + """ + + self.func_source_code = func_source_code + self.func_prompt = func_prompt + self.opt_prompt = opt_prompt + self.model = model + self.dsl = dsl + self.kernel_name = kernel_name + self.database = database + self.opt_node = opt_node + self.module_path = module_path + self.debug = debug + + def build_rewrite_prompt(self) -> Tuple[str, str]: + """Build rewrite prompt.""" + + # Get context by traversing opt_node to all leaf nodes + context = "" + leaf = False + cur_level = [self.opt_node] + while cur_level: + child_level = [] + for node in cur_level: + # Leaf nodes are code examples + if not leaf and not node.opt_children: + leaf = True + context += """ +Here are code examples before and after the optimization: +""" + context += node.opt_desc + for child in node.opt_children: + if child not in child_level: + child_level.append(child) + cur_level = child_level + + debug_str = "" + # if self.debug: + # debug_str += f""" + # ****** Context ****** : + # {context} + # """ + # if str(self.module_path) != "": + # debug_context_path = self.module_path / "debug_output" / "context.log" + # with open(str(debug_context_path), "w") as file: + # file.write(debug_str) + # # file.write("****** Context ****** : \n") + # # file.write(context) + + # Rewriting the kernels at the same DSL level as the input. + prompt = REWRITE_PROMPT_TEMPLATE.format( + dsl=self.dsl, + kernel_name=self.kernel_name, + func_prompt=self.func_prompt, + input_kernel=self.func_source_code, + opt_prompt=self.opt_prompt, + context=context, + ) + + if self.debug: + debug_str += f""" +****** Prompt ****** : +{prompt} +""" + # if str(self.module_path) != "": + # debug_prompt_path = self.module_path / "debug_output" / "prompt.log" + # with open(str(debug_prompt_path), "w") as file: + # file.write("****** Prompt ****** : \n") + # file.write(prompt) + + return prompt, debug_str diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/rewrite_prompt_template.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/rewrite_prompt_template.py new file mode 100644 index 0000000..c4d8fe4 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/rewrite_prompt_template.py @@ -0,0 +1,22 @@ +"""Rewrite Prompt template.""" + +REWRITE_PROMPT_TEMPLATE = """ +You are a professional performance engineer who is an expert in rewriting {dsl} kernels to improve their performance. + +Your task is to rewrite the following {dsl} kernel to integrate the specific optimization. +The kernel name is {kernel_name}. +The function of this kernel is {func_prompt}. +The kernel source code is: +{input_kernel} + +The required optimization to integrate is: +{opt_prompt} + +Here are the necessary context about the specific optimization: +{context} + +IMPORTANT: +1. Rewrite the given kernel at {dsl} level. +2. Generate the complete implementation that contains both the host code and the kernel code. +3. Please use markdown formatting (like ```python) in your output to wrap the code that you generate. +""" From b640cde36ed2a7eae37beec7d068b35f69e9415d Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Tue, 13 Jan 2026 11:17:09 -0800 Subject: [PATCH 18/22] Refactors the diagnose_prompt module into a modular architecture --- .../kernel_opt/diagnose_prompt/__init__.py | 34 +++ .../kernel_opt/diagnose_prompt/gpu_specs.py | 196 +++++++++++++ .../diagnose_prompt/judger_prompts.py | 263 ++---------------- .../diagnose_prompt/metric_schema.py | 151 ++++++++++ .../diagnose_prompt/prompt_manager.py | 114 -------- .../rewrite_prompt_template.py | 22 -- 6 files changed, 407 insertions(+), 373 deletions(-) create mode 100644 kernel_perf_agent/kernel_opt/diagnose_prompt/__init__.py create mode 100644 kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py create mode 100644 kernel_perf_agent/kernel_opt/diagnose_prompt/metric_schema.py delete mode 100644 kernel_perf_agent/kernel_opt/diagnose_prompt/prompt_manager.py delete mode 100644 kernel_perf_agent/kernel_opt/diagnose_prompt/rewrite_prompt_template.py diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/__init__.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/__init__.py new file mode 100644 index 0000000..34e8023 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Diagnose Prompt Module for Hardware Bottleneck Analysis. + +This module provides prompt building utilities for the Judge LLM that +analyzes NCU profiling metrics to identify performance bottlenecks. +""" + +from .gpu_specs import get_gpu_specs +from .judger_prompts import ( + build_judge_optimization_prompt, + extract_judge_response, + validate_judge_response, +) + +__all__ = [ + "get_gpu_specs", + "build_judge_optimization_prompt", + "extract_judge_response", + "validate_judge_response", +] diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py new file mode 100644 index 0000000..0309bb3 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GPU Specifications Database for Bottleneck Analysis + +This module provides GPU hardware specifications needed for performance analysis +and bottleneck identification. It includes peak compute performance, memory bandwidth, +cache sizes, and SM counts for common NVIDIA GPUs. + +""" + +import subprocess +from typing import Any, Dict, Optional + +# GPU specifications database +# Sources: NVIDIA official specifications, manufacturer datasheets +GPU_SPECS_DATABASE = { + "NVIDIA A100": { + "name": "NVIDIA A100", + "architecture": "Ampere", + "peak_fp32_tflops": 19.5, + "peak_fp16_tflops": 312.0, + "peak_bf16_tflops": 312.0, + "peak_memory_bw_gbps": 1555, + "sm_count": 108, + "max_threads_per_sm": 2048, + "l1_cache_kb": 192, + "l2_cache_mb": 40, + "memory_gb": 40, + "memory_type": "HBM2e", + }, + "NVIDIA H100": { + "name": "NVIDIA H100", + "architecture": "Hopper", + "peak_fp32_tflops": 51.0, + "peak_fp16_tflops": 989.0, + "peak_bf16_tflops": 989.0, + "peak_memory_bw_gbps": 3352, + "sm_count": 132, + "max_threads_per_sm": 2048, + "l1_cache_kb": 256, + "l2_cache_mb": 50, + "memory_gb": 80, + "memory_type": "HBM3", + }, + "NVIDIA RTX 4090": { + "name": "NVIDIA RTX 4090", + "architecture": "Ada Lovelace", + "peak_fp32_tflops": 82.6, + "peak_fp16_tflops": 165.0, + "peak_bf16_tflops": 165.0, + "peak_memory_bw_gbps": 1008, + "sm_count": 128, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 72, + "memory_gb": 24, + "memory_type": "GDDR6X", + }, + "NVIDIA RTX 5080": { + "name": "NVIDIA RTX 5080", + "architecture": "Blackwell", + "peak_fp32_tflops": 57.0, + "peak_fp16_tflops": 114.0, + "peak_bf16_tflops": 114.0, + "peak_memory_bw_gbps": 960, + "sm_count": 84, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 64, + "memory_gb": 16, + "memory_type": "GDDR7", + }, +} + + +def query_gpu_name() -> Optional[str]: + """ + Query GPU name using nvidia-smi. + + Returns: + GPU name string, or None if query fails + """ + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + # Take only the first GPU (nvidia-smi returns one line per GPU) + gpu_name = result.stdout.strip().split("\n")[0].strip() + return gpu_name + except (subprocess.TimeoutExpired, FileNotFoundError, Exception): + pass + return None + + +def get_gpu_specs(gpu_name: Optional[str] = None) -> Dict[str, Any]: + """ + Get GPU specifications for bottleneck analysis. + + This function returns hardware specifications needed for performance analysis, + including peak compute performance, memory bandwidth, cache sizes, and SM counts. + + Args: + gpu_name: GPU name (if None, auto-detect with nvidia-smi) + + Returns: + Dictionary with GPU specifications containing: + - name: GPU name + - architecture: GPU architecture (e.g., "Ampere", "Hopper") + - peak_fp32_tflops: Peak FP32 compute performance in TFLOPS + - peak_fp16_tflops: Peak FP16 compute performance in TFLOPS + - peak_bf16_tflops: Peak BF16 compute performance in TFLOPS (0 if not supported) + - peak_memory_bw_gbps: Peak memory bandwidth in GB/s + - sm_count: Number of streaming multiprocessors + - max_threads_per_sm: Maximum threads per SM + - l1_cache_kb: L1 cache size in KB per SM + - l2_cache_mb: Total L2 cache size in MB + - memory_gb: Total GPU memory in GB + - memory_type: Memory type (e.g., "HBM2e", "GDDR6X") + + Examples: + >>> specs = get_gpu_specs() # Auto-detect + >>> print(f"Peak BW: {specs['peak_memory_bw_gbps']} GB/s") + + >>> specs = get_gpu_specs("NVIDIA A100") + >>> print(f"SM Count: {specs['sm_count']}") + """ + # Auto-detect if not provided + if gpu_name is None: + gpu_name = query_gpu_name() + + # Return default if detection failed + if gpu_name is None: + print("⚠️ GPU auto-detection failed, using A100 specs as fallback") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + # Try exact match + if gpu_name in GPU_SPECS_DATABASE: + return GPU_SPECS_DATABASE[gpu_name].copy() + + # Try fuzzy match (contains or partial match) + gpu_name_lower = gpu_name.lower() + for key, specs in GPU_SPECS_DATABASE.items(): + key_lower = key.lower() + # Check if either name contains the other + if gpu_name_lower in key_lower or key_lower in gpu_name_lower: + print(f"ℹ️ Matched '{gpu_name}' to '{key}' (fuzzy match)") + return specs.copy() + + # Fallback to A100 specs with warning + print(f"⚠️ Unknown GPU: '{gpu_name}', using A100 specs as fallback") + print(f" Available GPUs: {', '.join(GPU_SPECS_DATABASE.keys())}") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + +if __name__ == "__main__": + print("GPU Specifications Module") + print("=" * 60) + + # Auto-detect GPU + detected_name = query_gpu_name() + if detected_name: + print(f"\nDetected GPU: {detected_name}") + else: + print("\nNo GPU detected (nvidia-smi not available)") + + # Get specs + specs = get_gpu_specs() + print( + f"\nUsing specs for: {specs['name']} ({specs.get('architecture', 'Unknown')})" + ) + print(f" - Peak Memory Bandwidth: {specs['peak_memory_bw_gbps']} GB/s") + print(f" - Peak FP32 Performance: {specs['peak_fp32_tflops']} TFLOPS") + print(f" - SM Count: {specs['sm_count']}") + + # Show all available GPUs + print(f"\n{'=' * 60}") + print("Available GPU specifications in database:") + for gpu_name in sorted(GPU_SPECS_DATABASE.keys()): + print(f" - {gpu_name}") diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py index 27af71e..37b55ec 100644 --- a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py @@ -28,10 +28,21 @@ - compute-bound - occupancy-limited - latency-bound + +Metric definitions are in metric_schema.py and rendering logic is in section_renderers.py. """ from typing import Any, Dict, Optional, Tuple +from .section_renderers import ( + create_metric_getter, + render_gpu_specs, + render_kernel_code, + render_ncu_metrics, + render_problem_description, + render_task_instructions, +) + # System prompt for the Judge LLM (Dual-Bottleneck NCU Analysis) JUDGE_SYSTEM_PROMPT = """You are a senior GPU performance engineer. Analyze the target GPU spec, the current kernel, and the Nsight Compute (NCU) profiling metrics. Identify EXACTLY TWO DISTINCT bottlenecks from the hardware profiling data, and propose specific optimization methods for each. Be surgical and metrics-driven. @@ -40,12 +51,11 @@ Analyze fundamental resource utilization using NCU profiling data: -- **memory-bound**: DRAM throughput >50% of peak, L1 hit rate <60%, L2 hit rate <70%, memory coalescing <80%, long scoreboard stalls >25% -- **compute-bound**: DRAM throughput <40%, compute/pipe utilization >50%, memory stalls <15%, eligible warps >4/cycle -- **occupancy-limited**: Achieved occupancy <50%, registers/thread >64, shared memory >48KB/block, check launch__occupancy_limit_* for limiter -- **latency-bound**: Total stalls >35%, long scoreboard >20%, short scoreboard >15%, eligible warps <2/cycle, BUT DRAM throughput <50% (latency, not bandwidth) - -## Rules (STRICT) +## Bottleneck Categories (Indicators Only) +- **memory-bound**: High DRAM throughput (>60%), low L1/L2 hit rates (<70%), high memory stalls (>30%) +- **compute-bound**: Low DRAM throughput (<40%), high compute utilization (>60%), low memory stalls (<20%) +- **occupancy-limited**: Low warp active (<50%), high register usage (>100/thread), shared memory pressure (>80%) +- **latency-bound**: High total stalls (>40%), memory dependency stalls dominate, long scoreboard stalls - Return EXACTLY TWO DISTINCT bottlenecks with DIFFERENT categories - Both bottlenecks must be from: {memory-bound, compute-bound, occupancy-limited, latency-bound} @@ -123,168 +133,19 @@ def build_judge_optimization_prompt( if not ncu_metrics: raise ValueError("NCU metrics are empty - cannot build judge prompt") + # Extract first kernel's metrics for the metric getter first_kernel = list(ncu_metrics.values())[0] if ncu_metrics else {} + get_metric_fn = create_metric_getter(first_kernel) - def get_metric(key: str, default: str = "N/A") -> str: - val = first_kernel.get(key, default) - if isinstance(val, (int, float)): - return f"{val:.2f}" - return str(val) - - # Build user prompt using list-join pattern (similar to Fuser/prompting.py) + # Build user prompt using modular section renderers parts: list[str] = [] - # Problem Description - parts.append("## Problem Description") - parts.append("") - parts.append(problem_description) - - # Current Kernel Code - parts.append("") - parts.append("## Current Kernel Code") - parts.append("") - parts.append("```python") - parts.append(kernel_code) - parts.append("```") - - # GPU Hardware Specifications - parts.append("") - parts.append("## GPU Hardware Specifications") - parts.append("") - parts.append(f"- **Name:** {gpu_specs.get('name', 'Unknown')}") - parts.append(f"- **Architecture:** {gpu_specs.get('architecture', 'Unknown')}") - parts.append( - f"- **Peak Memory Bandwidth:** {gpu_specs.get('peak_memory_bw_gbps', 'N/A')} GB/s" - ) - parts.append( - f"- **Peak FP32 Performance:** {gpu_specs.get('peak_fp32_tflops', 'N/A')} TFLOPS" - ) - parts.append( - f"- **Peak FP16 Performance:** {gpu_specs.get('peak_fp16_tflops', 'N/A')} TFLOPS" - ) - parts.append(f"- **SM Count:** {gpu_specs.get('sm_count', 'N/A')}") - parts.append( - f"- **Max Threads per SM:** {gpu_specs.get('max_threads_per_sm', 'N/A')}" - ) - parts.append(f"- **L1 Cache per SM:** {gpu_specs.get('l1_cache_kb', 'N/A')} KB") - parts.append(f"- **L2 Cache (Total):** {gpu_specs.get('l2_cache_mb', 'N/A')} MB") - parts.append( - f"- **Memory Size:** {gpu_specs.get('memory_gb', 'N/A')} GB {gpu_specs.get('memory_type', '')}" - ) - - # NCU Profiling Metrics - parts.append("") - parts.append("## NCU Profiling Metrics") - - # SM & Compute Utilization - parts.append("") - parts.append("### SM & Compute Utilization") - parts.append(f"- **SM Cycles Active:** {get_metric('sm__cycles_active.avg')}") - parts.append( - f"- **Warp Active:** {get_metric('sm__warps_active.avg.pct_of_peak_sustained_active')}%" - ) - parts.append( - f"- **Total Instructions Executed:** {get_metric('sm__inst_executed.sum')}" - ) - parts.append( - f"- **Tensor Core Utilization:** {get_metric('sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active')}%" - ) - parts.append( - f"- **Tensor Core Pipeline Active:** {get_metric('sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed')}%" - ) - - # Memory Bandwidth & Cache - parts.append("") - parts.append("### Memory Bandwidth & Cache") - parts.append( - f"- **DRAM Throughput:** {get_metric('dram__throughput.avg.pct_of_peak_sustained_elapsed')}%" - ) - parts.append( - f"- **DRAM Bandwidth:** {get_metric('dram__bytes.sum.per_second')} bytes/sec" - ) - parts.append( - f"- **GPU DRAM Throughput:** {get_metric('gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed')}%" - ) - parts.append(f"- **DRAM Bytes Read:** {get_metric('dram__bytes_read.sum')} bytes") - parts.append(f"- **DRAM Bytes Write:** {get_metric('dram__bytes_write.sum')} bytes") - parts.append( - f"- **L1 Cache Hit Rate:** {get_metric('l1tex__t_sector_hit_rate.pct')}%" - ) - parts.append( - f"- **L1 Throughput:** {get_metric('l1tex__throughput.avg.pct_of_peak_sustained_active')}%" - ) - parts.append( - f"- **L2 Cache Hit Rate:** {get_metric('lts__t_sector_hit_rate.pct')}%" - ) - parts.append( - f"- **L2 Throughput:** {get_metric('lts__throughput.avg.pct_of_peak_sustained_active')}%" - ) - - # Memory Access Patterns - parts.append("") - parts.append("### Memory Access Patterns") - parts.append( - f"- **Memory Coalescing:** {get_metric('smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct')}%" - ) - parts.append( - f"- **Branch Uniformity:** {get_metric('smsp__sass_average_branch_targets_threads_uniform.pct')}%" - ) - - # Occupancy & Resources - parts.append("") - parts.append("### Occupancy & Resources") - parts.append( - f"- **Occupancy Limited By Blocks:** {get_metric('launch__occupancy_limit_blocks')}" - ) - parts.append( - f"- **Occupancy Limited By Registers:** {get_metric('launch__occupancy_limit_registers')}" - ) - parts.append( - f"- **Occupancy Limited By Shared Memory:** {get_metric('launch__occupancy_limit_shared_mem')}" - ) - parts.append( - f"- **Registers per Thread:** {get_metric('launch__registers_per_thread')}" - ) - parts.append( - f"- **Shared Memory per Block:** {get_metric('launch__shared_mem_per_block_allocated')} bytes" - ) - - # Stall Metrics - parts.append("") - parts.append("### Stall Metrics (Warp Issue Stalls)") - parts.append( - f"- **Short Scoreboard Stalls:** {get_metric('smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct')}%" - ) - parts.append( - f"- **Long Scoreboard Stalls:** {get_metric('smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct')}%" - ) - parts.append( - f"- **Barrier Stalls:** {get_metric('smsp__warp_issue_stalled_barrier_per_warp_active.pct')}%" - ) - parts.append( - f"- **Branch Resolving Stalls:** {get_metric('smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct')}%" - ) - - # Task instructions - parts.append("") - parts.append("## Your Task") - parts.append("") - parts.append( - "Identify exactly TWO distinct bottlenecks from the NCU profiling metrics above:" - ) - parts.append("1. **Bottleneck 1 (Primary)**: The highest-impact performance issue") - parts.append( - "2. **Bottleneck 2 (Secondary)**: A different category issue that also limits performance" - ) - parts.append("") - parts.append( - "For each bottleneck, cite 3-4 specific metrics that reveal the issue, " - "and recommend ONE actionable optimization." - ) - parts.append("") - parts.append( - "**Be surgical and metrics-driven.** Return JSON in the format specified in the system prompt." - ) + # Compose sections using renderers + parts.extend(render_problem_description(problem_description)) + parts.extend(render_kernel_code(kernel_code)) + parts.extend(render_gpu_specs(gpu_specs)) + parts.extend(render_ncu_metrics(ncu_metrics, get_metric_fn)) + parts.extend(render_task_instructions()) user_prompt = "\n".join(parts) return JUDGE_SYSTEM_PROMPT, user_prompt @@ -362,35 +223,7 @@ def extract_judge_response(response_text: str) -> Optional[Dict[str, Any]]: except json.JSONDecodeError: pass - # Strategy 4: Backward compatibility - single-bottleneck format - match = re.search(r'\{[^}]*"bottleneck"[^}]*\}', response_text, re.DOTALL) - if match: - try: - old_format = json.loads(match.group(0)) - if "bottleneck" in old_format: - # Convert old format to dual-bottleneck format - return { - "bottleneck_1": { - "category": old_format.get("bottleneck", "unknown"), - "root_cause": old_format.get("root_cause", ""), - "suggestion": old_format.get("suggestion", ""), - "priority_metrics": old_format.get("priority_metrics", []), - "expected_improvement": old_format.get( - "expected_improvement", "" - ), - }, - "bottleneck_2": { - "category": "latency-bound", - "root_cause": "Secondary bottleneck inferred from single-bottleneck response", - "suggestion": "Review stall metrics for additional optimization opportunities", - "priority_metrics": [], - "expected_improvement": "Requires further profiling analysis", - }, - } - except json.JSONDecodeError: - pass - - # Strategy 5: Return None if all strategies fail + # Return None if all strategies fail return None @@ -506,47 +339,3 @@ def _validate_bottleneck_entry(bottleneck: Dict[str, Any]) -> bool: if __name__ == "__main__": print("Judge Prompts Module") print("=" * 60) - print("\nThis module provides prompt templates for hardware bottleneck analysis.") - print("\nExample usage:") - print( - """ - from kernel_perf_agent.kernel_opt.diagnose_prompt.judger_prompts import ( - build_judge_optimization_prompt, - extract_judge_response, - validate_judge_response, - ) - from kernel_perf_agent.kernel_opt.profiler.gpu_specs import get_gpu_specs - from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( - load_ncu_metrics, - metrics_to_prompt, - ) - import json - - # Get GPU specs - gpu_specs = get_gpu_specs() - - # Load NCU metrics - metrics_df = load_ncu_metrics("ncu_baseline.csv") - ncu_metrics = json.loads(metrics_to_prompt(metrics_df)) - - # Build prompts - sys_prompt, user_prompt = build_judge_optimization_prompt( - kernel_code=kernel_code, - problem_description=problem_description, - ncu_metrics=ncu_metrics, - gpu_specs=gpu_specs, - ) - - # Call LLM - response = llm.call([ - {"role": "system", "content": sys_prompt}, - {"role": "user", "content": user_prompt} - ]) - - # Extract and validate - analysis = extract_judge_response(response) - if analysis and validate_judge_response(analysis): - print(f"Bottleneck 1: {analysis['bottleneck_1']['category']}") - print(f"Bottleneck 2: {analysis['bottleneck_2']['category']}") - """ - ) diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/metric_schema.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/metric_schema.py new file mode 100644 index 0000000..64d1d67 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/metric_schema.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Metric Schema Definitions for NCU Profiling and GPU Specifications. + +This module provides the single source of truth for: +- NCU profiling metric definitions (keys, labels, units) +- GPU specification field definitions + +Schema Format: List of tuples (display_label, key, unit_suffix) +- display_label: Human-readable name shown in prompts +- key: NCU metric key or GPU spec dictionary key +- unit_suffix: Unit to append after value (e.g., "%", " GB/s", " bytes") +""" + +from typing import Dict, List, Tuple + +# Type alias for metric definition: (label, key, unit) +MetricDef = Tuple[str, str, str] + +# ============================================================================= +# GPU Specification Fields +# ============================================================================= + +GPU_SPEC_FIELDS: List[MetricDef] = [ + ("Name", "name", ""), + ("Architecture", "architecture", ""), + ("Peak Memory Bandwidth", "peak_memory_bw_gbps", " GB/s"), + ("Peak FP32 Performance", "peak_fp32_tflops", " TFLOPS"), + ("Peak FP16 Performance", "peak_fp16_tflops", " TFLOPS"), + ("SM Count", "sm_count", ""), + ("Max Threads per SM", "max_threads_per_sm", ""), + ("L1 Cache per SM", "l1_cache_kb", " KB"), + ("L2 Cache (Total)", "l2_cache_mb", " MB"), +] + +# Special case: Memory Size has two fields combined +GPU_MEMORY_FIELDS: List[Tuple[str, str, str, str]] = [ + # (label, size_key, type_key, size_unit) + ("Memory Size", "memory_gb", "memory_type", " GB"), +] + +# ============================================================================= +# NCU Profiling Metric Sections +# ============================================================================= + +NCU_METRIC_SECTIONS: Dict[str, List[MetricDef]] = { + "SM & Compute Utilization": [ + ("SM Cycles Active", "sm__cycles_active.avg", ""), + ("Warp Active", "sm__warps_active.avg.pct_of_peak_sustained_active", "%"), + ("Total Instructions Executed", "sm__inst_executed.sum", ""), + ( + "Tensor Core Utilization", + "sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", + "%", + ), + ( + "Tensor Core Pipeline Active", + "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", + "%", + ), + ], + "Memory Bandwidth & Cache": [ + ( + "DRAM Throughput", + "dram__throughput.avg.pct_of_peak_sustained_elapsed", + "%", + ), + ("DRAM Bandwidth", "dram__bytes.sum.per_second", " bytes/sec"), + ( + "GPU DRAM Throughput", + "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed", + "%", + ), + ("DRAM Bytes Read", "dram__bytes_read.sum", " bytes"), + ("DRAM Bytes Write", "dram__bytes_write.sum", " bytes"), + ("L1 Cache Hit Rate", "l1tex__t_sector_hit_rate.pct", "%"), + ( + "L1 Throughput", + "l1tex__throughput.avg.pct_of_peak_sustained_active", + "%", + ), + ("L2 Cache Hit Rate", "lts__t_sector_hit_rate.pct", "%"), + ( + "L2 Throughput", + "lts__throughput.avg.pct_of_peak_sustained_active", + "%", + ), + ], + "Memory Access Patterns": [ + ( + "Memory Coalescing", + "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", + "%", + ), + ( + "Branch Uniformity", + "smsp__sass_average_branch_targets_threads_uniform.pct", + "%", + ), + ], + "Occupancy & Resources": [ + ("Occupancy Limited By Blocks", "launch__occupancy_limit_blocks", ""), + ("Occupancy Limited By Registers", "launch__occupancy_limit_registers", ""), + ( + "Occupancy Limited By Shared Memory", + "launch__occupancy_limit_shared_mem", + "", + ), + ("Registers per Thread", "launch__registers_per_thread", ""), + ( + "Shared Memory per Block", + "launch__shared_mem_per_block_allocated", + " bytes", + ), + ], + "Stall Metrics (Warp Issue Stalls)": [ + ( + "Short Scoreboard Stalls", + "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", + "%", + ), + ( + "Long Scoreboard Stalls", + "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", + "%", + ), + ( + "Barrier Stalls", + "smsp__warp_issue_stalled_barrier_per_warp_active.pct", + "%", + ), + ( + "Branch Resolving Stalls", + "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", + "%", + ), + ], +} diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/prompt_manager.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/prompt_manager.py deleted file mode 100644 index b118cca..0000000 --- a/kernel_perf_agent/kernel_opt/diagnose_prompt/prompt_manager.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Prompt management.""" - -import ast -import inspect -from pathlib import Path -from typing import Callable, Tuple - -from kernel_perf_agent.kernel_opt.database.base import OptHierarchy, OptNode -from kernel_perf_agent.kernel_opt.prompts.rewrite_prompt_template import ( - REWRITE_PROMPT_TEMPLATE, -) -from kernel_perf_agent.kernel_opt.utils.parser_util import ( - get_module_path, - remove_decorators_from_file, -) - - -class PromptManager: - """Manages prompt construction.""" - - def __init__( - self, - func_source_code: str, - func_prompt: str, - opt_prompt: str, - model: str, - dsl: str, - kernel_name: str, - database: OptHierarchy, - opt_node: OptNode, - module_path: Path, - debug: bool, - ): - """Initialize prompt manager. - :param func: Function to optimize - :param func_prompt: Function prompt - :param opt_prompt: Optimization prompt - :param model: LLM model to use - :param dsl: Target DSL (e.g., "triton") - :param kernel_name: Name of the kernel (defaults to function name) - :param database: Knowledge database of kernel optimizations - :param opt_node: The most relevant optimization node in database - :param module_path: Path to the module containing the function - :param debug: Whether to print debug information - """ - - self.func_source_code = func_source_code - self.func_prompt = func_prompt - self.opt_prompt = opt_prompt - self.model = model - self.dsl = dsl - self.kernel_name = kernel_name - self.database = database - self.opt_node = opt_node - self.module_path = module_path - self.debug = debug - - def build_rewrite_prompt(self) -> Tuple[str, str]: - """Build rewrite prompt.""" - - # Get context by traversing opt_node to all leaf nodes - context = "" - leaf = False - cur_level = [self.opt_node] - while cur_level: - child_level = [] - for node in cur_level: - # Leaf nodes are code examples - if not leaf and not node.opt_children: - leaf = True - context += """ -Here are code examples before and after the optimization: -""" - context += node.opt_desc - for child in node.opt_children: - if child not in child_level: - child_level.append(child) - cur_level = child_level - - debug_str = "" - # if self.debug: - # debug_str += f""" - # ****** Context ****** : - # {context} - # """ - # if str(self.module_path) != "": - # debug_context_path = self.module_path / "debug_output" / "context.log" - # with open(str(debug_context_path), "w") as file: - # file.write(debug_str) - # # file.write("****** Context ****** : \n") - # # file.write(context) - - # Rewriting the kernels at the same DSL level as the input. - prompt = REWRITE_PROMPT_TEMPLATE.format( - dsl=self.dsl, - kernel_name=self.kernel_name, - func_prompt=self.func_prompt, - input_kernel=self.func_source_code, - opt_prompt=self.opt_prompt, - context=context, - ) - - if self.debug: - debug_str += f""" -****** Prompt ****** : -{prompt} -""" - # if str(self.module_path) != "": - # debug_prompt_path = self.module_path / "debug_output" / "prompt.log" - # with open(str(debug_prompt_path), "w") as file: - # file.write("****** Prompt ****** : \n") - # file.write(prompt) - - return prompt, debug_str diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/rewrite_prompt_template.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/rewrite_prompt_template.py deleted file mode 100644 index c4d8fe4..0000000 --- a/kernel_perf_agent/kernel_opt/diagnose_prompt/rewrite_prompt_template.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Rewrite Prompt template.""" - -REWRITE_PROMPT_TEMPLATE = """ -You are a professional performance engineer who is an expert in rewriting {dsl} kernels to improve their performance. - -Your task is to rewrite the following {dsl} kernel to integrate the specific optimization. -The kernel name is {kernel_name}. -The function of this kernel is {func_prompt}. -The kernel source code is: -{input_kernel} - -The required optimization to integrate is: -{opt_prompt} - -Here are the necessary context about the specific optimization: -{context} - -IMPORTANT: -1. Rewrite the given kernel at {dsl} level. -2. Generate the complete implementation that contains both the host code and the kernel code. -3. Please use markdown formatting (like ```python) in your output to wrap the code that you generate. -""" From e952123de63f8dbd82f3c39a66bd7a5d013c5ed5 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Tue, 13 Jan 2026 12:29:41 -0800 Subject: [PATCH 19/22] fix diff issue --- .../diagnose_prompt/judger_prompts.py | 211 +++++++++--------- 1 file changed, 100 insertions(+), 111 deletions(-) diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py index 37b55ec..b98e4d0 100644 --- a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py @@ -29,19 +29,95 @@ - occupancy-limited - latency-bound -Metric definitions are in metric_schema.py and rendering logic is in section_renderers.py. +Metric definitions are in metric_schema.py. """ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple -from .section_renderers import ( - create_metric_getter, - render_gpu_specs, - render_kernel_code, - render_ncu_metrics, - render_problem_description, - render_task_instructions, -) +from .metric_schema import GPU_MEMORY_FIELDS, GPU_SPEC_FIELDS, NCU_METRIC_SECTIONS + + + +# ============================================================================= +# Section Renderers +# ============================================================================= + + +def render_problem_description(problem_description: str) -> List[str]: + """Render the problem description section.""" + return ["## Problem Description", "", problem_description] + + +def render_kernel_code(kernel_code: str, language: str = "python") -> List[str]: + """Render the kernel code section with syntax highlighting.""" + return ["", "## Current Kernel Code", "", f"```{language}", kernel_code, "```"] + + +def render_gpu_specs(gpu_specs: Dict[str, Any]) -> List[str]: + """Render the GPU hardware specifications section.""" + lines = ["", "## GPU Hardware Specifications", ""] + + for label, key, unit in GPU_SPEC_FIELDS: + value = gpu_specs.get(key, "N/A") + lines.append(f"- **{label}:** {value}{unit}") + + for label, size_key, type_key, size_unit in GPU_MEMORY_FIELDS: + size_value = gpu_specs.get(size_key, "N/A") + type_value = gpu_specs.get(type_key, "") + lines.append(f"- **{label}:** {size_value}{size_unit} {type_value}") + + return lines + + +def render_ncu_metrics( + ncu_metrics: Dict[str, Any], + get_metric_fn: Callable[[str, str], str], +) -> List[str]: + """Render the NCU profiling metrics section.""" + lines = ["", "## NCU Profiling Metrics"] + + for section_name, metrics in NCU_METRIC_SECTIONS.items(): + lines.append("") + lines.append(f"### {section_name}") + for label, key, unit in metrics: + value = get_metric_fn(key, "N/A") + lines.append(f"- **{label}:** {value}{unit}") + + return lines + + +def render_task_instructions() -> List[str]: + """Render the task instructions section for dual-bottleneck analysis.""" + return [ + "", + "## Your Task", + "", + "Identify exactly TWO distinct bottlenecks from the NCU profiling metrics above:", + "1. **Bottleneck 1 (Primary)**: The highest-impact performance issue", + "2. **Bottleneck 2 (Secondary)**: A different category issue that also limits performance", + "", + "For each bottleneck, cite 3-4 specific metrics that reveal the issue, " + "and recommend ONE actionable optimization.", + "", + "**Be surgical and metrics-driven.** Return JSON in the format specified in the system prompt.", + ] + + +def create_metric_getter(kernel_metrics: Dict[str, Any]) -> Callable[[str, str], str]: + """Create a metric getter function for a specific kernel's metrics.""" + + def get_metric(key: str, default: str = "N/A") -> str: + val = kernel_metrics.get(key, default) + if isinstance(val, (int, float)): + return f"{val:.2f}" + return str(val) + + return get_metric + + +# ============================================================================= +# Bottleneck Analysis +# ============================================================================= # System prompt for the Judge LLM (Dual-Bottleneck NCU Analysis) @@ -228,114 +304,27 @@ def extract_judge_response(response_text: str) -> Optional[Dict[str, Any]]: def validate_judge_response(analysis: Dict[str, Any]) -> bool: - """ - Validate that Judge response contains required fields for dual-bottleneck format. - - This function validates the dual-bottleneck format with bottleneck_1 and - bottleneck_2 fields. Both bottlenecks use NCU hardware profiling categories. - - Args: - analysis: Parsed JSON from Judge response - - Returns: - True if response is valid, False otherwise - - Example: - >>> if validate_judge_response(analysis): - ... print("Valid dual-bottleneck response!") - ... else: - ... print("Invalid response - missing required fields") - """ - # Check for dual-bottleneck format - if "bottleneck_1" in analysis and "bottleneck_2" in analysis: - return _validate_bottleneck_entry( - analysis["bottleneck_1"] - ) and _validate_bottleneck_entry(analysis["bottleneck_2"]) - - # Backward compatibility: Check for old single-bottleneck format - if "bottleneck" in analysis: - required_fields = [ - "bottleneck", - "root_cause", - "suggestion", - "priority_metrics", - "expected_improvement", - ] - - for field in required_fields: - if field not in analysis: - return False - - valid_bottlenecks = [ - "memory-bound", - "compute-bound", - "occupancy-limited", - "latency-bound", - ] - if analysis["bottleneck"] not in valid_bottlenecks: - return False - - if not isinstance(analysis["priority_metrics"], list): - return False - - for field in ["root_cause", "suggestion", "expected_improvement"]: - if ( - not isinstance(analysis[field], str) - or len(analysis[field].strip()) < 10 - ): - return False + """Validate that Judge response contains required dual-bottleneck fields.""" + if "bottleneck_1" not in analysis or "bottleneck_2" not in analysis: + return False + return _validate_bottleneck_entry( + analysis["bottleneck_1"] + ) and _validate_bottleneck_entry(analysis["bottleneck_2"]) - return True - return False +VALID_CATEGORIES = frozenset(["memory-bound", "compute-bound", "occupancy-limited", "latency-bound"]) def _validate_bottleneck_entry(bottleneck: Dict[str, Any]) -> bool: - """ - Validate a single bottleneck entry (bottleneck_1 or bottleneck_2). - - Both bottlenecks use NCU hardware profiling categories: - memory-bound, compute-bound, occupancy-limited, latency-bound - - Args: - bottleneck: Bottleneck dictionary to validate - - Returns: - True if valid, False otherwise - """ - required_fields = [ - "category", - "root_cause", - "suggestion", - "priority_metrics", - "expected_improvement", - ] - - for field in required_fields: - if field not in bottleneck: - return False - - # NCU hardware profiling categories only - valid_categories = [ - "memory-bound", - "compute-bound", - "occupancy-limited", - "latency-bound", - ] - - if bottleneck["category"] not in valid_categories: + """Validate a single bottleneck entry.""" + required = ["category", "root_cause", "suggestion", "priority_metrics", "expected_improvement"] + if not all(f in bottleneck for f in required): + return False + if bottleneck["category"] not in VALID_CATEGORIES: return False - if not isinstance(bottleneck["priority_metrics"], list): return False - - for field in ["root_cause", "suggestion", "expected_improvement"]: - if not isinstance(bottleneck[field], str) or len(bottleneck[field].strip()) < 5: + for f in ["root_cause", "suggestion", "expected_improvement"]: + if not isinstance(bottleneck[f], str) or len(bottleneck[f].strip()) < 5: return False - return True - - -if __name__ == "__main__": - print("Judge Prompts Module") - print("=" * 60) From e7ba29a6d4b6b380269bf71aa832123b923de6b0 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Tue, 13 Jan 2026 12:42:25 -0800 Subject: [PATCH 20/22] fix ruff issue --- .../kernel_opt/diagnose_prompt/judger_prompts.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py index b98e4d0..4f75951 100644 --- a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py @@ -37,7 +37,6 @@ from .metric_schema import GPU_MEMORY_FIELDS, GPU_SPEC_FIELDS, NCU_METRIC_SECTIONS - # ============================================================================= # Section Renderers # ============================================================================= @@ -312,12 +311,20 @@ def validate_judge_response(analysis: Dict[str, Any]) -> bool: ) and _validate_bottleneck_entry(analysis["bottleneck_2"]) -VALID_CATEGORIES = frozenset(["memory-bound", "compute-bound", "occupancy-limited", "latency-bound"]) +VALID_CATEGORIES = frozenset( + ["memory-bound", "compute-bound", "occupancy-limited", "latency-bound"] +) def _validate_bottleneck_entry(bottleneck: Dict[str, Any]) -> bool: """Validate a single bottleneck entry.""" - required = ["category", "root_cause", "suggestion", "priority_metrics", "expected_improvement"] + required = [ + "category", + "root_cause", + "suggestion", + "priority_metrics", + "expected_improvement", + ] if not all(f in bottleneck for f in required): return False if bottleneck["category"] not in VALID_CATEGORIES: From 72ac4d193663e39879cf6fe8a0ebcee69f34832f Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Thu, 15 Jan 2026 10:21:38 -0800 Subject: [PATCH 21/22] fix --- .../kernel_opt/diagnose_prompt/gpu_specs.py | 78 +++--------------- .../diagnose_prompt/gpu_specs_database.py | 82 +++++++++++++++++++ .../diagnose_prompt/judger_prompts.py | 41 +++++----- 3 files changed, 116 insertions(+), 85 deletions(-) create mode 100644 kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py index 0309bb3..e465116 100644 --- a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py @@ -22,71 +22,16 @@ """ import subprocess -from typing import Any, Dict, Optional - -# GPU specifications database -# Sources: NVIDIA official specifications, manufacturer datasheets -GPU_SPECS_DATABASE = { - "NVIDIA A100": { - "name": "NVIDIA A100", - "architecture": "Ampere", - "peak_fp32_tflops": 19.5, - "peak_fp16_tflops": 312.0, - "peak_bf16_tflops": 312.0, - "peak_memory_bw_gbps": 1555, - "sm_count": 108, - "max_threads_per_sm": 2048, - "l1_cache_kb": 192, - "l2_cache_mb": 40, - "memory_gb": 40, - "memory_type": "HBM2e", - }, - "NVIDIA H100": { - "name": "NVIDIA H100", - "architecture": "Hopper", - "peak_fp32_tflops": 51.0, - "peak_fp16_tflops": 989.0, - "peak_bf16_tflops": 989.0, - "peak_memory_bw_gbps": 3352, - "sm_count": 132, - "max_threads_per_sm": 2048, - "l1_cache_kb": 256, - "l2_cache_mb": 50, - "memory_gb": 80, - "memory_type": "HBM3", - }, - "NVIDIA RTX 4090": { - "name": "NVIDIA RTX 4090", - "architecture": "Ada Lovelace", - "peak_fp32_tflops": 82.6, - "peak_fp16_tflops": 165.0, - "peak_bf16_tflops": 165.0, - "peak_memory_bw_gbps": 1008, - "sm_count": 128, - "max_threads_per_sm": 1536, - "l1_cache_kb": 128, - "l2_cache_mb": 72, - "memory_gb": 24, - "memory_type": "GDDR6X", - }, - "NVIDIA RTX 5080": { - "name": "NVIDIA RTX 5080", - "architecture": "Blackwell", - "peak_fp32_tflops": 57.0, - "peak_fp16_tflops": 114.0, - "peak_bf16_tflops": 114.0, - "peak_memory_bw_gbps": 960, - "sm_count": 84, - "max_threads_per_sm": 1536, - "l1_cache_kb": 128, - "l2_cache_mb": 64, - "memory_gb": 16, - "memory_type": "GDDR7", - }, -} - - -def query_gpu_name() -> Optional[str]: +from typing import Any + +from kernel_perf_agent.kernel_opt.diagnose_prompt.gpu_specs_database import ( + GPU_SPECS_DATABASE, +) + +__all__ = ["GPU_SPECS_DATABASE", "query_gpu_name", "get_gpu_specs"] + + +def query_gpu_name() -> str | None: """ Query GPU name using nvidia-smi. @@ -109,7 +54,7 @@ def query_gpu_name() -> Optional[str]: return None -def get_gpu_specs(gpu_name: Optional[str] = None) -> Dict[str, Any]: +def get_gpu_specs(gpu_name: str | None = None) -> dict[str, Any]: """ Get GPU specifications for bottleneck analysis. @@ -179,6 +124,7 @@ def get_gpu_specs(gpu_name: Optional[str] = None) -> Dict[str, Any]: print(f"\nDetected GPU: {detected_name}") else: print("\nNo GPU detected (nvidia-smi not available)") + exit() # Get specs specs = get_gpu_specs() diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py new file mode 100644 index 0000000..d5e4586 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GPU Specifications Database + +This module contains the GPU hardware specifications database used for +performance analysis and bottleneck identification. Separated into its +own file to allow easier module overriding. + +Sources: NVIDIA official specifications, manufacturer datasheets +""" + +GPU_SPECS_DATABASE: dict[str, dict[str, object]] = { + "NVIDIA A100": { + "name": "NVIDIA A100", + "architecture": "Ampere", + "peak_fp32_tflops": 19.5, + "peak_fp16_tflops": 312.0, + "peak_bf16_tflops": 312.0, + "peak_memory_bw_gbps": 1555, + "sm_count": 108, + "max_threads_per_sm": 2048, + "l1_cache_kb": 192, + "l2_cache_mb": 40, + "memory_gb": 40, + "memory_type": "HBM2e", + }, + "NVIDIA H100": { + "name": "NVIDIA H100", + "architecture": "Hopper", + "peak_fp32_tflops": 51.0, + "peak_fp16_tflops": 989.0, + "peak_bf16_tflops": 989.0, + "peak_memory_bw_gbps": 3352, + "sm_count": 132, + "max_threads_per_sm": 2048, + "l1_cache_kb": 256, + "l2_cache_mb": 50, + "memory_gb": 80, + "memory_type": "HBM3", + }, + "NVIDIA RTX 4090": { + "name": "NVIDIA RTX 4090", + "architecture": "Ada Lovelace", + "peak_fp32_tflops": 82.6, + "peak_fp16_tflops": 165.0, + "peak_bf16_tflops": 165.0, + "peak_memory_bw_gbps": 1008, + "sm_count": 128, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 72, + "memory_gb": 24, + "memory_type": "GDDR6X", + }, + "NVIDIA RTX 5080": { + "name": "NVIDIA RTX 5080", + "architecture": "Blackwell", + "peak_fp32_tflops": 57.0, + "peak_fp16_tflops": 114.0, + "peak_bf16_tflops": 114.0, + "peak_memory_bw_gbps": 960, + "sm_count": 84, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 64, + "memory_gb": 16, + "memory_type": "GDDR7", + }, +} diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py index 4f75951..27bea50 100644 --- a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py @@ -32,7 +32,7 @@ Metric definitions are in metric_schema.py. """ -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable from .metric_schema import GPU_MEMORY_FIELDS, GPU_SPEC_FIELDS, NCU_METRIC_SECTIONS @@ -42,23 +42,26 @@ # ============================================================================= -def render_problem_description(problem_description: str) -> List[str]: +def render_problem_description(problem_description: str) -> list[str]: """Render the problem description section.""" return ["## Problem Description", "", problem_description] -def render_kernel_code(kernel_code: str, language: str = "python") -> List[str]: +def render_kernel_code(kernel_code: str, language: str = "python") -> list[str]: """Render the kernel code section with syntax highlighting.""" return ["", "## Current Kernel Code", "", f"```{language}", kernel_code, "```"] -def render_gpu_specs(gpu_specs: Dict[str, Any]) -> List[str]: +def render_gpu_specs(gpu_specs: dict[str, Any]) -> list[str]: """Render the GPU hardware specifications section.""" lines = ["", "## GPU Hardware Specifications", ""] for label, key, unit in GPU_SPEC_FIELDS: value = gpu_specs.get(key, "N/A") - lines.append(f"- **{label}:** {value}{unit}") + if value == "N/A": + lines.append(f"- **{label}:** N/A") + else: + lines.append(f"- **{label}:** {value}{unit}") for label, size_key, type_key, size_unit in GPU_MEMORY_FIELDS: size_value = gpu_specs.get(size_key, "N/A") @@ -69,9 +72,9 @@ def render_gpu_specs(gpu_specs: Dict[str, Any]) -> List[str]: def render_ncu_metrics( - ncu_metrics: Dict[str, Any], + ncu_metrics: dict[str, Any], get_metric_fn: Callable[[str, str], str], -) -> List[str]: +) -> list[str]: """Render the NCU profiling metrics section.""" lines = ["", "## NCU Profiling Metrics"] @@ -85,7 +88,7 @@ def render_ncu_metrics( return lines -def render_task_instructions() -> List[str]: +def render_task_instructions() -> list[str]: """Render the task instructions section for dual-bottleneck analysis.""" return [ "", @@ -102,7 +105,7 @@ def render_task_instructions() -> List[str]: ] -def create_metric_getter(kernel_metrics: Dict[str, Any]) -> Callable[[str, str], str]: +def create_metric_getter(kernel_metrics: dict[str, Any]) -> Callable[[str, str], str]: """Create a metric getter function for a specific kernel's metrics.""" def get_metric(key: str, default: str = "N/A") -> str: @@ -172,9 +175,9 @@ def get_metric(key: str, default: str = "N/A") -> str: def build_judge_optimization_prompt( kernel_code: str, problem_description: str, - ncu_metrics: Dict[str, Any], - gpu_specs: Dict[str, Any], -) -> Tuple[str, str]: + ncu_metrics: dict[str, Any], + gpu_specs: dict[str, Any], +) -> tuple[str, str]: """ Build system and user prompts for Judge to analyze bottleneck. @@ -209,7 +212,7 @@ def build_judge_optimization_prompt( raise ValueError("NCU metrics are empty - cannot build judge prompt") # Extract first kernel's metrics for the metric getter - first_kernel = list(ncu_metrics.values())[0] if ncu_metrics else {} + first_kernel = list(ncu_metrics.values())[0] get_metric_fn = create_metric_getter(first_kernel) # Build user prompt using modular section renderers @@ -226,7 +229,7 @@ def build_judge_optimization_prompt( return JUDGE_SYSTEM_PROMPT, user_prompt -def extract_judge_response(response_text: str) -> Optional[Dict[str, Any]]: +def extract_judge_response(response_text: str) -> dict[str, Any] | None: """ Extract and parse JSON from Judge LLM response. @@ -302,7 +305,7 @@ def extract_judge_response(response_text: str) -> Optional[Dict[str, Any]]: return None -def validate_judge_response(analysis: Dict[str, Any]) -> bool: +def validate_judge_response(analysis: dict[str, Any]) -> bool: """Validate that Judge response contains required dual-bottleneck fields.""" if "bottleneck_1" not in analysis or "bottleneck_2" not in analysis: return False @@ -311,12 +314,12 @@ def validate_judge_response(analysis: Dict[str, Any]) -> bool: ) and _validate_bottleneck_entry(analysis["bottleneck_2"]) -VALID_CATEGORIES = frozenset( - ["memory-bound", "compute-bound", "occupancy-limited", "latency-bound"] -) +VALID_CATEGORIES = { + "memory-bound", "compute-bound", "occupancy-limited", "latency-bound" +} -def _validate_bottleneck_entry(bottleneck: Dict[str, Any]) -> bool: +def _validate_bottleneck_entry(bottleneck: dict[str, Any]) -> bool: """Validate a single bottleneck entry.""" required = [ "category", From e2c599e5081cbf986790ae9636b1e2cad873c0be Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Thu, 15 Jan 2026 11:38:57 -0800 Subject: [PATCH 22/22] fix ruff --- .../kernel_opt/diagnose_prompt/judger_prompts.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py index 27bea50..e07749e 100644 --- a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py @@ -315,7 +315,10 @@ def validate_judge_response(analysis: dict[str, Any]) -> bool: VALID_CATEGORIES = { - "memory-bound", "compute-bound", "occupancy-limited", "latency-bound" + "memory-bound", + "compute-bound", + "occupancy-limited", + "latency-bound", }