Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions graph_net/test/test_device_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/bin/bash

AI4C_ROOT=$(python3 -c "import graph_net_bench; import os; print(os.path.dirname(os.path.dirname(graph_net_bench.__file__)))")
OUTPUT_PATH=/tmp/workspace_eval_device_diff_test
REFERENCE_DIR="$OUTPUT_PATH/reference"

mkdir -p "$OUTPUT_PATH"
mkdir -p "$REFERENCE_DIR"

MODEL_PATH="$AI4C_ROOT/samples/ultralytics/yolov3-tinyu"

echo "=========================================="
echo "Step 1: Generate reference on device A (simulated)"
echo "=========================================="
python3 -m graph_net.torch.test_reference_device \
--model-path "$MODEL_PATH" \
--compiler nope \
--device cuda \
--warmup 1 \
--trials 1 \
--reference-dir "$REFERENCE_DIR" \
2>&1 | tee "$OUTPUT_PATH/reference.log"

echo ""
echo "=========================================="
echo "Step 2: Compare on device B (simulated)"
echo "=========================================="
python3 -m graph_net.torch.test_target_device \
--model-path "$MODEL_PATH" \
--device cuda \
--reference-dir "$REFERENCE_DIR" \
2>&1 | tee "$OUTPUT_PATH/target.log"

echo ""
echo "=========================================="
echo "Test completed. Logs saved to: $OUTPUT_PATH"
echo "=========================================="
114 changes: 24 additions & 90 deletions graph_net/torch/test_reference_device.py
Original file line number Diff line number Diff line change
@@ -1,94 +1,33 @@
import argparse
import torch
import os
from pathlib import Path
from contextlib import redirect_stdout, redirect_stderr
import json
import sys
import traceback
import types
from pathlib import Path

from graph_net_bench import path_utils
from graph_net_bench import test_compiler_util
from graph_net import model_path_util
from graph_net_bench.torch import test_compiler


def get_reference_log_path(reference_dir, model_path):
model_name = model_path.split("samples/")[-1].replace(os.sep, "_")
return os.path.join(reference_dir, f"{model_name}.log")


def get_reference_output_path(reference_dir, model_path):
model_name = model_path.split("samples/")[-1].replace(os.sep, "_")
return os.path.join(reference_dir, f"{model_name}.pth")


def register_op_lib(op_lib):
if op_lib == "flaggems":
import flag_gems

flag_gems.enable()
else:
pass
from graph_net_bench.torch import eval_backend_perf


def convert_args_for_eval_backend(args):
"""Convert test_reference_device args to eval_backend_perf args format."""
return types.SimpleNamespace(
model_path=args.model_path,
output_path=args.reference_dir,
seed=args.seed,
compiler=args.compiler,
device=args.device,
op_lib=args.op_lib,
warmup=args.warmup,
trials=args.trials,
log_prompt=args.log_prompt,
backend_config=getattr(args, "config", None),
)


def test_single_model(args):
ref_log = get_reference_log_path(args.reference_dir, args.model_path)
ref_dump = get_reference_output_path(args.reference_dir, args.model_path)
print(f"Reference log path: {ref_log}", file=sys.stderr, flush=True)
print(f"Reference outputs path: {ref_dump}", file=sys.stderr, flush=True)

with open(ref_log, "w", encoding="utf-8") as log_f:
with redirect_stdout(log_f), redirect_stderr(log_f):
compiler = test_compiler.get_compiler_backend(args)

input_dict = test_compiler.get_input_dict(args)
model = test_compiler.get_model(args)
model.eval()

test_compiler_util.print_with_log_prompt(
"[Config] seed:", args.seed, args.log_prompt
)

test_compiler_util.print_basic_config(
args,
test_compiler.get_hardward_name(args),
test_compiler.get_compile_framework_version(args),
)

test_compiler_util.print_with_log_prompt(
"[Config] op_lib:", args.op_lib, args.log_prompt
)

