From c86ff763f253cae25233f9d1f8e970704976b548 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Sun, 4 Jan 2026 15:07:26 +0800 Subject: [PATCH 1/6] Implement PrologueUnittestGenerator. --- graph_net/paddle/extractor.py | 32 +- .../paddle/prologue_unittest_generator.py | 409 ++++++++++++++++++ .../__init__.py | 0 .../sample_pass/agent_unittest_generator.py | 4 +- 4 files changed, 433 insertions(+), 12 deletions(-) create mode 100644 graph_net/paddle/prologue_unittest_generator.py rename graph_net/paddle/{sample_passes => sample_pass}/__init__.py (100%) diff --git a/graph_net/paddle/extractor.py b/graph_net/paddle/extractor.py index 2d8e6c63d..2d4ff5619 100644 --- a/graph_net/paddle/extractor.py +++ b/graph_net/paddle/extractor.py @@ -174,29 +174,43 @@ def run_model_with_dump_enabled(self, model_dump_path, **input_dict): paddle.set_flags(old_flags) return static_model - def translate_pir_program_to_sample_codes( - self, - model_dump_path, - split_positions=None, - group_head_and_tail=True, - ): + def get_ir_programs_path(self, model_dump_path): ir_programs_path = os.path.join(model_dump_path, "exec_programs.py") - example_inputs_path = os.path.join( - model_dump_path, "programs_example_input_tensor_meta.py" - ) assert os.path.isfile( ir_programs_path ), f"{ir_programs_path} is not a regular file." + return ir_programs_path + + def get_example_inputs_path(self, model_dump_path): + example_inputs_path = os.path.join( + model_dump_path, "programs_example_input_tensor_meta.py" + ) assert os.path.isfile( example_inputs_path ), f"{example_inputs_path} is not a regular file." + return example_inputs_path + def generate_op_example_inputs_path(self, model_dump_path, split_positions): # Arguments for graph decomposer op_example_inputs_path = ( os.path.join(model_dump_path, "op_example_input_tensor_meta.py") if split_positions else None ) + return op_example_inputs_path + + def translate_pir_program_to_sample_codes( + self, + model_dump_path, + split_positions=None, + group_head_and_tail=True, + ): + ir_programs_path = self.get_ir_programs_path(model_dump_path) + example_inputs_path = self.get_example_inputs_path(model_dump_path) + op_example_inputs_path = self.generate_op_example_inputs_path( + model_dump_path, split_positions + ) + all_samples = RunGeneration( model_name=self.name, ir_programs=ir_programs_path, diff --git a/graph_net/paddle/prologue_unittest_generator.py b/graph_net/paddle/prologue_unittest_generator.py new file mode 100644 index 000000000..6f5415482 --- /dev/null +++ b/graph_net/paddle/prologue_unittest_generator.py @@ -0,0 +1,409 @@ +import os +import re +import sys +import subprocess +import ast +import inspect +import jinja2 +import textwrap +import tempfile +from pathlib import Path +from typing import Literal, List +from collections import namedtuple + +import paddle +from athena.graphnet_samples import SubgraphGenerator +from graph_net import imp_util +from graph_net.paddle.extractor import GraphExtractor as BuiltinGraphExtractor +from graph_net.tensor_meta import TensorMeta + + +class GraphExtractor: + def __init__( + self, + config: dict, + model, + name, + dynamic, + input_spec=None, + ): + self.model = model + self.name = name.replace("/", "_") + self.dynamic = dynamic + self.input_spec = input_spec + self.config = self.make_config(**config) + + def make_config( + self, + subgraph_range: list, + device: Literal["auto", "cpu", "cuda", "xpu"] = "auto", + try_run: bool = False, + data_input_predicator_filepath: str = None, + data_input_predicator_class_name: str = None, + output_dir: str = "/tmp/prologue_unittests", + ): + assert isinstance(subgraph_range, (tuple, list)) and len(subgraph_range) == 2 + for pos in subgraph_range: + assert isinstance( + pos, int + ), f"subgraph_range should be list of int, {subgraph_range=}" + return { + "subgraph_range": subgraph_range, + "device": device, + "try_run": try_run, + "data_input_predicator_filepath": data_input_predicator_filepath, + "data_input_predicator_class_name": data_input_predicator_class_name, + "output_dir": output_dir, + } + + def __call__(self, **input_dict): + extracted_model = self.get_prologue_subgraph_unittest_generator()(**input_dict) + return extracted_model + + def get_prologue_subgraph_unittest_generator(self): + return PrologueSubgraphUnittestGenerator( + config=self.config, + parent_model=self.model, + parent_model_name=self.name, + parent_input_spec=self.input_spec, + ) + + +PADDLE_UNITTEST_TEMPLATE = r""" +import unittest +import numpy as np +import paddle + +{% macro get_input_tensor_instance(tensor_meta, device) -%} +{%- set shape = tensor_meta.shape -%} +{%- set dtype = tensor_meta.dtype -%} +{%- set data = tensor_meta.data -%} +{%- set min_val = tensor_meta.min_val -%} +{%- set max_val = tensor_meta.max_val -%} +{%- set mean = tensor_meta.mean -%} +{%- set std = tensor_meta.std -%} +{%- if data is not none -%} + paddle.to_tensor({{data}}, dtype='{{dtype}}', shape={{shape}}).to(device='{{device}}') +{%- elif dtype == "bool" -%} + paddle.randint(low=0, high=2, shape={{shape}}, dtype='{{dtype}}').to(device='{{device}}') +{%- elif dtype in ["int8", "int16", "int32", "int64"] -%} + paddle.randint(low={{min_val}}, high={{max_val}} + 1, shape={{shape}}, dtype='{{dtype}}').to(device='{{device}}') +{%- elif dtype in ["float16", "bfloat16", "float32", "float64"] -%} + {%- if mean is not none or std is not none -%} + init_float_tensor(shape={{shape}}, dtype='{{dtype}}', max_val={{max_val}}, min_val={{min_val}}, mean={{mean}}, std={{std}}) + {%- else -%} + init_float_tensor(shape={{shape}}, dtype='{{dtype}}', max_val={{max_val}}, min_val={{min_val}}) + {%- endif -%} +{%- endif -%} +{%- endmacro -%} + + +def init_float_tensor(shape, dtype, max_val, min_val, mean=None, std=None): + if mean is not None and std is not None: + tensor = paddle.randn(shape, dtype="float32") * std * 0.2 + mean + tensor = paddle.clip(tensor, min=min_val, max=max_val) + else: + tensor = paddle.uniform(shape=shape, dtype="float32", min=min_val, max=max_val) + return tensor.to(dtype).to('{{graph_module_desc.device}}') + + +class PrologueLayer(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, {{graph_module_desc.prologue_weight_arg_names | join(", ")}}, {{graph_module_desc.prologue_input_arg_names | join(", ")}}): + {{graph_module_desc.prologue_forward_body}} + + +class Model(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, {{graph_module_desc.weight_arg_names | join(", ")}}, {{graph_module_desc.input_arg_names | join(", ")}}): + {{graph_module_desc.forward_body}} + + +def get_inputs(): + {%- for arg_name in graph_module_desc.input_arg_names %} + {%- set input_idx = loop.index0 %} + {{arg_name}} = {{get_input_tensor_instance(graph_module_desc.input_tensor_metas[input_idx], graph_module_desc.device)}} + {%- endfor %} + return [{{graph_module_desc.input_arg_names | join(", ")}}] + + +class {{graph_module_desc.model_name}}Test(unittest.TestCase): + def setUp(self): + paddle.seed(123) + self.model = Model() + + def test_main(self): + inputs = get_inputs() + outputs = self.model(*inputs) + + +if __name__ == "__main__": + unittest.main() +""" + + +GraphModuleDescriptor = namedtuple( + "GraphModuleDescriptor", + [ + "device", + "model_name", + "input_arg_names", + "input_tensor_metas", + "weight_arg_names", + "weight_tensor_metas", + "forward_body", + ], +) + + +def load_class_from_file(file_path: str, class_name: str): + print(f"Load {class_name} from {file_path}") + module = imp_util.load_module(file_path, "unnamed") + model_class = getattr(module, class_name, None) + return model_class + + +class PrologueSubgraphUnittestGenerator: + def __init__( + self, + config: dict, + parent_model: paddle.nn.Layer, + parent_model_name: str, + parent_input_spec: List[paddle.static.InputSpec], + ): + self.config = config + self.extracted = False + self.parent_model_path = os.path.dirname(parent_model.__graph_net_file_path__) + self.builtin_extractor = BuiltinGraphExtractor( + model=parent_model, + name=parent_model_name, + dynamic=False, + input_spec=parent_input_spec, + workspace_path=self.config["output_dir"], + ) + self.subgraph_range = self.config["subgraph_range"] + self.device = self._choose_device(self.config["device"]) + self.try_run = self.config["try_run"] + self.data_input_predicator = self._make_data_input_predicator( + self.config["data_input_predicator_filepath"], + self.config["data_input_predicator_class_name"], + ) + + def __call__(self, **input_dict): + extracted_model = None + if not self.extracted: + extracted_model = self.do_extract(**input_dict) + self.extracted = True + return extracted_model + + def do_extract(self, **input_dict): + # 1. Run the model to dump pir programs + model_dump_path = os.path.join( + self.builtin_extractor.dump_path, self.builtin_extractor.name + ) + static_model = self.builtin_extractor.run_model_with_dump_enabled( + model_dump_path, **input_dict + ) + + # 2. Convert pir programs to graphnet samples + ir_programs_path = self.builtin_extractor.get_ir_programs_path(model_dump_path) + example_inputs_path = self.builtin_extractor.get_example_inputs_path( + model_dump_path + ) + op_example_inputs_path = self.builtin_extractor.generate_op_example_inputs_path( + model_dump_path, self.subgraph_range + ) + generator = SubgraphGenerator( + model_name=self.builtin_extractor.name, + programs_file=ir_programs_path, + example_inputs_file=example_inputs_path, + op_example_inputs_file=op_example_inputs_path, + eval_mode=True, + tmp_dir=model_dump_path, + ) + graphnet_sample_results = generator(self.subgraph_range, False) + assert len(graphnet_sample_results) == 1 + + return static_model + + def generate(self): + print(f"[PrologueUnittestGenerator] Generate unittest for {self.model_path}") + model_name = "".join( + word.capitalize() for word in re.split(r"[_.-]", self.model_path.name) + ) + graph_module = load_class_from_file( + self.model_path / "model.py", class_name="GraphModule" + ) + input_arg_names, weight_arg_names = self._get_input_and_weight_arg_names( + graph_module + ) + ( + input_tensor_metas, + weight_tensor_metas, + ) = self._get_input_and_weight_tensor_metas(input_arg_names, weight_arg_names) + + def _generate_unittest(): + graph_module_desc = GraphModuleDescriptor( + device=self.device, + model_name=model_name, + input_arg_names=input_arg_names, + input_tensor_metas=input_tensor_metas, + weight_arg_names=weight_arg_names, + weight_tensor_metas=weight_tensor_metas, + forward_body=self._get_forward_body( + graph_module, input_arg_names, weight_arg_names + ), + ) + return self._render_template(graph_module_desc) + + # Generate unittest with main for try-run. + unittest = _generate_unittest() + self._write_to_file(unittest, self.output_dir) + if self._try_to_run_unittest(unittest): + self._write_to_file(unittest, self.output_dir) + + def _choose_device(self, device) -> str: + import paddle + + if device in ["cpu", "gpu", "xpu"]: + return device + return "gpu" if paddle.device.is_compiled_with_cuda() else "cpu" + + def _make_data_input_predicator( + self, data_input_predicator_filepath, data_input_predicator_class_name + ): + if data_input_predicator_filepath and data_input_predicator_class_name: + module = imp_util.load_module(data_input_predicator_filepath) + cls = getattr(module, data_input_predicator_class_name) + return cls(config={}) + return lambda *args, **kwargs: True + + def _write_to_file(self, unittest, output_dir): + output_path = Path(output_dir) / self.output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(unittest, encoding="utf-8") + print( + f"[AgentUnittestGenerator] Generate unittest: {output_path} (device={self.device})" + ) + return output_path + + def _try_to_run_unittest(self, unittest): + if not self.try_run: + return True + + with tempfile.TemporaryDirectory(prefix="unittest_") as temp_dir: + output_path = self._write_to_file(unittest, temp_dir) + result = subprocess.run( + [sys.executable, output_path], + check=True, + ) + return result.returncode == 0 + + def _get_input_and_weight_arg_names(self, graph_module): + input_arg_names = [] + weight_arg_names = [] + sig = inspect.signature(graph_module.forward) + for name, param in sig.parameters.items(): + if name == "self": + continue + is_not_data_input = not self.data_input_predicator(self.model_path, name) + if is_not_data_input: + weight_arg_names.append(name) + else: + input_arg_names.append(name) + return input_arg_names, weight_arg_names + + def _get_input_and_weight_tensor_metas(self, input_arg_names, weight_arg_names): + tensor_metas = TensorMeta.unserialize_from_py_file( + self.model_path / "weight_meta.py" + ) + tensor_metas.extend( + TensorMeta.unserialize_from_py_file(self.model_path / "input_meta.py") + ) + name2tensor_metas = {meta.name: meta for meta in tensor_metas} + input_tensor_metas = [name2tensor_metas[name] for name in input_arg_names] + weight_tensor_metas = [name2tensor_metas[name] for name in weight_arg_names] + return input_tensor_metas, weight_tensor_metas + + def _get_forward_body(self, graph_module, input_arg_names, weight_arg_names): + def _remove_clear_stmt_of_args(stmt): + def _need_remove(target): + return isinstance(target, ast.Name) and target.id in arg_names + + arg_names = input_arg_names + weight_arg_names + if ( + isinstance(stmt, ast.Assign) + and isinstance(stmt.value, ast.Constant) + and stmt.value.value is None + ): + # remove stmt like w_0 = None + new_targets = [t for t in stmt.targets if not _need_remove(t)] + if not new_targets: + return None + stmt.targets = new_targets + elif isinstance(stmt, ast.Delete): + # remove stmt like del w_0 + new_targets = [] + for t in stmt.targets: + if isinstance(t, ast.Tuple): + kept = [e for e in t.elts if not _need_remove(e)] + if kept: + new_targets.append(ast.Tuple(elts=kept, ctx=ast.Del())) + elif not _need_remove(t): + new_targets.append(t) + if not new_targets: + return None + stmt.targets = new_targets + return stmt + + def _rewrite_reference_for_weight(stmt): + if isinstance(stmt, ast.Name): + if isinstance(stmt.ctx, ast.Load) and stmt.id in weight_arg_names: + return ast.Attribute( + value=ast.Name(id="self", ctx=ast.Load()), + attr=stmt.id, + ctx=ast.Load(), + ) + return stmt + + for field, value in ast.iter_fields(stmt): + if isinstance(value, list): + new_list = [] + for item in value: + if isinstance(item, ast.AST): + item = _rewrite_reference_for_weight(item) + new_list.append(item) + setattr(stmt, field, new_list) + elif isinstance(value, ast.AST): + setattr(stmt, field, _rewrite_reference_for_weight(value)) + return stmt + + def _update_for_weight(stmt): + stmt = _remove_clear_stmt_of_args(stmt) + if stmt is not None and weight_arg_names: + stmt = _rewrite_reference_for_weight(stmt) + ast.fix_missing_locations(stmt) + return stmt + + source = inspect.getsource(graph_module.forward) + lines = source.splitlines() + num_indents = len(lines[-1]) - len(lines[-1].lstrip()) + + tree = ast.parse(textwrap.dedent(source)) + func_def = tree.body[0] + dedented_stmts = [ + ast.unparse(s) + for stmt in func_def.body + if (s := _update_for_weight(stmt)) is not None + ] + + indent = " " * num_indents + return f"\n{indent}".join(dedented_stmts) + + def _render_template(self, graph_module_desc): + template_str = PADDLE_UNITTEST_TEMPLATE + return jinja2.Template(template_str).render(graph_module_desc=graph_module_desc) diff --git a/graph_net/paddle/sample_passes/__init__.py b/graph_net/paddle/sample_pass/__init__.py similarity index 100% rename from graph_net/paddle/sample_passes/__init__.py rename to graph_net/paddle/sample_pass/__init__.py diff --git a/graph_net/sample_pass/agent_unittest_generator.py b/graph_net/sample_pass/agent_unittest_generator.py index d0c7ec661..3f314924f 100644 --- a/graph_net/sample_pass/agent_unittest_generator.py +++ b/graph_net/sample_pass/agent_unittest_generator.py @@ -204,8 +204,6 @@ def load_class_from_file(file_path: str, class_name: str): class AgentUnittestGenerator: - """Generate standalone unittest scripts for Torch samples.""" - def __init__( self, framework: str, @@ -432,7 +430,7 @@ def _render_template(self, graph_module_desc): class AgentUnittestGeneratorPass(SamplePass, ResumableSamplePassMixin): - """SamplePass wrapper to generate Torch unittests via model_path_handler.""" + """SamplePass to generate unittests in KernelBench format.""" def __init__(self, config=None): super().__init__(config) From cbdc80a921c89c9b57512b3e88f0c2db698c9b16 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 5 Jan 2026 10:02:28 +0800 Subject: [PATCH 2/6] Generate correct prologue unittest. --- .../prologue_subgraph_unittest_generator.py | 514 ++++++++++++++++++ .../paddle/prologue_unittest_generator.py | 409 -------------- .../sample_pass/agent_unittest_generator.py | 2 +- graph_net/tensor_meta.py | 37 +- .../test/prologue_unittest_generator_test.sh | 24 + 5 files changed, 575 insertions(+), 411 deletions(-) create mode 100644 graph_net/paddle/prologue_subgraph_unittest_generator.py delete mode 100644 graph_net/paddle/prologue_unittest_generator.py create mode 100644 graph_net/test/prologue_unittest_generator_test.sh diff --git a/graph_net/paddle/prologue_subgraph_unittest_generator.py b/graph_net/paddle/prologue_subgraph_unittest_generator.py new file mode 100644 index 000000000..6fd43534c --- /dev/null +++ b/graph_net/paddle/prologue_subgraph_unittest_generator.py @@ -0,0 +1,514 @@ +import os +import re +import sys +import subprocess +import ast +import inspect +import jinja2 +import textwrap +import tempfile +from pathlib import Path +from typing import Literal, List +from collections import namedtuple + +import paddle +from athena.graphnet_samples import SubgraphGenerator +from graph_net import imp_util +from graph_net.paddle.extractor import GraphExtractor as BuiltinGraphExtractor +from graph_net.paddle.graph_meta_restorer import GraphMetaRestorer +from graph_net.tensor_meta import TensorMeta + + +class GraphExtractor: + def __init__( + self, + config: dict, + model, + name, + dynamic, + input_spec=None, + ): + self.model = model + self.name = name.replace("/", "_") + self.dynamic = dynamic + self.input_spec = input_spec + self.config = self.make_config(**config) + + def make_config( + self, + subgraph_range: list, + device: Literal["auto", "cpu", "cuda", "xpu"] = "auto", + try_run: bool = False, + output_dir: str = "/tmp/prologue_unittests", + ): + assert isinstance(subgraph_range, (tuple, list)) and len(subgraph_range) == 2 + for pos in subgraph_range: + assert isinstance( + pos, int + ), f"subgraph_range should be list of int, {subgraph_range=}" + return { + "subgraph_range": subgraph_range, + "device": device, + "try_run": try_run, + "output_dir": output_dir, + } + + def __call__(self, **input_dict): + extracted_model = self.get_prologue_subgraph_unittest_generator()(**input_dict) + return extracted_model + + def get_prologue_subgraph_unittest_generator(self): + return PrologueSubgraphUnittestGenerator( + config=self.config, + parent_model=self.model, + parent_model_name=self.name, + parent_input_spec=self.input_spec, + ) + + +PADDLE_UNITTEST_TEMPLATE = r""" +import os +import sys +import argparse +import unittest +import random +import numpy as np +import paddle + + +{% macro get_input_tensor_instance(tensor_meta) -%} +{%- set shape = tensor_meta.shape -%} +{%- set dtype = tensor_meta.dtype -%} +{%- set data = tensor_meta.data -%} +{%- set min_val = tensor_meta.min_val -%} +{%- set max_val = tensor_meta.max_val -%} +{%- set mean = tensor_meta.mean -%} +{%- set std = tensor_meta.std -%} +{%- if data is not none -%} + paddle.reshape(paddle.to_tensor({{data}}, dtype='{{dtype}}'), shape={{shape}}).to(device=device) +{%- elif dtype == "bool" -%} + init_integer_tensor(shape={{shape}}, dtype='{{dtype}}', device=device, min_val=0, max_val=1) +{%- elif dtype in ["int8", "int16", "int32", "int64"] -%} + init_integer_tensor(shape={{shape}}, dtype='{{dtype}}', device=device, min_val={{min_val}}, max_val={{max_val}}) +{%- elif dtype in ["float16", "bfloat16", "float32", "float64"] -%} + {%- if mean is not none or std is not none -%} + init_float_tensor(shape={{shape}}, dtype='{{dtype}}', device=device, min_val={{min_val}}, max_val={{max_val}}, mean={{mean}}, std={{std}}) + {%- else -%} + init_float_tensor(shape={{shape}}, dtype='{{dtype}}', device=device, min_val={{min_val}}, max_val={{max_val}}) + {%- endif -%} +{%- endif -%} +{%- endmacro -%} + + +def init_integer_tensor(shape, dtype, device, min_val, max_val): + array = np.random.randint( + low=min_val, high=max_val + 1, size=shape, dtype="int64" + ) + return paddle.to_tensor(array).to(dtype).to(device) + + +def init_float_tensor(shape, dtype, device, min_val, max_val, mean=None, std=None): + if mean is not None and std is not None: + array = np.random.normal(0, 1, shape) * std * 0.2 + mean + array = np.clip(array, min_val, max_val) + else: + array = np.random.uniform(low=min_val, high=max_val, size=shape) + return paddle.to_tensor(array).to(dtype).to(device) + + +class PrologueModel(paddle.nn.Layer): +{{graph_module_desc.prologue_forward_func}} + + +class TestModel(paddle.nn.Layer): +{{graph_module_desc.test_forward_func}} + + +def get_input_dict(device): + input_dict = { + {%- for tensor_meta in graph_module_desc.tensor_metas %} + '{{tensor_meta.name}}': {{get_input_tensor_instance(tensor_meta)}}, + {%- endfor %} + } + return input_dict + + +def tolerance_generator(tolerance, dtype): + if dtype == paddle.float16: + yield 10 ** (tolerance * 3 / 5), 10**tolerance + elif dtype == paddle.bfloat16: + yield 10 ** (tolerance * 1.796 / 5), 10**tolerance + elif dtype == paddle.float32: + yield 10 ** (tolerance * 5.886 / 5), 10**tolerance + elif dtype == paddle.float64: + yield 10 ** (tolerance * 7 / 5), 10 ** (tolerance * 7 / 5) + else: + assert False, f"Unsupported {dtype=}." + + +class {{graph_module_desc.test_name}}Test(unittest.TestCase): + def setUp(self): + self.device = TEST_ARGS.device + self.is_reference = TEST_ARGS.is_reference + self.reference_dir = TEST_ARGS.reference_dir + self.tolerance = TEST_ARGS.tolerance + + paddle.seed(123) + random.seed(123) + np.random.seed(123) + + self.input_dict = get_input_dict(self.device) + self.prologue_model = PrologueModel() + self.test_model = TestModel() + + def _flatten_outputs_to_list(self, outs): + flattened_outs = outs + if isinstance(outs, paddle.Tensor): + flattened_outs = [outs] + else: + flattened_outs = [ + x + for out in outs + for x in (out if isinstance(out, (tuple, list)) else (out,)) + ] + return flattened_outs + + def run_prologue_model(self): + prologue_inputs = [ + {%- for arg_name in graph_module_desc.prologue_arg_names %} + self.input_dict['{{arg_name}}'], + {%- endfor %} + ] + prologue_outputs = self.prologue_model(*prologue_inputs) + return self._flatten_outputs_to_list(prologue_outputs) + + def run_test_model(self, prologue_outputs): + test_inputs = [ + {%- for arg_name in graph_module_desc.test_arg_names %} + {%- if arg_name not in graph_module_desc.prologue_returns %} + self.input_dict['{{arg_name}}'], + {%- else %} + {%- for output_name in graph_module_desc.prologue_returns %} + {%- if arg_name == output_name %} + prologue_outputs[{{loop.index0}}], + {%- endif %} + {%- endfor %} + {%- endif %} + {%- endfor %} + ] + test_outputs = self.test_model(*test_inputs) + return self._flatten_outputs_to_list(test_outputs) + + def check_dtypes(self, reference_outputs, target_outputs): + def _get_output_dtypes(outs): + dtypes = [ + str(tensor.dtype).replace("paddle.", "") + if isinstance(tensor, paddle.Tensor) + else None + for i, tensor in enumerate(outs) + ] + return dtypes + + reference_dtypes = _get_output_dtypes(reference_outputs) + target_dtypes = _get_output_dtypes(target_outputs) + dtype_match = all( + reference == target for reference, target in zip(reference_dtypes, target_dtypes) + ) + self.assertTrue(dtype_match, f"Data type of outputs are not matched ({reference_dtypes=} vs {target_dtypes}).") + + def check_shapes(self, reference_outputs, target_outputs): + def _get_output_shapes(self, outs): + shapes = [ + tensor.shape if isinstance(tensor, paddle.Tensor) else None + for i, tensor in enumerate(outs) + ] + return shapes + + reference_shapes = _get_output_shapes(reference_outputs) + target_shapes = _get_output_shapes(target_outputs) + shape_match = all( + reference == target for reference, target in zip(reference_shapes, target_shapes) + ) + self.assertTrue(shape_match, f"Shape of outputs are not matched ({reference_shapes=} vs {target_shapes}).") + + def check_results(self, reference_outputs, target_outputs): + def _convert_to_numpy(out): + if out.dtype in [paddle.float16, paddle.bfloat16]: + return out.cast("float32").numpy() + else: + return out.numpy() + + assert len(reference_outputs) == len(target_outputs), f"The number of outputs is not equal ({len(reference_outputs)=} vs {len(target_outputs)})." + self.check_dtypes(reference_outputs, target_outputs) + self.check_shapes(reference_outputs, target_outputs) + + for reference, target in zip(reference_outputs, target_outputs): + atol, rtol = tolerance_generator(self.tolerance, reference.dtype) + np.testing.assert_allclose(_convert_to_numpy(reference), _convert_to_numpy(target), atol, rtol) + + def test_main(self): + prologue_output_path = os.path.join(self.reference_dir, "prologue.pdout") + if self.is_reference: + prologue_outputs = self.run_prologue_model() + print(f"Save prologue output tensors to {prologue_output_path}.") + paddle.save(prologue_outputs, prologue_output_path) + else: + print(f"Load prologue output tensors from {prologue_output_path}") + prologue_outputs = paddle.load(prologue_output_path) + + test_output_path = os.path.join(self.reference_dir, "test_reference.pdout") + test_outputs = self.run_test_model(prologue_outputs) + if self.is_reference: + print(f"Save test output tensors to {test_output_path}.") + paddle.save(test_outputs, test_output_path) + else: + print(f"Load test output tensors on reference device from {test_output_path}.") + test_reference_outputs = paddle.load(test_reference_outputs) + self.check_results(test_reference_outputs, test_outputs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--is-reference", action="store_true", default=False) + parser.add_argument("--device", type=str, required=True) + parser.add_argument("--reference-dir", type=str, required=True) + parser.add_argument("--tolerance", type=int, choices=range(-10, 5), default=0) + args, remaining = parser.parse_known_args() + + global TEST_ARGS + TEST_ARGS = args + + unittest.main(argv=[sys.argv[0]] + remaining) +""" + + +GraphModuleDescriptor = namedtuple( + "GraphModuleDescriptor", + [ + "test_name", + "tensor_metas", + "prologue_arg_names", + "prologue_returns", + "prologue_forward_func", + "test_arg_names", + "test_forward_func", + ], +) + + +def load_class_from_file(file_path: str, class_name: str): + print(f"Load {class_name} from {file_path}") + module = imp_util.load_module(file_path, "unnamed") + model_class = getattr(module, class_name, None) + return model_class + + +class PrologueSubgraphUnittestGenerator: + def __init__( + self, + config: dict, + parent_model: paddle.nn.Layer, + parent_model_name: str, + parent_input_spec: List[paddle.static.InputSpec], + ): + self.config = config + self.extracted = False + self.model_name = parent_model_name + self.parent_model_path = os.path.dirname(parent_model.__graph_net_file_path__) + self.builtin_extractor = BuiltinGraphExtractor( + model=parent_model, + name=parent_model_name, + dynamic=False, + input_spec=parent_input_spec, + workspace_path=self.config["output_dir"], + ) + self.subgraph_range = self.config["subgraph_range"] + self.device = self._choose_device(self.config["device"]) + self.try_run = self.config["try_run"] + self.output_dir = self.config["output_dir"] + self.graph_meta_restorer = self._make_graph_meta_restorer() + + def __call__(self, **input_dict): + extracted_model = None + if not self.extracted: + extracted_model = self.do_extract(**input_dict) + self.extracted = True + return extracted_model + + def do_extract(self, **input_dict): + print( + f"[PrologueUnittestGenerator] Generate unittest for {self.parent_model_path} with subgraph_range={self.subgraph_range}" + ) + + # 1. Run the model to dump pir programs + model_dump_path = os.path.join( + self.builtin_extractor.dump_path, self.model_name + ) + static_model = self.builtin_extractor.run_model_with_dump_enabled( + model_dump_path, **input_dict + ) + + # 2. Convert pir programs to graphnet samples + ir_programs_path = self.builtin_extractor.get_ir_programs_path(model_dump_path) + example_inputs_path = self.builtin_extractor.get_example_inputs_path( + model_dump_path + ) + op_example_inputs_path = self.builtin_extractor.generate_op_example_inputs_path( + model_dump_path, self.subgraph_range + ) + subgraph_generator = SubgraphGenerator( + model_name=self.model_name, + programs_file=ir_programs_path, + example_inputs_file=example_inputs_path, + op_example_inputs_file=op_example_inputs_path, + eval_mode=True, + tmp_dir=model_dump_path, + ) + + # 3. Generate unittest + with tempfile.TemporaryDirectory(prefix="prologue_unittest_") as tmp_dir: + self.generate(subgraph_generator, tmp_dir) + return static_model + + def _save_and_get_graph_module(self, subgraph_generator, subgraph_range, tmp_dir): + results = subgraph_generator(subgraph_range, False) + assert len(results) == 1 + output_name = f"{subgraph_range[0]}_{subgraph_range[1]}" + output_path = os.path.join(tmp_dir, f"{self.model_name}-{output_name}") + self.builtin_extractor.write_sample_to_file(output_path, results[0]) + graph_module = load_class_from_file( + Path(output_path) / "model.py", class_name="GraphModule" + ) + return graph_module, output_path + + def generate(self, subgraph_generator, tmp_dir): + test_name = "".join( + word.capitalize() for word in re.split(r"[_.-]", self.model_name) + ) + + graph_module, output_path = self._save_and_get_graph_module( + subgraph_generator, self.subgraph_range, tmp_dir + ) + self.graph_meta_restorer(output_path) + tensor_metas = self._get_tensor_metas(output_path) + + # prologue model information + prologue_subgraph_range = [self.subgraph_range[0], self.subgraph_range[1] - 1] + prologue_graph_module, _ = self._save_and_get_graph_module( + subgraph_generator, prologue_subgraph_range, tmp_dir + ) + prologue_forward_func, prologue_returns = self._get_forward_func_and_returns( + prologue_graph_module + ) + prologue_arg_names = self._get_forward_arg_names(prologue_graph_module) + + # test model information + test_subgraph_range = [self.subgraph_range[1] - 1, self.subgraph_range[1]] + test_graph_module, _ = self._save_and_get_graph_module( + subgraph_generator, test_subgraph_range, tmp_dir + ) + test_forward_func, _ = self._get_forward_func_and_returns(test_graph_module) + test_arg_names = self._get_forward_arg_names(test_graph_module) + + def _generate_unittest(): + graph_module_desc = GraphModuleDescriptor( + test_name=test_name, + tensor_metas=tensor_metas, + prologue_arg_names=prologue_arg_names, + prologue_returns=prologue_returns, + prologue_forward_func=prologue_forward_func, + test_arg_names=test_arg_names, + test_forward_func=test_forward_func, + ) + return self._render_template(graph_module_desc) + + # Generate unittest with main for try-run. + unittest = _generate_unittest() + self._write_to_file(unittest, self.output_dir) + if self._try_to_run_unittest(unittest, tmp_dir): + self._write_to_file(unittest, self.output_dir) + + def _choose_device(self, device) -> str: + if device in ["cpu", "gpu", "xpu"]: + return device + return paddle.get_device() + + def _make_graph_meta_restorer(self): + config = { + "update_inplace": True, + "input_meta_allow_partial_update": False, + } + graph_meta_restorer = GraphMetaRestorer( + config=config, parent_model_path=self.parent_model_path + ) + return graph_meta_restorer + + def _write_to_file(self, unittest, output_dir): + output_path = Path(output_dir) / f"{self.model_name}_test.py" + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(unittest, encoding="utf-8") + print(f"[PrologueUnittestGenerator] Generate unittest: {output_path}") + return output_path + + def _try_to_run_unittest(self, unittest, tmp_dir): + if not self.try_run: + return True + + output_path = self._write_to_file(unittest, tmp_dir) + result = subprocess.run( + [ + sys.executable, + output_path, + "--is-reference", + "--device", + self.device, + "--reference-dir", + os.path.join(tmp_dir, "test_reference_outputs"), + ], + check=True, + ) + return result.returncode == 0 + + def _get_tensor_metas(self, model_path): + tensor_metas = TensorMeta.unserialize_from_py_file_order_preserved( + Path(model_path) / "weight_meta.py" + ) + tensor_metas.extend( + TensorMeta.unserialize_from_py_file_order_preserved( + Path(model_path) / "input_meta.py" + ) + ) + return tensor_metas + + def _get_forward_arg_names(self, graph_module): + arg_names = [] + sig = inspect.signature(graph_module.forward) + for name, param in sig.parameters.items(): + if name == "self": + continue + arg_names.append(name) + return arg_names + + def _get_forward_func_and_returns(self, graph_module): + # the whole definition of forward function + source = inspect.getsource(graph_module.forward) + lines = source.splitlines() + num_indents = len(lines[0]) - len(lines[0].lstrip()) + + tree = ast.parse(textwrap.dedent(source)) + func_def = tree.body[0] + forward_func = textwrap.indent(ast.unparse(func_def), " " * num_indents) + + # the return statements + return_node_values = [ + node.value for node in ast.walk(func_def) if isinstance(node, ast.Return) + ] + return_codes = [ + ast.unparse(value) if value is not None else "None" + for value in return_node_values + ] + return forward_func, return_codes + + def _render_template(self, graph_module_desc): + template_str = PADDLE_UNITTEST_TEMPLATE + return jinja2.Template(template_str).render(graph_module_desc=graph_module_desc) diff --git a/graph_net/paddle/prologue_unittest_generator.py b/graph_net/paddle/prologue_unittest_generator.py deleted file mode 100644 index 6f5415482..000000000 --- a/graph_net/paddle/prologue_unittest_generator.py +++ /dev/null @@ -1,409 +0,0 @@ -import os -import re -import sys -import subprocess -import ast -import inspect -import jinja2 -import textwrap -import tempfile -from pathlib import Path -from typing import Literal, List -from collections import namedtuple - -import paddle -from athena.graphnet_samples import SubgraphGenerator -from graph_net import imp_util -from graph_net.paddle.extractor import GraphExtractor as BuiltinGraphExtractor -from graph_net.tensor_meta import TensorMeta - - -class GraphExtractor: - def __init__( - self, - config: dict, - model, - name, - dynamic, - input_spec=None, - ): - self.model = model - self.name = name.replace("/", "_") - self.dynamic = dynamic - self.input_spec = input_spec - self.config = self.make_config(**config) - - def make_config( - self, - subgraph_range: list, - device: Literal["auto", "cpu", "cuda", "xpu"] = "auto", - try_run: bool = False, - data_input_predicator_filepath: str = None, - data_input_predicator_class_name: str = None, - output_dir: str = "/tmp/prologue_unittests", - ): - assert isinstance(subgraph_range, (tuple, list)) and len(subgraph_range) == 2 - for pos in subgraph_range: - assert isinstance( - pos, int - ), f"subgraph_range should be list of int, {subgraph_range=}" - return { - "subgraph_range": subgraph_range, - "device": device, - "try_run": try_run, - "data_input_predicator_filepath": data_input_predicator_filepath, - "data_input_predicator_class_name": data_input_predicator_class_name, - "output_dir": output_dir, - } - - def __call__(self, **input_dict): - extracted_model = self.get_prologue_subgraph_unittest_generator()(**input_dict) - return extracted_model - - def get_prologue_subgraph_unittest_generator(self): - return PrologueSubgraphUnittestGenerator( - config=self.config, - parent_model=self.model, - parent_model_name=self.name, - parent_input_spec=self.input_spec, - ) - - -PADDLE_UNITTEST_TEMPLATE = r""" -import unittest -import numpy as np -import paddle - -{% macro get_input_tensor_instance(tensor_meta, device) -%} -{%- set shape = tensor_meta.shape -%} -{%- set dtype = tensor_meta.dtype -%} -{%- set data = tensor_meta.data -%} -{%- set min_val = tensor_meta.min_val -%} -{%- set max_val = tensor_meta.max_val -%} -{%- set mean = tensor_meta.mean -%} -{%- set std = tensor_meta.std -%} -{%- if data is not none -%} - paddle.to_tensor({{data}}, dtype='{{dtype}}', shape={{shape}}).to(device='{{device}}') -{%- elif dtype == "bool" -%} - paddle.randint(low=0, high=2, shape={{shape}}, dtype='{{dtype}}').to(device='{{device}}') -{%- elif dtype in ["int8", "int16", "int32", "int64"] -%} - paddle.randint(low={{min_val}}, high={{max_val}} + 1, shape={{shape}}, dtype='{{dtype}}').to(device='{{device}}') -{%- elif dtype in ["float16", "bfloat16", "float32", "float64"] -%} - {%- if mean is not none or std is not none -%} - init_float_tensor(shape={{shape}}, dtype='{{dtype}}', max_val={{max_val}}, min_val={{min_val}}, mean={{mean}}, std={{std}}) - {%- else -%} - init_float_tensor(shape={{shape}}, dtype='{{dtype}}', max_val={{max_val}}, min_val={{min_val}}) - {%- endif -%} -{%- endif -%} -{%- endmacro -%} - - -def init_float_tensor(shape, dtype, max_val, min_val, mean=None, std=None): - if mean is not None and std is not None: - tensor = paddle.randn(shape, dtype="float32") * std * 0.2 + mean - tensor = paddle.clip(tensor, min=min_val, max=max_val) - else: - tensor = paddle.uniform(shape=shape, dtype="float32", min=min_val, max=max_val) - return tensor.to(dtype).to('{{graph_module_desc.device}}') - - -class PrologueLayer(paddle.nn.Layer): - def __init__(self): - super().__init__() - - def forward(self, {{graph_module_desc.prologue_weight_arg_names | join(", ")}}, {{graph_module_desc.prologue_input_arg_names | join(", ")}}): - {{graph_module_desc.prologue_forward_body}} - - -class Model(paddle.nn.Layer): - def __init__(self): - super().__init__() - - def forward(self, {{graph_module_desc.weight_arg_names | join(", ")}}, {{graph_module_desc.input_arg_names | join(", ")}}): - {{graph_module_desc.forward_body}} - - -def get_inputs(): - {%- for arg_name in graph_module_desc.input_arg_names %} - {%- set input_idx = loop.index0 %} - {{arg_name}} = {{get_input_tensor_instance(graph_module_desc.input_tensor_metas[input_idx], graph_module_desc.device)}} - {%- endfor %} - return [{{graph_module_desc.input_arg_names | join(", ")}}] - - -class {{graph_module_desc.model_name}}Test(unittest.TestCase): - def setUp(self): - paddle.seed(123) - self.model = Model() - - def test_main(self): - inputs = get_inputs() - outputs = self.model(*inputs) - - -if __name__ == "__main__": - unittest.main() -""" - - -GraphModuleDescriptor = namedtuple( - "GraphModuleDescriptor", - [ - "device", - "model_name", - "input_arg_names", - "input_tensor_metas", - "weight_arg_names", - "weight_tensor_metas", - "forward_body", - ], -) - - -def load_class_from_file(file_path: str, class_name: str): - print(f"Load {class_name} from {file_path}") - module = imp_util.load_module(file_path, "unnamed") - model_class = getattr(module, class_name, None) - return model_class - - -class PrologueSubgraphUnittestGenerator: - def __init__( - self, - config: dict, - parent_model: paddle.nn.Layer, - parent_model_name: str, - parent_input_spec: List[paddle.static.InputSpec], - ): - self.config = config - self.extracted = False - self.parent_model_path = os.path.dirname(parent_model.__graph_net_file_path__) - self.builtin_extractor = BuiltinGraphExtractor( - model=parent_model, - name=parent_model_name, - dynamic=False, - input_spec=parent_input_spec, - workspace_path=self.config["output_dir"], - ) - self.subgraph_range = self.config["subgraph_range"] - self.device = self._choose_device(self.config["device"]) - self.try_run = self.config["try_run"] - self.data_input_predicator = self._make_data_input_predicator( - self.config["data_input_predicator_filepath"], - self.config["data_input_predicator_class_name"], - ) - - def __call__(self, **input_dict): - extracted_model = None - if not self.extracted: - extracted_model = self.do_extract(**input_dict) - self.extracted = True - return extracted_model - - def do_extract(self, **input_dict): - # 1. Run the model to dump pir programs - model_dump_path = os.path.join( - self.builtin_extractor.dump_path, self.builtin_extractor.name - ) - static_model = self.builtin_extractor.run_model_with_dump_enabled( - model_dump_path, **input_dict - ) - - # 2. Convert pir programs to graphnet samples - ir_programs_path = self.builtin_extractor.get_ir_programs_path(model_dump_path) - example_inputs_path = self.builtin_extractor.get_example_inputs_path( - model_dump_path - ) - op_example_inputs_path = self.builtin_extractor.generate_op_example_inputs_path( - model_dump_path, self.subgraph_range - ) - generator = SubgraphGenerator( - model_name=self.builtin_extractor.name, - programs_file=ir_programs_path, - example_inputs_file=example_inputs_path, - op_example_inputs_file=op_example_inputs_path, - eval_mode=True, - tmp_dir=model_dump_path, - ) - graphnet_sample_results = generator(self.subgraph_range, False) - assert len(graphnet_sample_results) == 1 - - return static_model - - def generate(self): - print(f"[PrologueUnittestGenerator] Generate unittest for {self.model_path}") - model_name = "".join( - word.capitalize() for word in re.split(r"[_.-]", self.model_path.name) - ) - graph_module = load_class_from_file( - self.model_path / "model.py", class_name="GraphModule" - ) - input_arg_names, weight_arg_names = self._get_input_and_weight_arg_names( - graph_module - ) - ( - input_tensor_metas, - weight_tensor_metas, - ) = self._get_input_and_weight_tensor_metas(input_arg_names, weight_arg_names) - - def _generate_unittest(): - graph_module_desc = GraphModuleDescriptor( - device=self.device, - model_name=model_name, - input_arg_names=input_arg_names, - input_tensor_metas=input_tensor_metas, - weight_arg_names=weight_arg_names, - weight_tensor_metas=weight_tensor_metas, - forward_body=self._get_forward_body( - graph_module, input_arg_names, weight_arg_names - ), - ) - return self._render_template(graph_module_desc) - - # Generate unittest with main for try-run. - unittest = _generate_unittest() - self._write_to_file(unittest, self.output_dir) - if self._try_to_run_unittest(unittest): - self._write_to_file(unittest, self.output_dir) - - def _choose_device(self, device) -> str: - import paddle - - if device in ["cpu", "gpu", "xpu"]: - return device - return "gpu" if paddle.device.is_compiled_with_cuda() else "cpu" - - def _make_data_input_predicator( - self, data_input_predicator_filepath, data_input_predicator_class_name - ): - if data_input_predicator_filepath and data_input_predicator_class_name: - module = imp_util.load_module(data_input_predicator_filepath) - cls = getattr(module, data_input_predicator_class_name) - return cls(config={}) - return lambda *args, **kwargs: True - - def _write_to_file(self, unittest, output_dir): - output_path = Path(output_dir) / self.output_name - output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_text(unittest, encoding="utf-8") - print( - f"[AgentUnittestGenerator] Generate unittest: {output_path} (device={self.device})" - ) - return output_path - - def _try_to_run_unittest(self, unittest): - if not self.try_run: - return True - - with tempfile.TemporaryDirectory(prefix="unittest_") as temp_dir: - output_path = self._write_to_file(unittest, temp_dir) - result = subprocess.run( - [sys.executable, output_path], - check=True, - ) - return result.returncode == 0 - - def _get_input_and_weight_arg_names(self, graph_module): - input_arg_names = [] - weight_arg_names = [] - sig = inspect.signature(graph_module.forward) - for name, param in sig.parameters.items(): - if name == "self": - continue - is_not_data_input = not self.data_input_predicator(self.model_path, name) - if is_not_data_input: - weight_arg_names.append(name) - else: - input_arg_names.append(name) - return input_arg_names, weight_arg_names - - def _get_input_and_weight_tensor_metas(self, input_arg_names, weight_arg_names): - tensor_metas = TensorMeta.unserialize_from_py_file( - self.model_path / "weight_meta.py" - ) - tensor_metas.extend( - TensorMeta.unserialize_from_py_file(self.model_path / "input_meta.py") - ) - name2tensor_metas = {meta.name: meta for meta in tensor_metas} - input_tensor_metas = [name2tensor_metas[name] for name in input_arg_names] - weight_tensor_metas = [name2tensor_metas[name] for name in weight_arg_names] - return input_tensor_metas, weight_tensor_metas - - def _get_forward_body(self, graph_module, input_arg_names, weight_arg_names): - def _remove_clear_stmt_of_args(stmt): - def _need_remove(target): - return isinstance(target, ast.Name) and target.id in arg_names - - arg_names = input_arg_names + weight_arg_names - if ( - isinstance(stmt, ast.Assign) - and isinstance(stmt.value, ast.Constant) - and stmt.value.value is None - ): - # remove stmt like w_0 = None - new_targets = [t for t in stmt.targets if not _need_remove(t)] - if not new_targets: - return None - stmt.targets = new_targets - elif isinstance(stmt, ast.Delete): - # remove stmt like del w_0 - new_targets = [] - for t in stmt.targets: - if isinstance(t, ast.Tuple): - kept = [e for e in t.elts if not _need_remove(e)] - if kept: - new_targets.append(ast.Tuple(elts=kept, ctx=ast.Del())) - elif not _need_remove(t): - new_targets.append(t) - if not new_targets: - return None - stmt.targets = new_targets - return stmt - - def _rewrite_reference_for_weight(stmt): - if isinstance(stmt, ast.Name): - if isinstance(stmt.ctx, ast.Load) and stmt.id in weight_arg_names: - return ast.Attribute( - value=ast.Name(id="self", ctx=ast.Load()), - attr=stmt.id, - ctx=ast.Load(), - ) - return stmt - - for field, value in ast.iter_fields(stmt): - if isinstance(value, list): - new_list = [] - for item in value: - if isinstance(item, ast.AST): - item = _rewrite_reference_for_weight(item) - new_list.append(item) - setattr(stmt, field, new_list) - elif isinstance(value, ast.AST): - setattr(stmt, field, _rewrite_reference_for_weight(value)) - return stmt - - def _update_for_weight(stmt): - stmt = _remove_clear_stmt_of_args(stmt) - if stmt is not None and weight_arg_names: - stmt = _rewrite_reference_for_weight(stmt) - ast.fix_missing_locations(stmt) - return stmt - - source = inspect.getsource(graph_module.forward) - lines = source.splitlines() - num_indents = len(lines[-1]) - len(lines[-1].lstrip()) - - tree = ast.parse(textwrap.dedent(source)) - func_def = tree.body[0] - dedented_stmts = [ - ast.unparse(s) - for stmt in func_def.body - if (s := _update_for_weight(stmt)) is not None - ] - - indent = " " * num_indents - return f"\n{indent}".join(dedented_stmts) - - def _render_template(self, graph_module_desc): - template_str = PADDLE_UNITTEST_TEMPLATE - return jinja2.Template(template_str).render(graph_module_desc=graph_module_desc) diff --git a/graph_net/sample_pass/agent_unittest_generator.py b/graph_net/sample_pass/agent_unittest_generator.py index 3f314924f..f1636d533 100644 --- a/graph_net/sample_pass/agent_unittest_generator.py +++ b/graph_net/sample_pass/agent_unittest_generator.py @@ -116,7 +116,7 @@ def test_main(self): {%- set mean = tensor_meta.mean -%} {%- set std = tensor_meta.std -%} {%- if data is not none -%} - paddle.to_tensor({{data}}, dtype='{{dtype}}', shape={{shape}}).to(device='{{device}}') + paddle.reshape(paddle.to_tensor({{data}}, dtype='{{dtype}}'), shape={{shape}}).to(device='{{device}}') {%- elif dtype == "bool" -%} paddle.randint(low=0, high=2, shape={{shape}}, dtype='{{dtype}}').to(device='{{device}}') {%- elif dtype in ["int8", "int16", "int32", "int64"] -%} diff --git a/graph_net/tensor_meta.py b/graph_net/tensor_meta.py index f1a082a4e..b3758adbf 100755 --- a/graph_net/tensor_meta.py +++ b/graph_net/tensor_meta.py @@ -1,3 +1,4 @@ +import ast import importlib.util as imp import inspect from dataclasses import dataclass @@ -38,6 +39,26 @@ def unserialize_from_py_file(cls, file_path) -> list["TensorMeta"]: for attrs in [cls._convert_cls_to_attrs(tensor_meta_cls)] ] + @classmethod + def unserialize_from_py_file_order_preserved(cls, file_path) -> list["TensorMeta"]: + return [ + TensorMeta( + record_class_name=attrs.get("record_class_name"), + name=attrs.get("name"), + original_name=attrs.get("original_name", None), + shape=attrs.get("shape", []), + dtype=attrs.get("dtype"), + device=attrs.get("device", None), + mean=attrs.get("mean", None), + std=attrs.get("std", None), + data=attrs.get("data", None), + max_val=attrs.get("max_val", None), + min_val=attrs.get("min_val", None), + ) + for name, tensor_meta_cls in cls._get_classes_order_preserved(file_path) + for attrs in [cls._convert_cls_to_attrs(tensor_meta_cls)] + ] + @classmethod def _convert_cls_to_attrs(cls, tensor_meta_cls): attrs = { @@ -50,11 +71,25 @@ def _convert_cls_to_attrs(cls, tensor_meta_cls): @classmethod def _get_classes(cls, file_path, name="unnamed"): - spec = imp.spec_from_file_location("unnamed", file_path) + spec = imp.spec_from_file_location(name, file_path) unnamed = imp.module_from_spec(spec) spec.loader.exec_module(unnamed) yield from inspect.getmembers(unnamed, inspect.isclass) + @classmethod + def _get_classes_order_preserved(cls, file_path, name="unnamed"): + with open(file_path, "r", encoding="utf-8") as f: + tree = ast.parse(f.read(), filename=file_path) + + class_names = [ + node.name for node in tree.body if isinstance(node, ast.ClassDef) + ] + + spec = imp.spec_from_file_location(name, file_path) + unnamed = imp.module_from_spec(spec) + spec.loader.exec_module(unnamed) + yield from [(name, getattr(unnamed, name)) for name in class_names] + def serialize_to_py_str(self) -> str: lines = [ (f"class {self.record_class_name}:"), diff --git a/graph_net/test/prologue_unittest_generator_test.sh b/graph_net/test/prologue_unittest_generator_test.sh new file mode 100644 index 000000000..0184e21e6 --- /dev/null +++ b/graph_net/test/prologue_unittest_generator_test.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +GRAPH_NET_ROOT=$(python -c "import graph_net, os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))") + +MODEL_PATH=paddle_samples/PaddleX/ResNet18 +decorator_config_json_str=$(cat < Date: Mon, 5 Jan 2026 23:24:42 +0800 Subject: [PATCH 3/6] Fix the parse of returned names. --- .../prologue_subgraph_unittest_generator.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/graph_net/paddle/prologue_subgraph_unittest_generator.py b/graph_net/paddle/prologue_subgraph_unittest_generator.py index 6fd43534c..1dd56bcfe 100644 --- a/graph_net/paddle/prologue_subgraph_unittest_generator.py +++ b/graph_net/paddle/prologue_subgraph_unittest_generator.py @@ -247,7 +247,7 @@ def _convert_to_numpy(out): np.testing.assert_allclose(_convert_to_numpy(reference), _convert_to_numpy(target), atol, rtol) def test_main(self): - prologue_output_path = os.path.join(self.reference_dir, "prologue.pdout") + prologue_output_path = os.path.join(self.reference_dir, "{{graph_module_desc.model_name}}_prologue.pdout") if self.is_reference: prologue_outputs = self.run_prologue_model() print(f"Save prologue output tensors to {prologue_output_path}.") @@ -256,7 +256,7 @@ def test_main(self): print(f"Load prologue output tensors from {prologue_output_path}") prologue_outputs = paddle.load(prologue_output_path) - test_output_path = os.path.join(self.reference_dir, "test_reference.pdout") + test_output_path = os.path.join(self.reference_dir, "{{graph_module_desc.model_name}}_test_reference.pdout") test_outputs = self.run_test_model(prologue_outputs) if self.is_reference: print(f"Save test output tensors to {test_output_path}.") @@ -285,6 +285,7 @@ def test_main(self): GraphModuleDescriptor = namedtuple( "GraphModuleDescriptor", [ + "model_name", "test_name", "tensor_metas", "prologue_arg_names", @@ -412,6 +413,7 @@ def generate(self, subgraph_generator, tmp_dir): def _generate_unittest(): graph_module_desc = GraphModuleDescriptor( + model_name=self.model_name, test_name=test_name, tensor_metas=tensor_metas, prologue_arg_names=prologue_arg_names, @@ -503,11 +505,10 @@ def _get_forward_func_and_returns(self, graph_module): return_node_values = [ node.value for node in ast.walk(func_def) if isinstance(node, ast.Return) ] - return_codes = [ - ast.unparse(value) if value is not None else "None" - for value in return_node_values - ] - return forward_func, return_codes + assert len(return_node_values) == 1 and return_node_values[0] is not None + return_code = ast.unparse(return_node_values[0]) + return_names = [name.strip() for name in return_code.strip("()").split(",")] + return forward_func, return_names def _render_template(self, graph_module_desc): template_str = PADDLE_UNITTEST_TEMPLATE From ef1eaa78d0a306002a0b57a8214f8db29b7e4d53 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 5 Jan 2026 23:25:58 +0800 Subject: [PATCH 4/6] Add unittest generating step. --- .../subgraph_decompose_and_evaluation_step.py | 91 +++++++++++++++++-- 1 file changed, 84 insertions(+), 7 deletions(-) diff --git a/graph_net/subgraph_decompose_and_evaluation_step.py b/graph_net/subgraph_decompose_and_evaluation_step.py index cd2a42796..addee52c8 100755 --- a/graph_net/subgraph_decompose_and_evaluation_step.py +++ b/graph_net/subgraph_decompose_and_evaluation_step.py @@ -418,6 +418,79 @@ def run_evaluation( ), f"[ERROR] test failed for {work_dir}, please check the log." +def generate_unittest_for_single_model( + framework, model_name, model_path, subgraph_range, output_dir, log_path +): + graphnet_root = path_utils.get_graphnet_root() + decorator_config = { + "decorator_path": f"{graphnet_root}/graph_net/paddle/extractor.py", + "decorator_config": { + "name": model_name, + "custom_extractor_path": f"{graphnet_root}/graph_net/paddle/prologue_subgraph_unittest_generator.py", + "custom_extractor_config": { + "output_dir": output_dir, + "subgraph_range": subgraph_range, + "device": "auto", + "try_run": True, + }, + }, + } + + decorator_config_b64 = convert_json_to_b64_string(decorator_config) + + print( + f"[Unittest] model_path: {model_path}, subgraph_range: {subgraph_range}", + flush=True, + ) + cmd = [ + sys.executable, + "-m", + f"graph_net.{framework}.run_model", + "--model-path", + model_path, + "--decorator-config", + decorator_config_b64, + ] + with open(log_path, "a") as f: + result = subprocess.run(cmd, stdout=f, stderr=f, text=True) + assert result.returncode == 0 + + +def generate_unittest(decompose_config, pass_id, output_dir): + running_state = decompose_config.get_running_state(pass_id) + + unittest_dir = os.path.join(output_dir, "unittests") + log_path = os.path.join(output_dir, "log_unittest_generation.txt") + print(f"[Unittest] log_path: {log_path}", flush=True) + for model_name, model_record in running_state.model_name2record.items(): + if not model_record.incorrect_subgraph_idxs: + continue + + original_path = model_record.original_path + subgraph_idx = model_record.incorrect_subgraph_idxs[0] + if decompose_config.decompose_method == "fixed-start": + subgraph_range = [0, model_record.uniform_split_positions[subgraph_idx + 1]] + else: + subgraph_range = [ + model_record.uniform_split_positions[subgraph_idx], + model_record.uniform_split_positions[subgraph_idx + 1], + ] + + rectified_model_path = get_rectfied_model_path(original_path) + assert os.path.exists( + rectified_model_path + ), f"{rectified_model_path} does not exist." + + generate_unittest_for_single_model( + decompose_config.framework, + model_name, + rectified_model_path, + subgraph_range, + unittest_dir, + log_path, + ) + + def reconstruct_split_positions_for_subgraphs( split_positions, subgraph_idxs, max_subgraph_size ): @@ -496,6 +569,8 @@ def generate_successor_tasks(args, output_dir, pass_id): max_subgraph_size=max_subgraph_size, running_states=prev_config.running_states, ) + if max_subgraph_size <= 0: + return decompose_config prev_running_state = prev_config.get_running_state(pass_id - 1) assert prev_running_state is not None @@ -538,14 +613,17 @@ def prepare_tasks_and_verify(args, pass_id, output_dir): print_incorrect_models(decompose_config, pass_id - 1, log_prompt="[Init]") if not decompose_config.get_incorrect_models(pass_id - 1): - print("[FINISHED] No models need processing.", flush=True) + print( + f"\n[Conclusion] No incorrect models after {pass_id - 1} steps.", flush=True + ) sys.exit(0) if decompose_config.max_subgraph_size <= 0: print( - f"[FINISHED] Cannot decompose with max_subgraph_size {decompose_config.max_subgraph_size}.", + f"\n[Conclusion] Decomposition has reached the minimal granularity (max_subgraph_size = 1) after {pass_id - 1} steps.", flush=True, ) + generate_unittest(decompose_config, pass_id - 1, output_dir) sys.exit(0) return decompose_config @@ -648,15 +726,14 @@ def print_summary_and_suggestion(decompose_config, pass_id): flush=True, ) print( - ">>> Please start next round decomposition test (Run this script again).", + ">>> Please start next decomposition step (Run this script again).", flush=True, ) - elif num_incorrect_models > 0 and decompose_config.max_subgraph_size <= 1: + elif decompose_config.max_subgraph_size <= 1: print( - ">>> [FAILURE] Minimal granularity reached, but errors persist.", flush=True + ">>> [Conclusion] Decomposition has reached the minimal granularity (max_subgraph_size = 1) after {pass_id - 1} steps.", + flush=True, ) - else: - print(">>> [SUCCESS] Debugging converged.", flush=True) print("=" * 80, flush=True) From 2728e84f508be5ae3d8e7dafc8f1abb7d59f01cf Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 6 Jan 2026 11:32:55 +0800 Subject: [PATCH 5/6] Fix several error. --- .../paddle/prologue_subgraph_unittest_generator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/graph_net/paddle/prologue_subgraph_unittest_generator.py b/graph_net/paddle/prologue_subgraph_unittest_generator.py index 1dd56bcfe..afe98b1d9 100644 --- a/graph_net/paddle/prologue_subgraph_unittest_generator.py +++ b/graph_net/paddle/prologue_subgraph_unittest_generator.py @@ -135,13 +135,13 @@ def get_input_dict(device): def tolerance_generator(tolerance, dtype): if dtype == paddle.float16: - yield 10 ** (tolerance * 3 / 5), 10**tolerance + return 10 ** (tolerance * 3 / 5), 10**tolerance elif dtype == paddle.bfloat16: - yield 10 ** (tolerance * 1.796 / 5), 10**tolerance + return 10 ** (tolerance * 1.796 / 5), 10**tolerance elif dtype == paddle.float32: - yield 10 ** (tolerance * 5.886 / 5), 10**tolerance + return 10 ** (tolerance * 5.886 / 5), 10**tolerance elif dtype == paddle.float64: - yield 10 ** (tolerance * 7 / 5), 10 ** (tolerance * 7 / 5) + return 10 ** (tolerance * 7 / 5), 10 ** (tolerance * 7 / 5) else: assert False, f"Unsupported {dtype=}." @@ -217,7 +217,7 @@ def _get_output_dtypes(outs): self.assertTrue(dtype_match, f"Data type of outputs are not matched ({reference_dtypes=} vs {target_dtypes}).") def check_shapes(self, reference_outputs, target_outputs): - def _get_output_shapes(self, outs): + def _get_output_shapes(outs): shapes = [ tensor.shape if isinstance(tensor, paddle.Tensor) else None for i, tensor in enumerate(outs) @@ -263,7 +263,7 @@ def test_main(self): paddle.save(test_outputs, test_output_path) else: print(f"Load test output tensors on reference device from {test_output_path}.") - test_reference_outputs = paddle.load(test_reference_outputs) + test_reference_outputs = paddle.load(test_output_path) self.check_results(test_reference_outputs, test_outputs) From 1129bd803bb2eb76aaf14df0463e3e4858b50788 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 6 Jan 2026 13:27:03 +0800 Subject: [PATCH 6/6] Add test_separated and test_combined. --- .../prologue_subgraph_unittest_generator.py | 89 +++++++++++++------ .../subgraph_decompose_and_evaluation_step.py | 8 +- ...logue_subgraph_unittest_generator_test.sh} | 1 + 3 files changed, 68 insertions(+), 30 deletions(-) rename graph_net/test/{prologue_unittest_generator_test.sh => prologue_subgraph_unittest_generator_test.sh} (96%) diff --git a/graph_net/paddle/prologue_subgraph_unittest_generator.py b/graph_net/paddle/prologue_subgraph_unittest_generator.py index afe98b1d9..227e117d5 100644 --- a/graph_net/paddle/prologue_subgraph_unittest_generator.py +++ b/graph_net/paddle/prologue_subgraph_unittest_generator.py @@ -38,6 +38,7 @@ def make_config( self, subgraph_range: list, device: Literal["auto", "cpu", "cuda", "xpu"] = "auto", + tolerance: int = 0, try_run: bool = False, output_dir: str = "/tmp/prologue_unittests", ): @@ -49,6 +50,7 @@ def make_config( return { "subgraph_range": subgraph_range, "device": device, + "tolerance": tolerance, "try_run": try_run, "output_dir": output_dir, } @@ -116,12 +118,24 @@ def init_float_tensor(shape, dtype, device, min_val, max_val, mean=None, std=Non return paddle.to_tensor(array).to(dtype).to(device) -class PrologueModel(paddle.nn.Layer): +class PrologueLayer(paddle.nn.Layer): {{graph_module_desc.prologue_forward_func}} +class SuspectLayer(paddle.nn.Layer): +{{graph_module_desc.suspect_forward_func}} + + class TestModel(paddle.nn.Layer): -{{graph_module_desc.test_forward_func}} + def __init__(self): + super().__init__() + self.prologue_layer = PrologueLayer() + self.suspect_layer = SuspectLayer() + + def forward(self, {{graph_module_desc.arg_names | join(", ")}}): + {{graph_module_desc.prologue_returns | join(", ")}} = self.prologue_layer({{graph_module_desc.prologue_arg_names | join(", ")}}) + {{graph_module_desc.suspect_returns | join(", ")}} = self.suspect_layer({{graph_module_desc.suspect_arg_names | join(", ")}}) + return ({{graph_module_desc.suspect_returns | join(", ")}},) def get_input_dict(device): @@ -151,14 +165,13 @@ def setUp(self): self.device = TEST_ARGS.device self.is_reference = TEST_ARGS.is_reference self.reference_dir = TEST_ARGS.reference_dir - self.tolerance = TEST_ARGS.tolerance + self.tolerance = {{graph_module_desc.tolerance}} paddle.seed(123) random.seed(123) np.random.seed(123) self.input_dict = get_input_dict(self.device) - self.prologue_model = PrologueModel() self.test_model = TestModel() def _flatten_outputs_to_list(self, outs): @@ -173,18 +186,18 @@ def _flatten_outputs_to_list(self, outs): ] return flattened_outs - def run_prologue_model(self): + def run_prologue_layer(self): prologue_inputs = [ {%- for arg_name in graph_module_desc.prologue_arg_names %} self.input_dict['{{arg_name}}'], {%- endfor %} ] - prologue_outputs = self.prologue_model(*prologue_inputs) + prologue_outputs = self.test_model.prologue_layer(*prologue_inputs) return self._flatten_outputs_to_list(prologue_outputs) - def run_test_model(self, prologue_outputs): - test_inputs = [ - {%- for arg_name in graph_module_desc.test_arg_names %} + def run_suspect_layer(self, prologue_outputs): + suspect_inputs = [ + {%- for arg_name in graph_module_desc.suspect_arg_names %} {%- if arg_name not in graph_module_desc.prologue_returns %} self.input_dict['{{arg_name}}'], {%- else %} @@ -196,7 +209,11 @@ def run_test_model(self, prologue_outputs): {%- endif %} {%- endfor %} ] - test_outputs = self.test_model(*test_inputs) + suspect_outputs = self.test_model.suspect_layer(*suspect_inputs) + return self._flatten_outputs_to_list(suspect_outputs) + + def run_test_model(self): + test_outputs = self.test_model(**self.input_dict) return self._flatten_outputs_to_list(test_outputs) def check_dtypes(self, reference_outputs, target_outputs): @@ -246,18 +263,29 @@ def _convert_to_numpy(out): atol, rtol = tolerance_generator(self.tolerance, reference.dtype) np.testing.assert_allclose(_convert_to_numpy(reference), _convert_to_numpy(target), atol, rtol) - def test_main(self): - prologue_output_path = os.path.join(self.reference_dir, "{{graph_module_desc.model_name}}_prologue.pdout") + def test_separated(self): + prologue_output_path = os.path.join(self.reference_dir, "{{graph_module_desc.model_name}}_separate_prologue.pdout") if self.is_reference: - prologue_outputs = self.run_prologue_model() + prologue_outputs = self.run_prologue_layer() print(f"Save prologue output tensors to {prologue_output_path}.") paddle.save(prologue_outputs, prologue_output_path) else: print(f"Load prologue output tensors from {prologue_output_path}") prologue_outputs = paddle.load(prologue_output_path) - test_output_path = os.path.join(self.reference_dir, "{{graph_module_desc.model_name}}_test_reference.pdout") - test_outputs = self.run_test_model(prologue_outputs) + test_output_path = os.path.join(self.reference_dir, "{{graph_module_desc.model_name}}_separate_reference.pdout") + test_outputs = self.run_suspect_layer(prologue_outputs) + if self.is_reference: + print(f"Save test output tensors to {test_output_path}.") + paddle.save(test_outputs, test_output_path) + else: + print(f"Load test output tensors on reference device from {test_output_path}.") + test_reference_outputs = paddle.load(test_output_path) + self.check_results(test_reference_outputs, test_outputs) + + def test_combined(self): + test_output_path = os.path.join(self.reference_dir, "{{graph_module_desc.model_name}}_combined_reference.pdout") + test_outputs = self.run_test_model() if self.is_reference: print(f"Save test output tensors to {test_output_path}.") paddle.save(test_outputs, test_output_path) @@ -272,7 +300,6 @@ def test_main(self): parser.add_argument("--is-reference", action="store_true", default=False) parser.add_argument("--device", type=str, required=True) parser.add_argument("--reference-dir", type=str, required=True) - parser.add_argument("--tolerance", type=int, choices=range(-10, 5), default=0) args, remaining = parser.parse_known_args() global TEST_ARGS @@ -287,12 +314,15 @@ def test_main(self): [ "model_name", "test_name", + "tolerance", + "arg_names", "tensor_metas", "prologue_arg_names", "prologue_returns", "prologue_forward_func", - "test_arg_names", - "test_forward_func", + "suspect_arg_names", + "suspect_returns", + "suspect_forward_func", ], ) @@ -325,6 +355,7 @@ def __init__( ) self.subgraph_range = self.config["subgraph_range"] self.device = self._choose_device(self.config["device"]) + self.tolerance = self.config["tolerance"] self.try_run = self.config["try_run"] self.output_dir = self.config["output_dir"] self.graph_meta_restorer = self._make_graph_meta_restorer() @@ -390,6 +421,7 @@ def generate(self, subgraph_generator, tmp_dir): graph_module, output_path = self._save_and_get_graph_module( subgraph_generator, self.subgraph_range, tmp_dir ) + arg_names = self._get_forward_arg_names(graph_module) self.graph_meta_restorer(output_path) tensor_metas = self._get_tensor_metas(output_path) @@ -403,24 +435,29 @@ def generate(self, subgraph_generator, tmp_dir): ) prologue_arg_names = self._get_forward_arg_names(prologue_graph_module) - # test model information - test_subgraph_range = [self.subgraph_range[1] - 1, self.subgraph_range[1]] - test_graph_module, _ = self._save_and_get_graph_module( - subgraph_generator, test_subgraph_range, tmp_dir + # suspect model information + suspect_subgraph_range = [self.subgraph_range[1] - 1, self.subgraph_range[1]] + suspect_graph_module, _ = self._save_and_get_graph_module( + subgraph_generator, suspect_subgraph_range, tmp_dir + ) + suspect_forward_func, suspect_returns = self._get_forward_func_and_returns( + suspect_graph_module ) - test_forward_func, _ = self._get_forward_func_and_returns(test_graph_module) - test_arg_names = self._get_forward_arg_names(test_graph_module) + suspect_arg_names = self._get_forward_arg_names(suspect_graph_module) def _generate_unittest(): graph_module_desc = GraphModuleDescriptor( model_name=self.model_name, test_name=test_name, + tolerance=self.tolerance, + arg_names=arg_names, tensor_metas=tensor_metas, prologue_arg_names=prologue_arg_names, prologue_returns=prologue_returns, prologue_forward_func=prologue_forward_func, - test_arg_names=test_arg_names, - test_forward_func=test_forward_func, + suspect_arg_names=suspect_arg_names, + suspect_returns=suspect_returns, + suspect_forward_func=suspect_forward_func, ) return self._render_template(graph_module_desc) diff --git a/graph_net/subgraph_decompose_and_evaluation_step.py b/graph_net/subgraph_decompose_and_evaluation_step.py index addee52c8..cb222c111 100755 --- a/graph_net/subgraph_decompose_and_evaluation_step.py +++ b/graph_net/subgraph_decompose_and_evaluation_step.py @@ -419,7 +419,7 @@ def run_evaluation( def generate_unittest_for_single_model( - framework, model_name, model_path, subgraph_range, output_dir, log_path + framework, model_name, model_path, subgraph_range, tolerance, output_dir, log_path ): graphnet_root = path_utils.get_graphnet_root() decorator_config = { @@ -431,6 +431,7 @@ def generate_unittest_for_single_model( "output_dir": output_dir, "subgraph_range": subgraph_range, "device": "auto", + "tolerance": tolerance, "try_run": True, }, }, @@ -486,6 +487,7 @@ def generate_unittest(decompose_config, pass_id, output_dir): model_name, rectified_model_path, subgraph_range, + decompose_config.tolerance[0], unittest_dir, log_path, ) @@ -770,9 +772,7 @@ def main(args): # --- Step 4: Analysis --- if task_controller.task_scheduler["post_analysis"]: - tolerance = ( - args.tolerance[0] if isinstance(args.tolerance, list) else args.tolerance - ) + tolerance = args.tolerance[0] print(f"\n--- Phase 3: Analysis (torlance={tolerance}) ---") next_pass_incorrect_models = sorted(get_incorrect_models(tolerance, log_path)) decompose_config.update_running_state_with_incorrect_models( diff --git a/graph_net/test/prologue_unittest_generator_test.sh b/graph_net/test/prologue_subgraph_unittest_generator_test.sh similarity index 96% rename from graph_net/test/prologue_unittest_generator_test.sh rename to graph_net/test/prologue_subgraph_unittest_generator_test.sh index 0184e21e6..23bb4ca3f 100644 --- a/graph_net/test/prologue_unittest_generator_test.sh +++ b/graph_net/test/prologue_subgraph_unittest_generator_test.sh @@ -13,6 +13,7 @@ decorator_config_json_str=$(cat <