Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions athena/generators/graphnet_sequence_sample_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,27 @@ class SequenceFuncDesc:


class GraphnetSequenceSampleGenerator:
def __init__(self, program_id, op_example_inputs_meta_getter):
def __init__(self, program_id, program_seq_stmts, op_example_inputs_meta_getter):
self.program_id = program_id
self.program_seq_stmts = program_seq_stmts
self.op_example_inputs_meta_getter = op_example_inputs_meta_getter
self.input_spec_mode = "original"

def Generate(self, seq_stmts):
seq_func_desc = self.MakeSequenceFuncDesc(seq_stmts)
def Generate(self, subgraph_range, use_all_inputs):
assert isinstance(subgraph_range, (tuple, list)) and len(subgraph_range) == 2
seq_stmts = self.program_seq_stmts[subgraph_range[0] : subgraph_range[1]]
seq_func_desc = self.MakeSequenceFuncDesc(
seq_stmts, use_all_inputs and subgraph_range[0] == 0
)
return self._RenderTemplate(seq_func_desc)

def MakeSequenceFuncDesc(self, seq_stmts):
op_id2seq_stmt = OrderedDict((stmt.op_id, stmt) for stmt in seq_stmts)
def MakeSequenceFuncDesc(self, seq_stmts, use_all_inputs):
if use_all_inputs:
op_id2seq_stmt = OrderedDict(
(stmt.op_id, stmt) for stmt in self.program_seq_stmts
)
else:
op_id2seq_stmt = OrderedDict((stmt.op_id, stmt) for stmt in seq_stmts)
ops_func_signature = OpsFuncSignature(
tensor_ids=self.GetTensorIds(op_id2seq_stmt),
operand_ids=self.GetOperandIds(op_id2seq_stmt),
Expand Down
3 changes: 3 additions & 0 deletions athena/generators/template_graphnet_sequence_sample.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ class GraphModule(paddle.nn.Layer):
{%- endfor -%}
{%- for tensor_id in sig.tensor_ids -%}
{%- if 'parameter' not in sig.tensor_name4tensor_id(tensor_id) %}
{%- set data, dtype = sig.immediate_value4int_array_member_id(tensor_id) -%}
{%- if data is none %}
{{"\t\t"}}{{tensor_name_converter(sig.tensor_name4tensor_id(tensor_id))}}{{","}}
{%- endif %}
{%- endif %}
{%- endfor -%}
):
{%- for tensor_id in sig.tensor_ids -%}
Expand Down
14 changes: 14 additions & 0 deletions athena/generators/template_op_example_input_meta_script.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def CalculateTensorMeta(tensor, meta_name):
raise NotImplementedError(f"meta_name: {meta_name}")


def IsInitialized(tensor):
try:
is_initialized = isinstance(tensor, paddle.Tensor) and tensor.numel() > 0
_ = x.shape
_ = x.dtype
return is_initialized
except Exception:
return False


def InitTensorMeta(tensor, meta_name, tensor_meta):
if tensor_meta:
return getattr(tensor_meta, meta_name)
Expand All @@ -88,6 +98,8 @@ def InitTensorMeta(tensor, meta_name, tensor_meta):
return [InitTensorMeta(t, meta_name, tensor_meta) for t in tensor]
if not hasattr(tensor, meta_name):
raise NotImplementedError(f"type(tensor): {type(tensor)}, meta_name: {meta_name}")
if not IsInitialized(tensor):
return None
kLimit = 64
if tensor.numel().item() < kLimit:
return None
Expand All @@ -109,6 +121,8 @@ def InitTensorData(tensor):
return [InitTensorData(t) for t in tensor]
if not hasattr(tensor, 'numel'):
raise NotImplementedError(f"type(tensor): {type(tensor)}")
if not IsInitialized(tensor):
return None
kLimit = 64
if tensor.numel().item() >= kLimit:
return None
Expand Down
80 changes: 52 additions & 28 deletions athena/graphnet_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from itertools import groupby
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict
from typing import Dict, List

from athena.generators.blocks_generator import BlocksGenerator
from athena.generators.block_name_generator import BlockNameGenerator
Expand Down Expand Up @@ -56,12 +56,17 @@
True,
"Whether extend split_positions to include the head and tail of the statement sequence.",
)
flags.DEFINE_boolean(
"use_all_inputs",
False,
"Whether use all inputs of the ir program.",
)
flags.DEFINE_boolean(
"eval_mode",
False,
"Generate graphnet sample for eval, which only keep output tensors with maximum depth (longest chain).",
)
flags.DEFINE_string("tmp_dir", tempfile.gettempdir(), "tmp directory.")
flags.DEFINE_string("tmp_dir", None, "tmp directory.")


@dataclass
Expand All @@ -73,10 +78,11 @@ class GraphnetSample:
input_meta: str
weight_meta: str
model: str
subgraph_range: List[int] = None


