diff --git a/graph_net/fault_locator/bi_search.py b/graph_net/fault_locator/bi_search.py index b69eeb3c1..3d2c94167 100644 --- a/graph_net/fault_locator/bi_search.py +++ b/graph_net/fault_locator/bi_search.py @@ -56,9 +56,6 @@ def bi_search( if stoper(search_history, high=high): break - # Determine current split point - truncate_pos = (low + high) // 2 - # Interval update if is_fault: # Fault detected in current prefix; search earlier for the root cause @@ -67,6 +64,9 @@ def bi_search( # Current prefix is healthy; the first fault must be in the suffix low = truncate_pos + 1 + # Determine current split point + truncate_pos = (low + high) // 2 + # Safety break for boundary convergence if low >= high: # Ensure the final point is captured if the loop terminates via boundary diff --git a/graph_net/fault_locator/torch/op_lib_evaluator.py b/graph_net/fault_locator/torch/op_lib_evaluator.py new file mode 100644 index 000000000..9a6b82643 --- /dev/null +++ b/graph_net/fault_locator/torch/op_lib_evaluator.py @@ -0,0 +1,106 @@ +import sys +import subprocess +import time +from pathlib import Path +from graph_net.declare_config_mixin import DeclareConfigMixin + + +class OpLibEvaluator(DeclareConfigMixin): + """ + Functor responsible for evaluating model samples by comparing a target operator + library's (e.g., FlagGems) performance and accuracy against a reference implementation. + The evaluator manages reference data generation and captures execution logs. + """ + + def __init__(self, config=None): + self.init_config(config) + + def declare_config( + self, + model_path_prefix: str, + output_dir: str, + op_lib: str, + device: str = "cuda", + compiler: str = "nope", + ): + """ + Configuration schema for operator library benchmarking. + The reference_data directory is automatically managed within the output_dir. + """ + pass + + def __call__(self, rel_model_path: str) -> str: + """ + Orchestrates the pipeline for reference data generation and target library testing. + Returns: + The complete log content from the target device test execution. + """ + output_path = Path(self.config["output_dir"]) + # Create an isolated workspace for the current model sample + workspace = output_path / rel_model_path + workspace.mkdir(parents=True, exist_ok=True) + + # Determine the shared directory for reference ground truth + reference_dir = output_path / "reference_data" + reference_dir.mkdir(parents=True, exist_ok=True) + + # Construct the absolute model path + full_model_path = Path(self.config["model_path_prefix"]) / rel_model_path + log_file = workspace / "op_lib_validation.log" + + # 1. Execute reference test to establish baseline metrics + self._run_reference_test(full_model_path, reference_dir) + + # 2. Execute target library test and capture performance/accuracy logs + return self._run_target_test(full_model_path, reference_dir, log_file) + + def _run_reference_test(self, full_model_path: Path, reference_dir: Path): + """ + Invokes the reference device test module to generate ground truth data. + """ + cmd = [ + sys.executable, + "-m", + "graph_net.torch.test_reference_device", + "--model-path", + str(full_model_path), + "--reference-dir", + str(reference_dir), + "--compiler", + self.config["compiler"], + "--device", + self.config["device"], + ] + # Reference tests are executed synchronously; output is captured but not returned + subprocess.run(cmd, check=True, capture_output=True, text=True) + + def _run_target_test( + self, full_model_path: Path, reference_dir: Path, log_file: Path + ) -> str: + """ + Invokes the target device test module for the specified op_lib and merges + stdout/stderr into the local log file. + """ + cmd = [ + sys.executable, + "-m", + "graph_net.torch.test_target_device", + "--model-path", + str(full_model_path), + "--reference-dir", + str(reference_dir), + "--device", + self.config["device"], + "--op-lib", + self.config["op_lib"], + ] + + print(" ".join(cmd)) + # Redirect all output to the log file for persistence and analysis + with log_file.open("w") as f: + starttime = time.time() + subprocess.run(cmd, stdout=f, stderr=subprocess.STDOUT, check=True) + endtime = time.time() + print("run_target_test running time {:.5f} s".format(endtime - starttime)) + + return log_file.read_text() diff --git a/graph_net/test/flaggems_fault_bisearcher_test.sh b/graph_net/test/flaggems_fault_bisearcher_test.sh new file mode 100644 index 000000000..ce702c908 --- /dev/null +++ b/graph_net/test/flaggems_fault_bisearcher_test.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# Resolve the root directory of the project +GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))") + +# Test Environment Setup +MODEL_LIST="$GRAPH_NET_ROOT/graph_net/test/small10_torch_samples_list.txt" +MODEL_PREFIX="$GRAPH_NET_ROOT" +OUTPUT_DIR="/tmp/workspace_auto_fault_bisearcher" + +# Execute the SamplePass via the standard CLI entry point +python3 -m graph_net.apply_sample_pass \ + --model-path-list "$MODEL_LIST" \ + --sample-pass-file-path "$GRAPH_NET_ROOT/graph_net/sample_pass/auto_fault_bisearcher.py" \ + --sample-pass-class-name AutoFaultBisearcher \ + --sample-pass-config $(base64 -w 0 <