Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions graph_net/dimension_generalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,41 @@ def _save_tensor_metas_as_weight_meta(self, to_model_path, tensor_metas):
(to_model_path / "weight_meta.py").write_text(weight_meta_code)

def _get_to_model_path(self, rel_model_path, symbol2example_value):
sym_dim_str = "_".join(
f"{sym_name}_{dim}"
for symbol, dim in symbol2example_value.items()
for sym_name in [symbol.name]
"""
Generates output paths organized by dimension configuration indices rather than
symbolic dimension strings.

Path structure transformation:
Before: model_name__symbolic_dims (e.g., 'model1__symA_8_symB_16')
After: index/model_name (e.g., '0/model1', '1/model1')

The index represents a specific dimension configuration from the reification set,
enabling systematic management of dimension variations.
"""
# Use indices instead of symbol strings
symbols, reified_dims = self._get_symbols_and_reified_dims(
Path(self.config["model_path_prefix"]) / rel_model_path,
DynamicDimConstraints.unserialize_from_py_file(
os.path.join(
self.config["model_path_prefix"],
rel_model_path,
"input_tensor_constraints.py",
)
),
)
sub_module_name = f"{os.path.basename(rel_model_path)}__{sym_dim_str}"
current_dims = tuple(symbol2example_value[symbol] for symbol in symbols)

# Find corresponding index through dimension value matching
dim_index = 0
for i, dims in enumerate(reified_dims):
if tuple(dims) == current_dims:
dim_index = i
break

# Path structure changed from model/name to index/model
sub_module_name = f"{dim_index}"
to_model_path = (
Path(self.config["output_dir"]) / rel_model_path / sub_module_name
Path(self.config["output_dir"]) / sub_module_name / rel_model_path
)
return to_model_path

Expand Down
87 changes: 61 additions & 26 deletions graph_net/sample_pass/group_ranges_from_subgraph_sources.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from graph_net.sample_pass.sample_pass import SamplePass
from pathlib import Path
import json
import re


class GroupRangesFromSubgraphSources(SamplePass):
Expand Down Expand Up @@ -30,57 +31,89 @@ def __call__(self, subgraph_rel_model_path: str):
)
subgraph_sources = json.load(open(model_path))
for original_graph_rel_model_path, subgraph_ranges in subgraph_sources.items():
self._collect_original_graph_rel_model_path2ranges(
original_graph_rel_model_path, subgraph_ranges
)
self._collect_original_graph_rel_model_path2subgraph_rel_model_path(
original_graph_rel_model_path,
[subgraph_rel_model_path] * len(subgraph_ranges),
)
# Extract actual start-end range from path, establish precise one-to-one mapping
path_range = self._extract_range_from_path(subgraph_rel_model_path)
if path_range:
self._collect_original_graph_rel_model_path2ranges(
original_graph_rel_model_path, path_range
)
self._collect_original_graph_rel_model_path2subgraph_rel_model_path(
original_graph_rel_model_path, [subgraph_rel_model_path]
)

def _extract_range_from_path(self, path: str) -> list[int]:
"""
Parses subgraph path names to extract the node range information.
This establishes the precise correspondence between
subgraph locations and their operational scope within the original graph.

For example: model_startX_endY_Z -> [X, Y]
"""
match = re.search(r"start(\d+)_end(\d+)", path)
if match:
return [int(match.group(1)), int(match.group(2))]
return None

def _collect_original_graph_rel_model_path2subgraph_rel_model_path(
self,
original_graph_rel_model_path: str,
subgraph_rel_model_paths: list[str],
):
"""Collect subgraph paths with automatic deduplication"""
old = self.original_graph_rel_model_path2subgraph_rel_model_paths.get(
original_graph_rel_model_path, []
)
# Deduplicate and merge, maintaining path uniqueness
combined = old + [p for p in subgraph_rel_model_paths if p not in old]
self.original_graph_rel_model_path2subgraph_rel_model_paths[
original_graph_rel_model_path
] = [
*old,
*subgraph_rel_model_paths,
]
] = combined

def _collect_original_graph_rel_model_path2ranges(
self, original_graph_rel_model_path, subgraph_ranges
self, original_graph_path: str, path_range: list[int]
):
"""Collect subgraph ranges with automatic deduplication"""
old_ranges = self.original_graph_rel_model_path2ranges.get(
original_graph_rel_model_path, []
original_graph_path, []
)
self.original_graph_rel_model_path2ranges[original_graph_rel_model_path] = [
*old_ranges,
*subgraph_ranges,
]
if path_range not in old_ranges:
old_ranges.append(path_range)
self.original_graph_rel_model_path2ranges[original_graph_path] = old_ranges

def END(self, rel_model_paths: list[str]):
"""Final processing: sort by range and save results"""
for (
original_graph_rel_model_path,
subgraph_ranges,
) in self.original_graph_rel_model_path2ranges.items():
subgraph_rel_model_paths = (
self.original_graph_rel_model_path2subgraph_rel_model_paths[
original_graph_rel_model_path
]
original_graph_rel_model_path
) in self.original_graph_rel_model_path2ranges.keys():
actual_ranges = self.original_graph_rel_model_path2ranges.get(
original_graph_rel_model_path, []
)
self._save_json(
original_graph_rel_model_path, subgraph_ranges, subgraph_rel_model_paths
subgraph_rel_model_paths = (
self.original_graph_rel_model_path2subgraph_rel_model_paths.get(
original_graph_rel_model_path, []
)
)

# Establish mapping relationship between ranges and paths
range_to_path = {}
for path in subgraph_rel_model_paths:
path_range = self._extract_range_from_path(path)
if path_range:
range_to_path[tuple(path_range)] = path

# Sort by start position of ranges to ensure structured output
sorted_ranges = sorted(actual_ranges, key=lambda x: x[0])
sorted_paths = [
range_to_path[tuple(r)]
for r in sorted_ranges
if tuple(r) in range_to_path
]
self._save_json(original_graph_rel_model_path, sorted_ranges, sorted_paths)

def _save_json(
self, original_graph_rel_model_path, subgraph_ranges, subgraph_rel_model_paths
):
"""Save final aggregated results to JSON files"""
model_dir = Path(self.config["output_dir"]) / original_graph_rel_model_path
model_dir.mkdir(parents=True, exist_ok=True)
ranges_json = self._get_ranges_json(subgraph_ranges)
Expand All @@ -90,6 +123,7 @@ def _save_json(
(model_dir / self.config["output_json_file_name"]).write_text(json_str)

def _get_paths_json(self, subgraph_rel_model_paths: list[str]):
"""Generate JSON object for paths section"""
json_obj = {
self.config[
"output_json_subgraph_rel_model_path_key"
Expand All @@ -98,5 +132,6 @@ def _get_paths_json(self, subgraph_rel_model_paths: list[str]):
return json_obj

def _get_ranges_json(self, subgraph_ranges: list[(int, int)]):
"""Generate JSON object for ranges section"""
json_obj = {self.config["output_json_key"]: subgraph_ranges}
return json_obj