diff --git a/kernel_perf_agent/README.md b/kernel_perf_agent/README.md new file mode 100644 index 0000000..e27dfec --- /dev/null +++ b/kernel_perf_agent/README.md @@ -0,0 +1,3 @@ +# Kernel Performance Utilities + +Low-level, reusable utilities for kernel optimization. diff --git a/kernel_perf_agent/__init__.py b/kernel_perf_agent/__init__.py new file mode 100644 index 0000000..1f49766 --- /dev/null +++ b/kernel_perf_agent/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Kernel Performance Agent package.""" + +# "Kernel Performance Agent package +__all__ = [] diff --git a/kernel_perf_agent/kernel_opt/profiler/__init__.py b/kernel_perf_agent/kernel_opt/profiler/__init__.py new file mode 100644 index 0000000..d177194 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/profiler/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""kernel_perf_agent package.""" + +# Kernel Perf Agent package +__all__ = [] diff --git a/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py new file mode 100644 index 0000000..4ce8568 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/profiler/ncu_profiler.py @@ -0,0 +1,484 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +NCU Profiling Module for Triton Kernels + +This module wraps three tasks: +1) Collect core metrics for Triton kernels with Nsight Compute into CSV (`profile_triton_kernel`). +2) Extract and clean those metrics into a DataFrame from the CSV (`load_ncu_metrics`). +3) Convert the metrics table into a string suitable for inclusion in an LLM prompt (`metrics_to_prompt`). + + +""" + +import json +import math +import os +import shutil +import subprocess +import sys +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Union + +import numpy as np +import pandas as pd + +# This selection of metrics is inspried by the CudaForge team (https://github.com/OptimAI-Lab/CudaForge/tree/main) + +METRICS = ",".join( + [ + "sm__cycles_active.avg", + "sm__warps_active.avg.pct_of_peak_sustained_active", + "launch__occupancy_limit_blocks", + "launch__occupancy_limit_registers", + "launch__occupancy_limit_shared_mem", + "launch__registers_per_thread", + "sm__inst_executed.sum", + "sm__inst_executed_pipe_fp32.avg.pct_of_peak_sustained_active", + "sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", + "dram__bytes_read.sum", + "dram__bytes_write.sum", + "dram__throughput.avg.pct_of_peak_sustained_elapsed", + "dram__bytes.sum.per_second", + "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed", + "launch__shared_mem_per_block_allocated", + "l1tex__t_sector_hit_rate.pct", + "l1tex__throughput.avg.pct_of_peak_sustained_active", + "lts__t_sector_hit_rate.pct", + "lts__throughput.avg.pct_of_peak_sustained_active", + "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", + "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", + "smsp__warp_issue_stalled_memory_dependency_per_warp_active.pct", + "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", + "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", + "smsp__warp_issue_stalled_barrier_per_warp_active.pct", + "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", + "smsp__sass_average_branch_targets_threads_uniform.pct", + ] +) + +# List version for convenient header selection +METRIC_COLUMNS: List[str] = [s.strip() for s in METRICS.split(",")] + + +class MetricSelectionPolicy(Enum): + """Policy for selecting rows when multiple rows exist for a kernel.""" + + FIRST = "first" + LAST = "last" + MAX_CYCLES = "max_cycles" + + +def profile_triton_kernel( + benchmark_script: Path, + workdir: Path, + out_csv: str = "ncu_output.csv", + python_executable: Optional[str] = None, + ncu_bin: Optional[str] = None, + launch_count: int = 20, + timeout: int = 120, + use_sudo: bool = False, +) -> Path: + """ + Profile a Triton kernel using NCU. + + Args: + benchmark_script: Path to benchmark script that calls the kernel + workdir: Working directory for execution + out_csv: Output CSV filename + python_executable: Python executable to use (default: sys.executable) + ncu_bin: Path to NCU binary (default: auto-detect) + launch_count: Number of kernel launches to profile + timeout: Timeout in seconds for NCU execution + use_sudo: Whether to run NCU with sudo. Can also be enabled via + KERNELAGENT_NCU_USE_SUDO=1 environment variable. Default: False + + Returns: + Path to output CSV file + + Raises: + RuntimeError: If NCU profiling fails + FileNotFoundError: If NCU binary or output CSV not found + """ + # Check for environment variable override + use_sudo = use_sudo or os.environ.get("KERNELAGENT_NCU_USE_SUDO", "0") == "1" + + # Resolve paths + if python_executable is None: + python_executable = sys.executable + + if ncu_bin is None: + ncu_bin = shutil.which("ncu") or "/usr/local/cuda/bin/ncu" + + if not Path(ncu_bin).exists(): + raise FileNotFoundError(f"NCU binary not found: {ncu_bin}") + + csv_path = (workdir / out_csv).resolve() + benchmark_script = benchmark_script.resolve() + + if not benchmark_script.exists(): + raise FileNotFoundError(f"Benchmark script not found: {benchmark_script}") + + # Preserve important environment variables + env = os.environ.copy() + + # Add Triton-specific environment variables + env["TRITON_CACHE_DIR"] = str(workdir / ".triton_cache") + + preserve = ",".join( + [ + "PATH", + "LD_LIBRARY_PATH", + "CUDA_VISIBLE_DEVICES", + "PYTHONPATH", + "TRITON_CACHE_DIR", + "TORCH_EXTENSIONS_DIR", + "CONDA_PREFIX", + "CONDA_DEFAULT_ENV", + ] + ) + + # Build NCU command + cmd = [] + if use_sudo: + cmd.extend(["sudo", "-E", f"--preserve-env={preserve}"]) + + cmd.extend( + [ + ncu_bin, + "--csv", + "--page=raw", + "--kernel-name-base=demangled", + "--target-processes=all", + "--replay-mode=kernel", + "--profile-from-start=on", + f"--log-file={str(csv_path)}", + f"--metrics={METRICS}", + "--launch-skip=0", + f"--launch-count={launch_count}", + python_executable, + str(benchmark_script), + ] + ) + + print("[NCU] Running profiling...") + print(f"[NCU] Benchmark: {benchmark_script.name}") + print(f"[NCU] Output: {csv_path}") + print(f"[NCU] Using sudo: {use_sudo}") + print(f"[NCU] Command: {' '.join(cmd[:10])}... (truncated)") + + try: + result = subprocess.run( + cmd, + cwd=str(workdir), + env=env, + capture_output=True, + text=True, + timeout=timeout, + ) + + if result.returncode != 0: + error_msg = result.stderr or result.stdout + + # Check for common permission errors + permission_errors = [ + "permission", + "ERR_NVGPUCTRPERM", + "profiling permissions", + "requires root", + "access denied", + ] + is_permission_error = any( + err.lower() in error_msg.lower() for err in permission_errors + ) + + if is_permission_error and not use_sudo: + raise RuntimeError( + f"NCU requires elevated permissions on this system.\n" + f"Options to resolve:\n" + f" 1. Rerun with use_sudo=True\n" + f" 2. Set environment variable: KERNELAGENT_NCU_USE_SUDO=1\n" + f" 3. Configure driver permissions: " + f"Original error:\n{error_msg[:500]}" + ) + + raise RuntimeError( + f"NCU profiling failed with return code {result.returncode}:\n{error_msg[:500]}" + ) + + if not csv_path.exists(): + raise FileNotFoundError(f"NCU did not create output CSV: {csv_path}") + + # Check if CSV has content + csv_size = csv_path.stat().st_size + if csv_size < 100: + raise RuntimeError( + f"NCU CSV file is too small ({csv_size} bytes), likely empty" + ) + + print("[NCU] ✓ Profiling completed successfully") + print(f"[NCU] ✓ CSV written: {csv_path} ({csv_size} bytes)") + return csv_path + + except subprocess.TimeoutExpired: + raise RuntimeError(f"NCU profiling timed out after {timeout} seconds") + except RuntimeError: + raise + except Exception as e: + raise RuntimeError(f"NCU profiling failed: {e}") + + +def _apply_selection_policy( + df: pd.DataFrame, + policy: MetricSelectionPolicy, +) -> pd.DataFrame: + """ + Apply selection policy to choose a single row from a DataFrame. + + Args: + df: DataFrame with one or more rows + policy: Selection policy to apply + + Returns: + DataFrame with a single row based on the policy + """ + if df.empty: + return df + + if len(df) == 1: + return df + + if policy == MetricSelectionPolicy.FIRST: + return df.iloc[[0]] + elif policy == MetricSelectionPolicy.LAST: + return df.iloc[[-1]] + elif policy == MetricSelectionPolicy.MAX_CYCLES: + if "sm__cycles_active.avg" in df.columns: + return df.sort_values("sm__cycles_active.avg", ascending=False).head(1) + # Fallback to last if cycles column not available + return df.iloc[[-1]] + else: + # Fallback to last for unknown policies + return df.iloc[[-1]] + + +def _filter_by_kernel_names( + df: pd.DataFrame, + name_list: Sequence[str], + policy: MetricSelectionPolicy, + keep_cols: List[str], +) -> pd.DataFrame: + """ + Filter DataFrame by kernel names with substring matching. + + Args: + df: DataFrame with NCU metrics + name_list: List of kernel name substrings to match + policy: Selection policy when multiple rows match + keep_cols: Columns to preserve in empty result + + Returns: + Filtered DataFrame with one row per matched kernel name + """ + results = [] + for name in name_list: + # Use contains match instead of exact equality (for Triton's long kernel names) + matched = df[ + df["Kernel Name"].astype(str).str.contains(name, regex=False, na=False) + ] + if matched.empty: + continue + + row = _apply_selection_policy(matched, policy) + results.append(row) + + if results: + return pd.concat(results, ignore_index=True) + else: + return pd.DataFrame(columns=keep_cols) + + +def load_ncu_metrics( + csv_path: Union[str, Path], + columns: Optional[Sequence[str]] = None, + extra_keep: Optional[Sequence[str]] = ("Kernel Name",), + coerce_numeric: bool = True, + name_list: Optional[Sequence[str]] = None, + select: Union[str, MetricSelectionPolicy] = MetricSelectionPolicy.LAST, +) -> pd.DataFrame: + """ + Load and parse NCU metrics from CSV file. + + Args: + csv_path: Path to NCU CSV output + columns: Specific metric columns to load (default: all METRIC_COLUMNS) + extra_keep: Additional columns to keep (e.g., "Kernel Name") + coerce_numeric: Convert metric values to numeric + name_list: Filter by kernel names (substring match) + select: Selection policy when multiple rows per name. + Can be MetricSelectionPolicy enum or string ("first", "last", "max_cycles") + + Returns: + DataFrame with parsed metrics + + Raises: + FileNotFoundError: If CSV file not found + ValueError: If no requested columns found in CSV or invalid select value + """ + csv_path = Path(csv_path) + if not csv_path.exists(): + raise FileNotFoundError(f"CSV not found: {csv_path}") + + # Convert string to enum if needed + if isinstance(select, str): + try: + policy = MetricSelectionPolicy(select) + except ValueError: + raise ValueError( + f"Invalid select value: {select}. " + f"Must be one of: {[p.value for p in MetricSelectionPolicy]}" + ) + else: + policy = select + + df = pd.read_csv(csv_path, comment="=", low_memory=False) + + metric_cols = list(columns) if columns is not None else METRIC_COLUMNS + keep_cols: List[str] = [] + if extra_keep: + keep_cols.extend([c for c in extra_keep if c in df.columns]) + keep_cols.extend([c for c in metric_cols if c in df.columns]) + + if not keep_cols: + raise ValueError("No requested columns found in the CSV header.") + + sub = df[keep_cols].copy() + + # Drop the units row (first row often contains units like "%", "inst", etc.) + if len(sub) > 0: + first_row_str = sub.iloc[0].astype(str).str.lower() + unit_tokens = ("%", "inst", "cycle", "block", "register", "register/thread") + if first_row_str.apply(lambda x: any(tok in x for tok in unit_tokens)).any(): + sub = sub.iloc[1:].reset_index(drop=True) + + # Coerce metrics to numeric + if coerce_numeric: + metric_in_sub = [c for c in metric_cols if c in sub.columns] + sub[metric_in_sub] = ( + sub[metric_in_sub] + .replace({",": "", "%": ""}, regex=True) + .apply(pd.to_numeric, errors="coerce") + ) + + # Filter by kernel name list if provided + if name_list: + sub = _filter_by_kernel_names(sub, name_list, policy, keep_cols) + else: + # Apply selection to all rows if no name filter + sub = _apply_selection_policy(sub, policy) + + return sub + + +def metrics_to_prompt( + df: pd.DataFrame, + title: str = "GPU Profiling Metrics:", + key_by: str = "Kernel Name", + round_digits: Optional[int] = 3, + compact: bool = False, + keep_cols: Optional[List[str]] = None, +) -> str: + """ + Convert NCU metrics DataFrame to JSON string for LLM prompts. + + Returns JSON in format: + { + "": { "": , ... } + } + Args: + df: DataFrame with NCU metrics + title: Title for the metrics section (not included in output) + key_by: Column to use as key (usually "Kernel Name") + round_digits: Number of decimal places for rounding + compact: If True, use compact JSON (no indentation) + keep_cols: Specific columns to include in output + + Returns: + JSON string with metrics + """ + + def _safe(v: Any) -> Any: + """Convert values to JSON-safe format.""" + if v is None or (isinstance(v, float) and math.isnan(v)): + return None + if isinstance(v, (pd.Timestamp, pd.Timedelta, pd.Interval)): + return str(v) + if isinstance(v, np.generic): + v = v.item() + if isinstance(v, float) and math.isinf(v): + return "inf" if v > 0 else "-inf" + if isinstance(v, float) and round_digits is not None: + return round(v, round_digits) + return v + + # Empty table + if df is None or df.empty: + return "{}" + + cols = list(df.columns) + + # Round numeric columns + if round_digits is not None: + num_cols = df.select_dtypes(include="number").columns + if len(num_cols) > 0: + df = df.copy() + df[num_cols] = df[num_cols].round(round_digits) + + # If key column is missing, return a list of rows + if key_by not in cols: + rows = [ + {k: _safe(v) for k, v in rec.items()} + for rec in df.to_dict(orient="records") + ] + return json.dumps(rows, ensure_ascii=False, indent=None if compact else 2) + + # Determine value columns + value_cols = [c for c in cols if c != key_by] + if keep_cols is not None: + value_cols = [c for c in value_cols if c in keep_cols] + + data: Dict[str, Any] = {} + for rec in df[[key_by] + value_cols].to_dict(orient="records"): + k = str(rec.pop(key_by)) + val_obj = {ck: _safe(cv) for ck, cv in rec.items()} + if k in data: + # Multiple rows for same key - convert to list + if isinstance(data[k], list): + data[k].append(val_obj) + else: + data[k] = [data[k], val_obj] + else: + data[k] = val_obj + + return json.dumps(data, ensure_ascii=False, indent=None if compact else 2) + + +if __name__ == "__main__": + print("ncu_profiler module loaded.") + print("Import its functions in your scripts:") + print( + " from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import profile_triton_kernel, load_ncu_metrics, metrics_to_prompt" + ) diff --git a/pyproject.toml b/pyproject.toml index 2d5382f..a31dc3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ pipeline-ui = "scripts.pipeline_ui:main" "Repository" = "https://github.com/pytorch-labs/KernelAgent" [tool.setuptools.packages.find] -include = ["triton_kernel_agent*", "Fuser*", "scripts", "utils"] +include = ["triton_kernel_agent*", "kernel_perf_agent*", "Fuser*", "scripts", "utils"] [tool.setuptools] include-package-data = true diff --git a/triton_kernel_agent/opt_worker_component/README.md b/triton_kernel_agent/opt_worker_component/README.md new file mode 100644 index 0000000..76026b2 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/README.md @@ -0,0 +1,12 @@ +# Opt Worker Components + +High-level components used by `OptimizationWorker`. + +These components are **thin wrappers** around low-level utilities +from `kernel_perf_agent` that provide: +- Logging integration +- Error handling +- Worker-specific configuration + +## Dependency Flow +worker_components → kernel_perf_agent (implementation) diff --git a/triton_kernel_agent/opt_worker_component/profiling/__init__.py b/triton_kernel_agent/opt_worker_component/profiling/__init__.py new file mode 100644 index 0000000..8f8d18b --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Profiling infrastructure for NCU-based kernel analysis.""" + +from .kernel_profiler import KernelProfiler +from .ncu_wrapper_factory import NCUWrapperFactory + +__all__ = ["NCUWrapperFactory", "KernelProfiler"] diff --git a/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py new file mode 100644 index 0000000..682c4e7 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/kernel_profiler.py @@ -0,0 +1,289 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Profiles Triton kernels using NVIDIA Nsight Compute (NCU).""" + +import json +import logging +import subprocess +import time +from dataclasses import asdict, dataclass +from datetime import datetime +from functools import cached_property +from pathlib import Path +from typing import Any, Dict + +import pandas as pd + +from kernel_perf_agent.kernel_opt.profiler.ncu_profiler import ( + load_ncu_metrics, + metrics_to_prompt, + profile_triton_kernel, +) + +from triton_kernel_agent.opt_worker_component.profiling.ncu_wrapper_factory import ( + NCUWrapperFactory, +) + +# Default timeout for NCU profiling in seconds +DEFAULT_NCU_TIMEOUT_SECONDS = 120 + + +@dataclass +class ProfilerMetadata: + """Metadata about a profiling run.""" + + kernel_file: str + problem_file: str + round_num: int + timestamp: str + ncu_version: str | None + + +@dataclass +class ProfilerResults: + """ + Results from a kernel profiling run. + + This dataclass encapsulates both the metrics DataFrame and the parsed + metrics dictionary, along with metadata about the profiling run. + """ + + metrics_df: pd.DataFrame + metrics: Dict[str, Any] + metadata: ProfilerMetadata + + def to_dict(self) -> Dict[str, Any]: + """Convert results to a dictionary for JSON serialization.""" + return { + "metrics": self.metrics, + "metadata": asdict(self.metadata), + } + + def to_json(self) -> str: + """Convert results to a JSON string.""" + return json.dumps(self.to_dict(), indent=2) + + +class KernelProfiler: + def __init__( + self, + logger: logging.Logger, + artifacts_dir: Path, + logs_dir: Path, + ncu_bin_path: str | None = None, + ncu_timeout_seconds: int = DEFAULT_NCU_TIMEOUT_SECONDS, + ): + """ + Initialize the kernel profiler. + + Args: + logger: Logger instance + artifacts_dir: Directory for optimization artifacts + logs_dir: Directory for saving profiling logs + ncu_bin_path: Path to NCU binary (auto-detect if None) + ncu_timeout_seconds: Timeout for NCU profiling in seconds + """ + self.logger = logger + self.artifacts_dir = artifacts_dir + self.logs_dir = logs_dir + self.ncu_bin_path = ncu_bin_path + self.ncu_timeout_seconds = ncu_timeout_seconds + self.wrapper_factory = NCUWrapperFactory(logger) + + @cached_property + def ncu_version(self) -> str | None: + """ + NCU version string (cached). + + Returns: + Version string like "2025.2.1.0" or None if unavailable + """ + try: + ncu_cmd = self.ncu_bin_path or "ncu" + result = subprocess.run( + [ncu_cmd, "--version"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + # Extract version from output + # Example: "Version 2025.2.1.0 (build 35987062) (public-release)" + for line in result.stdout.strip().split("\n"): + if line.startswith("Version "): + return line.split()[1] + return None + return None + except Exception as e: + self.logger.debug(f"Failed to get NCU version: {e}") + return None + + def _wait_with_backoff(self, attempt: int) -> None: + """ + Wait with exponential backoff before retrying. + + Args: + attempt: Current attempt number (1-indexed) + """ + wait_time = 2**attempt # Exponential backoff: 2, 4, 8 seconds + self.logger.warning(f"Retrying in {wait_time}s...") + time.sleep(wait_time) + + def profile_kernel( + self, + kernel_file: Path, + problem_file: Path, + round_num: int, + max_retries: int = 2, + ) -> ProfilerResults | None: + """ + Profile kernel with NCU (with retry logic). + + NCU profiling can fail due to GPU contention or transient issues. + This method automatically retries with exponential backoff. + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + round_num: Current optimization round number (used for file naming + and tracking which optimization iteration this profiling belongs to) + max_retries: Maximum number of retry attempts (default: 2) + + Returns: + ProfilerResults containing metrics and metadata, or None on failure + + Example: + >>> profiler = KernelProfiler(logger, artifacts_dir, logs_dir) + >>> results = profiler.profile_kernel( + ... Path("kernel.py"), Path("problem.py"), round_num=1 + ... ) + >>> if results: + ... print(f"DRAM throughput: {results.metrics['dram__throughput']}") + """ + # Create NCU wrapper script + wrapper_file = self.wrapper_factory.create_ncu_wrapper( + kernel_file, problem_file, self.artifacts_dir + ) + + for attempt in range(1, max_retries + 1): + try: + self.logger.info( + f"[Round {round_num}] NCU profiling attempt {attempt}/{max_retries}..." + ) + + # Profile with NCU + csv_file = f"ncu_round_{round_num}.csv" + csv_path = profile_triton_kernel( + benchmark_script=wrapper_file, + workdir=self.artifacts_dir, + out_csv=csv_file, + ncu_bin=self.ncu_bin_path, + launch_count=20, + timeout=self.ncu_timeout_seconds, + ) + + # Load and parse metrics + metrics_df = load_ncu_metrics(csv_path, select="last") + metrics = json.loads(metrics_to_prompt(metrics_df)) + + # Build ProfilerResults + results = ProfilerResults( + metrics_df=metrics_df, + metrics=metrics, + metadata=ProfilerMetadata( + kernel_file=str(kernel_file), + problem_file=str(problem_file), + round_num=round_num, + timestamp=datetime.utcnow().isoformat() + "Z", + ncu_version=self.ncu_version, + ), + ) + + # Save metrics with metadata + self._save_profiler_results(results) + + self.logger.info(f"✅ NCU profiling completed for round {round_num}") + return results + + except FileNotFoundError as e: + self.logger.error(f"❌ File not found during profiling: {e}") + return None + + except subprocess.TimeoutExpired: + is_final_attempt = attempt >= max_retries + if is_final_attempt: + self.logger.error( + f"❌ NCU profiling timed out after {self.ncu_timeout_seconds}s " + f"(final attempt {attempt}/{max_retries})" + ) + return None + else: + self.logger.debug( + f"NCU profiling timed out after {self.ncu_timeout_seconds}s " + f"(attempt {attempt}/{max_retries})" + ) + self._wait_with_backoff(attempt) + continue + + except json.JSONDecodeError as e: + is_final_attempt = attempt >= max_retries + if is_final_attempt: + self.logger.error( + f"❌ Failed to parse NCU metrics (final attempt): {e}" + ) + return None + else: + self.logger.debug( + f"Failed to parse NCU metrics (attempt {attempt}/{max_retries}): {e}" + ) + self._wait_with_backoff(attempt) + continue + + except Exception as e: + is_final_attempt = attempt >= max_retries + if is_final_attempt: + self.logger.error( + f"❌ Unexpected error during profiling (final attempt): {e}", + exc_info=True, + ) + return None + else: + self.logger.debug( + f"Unexpected error during profiling (attempt {attempt}/{max_retries}): {e}" + ) + self._wait_with_backoff(attempt) + continue + + # All retries exhausted + self.logger.error( + f"❌ NCU profiling failed after {max_retries} attempts for round {round_num}" + ) + return None + + def _save_profiler_results(self, results: ProfilerResults) -> None: + """ + Save profiling results with metadata to a JSON file. + + Args: + results: ProfilerResults to save + """ + metrics_file = ( + self.logs_dir / f"round{results.metadata.round_num:03d}_ncu_metrics.json" + ) + + with open(metrics_file, "w") as f: + f.write(results.to_json()) + + self.logger.debug(f"Saved metrics with metadata: {metrics_file}") diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_factory.py b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_factory.py new file mode 100644 index 0000000..e3c5a3c --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_factory.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NCU wrapper script generation for kernel profiling.""" + +import logging +from functools import cached_property +from pathlib import Path + +from jinja2 import Template + + +class NCUWrapperFactory: + """Factory for creating NCU wrapper scripts for profiling Triton kernels.""" + + # Template file path (relative to this file) + WRAPPER_TEMPLATE = Path(__file__).parent / "ncu_wrapper_template.j2" + + def __init__(self, logger: logging.Logger): + """ + Initialize the NCU wrapper factory. + + Args: + logger: Logger instance + """ + self.logger = logger + + @cached_property + def template(self) -> Template: + """ + Jinja2 template for wrapper script generation. + + Raises: + FileNotFoundError: If template file doesn't exist + """ + if not self.WRAPPER_TEMPLATE.exists(): + raise FileNotFoundError(f"Template not found: {self.WRAPPER_TEMPLATE}") + return Template(self.WRAPPER_TEMPLATE.read_text()) + + def create_ncu_wrapper( + self, + kernel_file: Path, + problem_file: Path, + output_dir: Path, + dtype_inference: bool = True, + model_extraction: bool = True, + ) -> Path: + """ + Create NCU wrapper script for profiling. + + The wrapper handles multiple kernel types: + - Standard kernels: kernel_function(*inputs) + - Conv/Linear kernels: Extracts weights from Model + - RMSNorm kernels: Passes init_inputs (features, eps) + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + output_dir: Directory to write wrapper script + dtype_inference: Enable automatic dtype inference from kernel source (default: True) + model_extraction: Enable model weight extraction for Conv/Linear kernels (default: True) + + Returns: + Path to created wrapper script + + Raises: + FileNotFoundError: If kernel_file or problem_file doesn't exist + OSError: If output_dir is not writable + """ + # Validate inputs + if not kernel_file.exists(): + raise FileNotFoundError(f"Kernel file not found: {kernel_file}") + if not problem_file.exists(): + raise FileNotFoundError(f"Problem file not found: {problem_file}") + + # Ensure output directory exists + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) + + wrapper_file = output_dir / "ncu_wrapper.py" + + # Render template + wrapper_content = self.template.render( + kernel_file_parent=repr(str(kernel_file.parent)), + problem_file_parent=repr(str(problem_file.parent)), + kernel_module=kernel_file.stem, + problem_module=problem_file.stem, + dtype_inference=dtype_inference, + model_extraction=model_extraction, + ) + + # Write wrapper file + wrapper_file.write_text(wrapper_content) + self.logger.info(f"Created NCU wrapper: {wrapper_file}") + return wrapper_file diff --git a/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 new file mode 100644 index 0000000..310d8bb --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/profiling/ncu_wrapper_template.j2 @@ -0,0 +1,128 @@ +"""NCU profiling wrapper.""" +import sys +import torch +import inspect +sys.path.insert(0, str({{ kernel_file_parent }})) +sys.path.insert(0, str({{ problem_file_parent }})) + +from {{ kernel_module }} import kernel_function +from {{ problem_module }} import get_inputs, get_init_inputs + +# Try to import Model if it exists (for Conv, Linear, etc.) +try: + from {{ problem_module }} import Model + has_model = True +except ImportError: + has_model = False + +# Get inputs +inputs = get_inputs() + +# Get additional initialization inputs (e.g., features, eps for RMSNorm) +init_inputs = get_init_inputs() + +{% if dtype_inference %} +# Infer required dtype from kernel function signature/docstring +required_dtype = None +try: + # Try to get dtype from kernel function docstring or source + kernel_source = inspect.getsource(kernel_function) + if 'bfloat16' in kernel_source.lower(): + required_dtype = torch.bfloat16 + elif 'float16' in kernel_source.lower() or 'half' in kernel_source.lower(): + required_dtype = torch.float16 + elif 'float32' in kernel_source.lower(): + required_dtype = torch.float32 +except Exception: + pass +{% else %} +# Dtype inference disabled +required_dtype = None +{% endif %} + +# Prepare inputs: move to CUDA and convert dtype if needed +# IMPORTANT: Only convert floating-point tensors; preserve integer tensors (e.g., class labels) +cuda_inputs = [] +for inp in inputs: + if isinstance(inp, torch.Tensor): + # Move to CUDA if not already + if not inp.is_cuda: + inp = inp.cuda() + # Convert dtype if required, but ONLY for floating-point tensors + # Preserve integer/bool tensors (e.g., targets for classification) + if required_dtype is not None and inp.is_floating_point() and inp.dtype != required_dtype: + inp = inp.to(required_dtype) + cuda_inputs.append(inp) + else: + cuda_inputs.append(inp) + +{% if model_extraction %} +# Check if this is a conv-like kernel that needs a Model to extract weights +needs_model = False +try: + sig = inspect.signature(kernel_function) + params = list(sig.parameters.keys()) + # Check if kernel expects 'weight' parameter (common for Conv, Linear, etc.) + if 'weight' in params: + needs_model = True +except Exception: + pass + +if needs_model and has_model and init_inputs: + # Initialize model to extract weight and bias + try: + model = Model(*init_inputs) if init_inputs else Model() + except TypeError: + # Model initialization failed, use fallback + model = Model() + + # Move model to CUDA and convert dtype + model = model.cuda() + if required_dtype is not None: + model = model.to(required_dtype) + + # Extract weight and bias from model + # Check various possible attribute names + weight = None + bias = None + layer = None + for attr_name in ['conv1', 'conv2', 'conv3', 'conv1d', 'conv2d', 'conv', 'conv3d', 'linear', 'fc']: + if hasattr(model, attr_name): + layer = getattr(model, attr_name) + if hasattr(layer, 'weight'): + weight = layer.weight + bias = layer.bias if hasattr(layer, 'bias') else None + break + + if weight is not None and layer is not None: + # Build arguments for kernel_function using keyword arguments + # to avoid positional argument misalignment issues + kernel_kwargs = {} # Empty dict + + # Add conv/linear-specific parameters if they exist + if hasattr(layer, 'stride'): + stride = layer.stride[0] if isinstance(layer.stride, (tuple, list)) else layer.stride + kernel_kwargs['stride'] = stride + if hasattr(layer, 'padding'): + padding = layer.padding[0] if isinstance(layer.padding, (tuple, list)) else layer.padding + kernel_kwargs['padding'] = padding + if hasattr(layer, 'dilation'): + dilation = layer.dilation[0] if isinstance(layer.dilation, (tuple, list)) else layer.dilation + kernel_kwargs['dilation'] = dilation + if hasattr(layer, 'groups'): + kernel_kwargs['groups'] = layer.groups + + # Call kernel with extracted parameters + output = kernel_function(cuda_inputs[0], weight, bias, **kernel_kwargs) + else: + # Fallback to original behavior + output = kernel_function(*cuda_inputs, *init_inputs) +else: +{% endif %} + # Run kernel with both tensor inputs and initialization inputs + # For example: RMSNorm needs kernel_function(x, features, eps) + # For cross-entropy: kernel_function(predictions, targets) + # where inputs come from get_inputs() and init_inputs from get_init_inputs() + output = kernel_function(*cuda_inputs, *init_inputs) + +print("Kernel executed successfully, output shape: " + str(output.shape if hasattr(output, 'shape') else type(output)))