def ConvertOutputStringToSample(
model_name, unique_name, subgraph_idx, program_id, sample_str
model_name, unique_name, subgraph_idx, program_id, sample_str, subgraph_range=None
):
metadata = {
"framework": "paddle",
Expand All @@ -94,6 +100,7 @@ def ConvertOutputStringToSample(
input_meta=input_meta.strip("\n\n\n") + "\n",
weight_meta=weight_meta.rstrip("\n\n\n") + "\n",
model=model,
subgraph_range=subgraph_range,
)
# PrintToTerminal(unique_name, sample_str)
return sample
Expand Down Expand Up @@ -154,7 +161,7 @@ def __init__(
example_inputs_file,
op_example_inputs_file,
eval_mode,
tmp_dir,
tmp_dir=None,
):
self.model_name = model_name
self.programs_file = programs_file
Expand Down Expand Up @@ -243,50 +250,63 @@ def ExtendHeadAndTail(self, seq_stmts, split_positions, group_head_and_tail):
print(f"split_positions_for_seq_stmts: {split_positions_for_seq_stmts}")
return split_positions_for_seq_stmts

def GetOutputSampleStrings(self, split_positions, group_head_and_tail=True):
def GetOutputSampleStrings(
self, split_positions, group_head_and_tail=True, use_all_inputs=False
):
def MakeSequenceSampleGenerator(
program_id, seq_stmts, op_example_inputs_meta_getter
program_id, program_seq_stmts, op_example_inputs_meta_getter
):
generator = GraphnetSequenceSampleGenerator(
program_id, op_example_inputs_meta_getter
return GraphnetSequenceSampleGenerator(
program_id, program_seq_stmts, op_example_inputs_meta_getter
)
return generator.Generate(seq_stmts)

print(f"origin split_positions: {split_positions}")
generated_sample_strs = set()
for subgraph_idx, (program_id, seq_stmts) in enumerate(
for subgraph_idx, (program_id, program_seq_stmts) in enumerate(
self.program_seq_stmts_list
):
generator = MakeSequenceSampleGenerator(
program_id, program_seq_stmts, self.op_example_inputs_meta_getter
)
split_positions_for_seq_stmts = self.ExtendHeadAndTail(
seq_stmts, split_positions, group_head_and_tail
program_seq_stmts, split_positions, group_head_and_tail
)
for i in range(len(split_positions_for_seq_stmts) - 1):
seq_stmts_slice = seq_stmts[
split_positions_for_seq_stmts[i] : split_positions_for_seq_stmts[
i + 1
]
]
sample_str = MakeSequenceSampleGenerator(
program_id, seq_stmts_slice, self.op_example_inputs_meta_getter
)
subgraph_range = split_positions_for_seq_stmts[i : i + 2]
sample_str = generator.Generate(subgraph_range, use_all_inputs)
if sample_str not in generated_sample_strs:
generated_sample_strs.add(sample_str)
stmt_hash = GetSeqStmtsHash(seq_stmts_slice)
yield (subgraph_idx, program_id, stmt_hash, sample_str)

def __call__(self, split_positions, group_head_and_tail=True):
stmt_hash = GetSeqStmtsHash(
program_seq_stmts[subgraph_range[0] : subgraph_range[1]]
)
yield (
subgraph_idx,
program_id,
stmt_hash,
subgraph_range,
sample_str,
)

def __call__(self, split_positions, group_head_and_tail=True, use_all_inputs=False):
graphnet_sample_results = []
seg_counter = defaultdict(lambda: itertools.count())
for _, (subgraph_idx, program_id, uid, sample_str) in enumerate(
self.GetOutputSampleStrings(split_positions, group_head_and_tail)
for _, (subgraph_idx, program_id, uid, subgraph_range, sample_str) in enumerate(
self.GetOutputSampleStrings(
split_positions, group_head_and_tail, use_all_inputs
)
):
unique_name = f"{uid}_{next(seg_counter[uid])}"
sample = ConvertOutputStringToSample(
self.model_name, unique_name, subgraph_idx, program_id, sample_str
self.model_name,
unique_name,
subgraph_idx,
program_id,
sample_str,
subgraph_range,
)
graphnet_sample_results.append(sample)
print(
f"[SubgraphGenerator] Generate {len(graphnet_sample_results)} graphnet subgraph samples ({split_positions=}, {group_head_and_tail=})."
f"[SubgraphGenerator] Generate {len(graphnet_sample_results)} graphnet subgraph samples ({split_positions=}, {group_head_and_tail=}, {use_all_inputs=})."
)
return graphnet_sample_results

Expand All @@ -298,6 +318,7 @@ def RunGeneration(
op_example_inputs,
split_positions,
group_head_and_tail,
use_all_inputs,
eval_mode,
tmp_dir=None,
):
Expand All @@ -313,7 +334,9 @@ def RunGeneration(
eval_mode,
tmp_dir,
)
graphnet_sample_results = generator(split_positions, group_head_and_tail)
graphnet_sample_results = generator(
split_positions, group_head_and_tail, use_all_inputs
)
return graphnet_sample_results


Expand All @@ -327,6 +350,7 @@ def main(argv):
op_example_inputs=FLAGS.op_example_inputs,
split_positions=split_positions,
group_head_and_tail=FLAGS.group_head_and_tail,
use_all_inputs=FLAGS.use_all_inputs,
eval_mode=FLAGS.eval_mode,
)

Expand Down