From 07a3268dd08dac21f70620bdc07ddbd5ccf9328a Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 11:46:28 -0800 Subject: [PATCH 01/12] NCU profiling wrapper generation and execution --- .../kernel_opt/profiler/__init__.py | 0 .../kernel_opt/profiler/gpu_specs.py | 227 +++++++++ .../kernel_opt/profiler/ncu_profiler.py | 449 ++++++++++++++++++ .../profiling/__init__.py | 6 + .../profiling/kernel_profiler.py | 89 ++++ .../profiling/ncu_wrapper_generator.py | 154 ++++++ 6 files changed, 925 insertions(+) create mode 100644 kernel_perf_agent/kernel_opt/profiler/__init__.py create mode 100644 kernel_perf_agent/kernel_opt/profiler/gpu_specs.py create mode 100644 kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py create mode 100644 triton_kernel_agent/opt_worker_component/profiling/__init__.py create mode 100644 triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py create mode 100644 triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py diff --git a/kernel_perf_agent/kernel_opt/profiler/__init__.py b/kernel_perf_agent/kernel_opt/profiler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py b/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py new file mode 100644 index 0000000..ed8b50f --- /dev/null +++ b/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +""" +GPU Specifications Database for Bottleneck Analysis + +This module provides GPU hardware specifications needed for performance analysis +and bottleneck identification. It includes peak compute performance, memory bandwidth, +cache sizes, and SM counts for common NVIDIA GPUs. + +""" + +import re +import subprocess +from typing import Any, Dict, Optional + + +# GPU specifications database +# Sources: NVIDIA official specifications, manufacturer datasheets +GPU_SPECS_DATABASE = { + "NVIDIA A100": { + "name": "NVIDIA A100", + "architecture": "Ampere", + "peak_fp32_tflops": 19.5, + "peak_fp16_tflops": 312.0, + "peak_bf16_tflops": 312.0, + "peak_memory_bw_gbps": 1555, + "sm_count": 108, + "max_threads_per_sm": 2048, + "l1_cache_kb": 192, + "l2_cache_mb": 40, + "memory_gb": 40, + "memory_type": "HBM2e", + }, + "NVIDIA H100": { + "name": "NVIDIA H100", + "architecture": "Hopper", + "peak_fp32_tflops": 51.0, + "peak_fp16_tflops": 989.0, + "peak_bf16_tflops": 989.0, + "peak_memory_bw_gbps": 3352, + "sm_count": 132, + "max_threads_per_sm": 2048, + "l1_cache_kb": 256, + "l2_cache_mb": 50, + "memory_gb": 80, + "memory_type": "HBM3", + }, + "NVIDIA RTX 4090": { + "name": "NVIDIA RTX 4090", + "architecture": "Ada Lovelace", + "peak_fp32_tflops": 82.6, + "peak_fp16_tflops": 165.0, + "peak_bf16_tflops": 165.0, + "peak_memory_bw_gbps": 1008, + "sm_count": 128, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 72, + "memory_gb": 24, + "memory_type": "GDDR6X", + }, + "NVIDIA RTX 5080": { + "name": "NVIDIA RTX 5080", + "architecture": "Blackwell", + "peak_fp32_tflops": 57.0, + "peak_fp16_tflops": 114.0, + "peak_bf16_tflops": 114.0, + "peak_memory_bw_gbps": 960, + "sm_count": 84, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 64, + "memory_gb": 16, + "memory_type": "GDDR7", + }, +} + + +def query_gpu_name() -> Optional[str]: + """ + Query GPU name using nvidia-smi. + + Returns: + GPU name string, or None if query fails + """ + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + gpu_name = result.stdout.strip() + # Normalize whitespace + gpu_name = re.sub(r"\s+", " ", gpu_name) + return gpu_name + except (subprocess.TimeoutExpired, FileNotFoundError, Exception): + pass + return None + + +def get_gpu_specs(gpu_name: Optional[str] = None) -> Dict[str, Any]: + """ + Get GPU specifications for bottleneck analysis. + + This function returns hardware specifications needed for performance analysis, + including peak compute performance, memory bandwidth, cache sizes, and SM counts. + + Args: + gpu_name: GPU name (if None, auto-detect with nvidia-smi) + + Returns: + Dictionary with GPU specifications containing: + - name: GPU name + - architecture: GPU architecture (e.g., "Ampere", "Hopper") + - peak_fp32_tflops: Peak FP32 compute performance in TFLOPS + - peak_fp16_tflops: Peak FP16 compute performance in TFLOPS + - peak_bf16_tflops: Peak BF16 compute performance in TFLOPS (0 if not supported) + - peak_memory_bw_gbps: Peak memory bandwidth in GB/s + - sm_count: Number of streaming multiprocessors + - max_threads_per_sm: Maximum threads per SM + - l1_cache_kb: L1 cache size in KB per SM + - l2_cache_mb: Total L2 cache size in MB + - memory_gb: Total GPU memory in GB + - memory_type: Memory type (e.g., "HBM2e", "GDDR6X") + + Examples: + >>> specs = get_gpu_specs() # Auto-detect + >>> print(f"Peak BW: {specs['peak_memory_bw_gbps']} GB/s") + + >>> specs = get_gpu_specs("NVIDIA A100") + >>> print(f"SM Count: {specs['sm_count']}") + """ + # Auto-detect if not provided + if gpu_name is None: + gpu_name = query_gpu_name() + + # Return default if detection failed + if gpu_name is None: + print("⚠️ GPU auto-detection failed, using A100 specs as fallback") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + # Try exact match + if gpu_name in GPU_SPECS_DATABASE: + return GPU_SPECS_DATABASE[gpu_name].copy() + + # Try fuzzy match (contains or partial match) + gpu_name_lower = gpu_name.lower() + for key, specs in GPU_SPECS_DATABASE.items(): + key_lower = key.lower() + # Check if either name contains the other + if gpu_name_lower in key_lower or key_lower in gpu_name_lower: + print(f"ℹ️ Matched '{gpu_name}' to '{key}' (fuzzy match)") + return specs.copy() + + # Fallback to A100 specs with warning + print(f"⚠️ Unknown GPU: '{gpu_name}', using A100 specs as fallback") + print(f" Available GPUs: {', '.join(GPU_SPECS_DATABASE.keys())}") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + +def format_gpu_specs_for_prompt(specs: Dict[str, Any]) -> str: + """ + Format GPU specifications as a readable string for LLM prompts. + + Args: + specs: GPU specifications dictionary from get_gpu_specs() + + Returns: + Formatted string with GPU specifications + + Example: + >>> specs = get_gpu_specs("NVIDIA A100") + >>> print(format_gpu_specs_for_prompt(specs)) + GPU: NVIDIA A100 (Ampere) + - Peak Memory Bandwidth: 1555 GB/s + - Peak FP32 Performance: 19.5 TFLOPS + - Peak FP16 Performance: 312.0 TFLOPS + - SM Count: 108 + - L1 Cache: 192 KB per SM + - L2 Cache: 40 MB + """ + lines = [ + f"GPU: {specs['name']} ({specs.get('architecture', 'Unknown')})", + f"- Peak Memory Bandwidth: {specs['peak_memory_bw_gbps']} GB/s", + f"- Peak FP32 Performance: {specs['peak_fp32_tflops']} TFLOPS", + ] + + if specs.get("peak_fp16_tflops", 0) > 0: + lines.append(f"- Peak FP16 Performance: {specs['peak_fp16_tflops']} TFLOPS") + + if specs.get("peak_bf16_tflops", 0) > 0: + lines.append(f"- Peak BF16 Performance: {specs['peak_bf16_tflops']} TFLOPS") + + lines.extend( + [ + f"- SM Count: {specs['sm_count']}", + f"- Max Threads per SM: {specs['max_threads_per_sm']}", + f"- L1 Cache: {specs['l1_cache_kb']} KB per SM", + f"- L2 Cache: {specs['l2_cache_mb']} MB", + f"- Memory: {specs['memory_gb']} GB {specs.get('memory_type', '')}", + ] + ) + + return "\n".join(lines) + + +if __name__ == "__main__": + print("GPU Specifications Module") + print("=" * 60) + + # Auto-detect GPU + detected_name = query_gpu_name() + if detected_name: + print(f"\nDetected GPU: {detected_name}") + else: + print("\nNo GPU detected (nvidia-smi not available)") + + # Get specs + specs = get_gpu_specs() + print(f"\n{format_gpu_specs_for_prompt(specs)}") + + # Show all available GPUs + print(f"\n{'=' * 60}") + print("Available GPU specifications in database:") + for gpu_name in sorted(GPU_SPECS_DATABASE.keys()): + print(f" - {gpu_name}") diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py new file mode 100644 index 0000000..26e59d2 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -0,0 +1,449 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +NCU Profiling Module for Triton Kernels + +This module wraps three tasks: +1) Collect core metrics for Triton CUDA kernels with Nsight Compute into CSV (`profile_triton_kernel`). +2) Extract and clean those metrics into a DataFrame from the CSV (`load_ncu_metrics`). +3) Convert the metrics table into a string suitable for inclusion in an LLM prompt (`metrics_to_prompt`). + +""" + +import json +import math +import os +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Union + +import numpy as np +import pandas as pd + + +__all__ = [ + "METRICS", + "METRIC_COLUMNS", + "profile_triton_kernel", + "load_ncu_metrics", + "metrics_to_prompt", +] + +METRICS = ",".join( + [ + "sm__cycles_active.avg", + "sm__warps_active.avg.pct_of_peak_sustained_active", + "launch__occupancy_limit_blocks", + "launch__occupancy_limit_registers", + "launch__occupancy_limit_shared_mem", + "launch__registers_per_thread", + "sm__inst_executed.sum", + "sm__inst_executed_pipe_fp32.avg.pct_of_peak_sustained_active", + "sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", + "dram__bytes_read.sum", + "dram__bytes_write.sum", + "dram__throughput.avg.pct_of_peak_sustained_elapsed", + "dram__bytes.sum.per_second", + "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed", + "launch__shared_mem_per_block_allocated", + "l1tex__t_sector_hit_rate.pct", + "l1tex__throughput.avg.pct_of_peak_sustained_active", + "lts__t_sector_hit_rate.pct", + "lts__throughput.avg.pct_of_peak_sustained_active", + "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", + "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", + "smsp__warp_issue_stalled_memory_dependency_per_warp_active.pct", + "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", + "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", + "smsp__warp_issue_stalled_barrier_per_warp_active.pct", + "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", + "smsp__sass_average_branch_targets_threads_uniform.pct", + ] +) + +# METRICS = ",".join( +# [ +# # Tensor Core (MOST CRITICAL for GEMM) +# "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", +# # Memory Bandwidth & Cache +# "dram__throughput.avg.pct_of_peak_sustained_elapsed", +# "l1tex__t_sector_hit_rate.pct", +# "lts__t_sector_hit_rate.pct", +# # Memory Access Efficiency +# "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", +# # Stalls & Bottlenecks (H100-compatible stall metrics) +# "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", +# "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", +# # Occupancy & Resources +# "launch__shared_mem_per_block_allocated", +# "sm__warps_active.avg.pct_of_peak_sustained_elapsed", +# # SM Utilization +# "sm__cycles_active.avg.pct_of_peak_sustained_elapsed", +# ] +# ) + +# # CRITICAL GEMM Optimization Metrics (Optimized for Matrix Multiplication) +# METRICS = ",".join( +# [ +# # === Memory Efficiency Metrics (CRITICAL) === +# "dram__throughput.avg.pct_of_peak_sustained_elapsed", # DRAM bandwidth utilization +# "dram__bytes_read.sum", # Total DRAM reads +# "dram__bytes_write.sum", # Total DRAM writes +# "l1tex__t_sector_hit_rate.pct", # L1 cache hit rate (CRITICAL for data reuse) +# "l1tex__throughput.avg.pct_of_peak_sustained_elapsed", # L1 throughput (fixed from _active) +# "lts__t_sector_hit_rate.pct", # L2 cache hit rate +# "lts__throughput.avg.pct_of_peak_sustained_elapsed", # L2 throughput (fixed from _active) +# "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", # Memory coalescing (NEW - CRITICAL) +# "l1tex__data_pipe_lsu_wavefronts_mem_shared_op_ld.sum", # Shared memory loads (NEW) +# # === Compute Utilization Metrics (CRITICAL for GEMM) === +# "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", # Tensor core pipeline (NEW - CRITICAL) +# "smsp__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", # Tensor core inst % +# "smsp__inst_executed_pipe_tensor.sum", # Total tensor core instructions (NEW) +# "sm__inst_executed_pipe_fp32.avg.pct_of_peak_sustained_active", # FP32 pipeline +# "sm__cycles_active.avg.pct_of_peak_sustained_elapsed", # SM active cycles (fixed from .avg only) +# "sm__inst_executed.sum", # Total instructions +# "smsp__inst_executed.avg.per_cycle_active", # IPC (NEW) +# # === Occupancy & Resource Metrics === +# "sm__warps_active.avg.pct_of_peak_sustained_elapsed", # Warp occupancy (fixed from _active) +# "sm__maximum_warps_per_active_cycle_pct", # Max achievable occupancy (NEW) +# "launch__occupancy_limit_blocks", # Block limit +# "launch__occupancy_limit_registers", # Register limit +# "launch__occupancy_limit_shared_mem", # Shared memory limit +# "launch__registers_per_thread", # Registers per thread +# "launch__shared_mem_per_block_allocated", # Shared memory usage (NEW - CRITICAL) +# # === Stall & Latency Metrics === +# "smsp__warp_issue_stalled_memory_dependency_per_warp_active.pct", # Memory dependency stalls +# "smsp__warp_issue_stalled_mem_throttle_per_warp_active.pct", # Memory throttle stalls (NEW - CRITICAL) +# "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", # Short scoreboard stalls +# "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", # Long scoreboard stalls +# "smsp__warp_issue_stalled_barrier_per_warp_active.pct", # Barrier stalls +# "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", # Branch stalls +# ] +# ) + + +# List version for convenient header selection +METRIC_COLUMNS: List[str] = [s.strip() for s in METRICS.split(",")] + + +def profile_triton_kernel( + benchmark_script: Path, + workdir: Path, + out_csv: str = "ncu_output.csv", + python_executable: Optional[str] = None, + ncu_bin: Optional[str] = None, + launch_count: int = 20, + timeout: int = 120, +) -> Path: + """ + Profile a Triton kernel using NCU. + + Args: + benchmark_script: Path to benchmark script that calls the kernel + workdir: Working directory for execution + out_csv: Output CSV filename + python_executable: Python executable to use (default: sys.executable) + ncu_bin: Path to NCU binary (default: auto-detect) + launch_count: Number of kernel launches to profile + timeout: Timeout in seconds for NCU execution + + Returns: + Path to output CSV file + + Raises: + RuntimeError: If NCU profiling fails + FileNotFoundError: If NCU binary or output CSV not found + """ + # Resolve paths + if python_executable is None: + python_executable = sys.executable + + if ncu_bin is None: + ncu_bin = shutil.which("ncu") or "/usr/local/cuda/bin/ncu" + + if not Path(ncu_bin).exists(): + raise FileNotFoundError(f"NCU binary not found: {ncu_bin}") + + csv_path = (workdir / out_csv).resolve() + benchmark_script = benchmark_script.resolve() + + if not benchmark_script.exists(): + raise FileNotFoundError(f"Benchmark script not found: {benchmark_script}") + + # Preserve important environment variables + env = os.environ.copy() + + # Add Triton-specific environment variables + env["TRITON_CACHE_DIR"] = str(workdir / ".triton_cache") + + preserve = ",".join( + [ + "PATH", + "LD_LIBRARY_PATH", + "CUDA_VISIBLE_DEVICES", + "PYTHONPATH", + "TRITON_CACHE_DIR", + "TORCH_EXTENSIONS_DIR", + "CONDA_PREFIX", + "CONDA_DEFAULT_ENV", + ] + ) + + # Build NCU command + cmd = [ + "sudo", + "-E", + f"--preserve-env={preserve}", + ncu_bin, + "--csv", + "--page=raw", + "--kernel-name-base=demangled", + "--target-processes=all", + "--replay-mode=kernel", + "--profile-from-start=on", + f"--log-file={str(csv_path)}", + f"--metrics={METRICS}", + "--launch-skip=0", + f"--launch-count={launch_count}", + python_executable, + str(benchmark_script), + ] + + print(f"[NCU] Running profiling...") + print(f"[NCU] Benchmark: {benchmark_script.name}") + print(f"[NCU] Output: {csv_path}") + print(f"[NCU] Command: {' '.join(cmd[:10])}... (truncated)") + + try: + result = subprocess.run( + cmd, + cwd=str(workdir), + env=env, + capture_output=True, + text=True, + timeout=timeout, + ) + + if result.returncode != 0: + error_msg = result.stderr or result.stdout + raise RuntimeError( + f"NCU profiling failed with return code {result.returncode}:\n{error_msg[:500]}" + ) + + if not csv_path.exists(): + raise FileNotFoundError(f"NCU did not create output CSV: {csv_path}") + + # Check if CSV has content + csv_size = csv_path.stat().st_size + if csv_size < 100: + raise RuntimeError( + f"NCU CSV file is too small ({csv_size} bytes), likely empty" + ) + + print(f"[NCU] ✓ Profiling completed successfully") + print(f"[NCU] ✓ CSV written: {csv_path} ({csv_size} bytes)") + return csv_path + + except subprocess.TimeoutExpired: + raise RuntimeError(f"NCU profiling timed out after {timeout} seconds") + except Exception as e: + raise RuntimeError(f"NCU profiling failed: {e}") + + +def load_ncu_metrics( + csv_path: Union[str, Path], + columns: Optional[Sequence[str]] = None, + extra_keep: Optional[Sequence[str]] = ("Kernel Name",), + coerce_numeric: bool = True, + name_list: Optional[Sequence[str]] = None, + select: str = "last", +) -> pd.DataFrame: + """ + Load and parse NCU metrics from CSV file. + + Args: + csv_path: Path to NCU CSV output + columns: Specific metric columns to load (default: all METRIC_COLUMNS) + extra_keep: Additional columns to keep (e.g., "Kernel Name") + coerce_numeric: Convert metric values to numeric + name_list: Filter by kernel names (substring match) + select: Selection policy when multiple rows per name: + "first", "last", "max_cycles" + + Returns: + DataFrame with parsed metrics + + Raises: + FileNotFoundError: If CSV file not found + ValueError: If no requested columns found in CSV + """ + csv_path = Path(csv_path) + if not csv_path.exists(): + raise FileNotFoundError(f"CSV not found: {csv_path}") + + df = pd.read_csv(csv_path, comment="=", low_memory=False) + + metric_cols = list(columns) if columns is not None else METRIC_COLUMNS + keep_cols: List[str] = [] + if extra_keep: + keep_cols.extend([c for c in extra_keep if c in df.columns]) + keep_cols.extend([c for c in metric_cols if c in df.columns]) + + if not keep_cols: + raise ValueError("No requested columns found in the CSV header.") + + sub = df[keep_cols].copy() + + # Drop the units row (first row often contains units like "%", "inst", etc.) + if len(sub) > 0: + first_row_str = sub.iloc[0].astype(str).str.lower() + unit_tokens = ("%", "inst", "cycle", "block", "register", "register/thread") + if first_row_str.apply(lambda x: any(tok in x for tok in unit_tokens)).any(): + sub = sub.iloc[1:].reset_index(drop=True) + + # Coerce metrics to numeric + if coerce_numeric: + metric_in_sub = [c for c in metric_cols if c in sub.columns] + sub[metric_in_sub] = ( + sub[metric_in_sub] + .replace({",": "", "%": ""}, regex=True) + .apply(pd.to_numeric, errors="coerce") + ) + + # Filter by kernel name list if provided + if name_list: + results = [] + for name in name_list: + # Use contains match instead of exact equality (for Triton's long kernel names) + matched = sub[ + sub["Kernel Name"].astype(str).str.contains(name, regex=False, na=False) + ] + if matched.empty: + continue + if len(matched) > 1: + if select == "first": + row = matched.iloc[[0]] + elif select == "last": + row = matched.iloc[[-1]] + elif ( + select == "max_cycles" + and "sm__cycles_active.avg" in matched.columns + ): + row = matched.sort_values( + "sm__cycles_active.avg", ascending=False + ).head(1) + else: + row = matched.iloc[[-1]] # fallback + else: + row = matched + results.append(row) + + if results: + sub = pd.concat(results, ignore_index=True) + else: + sub = pd.DataFrame(columns=keep_cols) + elif select in ("first", "last", "max_cycles"): + # Apply selection to all rows if no name filter + if len(sub) > 0: + if select == "first": + sub = sub.iloc[[0]] + elif select == "last": + sub = sub.iloc[[-1]] + elif select == "max_cycles" and "sm__cycles_active.avg" in sub.columns: + sub = sub.sort_values("sm__cycles_active.avg", ascending=False).head(1) + + return sub + + +def metrics_to_prompt( + df: pd.DataFrame, + title: str = "GPU Profiling Metrics:", + key_by: str = "Kernel Name", + round_digits: Optional[int] = 3, + compact: bool = False, + keep_cols: Optional[List[str]] = None, +) -> str: + """ + Convert NCU metrics DataFrame to JSON string for LLM prompts. + + Returns JSON in format: + { + "": { "": , ... } + } + Args: + df: DataFrame with NCU metrics + title: Title for the metrics section (not included in output) + key_by: Column to use as key (usually "Kernel Name") + round_digits: Number of decimal places for rounding + compact: If True, use compact JSON (no indentation) + keep_cols: Specific columns to include in output + + Returns: + JSON string with metrics + """ + + def _safe(v: Any) -> Any: + """Convert values to JSON-safe format.""" + if v is None or (isinstance(v, float) and math.isnan(v)): + return None + if isinstance(v, (pd.Timestamp, pd.Timedelta, pd.Interval)): + return str(v) + if isinstance(v, np.generic): + v = v.item() + if isinstance(v, float) and math.isinf(v): + return "inf" if v > 0 else "-inf" + if isinstance(v, float) and round_digits is not None: + return round(v, round_digits) + return v + + # Empty table + if df is None or df.empty: + return "{}" + + cols = list(df.columns) + + # Round numeric columns + if round_digits is not None: + num_cols = df.select_dtypes(include="number").columns + if len(num_cols) > 0: + df = df.copy() + df[num_cols] = df[num_cols].round(round_digits) + + # If key column is missing, return a list of rows + if key_by not in cols: + rows = [ + {k: _safe(v) for k, v in rec.items()} + for rec in df.to_dict(orient="records") + ] + return json.dumps(rows, ensure_ascii=False, indent=None if compact else 2) + + # Determine value columns + value_cols = [c for c in cols if c != key_by] + if keep_cols is not None: + value_cols = [c for c in value_cols if c in keep_cols] + + data: Dict[str, Any] = {} + for rec in df[[key_by] + value_cols].to_dict(orient="records"): + k = str(rec.pop(key_by)) + val_obj = {ck: _safe(cv) for ck, cv in rec.items()} + if k in data: + # Multiple rows for same key - convert to list + if isinstance(data[k], list): + data[k].append(val_obj) + else: + data[k] = [data[k], val_obj] + else: + data[k] = val_obj + + return json.dumps(data, ensure_ascii=False, indent=None if compact else 2) + + +if __name__ == "__main__": + print("ncu_profiler module loaded.") + print("Import its functions in your scripts:") + print( + " from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import profile_triton_kernel, load_ncu_metrics, metrics_to_prompt" + ) diff --git a/triton_kernel_agent/opt_worker_component/profiling/__init__.py b/triton_kernel_agent/opt_worker_component/profiling/__init__.py new file mode 100644 index 0000000..cb28cc6 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/__init__.py @@ -0,0 +1,6 @@ +"""Profiling infrastructure for NCU-based kernel analysis.""" + +from .ncu_wrapper_generator import NCUWrapperGenerator +from .kernel_profiler import KernelProfiler + +__all__ = ["NCUWrapperGenerator", "KernelProfiler"] diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py new file mode 100644 index 0000000..7ba98ef --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -0,0 +1,89 @@ +"""Kernel profiling with NCU.""" + +import json +import logging +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import pandas as pd + +from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( + load_ncu_metrics, + metrics_to_prompt, + profile_triton_kernel, +) + +from .ncu_wrapper_generator import NCUWrapperGenerator + + +class KernelProfiler: + """Profiles Triton kernels using NVIDIA Nsight Compute (NCU).""" + + def __init__( + self, + logger: logging.Logger, + temp_dir: Path, + logs_dir: Path, + ncu_bin_path: Optional[str] = None, + ): + """ + Initialize the kernel profiler. + + Args: + logger: Logger instance + temp_dir: Temporary directory for profiling artifacts + logs_dir: Directory for saving profiling logs + ncu_bin_path: Path to NCU binary (auto-detect if None) + """ + self.logger = logger + self.temp_dir = temp_dir + self.logs_dir = logs_dir + self.ncu_bin_path = ncu_bin_path + self.wrapper_generator = NCUWrapperGenerator(logger) + + def profile_kernel( + self, kernel_file: Path, problem_file: Path, round_num: int + ) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]: + """ + Profile kernel with NCU. + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + round_num: Current optimization round number + + Returns: + Tuple of (metrics_df, metrics_json) or (None, None) on failure + """ + try: + # Create NCU wrapper script + wrapper_file = self.wrapper_generator.create_ncu_wrapper( + kernel_file, problem_file, self.temp_dir + ) + + # Profile with NCU + csv_file = f"ncu_round_{round_num}.csv" + csv_path = profile_triton_kernel( + benchmark_script=wrapper_file, + workdir=self.temp_dir, + out_csv=csv_file, + ncu_bin=self.ncu_bin_path, + launch_count=20, + timeout=120, + ) + + # Load and parse metrics + metrics_df = load_ncu_metrics(csv_path, select="last") + metrics_json = json.loads(metrics_to_prompt(metrics_df)) + + # Save metrics JSON + metrics_file = self.logs_dir / f"round{round_num:03d}_ncu_metrics.json" + with open(metrics_file, "w") as f: + json.dump(metrics_json, f, indent=2) + + self.logger.info(f"✅ NCU profiling completed for round {round_num}") + return metrics_df, metrics_json + + except Exception as e: + self.logger.error(f"❌ NCU profiling failed: {e}") + return None, None diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py new file mode 100644 index 0000000..96f391f --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py @@ -0,0 +1,154 @@ +"""NCU wrapper script generation for kernel profiling.""" + +import logging +from pathlib import Path + + +class NCUWrapperGenerator: + """Generates NCU wrapper scripts for profiling Triton kernels.""" + + def __init__(self, logger: logging.Logger): + """ + Initialize the NCU wrapper generator. + + Args: + logger: Logger instance + """ + self.logger = logger + + def create_ncu_wrapper(self, kernel_file: Path, problem_file: Path, output_dir: Path) -> Path: + """ + Create NCU wrapper script for profiling. + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + output_dir: Directory to write wrapper script + + Returns: + Path to created wrapper script + """ + wrapper_file = output_dir / "ncu_wrapper.py" + + wrapper_content = f'''"""NCU profiling wrapper.""" +import sys +import torch +import inspect +sys.path.insert(0, str({str(kernel_file.parent)!r})) +sys.path.insert(0, str({str(problem_file.parent)!r})) + +from {kernel_file.stem} import kernel_function +from {problem_file.stem} import get_inputs, get_init_inputs + +# Try to import Model if it exists (for Conv, Linear, etc.) +try: + from {problem_file.stem} import Model + has_model = True +except ImportError: + has_model = False + +# Get inputs +inputs = get_inputs() + +# Get additional initialization inputs (e.g., features, eps for RMSNorm) +init_inputs = get_init_inputs() + +# Infer required dtype from kernel function signature/docstring +required_dtype = None +try: + # Try to get dtype from kernel function docstring or source + kernel_source = inspect.getsource(kernel_function) + if 'bfloat16' in kernel_source.lower(): + required_dtype = torch.bfloat16 + elif 'float16' in kernel_source.lower() or 'half' in kernel_source.lower(): + required_dtype = torch.float16 + elif 'float32' in kernel_source.lower(): + required_dtype = torch.float32 +except: + pass + +# Prepare inputs: move to CUDA and convert dtype if needed +# IMPORTANT: Only convert floating-point tensors; preserve integer tensors (e.g., class labels) +cuda_inputs = [] +for inp in inputs: + if isinstance(inp, torch.Tensor): + # Move to CUDA if not already + if not inp.is_cuda: + inp = inp.cuda() + # Convert dtype if required, but ONLY for floating-point tensors + # Preserve integer/bool tensors (e.g., targets for classification) + if required_dtype is not None and inp.is_floating_point() and inp.dtype != required_dtype: + inp = inp.to(required_dtype) + cuda_inputs.append(inp) + else: + cuda_inputs.append(inp) + +# Check if this is a conv-like kernel that needs a Model to extract weights +needs_model = False +try: + sig = inspect.signature(kernel_function) + params = list(sig.parameters.keys()) + # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) + if 'weight' in params: + needs_model = True +except: + pass + +if needs_model and has_model and init_inputs: + # Initialize model to extract weight and bias + model = Model(*init_inputs) if init_inputs else Model() + + # Move model to CUDA and convert dtype + model = model.cuda() + if required_dtype is not None: + model = model.to(required_dtype) + + # Extract weight and bias from model + # Check various possible attribute names + weight = None + bias = None + layer = None + for attr_name in ['conv1', 'conv2', 'conv3', 'conv1d', 'conv2d', 'conv', 'conv3d', 'linear', 'fc']: + if hasattr(model, attr_name): + layer = getattr(model, attr_name) + if hasattr(layer, 'weight'): + weight = layer.weight + bias = layer.bias if hasattr(layer, 'bias') else None + break + + if weight is not None and layer is not None: + # Build arguments for kernel_function using keyword arguments + # to avoid positional argument misalignment issues + kernel_kwargs = {{}} + + # Add conv/linear-specific parameters if they exist + if hasattr(layer, 'stride'): + stride = layer.stride[0] if isinstance(layer.stride, (tuple, list)) else layer.stride + kernel_kwargs['stride'] = stride + if hasattr(layer, 'padding'): + padding = layer.padding[0] if isinstance(layer.padding, (tuple, list)) else layer.padding + kernel_kwargs['padding'] = padding + if hasattr(layer, 'dilation'): + dilation = layer.dilation[0] if isinstance(layer.dilation, (tuple, list)) else layer.dilation + kernel_kwargs['dilation'] = dilation + if hasattr(layer, 'groups'): + kernel_kwargs['groups'] = layer.groups + + # Call kernel with extracted parameters + output = kernel_function(cuda_inputs[0], weight, bias, **kernel_kwargs) + else: + # Fallback to original behavior + output = kernel_function(*cuda_inputs, *init_inputs) +else: + # Run kernel with both tensor inputs and initialization inputs + # For example: RMSNorm needs kernel_function(x, features, eps) + # For cross-entropy: kernel_function(predictions, targets) + # where inputs come from get_inputs() and init_inputs from get_init_inputs() + output = kernel_function(*cuda_inputs, *init_inputs) + +print("Kernel executed successfully, output shape: " + str(output.shape if hasattr(output, 'shape') else type(output))) +''' + + wrapper_file.write_text(wrapper_content) + self.logger.info(f"Created NCU wrapper: {wrapper_file}") + return wrapper_file From 3c4b124ed826e9b9538f07d7fbe42089de882aa4 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 12:49:50 -0800 Subject: [PATCH 02/12] Refactor profiling components and add kernel_perf_util --- kernel_perf_util/README.md | 8 + .../kernel_opt/profiler/__init__.py | 0 .../kernel_opt/profiler/gpu_specs.py | 0 .../kernel_opt/profiler/ncu_profiler.py | 0 .../profiling/kernel_profiler.py | 190 +++++++++++++--- .../profiling/ncu_wrapper_generator.py | 210 ++++++++---------- .../profiling/ncu_wrapper_template.j2 | 128 +++++++++++ .../opt_worker_components/README.md | 12 + .../profiling/__init__.py | 6 + .../profiling/kernel_profiler.py | 89 ++++++++ .../profiling/ncu_wrapper_generator.py | 154 +++++++++++++ 11 files changed, 645 insertions(+), 152 deletions(-) create mode 100644 kernel_perf_util/README.md rename {kernel_perf_agent => kernel_perf_util}/kernel_opt/profiler/__init__.py (100%) rename {kernel_perf_agent => kernel_perf_util}/kernel_opt/profiler/gpu_specs.py (100%) rename {kernel_perf_agent => kernel_perf_util}/kernel_opt/profiler/ncu_profiler.py (100%) create mode 100644 triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 create mode 100644 triton_kernel_agent/opt_worker_components/README.md create mode 100644 triton_kernel_agent/opt_worker_components/profiling/__init__.py create mode 100644 triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py create mode 100644 triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py diff --git a/kernel_perf_util/README.md b/kernel_perf_util/README.md new file mode 100644 index 0000000..5114402 --- /dev/null +++ b/kernel_perf_util/README.md @@ -0,0 +1,8 @@ +# Kernel Performance Utilities + +Low-level, reusable utilities for kernel optimization: +- NCU profiling +- RAG-based optimization database +- Performance analysis tools + +These are **library components** that can be used by any worker. diff --git a/kernel_perf_agent/kernel_opt/profiler/__init__.py b/kernel_perf_util/kernel_opt/profiler/__init__.py similarity index 100% rename from kernel_perf_agent/kernel_opt/profiler/__init__.py rename to kernel_perf_util/kernel_opt/profiler/__init__.py diff --git a/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py b/kernel_perf_util/kernel_opt/profiler/gpu_specs.py similarity index 100% rename from kernel_perf_agent/kernel_opt/profiler/gpu_specs.py rename to kernel_perf_util/kernel_opt/profiler/gpu_specs.py diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py similarity index 100% rename from kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py rename to kernel_perf_util/kernel_opt/profiler/ncu_profiler.py diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index 7ba98ef..3fc991a 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -2,6 +2,9 @@ import json import logging +import subprocess +import time +from datetime import datetime from pathlib import Path from typing import Any, Dict, Optional, Tuple @@ -41,49 +44,170 @@ def __init__( self.ncu_bin_path = ncu_bin_path self.wrapper_generator = NCUWrapperGenerator(logger) + def _get_ncu_version(self) -> Optional[str]: + """ + Get NCU version string. + + Returns: + NCU version string or None if failed + """ + try: + ncu_cmd = self.ncu_bin_path or "ncu" + result = subprocess.run( + [ncu_cmd, "--version"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + # Extract version from output (e.g., "NVIDIA Nsight Compute 2024.3.1") + version_line = result.stdout.strip().split('\n')[0] + return version_line + return None + except Exception as e: + self.logger.debug(f"Failed to get NCU version: {e}") + return None + def profile_kernel( - self, kernel_file: Path, problem_file: Path, round_num: int + self, + kernel_file: Path, + problem_file: Path, + round_num: int, + max_retries: int = 3, ) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]: """ - Profile kernel with NCU. + Profile kernel with NCU (with retry logic). + + NCU profiling can fail due to GPU contention or transient issues. + This method automatically retries with exponential backoff. Args: kernel_file: Path to kernel file problem_file: Path to problem file round_num: Current optimization round number + max_retries: Maximum number of retry attempts (default: 3) Returns: Tuple of (metrics_df, metrics_json) or (None, None) on failure - """ - try: - # Create NCU wrapper script - wrapper_file = self.wrapper_generator.create_ncu_wrapper( - kernel_file, problem_file, self.temp_dir - ) - # Profile with NCU - csv_file = f"ncu_round_{round_num}.csv" - csv_path = profile_triton_kernel( - benchmark_script=wrapper_file, - workdir=self.temp_dir, - out_csv=csv_file, - ncu_bin=self.ncu_bin_path, - launch_count=20, - timeout=120, - ) - - # Load and parse metrics - metrics_df = load_ncu_metrics(csv_path, select="last") - metrics_json = json.loads(metrics_to_prompt(metrics_df)) - - # Save metrics JSON - metrics_file = self.logs_dir / f"round{round_num:03d}_ncu_metrics.json" - with open(metrics_file, "w") as f: - json.dump(metrics_json, f, indent=2) - - self.logger.info(f"✅ NCU profiling completed for round {round_num}") - return metrics_df, metrics_json + Example: + >>> profiler = KernelProfiler(logger, temp_dir, logs_dir) + >>> metrics_df, metrics_json = profiler.profile_kernel( + ... Path("kernel.py"), Path("problem.py"), round_num=1 + ... ) + >>> if metrics_json: + ... print(f"DRAM throughput: {metrics_json['dram__throughput']}") + """ + wrapper_file = None + + for attempt in range(1, max_retries + 1): + try: + # Create NCU wrapper script (cached if unchanged) + if wrapper_file is None: + wrapper_file = self.wrapper_generator.create_ncu_wrapper( + kernel_file, problem_file, self.temp_dir + ) + + self.logger.info( + f"[Round {round_num}] NCU profiling attempt {attempt}/{max_retries}..." + ) + + # Profile with NCU + csv_file = f"ncu_round_{round_num}.csv" + csv_path = profile_triton_kernel( + benchmark_script=wrapper_file, + workdir=self.temp_dir, + out_csv=csv_file, + ncu_bin=self.ncu_bin_path, + launch_count=20, + timeout=120, + ) + + # Load and parse metrics + metrics_df = load_ncu_metrics(csv_path, select="last") + metrics_json = json.loads(metrics_to_prompt(metrics_df)) + + # Save metrics with metadata + self._save_metrics_with_metadata( + metrics_json, kernel_file, problem_file, round_num + ) + + self.logger.info(f"✅ NCU profiling completed for round {round_num}") + return metrics_df, metrics_json + + except FileNotFoundError as e: + self.logger.error(f"❌ File not found during profiling: {e}") + return None, None + + except subprocess.TimeoutExpired: + self.logger.error( + f"❌ NCU profiling timed out after 120s (attempt {attempt}/{max_retries})" + ) + if attempt < max_retries: + wait_time = 2 ** attempt # Exponential backoff: 2, 4, 8 seconds + self.logger.warning(f"Retrying in {wait_time}s...") + time.sleep(wait_time) + continue + return None, None + + except json.JSONDecodeError as e: + self.logger.error(f"❌ Failed to parse NCU metrics: {e}") + if attempt < max_retries: + wait_time = 2 ** attempt + self.logger.warning(f"Retrying in {wait_time}s...") + time.sleep(wait_time) + continue + return None, None + + except Exception as e: + self.logger.error( + f"❌ Unexpected error during profiling (attempt {attempt}/{max_retries}): {e}", + exc_info=True + ) + if attempt < max_retries: + wait_time = 2 ** attempt + self.logger.warning(f"Retrying in {wait_time}s...") + time.sleep(wait_time) + continue + return None, None + + # All retries exhausted + self.logger.error( + f"❌ NCU profiling failed after {max_retries} attempts for round {round_num}" + ) + return None, None + + def _save_metrics_with_metadata( + self, + metrics_json: Dict[str, Any], + kernel_file: Path, + problem_file: Path, + round_num: int, + ) -> None: + """ + Save profiling metrics with additional metadata. - except Exception as e: - self.logger.error(f"❌ NCU profiling failed: {e}") - return None, None + Args: + metrics_json: NCU metrics as JSON dict + kernel_file: Path to kernel file + problem_file: Path to problem file + round_num: Current optimization round number + """ + metrics_file = self.logs_dir / f"round{round_num:03d}_ncu_metrics.json" + + # Build metadata + metadata = { + "metrics": metrics_json, + "metadata": { + "kernel_file": str(kernel_file), + "problem_file": str(problem_file), + "round_num": round_num, + "timestamp": datetime.utcnow().isoformat() + "Z", + "ncu_version": self._get_ncu_version(), + }, + } + + with open(metrics_file, "w") as f: + json.dump(metadata, f, indent=2) + + self.logger.debug(f"Saved metrics with metadata: {metrics_file}") diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py index 96f391f..9f0e2e5 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py @@ -2,11 +2,21 @@ import logging from pathlib import Path +from typing import Optional + +try: + from jinja2 import Template + HAS_JINJA2 = True +except ImportError: + HAS_JINJA2 = False class NCUWrapperGenerator: """Generates NCU wrapper scripts for profiling Triton kernels.""" + # Template file path (relative to this file) + WRAPPER_TEMPLATE = Path(__file__).parent / "ncu_wrapper_template.j2" + def __init__(self, logger: logging.Logger): """ Initialize the NCU wrapper generator. @@ -15,140 +25,102 @@ def __init__(self, logger: logging.Logger): logger: Logger instance """ self.logger = logger + self._template_cache: Optional[Template] = None - def create_ncu_wrapper(self, kernel_file: Path, problem_file: Path, output_dir: Path) -> Path: + def _load_template(self) -> Template: + """ + Load the Jinja2 template (cached). + + Returns: + Jinja2 Template object + + Raises: + ImportError: If Jinja2 is not installed + FileNotFoundError: If template file doesn't exist + """ + if self._template_cache is not None: + return self._template_cache + + if not HAS_JINJA2: + raise ImportError( + "Jinja2 is required for wrapper generation. " + "Install it with: pip install jinja2" + ) + + if not self.WRAPPER_TEMPLATE.exists(): + raise FileNotFoundError(f"Template not found: {self.WRAPPER_TEMPLATE}") + + self._template_cache = Template(self.WRAPPER_TEMPLATE.read_text()) + return self._template_cache + + def create_ncu_wrapper( + self, + kernel_file: Path, + problem_file: Path, + output_dir: Path, + dtype_inference: bool = True, + model_extraction: bool = True, + use_cache: bool = True, + ) -> Path: """ Create NCU wrapper script for profiling. + The wrapper handles multiple kernel types: + - Standard kernels: kernel_function(*inputs) + - Conv/Linear kernels: Extracts weights from Model + - RMSNorm kernels: Passes init_inputs (features, eps) + Args: kernel_file: Path to kernel file problem_file: Path to problem file output_dir: Directory to write wrapper script + dtype_inference: Enable automatic dtype inference from kernel source (default: True) + model_extraction: Enable model weight extraction for Conv/Linear kernels (default: True) + use_cache: Reuse existing wrapper if files haven't changed (default: True) Returns: Path to created wrapper script - """ - wrapper_file = output_dir / "ncu_wrapper.py" - - wrapper_content = f'''"""NCU profiling wrapper.""" -import sys -import torch -import inspect -sys.path.insert(0, str({str(kernel_file.parent)!r})) -sys.path.insert(0, str({str(problem_file.parent)!r})) - -from {kernel_file.stem} import kernel_function -from {problem_file.stem} import get_inputs, get_init_inputs - -# Try to import Model if it exists (for Conv, Linear, etc.) -try: - from {problem_file.stem} import Model - has_model = True -except ImportError: - has_model = False -# Get inputs -inputs = get_inputs() + Raises: + FileNotFoundError: If kernel_file or problem_file doesn't exist + OSError: If output_dir is not writable + """ + # Validate inputs + if not kernel_file.exists(): + raise FileNotFoundError(f"Kernel file not found: {kernel_file}") + if not problem_file.exists(): + raise FileNotFoundError(f"Problem file not found: {problem_file}") -# Get additional initialization inputs (e.g., features, eps for RMSNorm) -init_inputs = get_init_inputs() + # Ensure output directory exists + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) -# Infer required dtype from kernel function signature/docstring -required_dtype = None -try: - # Try to get dtype from kernel function docstring or source - kernel_source = inspect.getsource(kernel_function) - if 'bfloat16' in kernel_source.lower(): - required_dtype = torch.bfloat16 - elif 'float16' in kernel_source.lower() or 'half' in kernel_source.lower(): - required_dtype = torch.float16 - elif 'float32' in kernel_source.lower(): - required_dtype = torch.float32 -except: - pass - -# Prepare inputs: move to CUDA and convert dtype if needed -# IMPORTANT: Only convert floating-point tensors; preserve integer tensors (e.g., class labels) -cuda_inputs = [] -for inp in inputs: - if isinstance(inp, torch.Tensor): - # Move to CUDA if not already - if not inp.is_cuda: - inp = inp.cuda() - # Convert dtype if required, but ONLY for floating-point tensors - # Preserve integer/bool tensors (e.g., targets for classification) - if required_dtype is not None and inp.is_floating_point() and inp.dtype != required_dtype: - inp = inp.to(required_dtype) - cuda_inputs.append(inp) - else: - cuda_inputs.append(inp) - -# Check if this is a conv-like kernel that needs a Model to extract weights -needs_model = False -try: - sig = inspect.signature(kernel_function) - params = list(sig.parameters.keys()) - # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) - if 'weight' in params: - needs_model = True -except: - pass - -if needs_model and has_model and init_inputs: - # Initialize model to extract weight and bias - model = Model(*init_inputs) if init_inputs else Model() - - # Move model to CUDA and convert dtype - model = model.cuda() - if required_dtype is not None: - model = model.to(required_dtype) - - # Extract weight and bias from model - # Check various possible attribute names - weight = None - bias = None - layer = None - for attr_name in ['conv1', 'conv2', 'conv3', 'conv1d', 'conv2d', 'conv', 'conv3d', 'linear', 'fc']: - if hasattr(model, attr_name): - layer = getattr(model, attr_name) - if hasattr(layer, 'weight'): - weight = layer.weight - bias = layer.bias if hasattr(layer, 'bias') else None - break - - if weight is not None and layer is not None: - # Build arguments for kernel_function using keyword arguments - # to avoid positional argument misalignment issues - kernel_kwargs = {{}} - - # Add conv/linear-specific parameters if they exist - if hasattr(layer, 'stride'): - stride = layer.stride[0] if isinstance(layer.stride, (tuple, list)) else layer.stride - kernel_kwargs['stride'] = stride - if hasattr(layer, 'padding'): - padding = layer.padding[0] if isinstance(layer.padding, (tuple, list)) else layer.padding - kernel_kwargs['padding'] = padding - if hasattr(layer, 'dilation'): - dilation = layer.dilation[0] if isinstance(layer.dilation, (tuple, list)) else layer.dilation - kernel_kwargs['dilation'] = dilation - if hasattr(layer, 'groups'): - kernel_kwargs['groups'] = layer.groups - - # Call kernel with extracted parameters - output = kernel_function(cuda_inputs[0], weight, bias, **kernel_kwargs) - else: - # Fallback to original behavior - output = kernel_function(*cuda_inputs, *init_inputs) -else: - # Run kernel with both tensor inputs and initialization inputs - # For example: RMSNorm needs kernel_function(x, features, eps) - # For cross-entropy: kernel_function(predictions, targets) - # where inputs come from get_inputs() and init_inputs from get_init_inputs() - output = kernel_function(*cuda_inputs, *init_inputs) - -print("Kernel executed successfully, output shape: " + str(output.shape if hasattr(output, 'shape') else type(output))) -''' + wrapper_file = output_dir / "ncu_wrapper.py" + # Check cache: reuse wrapper if it's fresh + if use_cache and wrapper_file.exists(): + wrapper_mtime = wrapper_file.stat().st_mtime + kernel_mtime = kernel_file.stat().st_mtime + problem_mtime = problem_file.stat().st_mtime + + if wrapper_mtime > kernel_mtime and wrapper_mtime > problem_mtime: + self.logger.info( + f"Reusing cached NCU wrapper (fresher than source files): {wrapper_file}" + ) + return wrapper_file + + # Load template and render + template = self._load_template() + wrapper_content = template.render( + kernel_file_parent=repr(str(kernel_file.parent)), + problem_file_parent=repr(str(problem_file.parent)), + kernel_module=kernel_file.stem, + problem_module=problem_file.stem, + dtype_inference=dtype_inference, + model_extraction=model_extraction, + ) + + # Write wrapper file wrapper_file.write_text(wrapper_content) self.logger.info(f"Created NCU wrapper: {wrapper_file}") return wrapper_file diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 new file mode 100644 index 0000000..310d8bb --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 @@ -0,0 +1,128 @@ +"""NCU profiling wrapper.""" +import sys +import torch +import inspect +sys.path.insert(0, str({{ kernel_file_parent }})) +sys.path.insert(0, str({{ problem_file_parent }})) + +from {{ kernel_module }} import kernel_function +from {{ problem_module }} import get_inputs, get_init_inputs + +# Try to import Model if it exists (for Conv, Linear, etc.) +try: + from {{ problem_module }} import Model + has_model = True +except ImportError: + has_model = False + +# Get inputs +inputs = get_inputs() + +# Get additional initialization inputs (e.g., features, eps for RMSNorm) +init_inputs = get_init_inputs() + +{% if dtype_inference %} +# Infer required dtype from kernel function signature/docstring +required_dtype = None +try: + # Try to get dtype from kernel function docstring or source + kernel_source = inspect.getsource(kernel_function) + if 'bfloat16' in kernel_source.lower(): + required_dtype = torch.bfloat16 + elif 'float16' in kernel_source.lower() or 'half' in kernel_source.lower(): + required_dtype = torch.float16 + elif 'float32' in kernel_source.lower(): + required_dtype = torch.float32 +except Exception: + pass +{% else %} +# Dtype inference disabled +required_dtype = None +{% endif %} + +# Prepare inputs: move to CUDA and convert dtype if needed +# IMPORTANT: Only convert floating-point tensors; preserve integer tensors (e.g., class labels) +cuda_inputs = [] +for inp in inputs: + if isinstance(inp, torch.Tensor): + # Move to CUDA if not already + if not inp.is_cuda: + inp = inp.cuda() + # Convert dtype if required, but ONLY for floating-point tensors + # Preserve integer/bool tensors (e.g., targets for classification) + if required_dtype is not None and inp.is_floating_point() and inp.dtype != required_dtype: + inp = inp.to(required_dtype) + cuda_inputs.append(inp) + else: + cuda_inputs.append(inp) + +{% if model_extraction %} +# Check if this is a conv-like kernel that needs a Model to extract weights +needs_model = False +try: + sig = inspect.signature(kernel_function) + params = list(sig.parameters.keys()) + # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) + if 'weight' in params: + needs_model = True +except Exception: + pass + +if needs_model and has_model and init_inputs: + # Initialize model to extract weight and bias + try: + model = Model(*init_inputs) if init_inputs else Model() + except TypeError: + # Model initialization failed, use fallback + model = Model() + + # Move model to CUDA and convert dtype + model = model.cuda() + if required_dtype is not None: + model = model.to(required_dtype) + + # Extract weight and bias from model + # Check various possible attribute names + weight = None + bias = None + layer = None + for attr_name in ['conv1', 'conv2', 'conv3', 'conv1d', 'conv2d', 'conv', 'conv3d', 'linear', 'fc']: + if hasattr(model, attr_name): + layer = getattr(model, attr_name) + if hasattr(layer, 'weight'): + weight = layer.weight + bias = layer.bias if hasattr(layer, 'bias') else None + break + + if weight is not None and layer is not None: + # Build arguments for kernel_function using keyword arguments + # to avoid positional argument misalignment issues + kernel_kwargs = {} # Empty dict + + # Add conv/linear-specific parameters if they exist + if hasattr(layer, 'stride'): + stride = layer.stride[0] if isinstance(layer.stride, (tuple, list)) else layer.stride + kernel_kwargs['stride'] = stride + if hasattr(layer, 'padding'): + padding = layer.padding[0] if isinstance(layer.padding, (tuple, list)) else layer.padding + kernel_kwargs['padding'] = padding + if hasattr(layer, 'dilation'): + dilation = layer.dilation[0] if isinstance(layer.dilation, (tuple, list)) else layer.dilation + kernel_kwargs['dilation'] = dilation + if hasattr(layer, 'groups'): + kernel_kwargs['groups'] = layer.groups + + # Call kernel with extracted parameters + output = kernel_function(cuda_inputs[0], weight, bias, **kernel_kwargs) + else: + # Fallback to original behavior + output = kernel_function(*cuda_inputs, *init_inputs) +else: +{% endif %} + # Run kernel with both tensor inputs and initialization inputs + # For example: RMSNorm needs kernel_function(x, features, eps) + # For cross-entropy: kernel_function(predictions, targets) + # where inputs come from get_inputs() and init_inputs from get_init_inputs() + output = kernel_function(*cuda_inputs, *init_inputs) + +print("Kernel executed successfully, output shape: " + str(output.shape if hasattr(output, 'shape') else type(output))) diff --git a/triton_kernel_agent/opt_worker_components/README.md b/triton_kernel_agent/opt_worker_components/README.md new file mode 100644 index 0000000..76026b2 --- /dev/null +++ b/triton_kernel_agent/opt_worker_components/README.md @@ -0,0 +1,12 @@ +# Opt Worker Components + +High-level components used by `OptimizationWorker`. + +These components are **thin wrappers** around low-level utilities +from `kernel_perf_agent` that provide: +- Logging integration +- Error handling +- Worker-specific configuration + +## Dependency Flow +worker_components → kernel_perf_agent (implementation) diff --git a/triton_kernel_agent/opt_worker_components/profiling/__init__.py b/triton_kernel_agent/opt_worker_components/profiling/__init__.py new file mode 100644 index 0000000..cb28cc6 --- /dev/null +++ b/triton_kernel_agent/opt_worker_components/profiling/__init__.py @@ -0,0 +1,6 @@ +"""Profiling infrastructure for NCU-based kernel analysis.""" + +from .ncu_wrapper_generator import NCUWrapperGenerator +from .kernel_profiler import KernelProfiler + +__all__ = ["NCUWrapperGenerator", "KernelProfiler"] diff --git a/triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py new file mode 100644 index 0000000..7ba98ef --- /dev/null +++ b/triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py @@ -0,0 +1,89 @@ +"""Kernel profiling with NCU.""" + +import json +import logging +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import pandas as pd + +from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( + load_ncu_metrics, + metrics_to_prompt, + profile_triton_kernel, +) + +from .ncu_wrapper_generator import NCUWrapperGenerator + + +class KernelProfiler: + """Profiles Triton kernels using NVIDIA Nsight Compute (NCU).""" + + def __init__( + self, + logger: logging.Logger, + temp_dir: Path, + logs_dir: Path, + ncu_bin_path: Optional[str] = None, + ): + """ + Initialize the kernel profiler. + + Args: + logger: Logger instance + temp_dir: Temporary directory for profiling artifacts + logs_dir: Directory for saving profiling logs + ncu_bin_path: Path to NCU binary (auto-detect if None) + """ + self.logger = logger + self.temp_dir = temp_dir + self.logs_dir = logs_dir + self.ncu_bin_path = ncu_bin_path + self.wrapper_generator = NCUWrapperGenerator(logger) + + def profile_kernel( + self, kernel_file: Path, problem_file: Path, round_num: int + ) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]: + """ + Profile kernel with NCU. + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + round_num: Current optimization round number + + Returns: + Tuple of (metrics_df, metrics_json) or (None, None) on failure + """ + try: + # Create NCU wrapper script + wrapper_file = self.wrapper_generator.create_ncu_wrapper( + kernel_file, problem_file, self.temp_dir + ) + + # Profile with NCU + csv_file = f"ncu_round_{round_num}.csv" + csv_path = profile_triton_kernel( + benchmark_script=wrapper_file, + workdir=self.temp_dir, + out_csv=csv_file, + ncu_bin=self.ncu_bin_path, + launch_count=20, + timeout=120, + ) + + # Load and parse metrics + metrics_df = load_ncu_metrics(csv_path, select="last") + metrics_json = json.loads(metrics_to_prompt(metrics_df)) + + # Save metrics JSON + metrics_file = self.logs_dir / f"round{round_num:03d}_ncu_metrics.json" + with open(metrics_file, "w") as f: + json.dump(metrics_json, f, indent=2) + + self.logger.info(f"✅ NCU profiling completed for round {round_num}") + return metrics_df, metrics_json + + except Exception as e: + self.logger.error(f"❌ NCU profiling failed: {e}") + return None, None diff --git a/triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py b/triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py new file mode 100644 index 0000000..96f391f --- /dev/null +++ b/triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py @@ -0,0 +1,154 @@ +"""NCU wrapper script generation for kernel profiling.""" + +import logging +from pathlib import Path + + +class NCUWrapperGenerator: + """Generates NCU wrapper scripts for profiling Triton kernels.""" + + def __init__(self, logger: logging.Logger): + """ + Initialize the NCU wrapper generator. + + Args: + logger: Logger instance + """ + self.logger = logger + + def create_ncu_wrapper(self, kernel_file: Path, problem_file: Path, output_dir: Path) -> Path: + """ + Create NCU wrapper script for profiling. + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + output_dir: Directory to write wrapper script + + Returns: + Path to created wrapper script + """ + wrapper_file = output_dir / "ncu_wrapper.py" + + wrapper_content = f'''"""NCU profiling wrapper.""" +import sys +import torch +import inspect +sys.path.insert(0, str({str(kernel_file.parent)!r})) +sys.path.insert(0, str({str(problem_file.parent)!r})) + +from {kernel_file.stem} import kernel_function +from {problem_file.stem} import get_inputs, get_init_inputs + +# Try to import Model if it exists (for Conv, Linear, etc.) +try: + from {problem_file.stem} import Model + has_model = True +except ImportError: + has_model = False + +# Get inputs +inputs = get_inputs() + +# Get additional initialization inputs (e.g., features, eps for RMSNorm) +init_inputs = get_init_inputs() + +# Infer required dtype from kernel function signature/docstring +required_dtype = None +try: + # Try to get dtype from kernel function docstring or source + kernel_source = inspect.getsource(kernel_function) + if 'bfloat16' in kernel_source.lower(): + required_dtype = torch.bfloat16 + elif 'float16' in kernel_source.lower() or 'half' in kernel_source.lower(): + required_dtype = torch.float16 + elif 'float32' in kernel_source.lower(): + required_dtype = torch.float32 +except: + pass + +# Prepare inputs: move to CUDA and convert dtype if needed +# IMPORTANT: Only convert floating-point tensors; preserve integer tensors (e.g., class labels) +cuda_inputs = [] +for inp in inputs: + if isinstance(inp, torch.Tensor): + # Move to CUDA if not already + if not inp.is_cuda: + inp = inp.cuda() + # Convert dtype if required, but ONLY for floating-point tensors + # Preserve integer/bool tensors (e.g., targets for classification) + if required_dtype is not None and inp.is_floating_point() and inp.dtype != required_dtype: + inp = inp.to(required_dtype) + cuda_inputs.append(inp) + else: + cuda_inputs.append(inp) + +# Check if this is a conv-like kernel that needs a Model to extract weights +needs_model = False +try: + sig = inspect.signature(kernel_function) + params = list(sig.parameters.keys()) + # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) + if 'weight' in params: + needs_model = True +except: + pass + +if needs_model and has_model and init_inputs: + # Initialize model to extract weight and bias + model = Model(*init_inputs) if init_inputs else Model() + + # Move model to CUDA and convert dtype + model = model.cuda() + if required_dtype is not None: + model = model.to(required_dtype) + + # Extract weight and bias from model + # Check various possible attribute names + weight = None + bias = None + layer = None + for attr_name in ['conv1', 'conv2', 'conv3', 'conv1d', 'conv2d', 'conv', 'conv3d', 'linear', 'fc']: + if hasattr(model, attr_name): + layer = getattr(model, attr_name) + if hasattr(layer, 'weight'): + weight = layer.weight + bias = layer.bias if hasattr(layer, 'bias') else None + break + + if weight is not None and layer is not None: + # Build arguments for kernel_function using keyword arguments + # to avoid positional argument misalignment issues + kernel_kwargs = {{}} + + # Add conv/linear-specific parameters if they exist + if hasattr(layer, 'stride'): + stride = layer.stride[0] if isinstance(layer.stride, (tuple, list)) else layer.stride + kernel_kwargs['stride'] = stride + if hasattr(layer, 'padding'): + padding = layer.padding[0] if isinstance(layer.padding, (tuple, list)) else layer.padding + kernel_kwargs['padding'] = padding + if hasattr(layer, 'dilation'): + dilation = layer.dilation[0] if isinstance(layer.dilation, (tuple, list)) else layer.dilation + kernel_kwargs['dilation'] = dilation + if hasattr(layer, 'groups'): + kernel_kwargs['groups'] = layer.groups + + # Call kernel with extracted parameters + output = kernel_function(cuda_inputs[0], weight, bias, **kernel_kwargs) + else: + # Fallback to original behavior + output = kernel_function(*cuda_inputs, *init_inputs) +else: + # Run kernel with both tensor inputs and initialization inputs + # For example: RMSNorm needs kernel_function(x, features, eps) + # For cross-entropy: kernel_function(predictions, targets) + # where inputs come from get_inputs() and init_inputs from get_init_inputs() + output = kernel_function(*cuda_inputs, *init_inputs) + +print("Kernel executed successfully, output shape: " + str(output.shape if hasattr(output, 'shape') else type(output))) +''' + + wrapper_file.write_text(wrapper_content) + self.logger.info(f"Created NCU wrapper: {wrapper_file}") + return wrapper_file From 11f4e7978ff619b0f6052d875b15a1502e1691b8 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 12:54:11 -0800 Subject: [PATCH 03/12] Refactor profiling components and add kernel_perf_util --- .../README.md | 0 .../profiling/kernel_profiler.py | 12 +- .../profiling/__init__.py | 6 - .../profiling/kernel_profiler.py | 89 ---------- .../profiling/ncu_wrapper_generator.py | 154 ------------------ 5 files changed, 6 insertions(+), 255 deletions(-) rename triton_kernel_agent/{opt_worker_components => opt_worker_component}/README.md (100%) delete mode 100644 triton_kernel_agent/opt_worker_components/profiling/__init__.py delete mode 100644 triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py delete mode 100644 triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py diff --git a/triton_kernel_agent/opt_worker_components/README.md b/triton_kernel_agent/opt_worker_component/README.md similarity index 100% rename from triton_kernel_agent/opt_worker_components/README.md rename to triton_kernel_agent/opt_worker_component/README.md diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index 3fc991a..61f41b3 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -10,7 +10,7 @@ import pandas as pd -from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( +from kernel_perf_util.kernel_opt.profiler.ncu_profiler import ( load_ncu_metrics, metrics_to_prompt, profile_triton_kernel, @@ -61,7 +61,7 @@ def _get_ncu_version(self) -> Optional[str]: ) if result.returncode == 0: # Extract version from output (e.g., "NVIDIA Nsight Compute 2024.3.1") - version_line = result.stdout.strip().split('\n')[0] + version_line = result.stdout.strip().split("\n")[0] return version_line return None except Exception as e: @@ -144,7 +144,7 @@ def profile_kernel( f"❌ NCU profiling timed out after 120s (attempt {attempt}/{max_retries})" ) if attempt < max_retries: - wait_time = 2 ** attempt # Exponential backoff: 2, 4, 8 seconds + wait_time = 2**attempt # Exponential backoff: 2, 4, 8 seconds self.logger.warning(f"Retrying in {wait_time}s...") time.sleep(wait_time) continue @@ -153,7 +153,7 @@ def profile_kernel( except json.JSONDecodeError as e: self.logger.error(f"❌ Failed to parse NCU metrics: {e}") if attempt < max_retries: - wait_time = 2 ** attempt + wait_time = 2**attempt self.logger.warning(f"Retrying in {wait_time}s...") time.sleep(wait_time) continue @@ -162,10 +162,10 @@ def profile_kernel( except Exception as e: self.logger.error( f"❌ Unexpected error during profiling (attempt {attempt}/{max_retries}): {e}", - exc_info=True + exc_info=True, ) if attempt < max_retries: - wait_time = 2 ** attempt + wait_time = 2**attempt self.logger.warning(f"Retrying in {wait_time}s...") time.sleep(wait_time) continue diff --git a/triton_kernel_agent/opt_worker_components/profiling/__init__.py b/triton_kernel_agent/opt_worker_components/profiling/__init__.py deleted file mode 100644 index cb28cc6..0000000 --- a/triton_kernel_agent/opt_worker_components/profiling/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Profiling infrastructure for NCU-based kernel analysis.""" - -from .ncu_wrapper_generator import NCUWrapperGenerator -from .kernel_profiler import KernelProfiler - -__all__ = ["NCUWrapperGenerator", "KernelProfiler"] diff --git a/triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py deleted file mode 100644 index 7ba98ef..0000000 --- a/triton_kernel_agent/opt_worker_components/profiling/kernel_profiler.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Kernel profiling with NCU.""" - -import json -import logging -from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -import pandas as pd - -from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( - load_ncu_metrics, - metrics_to_prompt, - profile_triton_kernel, -) - -from .ncu_wrapper_generator import NCUWrapperGenerator - - -class KernelProfiler: - """Profiles Triton kernels using NVIDIA Nsight Compute (NCU).""" - - def __init__( - self, - logger: logging.Logger, - temp_dir: Path, - logs_dir: Path, - ncu_bin_path: Optional[str] = None, - ): - """ - Initialize the kernel profiler. - - Args: - logger: Logger instance - temp_dir: Temporary directory for profiling artifacts - logs_dir: Directory for saving profiling logs - ncu_bin_path: Path to NCU binary (auto-detect if None) - """ - self.logger = logger - self.temp_dir = temp_dir - self.logs_dir = logs_dir - self.ncu_bin_path = ncu_bin_path - self.wrapper_generator = NCUWrapperGenerator(logger) - - def profile_kernel( - self, kernel_file: Path, problem_file: Path, round_num: int - ) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]: - """ - Profile kernel with NCU. - - Args: - kernel_file: Path to kernel file - problem_file: Path to problem file - round_num: Current optimization round number - - Returns: - Tuple of (metrics_df, metrics_json) or (None, None) on failure - """ - try: - # Create NCU wrapper script - wrapper_file = self.wrapper_generator.create_ncu_wrapper( - kernel_file, problem_file, self.temp_dir - ) - - # Profile with NCU - csv_file = f"ncu_round_{round_num}.csv" - csv_path = profile_triton_kernel( - benchmark_script=wrapper_file, - workdir=self.temp_dir, - out_csv=csv_file, - ncu_bin=self.ncu_bin_path, - launch_count=20, - timeout=120, - ) - - # Load and parse metrics - metrics_df = load_ncu_metrics(csv_path, select="last") - metrics_json = json.loads(metrics_to_prompt(metrics_df)) - - # Save metrics JSON - metrics_file = self.logs_dir / f"round{round_num:03d}_ncu_metrics.json" - with open(metrics_file, "w") as f: - json.dump(metrics_json, f, indent=2) - - self.logger.info(f"✅ NCU profiling completed for round {round_num}") - return metrics_df, metrics_json - - except Exception as e: - self.logger.error(f"❌ NCU profiling failed: {e}") - return None, None diff --git a/triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py b/triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py deleted file mode 100644 index 96f391f..0000000 --- a/triton_kernel_agent/opt_worker_components/profiling/ncu_wrapper_generator.py +++ /dev/null @@ -1,154 +0,0 @@ -"""NCU wrapper script generation for kernel profiling.""" - -import logging -from pathlib import Path - - -class NCUWrapperGenerator: - """Generates NCU wrapper scripts for profiling Triton kernels.""" - - def __init__(self, logger: logging.Logger): - """ - Initialize the NCU wrapper generator. - - Args: - logger: Logger instance - """ - self.logger = logger - - def create_ncu_wrapper(self, kernel_file: Path, problem_file: Path, output_dir: Path) -> Path: - """ - Create NCU wrapper script for profiling. - - Args: - kernel_file: Path to kernel file - problem_file: Path to problem file - output_dir: Directory to write wrapper script - - Returns: - Path to created wrapper script - """ - wrapper_file = output_dir / "ncu_wrapper.py" - - wrapper_content = f'''"""NCU profiling wrapper.""" -import sys -import torch -import inspect -sys.path.insert(0, str({str(kernel_file.parent)!r})) -sys.path.insert(0, str({str(problem_file.parent)!r})) - -from {kernel_file.stem} import kernel_function -from {problem_file.stem} import get_inputs, get_init_inputs - -# Try to import Model if it exists (for Conv, Linear, etc.) -try: - from {problem_file.stem} import Model - has_model = True -except ImportError: - has_model = False - -# Get inputs -inputs = get_inputs() - -# Get additional initialization inputs (e.g., features, eps for RMSNorm) -init_inputs = get_init_inputs() - -# Infer required dtype from kernel function signature/docstring -required_dtype = None -try: - # Try to get dtype from kernel function docstring or source - kernel_source = inspect.getsource(kernel_function) - if 'bfloat16' in kernel_source.lower(): - required_dtype = torch.bfloat16 - elif 'float16' in kernel_source.lower() or 'half' in kernel_source.lower(): - required_dtype = torch.float16 - elif 'float32' in kernel_source.lower(): - required_dtype = torch.float32 -except: - pass - -# Prepare inputs: move to CUDA and convert dtype if needed -# IMPORTANT: Only convert floating-point tensors; preserve integer tensors (e.g., class labels) -cuda_inputs = [] -for inp in inputs: - if isinstance(inp, torch.Tensor): - # Move to CUDA if not already - if not inp.is_cuda: - inp = inp.cuda() - # Convert dtype if required, but ONLY for floating-point tensors - # Preserve integer/bool tensors (e.g., targets for classification) - if required_dtype is not None and inp.is_floating_point() and inp.dtype != required_dtype: - inp = inp.to(required_dtype) - cuda_inputs.append(inp) - else: - cuda_inputs.append(inp) - -# Check if this is a conv-like kernel that needs a Model to extract weights -needs_model = False -try: - sig = inspect.signature(kernel_function) - params = list(sig.parameters.keys()) - # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) - if 'weight' in params: - needs_model = True -except: - pass - -if needs_model and has_model and init_inputs: - # Initialize model to extract weight and bias - model = Model(*init_inputs) if init_inputs else Model() - - # Move model to CUDA and convert dtype - model = model.cuda() - if required_dtype is not None: - model = model.to(required_dtype) - - # Extract weight and bias from model - # Check various possible attribute names - weight = None - bias = None - layer = None - for attr_name in ['conv1', 'conv2', 'conv3', 'conv1d', 'conv2d', 'conv', 'conv3d', 'linear', 'fc']: - if hasattr(model, attr_name): - layer = getattr(model, attr_name) - if hasattr(layer, 'weight'): - weight = layer.weight - bias = layer.bias if hasattr(layer, 'bias') else None - break - - if weight is not None and layer is not None: - # Build arguments for kernel_function using keyword arguments - # to avoid positional argument misalignment issues - kernel_kwargs = {{}} - - # Add conv/linear-specific parameters if they exist - if hasattr(layer, 'stride'): - stride = layer.stride[0] if isinstance(layer.stride, (tuple, list)) else layer.stride - kernel_kwargs['stride'] = stride - if hasattr(layer, 'padding'): - padding = layer.padding[0] if isinstance(layer.padding, (tuple, list)) else layer.padding - kernel_kwargs['padding'] = padding - if hasattr(layer, 'dilation'): - dilation = layer.dilation[0] if isinstance(layer.dilation, (tuple, list)) else layer.dilation - kernel_kwargs['dilation'] = dilation - if hasattr(layer, 'groups'): - kernel_kwargs['groups'] = layer.groups - - # Call kernel with extracted parameters - output = kernel_function(cuda_inputs[0], weight, bias, **kernel_kwargs) - else: - # Fallback to original behavior - output = kernel_function(*cuda_inputs, *init_inputs) -else: - # Run kernel with both tensor inputs and initialization inputs - # For example: RMSNorm needs kernel_function(x, features, eps) - # For cross-entropy: kernel_function(predictions, targets) - # where inputs come from get_inputs() and init_inputs from get_init_inputs() - output = kernel_function(*cuda_inputs, *init_inputs) - -print("Kernel executed successfully, output shape: " + str(output.shape if hasattr(output, 'shape') else type(output))) -''' - - wrapper_file.write_text(wrapper_content) - self.logger.info(f"Created NCU wrapper: {wrapper_file}") - return wrapper_file From 251f419741dd3a50233b947b3dab26417e922e2e Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 12:55:10 -0800 Subject: [PATCH 04/12] Refactor profiling components and add kernel_perf_util --- .../kernel_opt/profiler/ncu_profiler.py | 61 ------------------- 1 file changed, 61 deletions(-) diff --git a/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py index 26e59d2..0b9168b 100644 --- a/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py +++ b/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py @@ -64,67 +64,6 @@ ] ) -# METRICS = ",".join( -# [ -# # Tensor Core (MOST CRITICAL for GEMM) -# "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", -# # Memory Bandwidth & Cache -# "dram__throughput.avg.pct_of_peak_sustained_elapsed", -# "l1tex__t_sector_hit_rate.pct", -# "lts__t_sector_hit_rate.pct", -# # Memory Access Efficiency -# "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", -# # Stalls & Bottlenecks (H100-compatible stall metrics) -# "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", -# "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", -# # Occupancy & Resources -# "launch__shared_mem_per_block_allocated", -# "sm__warps_active.avg.pct_of_peak_sustained_elapsed", -# # SM Utilization -# "sm__cycles_active.avg.pct_of_peak_sustained_elapsed", -# ] -# ) - -# # CRITICAL GEMM Optimization Metrics (Optimized for Matrix Multiplication) -# METRICS = ",".join( -# [ -# # === Memory Efficiency Metrics (CRITICAL) === -# "dram__throughput.avg.pct_of_peak_sustained_elapsed", # DRAM bandwidth utilization -# "dram__bytes_read.sum", # Total DRAM reads -# "dram__bytes_write.sum", # Total DRAM writes -# "l1tex__t_sector_hit_rate.pct", # L1 cache hit rate (CRITICAL for data reuse) -# "l1tex__throughput.avg.pct_of_peak_sustained_elapsed", # L1 throughput (fixed from _active) -# "lts__t_sector_hit_rate.pct", # L2 cache hit rate -# "lts__throughput.avg.pct_of_peak_sustained_elapsed", # L2 throughput (fixed from _active) -# "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", # Memory coalescing (NEW - CRITICAL) -# "l1tex__data_pipe_lsu_wavefronts_mem_shared_op_ld.sum", # Shared memory loads (NEW) -# # === Compute Utilization Metrics (CRITICAL for GEMM) === -# "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", # Tensor core pipeline (NEW - CRITICAL) -# "smsp__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", # Tensor core inst % -# "smsp__inst_executed_pipe_tensor.sum", # Total tensor core instructions (NEW) -# "sm__inst_executed_pipe_fp32.avg.pct_of_peak_sustained_active", # FP32 pipeline -# "sm__cycles_active.avg.pct_of_peak_sustained_elapsed", # SM active cycles (fixed from .avg only) -# "sm__inst_executed.sum", # Total instructions -# "smsp__inst_executed.avg.per_cycle_active", # IPC (NEW) -# # === Occupancy & Resource Metrics === -# "sm__warps_active.avg.pct_of_peak_sustained_elapsed", # Warp occupancy (fixed from _active) -# "sm__maximum_warps_per_active_cycle_pct", # Max achievable occupancy (NEW) -# "launch__occupancy_limit_blocks", # Block limit -# "launch__occupancy_limit_registers", # Register limit -# "launch__occupancy_limit_shared_mem", # Shared memory limit -# "launch__registers_per_thread", # Registers per thread -# "launch__shared_mem_per_block_allocated", # Shared memory usage (NEW - CRITICAL) -# # === Stall & Latency Metrics === -# "smsp__warp_issue_stalled_memory_dependency_per_warp_active.pct", # Memory dependency stalls -# "smsp__warp_issue_stalled_mem_throttle_per_warp_active.pct", # Memory throttle stalls (NEW - CRITICAL) -# "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", # Short scoreboard stalls -# "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", # Long scoreboard stalls -# "smsp__warp_issue_stalled_barrier_per_warp_active.pct", # Barrier stalls -# "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", # Branch stalls -# ] -# ) - - # List version for convenient header selection METRIC_COLUMNS: List[str] = [s.strip() for s in METRICS.split(",")] From b789660d6897fee27c8793eab9984d25c031a894 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 14:14:19 -0800 Subject: [PATCH 05/12] update directory name and add package in pyproject --- kernel_perf_agent/README.md | 8 + .../kernel_opt/profiler/__init__.py | 0 .../kernel_opt/profiler/gpu_specs.py | 227 ++++++++++ .../kernel_opt/profiler/ncu_profiler.py | 388 ++++++++++++++++++ pyproject.toml | 2 +- 5 files changed, 624 insertions(+), 1 deletion(-) create mode 100644 kernel_perf_agent/README.md create mode 100644 kernel_perf_agent/kernel_opt/profiler/__init__.py create mode 100644 kernel_perf_agent/kernel_opt/profiler/gpu_specs.py create mode 100644 kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py diff --git a/kernel_perf_agent/README.md b/kernel_perf_agent/README.md new file mode 100644 index 0000000..5114402 --- /dev/null +++ b/kernel_perf_agent/README.md @@ -0,0 +1,8 @@ +# Kernel Performance Utilities + +Low-level, reusable utilities for kernel optimization: +- NCU profiling +- RAG-based optimization database +- Performance analysis tools + +These are **library components** that can be used by any worker. diff --git a/kernel_perf_agent/kernel_opt/profiler/__init__.py b/kernel_perf_agent/kernel_opt/profiler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py b/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py new file mode 100644 index 0000000..ed8b50f --- /dev/null +++ b/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +""" +GPU Specifications Database for Bottleneck Analysis + +This module provides GPU hardware specifications needed for performance analysis +and bottleneck identification. It includes peak compute performance, memory bandwidth, +cache sizes, and SM counts for common NVIDIA GPUs. + +""" + +import re +import subprocess +from typing import Any, Dict, Optional + + +# GPU specifications database +# Sources: NVIDIA official specifications, manufacturer datasheets +GPU_SPECS_DATABASE = { + "NVIDIA A100": { + "name": "NVIDIA A100", + "architecture": "Ampere", + "peak_fp32_tflops": 19.5, + "peak_fp16_tflops": 312.0, + "peak_bf16_tflops": 312.0, + "peak_memory_bw_gbps": 1555, + "sm_count": 108, + "max_threads_per_sm": 2048, + "l1_cache_kb": 192, + "l2_cache_mb": 40, + "memory_gb": 40, + "memory_type": "HBM2e", + }, + "NVIDIA H100": { + "name": "NVIDIA H100", + "architecture": "Hopper", + "peak_fp32_tflops": 51.0, + "peak_fp16_tflops": 989.0, + "peak_bf16_tflops": 989.0, + "peak_memory_bw_gbps": 3352, + "sm_count": 132, + "max_threads_per_sm": 2048, + "l1_cache_kb": 256, + "l2_cache_mb": 50, + "memory_gb": 80, + "memory_type": "HBM3", + }, + "NVIDIA RTX 4090": { + "name": "NVIDIA RTX 4090", + "architecture": "Ada Lovelace", + "peak_fp32_tflops": 82.6, + "peak_fp16_tflops": 165.0, + "peak_bf16_tflops": 165.0, + "peak_memory_bw_gbps": 1008, + "sm_count": 128, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 72, + "memory_gb": 24, + "memory_type": "GDDR6X", + }, + "NVIDIA RTX 5080": { + "name": "NVIDIA RTX 5080", + "architecture": "Blackwell", + "peak_fp32_tflops": 57.0, + "peak_fp16_tflops": 114.0, + "peak_bf16_tflops": 114.0, + "peak_memory_bw_gbps": 960, + "sm_count": 84, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 64, + "memory_gb": 16, + "memory_type": "GDDR7", + }, +} + + +def query_gpu_name() -> Optional[str]: + """ + Query GPU name using nvidia-smi. + + Returns: + GPU name string, or None if query fails + """ + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + gpu_name = result.stdout.strip() + # Normalize whitespace + gpu_name = re.sub(r"\s+", " ", gpu_name) + return gpu_name + except (subprocess.TimeoutExpired, FileNotFoundError, Exception): + pass + return None + + +def get_gpu_specs(gpu_name: Optional[str] = None) -> Dict[str, Any]: + """ + Get GPU specifications for bottleneck analysis. + + This function returns hardware specifications needed for performance analysis, + including peak compute performance, memory bandwidth, cache sizes, and SM counts. + + Args: + gpu_name: GPU name (if None, auto-detect with nvidia-smi) + + Returns: + Dictionary with GPU specifications containing: + - name: GPU name + - architecture: GPU architecture (e.g., "Ampere", "Hopper") + - peak_fp32_tflops: Peak FP32 compute performance in TFLOPS + - peak_fp16_tflops: Peak FP16 compute performance in TFLOPS + - peak_bf16_tflops: Peak BF16 compute performance in TFLOPS (0 if not supported) + - peak_memory_bw_gbps: Peak memory bandwidth in GB/s + - sm_count: Number of streaming multiprocessors + - max_threads_per_sm: Maximum threads per SM + - l1_cache_kb: L1 cache size in KB per SM + - l2_cache_mb: Total L2 cache size in MB + - memory_gb: Total GPU memory in GB + - memory_type: Memory type (e.g., "HBM2e", "GDDR6X") + + Examples: + >>> specs = get_gpu_specs() # Auto-detect + >>> print(f"Peak BW: {specs['peak_memory_bw_gbps']} GB/s") + + >>> specs = get_gpu_specs("NVIDIA A100") + >>> print(f"SM Count: {specs['sm_count']}") + """ + # Auto-detect if not provided + if gpu_name is None: + gpu_name = query_gpu_name() + + # Return default if detection failed + if gpu_name is None: + print("⚠️ GPU auto-detection failed, using A100 specs as fallback") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + # Try exact match + if gpu_name in GPU_SPECS_DATABASE: + return GPU_SPECS_DATABASE[gpu_name].copy() + + # Try fuzzy match (contains or partial match) + gpu_name_lower = gpu_name.lower() + for key, specs in GPU_SPECS_DATABASE.items(): + key_lower = key.lower() + # Check if either name contains the other + if gpu_name_lower in key_lower or key_lower in gpu_name_lower: + print(f"ℹ️ Matched '{gpu_name}' to '{key}' (fuzzy match)") + return specs.copy() + + # Fallback to A100 specs with warning + print(f"⚠️ Unknown GPU: '{gpu_name}', using A100 specs as fallback") + print(f" Available GPUs: {', '.join(GPU_SPECS_DATABASE.keys())}") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + +def format_gpu_specs_for_prompt(specs: Dict[str, Any]) -> str: + """ + Format GPU specifications as a readable string for LLM prompts. + + Args: + specs: GPU specifications dictionary from get_gpu_specs() + + Returns: + Formatted string with GPU specifications + + Example: + >>> specs = get_gpu_specs("NVIDIA A100") + >>> print(format_gpu_specs_for_prompt(specs)) + GPU: NVIDIA A100 (Ampere) + - Peak Memory Bandwidth: 1555 GB/s + - Peak FP32 Performance: 19.5 TFLOPS + - Peak FP16 Performance: 312.0 TFLOPS + - SM Count: 108 + - L1 Cache: 192 KB per SM + - L2 Cache: 40 MB + """ + lines = [ + f"GPU: {specs['name']} ({specs.get('architecture', 'Unknown')})", + f"- Peak Memory Bandwidth: {specs['peak_memory_bw_gbps']} GB/s", + f"- Peak FP32 Performance: {specs['peak_fp32_tflops']} TFLOPS", + ] + + if specs.get("peak_fp16_tflops", 0) > 0: + lines.append(f"- Peak FP16 Performance: {specs['peak_fp16_tflops']} TFLOPS") + + if specs.get("peak_bf16_tflops", 0) > 0: + lines.append(f"- Peak BF16 Performance: {specs['peak_bf16_tflops']} TFLOPS") + + lines.extend( + [ + f"- SM Count: {specs['sm_count']}", + f"- Max Threads per SM: {specs['max_threads_per_sm']}", + f"- L1 Cache: {specs['l1_cache_kb']} KB per SM", + f"- L2 Cache: {specs['l2_cache_mb']} MB", + f"- Memory: {specs['memory_gb']} GB {specs.get('memory_type', '')}", + ] + ) + + return "\n".join(lines) + + +if __name__ == "__main__": + print("GPU Specifications Module") + print("=" * 60) + + # Auto-detect GPU + detected_name = query_gpu_name() + if detected_name: + print(f"\nDetected GPU: {detected_name}") + else: + print("\nNo GPU detected (nvidia-smi not available)") + + # Get specs + specs = get_gpu_specs() + print(f"\n{format_gpu_specs_for_prompt(specs)}") + + # Show all available GPUs + print(f"\n{'=' * 60}") + print("Available GPU specifications in database:") + for gpu_name in sorted(GPU_SPECS_DATABASE.keys()): + print(f" - {gpu_name}") diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py new file mode 100644 index 0000000..0b9168b --- /dev/null +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +NCU Profiling Module for Triton Kernels + +This module wraps three tasks: +1) Collect core metrics for Triton CUDA kernels with Nsight Compute into CSV (`profile_triton_kernel`). +2) Extract and clean those metrics into a DataFrame from the CSV (`load_ncu_metrics`). +3) Convert the metrics table into a string suitable for inclusion in an LLM prompt (`metrics_to_prompt`). + +""" + +import json +import math +import os +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Union + +import numpy as np +import pandas as pd + + +__all__ = [ + "METRICS", + "METRIC_COLUMNS", + "profile_triton_kernel", + "load_ncu_metrics", + "metrics_to_prompt", +] + +METRICS = ",".join( + [ + "sm__cycles_active.avg", + "sm__warps_active.avg.pct_of_peak_sustained_active", + "launch__occupancy_limit_blocks", + "launch__occupancy_limit_registers", + "launch__occupancy_limit_shared_mem", + "launch__registers_per_thread", + "sm__inst_executed.sum", + "sm__inst_executed_pipe_fp32.avg.pct_of_peak_sustained_active", + "sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", + "dram__bytes_read.sum", + "dram__bytes_write.sum", + "dram__throughput.avg.pct_of_peak_sustained_elapsed", + "dram__bytes.sum.per_second", + "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed", + "launch__shared_mem_per_block_allocated", + "l1tex__t_sector_hit_rate.pct", + "l1tex__throughput.avg.pct_of_peak_sustained_active", + "lts__t_sector_hit_rate.pct", + "lts__throughput.avg.pct_of_peak_sustained_active", + "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", + "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", + "smsp__warp_issue_stalled_memory_dependency_per_warp_active.pct", + "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", + "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", + "smsp__warp_issue_stalled_barrier_per_warp_active.pct", + "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", + "smsp__sass_average_branch_targets_threads_uniform.pct", + ] +) + +# List version for convenient header selection +METRIC_COLUMNS: List[str] = [s.strip() for s in METRICS.split(",")] + + +def profile_triton_kernel( + benchmark_script: Path, + workdir: Path, + out_csv: str = "ncu_output.csv", + python_executable: Optional[str] = None, + ncu_bin: Optional[str] = None, + launch_count: int = 20, + timeout: int = 120, +) -> Path: + """ + Profile a Triton kernel using NCU. + + Args: + benchmark_script: Path to benchmark script that calls the kernel + workdir: Working directory for execution + out_csv: Output CSV filename + python_executable: Python executable to use (default: sys.executable) + ncu_bin: Path to NCU binary (default: auto-detect) + launch_count: Number of kernel launches to profile + timeout: Timeout in seconds for NCU execution + + Returns: + Path to output CSV file + + Raises: + RuntimeError: If NCU profiling fails + FileNotFoundError: If NCU binary or output CSV not found + """ + # Resolve paths + if python_executable is None: + python_executable = sys.executable + + if ncu_bin is None: + ncu_bin = shutil.which("ncu") or "/usr/local/cuda/bin/ncu" + + if not Path(ncu_bin).exists(): + raise FileNotFoundError(f"NCU binary not found: {ncu_bin}") + + csv_path = (workdir / out_csv).resolve() + benchmark_script = benchmark_script.resolve() + + if not benchmark_script.exists(): + raise FileNotFoundError(f"Benchmark script not found: {benchmark_script}") + + # Preserve important environment variables + env = os.environ.copy() + + # Add Triton-specific environment variables + env["TRITON_CACHE_DIR"] = str(workdir / ".triton_cache") + + preserve = ",".join( + [ + "PATH", + "LD_LIBRARY_PATH", + "CUDA_VISIBLE_DEVICES", + "PYTHONPATH", + "TRITON_CACHE_DIR", + "TORCH_EXTENSIONS_DIR", + "CONDA_PREFIX", + "CONDA_DEFAULT_ENV", + ] + ) + + # Build NCU command + cmd = [ + "sudo", + "-E", + f"--preserve-env={preserve}", + ncu_bin, + "--csv", + "--page=raw", + "--kernel-name-base=demangled", + "--target-processes=all", + "--replay-mode=kernel", + "--profile-from-start=on", + f"--log-file={str(csv_path)}", + f"--metrics={METRICS}", + "--launch-skip=0", + f"--launch-count={launch_count}", + python_executable, + str(benchmark_script), + ] + + print(f"[NCU] Running profiling...") + print(f"[NCU] Benchmark: {benchmark_script.name}") + print(f"[NCU] Output: {csv_path}") + print(f"[NCU] Command: {' '.join(cmd[:10])}... (truncated)") + + try: + result = subprocess.run( + cmd, + cwd=str(workdir), + env=env, + capture_output=True, + text=True, + timeout=timeout, + ) + + if result.returncode != 0: + error_msg = result.stderr or result.stdout + raise RuntimeError( + f"NCU profiling failed with return code {result.returncode}:\n{error_msg[:500]}" + ) + + if not csv_path.exists(): + raise FileNotFoundError(f"NCU did not create output CSV: {csv_path}") + + # Check if CSV has content + csv_size = csv_path.stat().st_size + if csv_size < 100: + raise RuntimeError( + f"NCU CSV file is too small ({csv_size} bytes), likely empty" + ) + + print(f"[NCU] ✓ Profiling completed successfully") + print(f"[NCU] ✓ CSV written: {csv_path} ({csv_size} bytes)") + return csv_path + + except subprocess.TimeoutExpired: + raise RuntimeError(f"NCU profiling timed out after {timeout} seconds") + except Exception as e: + raise RuntimeError(f"NCU profiling failed: {e}") + + +def load_ncu_metrics( + csv_path: Union[str, Path], + columns: Optional[Sequence[str]] = None, + extra_keep: Optional[Sequence[str]] = ("Kernel Name",), + coerce_numeric: bool = True, + name_list: Optional[Sequence[str]] = None, + select: str = "last", +) -> pd.DataFrame: + """ + Load and parse NCU metrics from CSV file. + + Args: + csv_path: Path to NCU CSV output + columns: Specific metric columns to load (default: all METRIC_COLUMNS) + extra_keep: Additional columns to keep (e.g., "Kernel Name") + coerce_numeric: Convert metric values to numeric + name_list: Filter by kernel names (substring match) + select: Selection policy when multiple rows per name: + "first", "last", "max_cycles" + + Returns: + DataFrame with parsed metrics + + Raises: + FileNotFoundError: If CSV file not found + ValueError: If no requested columns found in CSV + """ + csv_path = Path(csv_path) + if not csv_path.exists(): + raise FileNotFoundError(f"CSV not found: {csv_path}") + + df = pd.read_csv(csv_path, comment="=", low_memory=False) + + metric_cols = list(columns) if columns is not None else METRIC_COLUMNS + keep_cols: List[str] = [] + if extra_keep: + keep_cols.extend([c for c in extra_keep if c in df.columns]) + keep_cols.extend([c for c in metric_cols if c in df.columns]) + + if not keep_cols: + raise ValueError("No requested columns found in the CSV header.") + + sub = df[keep_cols].copy() + + # Drop the units row (first row often contains units like "%", "inst", etc.) + if len(sub) > 0: + first_row_str = sub.iloc[0].astype(str).str.lower() + unit_tokens = ("%", "inst", "cycle", "block", "register", "register/thread") + if first_row_str.apply(lambda x: any(tok in x for tok in unit_tokens)).any(): + sub = sub.iloc[1:].reset_index(drop=True) + + # Coerce metrics to numeric + if coerce_numeric: + metric_in_sub = [c for c in metric_cols if c in sub.columns] + sub[metric_in_sub] = ( + sub[metric_in_sub] + .replace({",": "", "%": ""}, regex=True) + .apply(pd.to_numeric, errors="coerce") + ) + + # Filter by kernel name list if provided + if name_list: + results = [] + for name in name_list: + # Use contains match instead of exact equality (for Triton's long kernel names) + matched = sub[ + sub["Kernel Name"].astype(str).str.contains(name, regex=False, na=False) + ] + if matched.empty: + continue + if len(matched) > 1: + if select == "first": + row = matched.iloc[[0]] + elif select == "last": + row = matched.iloc[[-1]] + elif ( + select == "max_cycles" + and "sm__cycles_active.avg" in matched.columns + ): + row = matched.sort_values( + "sm__cycles_active.avg", ascending=False + ).head(1) + else: + row = matched.iloc[[-1]] # fallback + else: + row = matched + results.append(row) + + if results: + sub = pd.concat(results, ignore_index=True) + else: + sub = pd.DataFrame(columns=keep_cols) + elif select in ("first", "last", "max_cycles"): + # Apply selection to all rows if no name filter + if len(sub) > 0: + if select == "first": + sub = sub.iloc[[0]] + elif select == "last": + sub = sub.iloc[[-1]] + elif select == "max_cycles" and "sm__cycles_active.avg" in sub.columns: + sub = sub.sort_values("sm__cycles_active.avg", ascending=False).head(1) + + return sub + + +def metrics_to_prompt( + df: pd.DataFrame, + title: str = "GPU Profiling Metrics:", + key_by: str = "Kernel Name", + round_digits: Optional[int] = 3, + compact: bool = False, + keep_cols: Optional[List[str]] = None, +) -> str: + """ + Convert NCU metrics DataFrame to JSON string for LLM prompts. + + Returns JSON in format: + { + "": { "": , ... } + } + Args: + df: DataFrame with NCU metrics + title: Title for the metrics section (not included in output) + key_by: Column to use as key (usually "Kernel Name") + round_digits: Number of decimal places for rounding + compact: If True, use compact JSON (no indentation) + keep_cols: Specific columns to include in output + + Returns: + JSON string with metrics + """ + + def _safe(v: Any) -> Any: + """Convert values to JSON-safe format.""" + if v is None or (isinstance(v, float) and math.isnan(v)): + return None + if isinstance(v, (pd.Timestamp, pd.Timedelta, pd.Interval)): + return str(v) + if isinstance(v, np.generic): + v = v.item() + if isinstance(v, float) and math.isinf(v): + return "inf" if v > 0 else "-inf" + if isinstance(v, float) and round_digits is not None: + return round(v, round_digits) + return v + + # Empty table + if df is None or df.empty: + return "{}" + + cols = list(df.columns) + + # Round numeric columns + if round_digits is not None: + num_cols = df.select_dtypes(include="number").columns + if len(num_cols) > 0: + df = df.copy() + df[num_cols] = df[num_cols].round(round_digits) + + # If key column is missing, return a list of rows + if key_by not in cols: + rows = [ + {k: _safe(v) for k, v in rec.items()} + for rec in df.to_dict(orient="records") + ] + return json.dumps(rows, ensure_ascii=False, indent=None if compact else 2) + + # Determine value columns + value_cols = [c for c in cols if c != key_by] + if keep_cols is not None: + value_cols = [c for c in value_cols if c in keep_cols] + + data: Dict[str, Any] = {} + for rec in df[[key_by] + value_cols].to_dict(orient="records"): + k = str(rec.pop(key_by)) + val_obj = {ck: _safe(cv) for ck, cv in rec.items()} + if k in data: + # Multiple rows for same key - convert to list + if isinstance(data[k], list): + data[k].append(val_obj) + else: + data[k] = [data[k], val_obj] + else: + data[k] = val_obj + + return json.dumps(data, ensure_ascii=False, indent=None if compact else 2) + + +if __name__ == "__main__": + print("ncu_profiler module loaded.") + print("Import its functions in your scripts:") + print( + " from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import profile_triton_kernel, load_ncu_metrics, metrics_to_prompt" + ) diff --git a/pyproject.toml b/pyproject.toml index 2d5382f..a31dc3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ pipeline-ui = "scripts.pipeline_ui:main" "Repository" = "https://github.com/pytorch-labs/KernelAgent" [tool.setuptools.packages.find] -include = ["triton_kernel_agent*", "Fuser*", "scripts", "utils"] +include = ["triton_kernel_agent*", "kernel_perf_agent*", "Fuser*", "scripts", "utils"] [tool.setuptools] include-package-data = true From 4d35d570e3053a5b22c30c1b7c2cc73cd7bb1360 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 14:18:37 -0800 Subject: [PATCH 06/12] Remove kernel_perf_util directory --- kernel_perf_util/README.md | 8 - .../kernel_opt/profiler/__init__.py | 0 .../kernel_opt/profiler/gpu_specs.py | 227 ---------- .../kernel_opt/profiler/ncu_profiler.py | 388 ------------------ .../profiling/kernel_profiler.py | 2 +- 5 files changed, 1 insertion(+), 624 deletions(-) delete mode 100644 kernel_perf_util/README.md delete mode 100644 kernel_perf_util/kernel_opt/profiler/__init__.py delete mode 100644 kernel_perf_util/kernel_opt/profiler/gpu_specs.py delete mode 100644 kernel_perf_util/kernel_opt/profiler/ncu_profiler.py diff --git a/kernel_perf_util/README.md b/kernel_perf_util/README.md deleted file mode 100644 index 5114402..0000000 --- a/kernel_perf_util/README.md +++ /dev/null @@ -1,8 +0,0 @@ -# Kernel Performance Utilities - -Low-level, reusable utilities for kernel optimization: -- NCU profiling -- RAG-based optimization database -- Performance analysis tools - -These are **library components** that can be used by any worker. diff --git a/kernel_perf_util/kernel_opt/profiler/__init__.py b/kernel_perf_util/kernel_opt/profiler/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/kernel_perf_util/kernel_opt/profiler/gpu_specs.py b/kernel_perf_util/kernel_opt/profiler/gpu_specs.py deleted file mode 100644 index ed8b50f..0000000 --- a/kernel_perf_util/kernel_opt/profiler/gpu_specs.py +++ /dev/null @@ -1,227 +0,0 @@ -#!/usr/bin/env python3 -""" -GPU Specifications Database for Bottleneck Analysis - -This module provides GPU hardware specifications needed for performance analysis -and bottleneck identification. It includes peak compute performance, memory bandwidth, -cache sizes, and SM counts for common NVIDIA GPUs. - -""" - -import re -import subprocess -from typing import Any, Dict, Optional - - -# GPU specifications database -# Sources: NVIDIA official specifications, manufacturer datasheets -GPU_SPECS_DATABASE = { - "NVIDIA A100": { - "name": "NVIDIA A100", - "architecture": "Ampere", - "peak_fp32_tflops": 19.5, - "peak_fp16_tflops": 312.0, - "peak_bf16_tflops": 312.0, - "peak_memory_bw_gbps": 1555, - "sm_count": 108, - "max_threads_per_sm": 2048, - "l1_cache_kb": 192, - "l2_cache_mb": 40, - "memory_gb": 40, - "memory_type": "HBM2e", - }, - "NVIDIA H100": { - "name": "NVIDIA H100", - "architecture": "Hopper", - "peak_fp32_tflops": 51.0, - "peak_fp16_tflops": 989.0, - "peak_bf16_tflops": 989.0, - "peak_memory_bw_gbps": 3352, - "sm_count": 132, - "max_threads_per_sm": 2048, - "l1_cache_kb": 256, - "l2_cache_mb": 50, - "memory_gb": 80, - "memory_type": "HBM3", - }, - "NVIDIA RTX 4090": { - "name": "NVIDIA RTX 4090", - "architecture": "Ada Lovelace", - "peak_fp32_tflops": 82.6, - "peak_fp16_tflops": 165.0, - "peak_bf16_tflops": 165.0, - "peak_memory_bw_gbps": 1008, - "sm_count": 128, - "max_threads_per_sm": 1536, - "l1_cache_kb": 128, - "l2_cache_mb": 72, - "memory_gb": 24, - "memory_type": "GDDR6X", - }, - "NVIDIA RTX 5080": { - "name": "NVIDIA RTX 5080", - "architecture": "Blackwell", - "peak_fp32_tflops": 57.0, - "peak_fp16_tflops": 114.0, - "peak_bf16_tflops": 114.0, - "peak_memory_bw_gbps": 960, - "sm_count": 84, - "max_threads_per_sm": 1536, - "l1_cache_kb": 128, - "l2_cache_mb": 64, - "memory_gb": 16, - "memory_type": "GDDR7", - }, -} - - -def query_gpu_name() -> Optional[str]: - """ - Query GPU name using nvidia-smi. - - Returns: - GPU name string, or None if query fails - """ - try: - result = subprocess.run( - ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], - capture_output=True, - text=True, - timeout=5, - ) - if result.returncode == 0: - gpu_name = result.stdout.strip() - # Normalize whitespace - gpu_name = re.sub(r"\s+", " ", gpu_name) - return gpu_name - except (subprocess.TimeoutExpired, FileNotFoundError, Exception): - pass - return None - - -def get_gpu_specs(gpu_name: Optional[str] = None) -> Dict[str, Any]: - """ - Get GPU specifications for bottleneck analysis. - - This function returns hardware specifications needed for performance analysis, - including peak compute performance, memory bandwidth, cache sizes, and SM counts. - - Args: - gpu_name: GPU name (if None, auto-detect with nvidia-smi) - - Returns: - Dictionary with GPU specifications containing: - - name: GPU name - - architecture: GPU architecture (e.g., "Ampere", "Hopper") - - peak_fp32_tflops: Peak FP32 compute performance in TFLOPS - - peak_fp16_tflops: Peak FP16 compute performance in TFLOPS - - peak_bf16_tflops: Peak BF16 compute performance in TFLOPS (0 if not supported) - - peak_memory_bw_gbps: Peak memory bandwidth in GB/s - - sm_count: Number of streaming multiprocessors - - max_threads_per_sm: Maximum threads per SM - - l1_cache_kb: L1 cache size in KB per SM - - l2_cache_mb: Total L2 cache size in MB - - memory_gb: Total GPU memory in GB - - memory_type: Memory type (e.g., "HBM2e", "GDDR6X") - - Examples: - >>> specs = get_gpu_specs() # Auto-detect - >>> print(f"Peak BW: {specs['peak_memory_bw_gbps']} GB/s") - - >>> specs = get_gpu_specs("NVIDIA A100") - >>> print(f"SM Count: {specs['sm_count']}") - """ - # Auto-detect if not provided - if gpu_name is None: - gpu_name = query_gpu_name() - - # Return default if detection failed - if gpu_name is None: - print("⚠️ GPU auto-detection failed, using A100 specs as fallback") - return GPU_SPECS_DATABASE["NVIDIA A100"].copy() - - # Try exact match - if gpu_name in GPU_SPECS_DATABASE: - return GPU_SPECS_DATABASE[gpu_name].copy() - - # Try fuzzy match (contains or partial match) - gpu_name_lower = gpu_name.lower() - for key, specs in GPU_SPECS_DATABASE.items(): - key_lower = key.lower() - # Check if either name contains the other - if gpu_name_lower in key_lower or key_lower in gpu_name_lower: - print(f"ℹ️ Matched '{gpu_name}' to '{key}' (fuzzy match)") - return specs.copy() - - # Fallback to A100 specs with warning - print(f"⚠️ Unknown GPU: '{gpu_name}', using A100 specs as fallback") - print(f" Available GPUs: {', '.join(GPU_SPECS_DATABASE.keys())}") - return GPU_SPECS_DATABASE["NVIDIA A100"].copy() - - -def format_gpu_specs_for_prompt(specs: Dict[str, Any]) -> str: - """ - Format GPU specifications as a readable string for LLM prompts. - - Args: - specs: GPU specifications dictionary from get_gpu_specs() - - Returns: - Formatted string with GPU specifications - - Example: - >>> specs = get_gpu_specs("NVIDIA A100") - >>> print(format_gpu_specs_for_prompt(specs)) - GPU: NVIDIA A100 (Ampere) - - Peak Memory Bandwidth: 1555 GB/s - - Peak FP32 Performance: 19.5 TFLOPS - - Peak FP16 Performance: 312.0 TFLOPS - - SM Count: 108 - - L1 Cache: 192 KB per SM - - L2 Cache: 40 MB - """ - lines = [ - f"GPU: {specs['name']} ({specs.get('architecture', 'Unknown')})", - f"- Peak Memory Bandwidth: {specs['peak_memory_bw_gbps']} GB/s", - f"- Peak FP32 Performance: {specs['peak_fp32_tflops']} TFLOPS", - ] - - if specs.get("peak_fp16_tflops", 0) > 0: - lines.append(f"- Peak FP16 Performance: {specs['peak_fp16_tflops']} TFLOPS") - - if specs.get("peak_bf16_tflops", 0) > 0: - lines.append(f"- Peak BF16 Performance: {specs['peak_bf16_tflops']} TFLOPS") - - lines.extend( - [ - f"- SM Count: {specs['sm_count']}", - f"- Max Threads per SM: {specs['max_threads_per_sm']}", - f"- L1 Cache: {specs['l1_cache_kb']} KB per SM", - f"- L2 Cache: {specs['l2_cache_mb']} MB", - f"- Memory: {specs['memory_gb']} GB {specs.get('memory_type', '')}", - ] - ) - - return "\n".join(lines) - - -if __name__ == "__main__": - print("GPU Specifications Module") - print("=" * 60) - - # Auto-detect GPU - detected_name = query_gpu_name() - if detected_name: - print(f"\nDetected GPU: {detected_name}") - else: - print("\nNo GPU detected (nvidia-smi not available)") - - # Get specs - specs = get_gpu_specs() - print(f"\n{format_gpu_specs_for_prompt(specs)}") - - # Show all available GPUs - print(f"\n{'=' * 60}") - print("Available GPU specifications in database:") - for gpu_name in sorted(GPU_SPECS_DATABASE.keys()): - print(f" - {gpu_name}") diff --git a/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py deleted file mode 100644 index 0b9168b..0000000 --- a/kernel_perf_util/kernel_opt/profiler/ncu_profiler.py +++ /dev/null @@ -1,388 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -NCU Profiling Module for Triton Kernels - -This module wraps three tasks: -1) Collect core metrics for Triton CUDA kernels with Nsight Compute into CSV (`profile_triton_kernel`). -2) Extract and clean those metrics into a DataFrame from the CSV (`load_ncu_metrics`). -3) Convert the metrics table into a string suitable for inclusion in an LLM prompt (`metrics_to_prompt`). - -""" - -import json -import math -import os -import shutil -import subprocess -import sys -from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Union - -import numpy as np -import pandas as pd - - -__all__ = [ - "METRICS", - "METRIC_COLUMNS", - "profile_triton_kernel", - "load_ncu_metrics", - "metrics_to_prompt", -] - -METRICS = ",".join( - [ - "sm__cycles_active.avg", - "sm__warps_active.avg.pct_of_peak_sustained_active", - "launch__occupancy_limit_blocks", - "launch__occupancy_limit_registers", - "launch__occupancy_limit_shared_mem", - "launch__registers_per_thread", - "sm__inst_executed.sum", - "sm__inst_executed_pipe_fp32.avg.pct_of_peak_sustained_active", - "sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", - "dram__bytes_read.sum", - "dram__bytes_write.sum", - "dram__throughput.avg.pct_of_peak_sustained_elapsed", - "dram__bytes.sum.per_second", - "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed", - "launch__shared_mem_per_block_allocated", - "l1tex__t_sector_hit_rate.pct", - "l1tex__throughput.avg.pct_of_peak_sustained_active", - "lts__t_sector_hit_rate.pct", - "lts__throughput.avg.pct_of_peak_sustained_active", - "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", - "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", - "smsp__warp_issue_stalled_memory_dependency_per_warp_active.pct", - "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", - "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", - "smsp__warp_issue_stalled_barrier_per_warp_active.pct", - "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", - "smsp__sass_average_branch_targets_threads_uniform.pct", - ] -) - -# List version for convenient header selection -METRIC_COLUMNS: List[str] = [s.strip() for s in METRICS.split(",")] - - -def profile_triton_kernel( - benchmark_script: Path, - workdir: Path, - out_csv: str = "ncu_output.csv", - python_executable: Optional[str] = None, - ncu_bin: Optional[str] = None, - launch_count: int = 20, - timeout: int = 120, -) -> Path: - """ - Profile a Triton kernel using NCU. - - Args: - benchmark_script: Path to benchmark script that calls the kernel - workdir: Working directory for execution - out_csv: Output CSV filename - python_executable: Python executable to use (default: sys.executable) - ncu_bin: Path to NCU binary (default: auto-detect) - launch_count: Number of kernel launches to profile - timeout: Timeout in seconds for NCU execution - - Returns: - Path to output CSV file - - Raises: - RuntimeError: If NCU profiling fails - FileNotFoundError: If NCU binary or output CSV not found - """ - # Resolve paths - if python_executable is None: - python_executable = sys.executable - - if ncu_bin is None: - ncu_bin = shutil.which("ncu") or "/usr/local/cuda/bin/ncu" - - if not Path(ncu_bin).exists(): - raise FileNotFoundError(f"NCU binary not found: {ncu_bin}") - - csv_path = (workdir / out_csv).resolve() - benchmark_script = benchmark_script.resolve() - - if not benchmark_script.exists(): - raise FileNotFoundError(f"Benchmark script not found: {benchmark_script}") - - # Preserve important environment variables - env = os.environ.copy() - - # Add Triton-specific environment variables - env["TRITON_CACHE_DIR"] = str(workdir / ".triton_cache") - - preserve = ",".join( - [ - "PATH", - "LD_LIBRARY_PATH", - "CUDA_VISIBLE_DEVICES", - "PYTHONPATH", - "TRITON_CACHE_DIR", - "TORCH_EXTENSIONS_DIR", - "CONDA_PREFIX", - "CONDA_DEFAULT_ENV", - ] - ) - - # Build NCU command - cmd = [ - "sudo", - "-E", - f"--preserve-env={preserve}", - ncu_bin, - "--csv", - "--page=raw", - "--kernel-name-base=demangled", - "--target-processes=all", - "--replay-mode=kernel", - "--profile-from-start=on", - f"--log-file={str(csv_path)}", - f"--metrics={METRICS}", - "--launch-skip=0", - f"--launch-count={launch_count}", - python_executable, - str(benchmark_script), - ] - - print(f"[NCU] Running profiling...") - print(f"[NCU] Benchmark: {benchmark_script.name}") - print(f"[NCU] Output: {csv_path}") - print(f"[NCU] Command: {' '.join(cmd[:10])}... (truncated)") - - try: - result = subprocess.run( - cmd, - cwd=str(workdir), - env=env, - capture_output=True, - text=True, - timeout=timeout, - ) - - if result.returncode != 0: - error_msg = result.stderr or result.stdout - raise RuntimeError( - f"NCU profiling failed with return code {result.returncode}:\n{error_msg[:500]}" - ) - - if not csv_path.exists(): - raise FileNotFoundError(f"NCU did not create output CSV: {csv_path}") - - # Check if CSV has content - csv_size = csv_path.stat().st_size - if csv_size < 100: - raise RuntimeError( - f"NCU CSV file is too small ({csv_size} bytes), likely empty" - ) - - print(f"[NCU] ✓ Profiling completed successfully") - print(f"[NCU] ✓ CSV written: {csv_path} ({csv_size} bytes)") - return csv_path - - except subprocess.TimeoutExpired: - raise RuntimeError(f"NCU profiling timed out after {timeout} seconds") - except Exception as e: - raise RuntimeError(f"NCU profiling failed: {e}") - - -def load_ncu_metrics( - csv_path: Union[str, Path], - columns: Optional[Sequence[str]] = None, - extra_keep: Optional[Sequence[str]] = ("Kernel Name",), - coerce_numeric: bool = True, - name_list: Optional[Sequence[str]] = None, - select: str = "last", -) -> pd.DataFrame: - """ - Load and parse NCU metrics from CSV file. - - Args: - csv_path: Path to NCU CSV output - columns: Specific metric columns to load (default: all METRIC_COLUMNS) - extra_keep: Additional columns to keep (e.g., "Kernel Name") - coerce_numeric: Convert metric values to numeric - name_list: Filter by kernel names (substring match) - select: Selection policy when multiple rows per name: - "first", "last", "max_cycles" - - Returns: - DataFrame with parsed metrics - - Raises: - FileNotFoundError: If CSV file not found - ValueError: If no requested columns found in CSV - """ - csv_path = Path(csv_path) - if not csv_path.exists(): - raise FileNotFoundError(f"CSV not found: {csv_path}") - - df = pd.read_csv(csv_path, comment="=", low_memory=False) - - metric_cols = list(columns) if columns is not None else METRIC_COLUMNS - keep_cols: List[str] = [] - if extra_keep: - keep_cols.extend([c for c in extra_keep if c in df.columns]) - keep_cols.extend([c for c in metric_cols if c in df.columns]) - - if not keep_cols: - raise ValueError("No requested columns found in the CSV header.") - - sub = df[keep_cols].copy() - - # Drop the units row (first row often contains units like "%", "inst", etc.) - if len(sub) > 0: - first_row_str = sub.iloc[0].astype(str).str.lower() - unit_tokens = ("%", "inst", "cycle", "block", "register", "register/thread") - if first_row_str.apply(lambda x: any(tok in x for tok in unit_tokens)).any(): - sub = sub.iloc[1:].reset_index(drop=True) - - # Coerce metrics to numeric - if coerce_numeric: - metric_in_sub = [c for c in metric_cols if c in sub.columns] - sub[metric_in_sub] = ( - sub[metric_in_sub] - .replace({",": "", "%": ""}, regex=True) - .apply(pd.to_numeric, errors="coerce") - ) - - # Filter by kernel name list if provided - if name_list: - results = [] - for name in name_list: - # Use contains match instead of exact equality (for Triton's long kernel names) - matched = sub[ - sub["Kernel Name"].astype(str).str.contains(name, regex=False, na=False) - ] - if matched.empty: - continue - if len(matched) > 1: - if select == "first": - row = matched.iloc[[0]] - elif select == "last": - row = matched.iloc[[-1]] - elif ( - select == "max_cycles" - and "sm__cycles_active.avg" in matched.columns - ): - row = matched.sort_values( - "sm__cycles_active.avg", ascending=False - ).head(1) - else: - row = matched.iloc[[-1]] # fallback - else: - row = matched - results.append(row) - - if results: - sub = pd.concat(results, ignore_index=True) - else: - sub = pd.DataFrame(columns=keep_cols) - elif select in ("first", "last", "max_cycles"): - # Apply selection to all rows if no name filter - if len(sub) > 0: - if select == "first": - sub = sub.iloc[[0]] - elif select == "last": - sub = sub.iloc[[-1]] - elif select == "max_cycles" and "sm__cycles_active.avg" in sub.columns: - sub = sub.sort_values("sm__cycles_active.avg", ascending=False).head(1) - - return sub - - -def metrics_to_prompt( - df: pd.DataFrame, - title: str = "GPU Profiling Metrics:", - key_by: str = "Kernel Name", - round_digits: Optional[int] = 3, - compact: bool = False, - keep_cols: Optional[List[str]] = None, -) -> str: - """ - Convert NCU metrics DataFrame to JSON string for LLM prompts. - - Returns JSON in format: - { - "": { "": , ... } - } - Args: - df: DataFrame with NCU metrics - title: Title for the metrics section (not included in output) - key_by: Column to use as key (usually "Kernel Name") - round_digits: Number of decimal places for rounding - compact: If True, use compact JSON (no indentation) - keep_cols: Specific columns to include in output - - Returns: - JSON string with metrics - """ - - def _safe(v: Any) -> Any: - """Convert values to JSON-safe format.""" - if v is None or (isinstance(v, float) and math.isnan(v)): - return None - if isinstance(v, (pd.Timestamp, pd.Timedelta, pd.Interval)): - return str(v) - if isinstance(v, np.generic): - v = v.item() - if isinstance(v, float) and math.isinf(v): - return "inf" if v > 0 else "-inf" - if isinstance(v, float) and round_digits is not None: - return round(v, round_digits) - return v - - # Empty table - if df is None or df.empty: - return "{}" - - cols = list(df.columns) - - # Round numeric columns - if round_digits is not None: - num_cols = df.select_dtypes(include="number").columns - if len(num_cols) > 0: - df = df.copy() - df[num_cols] = df[num_cols].round(round_digits) - - # If key column is missing, return a list of rows - if key_by not in cols: - rows = [ - {k: _safe(v) for k, v in rec.items()} - for rec in df.to_dict(orient="records") - ] - return json.dumps(rows, ensure_ascii=False, indent=None if compact else 2) - - # Determine value columns - value_cols = [c for c in cols if c != key_by] - if keep_cols is not None: - value_cols = [c for c in value_cols if c in keep_cols] - - data: Dict[str, Any] = {} - for rec in df[[key_by] + value_cols].to_dict(orient="records"): - k = str(rec.pop(key_by)) - val_obj = {ck: _safe(cv) for ck, cv in rec.items()} - if k in data: - # Multiple rows for same key - convert to list - if isinstance(data[k], list): - data[k].append(val_obj) - else: - data[k] = [data[k], val_obj] - else: - data[k] = val_obj - - return json.dumps(data, ensure_ascii=False, indent=None if compact else 2) - - -if __name__ == "__main__": - print("ncu_profiler module loaded.") - print("Import its functions in your scripts:") - print( - " from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import profile_triton_kernel, load_ncu_metrics, metrics_to_prompt" - ) diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index 61f41b3..693fa8a 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -10,7 +10,7 @@ import pandas as pd -from kernel_perf_util.kernel_opt.profiler.ncu_profiler import ( +from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( load_ncu_metrics, metrics_to_prompt, profile_triton_kernel, From d87167857f31c5ee87a0a121a36312b18f82d9ba Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 14:37:08 -0800 Subject: [PATCH 07/12] move gpu spec.py to future PR and fix import --- .../kernel_opt/profiler/gpu_specs.py | 227 ------------------ .../profiling/kernel_profiler.py | 4 +- 2 files changed, 3 insertions(+), 228 deletions(-) delete mode 100644 kernel_perf_agent/kernel_opt/profiler/gpu_specs.py diff --git a/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py b/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py deleted file mode 100644 index ed8b50f..0000000 --- a/kernel_perf_agent/kernel_opt/profiler/gpu_specs.py +++ /dev/null @@ -1,227 +0,0 @@ -#!/usr/bin/env python3 -""" -GPU Specifications Database for Bottleneck Analysis - -This module provides GPU hardware specifications needed for performance analysis -and bottleneck identification. It includes peak compute performance, memory bandwidth, -cache sizes, and SM counts for common NVIDIA GPUs. - -""" - -import re -import subprocess -from typing import Any, Dict, Optional - - -# GPU specifications database -# Sources: NVIDIA official specifications, manufacturer datasheets -GPU_SPECS_DATABASE = { - "NVIDIA A100": { - "name": "NVIDIA A100", - "architecture": "Ampere", - "peak_fp32_tflops": 19.5, - "peak_fp16_tflops": 312.0, - "peak_bf16_tflops": 312.0, - "peak_memory_bw_gbps": 1555, - "sm_count": 108, - "max_threads_per_sm": 2048, - "l1_cache_kb": 192, - "l2_cache_mb": 40, - "memory_gb": 40, - "memory_type": "HBM2e", - }, - "NVIDIA H100": { - "name": "NVIDIA H100", - "architecture": "Hopper", - "peak_fp32_tflops": 51.0, - "peak_fp16_tflops": 989.0, - "peak_bf16_tflops": 989.0, - "peak_memory_bw_gbps": 3352, - "sm_count": 132, - "max_threads_per_sm": 2048, - "l1_cache_kb": 256, - "l2_cache_mb": 50, - "memory_gb": 80, - "memory_type": "HBM3", - }, - "NVIDIA RTX 4090": { - "name": "NVIDIA RTX 4090", - "architecture": "Ada Lovelace", - "peak_fp32_tflops": 82.6, - "peak_fp16_tflops": 165.0, - "peak_bf16_tflops": 165.0, - "peak_memory_bw_gbps": 1008, - "sm_count": 128, - "max_threads_per_sm": 1536, - "l1_cache_kb": 128, - "l2_cache_mb": 72, - "memory_gb": 24, - "memory_type": "GDDR6X", - }, - "NVIDIA RTX 5080": { - "name": "NVIDIA RTX 5080", - "architecture": "Blackwell", - "peak_fp32_tflops": 57.0, - "peak_fp16_tflops": 114.0, - "peak_bf16_tflops": 114.0, - "peak_memory_bw_gbps": 960, - "sm_count": 84, - "max_threads_per_sm": 1536, - "l1_cache_kb": 128, - "l2_cache_mb": 64, - "memory_gb": 16, - "memory_type": "GDDR7", - }, -} - - -def query_gpu_name() -> Optional[str]: - """ - Query GPU name using nvidia-smi. - - Returns: - GPU name string, or None if query fails - """ - try: - result = subprocess.run( - ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], - capture_output=True, - text=True, - timeout=5, - ) - if result.returncode == 0: - gpu_name = result.stdout.strip() - # Normalize whitespace - gpu_name = re.sub(r"\s+", " ", gpu_name) - return gpu_name - except (subprocess.TimeoutExpired, FileNotFoundError, Exception): - pass - return None - - -def get_gpu_specs(gpu_name: Optional[str] = None) -> Dict[str, Any]: - """ - Get GPU specifications for bottleneck analysis. - - This function returns hardware specifications needed for performance analysis, - including peak compute performance, memory bandwidth, cache sizes, and SM counts. - - Args: - gpu_name: GPU name (if None, auto-detect with nvidia-smi) - - Returns: - Dictionary with GPU specifications containing: - - name: GPU name - - architecture: GPU architecture (e.g., "Ampere", "Hopper") - - peak_fp32_tflops: Peak FP32 compute performance in TFLOPS - - peak_fp16_tflops: Peak FP16 compute performance in TFLOPS - - peak_bf16_tflops: Peak BF16 compute performance in TFLOPS (0 if not supported) - - peak_memory_bw_gbps: Peak memory bandwidth in GB/s - - sm_count: Number of streaming multiprocessors - - max_threads_per_sm: Maximum threads per SM - - l1_cache_kb: L1 cache size in KB per SM - - l2_cache_mb: Total L2 cache size in MB - - memory_gb: Total GPU memory in GB - - memory_type: Memory type (e.g., "HBM2e", "GDDR6X") - - Examples: - >>> specs = get_gpu_specs() # Auto-detect - >>> print(f"Peak BW: {specs['peak_memory_bw_gbps']} GB/s") - - >>> specs = get_gpu_specs("NVIDIA A100") - >>> print(f"SM Count: {specs['sm_count']}") - """ - # Auto-detect if not provided - if gpu_name is None: - gpu_name = query_gpu_name() - - # Return default if detection failed - if gpu_name is None: - print("⚠️ GPU auto-detection failed, using A100 specs as fallback") - return GPU_SPECS_DATABASE["NVIDIA A100"].copy() - - # Try exact match - if gpu_name in GPU_SPECS_DATABASE: - return GPU_SPECS_DATABASE[gpu_name].copy() - - # Try fuzzy match (contains or partial match) - gpu_name_lower = gpu_name.lower() - for key, specs in GPU_SPECS_DATABASE.items(): - key_lower = key.lower() - # Check if either name contains the other - if gpu_name_lower in key_lower or key_lower in gpu_name_lower: - print(f"ℹ️ Matched '{gpu_name}' to '{key}' (fuzzy match)") - return specs.copy() - - # Fallback to A100 specs with warning - print(f"⚠️ Unknown GPU: '{gpu_name}', using A100 specs as fallback") - print(f" Available GPUs: {', '.join(GPU_SPECS_DATABASE.keys())}") - return GPU_SPECS_DATABASE["NVIDIA A100"].copy() - - -def format_gpu_specs_for_prompt(specs: Dict[str, Any]) -> str: - """ - Format GPU specifications as a readable string for LLM prompts. - - Args: - specs: GPU specifications dictionary from get_gpu_specs() - - Returns: - Formatted string with GPU specifications - - Example: - >>> specs = get_gpu_specs("NVIDIA A100") - >>> print(format_gpu_specs_for_prompt(specs)) - GPU: NVIDIA A100 (Ampere) - - Peak Memory Bandwidth: 1555 GB/s - - Peak FP32 Performance: 19.5 TFLOPS - - Peak FP16 Performance: 312.0 TFLOPS - - SM Count: 108 - - L1 Cache: 192 KB per SM - - L2 Cache: 40 MB - """ - lines = [ - f"GPU: {specs['name']} ({specs.get('architecture', 'Unknown')})", - f"- Peak Memory Bandwidth: {specs['peak_memory_bw_gbps']} GB/s", - f"- Peak FP32 Performance: {specs['peak_fp32_tflops']} TFLOPS", - ] - - if specs.get("peak_fp16_tflops", 0) > 0: - lines.append(f"- Peak FP16 Performance: {specs['peak_fp16_tflops']} TFLOPS") - - if specs.get("peak_bf16_tflops", 0) > 0: - lines.append(f"- Peak BF16 Performance: {specs['peak_bf16_tflops']} TFLOPS") - - lines.extend( - [ - f"- SM Count: {specs['sm_count']}", - f"- Max Threads per SM: {specs['max_threads_per_sm']}", - f"- L1 Cache: {specs['l1_cache_kb']} KB per SM", - f"- L2 Cache: {specs['l2_cache_mb']} MB", - f"- Memory: {specs['memory_gb']} GB {specs.get('memory_type', '')}", - ] - ) - - return "\n".join(lines) - - -if __name__ == "__main__": - print("GPU Specifications Module") - print("=" * 60) - - # Auto-detect GPU - detected_name = query_gpu_name() - if detected_name: - print(f"\nDetected GPU: {detected_name}") - else: - print("\nNo GPU detected (nvidia-smi not available)") - - # Get specs - specs = get_gpu_specs() - print(f"\n{format_gpu_specs_for_prompt(specs)}") - - # Show all available GPUs - print(f"\n{'=' * 60}") - print("Available GPU specifications in database:") - for gpu_name in sorted(GPU_SPECS_DATABASE.keys()): - print(f" - {gpu_name}") diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index 693fa8a..27825ea 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -16,7 +16,9 @@ profile_triton_kernel, ) -from .ncu_wrapper_generator import NCUWrapperGenerator +from triton_kernel_agent.opt_worker_component.profiling.ncu_wrapper_generator import ( + NCUWrapperGenerator, +) class KernelProfiler: From db0c75474f8b964610de73b15c4a40a578125df6 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 14:44:11 -0800 Subject: [PATCH 08/12] Add copyright header --- kernel_perf_agent/README.md | 7 +------ .../kernel_opt/profiler/__init__.py | 18 ++++++++++++++++++ .../kernel_opt/profiler/ncu_profiler.py | 15 +++++++++++++-- .../opt_worker_component/profiling/__init__.py | 16 +++++++++++++++- .../profiling/kernel_profiler.py | 14 ++++++++++++++ .../profiling/ncu_wrapper_generator.py | 15 +++++++++++++++ 6 files changed, 76 insertions(+), 9 deletions(-) diff --git a/kernel_perf_agent/README.md b/kernel_perf_agent/README.md index 5114402..e27dfec 100644 --- a/kernel_perf_agent/README.md +++ b/kernel_perf_agent/README.md @@ -1,8 +1,3 @@ # Kernel Performance Utilities -Low-level, reusable utilities for kernel optimization: -- NCU profiling -- RAG-based optimization database -- Performance analysis tools - -These are **library components** that can be used by any worker. +Low-level, reusable utilities for kernel optimization. diff --git a/kernel_perf_agent/kernel_opt/profiler/__init__.py b/kernel_perf_agent/kernel_opt/profiler/__init__.py index e69de29..d177194 100644 --- a/kernel_perf_agent/kernel_opt/profiler/__init__.py +++ b/kernel_perf_agent/kernel_opt/profiler/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""kernel_perf_agent package.""" + +# Kernel Perf Agent package +__all__ = [] diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py index 0b9168b..2a09353 100644 --- a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -1,5 +1,16 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ NCU Profiling Module for Triton Kernels diff --git a/triton_kernel_agent/opt_worker_component/profiling/__init__.py b/triton_kernel_agent/opt_worker_component/profiling/__init__.py index cb28cc6..218ca6a 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/__init__.py +++ b/triton_kernel_agent/opt_worker_component/profiling/__init__.py @@ -1,6 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Profiling infrastructure for NCU-based kernel analysis.""" -from .ncu_wrapper_generator import NCUWrapperGenerator from .kernel_profiler import KernelProfiler +from .ncu_wrapper_generator import NCUWrapperGenerator __all__ = ["NCUWrapperGenerator", "KernelProfiler"] diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index 27825ea..e1de43e 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Kernel profiling with NCU.""" import json diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py index 9f0e2e5..64fd1c1 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """NCU wrapper script generation for kernel profiling.""" import logging @@ -6,6 +20,7 @@ try: from jinja2 import Template + HAS_JINJA2 = True except ImportError: HAS_JINJA2 = False From cd297596bbdf9d4a99f5e610abcec89ae48a979c Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Wed, 7 Jan 2026 14:52:31 -0800 Subject: [PATCH 09/12] fix ruff --- kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py index 2a09353..5362073 100644 --- a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -162,7 +162,7 @@ def profile_triton_kernel( str(benchmark_script), ] - print(f"[NCU] Running profiling...") + print("[NCU] Running profiling...") print(f"[NCU] Benchmark: {benchmark_script.name}") print(f"[NCU] Output: {csv_path}") print(f"[NCU] Command: {' '.join(cmd[:10])}... (truncated)") @@ -193,7 +193,7 @@ def profile_triton_kernel( f"NCU CSV file is too small ({csv_size} bytes), likely empty" ) - print(f"[NCU] ✓ Profiling completed successfully") + print("[NCU] ✓ Profiling completed successfully") print(f"[NCU] ✓ CSV written: {csv_path} ({csv_size} bytes)") return csv_path From bbfa6cd02ab9e496b7330d745b47621ed62ed92a Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Mon, 12 Jan 2026 16:58:32 -0800 Subject: [PATCH 10/12] address previous comments --- .../kernel_opt/profiler/ncu_profiler.py | 152 +++++++---- .../profiling/__init__.py | 4 +- .../profiling/kernel_profiler.py | 241 +++++++++++------- ...er_generator.py => ncu_wrapper_factory.py} | 57 +---- 4 files changed, 265 insertions(+), 189 deletions(-) rename triton_kernel_agent/opt_worker_component/profiling/{ncu_wrapper_generator.py => ncu_wrapper_factory.py} (65%) diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py index 5362073..710d5c6 100644 --- a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -16,10 +16,11 @@ NCU Profiling Module for Triton Kernels This module wraps three tasks: -1) Collect core metrics for Triton CUDA kernels with Nsight Compute into CSV (`profile_triton_kernel`). +1) Collect core metrics for Triton kernels with Nsight Compute into CSV (`profile_triton_kernel`). 2) Extract and clean those metrics into a DataFrame from the CSV (`load_ncu_metrics`). 3) Convert the metrics table into a string suitable for inclusion in an LLM prompt (`metrics_to_prompt`). + """ import json @@ -28,20 +29,14 @@ import shutil import subprocess import sys +from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Union import numpy as np import pandas as pd - -__all__ = [ - "METRICS", - "METRIC_COLUMNS", - "profile_triton_kernel", - "load_ncu_metrics", - "metrics_to_prompt", -] +# This selection of metrics is inspried by the CudaForge team (https://github.com/OptimAI-Lab/CudaForge/tree/main) METRICS = ",".join( [ @@ -79,6 +74,14 @@ METRIC_COLUMNS: List[str] = [s.strip() for s in METRICS.split(",")] +class MetricSelectionPolicy(Enum): + """Policy for selecting rows when multiple rows exist for a kernel.""" + + FIRST = "first" + LAST = "last" + MAX_CYCLES = "max_cycles" + + def profile_triton_kernel( benchmark_script: Path, workdir: Path, @@ -203,13 +206,83 @@ def profile_triton_kernel( raise RuntimeError(f"NCU profiling failed: {e}") +def _apply_selection_policy( + df: pd.DataFrame, + policy: MetricSelectionPolicy, +) -> pd.DataFrame: + """ + Apply selection policy to choose a single row from a DataFrame. + + Args: + df: DataFrame with one or more rows + policy: Selection policy to apply + + Returns: + DataFrame with a single row based on the policy + """ + if df.empty: + return df + + if len(df) == 1: + return df + + if policy == MetricSelectionPolicy.FIRST: + return df.iloc[[0]] + elif policy == MetricSelectionPolicy.LAST: + return df.iloc[[-1]] + elif policy == MetricSelectionPolicy.MAX_CYCLES: + if "sm__cycles_active.avg" in df.columns: + return df.sort_values("sm__cycles_active.avg", ascending=False).head(1) + # Fallback to last if cycles column not available + return df.iloc[[-1]] + else: + # Fallback to last for unknown policies + return df.iloc[[-1]] + + +def _filter_by_kernel_names( + df: pd.DataFrame, + name_list: Sequence[str], + policy: MetricSelectionPolicy, + keep_cols: List[str], +) -> pd.DataFrame: + """ + Filter DataFrame by kernel names with substring matching. + + Args: + df: DataFrame with NCU metrics + name_list: List of kernel name substrings to match + policy: Selection policy when multiple rows match + keep_cols: Columns to preserve in empty result + + Returns: + Filtered DataFrame with one row per matched kernel name + """ + results = [] + for name in name_list: + # Use contains match instead of exact equality (for Triton's long kernel names) + matched = df[ + df["Kernel Name"].astype(str).str.contains(name, regex=False, na=False) + ] + if matched.empty: + continue + + row = _apply_selection_policy(matched, policy) + results.append(row) + + if results: + return pd.concat(results, ignore_index=True) + else: + return pd.DataFrame(columns=keep_cols) + + def load_ncu_metrics( csv_path: Union[str, Path], columns: Optional[Sequence[str]] = None, extra_keep: Optional[Sequence[str]] = ("Kernel Name",), coerce_numeric: bool = True, name_list: Optional[Sequence[str]] = None, - select: str = "last", + select: Union[str, MetricSelectionPolicy] = MetricSelectionPolicy.LAST, ) -> pd.DataFrame: """ Load and parse NCU metrics from CSV file. @@ -220,20 +293,32 @@ def load_ncu_metrics( extra_keep: Additional columns to keep (e.g., "Kernel Name") coerce_numeric: Convert metric values to numeric name_list: Filter by kernel names (substring match) - select: Selection policy when multiple rows per name: - "first", "last", "max_cycles" + select: Selection policy when multiple rows per name. + Can be MetricSelectionPolicy enum or string ("first", "last", "max_cycles") Returns: DataFrame with parsed metrics Raises: FileNotFoundError: If CSV file not found - ValueError: If no requested columns found in CSV + ValueError: If no requested columns found in CSV or invalid select value """ csv_path = Path(csv_path) if not csv_path.exists(): raise FileNotFoundError(f"CSV not found: {csv_path}") + # Convert string to enum if needed + if isinstance(select, str): + try: + policy = MetricSelectionPolicy(select) + except ValueError: + raise ValueError( + f"Invalid select value: {select}. " + f"Must be one of: {[p.value for p in MetricSelectionPolicy]}" + ) + else: + policy = select + df = pd.read_csv(csv_path, comment="=", low_memory=False) metric_cols = list(columns) if columns is not None else METRIC_COLUMNS @@ -265,45 +350,10 @@ def load_ncu_metrics( # Filter by kernel name list if provided if name_list: - results = [] - for name in name_list: - # Use contains match instead of exact equality (for Triton's long kernel names) - matched = sub[ - sub["Kernel Name"].astype(str).str.contains(name, regex=False, na=False) - ] - if matched.empty: - continue - if len(matched) > 1: - if select == "first": - row = matched.iloc[[0]] - elif select == "last": - row = matched.iloc[[-1]] - elif ( - select == "max_cycles" - and "sm__cycles_active.avg" in matched.columns - ): - row = matched.sort_values( - "sm__cycles_active.avg", ascending=False - ).head(1) - else: - row = matched.iloc[[-1]] # fallback - else: - row = matched - results.append(row) - - if results: - sub = pd.concat(results, ignore_index=True) - else: - sub = pd.DataFrame(columns=keep_cols) - elif select in ("first", "last", "max_cycles"): + sub = _filter_by_kernel_names(sub, name_list, policy, keep_cols) + else: # Apply selection to all rows if no name filter - if len(sub) > 0: - if select == "first": - sub = sub.iloc[[0]] - elif select == "last": - sub = sub.iloc[[-1]] - elif select == "max_cycles" and "sm__cycles_active.avg" in sub.columns: - sub = sub.sort_values("sm__cycles_active.avg", ascending=False).head(1) + sub = _apply_selection_policy(sub, policy) return sub diff --git a/triton_kernel_agent/opt_worker_component/profiling/__init__.py b/triton_kernel_agent/opt_worker_component/profiling/__init__.py index 218ca6a..8f8d18b 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/__init__.py +++ b/triton_kernel_agent/opt_worker_component/profiling/__init__.py @@ -15,6 +15,6 @@ """Profiling infrastructure for NCU-based kernel analysis.""" from .kernel_profiler import KernelProfiler -from .ncu_wrapper_generator import NCUWrapperGenerator +from .ncu_wrapper_factory import NCUWrapperFactory -__all__ = ["NCUWrapperGenerator", "KernelProfiler"] +__all__ = ["NCUWrapperFactory", "KernelProfiler"] diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index e1de43e..bcc5a70 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Kernel profiling with NCU.""" +"""Profiles Triton kernels using NVIDIA Nsight Compute (NCU).""" import json import logging import subprocess import time +from dataclasses import asdict, dataclass from datetime import datetime +from functools import cached_property from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict import pandas as pd @@ -30,42 +32,84 @@ profile_triton_kernel, ) -from triton_kernel_agent.opt_worker_component.profiling.ncu_wrapper_generator import ( - NCUWrapperGenerator, +from triton_kernel_agent.opt_worker_component.profiling.ncu_wrapper_factory import ( + NCUWrapperFactory, ) +# Default timeout for NCU profiling in seconds +DEFAULT_NCU_TIMEOUT_SECONDS = 120 + + +@dataclass +class ProfilerMetadata: + """Metadata about a profiling run.""" + + kernel_file: str + problem_file: str + round_num: int + timestamp: str + ncu_version: str | None + + +@dataclass +class ProfilerResults: + """ + Results from a kernel profiling run. + + This dataclass encapsulates both the metrics DataFrame and the parsed + metrics dictionary, along with metadata about the profiling run. + """ + + metrics_df: pd.DataFrame + metrics: Dict[str, Any] + metadata: ProfilerMetadata + + def to_dict(self) -> Dict[str, Any]: + """Convert results to a dictionary for JSON serialization.""" + return { + "metrics": self.metrics, + "metadata": asdict(self.metadata), + } + + def to_json(self) -> str: + """Convert results to a JSON string.""" + return json.dumps(self.to_dict(), indent=2) + class KernelProfiler: - """Profiles Triton kernels using NVIDIA Nsight Compute (NCU).""" def __init__( self, logger: logging.Logger, - temp_dir: Path, + artifacts_dir: Path, logs_dir: Path, - ncu_bin_path: Optional[str] = None, + ncu_bin_path: str | None = None, + ncu_timeout_seconds: int = DEFAULT_NCU_TIMEOUT_SECONDS, ): """ Initialize the kernel profiler. Args: logger: Logger instance - temp_dir: Temporary directory for profiling artifacts + artifacts_dir: Directory for optimization artifacts logs_dir: Directory for saving profiling logs ncu_bin_path: Path to NCU binary (auto-detect if None) + ncu_timeout_seconds: Timeout for NCU profiling in seconds """ self.logger = logger - self.temp_dir = temp_dir + self.artifacts_dir = artifacts_dir self.logs_dir = logs_dir self.ncu_bin_path = ncu_bin_path - self.wrapper_generator = NCUWrapperGenerator(logger) + self.ncu_timeout_seconds = ncu_timeout_seconds + self.wrapper_factory = NCUWrapperFactory(logger) - def _get_ncu_version(self) -> Optional[str]: + @cached_property + def ncu_version(self) -> str | None: """ - Get NCU version string. + NCU version string (cached). Returns: - NCU version string or None if failed + Version string like "2025.2.1.0" or None if unavailable """ try: ncu_cmd = self.ncu_bin_path or "ncu" @@ -76,21 +120,35 @@ def _get_ncu_version(self) -> Optional[str]: timeout=5, ) if result.returncode == 0: - # Extract version from output (e.g., "NVIDIA Nsight Compute 2024.3.1") - version_line = result.stdout.strip().split("\n")[0] - return version_line + # Extract version from output + # Example: "Version 2025.2.1.0 (build 35987062) (public-release)" + for line in result.stdout.strip().split("\n"): + if line.startswith("Version "): + return line.split()[1] + return None return None except Exception as e: self.logger.debug(f"Failed to get NCU version: {e}") return None + def _wait_with_backoff(self, attempt: int) -> None: + """ + Wait with exponential backoff before retrying. + + Args: + attempt: Current attempt number (1-indexed) + """ + wait_time = 2**attempt # Exponential backoff: 2, 4, 8 seconds + self.logger.warning(f"Retrying in {wait_time}s...") + time.sleep(wait_time) + def profile_kernel( self, kernel_file: Path, problem_file: Path, round_num: int, - max_retries: int = 3, - ) -> Tuple[Optional[pd.DataFrame], Optional[Dict[str, Any]]]: + max_retries: int = 2, + ) -> ProfilerResults | None: """ Profile kernel with NCU (with retry logic). @@ -100,30 +158,28 @@ def profile_kernel( Args: kernel_file: Path to kernel file problem_file: Path to problem file - round_num: Current optimization round number - max_retries: Maximum number of retry attempts (default: 3) + round_num: Current optimization round number (used for file naming + and tracking which optimization iteration this profiling belongs to) + max_retries: Maximum number of retry attempts (default: 2) Returns: - Tuple of (metrics_df, metrics_json) or (None, None) on failure + ProfilerResults containing metrics and metadata, or None on failure Example: - >>> profiler = KernelProfiler(logger, temp_dir, logs_dir) - >>> metrics_df, metrics_json = profiler.profile_kernel( + >>> profiler = KernelProfiler(logger, artifacts_dir, logs_dir) + >>> results = profiler.profile_kernel( ... Path("kernel.py"), Path("problem.py"), round_num=1 ... ) - >>> if metrics_json: - ... print(f"DRAM throughput: {metrics_json['dram__throughput']}") + >>> if results: + ... print(f"DRAM throughput: {results.metrics['dram__throughput']}") """ - wrapper_file = None + # Create NCU wrapper script + wrapper_file = self.wrapper_factory.create_ncu_wrapper( + kernel_file, problem_file, self.artifacts_dir + ) for attempt in range(1, max_retries + 1): try: - # Create NCU wrapper script (cached if unchanged) - if wrapper_file is None: - wrapper_file = self.wrapper_generator.create_ncu_wrapper( - kernel_file, problem_file, self.temp_dir - ) - self.logger.info( f"[Round {round_num}] NCU profiling attempt {attempt}/{max_retries}..." ) @@ -132,98 +188,103 @@ def profile_kernel( csv_file = f"ncu_round_{round_num}.csv" csv_path = profile_triton_kernel( benchmark_script=wrapper_file, - workdir=self.temp_dir, + workdir=self.artifacts_dir, out_csv=csv_file, ncu_bin=self.ncu_bin_path, launch_count=20, - timeout=120, + timeout=self.ncu_timeout_seconds, ) # Load and parse metrics metrics_df = load_ncu_metrics(csv_path, select="last") - metrics_json = json.loads(metrics_to_prompt(metrics_df)) + metrics = json.loads(metrics_to_prompt(metrics_df)) + + # Build ProfilerResults + results = ProfilerResults( + metrics_df=metrics_df, + metrics=metrics, + metadata=ProfilerMetadata( + kernel_file=str(kernel_file), + problem_file=str(problem_file), + round_num=round_num, + timestamp=datetime.utcnow().isoformat() + "Z", + ncu_version=self.ncu_version, + ), + ) # Save metrics with metadata - self._save_metrics_with_metadata( - metrics_json, kernel_file, problem_file, round_num - ) + self._save_profiler_results(results) self.logger.info(f"✅ NCU profiling completed for round {round_num}") - return metrics_df, metrics_json + return results except FileNotFoundError as e: self.logger.error(f"❌ File not found during profiling: {e}") - return None, None + return None except subprocess.TimeoutExpired: - self.logger.error( - f"❌ NCU profiling timed out after 120s (attempt {attempt}/{max_retries})" - ) - if attempt < max_retries: - wait_time = 2**attempt # Exponential backoff: 2, 4, 8 seconds - self.logger.warning(f"Retrying in {wait_time}s...") - time.sleep(wait_time) + is_final_attempt = attempt >= max_retries + if is_final_attempt: + self.logger.error( + f"❌ NCU profiling timed out after {self.ncu_timeout_seconds}s " + f"(final attempt {attempt}/{max_retries})" + ) + return None + else: + self.logger.debug( + f"NCU profiling timed out after {self.ncu_timeout_seconds}s " + f"(attempt {attempt}/{max_retries})" + ) + self._wait_with_backoff(attempt) continue - return None, None except json.JSONDecodeError as e: - self.logger.error(f"❌ Failed to parse NCU metrics: {e}") - if attempt < max_retries: - wait_time = 2**attempt - self.logger.warning(f"Retrying in {wait_time}s...") - time.sleep(wait_time) + is_final_attempt = attempt >= max_retries + if is_final_attempt: + self.logger.error( + f"❌ Failed to parse NCU metrics (final attempt): {e}" + ) + return None + else: + self.logger.debug( + f"Failed to parse NCU metrics (attempt {attempt}/{max_retries}): {e}" + ) + self._wait_with_backoff(attempt) continue - return None, None except Exception as e: - self.logger.error( - f"❌ Unexpected error during profiling (attempt {attempt}/{max_retries}): {e}", - exc_info=True, - ) - if attempt < max_retries: - wait_time = 2**attempt - self.logger.warning(f"Retrying in {wait_time}s...") - time.sleep(wait_time) + is_final_attempt = attempt >= max_retries + if is_final_attempt: + self.logger.error( + f"❌ Unexpected error during profiling (final attempt): {e}", + exc_info=True, + ) + return None + else: + self.logger.debug( + f"Unexpected error during profiling (attempt {attempt}/{max_retries}): {e}" + ) + self._wait_with_backoff(attempt) continue - return None, None # All retries exhausted self.logger.error( f"❌ NCU profiling failed after {max_retries} attempts for round {round_num}" ) - return None, None + return None - def _save_metrics_with_metadata( - self, - metrics_json: Dict[str, Any], - kernel_file: Path, - problem_file: Path, - round_num: int, - ) -> None: + def _save_profiler_results(self, results: ProfilerResults) -> None: """ - Save profiling metrics with additional metadata. + Save profiling results with metadata to a JSON file. Args: - metrics_json: NCU metrics as JSON dict - kernel_file: Path to kernel file - problem_file: Path to problem file - round_num: Current optimization round number + results: ProfilerResults to save """ - metrics_file = self.logs_dir / f"round{round_num:03d}_ncu_metrics.json" - - # Build metadata - metadata = { - "metrics": metrics_json, - "metadata": { - "kernel_file": str(kernel_file), - "problem_file": str(problem_file), - "round_num": round_num, - "timestamp": datetime.utcnow().isoformat() + "Z", - "ncu_version": self._get_ncu_version(), - }, - } + metrics_file = ( + self.logs_dir / f"round{results.metadata.round_num:03d}_ncu_metrics.json" + ) with open(metrics_file, "w") as f: - json.dump(metadata, f, indent=2) + f.write(results.to_json()) self.logger.debug(f"Saved metrics with metadata: {metrics_file}") diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_factory.py similarity index 65% rename from triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py rename to triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_factory.py index 64fd1c1..e3c5a3c 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_generator.py +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_factory.py @@ -15,58 +15,38 @@ """NCU wrapper script generation for kernel profiling.""" import logging +from functools import cached_property from pathlib import Path -from typing import Optional -try: - from jinja2 import Template +from jinja2 import Template - HAS_JINJA2 = True -except ImportError: - HAS_JINJA2 = False - -class NCUWrapperGenerator: - """Generates NCU wrapper scripts for profiling Triton kernels.""" +class NCUWrapperFactory: + """Factory for creating NCU wrapper scripts for profiling Triton kernels.""" # Template file path (relative to this file) WRAPPER_TEMPLATE = Path(__file__).parent / "ncu_wrapper_template.j2" def __init__(self, logger: logging.Logger): """ - Initialize the NCU wrapper generator. + Initialize the NCU wrapper factory. Args: logger: Logger instance """ self.logger = logger - self._template_cache: Optional[Template] = None - def _load_template(self) -> Template: + @cached_property + def template(self) -> Template: """ - Load the Jinja2 template (cached). - - Returns: - Jinja2 Template object + Jinja2 template for wrapper script generation. Raises: - ImportError: If Jinja2 is not installed FileNotFoundError: If template file doesn't exist """ - if self._template_cache is not None: - return self._template_cache - - if not HAS_JINJA2: - raise ImportError( - "Jinja2 is required for wrapper generation. " - "Install it with: pip install jinja2" - ) - if not self.WRAPPER_TEMPLATE.exists(): raise FileNotFoundError(f"Template not found: {self.WRAPPER_TEMPLATE}") - - self._template_cache = Template(self.WRAPPER_TEMPLATE.read_text()) - return self._template_cache + return Template(self.WRAPPER_TEMPLATE.read_text()) def create_ncu_wrapper( self, @@ -75,7 +55,6 @@ def create_ncu_wrapper( output_dir: Path, dtype_inference: bool = True, model_extraction: bool = True, - use_cache: bool = True, ) -> Path: """ Create NCU wrapper script for profiling. @@ -91,7 +70,6 @@ def create_ncu_wrapper( output_dir: Directory to write wrapper script dtype_inference: Enable automatic dtype inference from kernel source (default: True) model_extraction: Enable model weight extraction for Conv/Linear kernels (default: True) - use_cache: Reuse existing wrapper if files haven't changed (default: True) Returns: Path to created wrapper script @@ -112,21 +90,8 @@ def create_ncu_wrapper( wrapper_file = output_dir / "ncu_wrapper.py" - # Check cache: reuse wrapper if it's fresh - if use_cache and wrapper_file.exists(): - wrapper_mtime = wrapper_file.stat().st_mtime - kernel_mtime = kernel_file.stat().st_mtime - problem_mtime = problem_file.stat().st_mtime - - if wrapper_mtime > kernel_mtime and wrapper_mtime > problem_mtime: - self.logger.info( - f"Reusing cached NCU wrapper (fresher than source files): {wrapper_file}" - ) - return wrapper_file - - # Load template and render - template = self._load_template() - wrapper_content = template.render( + # Render template + wrapper_content = self.template.render( kernel_file_parent=repr(str(kernel_file.parent)), problem_file_parent=repr(str(problem_file.parent)), kernel_module=kernel_file.stem, From 543453a6803ce56a276b4a78540f0e0884b09e45 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Mon, 12 Jan 2026 17:06:54 -0800 Subject: [PATCH 11/12] fix ruff --- .../opt_worker_component/profiling/kernel_profiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py index bcc5a70..682c4e7 100644 --- a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -77,7 +77,6 @@ def to_json(self) -> str: class KernelProfiler: - def __init__( self, logger: logging.Logger, From 16752e5216a097146457d2f017749c5f57622ab9 Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Tue, 20 Jan 2026 20:49:05 -0800 Subject: [PATCH 12/12] address reviews to init kernel_perf_agent and gracefully handle sudo --- kernel_perf_agent/__init__.py | 18 +++++ .../kernel_opt/profiler/ncu_profiler.py | 71 ++++++++++++++----- 2 files changed, 71 insertions(+), 18 deletions(-) create mode 100644 kernel_perf_agent/__init__.py diff --git a/kernel_perf_agent/__init__.py b/kernel_perf_agent/__init__.py new file mode 100644 index 0000000..1f49766 --- /dev/null +++ b/kernel_perf_agent/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Kernel Performance Agent package.""" + +# "Kernel Performance Agent package +__all__ = [] diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py index 710d5c6..4ce8568 100644 --- a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -90,6 +90,7 @@ def profile_triton_kernel( ncu_bin: Optional[str] = None, launch_count: int = 20, timeout: int = 120, + use_sudo: bool = False, ) -> Path: """ Profile a Triton kernel using NCU. @@ -102,6 +103,8 @@ def profile_triton_kernel( ncu_bin: Path to NCU binary (default: auto-detect) launch_count: Number of kernel launches to profile timeout: Timeout in seconds for NCU execution + use_sudo: Whether to run NCU with sudo. Can also be enabled via + KERNELAGENT_NCU_USE_SUDO=1 environment variable. Default: False Returns: Path to output CSV file @@ -110,6 +113,9 @@ def profile_triton_kernel( RuntimeError: If NCU profiling fails FileNotFoundError: If NCU binary or output CSV not found """ + # Check for environment variable override + use_sudo = use_sudo or os.environ.get("KERNELAGENT_NCU_USE_SUDO", "0") == "1" + # Resolve paths if python_executable is None: python_executable = sys.executable @@ -146,28 +152,32 @@ def profile_triton_kernel( ) # Build NCU command - cmd = [ - "sudo", - "-E", - f"--preserve-env={preserve}", - ncu_bin, - "--csv", - "--page=raw", - "--kernel-name-base=demangled", - "--target-processes=all", - "--replay-mode=kernel", - "--profile-from-start=on", - f"--log-file={str(csv_path)}", - f"--metrics={METRICS}", - "--launch-skip=0", - f"--launch-count={launch_count}", - python_executable, - str(benchmark_script), - ] + cmd = [] + if use_sudo: + cmd.extend(["sudo", "-E", f"--preserve-env={preserve}"]) + + cmd.extend( + [ + ncu_bin, + "--csv", + "--page=raw", + "--kernel-name-base=demangled", + "--target-processes=all", + "--replay-mode=kernel", + "--profile-from-start=on", + f"--log-file={str(csv_path)}", + f"--metrics={METRICS}", + "--launch-skip=0", + f"--launch-count={launch_count}", + python_executable, + str(benchmark_script), + ] + ) print("[NCU] Running profiling...") print(f"[NCU] Benchmark: {benchmark_script.name}") print(f"[NCU] Output: {csv_path}") + print(f"[NCU] Using sudo: {use_sudo}") print(f"[NCU] Command: {' '.join(cmd[:10])}... (truncated)") try: @@ -182,6 +192,29 @@ def profile_triton_kernel( if result.returncode != 0: error_msg = result.stderr or result.stdout + + # Check for common permission errors + permission_errors = [ + "permission", + "ERR_NVGPUCTRPERM", + "profiling permissions", + "requires root", + "access denied", + ] + is_permission_error = any( + err.lower() in error_msg.lower() for err in permission_errors + ) + + if is_permission_error and not use_sudo: + raise RuntimeError( + f"NCU requires elevated permissions on this system.\n" + f"Options to resolve:\n" + f" 1. Rerun with use_sudo=True\n" + f" 2. Set environment variable: KERNELAGENT_NCU_USE_SUDO=1\n" + f" 3. Configure driver permissions: " + f"Original error:\n{error_msg[:500]}" + ) + raise RuntimeError( f"NCU profiling failed with return code {result.returncode}:\n{error_msg[:500]}" ) @@ -202,6 +235,8 @@ def profile_triton_kernel( except subprocess.TimeoutExpired: raise RuntimeError(f"NCU profiling timed out after {timeout} seconds") + except RuntimeError: + raise except Exception as e: raise RuntimeError(f"NCU profiling failed: {e}")