diff --git a/kernel_perf_agent/README.md b/kernel_perf_agent/README.md new file mode 100644 index 0000000..e27dfec --- /dev/null +++ b/kernel_perf_agent/README.md @@ -0,0 +1,3 @@ +# Kernel Performance Utilities + +Low-level, reusable utilities for kernel optimization. diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/__init__.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/__init__.py new file mode 100644 index 0000000..34e8023 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Diagnose Prompt Module for Hardware Bottleneck Analysis. + +This module provides prompt building utilities for the Judge LLM that +analyzes NCU profiling metrics to identify performance bottlenecks. +""" + +from .gpu_specs import get_gpu_specs +from .judger_prompts import ( + build_judge_optimization_prompt, + extract_judge_response, + validate_judge_response, +) + +__all__ = [ + "get_gpu_specs", + "build_judge_optimization_prompt", + "extract_judge_response", + "validate_judge_response", +] diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py new file mode 100644 index 0000000..e465116 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GPU Specifications Database for Bottleneck Analysis + +This module provides GPU hardware specifications needed for performance analysis +and bottleneck identification. It includes peak compute performance, memory bandwidth, +cache sizes, and SM counts for common NVIDIA GPUs. + +""" + +import subprocess +from typing import Any + +from kernel_perf_agent.kernel_opt.diagnose_prompt.gpu_specs_database import ( + GPU_SPECS_DATABASE, +) + +__all__ = ["GPU_SPECS_DATABASE", "query_gpu_name", "get_gpu_specs"] + + +def query_gpu_name() -> str | None: + """ + Query GPU name using nvidia-smi. + + Returns: + GPU name string, or None if query fails + """ + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + # Take only the first GPU (nvidia-smi returns one line per GPU) + gpu_name = result.stdout.strip().split("\n")[0].strip() + return gpu_name + except (subprocess.TimeoutExpired, FileNotFoundError, Exception): + pass + return None + + +def get_gpu_specs(gpu_name: str | None = None) -> dict[str, Any]: + """ + Get GPU specifications for bottleneck analysis. + + This function returns hardware specifications needed for performance analysis, + including peak compute performance, memory bandwidth, cache sizes, and SM counts. + + Args: + gpu_name: GPU name (if None, auto-detect with nvidia-smi) + + Returns: + Dictionary with GPU specifications containing: + - name: GPU name + - architecture: GPU architecture (e.g., "Ampere", "Hopper") + - peak_fp32_tflops: Peak FP32 compute performance in TFLOPS + - peak_fp16_tflops: Peak FP16 compute performance in TFLOPS + - peak_bf16_tflops: Peak BF16 compute performance in TFLOPS (0 if not supported) + - peak_memory_bw_gbps: Peak memory bandwidth in GB/s + - sm_count: Number of streaming multiprocessors + - max_threads_per_sm: Maximum threads per SM + - l1_cache_kb: L1 cache size in KB per SM + - l2_cache_mb: Total L2 cache size in MB + - memory_gb: Total GPU memory in GB + - memory_type: Memory type (e.g., "HBM2e", "GDDR6X") + + Examples: + >>> specs = get_gpu_specs() # Auto-detect + >>> print(f"Peak BW: {specs['peak_memory_bw_gbps']} GB/s") + + >>> specs = get_gpu_specs("NVIDIA A100") + >>> print(f"SM Count: {specs['sm_count']}") + """ + # Auto-detect if not provided + if gpu_name is None: + gpu_name = query_gpu_name() + + # Return default if detection failed + if gpu_name is None: + print("⚠️ GPU auto-detection failed, using A100 specs as fallback") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + # Try exact match + if gpu_name in GPU_SPECS_DATABASE: + return GPU_SPECS_DATABASE[gpu_name].copy() + + # Try fuzzy match (contains or partial match) + gpu_name_lower = gpu_name.lower() + for key, specs in GPU_SPECS_DATABASE.items(): + key_lower = key.lower() + # Check if either name contains the other + if gpu_name_lower in key_lower or key_lower in gpu_name_lower: + print(f"ℹ️ Matched '{gpu_name}' to '{key}' (fuzzy match)") + return specs.copy() + + # Fallback to A100 specs with warning + print(f"⚠️ Unknown GPU: '{gpu_name}', using A100 specs as fallback") + print(f" Available GPUs: {', '.join(GPU_SPECS_DATABASE.keys())}") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + +if __name__ == "__main__": + print("GPU Specifications Module") + print("=" * 60) + + # Auto-detect GPU + detected_name = query_gpu_name() + if detected_name: + print(f"\nDetected GPU: {detected_name}") + else: + print("\nNo GPU detected (nvidia-smi not available)") + exit() + + # Get specs + specs = get_gpu_specs() + print( + f"\nUsing specs for: {specs['name']} ({specs.get('architecture', 'Unknown')})" + ) + print(f" - Peak Memory Bandwidth: {specs['peak_memory_bw_gbps']} GB/s") + print(f" - Peak FP32 Performance: {specs['peak_fp32_tflops']} TFLOPS") + print(f" - SM Count: {specs['sm_count']}") + + # Show all available GPUs + print(f"\n{'=' * 60}") + print("Available GPU specifications in database:") + for gpu_name in sorted(GPU_SPECS_DATABASE.keys()): + print(f" - {gpu_name}") diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py new file mode 100644 index 0000000..d5e4586 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GPU Specifications Database + +This module contains the GPU hardware specifications database used for +performance analysis and bottleneck identification. Separated into its +own file to allow easier module overriding. + +Sources: NVIDIA official specifications, manufacturer datasheets +""" + +GPU_SPECS_DATABASE: dict[str, dict[str, object]] = { + "NVIDIA A100": { + "name": "NVIDIA A100", + "architecture": "Ampere", + "peak_fp32_tflops": 19.5, + "peak_fp16_tflops": 312.0, + "peak_bf16_tflops": 312.0, + "peak_memory_bw_gbps": 1555, + "sm_count": 108, + "max_threads_per_sm": 2048, + "l1_cache_kb": 192, + "l2_cache_mb": 40, + "memory_gb": 40, + "memory_type": "HBM2e", + }, + "NVIDIA H100": { + "name": "NVIDIA H100", + "architecture": "Hopper", + "peak_fp32_tflops": 51.0, + "peak_fp16_tflops": 989.0, + "peak_bf16_tflops": 989.0, + "peak_memory_bw_gbps": 3352, + "sm_count": 132, + "max_threads_per_sm": 2048, + "l1_cache_kb": 256, + "l2_cache_mb": 50, + "memory_gb": 80, + "memory_type": "HBM3", + }, + "NVIDIA RTX 4090": { + "name": "NVIDIA RTX 4090", + "architecture": "Ada Lovelace", + "peak_fp32_tflops": 82.6, + "peak_fp16_tflops": 165.0, + "peak_bf16_tflops": 165.0, + "peak_memory_bw_gbps": 1008, + "sm_count": 128, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 72, + "memory_gb": 24, + "memory_type": "GDDR6X", + }, + "NVIDIA RTX 5080": { + "name": "NVIDIA RTX 5080", + "architecture": "Blackwell", + "peak_fp32_tflops": 57.0, + "peak_fp16_tflops": 114.0, + "peak_bf16_tflops": 114.0, + "peak_memory_bw_gbps": 960, + "sm_count": 84, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 64, + "memory_gb": 16, + "memory_type": "GDDR7", + }, +} diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py new file mode 100644 index 0000000..e07749e --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py @@ -0,0 +1,343 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Prompt Builder for Hardware Bottleneck Diagnosis + +This module provides prompt templates and builder functions for the Judge LLM +that analyzes NCU profiling metrics to identify performance bottlenecks and +provide specific optimization recommendations. + +The Judge uses a dual-bottleneck framework based on NCU hardware profiling: +- bottleneck_1 (Primary): Highest-impact performance issue +- bottleneck_2 (Secondary): Different category issue that also limits performance + +Both bottlenecks are selected from NCU hardware profiling categories: +- memory-bound +- compute-bound +- occupancy-limited +- latency-bound + +Metric definitions are in metric_schema.py. +""" + +from typing import Any, Callable + +from .metric_schema import GPU_MEMORY_FIELDS, GPU_SPEC_FIELDS, NCU_METRIC_SECTIONS + + +# ============================================================================= +# Section Renderers +# ============================================================================= + + +def render_problem_description(problem_description: str) -> list[str]: + """Render the problem description section.""" + return ["## Problem Description", "", problem_description] + + +def render_kernel_code(kernel_code: str, language: str = "python") -> list[str]: + """Render the kernel code section with syntax highlighting.""" + return ["", "## Current Kernel Code", "", f"```{language}", kernel_code, "```"] + + +def render_gpu_specs(gpu_specs: dict[str, Any]) -> list[str]: + """Render the GPU hardware specifications section.""" + lines = ["", "## GPU Hardware Specifications", ""] + + for label, key, unit in GPU_SPEC_FIELDS: + value = gpu_specs.get(key, "N/A") + if value == "N/A": + lines.append(f"- **{label}:** N/A") + else: + lines.append(f"- **{label}:** {value}{unit}") + + for label, size_key, type_key, size_unit in GPU_MEMORY_FIELDS: + size_value = gpu_specs.get(size_key, "N/A") + type_value = gpu_specs.get(type_key, "") + lines.append(f"- **{label}:** {size_value}{size_unit} {type_value}") + + return lines + + +def render_ncu_metrics( + ncu_metrics: dict[str, Any], + get_metric_fn: Callable[[str, str], str], +) -> list[str]: + """Render the NCU profiling metrics section.""" + lines = ["", "## NCU Profiling Metrics"] + + for section_name, metrics in NCU_METRIC_SECTIONS.items(): + lines.append("") + lines.append(f"### {section_name}") + for label, key, unit in metrics: + value = get_metric_fn(key, "N/A") + lines.append(f"- **{label}:** {value}{unit}") + + return lines + + +def render_task_instructions() -> list[str]: + """Render the task instructions section for dual-bottleneck analysis.""" + return [ + "", + "## Your Task", + "", + "Identify exactly TWO distinct bottlenecks from the NCU profiling metrics above:", + "1. **Bottleneck 1 (Primary)**: The highest-impact performance issue", + "2. **Bottleneck 2 (Secondary)**: A different category issue that also limits performance", + "", + "For each bottleneck, cite 3-4 specific metrics that reveal the issue, " + "and recommend ONE actionable optimization.", + "", + "**Be surgical and metrics-driven.** Return JSON in the format specified in the system prompt.", + ] + + +def create_metric_getter(kernel_metrics: dict[str, Any]) -> Callable[[str, str], str]: + """Create a metric getter function for a specific kernel's metrics.""" + + def get_metric(key: str, default: str = "N/A") -> str: + val = kernel_metrics.get(key, default) + if isinstance(val, (int, float)): + return f"{val:.2f}" + return str(val) + + return get_metric + + +# ============================================================================= +# Bottleneck Analysis +# ============================================================================= + + +# System prompt for the Judge LLM (Dual-Bottleneck NCU Analysis) +JUDGE_SYSTEM_PROMPT = """You are a senior GPU performance engineer. Analyze the target GPU spec, the current kernel, and the Nsight Compute (NCU) profiling metrics. Identify EXACTLY TWO DISTINCT bottlenecks from the hardware profiling data, and propose specific optimization methods for each. Be surgical and metrics-driven. + +## Bottleneck Categories (NCU Hardware Profiling) + +Analyze fundamental resource utilization using NCU profiling data: + +## Bottleneck Categories (Indicators Only) +- **memory-bound**: High DRAM throughput (>60%), low L1/L2 hit rates (<70%), high memory stalls (>30%) +- **compute-bound**: Low DRAM throughput (<40%), high compute utilization (>60%), low memory stalls (<20%) +- **occupancy-limited**: Low warp active (<50%), high register usage (>100/thread), shared memory pressure (>80%) +- **latency-bound**: High total stalls (>40%), memory dependency stalls dominate, long scoreboard stalls + +- Return EXACTLY TWO DISTINCT bottlenecks with DIFFERENT categories +- Both bottlenecks must be from: {memory-bound, compute-bound, occupancy-limited, latency-bound} +- For each bottleneck, cite 3-4 specific NCU metric values that reveal the issue +- Propose ONE actionable optimization method per bottleneck +- Keep fields brief; avoid lists of alternatives, disclaimers, or generic advice + +## Output Format (JSON - STRICT) + +```json +{ + "bottleneck_1": { + "category": "", + "root_cause": "", + "suggestion": "", + "priority_metrics": ["", "", ""], + "expected_improvement": "" + }, + "bottleneck_2": { + "category": "", + "root_cause": "", + "suggestion": "", + "priority_metrics": ["", "", ""], + "expected_improvement": "" + } +} +``` + +## Important Notes + +- bottleneck_1 is the PRIMARY (highest-impact) issue +- bottleneck_2 is the SECONDARY issue (different category from bottleneck_1) +- They should be independently addressable (fixing one doesn't automatically fix the other) + +Follow the Rules exactly. Return JSON in the specified format. +""" + + +def build_judge_optimization_prompt( + kernel_code: str, + problem_description: str, + ncu_metrics: dict[str, Any], + gpu_specs: dict[str, Any], +) -> tuple[str, str]: + """ + Build system and user prompts for Judge to analyze bottleneck. + + This function constructs detailed prompts for the Judge LLM that include: + - The kernel code being analyzed + - The original problem description + - Complete NCU profiling metrics + - GPU hardware specifications + + Args: + kernel_code: Current Triton kernel code + problem_description: Original problem description + ncu_metrics: NCU profiling metrics as a dictionary (from metrics_to_prompt) + gpu_specs: GPU specifications (from get_gpu_specs) + + Returns: + Tuple of (system_prompt, user_prompt) + + Example: + >>> sys_prompt, user_prompt = build_judge_optimization_prompt( + ... kernel_code=kernel_code, + ... problem_description=problem_desc, + ... ncu_metrics=ncu_metrics, + ... gpu_specs=gpu_specs, + ... ) + >>> response = llm.call([ + ... {"role": "system", "content": sys_prompt}, + ... {"role": "user", "content": user_prompt} + ... ]) + """ + if not ncu_metrics: + raise ValueError("NCU metrics are empty - cannot build judge prompt") + + # Extract first kernel's metrics for the metric getter + first_kernel = list(ncu_metrics.values())[0] + get_metric_fn = create_metric_getter(first_kernel) + + # Build user prompt using modular section renderers + parts: list[str] = [] + + # Compose sections using renderers + parts.extend(render_problem_description(problem_description)) + parts.extend(render_kernel_code(kernel_code)) + parts.extend(render_gpu_specs(gpu_specs)) + parts.extend(render_ncu_metrics(ncu_metrics, get_metric_fn)) + parts.extend(render_task_instructions()) + + user_prompt = "\n".join(parts) + return JUDGE_SYSTEM_PROMPT, user_prompt + + +def extract_judge_response(response_text: str) -> dict[str, Any] | None: + """ + Extract and parse JSON from Judge LLM response. + + This function handles various response formats and provides fallback strategies + for robust JSON extraction. Expects dual-bottleneck format with bottleneck_1 + and bottleneck_2 fields. + + Args: + response_text: Raw text response from Judge LLM + + Returns: + Parsed JSON dictionary with bottleneck_1 and bottleneck_2, + or None if extraction fails + + Example: + >>> response = llm.call(judge_prompts) + >>> analysis = extract_judge_response(response) + >>> if analysis: + ... print(f"Bottleneck 1: {analysis['bottleneck_1']['category']}") + ... print(f"Bottleneck 2: {analysis['bottleneck_2']['category']}") + """ + import json + import re + + # Strategy 1: Find JSON in code block + match = re.search(r"```json\s*(\{.*?\})\s*```", response_text, re.DOTALL) + if match: + try: + data = json.loads(match.group(1)) + if "bottleneck_1" in data and "bottleneck_2" in data: + return data + except json.JSONDecodeError: + pass + + # Strategy 2: Find first { ... } block with "bottleneck_1" field + match = re.search(r'\{[^}]*"bottleneck_1"[^}]*\}', response_text, re.DOTALL) + if match: + try: + # Extract the full JSON object (may be nested) + start_pos = response_text.find("{", match.start()) + brace_count = 0 + end_pos = start_pos + + for i in range(start_pos, len(response_text)): + if response_text[i] == "{": + brace_count += 1 + elif response_text[i] == "}": + brace_count -= 1 + if brace_count == 0: + end_pos = i + 1 + break + + json_str = response_text[start_pos:end_pos] + data = json.loads(json_str) + if "bottleneck_1" in data and "bottleneck_2" in data: + return data + except (json.JSONDecodeError, ValueError): + pass + + # Strategy 3: Find any JSON object with dual-bottleneck structure + match = re.search( + r'\{\s*"bottleneck_1"\s*:\s*\{.*?\}\s*,\s*"bottleneck_2"\s*:\s*\{.*?\}\s*\}', + response_text, + re.DOTALL, + ) + if match: + try: + return json.loads(match.group(0)) + except json.JSONDecodeError: + pass + + # Return None if all strategies fail + return None + + +def validate_judge_response(analysis: dict[str, Any]) -> bool: + """Validate that Judge response contains required dual-bottleneck fields.""" + if "bottleneck_1" not in analysis or "bottleneck_2" not in analysis: + return False + return _validate_bottleneck_entry( + analysis["bottleneck_1"] + ) and _validate_bottleneck_entry(analysis["bottleneck_2"]) + + +VALID_CATEGORIES = { + "memory-bound", + "compute-bound", + "occupancy-limited", + "latency-bound", +} + + +def _validate_bottleneck_entry(bottleneck: dict[str, Any]) -> bool: + """Validate a single bottleneck entry.""" + required = [ + "category", + "root_cause", + "suggestion", + "priority_metrics", + "expected_improvement", + ] + if not all(f in bottleneck for f in required): + return False + if bottleneck["category"] not in VALID_CATEGORIES: + return False + if not isinstance(bottleneck["priority_metrics"], list): + return False + for f in ["root_cause", "suggestion", "expected_improvement"]: + if not isinstance(bottleneck[f], str) or len(bottleneck[f].strip()) < 5: + return False + return True diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/metric_schema.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/metric_schema.py new file mode 100644 index 0000000..64d1d67 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/metric_schema.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Metric Schema Definitions for NCU Profiling and GPU Specifications. + +This module provides the single source of truth for: +- NCU profiling metric definitions (keys, labels, units) +- GPU specification field definitions + +Schema Format: List of tuples (display_label, key, unit_suffix) +- display_label: Human-readable name shown in prompts +- key: NCU metric key or GPU spec dictionary key +- unit_suffix: Unit to append after value (e.g., "%", " GB/s", " bytes") +""" + +from typing import Dict, List, Tuple + +# Type alias for metric definition: (label, key, unit) +MetricDef = Tuple[str, str, str] + +# ============================================================================= +# GPU Specification Fields +# ============================================================================= + +GPU_SPEC_FIELDS: List[MetricDef] = [ + ("Name", "name", ""), + ("Architecture", "architecture", ""), + ("Peak Memory Bandwidth", "peak_memory_bw_gbps", " GB/s"), + ("Peak FP32 Performance", "peak_fp32_tflops", " TFLOPS"), + ("Peak FP16 Performance", "peak_fp16_tflops", " TFLOPS"), + ("SM Count", "sm_count", ""), + ("Max Threads per SM", "max_threads_per_sm", ""), + ("L1 Cache per SM", "l1_cache_kb", " KB"), + ("L2 Cache (Total)", "l2_cache_mb", " MB"), +] + +# Special case: Memory Size has two fields combined +GPU_MEMORY_FIELDS: List[Tuple[str, str, str, str]] = [ + # (label, size_key, type_key, size_unit) + ("Memory Size", "memory_gb", "memory_type", " GB"), +] + +# ============================================================================= +# NCU Profiling Metric Sections +# ============================================================================= + +NCU_METRIC_SECTIONS: Dict[str, List[MetricDef]] = { + "SM & Compute Utilization": [ + ("SM Cycles Active", "sm__cycles_active.avg", ""), + ("Warp Active", "sm__warps_active.avg.pct_of_peak_sustained_active", "%"), + ("Total Instructions Executed", "sm__inst_executed.sum", ""), + ( + "Tensor Core Utilization", + "sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", + "%", + ), + ( + "Tensor Core Pipeline Active", + "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", + "%", + ), + ], + "Memory Bandwidth & Cache": [ + ( + "DRAM Throughput", + "dram__throughput.avg.pct_of_peak_sustained_elapsed", + "%", + ), + ("DRAM Bandwidth", "dram__bytes.sum.per_second", " bytes/sec"), + ( + "GPU DRAM Throughput", + "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed", + "%", + ), + ("DRAM Bytes Read", "dram__bytes_read.sum", " bytes"), + ("DRAM Bytes Write", "dram__bytes_write.sum", " bytes"), + ("L1 Cache Hit Rate", "l1tex__t_sector_hit_rate.pct", "%"), + ( + "L1 Throughput", + "l1tex__throughput.avg.pct_of_peak_sustained_active", + "%", + ), + ("L2 Cache Hit Rate", "lts__t_sector_hit_rate.pct", "%"), + ( + "L2 Throughput", + "lts__throughput.avg.pct_of_peak_sustained_active", + "%", + ), + ], + "Memory Access Patterns": [ + ( + "Memory Coalescing", + "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", + "%", + ), + ( + "Branch Uniformity", + "smsp__sass_average_branch_targets_threads_uniform.pct", + "%", + ), + ], + "Occupancy & Resources": [ + ("Occupancy Limited By Blocks", "launch__occupancy_limit_blocks", ""), + ("Occupancy Limited By Registers", "launch__occupancy_limit_registers", ""), + ( + "Occupancy Limited By Shared Memory", + "launch__occupancy_limit_shared_mem", + "", + ), + ("Registers per Thread", "launch__registers_per_thread", ""), + ( + "Shared Memory per Block", + "launch__shared_mem_per_block_allocated", + " bytes", + ), + ], + "Stall Metrics (Warp Issue Stalls)": [ + ( + "Short Scoreboard Stalls", + "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", + "%", + ), + ( + "Long Scoreboard Stalls", + "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", + "%", + ), + ( + "Barrier Stalls", + "smsp__warp_issue_stalled_barrier_per_warp_active.pct", + "%", + ), + ( + "Branch Resolving Stalls", + "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", + "%", + ), + ], +} diff --git a/kernel_perf_agent/kernel_opt/profiler/__init__.py b/kernel_perf_agent/kernel_opt/profiler/__init__.py new file mode 100644 index 0000000..d177194 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/profiler/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""kernel_perf_agent package.""" + +# Kernel Perf Agent package +__all__ = [] diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py new file mode 100644 index 0000000..710d5c6 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -0,0 +1,449 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +NCU Profiling Module for Triton Kernels + +This module wraps three tasks: +1) Collect core metrics for Triton kernels with Nsight Compute into CSV (`profile_triton_kernel`). +2) Extract and clean those metrics into a DataFrame from the CSV (`load_ncu_metrics`). +3) Convert the metrics table into a string suitable for inclusion in an LLM prompt (`metrics_to_prompt`). + + +""" + +import json +import math +import os +import shutil +import subprocess +import sys +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Union + +import numpy as np +import pandas as pd + +# This selection of metrics is inspried by the CudaForge team (https://github.com/OptimAI-Lab/CudaForge/tree/main) + +METRICS = ",".join( + [ + "sm__cycles_active.avg", + "sm__warps_active.avg.pct_of_peak_sustained_active", + "launch__occupancy_limit_blocks", + "launch__occupancy_limit_registers", + "launch__occupancy_limit_shared_mem", + "launch__registers_per_thread", + "sm__inst_executed.sum", + "sm__inst_executed_pipe_fp32.avg.pct_of_peak_sustained_active", + "sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", + "dram__bytes_read.sum", + "dram__bytes_write.sum", + "dram__throughput.avg.pct_of_peak_sustained_elapsed", + "dram__bytes.sum.per_second", + "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed", + "launch__shared_mem_per_block_allocated", + "l1tex__t_sector_hit_rate.pct", + "l1tex__throughput.avg.pct_of_peak_sustained_active", + "lts__t_sector_hit_rate.pct", + "lts__throughput.avg.pct_of_peak_sustained_active", + "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", + "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", + "smsp__warp_issue_stalled_memory_dependency_per_warp_active.pct", + "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", + "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", + "smsp__warp_issue_stalled_barrier_per_warp_active.pct", + "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", + "smsp__sass_average_branch_targets_threads_uniform.pct", + ] +) + +# List version for convenient header selection +METRIC_COLUMNS: List[str] = [s.strip() for s in METRICS.split(",")] + + +class MetricSelectionPolicy(Enum): + """Policy for selecting rows when multiple rows exist for a kernel.""" + + FIRST = "first" + LAST = "last" + MAX_CYCLES = "max_cycles" + + +def profile_triton_kernel( + benchmark_script: Path, + workdir: Path, + out_csv: str = "ncu_output.csv", + python_executable: Optional[str] = None, + ncu_bin: Optional[str] = None, + launch_count: int = 20, + timeout: int = 120, +) -> Path: + """ + Profile a Triton kernel using NCU. + + Args: + benchmark_script: Path to benchmark script that calls the kernel + workdir: Working directory for execution + out_csv: Output CSV filename + python_executable: Python executable to use (default: sys.executable) + ncu_bin: Path to NCU binary (default: auto-detect) + launch_count: Number of kernel launches to profile + timeout: Timeout in seconds for NCU execution + + Returns: + Path to output CSV file + + Raises: + RuntimeError: If NCU profiling fails + FileNotFoundError: If NCU binary or output CSV not found + """ + # Resolve paths + if python_executable is None: + python_executable = sys.executable + + if ncu_bin is None: + ncu_bin = shutil.which("ncu") or "/usr/local/cuda/bin/ncu" + + if not Path(ncu_bin).exists(): + raise FileNotFoundError(f"NCU binary not found: {ncu_bin}") + + csv_path = (workdir / out_csv).resolve() + benchmark_script = benchmark_script.resolve() + + if not benchmark_script.exists(): + raise FileNotFoundError(f"Benchmark script not found: {benchmark_script}") + + # Preserve important environment variables + env = os.environ.copy() + + # Add Triton-specific environment variables + env["TRITON_CACHE_DIR"] = str(workdir / ".triton_cache") + + preserve = ",".join( + [ + "PATH", + "LD_LIBRARY_PATH", + "CUDA_VISIBLE_DEVICES", + "PYTHONPATH", + "TRITON_CACHE_DIR", + "TORCH_EXTENSIONS_DIR", + "CONDA_PREFIX", + "CONDA_DEFAULT_ENV", + ] + ) + + # Build NCU command + cmd = [ + "sudo", + "-E", + f"--preserve-env={preserve}", + ncu_bin, + "--csv", + "--page=raw", + "--kernel-name-base=demangled", + "--target-processes=all", + "--replay-mode=kernel", + "--profile-from-start=on", + f"--log-file={str(csv_path)}", + f"--metrics={METRICS}", + "--launch-skip=0", + f"--launch-count={launch_count}", + python_executable, + str(benchmark_script), + ] + + print("[NCU] Running profiling...") + print(f"[NCU] Benchmark: {benchmark_script.name}") + print(f"[NCU] Output: {csv_path}") + print(f"[NCU] Command: {' '.join(cmd[:10])}... (truncated)") + + try: + result = subprocess.run( + cmd, + cwd=str(workdir), + env=env, + capture_output=True, + text=True, + timeout=timeout, + ) + + if result.returncode != 0: + error_msg = result.stderr or result.stdout + raise RuntimeError( + f"NCU profiling failed with return code {result.returncode}:\n{error_msg[:500]}" + ) + + if not csv_path.exists(): + raise FileNotFoundError(f"NCU did not create output CSV: {csv_path}") + + # Check if CSV has content + csv_size = csv_path.stat().st_size + if csv_size < 100: + raise RuntimeError( + f"NCU CSV file is too small ({csv_size} bytes), likely empty" + ) + + print("[NCU] ✓ Profiling completed successfully") + print(f"[NCU] ✓ CSV written: {csv_path} ({csv_size} bytes)") + return csv_path + + except subprocess.TimeoutExpired: + raise RuntimeError(f"NCU profiling timed out after {timeout} seconds") + except Exception as e: + raise RuntimeError(f"NCU profiling failed: {e}") + + +def _apply_selection_policy( + df: pd.DataFrame, + policy: MetricSelectionPolicy, +) -> pd.DataFrame: + """ + Apply selection policy to choose a single row from a DataFrame. + + Args: + df: DataFrame with one or more rows + policy: Selection policy to apply + + Returns: + DataFrame with a single row based on the policy + """ + if df.empty: + return df + + if len(df) == 1: + return df + + if policy == MetricSelectionPolicy.FIRST: + return df.iloc[[0]] + elif policy == MetricSelectionPolicy.LAST: + return df.iloc[[-1]] + elif policy == MetricSelectionPolicy.MAX_CYCLES: + if "sm__cycles_active.avg" in df.columns: + return df.sort_values("sm__cycles_active.avg", ascending=False).head(1) + # Fallback to last if cycles column not available + return df.iloc[[-1]] + else: + # Fallback to last for unknown policies + return df.iloc[[-1]] + + +def _filter_by_kernel_names( + df: pd.DataFrame, + name_list: Sequence[str], + policy: MetricSelectionPolicy, + keep_cols: List[str], +) -> pd.DataFrame: + """ + Filter DataFrame by kernel names with substring matching. + + Args: + df: DataFrame with NCU metrics + name_list: List of kernel name substrings to match + policy: Selection policy when multiple rows match + keep_cols: Columns to preserve in empty result + + Returns: + Filtered DataFrame with one row per matched kernel name + """ + results = [] + for name in name_list: + # Use contains match instead of exact equality (for Triton's long kernel names) + matched = df[ + df["Kernel Name"].astype(str).str.contains(name, regex=False, na=False) + ] + if matched.empty: + continue + + row = _apply_selection_policy(matched, policy) + results.append(row) + + if results: + return pd.concat(results, ignore_index=True) + else: + return pd.DataFrame(columns=keep_cols) + + +def load_ncu_metrics( + csv_path: Union[str, Path], + columns: Optional[Sequence[str]] = None, + extra_keep: Optional[Sequence[str]] = ("Kernel Name",), + coerce_numeric: bool = True, + name_list: Optional[Sequence[str]] = None, + select: Union[str, MetricSelectionPolicy] = MetricSelectionPolicy.LAST, +) -> pd.DataFrame: + """ + Load and parse NCU metrics from CSV file. + + Args: + csv_path: Path to NCU CSV output + columns: Specific metric columns to load (default: all METRIC_COLUMNS) + extra_keep: Additional columns to keep (e.g., "Kernel Name") + coerce_numeric: Convert metric values to numeric + name_list: Filter by kernel names (substring match) + select: Selection policy when multiple rows per name. + Can be MetricSelectionPolicy enum or string ("first", "last", "max_cycles") + + Returns: + DataFrame with parsed metrics + + Raises: + FileNotFoundError: If CSV file not found + ValueError: If no requested columns found in CSV or invalid select value + """ + csv_path = Path(csv_path) + if not csv_path.exists(): + raise FileNotFoundError(f"CSV not found: {csv_path}") + + # Convert string to enum if needed + if isinstance(select, str): + try: + policy = MetricSelectionPolicy(select) + except ValueError: + raise ValueError( + f"Invalid select value: {select}. " + f"Must be one of: {[p.value for p in MetricSelectionPolicy]}" + ) + else: + policy = select + + df = pd.read_csv(csv_path, comment="=", low_memory=False) + + metric_cols = list(columns) if columns is not None else METRIC_COLUMNS + keep_cols: List[str] = [] + if extra_keep: + keep_cols.extend([c for c in extra_keep if c in df.columns]) + keep_cols.extend([c for c in metric_cols if c in df.columns]) + + if not keep_cols: + raise ValueError("No requested columns found in the CSV header.") + + sub = df[keep_cols].copy() + + # Drop the units row (first row often contains units like "%", "inst", etc.) + if len(sub) > 0: + first_row_str = sub.iloc[0].astype(str).str.lower() + unit_tokens = ("%", "inst", "cycle", "block", "register", "register/thread") + if first_row_str.apply(lambda x: any(tok in x for tok in unit_tokens)).any(): + sub = sub.iloc[1:].reset_index(drop=True) + + # Coerce metrics to numeric + if coerce_numeric: + metric_in_sub = [c for c in metric_cols if c in sub.columns] + sub[metric_in_sub] = ( + sub[metric_in_sub] + .replace({",": "", "%": ""}, regex=True) + .apply(pd.to_numeric, errors="coerce") + ) + + # Filter by kernel name list if provided + if name_list: + sub = _filter_by_kernel_names(sub, name_list, policy, keep_cols) + else: + # Apply selection to all rows if no name filter + sub = _apply_selection_policy(sub, policy) + + return sub + + +def metrics_to_prompt( + df: pd.DataFrame, + title: str = "GPU Profiling Metrics:", + key_by: str = "Kernel Name", + round_digits: Optional[int] = 3, + compact: bool = False, + keep_cols: Optional[List[str]] = None, +) -> str: + """ + Convert NCU metrics DataFrame to JSON string for LLM prompts. + + Returns JSON in format: + { + "": { "": , ... } + } + Args: + df: DataFrame with NCU metrics + title: Title for the metrics section (not included in output) + key_by: Column to use as key (usually "Kernel Name") + round_digits: Number of decimal places for rounding + compact: If True, use compact JSON (no indentation) + keep_cols: Specific columns to include in output + + Returns: + JSON string with metrics + """ + + def _safe(v: Any) -> Any: + """Convert values to JSON-safe format.""" + if v is None or (isinstance(v, float) and math.isnan(v)): + return None + if isinstance(v, (pd.Timestamp, pd.Timedelta, pd.Interval)): + return str(v) + if isinstance(v, np.generic): + v = v.item() + if isinstance(v, float) and math.isinf(v): + return "inf" if v > 0 else "-inf" + if isinstance(v, float) and round_digits is not None: + return round(v, round_digits) + return v + + # Empty table + if df is None or df.empty: + return "{}" + + cols = list(df.columns) + + # Round numeric columns + if round_digits is not None: + num_cols = df.select_dtypes(include="number").columns + if len(num_cols) > 0: + df = df.copy() + df[num_cols] = df[num_cols].round(round_digits) + + # If key column is missing, return a list of rows + if key_by not in cols: + rows = [ + {k: _safe(v) for k, v in rec.items()} + for rec in df.to_dict(orient="records") + ] + return json.dumps(rows, ensure_ascii=False, indent=None if compact else 2) + + # Determine value columns + value_cols = [c for c in cols if c != key_by] + if keep_cols is not None: + value_cols = [c for c in value_cols if c in keep_cols] + + data: Dict[str, Any] = {} + for rec in df[[key_by] + value_cols].to_dict(orient="records"): + k = str(rec.pop(key_by)) + val_obj = {ck: _safe(cv) for ck, cv in rec.items()} + if k in data: + # Multiple rows for same key - convert to list + if isinstance(data[k], list): + data[k].append(val_obj) + else: + data[k] = [data[k], val_obj] + else: + data[k] = val_obj + + return json.dumps(data, ensure_ascii=False, indent=None if compact else 2) + + +if __name__ == "__main__": + print("ncu_profiler module loaded.") + print("Import its functions in your scripts:") + print( + " from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import profile_triton_kernel, load_ncu_metrics, metrics_to_prompt" + ) diff --git a/pyproject.toml b/pyproject.toml index 2d5382f..a31dc3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ pipeline-ui = "scripts.pipeline_ui:main" "Repository" = "https://github.com/pytorch-labs/KernelAgent" [tool.setuptools.packages.find] -include = ["triton_kernel_agent*", "Fuser*", "scripts", "utils"] +include = ["triton_kernel_agent*", "kernel_perf_agent*", "Fuser*", "scripts", "utils"] [tool.setuptools] include-package-data = true diff --git a/triton_kernel_agent/opt_worker_component/README.md b/triton_kernel_agent/opt_worker_component/README.md new file mode 100644 index 0000000..76026b2 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/README.md @@ -0,0 +1,12 @@ +# Opt Worker Components + +High-level components used by `OptimizationWorker`. + +These components are **thin wrappers** around low-level utilities +from `kernel_perf_agent` that provide: +- Logging integration +- Error handling +- Worker-specific configuration + +## Dependency Flow +worker_components → kernel_perf_agent (implementation) diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py b/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py new file mode 100644 index 0000000..61e1d08 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarking infrastructure for kernel performance measurement.""" + +from .benchmark import Benchmark + +__all__ = ["Benchmark"] diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py new file mode 100644 index 0000000..8ee39e4 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified benchmarking for Triton kernels and PyTorch baselines. + +This module consolidates kernel and PyTorch benchmarking with improved timing +utilities, L2 cache clearing, and comprehensive statistics. +""" + +import json +import logging +import subprocess +import sys +import traceback +from pathlib import Path +from typing import Any, Optional + +import torch + +from triton_kernel_agent.opt_worker_component.benchmarking.timing import ( + compute_timing_stats, + prepare_pytorch_model, + time_with_cuda_events, + time_with_triton_do_bench, +) + + +class BenchmarkLockManager: + """Manages GPU benchmarking locks to prevent resource contention.""" + + def __init__(self, lock: Any, worker_id: int, logger: logging.Logger): + """Initialize the lock manager. + + Args: + lock: Shared multiprocessing lock for serializing GPU access + worker_id: Worker ID for logging + logger: Logger instance + """ + self.lock = lock + self.worker_id = worker_id + self.logger = logger + + def __enter__(self): + """Acquire the benchmarking lock.""" + self.logger.info(f"⏳ Waiting for benchmark lock (worker {self.worker_id})...") + self.lock.acquire() + self.logger.info(f"🔓 Acquired benchmark lock (worker {self.worker_id})") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Release the benchmarking lock.""" + try: + self.lock.release() + self.logger.info(f"🔒 Released benchmark lock (worker {self.worker_id})") + except Exception as e: + self.logger.warning(f"Failed to release benchmark lock: {e}") + return False + + +class Benchmark: + """Unified benchmark for Triton kernels and PyTorch baselines. + + Supports two modes: + 1. Subprocess mode: Runs benchmarks in isolated processes (for compatibility) + 2. Direct mode: Uses in-process timing utilities (faster, more flexible) + """ + + def __init__( + self, + logger: logging.Logger, + artifacts_dir: Path, + benchmark_lock: Any, + worker_id: int = 0, + warmup: int = 25, + repeat: int = 100, + timing_method: str = "cuda_event", + ): + """Initialize the benchmark. + + Args: + logger: Logger instance + artifacts_dir: Directory for benchmark artifacts + benchmark_lock: Shared lock to serialize GPU benchmarking + worker_id: Worker ID + warmup: Number of warmup iterations (or warmup time in ms for do_bench) + repeat: Number of repeat iterations (or rep time in ms for do_bench) + timing_method: Timing method ("cuda_event", "do_bench", "host_time") + """ + self.logger = logger + self.artifacts_dir = artifacts_dir + self.lock_manager = BenchmarkLockManager(benchmark_lock, worker_id, logger) + self.warmup = warmup + self.repeat = repeat + self.timing_method = timing_method + + def benchmark_kernel( + self, + kernel_file: Path, + problem_file: Path, + baseline_file: Optional[Path] = None, + ) -> dict[str, Any]: + """Benchmark Triton kernel performance using subprocess isolation. + + Uses subprocess for crash protection of potentially buggy kernels. + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + baseline_file: Path to baseline kernel (optional) + + Returns: + Dictionary with benchmark results: + - time_ms: Mean time in ms + - speedup: Speedup vs baseline + """ + try: + with self.lock_manager: + results_json = self.artifacts_dir / "benchmark_results.json" + benchmark_script = Path(__file__).parent / "kernel_subprocess.py" + + cmd = [ + sys.executable, + str(benchmark_script), + "--problem", + str(problem_file), + "--kernel", + str(kernel_file), + "--warmup", + str(self.warmup), + "--repeat", + str(self.repeat), + "--json", + str(results_json), + "--quiet", + ] + + if baseline_file: + cmd.extend(["--baseline"]) + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=300, + ) + + if result.returncode != 0: + error_msg = ( + result.stderr.strip() + or result.stdout.strip() + or "Unknown error" + ) + self.logger.error(f"Kernel benchmark failed: {error_msg}") + return {"time_ms": float("inf"), "speedup": 0.0} + + with open(results_json, "r") as f: + results = json.load(f) + + kernel_name = kernel_file.stem + kernel_results = results.get("kernels", {}).get(kernel_name, {}) + + return { + "time_ms": kernel_results.get("time_ms", float("inf")), + "speedup": kernel_results.get("speedup", 1.0), + } + + except Exception as e: + self.logger.error(f"Kernel benchmark failed: {e}") + return {"time_ms": float("inf"), "speedup": 0.0} + + def benchmark_pytorch( + self, + problem_file: Path, + dtype: Optional[torch.dtype] = None, + ) -> dict[str, Any]: + """Benchmark PyTorch baseline using direct in-process timing. + + Always uses direct mode (PyTorch is stable, doesn't need subprocess isolation). + + Args: + problem_file: Path to problem file (must define Model class and get_inputs()) + dtype: Data type to use (default: auto-detect based on model parameters) + + Returns: + Dictionary with benchmark results: + - time_ms: Mean time in ms + - stats: Full timing statistics (mean, std, min, max, all_times, etc.) + """ + try: + with self.lock_manager: + model, inputs = prepare_pytorch_model( + problem_file=problem_file, + device="cuda", + dtype=dtype, + ) + + if self.timing_method == "do_bench": + times = time_with_triton_do_bench( + lambda: model(*inputs), + [], + warmup=self.warmup, + rep=self.repeat, + verbose=False, + ) + else: # cuda_event + times = time_with_cuda_events( + lambda: model(*inputs), + [], + num_warmup=self.warmup, + num_trials=self.repeat, + clear_cache=True, + verbose=False, + ) + + stats = compute_timing_stats(times) + + return { + "time_ms": stats["mean"], + "stats": stats, + } + + except Exception as e: + self.logger.error(f"PyTorch baseline benchmark failed: {e}") + self.logger.error(traceback.format_exc()) + return {"time_ms": float("inf")} diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py new file mode 100644 index 0000000..733216c --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A task-agnostic, profiling-only benchmark script for Triton kernels. +This script ONLY benchmarks candidate kernels without correctness checks. +Assumes correctness has been verified upstream. + +Design: +- Skips correctness verification (assumes already verified) +- Only runs candidate kernels +- Fast profiling for iterative optimization loops +- Uses shared utilities from timing.py +""" + +from __future__ import annotations + +import argparse +import inspect +import json +import sys +from pathlib import Path + +from timing import ( + import_module, + load_kernel_function, + load_problem_interface, + prepare_inputs, +) +from typing import Any, Callable, Tuple + +import torch +import triton.testing as tt + + +def _extract_model_params( + model: torch.nn.Module, +) -> tuple[torch.Tensor | None, torch.Tensor | None, dict[str, Any]]: + """Extract weight, bias, and layer parameters from a PyTorch model. + + Searches for Conv or Linear layers and extracts their parameters. + + Args: + model: PyTorch model to extract parameters from + + Returns: + Tuple of (weight, bias, layer_kwargs) where layer_kwargs contains + stride, padding, dilation, groups if applicable + """ + for _, module in model.named_modules(): + if isinstance( + module, + (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.Linear), + ): + if hasattr(module, "weight") and module.weight is not None: + weight = module.weight + bias = getattr(module, "bias", None) + + layer_kwargs: dict[str, Any] = {} + if hasattr(module, "stride"): + stride = module.stride + layer_kwargs["stride"] = ( + stride[0] if isinstance(stride, (tuple, list)) else stride + ) + if hasattr(module, "padding"): + padding = module.padding + layer_kwargs["padding"] = ( + padding[0] if isinstance(padding, (tuple, list)) else padding + ) + if hasattr(module, "dilation"): + dilation = module.dilation + layer_kwargs["dilation"] = ( + dilation[0] if isinstance(dilation, (tuple, list)) else dilation + ) + if hasattr(module, "groups"): + layer_kwargs["groups"] = module.groups + + return weight, bias, layer_kwargs + + return None, None, {} + + +def _run_once( + fn: Callable, inputs: Tuple[torch.Tensor, ...], init_inputs: list, name: str +) -> torch.Tensor: + """Run kernel once to verify execution and get output shape/dtype.""" + try: + with torch.inference_mode(): + return fn(*inputs, *init_inputs) + except Exception as exc: + raise RuntimeError(f"{name} failed to execute: {exc}") from exc + + +def _benchmark( + fn: Callable, + inputs: Tuple[torch.Tensor, ...], + init_inputs: list, + name: str, + warmup: int = 25, + rep: int = 100, +) -> float: + """Benchmark a kernel function using triton.testing.do_bench.""" + try: + ms = tt.do_bench( + lambda: fn(*inputs, *init_inputs), + warmup=warmup, + rep=rep, + return_mode="mean", + ) + print(f"{name}: {ms:.4f} ms (mean over {rep} runs)") + return ms + except Exception as exc: + print(f"❌ {name}: Benchmark failed: {exc}") + return float("inf") + + +def _parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Task-agnostic Triton kernel benchmark", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + parser.add_argument( + "--problem", + type=Path, + required=True, + help="Path to problem file (must define Model and get_inputs)", + ) + parser.add_argument( + "--kernel", + type=Path, + required=True, + help="Path to kernel file (must define kernel_function)", + ) + parser.add_argument( + "--baseline", + action="store_true", + help="Include PyTorch reference model in benchmark", + ) + parser.add_argument("--warmup", type=int, default=25) + parser.add_argument("--repeat", type=int, default=100) + parser.add_argument("--size", type=int, default=4096, help="Problem size N") + parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["float32", "float16", "bfloat16"], + ) + parser.add_argument("--json", type=Path, help="Save results to JSON file") + parser.add_argument("--quiet", action="store_true") + + args = parser.parse_args() + args.problem = args.problem.resolve() + args.kernel = args.kernel.resolve() + return args + + +def _load_problem( + args: argparse.Namespace, + device: torch.device, + dtype: torch.dtype, +) -> tuple[type, tuple, list, torch.nn.Module | None]: + """Load problem interface, prepare inputs, and optionally create baseline model. + + Returns: + Tuple of (Model class, inputs, init_inputs, baseline_model or None) + """ + Model, get_inputs, get_init_inputs = load_problem_interface(args.problem) + + # Check for optional benchmark config override + try: + problem_mod = import_module(args.problem, "problem") + get_benchmark_config = getattr(problem_mod, "get_benchmark_config", None) + if get_benchmark_config is not None: + config = get_benchmark_config() + args.warmup = config.get("warmup", args.warmup) + args.repeat = config.get("repeat", args.repeat) + if not args.quiet: + print( + f"Using problem-specific config: " + f"warmup={args.warmup}, repeat={args.repeat}" + ) + except Exception: + pass + + inputs = prepare_inputs(get_inputs, device=device, dtype=dtype) + + init_inputs = get_init_inputs() if get_init_inputs is not None else [] + if not isinstance(init_inputs, (tuple, list)): + init_inputs = [init_inputs] + + # Create baseline model if requested + baseline_model = None + if args.baseline: + baseline_model = ( + Model(*init_inputs).to(device=device, dtype=dtype) + if init_inputs + else Model().to(device=device, dtype=dtype) + ) + baseline_model.eval() + out = _run_once(baseline_model, inputs, [], "Reference") + if not args.quiet: + print(f"Reference output shape: {out.shape}, dtype: {out.dtype}") + print() + + return Model, inputs, init_inputs, baseline_model + + +def _prepare_kernel( + kernel_file: Path, + Model: type, + baseline_model: torch.nn.Module | None, + init_inputs: list, + device: torch.device, + dtype: torch.dtype, + quiet: bool = False, +) -> tuple[Callable, tuple, list]: + """Load kernel and wrap it with model parameters if needed. + + Returns: + Tuple of (kernel_function, kernel_args, kernel_init_args) + """ + kernel_function = load_kernel_function(kernel_file) + + # Check if kernel expects weight/bias parameters + needs_model = False + try: + sig = inspect.signature(kernel_function) + if "weight" in sig.parameters: + needs_model = True + except Exception: + pass + + kernel_init_args = init_inputs + + if needs_model and Model is not None: + try: + # Reuse baseline model if available + extract_model = baseline_model + if extract_model is None: + extract_model = ( + Model(*init_inputs).to(device=device, dtype=dtype) + if init_inputs + else Model().to(device=device, dtype=dtype) + ) + + weight, bias, kernel_kwargs = _extract_model_params(extract_model) + + if weight is not None: + original_fn = kernel_function + + def kernel_with_model(*args, **kwargs): + return original_fn(args[0], weight, bias, **kernel_kwargs) + + kernel_function = kernel_with_model + kernel_init_args = [] + except Exception as exc: + if not quiet: + print(f"⚠️ Warning: Failed to extract model parameters: {exc}") + print(" Falling back to direct kernel invocation") + + return kernel_function, kernel_init_args + + +def _save_results(results: dict[str, Any], path: Path) -> None: + """Save benchmark results to JSON file.""" + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(results, f, indent=2) + print(f"Results saved to: {path}") + + +def main(): + args = _parse_args() + + device = torch.device(args.device) + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + dtype = dtype_map[args.dtype] + + if not args.quiet: + print("=" * 80) + print("TRITON KERNEL PROFILING") + print("=" * 80) + print(f"Problem: {args.problem.name}") + print(f"Size: {args.size}") + print(f"Device: {device}, Dtype: {dtype}") + print(f"Warmup: {args.warmup}, Repeat: {args.repeat}") + print() + + # Load problem and prepare inputs + try: + Model, inputs, init_inputs, baseline_model = _load_problem(args, device, dtype) + except Exception as exc: + print(f"❌ Failed to load problem: {exc}") + sys.exit(1) + + results: dict[str, Any] = { + "problem": str(args.problem), + "size": args.size, + "device": str(device), + "dtype": str(dtype), + "warmup": args.warmup, + "repeat": args.repeat, + "kernels": {}, + } + + # Benchmark baseline (if requested) + baseline_time = None + if baseline_model is not None: + if not args.quiet: + print("1. PyTorch Reference") + baseline_time = _benchmark( + baseline_model, inputs, [], "PyTorch", args.warmup, args.repeat + ) + results["kernels"]["pytorch_reference"] = { + "time_ms": baseline_time, + "speedup": 1.0, + } + if not args.quiet: + print() + + # Load and prepare kernel + kernel_name = args.kernel.stem + if not args.quiet: + idx = 2 if args.baseline else 1 + print(f"{idx}. Candidate: {kernel_name}") + + try: + kernel_fn, kernel_init_args = _prepare_kernel( + args.kernel, Model, baseline_model, init_inputs, device, dtype, args.quiet + ) + except Exception as exc: + print(f"❌ Failed to load kernel: {exc}") + results["kernels"][kernel_name] = {"time_ms": float("inf"), "error": str(exc)} + if args.json: + _save_results(results, args.json) + sys.exit(1) + + # Verify kernel executes + try: + out = _run_once(kernel_fn, inputs, kernel_init_args, kernel_name) + if not args.quiet: + print(f"✓ {kernel_name} executes successfully") + print(f" Output shape: {out.shape}, dtype: {out.dtype}") + except Exception as exc: + print(f"❌ {kernel_name} failed: {exc}") + results["kernels"][kernel_name] = {"time_ms": float("inf"), "error": str(exc)} + if args.json: + _save_results(results, args.json) + sys.exit(1) + + # Benchmark kernel + kernel_time = _benchmark( + kernel_fn, inputs, kernel_init_args, kernel_name, args.warmup, args.repeat + ) + results["kernels"][kernel_name] = {"time_ms": kernel_time, "path": str(args.kernel)} + + # Calculate speedup + if baseline_time is not None and kernel_time != float("inf"): + speedup = baseline_time / kernel_time + results["kernels"][kernel_name]["speedup"] = speedup + if not args.quiet: + print(f"Speedup vs PyTorch: {speedup:.2f}x") + + if args.json: + _save_results(results, args.json) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nBenchmark interrupted") + sys.exit(130) + except Exception as exc: + print(f"❌ Unexpected error: {exc}") + sys.exit(1) diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/timing.py b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py new file mode 100644 index 0000000..0e0730c --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py @@ -0,0 +1,459 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core timing and model loading utilities for kernel benchmarking. + +This module consolidates: +- Timing functions (CUDA events, do_bench, host timing) +- Model/kernel loading utilities +- Statistics computation + +Inspired by KernelBench's timing.py +""" + +import hashlib +import importlib.util +import sys +from pathlib import Path +from typing import Any, Callable, Optional, Tuple + +import numpy as np +import torch + + +# ============================================================================= +# Model and Kernel Loading Utilities +# ============================================================================= + + +class CompilationError(RuntimeError): + """Raised when a kernel or problem file fails to compile/import.""" + + pass + + +def import_module(path: Path, module_name: Optional[str] = None): + """Dynamically import a Python file. + + Args: + path: Path to the Python file + module_name: Optional name for the module (auto-generated if None) + + Returns: + The imported module + + Raises: + FileNotFoundError: If path doesn't exist + CompilationError: If import fails + """ + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + if module_name is None: + module_name = f"mod_{hashlib.md5(str(path).encode()).hexdigest()}" + + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None or spec.loader is None: + raise CompilationError(f"Failed to create spec for {path}") + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + + try: + spec.loader.exec_module(module) + except Exception as exc: + raise CompilationError(f"Failed to import {path}: {exc}") from exc + + return module + + +def load_problem_interface( + problem_file: Path, +) -> Tuple[type, Callable, Optional[Callable]]: + """Load the standard problem interface from a problem file. + + Args: + problem_file: Path to problem file + + Returns: + Tuple of (Model class, get_inputs function, get_init_inputs function) + + Raises: + CompilationError: If problem file doesn't define required interface + """ + module = import_module(problem_file, "problem") + + Model = getattr(module, "Model", None) + get_inputs = getattr(module, "get_inputs", None) + get_init_inputs = getattr(module, "get_init_inputs", None) + + if Model is None: + raise CompilationError("Problem file must define 'Model' class") + if get_inputs is None: + raise CompilationError("Problem file must define 'get_inputs()' function") + + return Model, get_inputs, get_init_inputs + + +def prepare_inputs( + get_inputs: Callable, + device: torch.device | str = "cuda", + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[torch.Tensor, ...]: + """Prepare inputs by converting to target device and dtype. + + Args: + get_inputs: Function that returns inputs + device: Target device + dtype: Target dtype for floating-point tensors + + Returns: + Tuple of prepared inputs + """ + inputs = get_inputs() + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + + # Convert inputs to target device and dtype + # IMPORTANT: Only convert floating-point tensors; preserve integer/bool tensors + converted_inputs = [] + for inp in inputs: + if isinstance(inp, torch.Tensor): + inp = inp.to(device=device) + # Preserve integer/bool tensors (e.g., targets for classification) + if inp.is_floating_point(): + inp = inp.to(dtype=dtype) + converted_inputs.append(inp) + + return tuple(converted_inputs) + + +def prepare_pytorch_model( + problem_file: Path, + device: torch.device | str = "cuda", + dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.nn.Module, Tuple[torch.Tensor, ...]]: + """Prepare PyTorch model and inputs for benchmarking. + + This handles the full workflow: + 1. Load problem interface (Model, get_inputs, get_init_inputs) + 2. Initialize model with init inputs + 3. Move model to device + 4. Handle dtype conversion based on whether model has parameters + + Args: + problem_file: Path to problem file + device: Target device + dtype: Target dtype (auto-detected if None) + + Returns: + Tuple of (model, inputs) ready for benchmarking + """ + Model, get_inputs, get_init_inputs = load_problem_interface(problem_file) + + # Get initialization inputs (e.g., features, eps for RMSNorm) + init_inputs = get_init_inputs() if get_init_inputs is not None else [] + if not isinstance(init_inputs, (tuple, list)): + init_inputs = [init_inputs] + + model = Model(*init_inputs) if init_inputs else Model() + model = model.cuda() + has_parameters = any(p.numel() > 0 for p in model.parameters()) + + inputs = get_inputs() + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + + # Default to bfloat16 unless explicitly specified or model is a loss function + target_dtype = dtype or torch.bfloat16 + is_loss_function = isinstance(model, torch.nn.modules.loss._Loss) + + if has_parameters or not is_loss_function: + # Models with parameters (Conv, Linear, etc.) OR compute operations (matmul, etc.) + # → use bfloat16 (or user-specified dtype) + if has_parameters: + model = model.to(target_dtype) + inputs = [ + ( + inp.cuda().to(target_dtype) + if isinstance(inp, torch.Tensor) and inp.is_floating_point() + else inp.cuda() + if isinstance(inp, torch.Tensor) + else inp + ) + for inp in inputs + ] + else: + # Loss functions (no parameters) → use float32 for compatibility + # PyTorch cross_entropy doesn't support bf16 on CUDA + processed_inputs = [] + for i, inp in enumerate(inputs): + if isinstance(inp, torch.Tensor): + if i == 0 and inp.is_floating_point(): + # First input (predictions) - convert to float32 for compatibility + processed_inputs.append(inp.cuda().to(torch.float32)) + else: + # Other inputs (like targets) - just move to CUDA, preserve dtype + processed_inputs.append(inp.cuda()) + else: + processed_inputs.append(inp) + inputs = processed_inputs + + return model, tuple(inputs) + + +def load_kernel_function(kernel_file: Path) -> Callable: + """Load kernel_function from a kernel file. + + Args: + kernel_file: Path to kernel file + + Returns: + The kernel_function callable + + Raises: + CompilationError: If kernel file doesn't define kernel_function + """ + module = import_module(kernel_file, "kernel") + + kernel_function = getattr(module, "kernel_function", None) + if kernel_function is None: + raise CompilationError( + f"Kernel file {kernel_file.name} must define 'kernel_function'" + ) + + return kernel_function + + +# ============================================================================= +# Timing Utilities +# ============================================================================= + + +def clear_l2_cache(device: torch.device | str = "cuda") -> None: + """Clear L2 cache by thrashing with a large tensor. + + This ensures we measure cold cache performance, which is more representative + of real-world scenarios where data isn't already cached. + + Reference: KernelBench timing.py + L2 cache sizes: A100=40MB, H100=50MB, H200=90MB, RTX4090=72MB, L40S=48MB + We overwrite >256MB to fully thrash L2 cache. + + Args: + device: CUDA device to use + """ + # 32 * 1024 * 1024 * 8B = 256MB - enough to thrash most GPU L2 caches + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device=device) + dummy.fill_(42) # Write to tensor to ensure cache thrashing + del dummy + + +def time_with_cuda_events( + kernel_fn: Callable, + args: list[Any], + num_warmup: int = 3, + num_trials: int = 10, + clear_cache: bool = True, + discard_first: int = 0, + verbose: bool = False, + device: Optional[torch.device | str] = None, +) -> list[float]: + """Time a CUDA kernel using CUDA events for accurate device-side timing. + + This measures actual GPU execution time without host-side overhead. + Each trial clears L2 cache to measure cold-cache performance. + + Args: + kernel_fn: Function to time + args: Arguments to pass to kernel_fn + num_warmup: Number of warmup iterations + num_trials: Number of timing trials + clear_cache: Whether to clear L2 cache between trials + discard_first: Number of initial trials to discard + verbose: Print per-trial timing info + device: CUDA device to use (None = current device) + + Returns: + List of elapsed times in milliseconds (length = num_trials) + """ + if device is None: + device = torch.cuda.current_device() + + with torch.cuda.device(device): + # Warmup + for _ in range(num_warmup): + kernel_fn(*args) + torch.cuda.synchronize(device=device) + + torch.cuda.empty_cache() + + if verbose: + print( + f"[Timing] Device: {torch.cuda.get_device_name(device)}, " + f"warmup={num_warmup}, trials={num_trials}" + ) + + elapsed_times: list[float] = [] + + # Timing trials + for trial in range(num_trials + discard_first): + torch.cuda.synchronize(device=device) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + if clear_cache: + clear_l2_cache(device=device) + + start_event.record() + kernel_fn(*args) + end_event.record() + + torch.cuda.synchronize(device=device) + elapsed_time_ms = start_event.elapsed_time(end_event) + + if trial >= discard_first: + if verbose: + print( + f" Trial {trial - discard_first + 1}: {elapsed_time_ms:.3f} ms" + ) + elapsed_times.append(elapsed_time_ms) + + return elapsed_times + + +def time_with_inductor_benchmarker( + kernel_fn: Callable, + args: list[Any], + num_warmup: int = 25, + verbose: bool = False, +) -> float: + """Time using PyTorch Inductor's benchmarker (simplest approach). + + This is a thin wrapper around torch._inductor.runtime.benchmarking.benchmarker, + which handles CUDA synchronization and timing internally. + + Args: + kernel_fn: Function to time + args: Arguments to pass to kernel_fn + num_warmup: Number of warmup iterations + verbose: Print timing info + + Returns: + Elapsed time in milliseconds (single value, not a list) + + Note: + This uses a private PyTorch API (_inductor) which may change without notice. + """ + from torch._inductor.runtime.benchmarking import benchmarker + + # Warmup + for _ in range(num_warmup): + kernel_fn(*args) + + ms = benchmarker.benchmark_gpu(lambda: kernel_fn(*args)) + + if verbose: + print(f"[Timing] Inductor benchmarker: {ms:.4f} ms") + + return ms + + +def time_with_triton_do_bench( + kernel_fn: Callable, + args: list[Any], + warmup: int = 25, + rep: int = 100, + verbose: bool = False, + device: Optional[torch.device | str] = None, +) -> list[float]: + """Time using Triton's do_bench with adaptive trial count. + + Triton's do_bench automatically determines the number of trials based on + warmup/rep time budgets. This is convenient but gives less control. + + Args: + kernel_fn: Function to time + args: Arguments to pass to kernel_fn + warmup: Warmup time budget in milliseconds + rep: Repetition time budget in milliseconds + verbose: Print timing info + device: CUDA device to use + + Returns: + List of all trial times in milliseconds + """ + if device is None: + device = torch.cuda.current_device() + + import triton.testing as triton_testing + + with torch.cuda.device(device): + if verbose: + print( + f"[Timing] Using triton.do_bench on {torch.cuda.get_device_name(device)}" + ) + + def wrapped_fn(): + return kernel_fn(*args) + + times = triton_testing.do_bench( + fn=wrapped_fn, + warmup=warmup, + rep=rep, + grad_to_none=None, + quantiles=None, + return_mode="all", + ) + + return times + + +def compute_timing_stats( + elapsed_times: list[float], + device: Optional[torch.device | str] = None, +) -> dict[str, Any]: + """Compute essential timing statistics. + + Args: + elapsed_times: List of elapsed times in milliseconds + device: CUDA device (for recording hardware info) + + Returns: + Dictionary with timing statistics: + - mean: Mean time in ms + - std: Standard deviation in ms + - min: Minimum time in ms + - max: Maximum time in ms + - num_trials: Number of trials + - all_times: All trial times + - hardware: GPU name (if device provided) + """ + times_array = np.array(elapsed_times) + + stats = { + "mean": float(np.mean(times_array)), + "std": float(np.std(times_array)), + "min": float(np.min(times_array)), + "max": float(np.max(times_array)), + "num_trials": len(elapsed_times), + "all_times": [float(t) for t in elapsed_times], + } + + if device is not None: + stats["hardware"] = torch.cuda.get_device_name(device=device) + stats["device"] = str(device) + + return stats diff --git a/triton_kernel_agent/opt_worker_component/profiling/__init__.py b/triton_kernel_agent/opt_worker_component/profiling/__init__.py new file mode 100644 index 0000000..8f8d18b --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Profiling infrastructure for NCU-based kernel analysis.""" + +from .kernel_profiler import KernelProfiler +from .ncu_wrapper_factory import NCUWrapperFactory + +__all__ = ["NCUWrapperFactory", "KernelProfiler"] diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py new file mode 100644 index 0000000..682c4e7 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -0,0 +1,289 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Profiles Triton kernels using NVIDIA Nsight Compute (NCU).""" + +import json +import logging +import subprocess +import time +from dataclasses import asdict, dataclass +from datetime import datetime +from functools import cached_property +from pathlib import Path +from typing import Any, Dict + +import pandas as pd + +from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( + load_ncu_metrics, + metrics_to_prompt, + profile_triton_kernel, +) + +from triton_kernel_agent.opt_worker_component.profiling.ncu_wrapper_factory import ( + NCUWrapperFactory, +) + +# Default timeout for NCU profiling in seconds +DEFAULT_NCU_TIMEOUT_SECONDS = 120 + + +@dataclass +class ProfilerMetadata: + """Metadata about a profiling run.""" + + kernel_file: str + problem_file: str + round_num: int + timestamp: str + ncu_version: str | None + + +@dataclass +class ProfilerResults: + """ + Results from a kernel profiling run. + + This dataclass encapsulates both the metrics DataFrame and the parsed + metrics dictionary, along with metadata about the profiling run. + """ + + metrics_df: pd.DataFrame + metrics: Dict[str, Any] + metadata: ProfilerMetadata + + def to_dict(self) -> Dict[str, Any]: + """Convert results to a dictionary for JSON serialization.""" + return { + "metrics": self.metrics, + "metadata": asdict(self.metadata), + } + + def to_json(self) -> str: + """Convert results to a JSON string.""" + return json.dumps(self.to_dict(), indent=2) + + +class KernelProfiler: + def __init__( + self, + logger: logging.Logger, + artifacts_dir: Path, + logs_dir: Path, + ncu_bin_path: str | None = None, + ncu_timeout_seconds: int = DEFAULT_NCU_TIMEOUT_SECONDS, + ): + """ + Initialize the kernel profiler. + + Args: + logger: Logger instance + artifacts_dir: Directory for optimization artifacts + logs_dir: Directory for saving profiling logs + ncu_bin_path: Path to NCU binary (auto-detect if None) + ncu_timeout_seconds: Timeout for NCU profiling in seconds + """ + self.logger = logger + self.artifacts_dir = artifacts_dir + self.logs_dir = logs_dir + self.ncu_bin_path = ncu_bin_path + self.ncu_timeout_seconds = ncu_timeout_seconds + self.wrapper_factory = NCUWrapperFactory(logger) + + @cached_property + def ncu_version(self) -> str | None: + """ + NCU version string (cached). + + Returns: + Version string like "2025.2.1.0" or None if unavailable + """ + try: + ncu_cmd = self.ncu_bin_path or "ncu" + result = subprocess.run( + [ncu_cmd, "--version"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + # Extract version from output + # Example: "Version 2025.2.1.0 (build 35987062) (public-release)" + for line in result.stdout.strip().split("\n"): + if line.startswith("Version "): + return line.split()[1] + return None + return None + except Exception as e: + self.logger.debug(f"Failed to get NCU version: {e}") + return None + + def _wait_with_backoff(self, attempt: int) -> None: + """ + Wait with exponential backoff before retrying. + + Args: + attempt: Current attempt number (1-indexed) + """ + wait_time = 2**attempt # Exponential backoff: 2, 4, 8 seconds + self.logger.warning(f"Retrying in {wait_time}s...") + time.sleep(wait_time) + + def profile_kernel( + self, + kernel_file: Path, + problem_file: Path, + round_num: int, + max_retries: int = 2, + ) -> ProfilerResults | None: + """ + Profile kernel with NCU (with retry logic). + + NCU profiling can fail due to GPU contention or transient issues. + This method automatically retries with exponential backoff. + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + round_num: Current optimization round number (used for file naming + and tracking which optimization iteration this profiling belongs to) + max_retries: Maximum number of retry attempts (default: 2) + + Returns: + ProfilerResults containing metrics and metadata, or None on failure + + Example: + >>> profiler = KernelProfiler(logger, artifacts_dir, logs_dir) + >>> results = profiler.profile_kernel( + ... Path("kernel.py"), Path("problem.py"), round_num=1 + ... ) + >>> if results: + ... print(f"DRAM throughput: {results.metrics['dram__throughput']}") + """ + # Create NCU wrapper script + wrapper_file = self.wrapper_factory.create_ncu_wrapper( + kernel_file, problem_file, self.artifacts_dir + ) + + for attempt in range(1, max_retries + 1): + try: + self.logger.info( + f"[Round {round_num}] NCU profiling attempt {attempt}/{max_retries}..." + ) + + # Profile with NCU + csv_file = f"ncu_round_{round_num}.csv" + csv_path = profile_triton_kernel( + benchmark_script=wrapper_file, + workdir=self.artifacts_dir, + out_csv=csv_file, + ncu_bin=self.ncu_bin_path, + launch_count=20, + timeout=self.ncu_timeout_seconds, + ) + + # Load and parse metrics + metrics_df = load_ncu_metrics(csv_path, select="last") + metrics = json.loads(metrics_to_prompt(metrics_df)) + + # Build ProfilerResults + results = ProfilerResults( + metrics_df=metrics_df, + metrics=metrics, + metadata=ProfilerMetadata( + kernel_file=str(kernel_file), + problem_file=str(problem_file), + round_num=round_num, + timestamp=datetime.utcnow().isoformat() + "Z", + ncu_version=self.ncu_version, + ), + ) + + # Save metrics with metadata + self._save_profiler_results(results) + + self.logger.info(f"✅ NCU profiling completed for round {round_num}") + return results + + except FileNotFoundError as e: + self.logger.error(f"❌ File not found during profiling: {e}") + return None + + except subprocess.TimeoutExpired: + is_final_attempt = attempt >= max_retries + if is_final_attempt: + self.logger.error( + f"❌ NCU profiling timed out after {self.ncu_timeout_seconds}s " + f"(final attempt {attempt}/{max_retries})" + ) + return None + else: + self.logger.debug( + f"NCU profiling timed out after {self.ncu_timeout_seconds}s " + f"(attempt {attempt}/{max_retries})" + ) + self._wait_with_backoff(attempt) + continue + + except json.JSONDecodeError as e: + is_final_attempt = attempt >= max_retries + if is_final_attempt: + self.logger.error( + f"❌ Failed to parse NCU metrics (final attempt): {e}" + ) + return None + else: + self.logger.debug( + f"Failed to parse NCU metrics (attempt {attempt}/{max_retries}): {e}" + ) + self._wait_with_backoff(attempt) + continue + + except Exception as e: + is_final_attempt = attempt >= max_retries + if is_final_attempt: + self.logger.error( + f"❌ Unexpected error during profiling (final attempt): {e}", + exc_info=True, + ) + return None + else: + self.logger.debug( + f"Unexpected error during profiling (attempt {attempt}/{max_retries}): {e}" + ) + self._wait_with_backoff(attempt) + continue + + # All retries exhausted + self.logger.error( + f"❌ NCU profiling failed after {max_retries} attempts for round {round_num}" + ) + return None + + def _save_profiler_results(self, results: ProfilerResults) -> None: + """ + Save profiling results with metadata to a JSON file. + + Args: + results: ProfilerResults to save + """ + metrics_file = ( + self.logs_dir / f"round{results.metadata.round_num:03d}_ncu_metrics.json" + ) + + with open(metrics_file, "w") as f: + f.write(results.to_json()) + + self.logger.debug(f"Saved metrics with metadata: {metrics_file}") diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_factory.py b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_factory.py new file mode 100644 index 0000000..e3c5a3c --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_factory.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NCU wrapper script generation for kernel profiling.""" + +import logging +from functools import cached_property +from pathlib import Path + +from jinja2 import Template + + +class NCUWrapperFactory: + """Factory for creating NCU wrapper scripts for profiling Triton kernels.""" + + # Template file path (relative to this file) + WRAPPER_TEMPLATE = Path(__file__).parent / "ncu_wrapper_template.j2" + + def __init__(self, logger: logging.Logger): + """ + Initialize the NCU wrapper factory. + + Args: + logger: Logger instance + """ + self.logger = logger + + @cached_property + def template(self) -> Template: + """ + Jinja2 template for wrapper script generation. + + Raises: + FileNotFoundError: If template file doesn't exist + """ + if not self.WRAPPER_TEMPLATE.exists(): + raise FileNotFoundError(f"Template not found: {self.WRAPPER_TEMPLATE}") + return Template(self.WRAPPER_TEMPLATE.read_text()) + + def create_ncu_wrapper( + self, + kernel_file: Path, + problem_file: Path, + output_dir: Path, + dtype_inference: bool = True, + model_extraction: bool = True, + ) -> Path: + """ + Create NCU wrapper script for profiling. + + The wrapper handles multiple kernel types: + - Standard kernels: kernel_function(*inputs) + - Conv/Linear kernels: Extracts weights from Model + - RMSNorm kernels: Passes init_inputs (features, eps) + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + output_dir: Directory to write wrapper script + dtype_inference: Enable automatic dtype inference from kernel source (default: True) + model_extraction: Enable model weight extraction for Conv/Linear kernels (default: True) + + Returns: + Path to created wrapper script + + Raises: + FileNotFoundError: If kernel_file or problem_file doesn't exist + OSError: If output_dir is not writable + """ + # Validate inputs + if not kernel_file.exists(): + raise FileNotFoundError(f"Kernel file not found: {kernel_file}") + if not problem_file.exists(): + raise FileNotFoundError(f"Problem file not found: {problem_file}") + + # Ensure output directory exists + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) + + wrapper_file = output_dir / "ncu_wrapper.py" + + # Render template + wrapper_content = self.template.render( + kernel_file_parent=repr(str(kernel_file.parent)), + problem_file_parent=repr(str(problem_file.parent)), + kernel_module=kernel_file.stem, + problem_module=problem_file.stem, + dtype_inference=dtype_inference, + model_extraction=model_extraction, + ) + + # Write wrapper file + wrapper_file.write_text(wrapper_content) + self.logger.info(f"Created NCU wrapper: {wrapper_file}") + return wrapper_file diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 new file mode 100644 index 0000000..310d8bb --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 @@ -0,0 +1,128 @@ +"""NCU profiling wrapper.""" +import sys +import torch +import inspect +sys.path.insert(0, str({{ kernel_file_parent }})) +sys.path.insert(0, str({{ problem_file_parent }})) + +from {{ kernel_module }} import kernel_function +from {{ problem_module }} import get_inputs, get_init_inputs + +# Try to import Model if it exists (for Conv, Linear, etc.) +try: + from {{ problem_module }} import Model + has_model = True +except ImportError: + has_model = False + +# Get inputs +inputs = get_inputs() + +# Get additional initialization inputs (e.g., features, eps for RMSNorm) +init_inputs = get_init_inputs() + +{% if dtype_inference %} +# Infer required dtype from kernel function signature/docstring +required_dtype = None +try: + # Try to get dtype from kernel function docstring or source + kernel_source = inspect.getsource(kernel_function) + if 'bfloat16' in kernel_source.lower(): + required_dtype = torch.bfloat16 + elif 'float16' in kernel_source.lower() or 'half' in kernel_source.lower(): + required_dtype = torch.float16 + elif 'float32' in kernel_source.lower(): + required_dtype = torch.float32 +except Exception: + pass +{% else %} +# Dtype inference disabled +required_dtype = None +{% endif %} + +# Prepare inputs: move to CUDA and convert dtype if needed +# IMPORTANT: Only convert floating-point tensors; preserve integer tensors (e.g., class labels) +cuda_inputs = [] +for inp in inputs: + if isinstance(inp, torch.Tensor): + # Move to CUDA if not already + if not inp.is_cuda: + inp = inp.cuda() + # Convert dtype if required, but ONLY for floating-point tensors + # Preserve integer/bool tensors (e.g., targets for classification) + if required_dtype is not None and inp.is_floating_point() and inp.dtype != required_dtype: + inp = inp.to(required_dtype) + cuda_inputs.append(inp) + else: + cuda_inputs.append(inp) + +{% if model_extraction %} +# Check if this is a conv-like kernel that needs a Model to extract weights +needs_model = False +try: + sig = inspect.signature(kernel_function) + params = list(sig.parameters.keys()) + # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) + if 'weight' in params: + needs_model = True +except Exception: + pass + +if needs_model and has_model and init_inputs: + # Initialize model to extract weight and bias + try: + model = Model(*init_inputs) if init_inputs else Model() + except TypeError: + # Model initialization failed, use fallback + model = Model() + + # Move model to CUDA and convert dtype + model = model.cuda() + if required_dtype is not None: + model = model.to(required_dtype) + + # Extract weight and bias from model + # Check various possible attribute names + weight = None + bias = None + layer = None + for attr_name in ['conv1', 'conv2', 'conv3', 'conv1d', 'conv2d', 'conv', 'conv3d', 'linear', 'fc']: + if hasattr(model, attr_name): + layer = getattr(model, attr_name) + if hasattr(layer, 'weight'): + weight = layer.weight + bias = layer.bias if hasattr(layer, 'bias') else None + break + + if weight is not None and layer is not None: + # Build arguments for kernel_function using keyword arguments + # to avoid positional argument misalignment issues + kernel_kwargs = {} # Empty dict + + # Add conv/linear-specific parameters if they exist + if hasattr(layer, 'stride'): + stride = layer.stride[0] if isinstance(layer.stride, (tuple, list)) else layer.stride + kernel_kwargs['stride'] = stride + if hasattr(layer, 'padding'): + padding = layer.padding[0] if isinstance(layer.padding, (tuple, list)) else layer.padding + kernel_kwargs['padding'] = padding + if hasattr(layer, 'dilation'): + dilation = layer.dilation[0] if isinstance(layer.dilation, (tuple, list)) else layer.dilation + kernel_kwargs['dilation'] = dilation + if hasattr(layer, 'groups'): + kernel_kwargs['groups'] = layer.groups + + # Call kernel with extracted parameters + output = kernel_function(cuda_inputs[0], weight, bias, **kernel_kwargs) + else: + # Fallback to original behavior + output = kernel_function(*cuda_inputs, *init_inputs) +else: +{% endif %} + # Run kernel with both tensor inputs and initialization inputs + # For example: RMSNorm needs kernel_function(x, features, eps) + # For cross-entropy: kernel_function(predictions, targets) + # where inputs come from get_inputs() and init_inputs from get_init_inputs() + output = kernel_function(*cuda_inputs, *init_inputs) + +print("Kernel executed successfully, output shape: " + str(output.shape if hasattr(output, 'shape') else type(output)))