From d79a257563e399610820e4b093081d8fc82f2122 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Tue, 6 Jan 2026 22:47:50 +0800 Subject: [PATCH 1/4] Support use_all_inputs when subgraph_range is start at 0. --- .../graphnet_sequence_sample_generator.py | 20 +++++-- athena/graphnet_samples.py | 55 +++++++++++-------- 2 files changed, 48 insertions(+), 27 deletions(-) diff --git a/athena/generators/graphnet_sequence_sample_generator.py b/athena/generators/graphnet_sequence_sample_generator.py index 1ce317c..93a59e5 100644 --- a/athena/generators/graphnet_sequence_sample_generator.py +++ b/athena/generators/graphnet_sequence_sample_generator.py @@ -33,17 +33,27 @@ class SequenceFuncDesc: class GraphnetSequenceSampleGenerator: - def __init__(self, program_id, op_example_inputs_meta_getter): + def __init__(self, program_id, program_seq_stmts, op_example_inputs_meta_getter): self.program_id = program_id + self.program_seq_stmts = program_seq_stmts self.op_example_inputs_meta_getter = op_example_inputs_meta_getter self.input_spec_mode = "original" - def Generate(self, seq_stmts): - seq_func_desc = self.MakeSequenceFuncDesc(seq_stmts) + def Generate(self, subgraph_range, use_all_inputs): + assert isinstance(subgraph_range, (tuple, list)) and len(subgraph_range) == 2 + seq_stmts = self.program_seq_stmts[subgraph_range[0] : subgraph_range[1]] + seq_func_desc = self.MakeSequenceFuncDesc( + seq_stmts, use_all_inputs and subgraph_range[0] == 0 + ) return self._RenderTemplate(seq_func_desc) - def MakeSequenceFuncDesc(self, seq_stmts): - op_id2seq_stmt = OrderedDict((stmt.op_id, stmt) for stmt in seq_stmts) + def MakeSequenceFuncDesc(self, seq_stmts, use_all_inputs): + if use_all_inputs: + op_id2seq_stmt = OrderedDict( + (stmt.op_id, stmt) for stmt in self.program_seq_stmts + ) + else: + op_id2seq_stmt = OrderedDict((stmt.op_id, stmt) for stmt in seq_stmts) ops_func_signature = OpsFuncSignature( tensor_ids=self.GetTensorIds(op_id2seq_stmt), operand_ids=self.GetOperandIds(op_id2seq_stmt), diff --git a/athena/graphnet_samples.py b/athena/graphnet_samples.py index a6dab6e..8831120 100644 --- a/athena/graphnet_samples.py +++ b/athena/graphnet_samples.py @@ -56,12 +56,17 @@ True, "Whether extend split_positions to include the head and tail of the statement sequence.", ) +flags.DEFINE_boolean( + "use_all_inputs", + False, + "Whether use all inputs of the ir program.", +) flags.DEFINE_boolean( "eval_mode", False, "Generate graphnet sample for eval, which only keep output tensors with maximum depth (longest chain).", ) -flags.DEFINE_string("tmp_dir", tempfile.gettempdir(), "tmp directory.") +flags.DEFINE_string("tmp_dir", None, "tmp directory.") @dataclass @@ -154,7 +159,7 @@ def __init__( example_inputs_file, op_example_inputs_file, eval_mode, - tmp_dir, + tmp_dir=None, ): self.model_name = model_name self.programs_file = programs_file @@ -243,42 +248,44 @@ def ExtendHeadAndTail(self, seq_stmts, split_positions, group_head_and_tail): print(f"split_positions_for_seq_stmts: {split_positions_for_seq_stmts}") return split_positions_for_seq_stmts - def GetOutputSampleStrings(self, split_positions, group_head_and_tail=True): + def GetOutputSampleStrings( + self, split_positions, group_head_and_tail=True, use_all_inputs=False + ): def MakeSequenceSampleGenerator( - program_id, seq_stmts, op_example_inputs_meta_getter + program_id, program_seq_stmts, op_example_inputs_meta_getter ): - generator = GraphnetSequenceSampleGenerator( - program_id, op_example_inputs_meta_getter + return GraphnetSequenceSampleGenerator( + program_id, program_seq_stmts, op_example_inputs_meta_getter ) - return generator.Generate(seq_stmts) print(f"origin split_positions: {split_positions}") generated_sample_strs = set() - for subgraph_idx, (program_id, seq_stmts) in enumerate( + for subgraph_idx, (program_id, program_seq_stmts) in enumerate( self.program_seq_stmts_list ): + generator = MakeSequenceSampleGenerator( + program_id, program_seq_stmts, self.op_example_inputs_meta_getter + ) split_positions_for_seq_stmts = self.ExtendHeadAndTail( - seq_stmts, split_positions, group_head_and_tail + program_seq_stmts, split_positions, group_head_and_tail ) for i in range(len(split_positions_for_seq_stmts) - 1): - seq_stmts_slice = seq_stmts[ - split_positions_for_seq_stmts[i] : split_positions_for_seq_stmts[ - i + 1 - ] - ] - sample_str = MakeSequenceSampleGenerator( - program_id, seq_stmts_slice, self.op_example_inputs_meta_getter - ) + subgraph_range = split_positions_for_seq_stmts[i : i + 2] + sample_str = generator.Generate(subgraph_range, use_all_inputs) if sample_str not in generated_sample_strs: generated_sample_strs.add(sample_str) - stmt_hash = GetSeqStmtsHash(seq_stmts_slice) + stmt_hash = GetSeqStmtsHash( + program_seq_stmts[subgraph_range[0] : subgraph_range[1]] + ) yield (subgraph_idx, program_id, stmt_hash, sample_str) - def __call__(self, split_positions, group_head_and_tail=True): + def __call__(self, split_positions, group_head_and_tail=True, use_all_inputs=False): graphnet_sample_results = [] seg_counter = defaultdict(lambda: itertools.count()) for _, (subgraph_idx, program_id, uid, sample_str) in enumerate( - self.GetOutputSampleStrings(split_positions, group_head_and_tail) + self.GetOutputSampleStrings( + split_positions, group_head_and_tail, use_all_inputs + ) ): unique_name = f"{uid}_{next(seg_counter[uid])}" sample = ConvertOutputStringToSample( @@ -286,7 +293,7 @@ def __call__(self, split_positions, group_head_and_tail=True): ) graphnet_sample_results.append(sample) print( - f"[SubgraphGenerator] Generate {len(graphnet_sample_results)} graphnet subgraph samples ({split_positions=}, {group_head_and_tail=})." + f"[SubgraphGenerator] Generate {len(graphnet_sample_results)} graphnet subgraph samples ({split_positions=}, {group_head_and_tail=}, {use_all_inputs=})." ) return graphnet_sample_results @@ -298,6 +305,7 @@ def RunGeneration( op_example_inputs, split_positions, group_head_and_tail, + use_all_inputs, eval_mode, tmp_dir=None, ): @@ -313,7 +321,9 @@ def RunGeneration( eval_mode, tmp_dir, ) - graphnet_sample_results = generator(split_positions, group_head_and_tail) + graphnet_sample_results = generator( + split_positions, group_head_and_tail, use_all_inputs + ) return graphnet_sample_results @@ -327,6 +337,7 @@ def main(argv): op_example_inputs=FLAGS.op_example_inputs, split_positions=split_positions, group_head_and_tail=FLAGS.group_head_and_tail, + use_all_inputs=FLAGS.use_all_inputs, eval_mode=FLAGS.eval_mode, ) From f332ce5094cdfb40dde931509ddc75220a231a7f Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 7 Jan 2026 11:13:04 +0800 Subject: [PATCH 2/4] Fix the case when tensor is not initialized. --- .../template_op_example_input_meta_script.jinja | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/athena/generators/template_op_example_input_meta_script.jinja b/athena/generators/template_op_example_input_meta_script.jinja index d255f62..e374f3f 100644 --- a/athena/generators/template_op_example_input_meta_script.jinja +++ b/athena/generators/template_op_example_input_meta_script.jinja @@ -77,6 +77,16 @@ def CalculateTensorMeta(tensor, meta_name): raise NotImplementedError(f"meta_name: {meta_name}") +def IsInitialized(tensor): + try: + is_initialized = isinstance(tensor, paddle.Tensor) and tensor.numel() > 0 + _ = x.shape + _ = x.dtype + return is_initialized + except Exception: + return False + + def InitTensorMeta(tensor, meta_name, tensor_meta): if tensor_meta: return getattr(tensor_meta, meta_name) @@ -88,6 +98,8 @@ def InitTensorMeta(tensor, meta_name, tensor_meta): return [InitTensorMeta(t, meta_name, tensor_meta) for t in tensor] if not hasattr(tensor, meta_name): raise NotImplementedError(f"type(tensor): {type(tensor)}, meta_name: {meta_name}") + if not IsInitialized(tensor): + return None kLimit = 64 if tensor.numel().item() < kLimit: return None @@ -109,6 +121,8 @@ def InitTensorData(tensor): return [InitTensorData(t) for t in tensor] if not hasattr(tensor, 'numel'): raise NotImplementedError(f"type(tensor): {type(tensor)}") + if not IsInitialized(tensor): + return None kLimit = 64 if tensor.numel().item() >= kLimit: return None From fed09f72be5866b5b1d8571c520363872c5e1d9e Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Wed, 7 Jan 2026 17:51:46 +0800 Subject: [PATCH 3/4] Remove full_int_array from input arguments. --- athena/generators/template_graphnet_sequence_sample.jinja | 3 +++ 1 file changed, 3 insertions(+) diff --git a/athena/generators/template_graphnet_sequence_sample.jinja b/athena/generators/template_graphnet_sequence_sample.jinja index bed3f08..b318bf6 100644 --- a/athena/generators/template_graphnet_sequence_sample.jinja +++ b/athena/generators/template_graphnet_sequence_sample.jinja @@ -120,8 +120,11 @@ class GraphModule(paddle.nn.Layer): {%- endfor -%} {%- for tensor_id in sig.tensor_ids -%} {%- if 'parameter' not in sig.tensor_name4tensor_id(tensor_id) %} + {%- set data, dtype = sig.immediate_value4int_array_member_id(tensor_id) -%} + {%- if data is none %} {{"\t\t"}}{{tensor_name_converter(sig.tensor_name4tensor_id(tensor_id))}}{{","}} {%- endif %} + {%- endif %} {%- endfor -%} ): {%- for tensor_id in sig.tensor_ids -%} From a741abe55e9cef8246b2774a767f1e77823eea30 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 8 Jan 2026 09:37:02 +0800 Subject: [PATCH 4/4] Record subgraph_range. --- athena/graphnet_samples.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/athena/graphnet_samples.py b/athena/graphnet_samples.py index 8831120..64fb7d6 100644 --- a/athena/graphnet_samples.py +++ b/athena/graphnet_samples.py @@ -9,7 +9,7 @@ from itertools import groupby from collections import defaultdict from dataclasses import dataclass -from typing import Dict +from typing import Dict, List from athena.generators.blocks_generator import BlocksGenerator from athena.generators.block_name_generator import BlockNameGenerator @@ -78,10 +78,11 @@ class GraphnetSample: input_meta: str weight_meta: str model: str + subgraph_range: List[int] = None def ConvertOutputStringToSample( - model_name, unique_name, subgraph_idx, program_id, sample_str + model_name, unique_name, subgraph_idx, program_id, sample_str, subgraph_range=None ): metadata = { "framework": "paddle", @@ -99,6 +100,7 @@ def ConvertOutputStringToSample( input_meta=input_meta.strip("\n\n\n") + "\n", weight_meta=weight_meta.rstrip("\n\n\n") + "\n", model=model, + subgraph_range=subgraph_range, ) # PrintToTerminal(unique_name, sample_str) return sample @@ -277,19 +279,30 @@ def MakeSequenceSampleGenerator( stmt_hash = GetSeqStmtsHash( program_seq_stmts[subgraph_range[0] : subgraph_range[1]] ) - yield (subgraph_idx, program_id, stmt_hash, sample_str) + yield ( + subgraph_idx, + program_id, + stmt_hash, + subgraph_range, + sample_str, + ) def __call__(self, split_positions, group_head_and_tail=True, use_all_inputs=False): graphnet_sample_results = [] seg_counter = defaultdict(lambda: itertools.count()) - for _, (subgraph_idx, program_id, uid, sample_str) in enumerate( + for _, (subgraph_idx, program_id, uid, subgraph_range, sample_str) in enumerate( self.GetOutputSampleStrings( split_positions, group_head_and_tail, use_all_inputs ) ): unique_name = f"{uid}_{next(seg_counter[uid])}" sample = ConvertOutputStringToSample( - self.model_name, unique_name, subgraph_idx, program_id, sample_str + self.model_name, + unique_name, + subgraph_idx, + program_id, + sample_str, + subgraph_range, ) graphnet_sample_results.append(sample) print(