diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py b/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py new file mode 100644 index 0000000..61e1d08 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/benchmarking/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarking infrastructure for kernel performance measurement.""" + +from .benchmark import Benchmark + +__all__ = ["Benchmark"] diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py new file mode 100644 index 0000000..8ee39e4 --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/benchmarking/benchmark.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified benchmarking for Triton kernels and PyTorch baselines. + +This module consolidates kernel and PyTorch benchmarking with improved timing +utilities, L2 cache clearing, and comprehensive statistics. +""" + +import json +import logging +import subprocess +import sys +import traceback +from pathlib import Path +from typing import Any, Optional + +import torch + +from triton_kernel_agent.opt_worker_component.benchmarking.timing import ( + compute_timing_stats, + prepare_pytorch_model, + time_with_cuda_events, + time_with_triton_do_bench, +) + + +class BenchmarkLockManager: + """Manages GPU benchmarking locks to prevent resource contention.""" + + def __init__(self, lock: Any, worker_id: int, logger: logging.Logger): + """Initialize the lock manager. + + Args: + lock: Shared multiprocessing lock for serializing GPU access + worker_id: Worker ID for logging + logger: Logger instance + """ + self.lock = lock + self.worker_id = worker_id + self.logger = logger + + def __enter__(self): + """Acquire the benchmarking lock.""" + self.logger.info(f"⏳ Waiting for benchmark lock (worker {self.worker_id})...") + self.lock.acquire() + self.logger.info(f"🔓 Acquired benchmark lock (worker {self.worker_id})") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Release the benchmarking lock.""" + try: + self.lock.release() + self.logger.info(f"🔒 Released benchmark lock (worker {self.worker_id})") + except Exception as e: + self.logger.warning(f"Failed to release benchmark lock: {e}") + return False + + +class Benchmark: + """Unified benchmark for Triton kernels and PyTorch baselines. + + Supports two modes: + 1. Subprocess mode: Runs benchmarks in isolated processes (for compatibility) + 2. Direct mode: Uses in-process timing utilities (faster, more flexible) + """ + + def __init__( + self, + logger: logging.Logger, + artifacts_dir: Path, + benchmark_lock: Any, + worker_id: int = 0, + warmup: int = 25, + repeat: int = 100, + timing_method: str = "cuda_event", + ): + """Initialize the benchmark. + + Args: + logger: Logger instance + artifacts_dir: Directory for benchmark artifacts + benchmark_lock: Shared lock to serialize GPU benchmarking + worker_id: Worker ID + warmup: Number of warmup iterations (or warmup time in ms for do_bench) + repeat: Number of repeat iterations (or rep time in ms for do_bench) + timing_method: Timing method ("cuda_event", "do_bench", "host_time") + """ + self.logger = logger + self.artifacts_dir = artifacts_dir + self.lock_manager = BenchmarkLockManager(benchmark_lock, worker_id, logger) + self.warmup = warmup + self.repeat = repeat + self.timing_method = timing_method + + def benchmark_kernel( + self, + kernel_file: Path, + problem_file: Path, + baseline_file: Optional[Path] = None, + ) -> dict[str, Any]: + """Benchmark Triton kernel performance using subprocess isolation. + + Uses subprocess for crash protection of potentially buggy kernels. + + Args: + kernel_file: Path to kernel file + problem_file: Path to problem file + baseline_file: Path to baseline kernel (optional) + + Returns: + Dictionary with benchmark results: + - time_ms: Mean time in ms + - speedup: Speedup vs baseline + """ + try: + with self.lock_manager: + results_json = self.artifacts_dir / "benchmark_results.json" + benchmark_script = Path(__file__).parent / "kernel_subprocess.py" + + cmd = [ + sys.executable, + str(benchmark_script), + "--problem", + str(problem_file), + "--kernel", + str(kernel_file), + "--warmup", + str(self.warmup), + "--repeat", + str(self.repeat), + "--json", + str(results_json), + "--quiet", + ] + + if baseline_file: + cmd.extend(["--baseline"]) + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=300, + ) + + if result.returncode != 0: + error_msg = ( + result.stderr.strip() + or result.stdout.strip() + or "Unknown error" + ) + self.logger.error(f"Kernel benchmark failed: {error_msg}") + return {"time_ms": float("inf"), "speedup": 0.0} + + with open(results_json, "r") as f: + results = json.load(f) + + kernel_name = kernel_file.stem + kernel_results = results.get("kernels", {}).get(kernel_name, {}) + + return { + "time_ms": kernel_results.get("time_ms", float("inf")), + "speedup": kernel_results.get("speedup", 1.0), + } + + except Exception as e: + self.logger.error(f"Kernel benchmark failed: {e}") + return {"time_ms": float("inf"), "speedup": 0.0} + + def benchmark_pytorch( + self, + problem_file: Path, + dtype: Optional[torch.dtype] = None, + ) -> dict[str, Any]: + """Benchmark PyTorch baseline using direct in-process timing. + + Always uses direct mode (PyTorch is stable, doesn't need subprocess isolation). + + Args: + problem_file: Path to problem file (must define Model class and get_inputs()) + dtype: Data type to use (default: auto-detect based on model parameters) + + Returns: + Dictionary with benchmark results: + - time_ms: Mean time in ms + - stats: Full timing statistics (mean, std, min, max, all_times, etc.) + """ + try: + with self.lock_manager: + model, inputs = prepare_pytorch_model( + problem_file=problem_file, + device="cuda", + dtype=dtype, + ) + + if self.timing_method == "do_bench": + times = time_with_triton_do_bench( + lambda: model(*inputs), + [], + warmup=self.warmup, + rep=self.repeat, + verbose=False, + ) + else: # cuda_event + times = time_with_cuda_events( + lambda: model(*inputs), + [], + num_warmup=self.warmup, + num_trials=self.repeat, + clear_cache=True, + verbose=False, + ) + + stats = compute_timing_stats(times) + + return { + "time_ms": stats["mean"], + "stats": stats, + } + + except Exception as e: + self.logger.error(f"PyTorch baseline benchmark failed: {e}") + self.logger.error(traceback.format_exc()) + return {"time_ms": float("inf")} diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py new file mode 100644 index 0000000..733216c --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/benchmarking/kernel_subprocess.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A task-agnostic, profiling-only benchmark script for Triton kernels. +This script ONLY benchmarks candidate kernels without correctness checks. +Assumes correctness has been verified upstream. + +Design: +- Skips correctness verification (assumes already verified) +- Only runs candidate kernels +- Fast profiling for iterative optimization loops +- Uses shared utilities from timing.py +""" + +from __future__ import annotations + +import argparse +import inspect +import json +import sys +from pathlib import Path + +from timing import ( + import_module, + load_kernel_function, + load_problem_interface, + prepare_inputs, +) +from typing import Any, Callable, Tuple + +import torch +import triton.testing as tt + + +def _extract_model_params( + model: torch.nn.Module, +) -> tuple[torch.Tensor | None, torch.Tensor | None, dict[str, Any]]: + """Extract weight, bias, and layer parameters from a PyTorch model. + + Searches for Conv or Linear layers and extracts their parameters. + + Args: + model: PyTorch model to extract parameters from + + Returns: + Tuple of (weight, bias, layer_kwargs) where layer_kwargs contains + stride, padding, dilation, groups if applicable + """ + for _, module in model.named_modules(): + if isinstance( + module, + (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.Linear), + ): + if hasattr(module, "weight") and module.weight is not None: + weight = module.weight + bias = getattr(module, "bias", None) + + layer_kwargs: dict[str, Any] = {} + if hasattr(module, "stride"): + stride = module.stride + layer_kwargs["stride"] = ( + stride[0] if isinstance(stride, (tuple, list)) else stride + ) + if hasattr(module, "padding"): + padding = module.padding + layer_kwargs["padding"] = ( + padding[0] if isinstance(padding, (tuple, list)) else padding + ) + if hasattr(module, "dilation"): + dilation = module.dilation + layer_kwargs["dilation"] = ( + dilation[0] if isinstance(dilation, (tuple, list)) else dilation + ) + if hasattr(module, "groups"): + layer_kwargs["groups"] = module.groups + + return weight, bias, layer_kwargs + + return None, None, {} + + +def _run_once( + fn: Callable, inputs: Tuple[torch.Tensor, ...], init_inputs: list, name: str +) -> torch.Tensor: + """Run kernel once to verify execution and get output shape/dtype.""" + try: + with torch.inference_mode(): + return fn(*inputs, *init_inputs) + except Exception as exc: + raise RuntimeError(f"{name} failed to execute: {exc}") from exc + + +def _benchmark( + fn: Callable, + inputs: Tuple[torch.Tensor, ...], + init_inputs: list, + name: str, + warmup: int = 25, + rep: int = 100, +) -> float: + """Benchmark a kernel function using triton.testing.do_bench.""" + try: + ms = tt.do_bench( + lambda: fn(*inputs, *init_inputs), + warmup=warmup, + rep=rep, + return_mode="mean", + ) + print(f"{name}: {ms:.4f} ms (mean over {rep} runs)") + return ms + except Exception as exc: + print(f"❌ {name}: Benchmark failed: {exc}") + return float("inf") + + +def _parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Task-agnostic Triton kernel benchmark", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + parser.add_argument( + "--problem", + type=Path, + required=True, + help="Path to problem file (must define Model and get_inputs)", + ) + parser.add_argument( + "--kernel", + type=Path, + required=True, + help="Path to kernel file (must define kernel_function)", + ) + parser.add_argument( + "--baseline", + action="store_true", + help="Include PyTorch reference model in benchmark", + ) + parser.add_argument("--warmup", type=int, default=25) + parser.add_argument("--repeat", type=int, default=100) + parser.add_argument("--size", type=int, default=4096, help="Problem size N") + parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["float32", "float16", "bfloat16"], + ) + parser.add_argument("--json", type=Path, help="Save results to JSON file") + parser.add_argument("--quiet", action="store_true") + + args = parser.parse_args() + args.problem = args.problem.resolve() + args.kernel = args.kernel.resolve() + return args + + +def _load_problem( + args: argparse.Namespace, + device: torch.device, + dtype: torch.dtype, +) -> tuple[type, tuple, list, torch.nn.Module | None]: + """Load problem interface, prepare inputs, and optionally create baseline model. + + Returns: + Tuple of (Model class, inputs, init_inputs, baseline_model or None) + """ + Model, get_inputs, get_init_inputs = load_problem_interface(args.problem) + + # Check for optional benchmark config override + try: + problem_mod = import_module(args.problem, "problem") + get_benchmark_config = getattr(problem_mod, "get_benchmark_config", None) + if get_benchmark_config is not None: + config = get_benchmark_config() + args.warmup = config.get("warmup", args.warmup) + args.repeat = config.get("repeat", args.repeat) + if not args.quiet: + print( + f"Using problem-specific config: " + f"warmup={args.warmup}, repeat={args.repeat}" + ) + except Exception: + pass + + inputs = prepare_inputs(get_inputs, device=device, dtype=dtype) + + init_inputs = get_init_inputs() if get_init_inputs is not None else [] + if not isinstance(init_inputs, (tuple, list)): + init_inputs = [init_inputs] + + # Create baseline model if requested + baseline_model = None + if args.baseline: + baseline_model = ( + Model(*init_inputs).to(device=device, dtype=dtype) + if init_inputs + else Model().to(device=device, dtype=dtype) + ) + baseline_model.eval() + out = _run_once(baseline_model, inputs, [], "Reference") + if not args.quiet: + print(f"Reference output shape: {out.shape}, dtype: {out.dtype}") + print() + + return Model, inputs, init_inputs, baseline_model + + +def _prepare_kernel( + kernel_file: Path, + Model: type, + baseline_model: torch.nn.Module | None, + init_inputs: list, + device: torch.device, + dtype: torch.dtype, + quiet: bool = False, +) -> tuple[Callable, tuple, list]: + """Load kernel and wrap it with model parameters if needed. + + Returns: + Tuple of (kernel_function, kernel_args, kernel_init_args) + """ + kernel_function = load_kernel_function(kernel_file) + + # Check if kernel expects weight/bias parameters + needs_model = False + try: + sig = inspect.signature(kernel_function) + if "weight" in sig.parameters: + needs_model = True + except Exception: + pass + + kernel_init_args = init_inputs + + if needs_model and Model is not None: + try: + # Reuse baseline model if available + extract_model = baseline_model + if extract_model is None: + extract_model = ( + Model(*init_inputs).to(device=device, dtype=dtype) + if init_inputs + else Model().to(device=device, dtype=dtype) + ) + + weight, bias, kernel_kwargs = _extract_model_params(extract_model) + + if weight is not None: + original_fn = kernel_function + + def kernel_with_model(*args, **kwargs): + return original_fn(args[0], weight, bias, **kernel_kwargs) + + kernel_function = kernel_with_model + kernel_init_args = [] + except Exception as exc: + if not quiet: + print(f"⚠️ Warning: Failed to extract model parameters: {exc}") + print(" Falling back to direct kernel invocation") + + return kernel_function, kernel_init_args + + +def _save_results(results: dict[str, Any], path: Path) -> None: + """Save benchmark results to JSON file.""" + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(results, f, indent=2) + print(f"Results saved to: {path}") + + +def main(): + args = _parse_args() + + device = torch.device(args.device) + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + dtype = dtype_map[args.dtype] + + if not args.quiet: + print("=" * 80) + print("TRITON KERNEL PROFILING") + print("=" * 80) + print(f"Problem: {args.problem.name}") + print(f"Size: {args.size}") + print(f"Device: {device}, Dtype: {dtype}") + print(f"Warmup: {args.warmup}, Repeat: {args.repeat}") + print() + + # Load problem and prepare inputs + try: + Model, inputs, init_inputs, baseline_model = _load_problem(args, device, dtype) + except Exception as exc: + print(f"❌ Failed to load problem: {exc}") + sys.exit(1) + + results: dict[str, Any] = { + "problem": str(args.problem), + "size": args.size, + "device": str(device), + "dtype": str(dtype), + "warmup": args.warmup, + "repeat": args.repeat, + "kernels": {}, + } + + # Benchmark baseline (if requested) + baseline_time = None + if baseline_model is not None: + if not args.quiet: + print("1. PyTorch Reference") + baseline_time = _benchmark( + baseline_model, inputs, [], "PyTorch", args.warmup, args.repeat + ) + results["kernels"]["pytorch_reference"] = { + "time_ms": baseline_time, + "speedup": 1.0, + } + if not args.quiet: + print() + + # Load and prepare kernel + kernel_name = args.kernel.stem + if not args.quiet: + idx = 2 if args.baseline else 1 + print(f"{idx}. Candidate: {kernel_name}") + + try: + kernel_fn, kernel_init_args = _prepare_kernel( + args.kernel, Model, baseline_model, init_inputs, device, dtype, args.quiet + ) + except Exception as exc: + print(f"❌ Failed to load kernel: {exc}") + results["kernels"][kernel_name] = {"time_ms": float("inf"), "error": str(exc)} + if args.json: + _save_results(results, args.json) + sys.exit(1) + + # Verify kernel executes + try: + out = _run_once(kernel_fn, inputs, kernel_init_args, kernel_name) + if not args.quiet: + print(f"✓ {kernel_name} executes successfully") + print(f" Output shape: {out.shape}, dtype: {out.dtype}") + except Exception as exc: + print(f"❌ {kernel_name} failed: {exc}") + results["kernels"][kernel_name] = {"time_ms": float("inf"), "error": str(exc)} + if args.json: + _save_results(results, args.json) + sys.exit(1) + + # Benchmark kernel + kernel_time = _benchmark( + kernel_fn, inputs, kernel_init_args, kernel_name, args.warmup, args.repeat + ) + results["kernels"][kernel_name] = {"time_ms": kernel_time, "path": str(args.kernel)} + + # Calculate speedup + if baseline_time is not None and kernel_time != float("inf"): + speedup = baseline_time / kernel_time + results["kernels"][kernel_name]["speedup"] = speedup + if not args.quiet: + print(f"Speedup vs PyTorch: {speedup:.2f}x") + + if args.json: + _save_results(results, args.json) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nBenchmark interrupted") + sys.exit(130) + except Exception as exc: + print(f"❌ Unexpected error: {exc}") + sys.exit(1) diff --git a/triton_kernel_agent/opt_worker_component/benchmarking/timing.py b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py new file mode 100644 index 0000000..0e0730c --- /dev/null +++ b/triton_kernel_agent/opt_worker_component/benchmarking/timing.py @@ -0,0 +1,459 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core timing and model loading utilities for kernel benchmarking. + +This module consolidates: +- Timing functions (CUDA events, do_bench, host timing) +- Model/kernel loading utilities +- Statistics computation + +Inspired by KernelBench's timing.py +""" + +import hashlib +import importlib.util +import sys +from pathlib import Path +from typing import Any, Callable, Optional, Tuple + +import numpy as np +import torch + + +# ============================================================================= +# Model and Kernel Loading Utilities +# ============================================================================= + + +class CompilationError(RuntimeError): + """Raised when a kernel or problem file fails to compile/import.""" + + pass + + +def import_module(path: Path, module_name: Optional[str] = None): + """Dynamically import a Python file. + + Args: + path: Path to the Python file + module_name: Optional name for the module (auto-generated if None) + + Returns: + The imported module + + Raises: + FileNotFoundError: If path doesn't exist + CompilationError: If import fails + """ + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + if module_name is None: + module_name = f"mod_{hashlib.md5(str(path).encode()).hexdigest()}" + + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None or spec.loader is None: + raise CompilationError(f"Failed to create spec for {path}") + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + + try: + spec.loader.exec_module(module) + except Exception as exc: + raise CompilationError(f"Failed to import {path}: {exc}") from exc + + return module + + +def load_problem_interface( + problem_file: Path, +) -> Tuple[type, Callable, Optional[Callable]]: + """Load the standard problem interface from a problem file. + + Args: + problem_file: Path to problem file + + Returns: + Tuple of (Model class, get_inputs function, get_init_inputs function) + + Raises: + CompilationError: If problem file doesn't define required interface + """ + module = import_module(problem_file, "problem") + + Model = getattr(module, "Model", None) + get_inputs = getattr(module, "get_inputs", None) + get_init_inputs = getattr(module, "get_init_inputs", None) + + if Model is None: + raise CompilationError("Problem file must define 'Model' class") + if get_inputs is None: + raise CompilationError("Problem file must define 'get_inputs()' function") + + return Model, get_inputs, get_init_inputs + + +def prepare_inputs( + get_inputs: Callable, + device: torch.device | str = "cuda", + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[torch.Tensor, ...]: + """Prepare inputs by converting to target device and dtype. + + Args: + get_inputs: Function that returns inputs + device: Target device + dtype: Target dtype for floating-point tensors + + Returns: + Tuple of prepared inputs + """ + inputs = get_inputs() + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + + # Convert inputs to target device and dtype + # IMPORTANT: Only convert floating-point tensors; preserve integer/bool tensors + converted_inputs = [] + for inp in inputs: + if isinstance(inp, torch.Tensor): + inp = inp.to(device=device) + # Preserve integer/bool tensors (e.g., targets for classification) + if inp.is_floating_point(): + inp = inp.to(dtype=dtype) + converted_inputs.append(inp) + + return tuple(converted_inputs) + + +def prepare_pytorch_model( + problem_file: Path, + device: torch.device | str = "cuda", + dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.nn.Module, Tuple[torch.Tensor, ...]]: + """Prepare PyTorch model and inputs for benchmarking. + + This handles the full workflow: + 1. Load problem interface (Model, get_inputs, get_init_inputs) + 2. Initialize model with init inputs + 3. Move model to device + 4. Handle dtype conversion based on whether model has parameters + + Args: + problem_file: Path to problem file + device: Target device + dtype: Target dtype (auto-detected if None) + + Returns: + Tuple of (model, inputs) ready for benchmarking + """ + Model, get_inputs, get_init_inputs = load_problem_interface(problem_file) + + # Get initialization inputs (e.g., features, eps for RMSNorm) + init_inputs = get_init_inputs() if get_init_inputs is not None else [] + if not isinstance(init_inputs, (tuple, list)): + init_inputs = [init_inputs] + + model = Model(*init_inputs) if init_inputs else Model() + model = model.cuda() + has_parameters = any(p.numel() > 0 for p in model.parameters()) + + inputs = get_inputs() + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + + # Default to bfloat16 unless explicitly specified or model is a loss function + target_dtype = dtype or torch.bfloat16 + is_loss_function = isinstance(model, torch.nn.modules.loss._Loss) + + if has_parameters or not is_loss_function: + # Models with parameters (Conv, Linear, etc.) OR compute operations (matmul, etc.) + # → use bfloat16 (or user-specified dtype) + if has_parameters: + model = model.to(target_dtype) + inputs = [ + ( + inp.cuda().to(target_dtype) + if isinstance(inp, torch.Tensor) and inp.is_floating_point() + else inp.cuda() + if isinstance(inp, torch.Tensor) + else inp + ) + for inp in inputs + ] + else: + # Loss functions (no parameters) → use float32 for compatibility + # PyTorch cross_entropy doesn't support bf16 on CUDA + processed_inputs = [] + for i, inp in enumerate(inputs): + if isinstance(inp, torch.Tensor): + if i == 0 and inp.is_floating_point(): + # First input (predictions) - convert to float32 for compatibility + processed_inputs.append(inp.cuda().to(torch.float32)) + else: + # Other inputs (like targets) - just move to CUDA, preserve dtype + processed_inputs.append(inp.cuda()) + else: + processed_inputs.append(inp) + inputs = processed_inputs + + return model, tuple(inputs) + + +def load_kernel_function(kernel_file: Path) -> Callable: + """Load kernel_function from a kernel file. + + Args: + kernel_file: Path to kernel file + + Returns: + The kernel_function callable + + Raises: + CompilationError: If kernel file doesn't define kernel_function + """ + module = import_module(kernel_file, "kernel") + + kernel_function = getattr(module, "kernel_function", None) + if kernel_function is None: + raise CompilationError( + f"Kernel file {kernel_file.name} must define 'kernel_function'" + ) + + return kernel_function + + +# ============================================================================= +# Timing Utilities +# ============================================================================= + + +def clear_l2_cache(device: torch.device | str = "cuda") -> None: + """Clear L2 cache by thrashing with a large tensor. + + This ensures we measure cold cache performance, which is more representative + of real-world scenarios where data isn't already cached. + + Reference: KernelBench timing.py + L2 cache sizes: A100=40MB, H100=50MB, H200=90MB, RTX4090=72MB, L40S=48MB + We overwrite >256MB to fully thrash L2 cache. + + Args: + device: CUDA device to use + """ + # 32 * 1024 * 1024 * 8B = 256MB - enough to thrash most GPU L2 caches + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device=device) + dummy.fill_(42) # Write to tensor to ensure cache thrashing + del dummy + + +def time_with_cuda_events( + kernel_fn: Callable, + args: list[Any], + num_warmup: int = 3, + num_trials: int = 10, + clear_cache: bool = True, + discard_first: int = 0, + verbose: bool = False, + device: Optional[torch.device | str] = None, +) -> list[float]: + """Time a CUDA kernel using CUDA events for accurate device-side timing. + + This measures actual GPU execution time without host-side overhead. + Each trial clears L2 cache to measure cold-cache performance. + + Args: + kernel_fn: Function to time + args: Arguments to pass to kernel_fn + num_warmup: Number of warmup iterations + num_trials: Number of timing trials + clear_cache: Whether to clear L2 cache between trials + discard_first: Number of initial trials to discard + verbose: Print per-trial timing info + device: CUDA device to use (None = current device) + + Returns: + List of elapsed times in milliseconds (length = num_trials) + """ + if device is None: + device = torch.cuda.current_device() + + with torch.cuda.device(device): + # Warmup + for _ in range(num_warmup): + kernel_fn(*args) + torch.cuda.synchronize(device=device) + + torch.cuda.empty_cache() + + if verbose: + print( + f"[Timing] Device: {torch.cuda.get_device_name(device)}, " + f"warmup={num_warmup}, trials={num_trials}" + ) + + elapsed_times: list[float] = [] + + # Timing trials + for trial in range(num_trials + discard_first): + torch.cuda.synchronize(device=device) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + if clear_cache: + clear_l2_cache(device=device) + + start_event.record() + kernel_fn(*args) + end_event.record() + + torch.cuda.synchronize(device=device) + elapsed_time_ms = start_event.elapsed_time(end_event) + + if trial >= discard_first: + if verbose: + print( + f" Trial {trial - discard_first + 1}: {elapsed_time_ms:.3f} ms" + ) + elapsed_times.append(elapsed_time_ms) + + return elapsed_times + + +def time_with_inductor_benchmarker( + kernel_fn: Callable, + args: list[Any], + num_warmup: int = 25, + verbose: bool = False, +) -> float: + """Time using PyTorch Inductor's benchmarker (simplest approach). + + This is a thin wrapper around torch._inductor.runtime.benchmarking.benchmarker, + which handles CUDA synchronization and timing internally. + + Args: + kernel_fn: Function to time + args: Arguments to pass to kernel_fn + num_warmup: Number of warmup iterations + verbose: Print timing info + + Returns: + Elapsed time in milliseconds (single value, not a list) + + Note: + This uses a private PyTorch API (_inductor) which may change without notice. + """ + from torch._inductor.runtime.benchmarking import benchmarker + + # Warmup + for _ in range(num_warmup): + kernel_fn(*args) + + ms = benchmarker.benchmark_gpu(lambda: kernel_fn(*args)) + + if verbose: + print(f"[Timing] Inductor benchmarker: {ms:.4f} ms") + + return ms + + +def time_with_triton_do_bench( + kernel_fn: Callable, + args: list[Any], + warmup: int = 25, + rep: int = 100, + verbose: bool = False, + device: Optional[torch.device | str] = None, +) -> list[float]: + """Time using Triton's do_bench with adaptive trial count. + + Triton's do_bench automatically determines the number of trials based on + warmup/rep time budgets. This is convenient but gives less control. + + Args: + kernel_fn: Function to time + args: Arguments to pass to kernel_fn + warmup: Warmup time budget in milliseconds + rep: Repetition time budget in milliseconds + verbose: Print timing info + device: CUDA device to use + + Returns: + List of all trial times in milliseconds + """ + if device is None: + device = torch.cuda.current_device() + + import triton.testing as triton_testing + + with torch.cuda.device(device): + if verbose: + print( + f"[Timing] Using triton.do_bench on {torch.cuda.get_device_name(device)}" + ) + + def wrapped_fn(): + return kernel_fn(*args) + + times = triton_testing.do_bench( + fn=wrapped_fn, + warmup=warmup, + rep=rep, + grad_to_none=None, + quantiles=None, + return_mode="all", + ) + + return times + + +def compute_timing_stats( + elapsed_times: list[float], + device: Optional[torch.device | str] = None, +) -> dict[str, Any]: + """Compute essential timing statistics. + + Args: + elapsed_times: List of elapsed times in milliseconds + device: CUDA device (for recording hardware info) + + Returns: + Dictionary with timing statistics: + - mean: Mean time in ms + - std: Standard deviation in ms + - min: Minimum time in ms + - max: Maximum time in ms + - num_trials: Number of trials + - all_times: All trial times + - hardware: GPU name (if device provided) + """ + times_array = np.array(elapsed_times) + + stats = { + "mean": float(np.mean(times_array)), + "std": float(np.std(times_array)), + "min": float(np.min(times_array)), + "max": float(np.max(times_array)), + "num_trials": len(elapsed_times), + "all_times": [float(t) for t in elapsed_times], + } + + if device is not None: + stats["hardware"] = torch.cuda.get_device_name(device=device) + stats["device"] = str(device) + + return stats