Skip to content
83 changes: 56 additions & 27 deletions graph_net/auto_fault_locator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def __init__(self, args):
self.machine = args.machine
self.port = args.port

def get_one_step_cmd(self, config_str):
config_b64 = convert_json_to_b64_string(config_str)
return [
def execute_one_step_cmd(self, test_config):
test_config_b64_str = convert_json_to_b64_string(test_config)
cmd = [
sys.executable,
"-m",
"graph_net.subgraph_decompose_and_evaluation_step",
Expand All @@ -36,7 +36,7 @@ def get_one_step_cmd(self, config_str):
"--framework",
self.framework,
"--test-config",
config_b64,
test_config_b64_str,
"--decompose-method",
self.decompose_method,
"--tolerance",
Expand All @@ -45,49 +45,78 @@ def get_one_step_cmd(self, config_str):
self.max_subgraph_size,
]

def run_remote_test_reference(self):
print(f"[AutoFaultLocator] Executing: {' '.join(cmd)}", flush=True)
result = subprocess.run(cmd, check=True, text=True)
return result

def run_test_reference_device(self, is_remote):
print(
"\n>>> [Step 1] Run Remote Reference Device (Decomposition And Evaluation)\n"
"\n>>> [AutoFaultLocator 2/1] Run Test Reference Device (Decomposition And Evaluation)\n",
flush=True,
)

test_remote_reference_device_config_str = {
"test_module_name": "test_remote_reference_device",
"test_remote_reference_device_arguments": {
test_module_name = (
"test_remote_reference_device" if is_remote else "test_reference_device"
)
test_reference_device_config = {
"test_module_name": test_module_name,
f"{test_module_name}_arguments": {
"model-path": None,
"reference-dir": None,
"compiler": "nope",
"device": self.reference_device,
"op-lib": "default",
"warmup": 5,
"trials": 20,
"seed": 123,
"machine": self.machine,
"port": self.port,
},
}

cmd = self.get_one_step_cmd(test_remote_reference_device_config_str)
print(f"Executing: {' '.join(cmd)}")
result = subprocess.run(cmd, check=True, text=True)
if args.framework == "torch":
test_reference_device_config[f"{test_module_name}_arguments"].update(
{"op-lib": "default"}
)
if is_remote:
test_reference_device_config[f"{test_module_name}_arguments"].update(
{
"machine": self.machine,
"port": self.port,
}
)

result = self.execute_one_step_cmd(test_reference_device_config)
assert (
result.returncode == 0
), f"Run Remote Reference Device failed with return code {result.returncode}"

def run_local_test_target(self):
print("\n>>> [Step 2] Run Local Target Device (Evaluation And Analysis)\n")
def run_test_target_device(self, is_remote):
print(
"\n>>> [AutoFaultLocator 2/2] Run Test Target Device (Evaluation And Analysis)\n",
flush=True,
)

test_target_device_config_str = {
"test_module_name": "test_target_device",
"test_target_device_arguments": {
test_module_name = (
"test_remote_target_device" if is_remote else "test_target_device"
)
test_target_device_config = {
"test_module_name": test_module_name,
f"{test_module_name}_arguments": {
"model-path": None,
"reference-dir": None,
"compiler": "nope",
"device": self.target_device,
"warmup": 5,
"trials": 20,
"seed": 123,
},
}

cmd = self.get_one_step_cmd(test_target_device_config_str)
print(f"Executing: {' '.join(cmd)}")
result = subprocess.run(cmd, check=True, text=True)
if is_remote:
test_target_device_config[f"{test_module_name}_arguments"].update(
{
"machine": self.machine,
"port": self.port,
}
)

result = self.execute_one_step_cmd(test_target_device_config)
assert (
result.returncode == 0
), f"Run Local Target Device failed with return code {result.returncode}"
Expand All @@ -114,8 +143,8 @@ def analyze_and_decide_next(self):
def main(args):
locator = AutoFaultLocator(args)
while True:
locator.run_remote_test_reference()
locator.run_local_test_target()
locator.run_test_reference_device(is_remote=False)
locator.run_test_target_device(is_remote=True)
should_continue = locator.analyze_and_decide_next()
if not should_continue:
break
Expand Down
64 changes: 63 additions & 1 deletion graph_net/subgraph_decompose_and_evaluation_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import glob
from dataclasses import dataclass, field, asdict
from typing import List, Dict
from graph_net_bench.analysis_util import get_incorrect_models
from graph_net_bench.analysis_util import get_incorrect_models, get_min_passed_tolerance
from graph_net.graph_net_root import get_graphnet_root
from graph_net_bench import path_utils

Expand Down Expand Up @@ -297,6 +297,64 @@ def update_running_state_with_incorrect_models(
)


class ToleranceRecord:
model_name2subgraph_tolerance_record = {}
filename = "tolerance_record.json"

@classmethod
def load(cls, pass_id, output_dir):
if pass_id >= 0:
work_dir = get_decompose_workspace_path(output_dir, pass_id)
filepath = os.path.join(work_dir, cls.filename)
with open(filepath, "r") as f:
data = json.load(f)
cls.model_name2subgraph_tolerance_record = data

@classmethod
def save(cls, pass_id, output_dir):
work_dir = get_decompose_workspace_path(output_dir, pass_id)
filepath = os.path.join(work_dir, cls.filename)
print(f"Save tolerance record to: {filepath}.")
with open(filepath, "w") as f:
json.dump(cls.model_name2subgraph_tolerance_record, f, indent=4)

@classmethod
def update(cls, pass_id, output_dir, decompose_config, log_path):
cls.load(pass_id - 1, output_dir)

subgraph_path2tolerance = get_min_passed_tolerance(log_path)
running_state = decompose_config.get_running_state(pass_id)
for subgraph_path, tolerance in subgraph_path2tolerance.items():
model_name, subgraph_idx = extract_model_name_and_subgraph_idx(
subgraph_path
)
if model_name not in running_state.model_name2record:
continue

split_positions = running_state.model_name2record[
model_name
].get_split_positions(decompose_config.decompose_method)
assert len(split_positions) >= 2
subgraph_split_point = int(split_positions[1])
if model_name not in cls.model_name2subgraph_tolerance_record:
cls.model_name2subgraph_tolerance_record[model_name] = {}
cls.model_name2subgraph_tolerance_record[model_name][
subgraph_split_point
] = tolerance

cls.model_name2subgraph_tolerance_record = dict(
sorted(cls.model_name2subgraph_tolerance_record.items())
)
for (
model_name,
subgraph_tolerance_record,
) in cls.model_name2subgraph_tolerance_record.items():
cls.model_name2subgraph_tolerance_record[model_name] = dict(
sorted(subgraph_tolerance_record.items(), key=lambda x: int(x[0]))
)
cls.save(pass_id, output_dir)


def get_rectfied_model_path(model_path):
graphnet_root = get_graphnet_root()
return os.path.join(graphnet_root, model_path.split("GraphNet/")[-1])
Expand Down Expand Up @@ -792,6 +850,10 @@ def main(args):
print_incorrect_models(
decompose_config, current_pass_id, log_prompt="[Analysis]"
)

ToleranceRecord.update(
current_pass_id, base_output_dir, decompose_config, log_path
)
print_summary_and_suggestion(decompose_config, current_pass_id)

# --- Step 5: Save States ---
Expand Down
16 changes: 11 additions & 5 deletions graph_net/test/get_incorrect_models_test.sh
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
#!/bin/bash


SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
GRAPH_NET_DIR=$(dirname "$SCRIPT_DIR")
PROJECT_ROOT=$(dirname "$GRAPH_NET_DIR")

# 将项目根目录加入Python路径
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"

TOLERANCE_LIST=(-2 -1 0 1 2)
LOG_FILE_PATH="log_file_for_test.txt"

python3 - <<END
from graph_net_bench import analysis_util

result = analysis_util.get_incorrect_models($TOLERANCE_LIST, '$LOG_FILE_PATH')
print_detail = False
for tolerance in range(-10, 3, 1):
incorrect_models = analysis_util.get_incorrect_models(tolerance, '$LOG_FILE_PATH')
if tolerance <= 1:
incorrect_models_next = analysis_util.get_incorrect_models(tolerance + 1, '$LOG_FILE_PATH')
incorrect_models = incorrect_models - incorrect_models_next
print(f"- tolerance=[{tolerance}, {tolerance + 1}), number of incorrect_models={len(incorrect_models)}")
if print_detail:
for idx, sample in enumerate(incorrect_models):
print(f"- {idx=}, {sample=}")
print()

for item in result:
print(item)
END
24 changes: 24 additions & 0 deletions graph_net_bench/analysis_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,27 @@ def get_incorrect_models(
failed_models.add(sample.get("model_path")) if not is_correct else None

return failed_models


def get_min_passed_tolerance(log_file_path: str, type: str = "ESt") -> int:
model_path2tolerance = {}
samples = parse_logs_to_data(log_file_path)

for sample in samples:
model_path = sample.get("model_path")
for tolerance in range(-10, 5, 1):
is_correct, fail_type = check_sample_correctness(sample, tolerance)
if is_correct:
model_path2tolerance[model_path] = tolerance
break

if type == "ESt":
for sample in samples:
model_path = sample.get("model_path")
if (
model_path not in model_path2tolerance
or model_path2tolerance[model_path] > 1
):
model_path2tolerance[model_path] = 1

return model_path2tolerance
6 changes: 3 additions & 3 deletions graph_net_rpc/sample_remote_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def __init__(self, machine: str, port: int):

def _get_stub(self):
if self._stub is None:
# Default is 4MB (4194304), increase it to 320MB
# Default is 4MB (4194304), increase it to 1024MB
options = [
("grpc.max_send_message_length", 320 * 1024 * 1024),
("grpc.max_receive_message_length", 320 * 1024 * 1024),
("grpc.max_send_message_length", 1024 * 1024 * 1024),
("grpc.max_receive_message_length", 1024 * 1024 * 1024),
]
self._channel = grpc.insecure_channel(
f"{self.machine}:{self.port}", options=options
Expand Down
Loading