success = False
time_stats = {}
try:
compiled_model = compiler(model)

def model_call():
return compiled_model(**input_dict)

outputs, time_stats = test_compiler.measure_performance(
model_call, args, compiler
)
success = True
except Exception as e:
print(
f"Run model failed: {str(e)}\n{traceback.format_exc()}",
file=sys.stderr,
flush=True,
)

test_compiler_util.print_running_status(args, success)
if success:
torch.save(outputs, str(ref_dump))
test_compiler_util.print_with_log_prompt(
"[Performance][eager]:", json.dumps(time_stats), args.log_prompt
)

with open(ref_log, "r", encoding="utf-8") as f:
content = f.read()
print(content, file=sys.stderr, flush=True)
eval_args = convert_args_for_eval_backend(args)
eval_backend_perf.eval_single_model_with_single_backend(eval_args)


def test_multi_models(args):
Expand Down Expand Up @@ -136,20 +75,15 @@ def test_multi_models(args):

def main(args):
assert os.path.isdir(args.model_path)
# Support all torch compilers
valid_compilers = list(test_compiler.compiler_backend_name2class.keys())
assert (
args.compiler in valid_compilers
), f"Compiler must be one of {valid_compilers}"
assert args.device in ["cuda"]
assert args.device in ["cuda", "cpu"]

test_compiler.set_seed(random_seed=args.seed)
eval_backend_perf.set_seed(args.seed)

ref_dump_dir = Path(args.reference_dir)
ref_dump_dir.mkdir(parents=True, exist_ok=True)

if path_utils.is_single_model_dir(args.model_path):
register_op_lib(args.op_lib)
eval_backend_perf.register_op_lib(args.op_lib)
test_single_model(args)
else:
test_multi_models(args)
Expand Down
133 changes: 47 additions & 86 deletions graph_net/torch/test_target_device.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import argparse
import os
import json
import sys
import traceback
import types

import torch
from graph_net_bench import path_utils
from graph_net_bench import test_compiler_util
from graph_net import model_path_util
from graph_net_bench.torch import test_compiler
from graph_net.torch import test_reference_device
from graph_net_bench.torch import utils, eval_backend_perf, eval_backend_diff


def parse_config_from_reference_log(log_path):
Expand All @@ -31,94 +29,59 @@ def parse_config_from_reference_log(log_path):
return config


def parse_time_stats_from_reference_log(log_path):
assert os.path.isfile(
log_path
), f"{log_path} does not exist or is not a regular file."

