diff --git a/graph_net/paddle/extractor.py b/graph_net/paddle/extractor.py index 2d8e6c63d..2d4ff5619 100644 --- a/graph_net/paddle/extractor.py +++ b/graph_net/paddle/extractor.py @@ -174,29 +174,43 @@ def run_model_with_dump_enabled(self, model_dump_path, **input_dict): paddle.set_flags(old_flags) return static_model - def translate_pir_program_to_sample_codes( - self, - model_dump_path, - split_positions=None, - group_head_and_tail=True, - ): + def get_ir_programs_path(self, model_dump_path): ir_programs_path = os.path.join(model_dump_path, "exec_programs.py") - example_inputs_path = os.path.join( - model_dump_path, "programs_example_input_tensor_meta.py" - ) assert os.path.isfile( ir_programs_path ), f"{ir_programs_path} is not a regular file." + return ir_programs_path + + def get_example_inputs_path(self, model_dump_path): + example_inputs_path = os.path.join( + model_dump_path, "programs_example_input_tensor_meta.py" + ) assert os.path.isfile( example_inputs_path ), f"{example_inputs_path} is not a regular file." + return example_inputs_path + def generate_op_example_inputs_path(self, model_dump_path, split_positions): # Arguments for graph decomposer op_example_inputs_path = ( os.path.join(model_dump_path, "op_example_input_tensor_meta.py") if split_positions else None ) + return op_example_inputs_path + + def translate_pir_program_to_sample_codes( + self, + model_dump_path, + split_positions=None, + group_head_and_tail=True, + ): + ir_programs_path = self.get_ir_programs_path(model_dump_path) + example_inputs_path = self.get_example_inputs_path(model_dump_path) + op_example_inputs_path = self.generate_op_example_inputs_path( + model_dump_path, split_positions + ) + all_samples = RunGeneration( model_name=self.name, ir_programs=ir_programs_path, diff --git a/graph_net/paddle/prologue_subgraph_unittest_generator.py b/graph_net/paddle/prologue_subgraph_unittest_generator.py new file mode 100644 index 000000000..227e117d5 --- /dev/null +++ b/graph_net/paddle/prologue_subgraph_unittest_generator.py @@ -0,0 +1,552 @@ +import os +import re +import sys +import subprocess +import ast +import inspect +import jinja2 +import textwrap +import tempfile +from pathlib import Path +from typing import Literal, List +from collections import namedtuple + +import paddle +from athena.graphnet_samples import SubgraphGenerator +from graph_net import imp_util +from graph_net.paddle.extractor import GraphExtractor as BuiltinGraphExtractor +from graph_net.paddle.graph_meta_restorer import GraphMetaRestorer +from graph_net.tensor_meta import TensorMeta + + +class GraphExtractor: + def __init__( + self, + config: dict, + model, + name, + dynamic, + input_spec=None, + ): + self.model = model + self.name = name.replace("/", "_") + self.dynamic = dynamic + self.input_spec = input_spec + self.config = self.make_config(**config) + + def make_config( + self, + subgraph_range: list, + device: Literal["auto", "cpu", "cuda", "xpu"] = "auto", + tolerance: int = 0, + try_run: bool = False, + output_dir: str = "/tmp/prologue_unittests", + ): + assert isinstance(subgraph_range, (tuple, list)) and len(subgraph_range) == 2 + for pos in subgraph_range: + assert isinstance( + pos, int + ), f"subgraph_range should be list of int, {subgraph_range=}" + return { + "subgraph_range": subgraph_range, + "device": device, + "tolerance": tolerance, + "try_run": try_run, + "output_dir": output_dir, + } + + def __call__(self, **input_dict): + extracted_model = self.get_prologue_subgraph_unittest_generator()(**input_dict) + return extracted_model + + def get_prologue_subgraph_unittest_generator(self): + return PrologueSubgraphUnittestGenerator( + config=self.config, + parent_model=self.model, + parent_model_name=self.name, + parent_input_spec=self.input_spec, + ) + + +PADDLE_UNITTEST_TEMPLATE = r""" +import os +import sys +import argparse +import unittest +import random +import numpy as np +import paddle + + +{% macro get_input_tensor_instance(tensor_meta) -%} +{%- set shape = tensor_meta.shape -%} +{%- set dtype = tensor_meta.dtype -%} +{%- set data = tensor_meta.data -%} +{%- set min_val = tensor_meta.min_val -%} +{%- set max_val = tensor_meta.max_val -%} +{%- set mean = tensor_meta.mean -%} +{%- set std = tensor_meta.std -%} +{%- if data is not none -%} + paddle.reshape(paddle.to_tensor({{data}}, dtype='{{dtype}}'), shape={{shape}}).to(device=device) +{%- elif dtype == "bool" -%} + init_integer_tensor(shape={{shape}}, dtype='{{dtype}}', device=device, min_val=0, max_val=1) +{%- elif dtype in ["int8", "int16", "int32", "int64"] -%} + init_integer_tensor(shape={{shape}}, dtype='{{dtype}}', device=device, min_val={{min_val}}, max_val={{max_val}}) +{%- elif dtype in ["float16", "bfloat16", "float32", "float64"] -%} + {%- if mean is not none or std is not none -%} + init_float_tensor(shape={{shape}}, dtype='{{dtype}}', device=device, min_val={{min_val}}, max_val={{max_val}}, mean={{mean}}, std={{std}}) + {%- else -%} + init_float_tensor(shape={{shape}}, dtype='{{dtype}}', device=device, min_val={{min_val}}, max_val={{max_val}}) + {%- endif -%} +{%- endif -%} +{%- endmacro -%} + + +def init_integer_tensor(shape, dtype, device, min_val, max_val): + array = np.random.randint( + low=min_val, high=max_val + 1, size=shape, dtype="int64" + ) + return paddle.to_tensor(array).to(dtype).to(device) + + +def init_float_tensor(shape, dtype, device, min_val, max_val, mean=None, std=None): + if mean is not None and std is not None: + array = np.random.normal(0, 1, shape) * std * 0.2 + mean + array = np.clip(array, min_val, max_val) + else: + array = np.random.uniform(low=min_val, high=max_val, size=shape) + return paddle.to_tensor(array).to(dtype).to(device) + + +class PrologueLayer(paddle.nn.Layer): +{{graph_module_desc.prologue_forward_func}} + + +class SuspectLayer(paddle.nn.Layer): +{{graph_module_desc.suspect_forward_func}} + + +class TestModel(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.prologue_layer = PrologueLayer() + self.suspect_layer = SuspectLayer() + + def forward(self, {{graph_module_desc.arg_names | join(", ")}}): + {{graph_module_desc.prologue_returns | join(", ")}} = self.prologue_layer({{graph_module_desc.prologue_arg_names | join(", ")}}) + {{graph_module_desc.suspect_returns | join(", ")}} = self.suspect_layer({{graph_module_desc.suspect_arg_names | join(", ")}}) + return ({{graph_module_desc.suspect_returns | join(", ")}},) + + +def get_input_dict(device): + input_dict = { + {%- for tensor_meta in graph_module_desc.tensor_metas %} + '{{tensor_meta.name}}': {{get_input_tensor_instance(tensor_meta)}}, + {%- endfor %} + } + return input_dict + + +def tolerance_generator(tolerance, dtype): + if dtype == paddle.float16: + return 10 ** (tolerance * 3 / 5), 10**tolerance + elif dtype == paddle.bfloat16: + return 10 ** (tolerance * 1.796 / 5), 10**tolerance + elif dtype == paddle.float32: + return 10 ** (tolerance * 5.886 / 5), 10**tolerance + elif dtype == paddle.float64: + return 10 ** (tolerance * 7 / 5), 10 ** (tolerance * 7 / 5) + else: + assert False, f"Unsupported {dtype=}." + + +class {{graph_module_desc.test_name}}Test(unittest.TestCase): + def setUp(self): + self.device = TEST_ARGS.device + self.is_reference = TEST_ARGS.is_reference + self.reference_dir = TEST_ARGS.reference_dir + self.tolerance = {{graph_module_desc.tolerance}} + + paddle.seed(123) + random.seed(123) + np.random.seed(123) + + self.input_dict = get_input_dict(self.device) + self.test_model = TestModel() + + def _flatten_outputs_to_list(self, outs): + flattened_outs = outs + if isinstance(outs, paddle.Tensor): + flattened_outs = [outs] + else: + flattened_outs = [ + x + for out in outs + for x in (out if isinstance(out, (tuple, list)) else (out,)) + ] + return flattened_outs + + def run_prologue_layer(self): + prologue_inputs = [ + {%- for arg_name in graph_module_desc.prologue_arg_names %} + self.input_dict['{{arg_name}}'], + {%- endfor %} + ] + prologue_outputs = self.test_model.prologue_layer(*prologue_inputs) + return self._flatten_outputs_to_list(prologue_outputs) + + def run_suspect_layer(self, prologue_outputs): + suspect_inputs = [ + {%- for arg_name in graph_module_desc.suspect_arg_names %} + {%- if arg_name not in graph_module_desc.prologue_returns %} + self.input_dict['{{arg_name}}'], + {%- else %} + {%- for output_name in graph_module_desc.prologue_returns %} + {%- if arg_name == output_name %} + prologue_outputs[{{loop.index0}}], + {%- endif %} + {%- endfor %} + {%- endif %} + {%- endfor %} + ] + suspect_outputs = self.test_model.suspect_layer(*suspect_inputs) + return self._flatten_outputs_to_list(suspect_outputs) + + def run_test_model(self): + test_outputs = self.test_model(**self.input_dict) + return self._flatten_outputs_to_list(test_outputs) + + def check_dtypes(self, reference_outputs, target_outputs): + def _get_output_dtypes(outs): + dtypes = [ + str(tensor.dtype).replace("paddle.", "") + if isinstance(tensor, paddle.Tensor) + else None + for i, tensor in enumerate(outs) + ] + return dtypes + + reference_dtypes = _get_output_dtypes(reference_outputs) + target_dtypes = _get_output_dtypes(target_outputs) + dtype_match = all( + reference == target for reference, target in zip(reference_dtypes, target_dtypes) + ) + self.assertTrue(dtype_match, f"Data type of outputs are not matched ({reference_dtypes=} vs {target_dtypes}).") + + def check_shapes(self, reference_outputs, target_outputs): + def _get_output_shapes(outs): + shapes = [ + tensor.shape if isinstance(tensor, paddle.Tensor) else None + for i, tensor in enumerate(outs) + ] + return shapes + + reference_shapes = _get_output_shapes(reference_outputs) + target_shapes = _get_output_shapes(target_outputs) + shape_match = all( + reference == target for reference, target in zip(reference_shapes, target_shapes) + ) + self.assertTrue(shape_match, f"Shape of outputs are not matched ({reference_shapes=} vs {target_shapes}).") + + def check_results(self, reference_outputs, target_outputs): + def _convert_to_numpy(out): + if out.dtype in [paddle.float16, paddle.bfloat16]: + return out.cast("float32").numpy() + else: + return out.numpy() + + assert len(reference_outputs) == len(target_outputs), f"The number of outputs is not equal ({len(reference_outputs)=} vs {len(target_outputs)})." + self.check_dtypes(reference_outputs, target_outputs) + self.check_shapes(reference_outputs, target_outputs) + + for reference, target in zip(reference_outputs, target_outputs): + atol, rtol = tolerance_generator(self.tolerance, reference.dtype) + np.testing.assert_allclose(_convert_to_numpy(reference), _convert_to_numpy(target), atol, rtol) + + def test_separated(self): + prologue_output_path = os.path.join(self.reference_dir, "{{graph_module_desc.model_name}}_separate_prologue.pdout") + if self.is_reference: + prologue_outputs = self.run_prologue_layer() + print(f"Save prologue output tensors to {prologue_output_path}.") + paddle.save(prologue_outputs, prologue_output_path) + else: + print(f"Load prologue output tensors from {prologue_output_path}") + prologue_outputs = paddle.load(prologue_output_path) + + test_output_path = os.path.join(self.reference_dir, "{{graph_module_desc.model_name}}_separate_reference.pdout") + test_outputs = self.run_suspect_layer(prologue_outputs) + if self.is_reference: + print(f"Save test output tensors to {test_output_path}.") + paddle.save(test_outputs, test_output_path) + else: + print(f"Load test output tensors on reference device from {test_output_path}.") + test_reference_outputs = paddle.load(test_output_path) + self.check_results(test_reference_outputs, test_outputs) + + def test_combined(self): + test_output_path = os.path.join(self.reference_dir, "{{graph_module_desc.model_name}}_combined_reference.pdout") + test_outputs = self.run_test_model() + if self.is_reference: + print(f"Save test output tensors to {test_output_path}.") + paddle.save(test_outputs, test_output_path) + else: + print(f"Load test output tensors on reference device from {test_output_path}.") + test_reference_outputs = paddle.load(test_output_path) + self.check_results(test_reference_outputs, test_outputs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--is-reference", action="store_true", default=False) + parser.add_argument("--device", type=str, required=True) + parser.add_argument("--reference-dir", type=str, required=True) + args, remaining = parser.parse_known_args() + + global TEST_ARGS + TEST_ARGS = args + + unittest.main(argv=[sys.argv[0]] + remaining) +""" + + +GraphModuleDescriptor = namedtuple( + "GraphModuleDescriptor", + [ + "model_name", + "test_name", + "tolerance", + "arg_names", + "tensor_metas", + "prologue_arg_names", + "prologue_returns", + "prologue_forward_func", + "suspect_arg_names", + "suspect_returns", + "suspect_forward_func", + ], +) + + +def load_class_from_file(file_path: str, class_name: str): + print(f"Load {class_name} from {file_path}") + module = imp_util.load_module(file_path, "unnamed") + model_class = getattr(module, class_name, None) + return model_class + + +class PrologueSubgraphUnittestGenerator: + def __init__( + self, + config: dict, + parent_model: paddle.nn.Layer, + parent_model_name: str, + parent_input_spec: List[paddle.static.InputSpec], + ): + self.config = config + self.extracted = False + self.model_name = parent_model_name + self.parent_model_path = os.path.dirname(parent_model.__graph_net_file_path__) + self.builtin_extractor = BuiltinGraphExtractor( + model=parent_model, + name=parent_model_name, + dynamic=False, + input_spec=parent_input_spec, + workspace_path=self.config["output_dir"], + ) + self.subgraph_range = self.config["subgraph_range"] + self.device = self._choose_device(self.config["device"]) + self.tolerance = self.config["tolerance"] + self.try_run = self.config["try_run"] + self.output_dir = self.config["output_dir"] + self.graph_meta_restorer = self._make_graph_meta_restorer() + + def __call__(self, **input_dict): + extracted_model = None + if not self.extracted: + extracted_model = self.do_extract(**input_dict) + self.extracted = True + return extracted_model + + def do_extract(self, **input_dict): + print( + f"[PrologueUnittestGenerator] Generate unittest for {self.parent_model_path} with subgraph_range={self.subgraph_range}" + ) + + # 1. Run the model to dump pir programs + model_dump_path = os.path.join( + self.builtin_extractor.dump_path, self.model_name + ) + static_model = self.builtin_extractor.run_model_with_dump_enabled( + model_dump_path, **input_dict + ) + + # 2. Convert pir programs to graphnet samples + ir_programs_path = self.builtin_extractor.get_ir_programs_path(model_dump_path) + example_inputs_path = self.builtin_extractor.get_example_inputs_path( + model_dump_path + ) + op_example_inputs_path = self.builtin_extractor.generate_op_example_inputs_path( + model_dump_path, self.subgraph_range + ) + subgraph_generator = SubgraphGenerator( + model_name=self.model_name, + programs_file=ir_programs_path, + example_inputs_file=example_inputs_path, + op_example_inputs_file=op_example_inputs_path, + eval_mode=True, + tmp_dir=model_dump_path, + ) + + # 3. Generate unittest + with tempfile.TemporaryDirectory(prefix="prologue_unittest_") as tmp_dir: + self.generate(subgraph_generator, tmp_dir) + return static_model + + def _save_and_get_graph_module(self, subgraph_generator, subgraph_range, tmp_dir): + results = subgraph_generator(subgraph_range, False) + assert len(results) == 1 + output_name = f"{subgraph_range[0]}_{subgraph_range[1]}" + output_path = os.path.join(tmp_dir, f"{self.model_name}-{output_name}") + self.builtin_extractor.write_sample_to_file(output_path, results[0]) + graph_module = load_class_from_file( + Path(output_path) / "model.py", class_name="GraphModule" + ) + return graph_module, output_path + + def generate(self, subgraph_generator, tmp_dir): + test_name = "".join( + word.capitalize() for word in re.split(r"[_.-]", self.model_name) + ) + + graph_module, output_path = self._save_and_get_graph_module( + subgraph_generator, self.subgraph_range, tmp_dir + ) + arg_names = self._get_forward_arg_names(graph_module) + self.graph_meta_restorer(output_path) + tensor_metas = self._get_tensor_metas(output_path) + + # prologue model information + prologue_subgraph_range = [self.subgraph_range[0], self.subgraph_range[1] - 1] + prologue_graph_module, _ = self._save_and_get_graph_module( + subgraph_generator, prologue_subgraph_range, tmp_dir + ) + prologue_forward_func, prologue_returns = self._get_forward_func_and_returns( + prologue_graph_module + ) + prologue_arg_names = self._get_forward_arg_names(prologue_graph_module) + + # suspect model information + suspect_subgraph_range = [self.subgraph_range[1] - 1, self.subgraph_range[1]] + suspect_graph_module, _ = self._save_and_get_graph_module( + subgraph_generator, suspect_subgraph_range, tmp_dir + ) + suspect_forward_func, suspect_returns = self._get_forward_func_and_returns( + suspect_graph_module + ) + suspect_arg_names = self._get_forward_arg_names(suspect_graph_module) + + def _generate_unittest(): + graph_module_desc = GraphModuleDescriptor( + model_name=self.model_name, + test_name=test_name, + tolerance=self.tolerance, + arg_names=arg_names, + tensor_metas=tensor_metas, + prologue_arg_names=prologue_arg_names, + prologue_returns=prologue_returns, + prologue_forward_func=prologue_forward_func, + suspect_arg_names=suspect_arg_names, + suspect_returns=suspect_returns, + suspect_forward_func=suspect_forward_func, + ) + return self._render_template(graph_module_desc) + + # Generate unittest with main for try-run. + unittest = _generate_unittest() + self._write_to_file(unittest, self.output_dir) + if self._try_to_run_unittest(unittest, tmp_dir): + self._write_to_file(unittest, self.output_dir) + + def _choose_device(self, device) -> str: + if device in ["cpu", "gpu", "xpu"]: + return device + return paddle.get_device() + + def _make_graph_meta_restorer(self): + config = { + "update_inplace": True, + "input_meta_allow_partial_update": False, + } + graph_meta_restorer = GraphMetaRestorer( + config=config, parent_model_path=self.parent_model_path + ) + return graph_meta_restorer + + def _write_to_file(self, unittest, output_dir): + output_path = Path(output_dir) / f"{self.model_name}_test.py" + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(unittest, encoding="utf-8") + print(f"[PrologueUnittestGenerator] Generate unittest: {output_path}") + return output_path + + def _try_to_run_unittest(self, unittest, tmp_dir): + if not self.try_run: + return True + + output_path = self._write_to_file(unittest, tmp_dir) + result = subprocess.run( + [ + sys.executable, + output_path, + "--is-reference", + "--device", + self.device, + "--reference-dir", + os.path.join(tmp_dir, "test_reference_outputs"), + ], + check=True, + ) + return result.returncode == 0 + + def _get_tensor_metas(self, model_path): + tensor_metas = TensorMeta.unserialize_from_py_file_order_preserved( + Path(model_path) / "weight_meta.py" + ) + tensor_metas.extend( + TensorMeta.unserialize_from_py_file_order_preserved( + Path(model_path) / "input_meta.py" + ) + ) + return tensor_metas + + def _get_forward_arg_names(self, graph_module): + arg_names = [] + sig = inspect.signature(graph_module.forward) + for name, param in sig.parameters.items(): + if name == "self": + continue + arg_names.append(name) + return arg_names + + def _get_forward_func_and_returns(self, graph_module): + # the whole definition of forward function + source = inspect.getsource(graph_module.forward) + lines = source.splitlines() + num_indents = len(lines[0]) - len(lines[0].lstrip()) + + tree = ast.parse(textwrap.dedent(source)) + func_def = tree.body[0] + forward_func = textwrap.indent(ast.unparse(func_def), " " * num_indents) + + # the return statements + return_node_values = [ + node.value for node in ast.walk(func_def) if isinstance(node, ast.Return) + ] + assert len(return_node_values) == 1 and return_node_values[0] is not None + return_code = ast.unparse(return_node_values[0]) + return_names = [name.strip() for name in return_code.strip("()").split(",")] + return forward_func, return_names + + def _render_template(self, graph_module_desc): + template_str = PADDLE_UNITTEST_TEMPLATE + return jinja2.Template(template_str).render(graph_module_desc=graph_module_desc) diff --git a/graph_net/paddle/sample_passes/__init__.py b/graph_net/paddle/sample_pass/__init__.py similarity index 100% rename from graph_net/paddle/sample_passes/__init__.py rename to graph_net/paddle/sample_pass/__init__.py diff --git a/graph_net/sample_pass/agent_unittest_generator.py b/graph_net/sample_pass/agent_unittest_generator.py index d0c7ec661..f1636d533 100644 --- a/graph_net/sample_pass/agent_unittest_generator.py +++ b/graph_net/sample_pass/agent_unittest_generator.py @@ -116,7 +116,7 @@ def test_main(self): {%- set mean = tensor_meta.mean -%} {%- set std = tensor_meta.std -%} {%- if data is not none -%} - paddle.to_tensor({{data}}, dtype='{{dtype}}', shape={{shape}}).to(device='{{device}}') + paddle.reshape(paddle.to_tensor({{data}}, dtype='{{dtype}}'), shape={{shape}}).to(device='{{device}}') {%- elif dtype == "bool" -%} paddle.randint(low=0, high=2, shape={{shape}}, dtype='{{dtype}}').to(device='{{device}}') {%- elif dtype in ["int8", "int16", "int32", "int64"] -%} @@ -204,8 +204,6 @@ def load_class_from_file(file_path: str, class_name: str): class AgentUnittestGenerator: - """Generate standalone unittest scripts for Torch samples.""" - def __init__( self, framework: str, @@ -432,7 +430,7 @@ def _render_template(self, graph_module_desc): class AgentUnittestGeneratorPass(SamplePass, ResumableSamplePassMixin): - """SamplePass wrapper to generate Torch unittests via model_path_handler.""" + """SamplePass to generate unittests in KernelBench format.""" def __init__(self, config=None): super().__init__(config) diff --git a/graph_net/subgraph_decompose_and_evaluation_step.py b/graph_net/subgraph_decompose_and_evaluation_step.py index cd2a42796..cb222c111 100755 --- a/graph_net/subgraph_decompose_and_evaluation_step.py +++ b/graph_net/subgraph_decompose_and_evaluation_step.py @@ -418,6 +418,81 @@ def run_evaluation( ), f"[ERROR] test failed for {work_dir}, please check the log." +def generate_unittest_for_single_model( + framework, model_name, model_path, subgraph_range, tolerance, output_dir, log_path +): + graphnet_root = path_utils.get_graphnet_root() + decorator_config = { + "decorator_path": f"{graphnet_root}/graph_net/paddle/extractor.py", + "decorator_config": { + "name": model_name, + "custom_extractor_path": f"{graphnet_root}/graph_net/paddle/prologue_subgraph_unittest_generator.py", + "custom_extractor_config": { + "output_dir": output_dir, + "subgraph_range": subgraph_range, + "device": "auto", + "tolerance": tolerance, + "try_run": True, + }, + }, + } + + decorator_config_b64 = convert_json_to_b64_string(decorator_config) + + print( + f"[Unittest] model_path: {model_path}, subgraph_range: {subgraph_range}", + flush=True, + ) + cmd = [ + sys.executable, + "-m", + f"graph_net.{framework}.run_model", + "--model-path", + model_path, + "--decorator-config", + decorator_config_b64, + ] + with open(log_path, "a") as f: + result = subprocess.run(cmd, stdout=f, stderr=f, text=True) + assert result.returncode == 0 + + +def generate_unittest(decompose_config, pass_id, output_dir): + running_state = decompose_config.get_running_state(pass_id) + + unittest_dir = os.path.join(output_dir, "unittests") + log_path = os.path.join(output_dir, "log_unittest_generation.txt") + print(f"[Unittest] log_path: {log_path}", flush=True) + for model_name, model_record in running_state.model_name2record.items(): + if not model_record.incorrect_subgraph_idxs: + continue + + original_path = model_record.original_path + subgraph_idx = model_record.incorrect_subgraph_idxs[0] + if decompose_config.decompose_method == "fixed-start": + subgraph_range = [0, model_record.uniform_split_positions[subgraph_idx + 1]] + else: + subgraph_range = [ + model_record.uniform_split_positions[subgraph_idx], + model_record.uniform_split_positions[subgraph_idx + 1], + ] + + rectified_model_path = get_rectfied_model_path(original_path) + assert os.path.exists( + rectified_model_path + ), f"{rectified_model_path} does not exist." + + generate_unittest_for_single_model( + decompose_config.framework, + model_name, + rectified_model_path, + subgraph_range, + decompose_config.tolerance[0], + unittest_dir, + log_path, + ) + + def reconstruct_split_positions_for_subgraphs( split_positions, subgraph_idxs, max_subgraph_size ): @@ -496,6 +571,8 @@ def generate_successor_tasks(args, output_dir, pass_id): max_subgraph_size=max_subgraph_size, running_states=prev_config.running_states, ) + if max_subgraph_size <= 0: + return decompose_config prev_running_state = prev_config.get_running_state(pass_id - 1) assert prev_running_state is not None @@ -538,14 +615,17 @@ def prepare_tasks_and_verify(args, pass_id, output_dir): print_incorrect_models(decompose_config, pass_id - 1, log_prompt="[Init]") if not decompose_config.get_incorrect_models(pass_id - 1): - print("[FINISHED] No models need processing.", flush=True) + print( + f"\n[Conclusion] No incorrect models after {pass_id - 1} steps.", flush=True + ) sys.exit(0) if decompose_config.max_subgraph_size <= 0: print( - f"[FINISHED] Cannot decompose with max_subgraph_size {decompose_config.max_subgraph_size}.", + f"\n[Conclusion] Decomposition has reached the minimal granularity (max_subgraph_size = 1) after {pass_id - 1} steps.", flush=True, ) + generate_unittest(decompose_config, pass_id - 1, output_dir) sys.exit(0) return decompose_config @@ -648,15 +728,14 @@ def print_summary_and_suggestion(decompose_config, pass_id): flush=True, ) print( - ">>> Please start next round decomposition test (Run this script again).", + ">>> Please start next decomposition step (Run this script again).", flush=True, ) - elif num_incorrect_models > 0 and decompose_config.max_subgraph_size <= 1: + elif decompose_config.max_subgraph_size <= 1: print( - ">>> [FAILURE] Minimal granularity reached, but errors persist.", flush=True + ">>> [Conclusion] Decomposition has reached the minimal granularity (max_subgraph_size = 1) after {pass_id - 1} steps.", + flush=True, ) - else: - print(">>> [SUCCESS] Debugging converged.", flush=True) print("=" * 80, flush=True) @@ -693,9 +772,7 @@ def main(args): # --- Step 4: Analysis --- if task_controller.task_scheduler["post_analysis"]: - tolerance = ( - args.tolerance[0] if isinstance(args.tolerance, list) else args.tolerance - ) + tolerance = args.tolerance[0] print(f"\n--- Phase 3: Analysis (torlance={tolerance}) ---") next_pass_incorrect_models = sorted(get_incorrect_models(tolerance, log_path)) decompose_config.update_running_state_with_incorrect_models( diff --git a/graph_net/tensor_meta.py b/graph_net/tensor_meta.py index d0416c524..cd1e03a0f 100755 --- a/graph_net/tensor_meta.py +++ b/graph_net/tensor_meta.py @@ -1,3 +1,4 @@ +import ast import importlib.util as imp import inspect from dataclasses import dataclass @@ -40,6 +41,26 @@ def unserialize_from_py_file(cls, file_path: str) -> list["TensorMeta"]: for attrs in [cls._convert_cls_to_attrs(tensor_meta_cls)] ] + @classmethod + def unserialize_from_py_file_order_preserved(cls, file_path) -> list["TensorMeta"]: + return [ + TensorMeta( + record_class_name=attrs.get("record_class_name"), + name=attrs.get("name"), + original_name=attrs.get("original_name", None), + shape=attrs.get("shape", []), + dtype=attrs.get("dtype"), + device=attrs.get("device", None), + mean=attrs.get("mean", None), + std=attrs.get("std", None), + data=attrs.get("data", None), + max_val=attrs.get("max_val", None), + min_val=attrs.get("min_val", None), + ) + for name, tensor_meta_cls in cls._get_classes_order_preserved(file_path) + for attrs in [cls._convert_cls_to_attrs(tensor_meta_cls)] + ] + @classmethod def _convert_cls_to_attrs(cls, tensor_meta_cls): attrs = { @@ -52,11 +73,25 @@ def _convert_cls_to_attrs(cls, tensor_meta_cls): @classmethod def _get_classes(cls, file_path, name="unnamed"): - spec = imp.spec_from_file_location("unnamed", file_path) + spec = imp.spec_from_file_location(name, file_path) unnamed = imp.module_from_spec(spec) spec.loader.exec_module(unnamed) yield from inspect.getmembers(unnamed, inspect.isclass) + @classmethod + def _get_classes_order_preserved(cls, file_path, name="unnamed"): + with open(file_path, "r", encoding="utf-8") as f: + tree = ast.parse(f.read(), filename=file_path) + + class_names = [ + node.name for node in tree.body if isinstance(node, ast.ClassDef) + ] + + spec = imp.spec_from_file_location(name, file_path) + unnamed = imp.module_from_spec(spec) + spec.loader.exec_module(unnamed) + yield from [(name, getattr(unnamed, name)) for name in class_names] + def update_shape_safely(self, shape): self.shape = shape if self.data is None: diff --git a/graph_net/test/prologue_subgraph_unittest_generator_test.sh b/graph_net/test/prologue_subgraph_unittest_generator_test.sh new file mode 100644 index 000000000..23bb4ca3f --- /dev/null +++ b/graph_net/test/prologue_subgraph_unittest_generator_test.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +GRAPH_NET_ROOT=$(python -c "import graph_net, os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))") + +MODEL_PATH=paddle_samples/PaddleX/ResNet18 +decorator_config_json_str=$(cat <