diff --git a/graph_net/dimension_generalizer.py b/graph_net/dimension_generalizer.py index da754d6d8..e8ac50731 100644 --- a/graph_net/dimension_generalizer.py +++ b/graph_net/dimension_generalizer.py @@ -137,14 +137,41 @@ def _save_tensor_metas_as_weight_meta(self, to_model_path, tensor_metas): (to_model_path / "weight_meta.py").write_text(weight_meta_code) def _get_to_model_path(self, rel_model_path, symbol2example_value): - sym_dim_str = "_".join( - f"{sym_name}_{dim}" - for symbol, dim in symbol2example_value.items() - for sym_name in [symbol.name] + """ + Generates output paths organized by dimension configuration indices rather than + symbolic dimension strings. + + Path structure transformation: + Before: model_name__symbolic_dims (e.g., 'model1__symA_8_symB_16') + After: index/model_name (e.g., '0/model1', '1/model1') + + The index represents a specific dimension configuration from the reification set, + enabling systematic management of dimension variations. + """ + # Use indices instead of symbol strings + symbols, reified_dims = self._get_symbols_and_reified_dims( + Path(self.config["model_path_prefix"]) / rel_model_path, + DynamicDimConstraints.unserialize_from_py_file( + os.path.join( + self.config["model_path_prefix"], + rel_model_path, + "input_tensor_constraints.py", + ) + ), ) - sub_module_name = f"{os.path.basename(rel_model_path)}__{sym_dim_str}" + current_dims = tuple(symbol2example_value[symbol] for symbol in symbols) + + # Find corresponding index through dimension value matching + dim_index = 0 + for i, dims in enumerate(reified_dims): + if tuple(dims) == current_dims: + dim_index = i + break + + # Path structure changed from model/name to index/model + sub_module_name = f"{dim_index}" to_model_path = ( - Path(self.config["output_dir"]) / rel_model_path / sub_module_name + Path(self.config["output_dir"]) / sub_module_name / rel_model_path ) return to_model_path diff --git a/graph_net/sample_pass/group_ranges_from_subgraph_sources.py b/graph_net/sample_pass/group_ranges_from_subgraph_sources.py index eaf5be128..f776f0389 100644 --- a/graph_net/sample_pass/group_ranges_from_subgraph_sources.py +++ b/graph_net/sample_pass/group_ranges_from_subgraph_sources.py @@ -1,6 +1,7 @@ from graph_net.sample_pass.sample_pass import SamplePass from pathlib import Path import json +import re class GroupRangesFromSubgraphSources(SamplePass): @@ -30,57 +31,89 @@ def __call__(self, subgraph_rel_model_path: str): ) subgraph_sources = json.load(open(model_path)) for original_graph_rel_model_path, subgraph_ranges in subgraph_sources.items(): - self._collect_original_graph_rel_model_path2ranges( - original_graph_rel_model_path, subgraph_ranges - ) - self._collect_original_graph_rel_model_path2subgraph_rel_model_path( - original_graph_rel_model_path, - [subgraph_rel_model_path] * len(subgraph_ranges), - ) + # Extract actual start-end range from path, establish precise one-to-one mapping + path_range = self._extract_range_from_path(subgraph_rel_model_path) + if path_range: + self._collect_original_graph_rel_model_path2ranges( + original_graph_rel_model_path, path_range + ) + self._collect_original_graph_rel_model_path2subgraph_rel_model_path( + original_graph_rel_model_path, [subgraph_rel_model_path] + ) + + def _extract_range_from_path(self, path: str) -> list[int]: + """ + Parses subgraph path names to extract the node range information. + This establishes the precise correspondence between + subgraph locations and their operational scope within the original graph. + + For example: model_startX_endY_Z -> [X, Y] + """ + match = re.search(r"start(\d+)_end(\d+)", path) + if match: + return [int(match.group(1)), int(match.group(2))] + return None def _collect_original_graph_rel_model_path2subgraph_rel_model_path( self, original_graph_rel_model_path: str, subgraph_rel_model_paths: list[str], ): + """Collect subgraph paths with automatic deduplication""" old = self.original_graph_rel_model_path2subgraph_rel_model_paths.get( original_graph_rel_model_path, [] ) + # Deduplicate and merge, maintaining path uniqueness + combined = old + [p for p in subgraph_rel_model_paths if p not in old] self.original_graph_rel_model_path2subgraph_rel_model_paths[ original_graph_rel_model_path - ] = [ - *old, - *subgraph_rel_model_paths, - ] + ] = combined def _collect_original_graph_rel_model_path2ranges( - self, original_graph_rel_model_path, subgraph_ranges + self, original_graph_path: str, path_range: list[int] ): + """Collect subgraph ranges with automatic deduplication""" old_ranges = self.original_graph_rel_model_path2ranges.get( - original_graph_rel_model_path, [] + original_graph_path, [] ) - self.original_graph_rel_model_path2ranges[original_graph_rel_model_path] = [ - *old_ranges, - *subgraph_ranges, - ] + if path_range not in old_ranges: + old_ranges.append(path_range) + self.original_graph_rel_model_path2ranges[original_graph_path] = old_ranges def END(self, rel_model_paths: list[str]): + """Final processing: sort by range and save results""" for ( - original_graph_rel_model_path, - subgraph_ranges, - ) in self.original_graph_rel_model_path2ranges.items(): - subgraph_rel_model_paths = ( - self.original_graph_rel_model_path2subgraph_rel_model_paths[ - original_graph_rel_model_path - ] + original_graph_rel_model_path + ) in self.original_graph_rel_model_path2ranges.keys(): + actual_ranges = self.original_graph_rel_model_path2ranges.get( + original_graph_rel_model_path, [] ) - self._save_json( - original_graph_rel_model_path, subgraph_ranges, subgraph_rel_model_paths + subgraph_rel_model_paths = ( + self.original_graph_rel_model_path2subgraph_rel_model_paths.get( + original_graph_rel_model_path, [] + ) ) + # Establish mapping relationship between ranges and paths + range_to_path = {} + for path in subgraph_rel_model_paths: + path_range = self._extract_range_from_path(path) + if path_range: + range_to_path[tuple(path_range)] = path + + # Sort by start position of ranges to ensure structured output + sorted_ranges = sorted(actual_ranges, key=lambda x: x[0]) + sorted_paths = [ + range_to_path[tuple(r)] + for r in sorted_ranges + if tuple(r) in range_to_path + ] + self._save_json(original_graph_rel_model_path, sorted_ranges, sorted_paths) + def _save_json( self, original_graph_rel_model_path, subgraph_ranges, subgraph_rel_model_paths ): + """Save final aggregated results to JSON files""" model_dir = Path(self.config["output_dir"]) / original_graph_rel_model_path model_dir.mkdir(parents=True, exist_ok=True) ranges_json = self._get_ranges_json(subgraph_ranges) @@ -90,6 +123,7 @@ def _save_json( (model_dir / self.config["output_json_file_name"]).write_text(json_str) def _get_paths_json(self, subgraph_rel_model_paths: list[str]): + """Generate JSON object for paths section""" json_obj = { self.config[ "output_json_subgraph_rel_model_path_key" @@ -98,5 +132,6 @@ def _get_paths_json(self, subgraph_rel_model_paths: list[str]): return json_obj def _get_ranges_json(self, subgraph_ranges: list[(int, int)]): + """Generate JSON object for ranges section""" json_obj = {self.config["output_json_key"]: subgraph_ranges} return json_obj