diff --git a/graph_net/test/test_device_test.sh b/graph_net/test/test_device_test.sh new file mode 100755 index 000000000..10e0ab766 --- /dev/null +++ b/graph_net/test/test_device_test.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +AI4C_ROOT=$(python3 -c "import graph_net_bench; import os; print(os.path.dirname(os.path.dirname(graph_net_bench.__file__)))") +OUTPUT_PATH=/tmp/workspace_eval_device_diff_test +REFERENCE_DIR="$OUTPUT_PATH/reference" + +mkdir -p "$OUTPUT_PATH" +mkdir -p "$REFERENCE_DIR" + +MODEL_PATH="$AI4C_ROOT/samples/ultralytics/yolov3-tinyu" + +echo "==========================================" +echo "Step 1: Generate reference on device A (simulated)" +echo "==========================================" +python3 -m graph_net.torch.test_reference_device \ + --model-path "$MODEL_PATH" \ + --compiler nope \ + --device cuda \ + --warmup 1 \ + --trials 1 \ + --reference-dir "$REFERENCE_DIR" \ + 2>&1 | tee "$OUTPUT_PATH/reference.log" + +echo "" +echo "==========================================" +echo "Step 2: Compare on device B (simulated)" +echo "==========================================" +python3 -m graph_net.torch.test_target_device \ + --model-path "$MODEL_PATH" \ + --device cuda \ + --reference-dir "$REFERENCE_DIR" \ + 2>&1 | tee "$OUTPUT_PATH/target.log" + +echo "" +echo "==========================================" +echo "Test completed. Logs saved to: $OUTPUT_PATH" +echo "==========================================" \ No newline at end of file diff --git a/graph_net/torch/test_reference_device.py b/graph_net/torch/test_reference_device.py index e2b1a7ef7..bb80c1e8c 100644 --- a/graph_net/torch/test_reference_device.py +++ b/graph_net/torch/test_reference_device.py @@ -1,94 +1,33 @@ import argparse -import torch import os -from pathlib import Path -from contextlib import redirect_stdout, redirect_stderr -import json import sys -import traceback +import types +from pathlib import Path from graph_net_bench import path_utils -from graph_net_bench import test_compiler_util from graph_net import model_path_util -from graph_net_bench.torch import test_compiler - - -def get_reference_log_path(reference_dir, model_path): - model_name = model_path.split("samples/")[-1].replace(os.sep, "_") - return os.path.join(reference_dir, f"{model_name}.log") - - -def get_reference_output_path(reference_dir, model_path): - model_name = model_path.split("samples/")[-1].replace(os.sep, "_") - return os.path.join(reference_dir, f"{model_name}.pth") - - -def register_op_lib(op_lib): - if op_lib == "flaggems": - import flag_gems - - flag_gems.enable() - else: - pass +from graph_net_bench.torch import eval_backend_perf + + +def convert_args_for_eval_backend(args): + """Convert test_reference_device args to eval_backend_perf args format.""" + return types.SimpleNamespace( + model_path=args.model_path, + output_path=args.reference_dir, + seed=args.seed, + compiler=args.compiler, + device=args.device, + op_lib=args.op_lib, + warmup=args.warmup, + trials=args.trials, + log_prompt=args.log_prompt, + backend_config=getattr(args, "config", None), + ) def test_single_model(args): - ref_log = get_reference_log_path(args.reference_dir, args.model_path) - ref_dump = get_reference_output_path(args.reference_dir, args.model_path) - print(f"Reference log path: {ref_log}", file=sys.stderr, flush=True) - print(f"Reference outputs path: {ref_dump}", file=sys.stderr, flush=True) - - with open(ref_log, "w", encoding="utf-8") as log_f: - with redirect_stdout(log_f), redirect_stderr(log_f): - compiler = test_compiler.get_compiler_backend(args) - - input_dict = test_compiler.get_input_dict(args) - model = test_compiler.get_model(args) - model.eval() - - test_compiler_util.print_with_log_prompt( - "[Config] seed:", args.seed, args.log_prompt - ) - - test_compiler_util.print_basic_config( - args, - test_compiler.get_hardward_name(args), - test_compiler.get_compile_framework_version(args), - ) - - test_compiler_util.print_with_log_prompt( - "[Config] op_lib:", args.op_lib, args.log_prompt - ) - - success = False - time_stats = {} - try: - compiled_model = compiler(model) - - def model_call(): - return compiled_model(**input_dict) - - outputs, time_stats = test_compiler.measure_performance( - model_call, args, compiler - ) - success = True - except Exception as e: - print( - f"Run model failed: {str(e)}\n{traceback.format_exc()}", - file=sys.stderr, - flush=True, - ) - - test_compiler_util.print_running_status(args, success) - if success: - torch.save(outputs, str(ref_dump)) - test_compiler_util.print_with_log_prompt( - "[Performance][eager]:", json.dumps(time_stats), args.log_prompt - ) - - with open(ref_log, "r", encoding="utf-8") as f: - content = f.read() - print(content, file=sys.stderr, flush=True) + eval_args = convert_args_for_eval_backend(args) + eval_backend_perf.eval_single_model_with_single_backend(eval_args) def test_multi_models(args): @@ -136,20 +75,15 @@ def test_multi_models(args): def main(args): assert os.path.isdir(args.model_path) - # Support all torch compilers - valid_compilers = list(test_compiler.compiler_backend_name2class.keys()) - assert ( - args.compiler in valid_compilers - ), f"Compiler must be one of {valid_compilers}" - assert args.device in ["cuda"] + assert args.device in ["cuda", "cpu"] - test_compiler.set_seed(random_seed=args.seed) + eval_backend_perf.set_seed(args.seed) ref_dump_dir = Path(args.reference_dir) ref_dump_dir.mkdir(parents=True, exist_ok=True) if path_utils.is_single_model_dir(args.model_path): - register_op_lib(args.op_lib) + eval_backend_perf.register_op_lib(args.op_lib) test_single_model(args) else: test_multi_models(args) diff --git a/graph_net/torch/test_target_device.py b/graph_net/torch/test_target_device.py index 87fe22077..ee46ceee6 100644 --- a/graph_net/torch/test_target_device.py +++ b/graph_net/torch/test_target_device.py @@ -1,15 +1,13 @@ import argparse import os -import json import sys -import traceback +import types import torch from graph_net_bench import path_utils from graph_net_bench import test_compiler_util from graph_net import model_path_util -from graph_net_bench.torch import test_compiler -from graph_net.torch import test_reference_device +from graph_net_bench.torch import utils, eval_backend_perf, eval_backend_diff def parse_config_from_reference_log(log_path): @@ -31,94 +29,59 @@ def parse_config_from_reference_log(log_path): return config -def parse_time_stats_from_reference_log(log_path): - assert os.path.isfile( - log_path - ), f"{log_path} does not exist or is not a regular file." - - with open(log_path, "r", encoding="utf-8") as f: - lines = f.readlines() - for line in reversed(lines): - if "[Performance][eager]" in line: - start = line.find("{") - end = line.rfind("}") - time_stats = json.loads(line[start : end + 1]) - return time_stats +def get_ref_config_from_log(args, model_path): + """Extract config from reference log file.""" + ref_log = utils.get_log_path(args.reference_dir, model_path) + config = parse_config_from_reference_log(ref_log) + return config -def update_args_and_set_seed(args, model_path): - ref_log = test_reference_device.get_reference_log_path( - args.reference_dir, model_path +def convert_args_for_eval_backend(args, output_path): + """Convert test_target_device args to eval_backend_perf args format.""" + return types.SimpleNamespace( + model_path=args.model_path, + output_path=output_path, + seed=args.seed, + compiler=args.compiler, + device=args.device, + op_lib=args.op_lib, + warmup=args.warmup, + trials=args.trials, + log_prompt=args.log_prompt, + backend_config=getattr(args, "config", None), ) - config = parse_config_from_reference_log(ref_log) - vars(args)["model_path"] = model_path - vars(args)["compiler"] = config.get("compiler") - vars(args)["trials"] = int(config.get("trials")) - vars(args)["warmup"] = int(config.get("warmup")) - test_compiler.set_seed(random_seed=int(config.get("seed"))) - return args def test_single_model(args): - compiler = test_compiler.get_compiler_backend(args) + target_dir = "/tmp/eval_device_diff/target" - input_dict = test_compiler.get_input_dict(args) - model = test_compiler.get_model(args) - model.eval() + ref_config = get_ref_config_from_log(args, args.model_path) + vars(args)["compiler"] = ref_config.get("compiler") + vars(args)["trials"] = int(ref_config.get("trials")) + vars(args)["warmup"] = int(ref_config.get("warmup")) + vars(args)["seed"] = int(ref_config.get("seed")) - model_path = os.path.normpath(args.model_path) - test_compiler_util.print_with_log_prompt( - "[Processing]", model_path, args.log_prompt - ) - test_compiler_util.print_basic_config( - args, - test_compiler.get_hardward_name(args), - test_compiler.get_compile_framework_version(args), - ) + eval_args = convert_args_for_eval_backend(args, target_dir) + eval_backend_perf.eval_single_model_with_single_backend(eval_args) - success = False - time_stats = {} - try: - compiled_model = compiler(model) - - def model_call(): - return compiled_model(**input_dict) - - outputs, time_stats = test_compiler.measure_performance( - model_call, args, compiler - ) - success = True - except Exception as e: - print( - f"Run model failed: {str(e)}\n{traceback.format_exc()}", - file=sys.stderr, - flush=True, - ) - - test_compiler_util.print_running_status(args, success) - - model_name = test_compiler_util.get_model_name(args.model_path) - if test_compiler_util.get_subgraph_tag(args.model_path): - model_name += "_" + test_compiler_util.get_subgraph_tag(args.model_path) - - ref_dump = test_reference_device.get_reference_output_path( - args.reference_dir, args.model_path - ) + ref_dump = utils.get_output_path(args.reference_dir, args.model_path) ref_out = torch.load(str(ref_dump)) - - ref_log = test_reference_device.get_reference_log_path( - args.reference_dir, args.model_path + ref_log = utils.get_log_path(args.reference_dir, args.model_path) + ref_time_stats = eval_backend_diff.parse_time_stats_from_reference_log(ref_log) + + target_dump = utils.get_output_path(target_dir, args.model_path) + target_out = torch.load(str(target_dump)) + target_log = utils.get_log_path(target_dir, args.model_path) + target_time_stats = eval_backend_diff.parse_time_stats_from_reference_log( + target_log ) - ref_time_stats = parse_time_stats_from_reference_log(ref_log) - - if success: - test_compiler.compare_correctness(ref_out, outputs, args) - test_compiler_util.print_times_and_speedup(args, ref_time_stats, time_stats) + eval_backend_diff.compare_correctness(ref_out, target_out, eval_args) + test_compiler_util.print_times_and_speedup(args, ref_time_stats, target_time_stats) def is_reference_log_exist(reference_dir, model_path): - log_path = test_reference_device.get_reference_log_path(reference_dir, model_path) + log_path = utils.get_log_path(reference_dir, model_path) return os.path.isfile(log_path) @@ -172,18 +135,16 @@ def main(args): if path_utils.is_single_model_dir(args.model_path): if args.op_lib == "origin": - ref_log = test_reference_device.get_reference_log_path( - args.reference_dir, args.model_path - ) - config = parse_config_from_reference_log(ref_log) - vars(args)["op_lib"] = config.get("op_lib") - test_compiler_util.print_with_log_prompt( - "[Config] op_lib:", args.op_lib, args.log_prompt + ref_config = get_ref_config_from_log(args, args.model_path) + vars(args)["op_lib"] = ref_config.get("op_lib") + print( + f"{args.log_prompt} [Config] op_lib: {args.op_lib}", + file=sys.stderr, + flush=True, ) else: - test_reference_device.register_op_lib(args.op_lib) + eval_backend_perf.register_op_lib(args.op_lib) - args = update_args_and_set_seed(args, args.model_path) test_single_model(args) else: test_multi_models(args) diff --git a/graph_net_bench/test_compiler_util.py b/graph_net_bench/test_compiler_util.py index f587da2ff..44ccc703e 100644 --- a/graph_net_bench/test_compiler_util.py +++ b/graph_net_bench/test_compiler_util.py @@ -5,6 +5,7 @@ import time import subprocess import shutil +import base64 import numpy as np from dataclasses import dataclass from contextlib import contextmanager @@ -156,6 +157,24 @@ def print_basic_config(args, hardware_name, compile_framework_version): ) +def print_config(args, hardware_name, compiler_version): + model_path = os.path.normpath(args.model_path) + model_name = get_model_name(model_path) + print_with_log_prompt("[Config] model:", model_name, args.log_prompt) + print_with_log_prompt("[Config] seed:", args.seed, args.log_prompt) + print_with_log_prompt("[Config] device:", args.device, args.log_prompt) + print_with_log_prompt("[Config] hardware:", hardware_name, args.log_prompt) + print_with_log_prompt("[Config] op_lib:", args.op_lib, args.log_prompt) + print_with_log_prompt("[Config] compiler:", args.compiler, args.log_prompt) + print_with_log_prompt("[Config] warmup:", args.warmup, args.log_prompt) + print_with_log_prompt("[Config] trials:", args.trials, args.log_prompt) + print_with_log_prompt( + "[Config] compile_framework_version:", + compiler_version, + args.log_prompt, + ) + + def print_running_status(args, eager_success, compiled_success=None): def convert_to_str(b): return "success" if b else "failed" @@ -353,3 +372,12 @@ def get_allow_samples(allow_list, model_path_prefix): test_samples.append(os.path.join(model_path_prefix, line.strip())) return test_samples + + +def convert_to_dict(config_str): + if config_str in {None, "", "null", "None"}: + return {} + config_str = base64.b64decode(config_str).decode("utf-8") + config = json.loads(config_str) + assert isinstance(config, dict), f"config should be a dict. {config_str=}" + return config diff --git a/graph_net_bench/torch/eval_backend_diff.py b/graph_net_bench/torch/eval_backend_diff.py index 8488b71b7..cfa171dc6 100755 --- a/graph_net_bench/torch/eval_backend_diff.py +++ b/graph_net_bench/torch/eval_backend_diff.py @@ -1,307 +1,15 @@ from . import utils -import subprocess import argparse -import importlib.util import torch -from pathlib import Path -from typing import Type import sys import os import os.path import traceback import json -import random -import numpy as np -import platform -import base64 -from graph_net_bench.torch.backend.graph_compiler_backend import GraphCompilerBackend -from graph_net_bench.torch.backend.tvm_backend import TvmBackend -from graph_net_bench.torch.backend.xla_backend import XlaBackend -from graph_net_bench.torch.backend.inductor_backend import InductorBackend -from graph_net_bench.torch.backend.tensorrt_backend import TensorRTBackend -from graph_net_bench.torch.backend.blade_disc_backend import BladeDISCBackend -from graph_net_bench.torch.backend.nope_backend import NopeBackend -from graph_net_bench.torch.backend.pass_mgr_backend import PassMgrBackend -from graph_net_bench.torch.backend.unstable_to_stable_backend import ( - UnstableToStableBackend, -) -from graph_net_bench.torch.backend.range_decomposer_validator_backend import ( - RangeDecomposerValidatorBackend, -) -from graph_net_bench.torch.backend.graph_variable_renamer_validator_backend import ( - GraphVariableRenamerValidatorBackend, -) +import types from graph_net_bench import test_compiler_util from graph_net_bench import path_utils - - -compiler_backend_name2class = { - "tvm": TvmBackend, - "xla": XlaBackend, - "inductor": InductorBackend, - "tensorrt": TensorRTBackend, - "bladedisc": BladeDISCBackend, - "nope": NopeBackend, - "pass_mgr": PassMgrBackend, - "unstable_to_stable": UnstableToStableBackend, - "range_decomposer_validator": RangeDecomposerValidatorBackend, - "graph_variable_renamer_validator": GraphVariableRenamerValidatorBackend, -} - - -def set_seed(random_seed): - random.seed(random_seed) - np.random.seed(random_seed) - torch.manual_seed(random_seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(random_seed) - torch.cuda.manual_seed_all(random_seed) - - -def get_hardward_name(args): - hardware_name = "unknown" - if "cuda" in args.device: - hardware_name = torch.cuda.get_device_name(args.device) - elif args.device == "cpu": - hardware_name = platform.processor() - return hardware_name - - -def get_compile_framework_version(args): - if args.compiler in ["inductor", "nope", "unstable_to_stable"]: - return torch.__version__ - elif args.compiler in ["tvm", "xla", "tensorrt", "bladedisc"]: - # Assuming compiler object has a version attribute - return f"{args.compiler.capitalize()} {args.compiler.version}" - return "unknown" - - -def load_class_from_file( - args: argparse.Namespace, class_name: str, device: str -) -> Type[torch.nn.Module]: - file_path = f"{args.model_path}/model.py" - file = Path(file_path).resolve() - module_name = file.stem - - with open(file_path, "r", encoding="utf-8") as f: - model_code = f.read() - model_code = utils.modify_code_by_device(model_code, device) - spec = importlib.util.spec_from_loader(module_name, loader=None) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - compiled_code = compile(model_code, filename=file, mode="exec") - exec(compiled_code, module.__dict__) - - model_class = getattr(module, class_name, None) - setattr(model_class, "__graph_net_file_path__", file_path) - setattr(model_class, "__graph_net_device__", device) - return model_class - - -def convert_to_dict(config_str): - if config_str in {None, "", "null", "None"}: - return {} - config_str = base64.b64decode(config_str).decode("utf-8") - config = json.loads(config_str) - assert isinstance(config, dict), f"config should be a dict. {config_str=}" - return config - - -def get_compiler_backend(args) -> GraphCompilerBackend: - assert ( - args.compiler in compiler_backend_name2class - ), f"Unknown compiler: {args.compiler}" - backend_class = compiler_backend_name2class[args.compiler] - return backend_class(args.backend_config) - - -def get_model(args): - device = "xla" if args.compiler == "xla" else args.device - - # device: Torch device object specifying the target device for model loading (e.g., 'cuda', 'cpu', 'xla') - model_class = load_class_from_file(args, class_name="GraphModule", device=device) - model = model_class().to(torch.device(args.device)) - return model - - -def get_input_dict(args): - inputs_params = utils.load_converted_from_text(f"{args.model_path}") - params = inputs_params["weight_info"] - for tensor_meta in params.values(): - if "device" in tensor_meta["info"]: - tensor_meta["info"]["device"] = args.device - return { - k: utils.replay_tensor(v).to(torch.device(args.device)) - for k, v in params.items() - } - - -def measure_performance(model_call, args, compiler): - stats = {} - outs = model_call() - - # Warmup runs - for _ in range(args.warmup): - model_call() - compiler.synchronize() - - hardware_name = get_hardward_name(args) - print( - f"[Profiling] Using device: {args.device} {hardware_name}, warm up {args.warmup}, trials {args.trials}", - file=sys.stderr, - flush=True, - ) - - if "cuda" in args.device: - """ - Acknowledgement: We evaluate the performance on both end-to-end and GPU-only timings, - With reference to methods only based on CUDA events from KernelBench in https://github.com/ScalingIntelligence/KernelBench - """ - - e2e_times = [] - gpu_times = [] - - for i in range(args.trials): - # End-to-end timing (naive_timer) - duration_box = test_compiler_util.DurationBox(-1) - with test_compiler_util.naive_timer(duration_box, compiler.synchronize): - # GPU-only timing (CUDA Events) - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - - model_call() - - end_event.record() - compiler.synchronize() - - gpu_time_ms = start_event.elapsed_time(end_event) - e2e_times.append(duration_box.value) - gpu_times.append(gpu_time_ms) - print( - f"Trial {i + 1}: e2e={duration_box.value:.5f} ms, gpu={gpu_time_ms:.5f} ms", - file=sys.stderr, - flush=True, - ) - - stats["e2e"] = test_compiler_util.get_timing_stats(e2e_times) - stats["gpu"] = test_compiler_util.get_timing_stats(gpu_times) - - else: # CPU or other devices - e2e_times = [] - for i in range(args.trials): - duration_box = test_compiler_util.DurationBox(-1) - with test_compiler_util.naive_timer(duration_box, compiler.synchronize): - model_call() - print( - f"Trial {i + 1}: e2e={duration_box.value:.5f} ms", - file=sys.stderr, - flush=True, - ) - e2e_times.append(duration_box.value) - stats["e2e"] = test_compiler_util.get_timing_stats(e2e_times) - - return outs, stats - - -def test_single_model(args): - compiler = get_compiler_backend(args) - input_dict = get_input_dict(args) - model = get_model(args) - model_path = os.path.normpath(args.model_path) - test_compiler_util.print_with_log_prompt( - "[Processing]", model_path, args.log_prompt - ) - test_compiler_util.print_basic_config( - args, get_hardward_name(args), get_compile_framework_version(args) - ) - - runtime_seed = 1024 - eager_failure = False - expected_out = None - eager_time_stats = {} - - try: - - def eager_model_call(): - return model(**input_dict) - - expected_out, eager_time_stats = measure_performance( - eager_model_call, args, compiler - ) - - torch.manual_seed(runtime_seed) - if not isinstance(expected_out, tuple): - expected_out = (expected_out,) - except (TypeError, RuntimeError) as e: - print(f"Eager model execution failed: {str(e)}", file=sys.stderr) - eager_failure = True - - compiled_failure = False - compiled_model = None - compiled_time_stats = {} - - try: - compiled_model = compiler(model) - torch.manual_seed(runtime_seed) - - def compiled_model_call(): - return compiled_model(**input_dict) - - compiled_out, compiled_time_stats = measure_performance( - compiled_model_call, args, compiler - ) - - if not isinstance(compiled_out, tuple): - compiled_out = (compiled_out,) - if args.compiler == "xla": - compiled_out = tuple(item.to("cpu").to("cuda") for item in compiled_out) - except (TypeError, RuntimeError) as e: - print(f"Compiled model execution failed: {str(e)}", file=sys.stderr) - compiled_failure = True - print("\n--- Full Traceback ---") - traceback.print_exc() - print(f"debug-model-execution {type(e).__name__} {args.model_path}", flush=True) - except Exception as e: - compiled_failure = True - print("\n--- Full Traceback ---") - traceback.print_exc() - print(f"debug-model-execution {type(e).__name__} {args.model_path}", flush=True) - - if eager_failure: - print(f"{args.log_prompt} [Result] status: failed", file=sys.stderr, flush=True) - print( - f"{args.log_prompt} [Fail due to eager model execution error.]", - file=sys.stderr, - flush=True, - ) - elif compiled_failure: - print(f"{args.log_prompt} [Result] status: failed", file=sys.stderr, flush=True) - print( - f"{args.log_prompt} [Fail due to compiled model execution error.]", - file=sys.stderr, - flush=True, - ) - else: - compare_correctness(expected_out, compiled_out, args) - - print( - f"{args.log_prompt} [Result] status: success", file=sys.stderr, flush=True - ) - - test_compiler_util.print_times_and_speedup( - args, eager_time_stats, compiled_time_stats - ) - - -def print_and_store_cmp(key, cmp_func, args, expected_out, compiled_out, **kwargs): - cmp_ret = cmp_func(expected_out, compiled_out, **kwargs) - print( - f"{args.log_prompt} [Correctness]{key}: {cmp_ret}", - file=sys.stderr, - flush=True, - ) - return cmp_ret +from .eval_backend_perf import eval_single_model_with_single_backend def compare_correctness(expected_out, compiled_out, args): @@ -386,124 +94,159 @@ def get_cmp_diff_count(expected_out, compiled_out, atol, rtol): return " ".join(results) -def get_sample_root(args): - return args.model_path_prefix +def parse_time_stats_from_reference_log(log_path): + assert os.path.isfile( + log_path + ), f"{log_path} does not exist or is not a regular file." + with open(log_path, "r", encoding="utf-8") as f: + lines = f.readlines() + for line in reversed(lines): + if "[Performance][eager]" in line: + start = line.find("{") + end = line.rfind("}") + time_stats = json.loads(line[start : end + 1]) + return time_stats -def test_multi_models(args): - test_samples = test_compiler_util.get_allow_samples( - args.model_path_list, get_sample_root(args) - ) - sample_idx = 0 - failed_samples = [] - module_name = os.path.splitext(os.path.basename(__file__))[0] - for model_path in path_utils.get_recursively_model_path(args.model_path): - if test_samples is None or os.path.abspath(model_path) in test_samples: - print( - f"[{sample_idx}] {module_name}, model_path: {model_path}", - file=sys.stderr, - flush=True, - ) - cmd = " ".join( - [ - sys.executable, - f"-m graph_net_bench.torch.{module_name}", - f"--model-path {model_path}", - f"--config {args.config}", - ] +def _get_model_paths(args, model_path_prefix, use_model_list): + if use_model_list: + assert os.path.isdir(model_path_prefix) and os.path.isfile(args.model_path_list) + + test_samples = test_compiler_util.get_allow_samples( + args.model_path_list, model_path_prefix + ) + model_paths = [ + os.path.join(model_path_prefix, rel_model_path) + for rel_model_path in test_samples + if os.path.exists( + os.path.join(model_path_prefix, rel_model_path, "model.py") ) - try: - process = subprocess.Popen(cmd, shell=True) - cmd_ret = process.wait() - except KeyboardInterrupt: - print("KeyboardInterrupt") - sys.exit(1) - except Exception: - print("\n--- Full Traceback ---") - traceback.print_exc() - if cmd_ret != 0: - failed_samples.append(model_path) - sample_idx += 1 + ] + else: + assert os.path.isdir(args.model_path) - print( - f"Totally {sample_idx} verified samples, failed {len(failed_samples)} samples.", - file=sys.stderr, - flush=True, - ) - for model_path in failed_samples: - print(f"- {model_path}", file=sys.stderr, flush=True) + test_samples = test_compiler_util.get_allow_samples( + args.model_path_list, model_path_prefix + ) + model_paths = [ + model_path + for model_path in path_utils.get_recursively_model_path(args.model_path) + if test_samples is None or os.path.abspath(model_path) in test_samples + ] + return model_paths -def test_multi_models_with_prefix(args): - assert os.path.isdir(args.model_path_prefix) - assert os.path.isfile(args.model_path_list) - test_samples = test_compiler_util.get_allow_samples( - args.model_path_list, get_sample_root(args) - ) - py_module_name = os.path.splitext(os.path.basename(__file__))[0] - for rel_model_path in test_samples: - model_path = os.path.join(args.model_path_prefix, rel_model_path) - if not os.path.exists(model_path): - continue - if not os.path.exists(os.path.join(model_path, "model.py")): - continue - cmd = " ".join( - [ - sys.executable, - f"-m graph_net_bench.torch.{py_module_name}", - f"--model-path {model_path}", - f"--config {args.config}", - ] + +def _create_model_args(model_path, reference_config, target_config): + args = argparse.Namespace() + args.model_path = model_path + args.model_path_list = None + args.reference_config = reference_config + args.target_config = target_config + return args + + +def eval_multi_models(args, model_path_prefix=None, use_model_list=False): + module_name = os.path.splitext(os.path.basename(__file__))[0] + + model_paths = _get_model_paths(args, model_path_prefix, use_model_list) + failed_samples = [] + for sample_idx, model_path in enumerate(model_paths): + print( + f"[{sample_idx}] {module_name}, model_path: {model_path}", + file=sys.stderr, + flush=True, ) + + model_args = argparse.Namespace() + model_args.model_path = model_path + model_args.model_path_list = None + model_args.reference_config = args.reference_config + model_args.target_config = args.target_config + try: - process = subprocess.Popen(cmd, shell=True) - process.wait() + eval_single_model(model_args) + success = True except KeyboardInterrupt: print("KeyboardInterrupt") sys.exit(1) except Exception: print("\n--- Full Traceback ---") traceback.print_exc() + success = False + if not success: + failed_samples.append(model_path) -def main(args): - if args.model_path_list is not None and args.model_path_prefix is not None: - test_multi_models_with_prefix(args) - return - assert os.path.isdir(args.model_path) + print( + f"Totally {len(model_paths)} verified samples, failed {len(failed_samples)} samples.", + file=sys.stderr, + flush=True, + ) + if failed_samples: + for model_path in failed_samples: + print(f"- {model_path}", file=sys.stderr, flush=True) + + +def eval_single_model(args): + ref_dir = "/tmp/eval_perf_diff/reference" + target_dir = "/tmp/eval_perf_diff/target" + + ref_args = types.SimpleNamespace( + model_path=args.model_path, + output_path=ref_dir, + **test_compiler_util.convert_to_dict(args.reference_config), + ) + target_args = types.SimpleNamespace( + model_path=args.model_path, + output_path=target_dir, + **test_compiler_util.convert_to_dict(args.target_config), + ) + + eval_single_model_with_single_backend(ref_args) + eval_single_model_with_single_backend(target_args) + + # compare_perf_diff + # A + ref_dump_path = utils.get_output_path(ref_dir, args.model_path) + ref_out = torch.load(str(ref_dump_path)) + + ref_log_path = utils.get_log_path(ref_dir, args.model_path) + ref_time_stats = parse_time_stats_from_reference_log(ref_log_path) - initalize_seed = 123 - set_seed(random_seed=initalize_seed) + # B + target_dump_path = utils.get_output_path(target_dir, args.model_path) + target_out = torch.load(str(target_dump_path)) - if path_utils.is_single_model_dir(args.model_path): - test_single_model(args) + target_log_path = utils.get_log_path(target_dir, args.model_path) + target_time_stats = parse_time_stats_from_reference_log(target_log_path) + + compare_correctness(ref_out, target_out, ref_args) + test_compiler_util.print_times_and_speedup( + ref_args, ref_time_stats, target_time_stats + ) + + +def main(args): + ref_config = test_compiler_util.convert_to_dict(args.reference_config) + model_path_prefix = ref_config.get("model_path_prefix") + + if args.model_path_list and model_path_prefix: + eval_multi_models(args, model_path_prefix, use_model_list=True) + elif os.path.isdir(args.model_path): + if path_utils.is_single_model_dir(args.model_path): + eval_single_model(args) + else: + eval_multi_models(args, model_path_prefix, use_model_list=False) else: - test_multi_models(args) - - -def complete_default_args( - mut_args, - compiler: str = "inductor", # Compiler name - device: str = "cuda", # Device for testing the compiler (e.g., 'cpu' or 'cuda') - warmup: int = 3, # Number of warmup steps - trials: int = 5, # Number of timing trials - log_prompt: str = "graph-net-test-compiler-log", # Log prompt for performance log filtering - model_path_prefix: str = None, # Prefix path to model path in --model-path-list - backend_config: dict = None, # backend configuration json -): - backend_config = backend_config if backend_config is not None else {} - mut_args.compiler = compiler - mut_args.device = device - mut_args.warmup = warmup - mut_args.trials = trials - mut_args.log_prompt = log_prompt - mut_args.model_path_prefix = model_path_prefix - mut_args.backend_config = backend_config + raise ValueError(f"Invalid model path: {args.model_path}") if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Test compiler performance.") + parser = argparse.ArgumentParser( + description="Evaluate Backend Performance Difference." + ) parser.add_argument( "--model-path", type=str, @@ -519,12 +262,16 @@ def complete_default_args( help="Path to samples list, each line contains a sample path", ) parser.add_argument( - "--config", + "--reference-config", type=str, - required=False, - default=None, - help="base64 encode configuration json.", + required=True, + help="base64 encode reference config json.", + ) + parser.add_argument( + "--target-config", + type=str, + required=True, + help="base64 encode target config json.", ) args = parser.parse_args() - complete_default_args(args, **convert_to_dict(args.config)) main(args=args) diff --git a/graph_net_bench/torch/eval_backend_perf.py b/graph_net_bench/torch/eval_backend_perf.py new file mode 100644 index 000000000..5c8586f30 --- /dev/null +++ b/graph_net_bench/torch/eval_backend_perf.py @@ -0,0 +1,301 @@ +from . import utils +import argparse +import importlib.util +import torch +from pathlib import Path +from typing import Type +import sys +import os +import traceback +import json +import random +import numpy as np +import platform +import types +from contextlib import redirect_stdout, redirect_stderr +from graph_net_bench.torch.backend.graph_compiler_backend import GraphCompilerBackend +from graph_net_bench import test_compiler_util + + +def register_op_lib(op_lib): + if op_lib == "flaggems": + import flag_gems + + flag_gems.enable() + else: + pass + + +def set_seed(random_seed): + random.seed(random_seed) + np.random.seed(random_seed) + torch.manual_seed(random_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(random_seed) + torch.cuda.manual_seed_all(random_seed) + + +def get_hardward_name(device): + hardware_name = "unknown" + if "cuda" in device: + hardware_name = torch.cuda.get_device_name(device) + elif device == "cpu": + hardware_name = platform.processor() + return hardware_name + + +def get_compiler_version(compiler): + if compiler in ["inductor", "nope", "unstable_to_stable"]: + return torch.__version__ + elif compiler in ["tvm", "xla", "tensorrt", "bladedisc"]: + # Assuming compiler object has a version attribute + return f"{compiler.capitalize()} {compiler.version}" + return "unknown" + + +def load_class_from_file( + model_path: str, class_name: str, device: str +) -> Type[torch.nn.Module]: + file_path = f"{model_path}/model.py" + file = Path(file_path).resolve() + module_name = file.stem + + with open(file_path, "r", encoding="utf-8") as f: + model_code = f.read() + model_code = utils.modify_code_by_device(model_code, device) + spec = importlib.util.spec_from_loader(module_name, loader=None) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + compiled_code = compile(model_code, filename=file, mode="exec") + exec(compiled_code, module.__dict__) + + model_class = getattr(module, class_name, None) + setattr(model_class, "__graph_net_file_path__", file_path) + setattr(model_class, "__graph_net_device__", device) + return model_class + + +def get_compiler_backend(args) -> GraphCompilerBackend: + """ + Dynamically load backend class based on args.compiler + """ + compiler_name = args.compiler.lower() + module_name = f"graph_net_bench.torch.backend.{compiler_name}_backend" + + try: + module = __import__(module_name, fromlist=[f"{compiler_name.title()}Backend"]) + + class_name = ( + f"{''.join(part.title() for part in compiler_name.split('_'))}Backend" + ) + + backend_class = None + if hasattr(module, class_name): + backend_class = getattr(module, class_name) + else: + raise ImportError(f"No valid backend class found in {module_name}") + + except ImportError as e: + raise ImportError(f"Failed to import backend module for '{compiler_name}': {e}") + + backend_config = ( + test_compiler_util.convert_to_dict(args.backend_config) + if args.backend_config is not None + else {} + ) + return backend_class(backend_config) + + +def get_model(args): + device = "xla" if args.compiler == "xla" else args.device + + # device: Torch device object specifying the target device for model loading (e.g., 'cuda', 'cpu', 'xla') + model_class = load_class_from_file( + args.model_path, class_name="GraphModule", device=device + ) + model = model_class().to(torch.device(args.device)) + return model + + +def get_input_dict(args): + inputs_params = utils.load_converted_from_text(f"{args.model_path}") + params = inputs_params["weight_info"] + for tensor_meta in params.values(): + if "device" in tensor_meta["info"]: + tensor_meta["info"]["device"] = args.device + return { + k: utils.replay_tensor(v).to(torch.device(args.device)) + for k, v in params.items() + } + + +def measure_performance(model_call, args, compiler): + stats = {} + outs = model_call() + + # Warmup runs + for _ in range(args.warmup): + model_call() + compiler.synchronize() + + print( + f"[Profiling] Warm up {args.warmup}, Trials {args.trials}", + file=sys.stderr, + flush=True, + ) + + if "cuda" in args.device: + torch.cuda.empty_cache() + e2e_times = [] + gpu_times = [] + + for i in range(args.trials): + # End-to-end timing (naive_timer) + duration_box = test_compiler_util.DurationBox(-1) + with test_compiler_util.naive_timer(duration_box, compiler.synchronize): + # GPU-only timing (CUDA Events) + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + + model_call() + + end_event.record() + compiler.synchronize() + + gpu_time_ms = start_event.elapsed_time(end_event) + e2e_times.append(duration_box.value) + gpu_times.append(gpu_time_ms) + print( + f"Trial {i + 1}: e2e={duration_box.value:.5f} ms, gpu={gpu_time_ms:.5f} ms", + file=sys.stderr, + flush=True, + ) + + stats["e2e"] = test_compiler_util.get_timing_stats(e2e_times) + stats["gpu"] = test_compiler_util.get_timing_stats(gpu_times) + + else: # CPU or other devices + e2e_times = [] + for i in range(args.trials): + duration_box = test_compiler_util.DurationBox(-1) + with test_compiler_util.naive_timer(duration_box, compiler.synchronize): + model_call() + print( + f"Trial {i + 1}: e2e={duration_box.value:.5f} ms", + file=sys.stderr, + flush=True, + ) + e2e_times.append(duration_box.value) + stats["e2e"] = test_compiler_util.get_timing_stats(e2e_times) + + return outs, stats + + +def eval_single_model_with_single_backend(args): + check_and_complete_args(args) + set_seed(args.seed) + os.makedirs(args.output_path, exist_ok=True) + log_path = utils.get_log_path(args.output_path, args.model_path) + output_dump_path = utils.get_output_path(args.output_path, args.model_path) + print(f"Log path: {log_path}", file=sys.stderr, flush=True) + print(f"Outputs path: {output_dump_path}", file=sys.stderr, flush=True) + + with open(log_path, "w", encoding="utf-8") as log_f: + with redirect_stdout(log_f), redirect_stderr(log_f): + compiler = get_compiler_backend(args) + + input_dict = get_input_dict(args) + model = get_model(args) + model.eval() + + test_compiler_util.print_config( + args, + get_hardward_name(args.device), + get_compiler_version(args.compiler), + ) + + success = False + time_stats = {} + try: + compiled_model = compiler(model) + + def model_call(): + return compiled_model(**input_dict) + + outputs, time_stats = measure_performance(model_call, args, compiler) + success = True + except Exception as e: + print( + f"Run model failed: {str(e)}\n{traceback.format_exc()}", + file=sys.stderr, + flush=True, + ) + + test_compiler_util.print_running_status(args, success) + if success: + torch.save(outputs, str(output_dump_path)) + test_compiler_util.print_with_log_prompt( + "[Performance][eager]:", json.dumps(time_stats), args.log_prompt + ) + + with open(log_path, "r", encoding="utf-8") as f: + content = f.read() + print(content, file=sys.stderr, flush=True) + + +def check_and_complete_args(args): + """ + Ensure all required arguments are present with default values if missing + """ + defaults = { + "model_path": None, # Model path + "output_path": None, # Log and output directory + "seed": 123, # Random seed + "compiler": "inductor", # Compiler name + "device": "cuda", # Device for testing the compiler (e.g., 'cpu' or 'cuda') + "op_lib": None, # Operator library + "warmup": 3, # Number of warmup steps + "trials": 5, # Number of timing trials + "log_prompt": "graph-net-bench-log", # Log prompt for performance log filtering + "model_path_prefix": None, # Prefix path to model path in args.model-path + "backend_config": None, # backend configuration json + } + + for key, default in defaults.items(): + if not hasattr(args, key): + setattr(args, key, default) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Single Backend Performance Evaluation" + ) + parser.add_argument( + "--model-path", + type=str, + required=False, + default=None, + help="Path to model file(s), each subdirectory containing graph_net.json will be regarded as a model", + ) + parser.add_argument( + "--output-path", + type=str, + required=False, + default="/tmp/test_save", + help="Path to save outputs", + ) + parser.add_argument( + "--config", + type=str, + required=False, + default=None, + help="base64 encode configuration json.", + ) + args = parser.parse_args() + mut_args = types.SimpleNamespace( + model_path=args.model_path, + output_path=args.output_path, + **test_compiler_util.convert_to_dict(args.config), + ) + eval_single_model_with_single_backend(mut_args) diff --git a/graph_net_bench/torch/utils.py b/graph_net_bench/torch/utils.py index c937ff4de..700a59972 100755 --- a/graph_net_bench/torch/utils.py +++ b/graph_net_bench/torch/utils.py @@ -1,4 +1,5 @@ import torch +import os import ast import math import inspect @@ -7,6 +8,16 @@ kLiteralTensorSize = 64 +def get_log_path(log_dir, model_path): + model_name = model_path.split("torch_samples/")[-1].replace(os.sep, "_") + return os.path.join(log_dir, f"{model_name}.log") + + +def get_output_path(output_dir, model_path): + model_name = model_path.split("torch_samples/")[-1].replace(os.sep, "_") + return os.path.join(output_dir, f"{model_name}.pth") + + def get_limited_precision_float_str(value): if not isinstance(value, float): return value diff --git a/test/eval_backend_diff_test.sh b/test/eval_backend_diff_test.sh index e3fa79602..1eaca5ecd 100755 --- a/test/eval_backend_diff_test.sh +++ b/test/eval_backend_diff_test.sh @@ -8,11 +8,23 @@ model_list="$AI4C_ROOT/test/workspace_eval_backend_diff/sample_list.txt" python3 -m graph_net_bench.torch.eval_backend_diff \ --model-path-list $model_list \ - --config $(base64 -w 0 <&1 | tee "$OUTPUT_PATH/validation.log"