with open(log_path, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in reversed(lines):
if "[Performance][eager]" in line:
start = line.find("{")
end = line.rfind("}")
time_stats = json.loads(line[start : end + 1])
return time_stats
def get_ref_config_from_log(args, model_path):
"""Extract config from reference log file."""
ref_log = utils.get_log_path(args.reference_dir, model_path)
config = parse_config_from_reference_log(ref_log)
return config


def update_args_and_set_seed(args, model_path):
ref_log = test_reference_device.get_reference_log_path(
args.reference_dir, model_path
def convert_args_for_eval_backend(args, output_path):
"""Convert test_target_device args to eval_backend_perf args format."""
return types.SimpleNamespace(
model_path=args.model_path,
output_path=output_path,
seed=args.seed,
compiler=args.compiler,
device=args.device,
op_lib=args.op_lib,
warmup=args.warmup,
trials=args.trials,
log_prompt=args.log_prompt,
backend_config=getattr(args, "config", None),
)
config = parse_config_from_reference_log(ref_log)
vars(args)["model_path"] = model_path
vars(args)["compiler"] = config.get("compiler")
vars(args)["trials"] = int(config.get("trials"))
vars(args)["warmup"] = int(config.get("warmup"))
test_compiler.set_seed(random_seed=int(config.get("seed")))
return args


def test_single_model(args):
compiler = test_compiler.get_compiler_backend(args)
target_dir = "/tmp/eval_device_diff/target"

input_dict = test_compiler.get_input_dict(args)
model = test_compiler.get_model(args)
model.eval()
ref_config = get_ref_config_from_log(args, args.model_path)
vars(args)["compiler"] = ref_config.get("compiler")
vars(args)["trials"] = int(ref_config.get("trials"))
vars(args)["warmup"] = int(ref_config.get("warmup"))
vars(args)["seed"] = int(ref_config.get("seed"))

model_path = os.path.normpath(args.model_path)
test_compiler_util.print_with_log_prompt(
"[Processing]", model_path, args.log_prompt
)
test_compiler_util.print_basic_config(
args,
test_compiler.get_hardward_name(args),
test_compiler.get_compile_framework_version(args),
)
eval_args = convert_args_for_eval_backend(args, target_dir)
eval_backend_perf.eval_single_model_with_single_backend(eval_args)

success = False
time_stats = {}
try:
compiled_model = compiler(model)

def model_call():
return compiled_model(**input_dict)

outputs, time_stats = test_compiler.measure_performance(
model_call, args, compiler
)
success = True
except Exception as e:
print(
f"Run model failed: {str(e)}\n{traceback.format_exc()}",
file=sys.stderr,
flush=True,
)

test_compiler_util.print_running_status(args, success)

model_name = test_compiler_util.get_model_name(args.model_path)
if test_compiler_util.get_subgraph_tag(args.model_path):
model_name += "_" + test_compiler_util.get_subgraph_tag(args.model_path)

ref_dump = test_reference_device.get_reference_output_path(
args.reference_dir, args.model_path
)
ref_dump = utils.get_output_path(args.reference_dir, args.model_path)
ref_out = torch.load(str(ref_dump))

ref_log = test_reference_device.get_reference_log_path(
args.reference_dir, args.model_path
ref_log = utils.get_log_path(args.reference_dir, args.model_path)
ref_time_stats = eval_backend_diff.parse_time_stats_from_reference_log(ref_log)

target_dump = utils.get_output_path(target_dir, args.model_path)
target_out = torch.load(str(target_dump))
target_log = utils.get_log_path(target_dir, args.model_path)
target_time_stats = eval_backend_diff.parse_time_stats_from_reference_log(
target_log
)
ref_time_stats = parse_time_stats_from_reference_log(ref_log)

if success:
test_compiler.compare_correctness(ref_out, outputs, args)

test_compiler_util.print_times_and_speedup(args, ref_time_stats, time_stats)
eval_backend_diff.compare_correctness(ref_out, target_out, eval_args)
test_compiler_util.print_times_and_speedup(args, ref_time_stats, target_time_stats)


def is_reference_log_exist(reference_dir, model_path):
log_path = test_reference_device.get_reference_log_path(reference_dir, model_path)
log_path = utils.get_log_path(reference_dir, model_path)
return os.path.isfile(log_path)


Expand Down Expand Up @@ -172,18 +135,16 @@ def main(args):

if path_utils.is_single_model_dir(args.model_path):
if args.op_lib == "origin":
ref_log = test_reference_device.get_reference_log_path(
args.reference_dir, args.model_path
)
config = parse_config_from_reference_log(ref_log)
vars(args)["op_lib"] = config.get("op_lib")
test_compiler_util.print_with_log_prompt(
"[Config] op_lib:", args.op_lib, args.log_prompt
ref_config = get_ref_config_from_log(args, args.model_path)
vars(args)["op_lib"] = ref_config.get("op_lib")
print(
f"{args.log_prompt} [Config] op_lib: {args.op_lib}",
file=sys.stderr,
flush=True,
)
else:
test_reference_device.register_op_lib(args.op_lib)
eval_backend_perf.register_op_lib(args.op_lib)

args = update_args_and_set_seed(args, args.model_path)
test_single_model(args)
else:
test_multi_models(args)
Expand Down
Loading