From 72cc04f9e1bfc4b7558e2f53fab8ca37cf196fae Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Mon, 11 Oct 2021 19:24:26 +0200 Subject: [PATCH 1/5] Add failing test for torch.fx.replace_pattern --- test/fx/test_subgraph_rewriter.py | 32 +++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index 17cf05bb67d9a..567c2fe1b4844 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -458,3 +458,35 @@ def forward(self, x): if n.op == 'placeholder': assert n.type == int assert m.type == int + + def test_subgraph_writer_replace_consecutive_submodules(self): + + def f(x): + x = torch.sigmoid(x) + x = torch.sigmoid(x) + return torch.sigmoid(x) + + def pattern(x): + return torch.sigmoid(x) + + def replacement(x): + return torch.exp(x) + + def comparison(x): + x = torch.exp(x) + x = torch.exp(x) + return torch.exp(x) + + traced = symbolic_trace(f) + comparison_fn = symbolic_trace(comparison) + + x = torch.randn(3, 4) + + subgraph_rewriter.replace_pattern(traced, pattern, replacement) + + traced.graph.lint() + + ref_outs = comparison_fn(x) + test_outs = traced.forward(x) + self.assertEqual(ref_outs, test_outs) + From b26e21f77d6542322fa813c64eb41d855526bae7 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Mon, 11 Oct 2021 20:50:26 +0200 Subject: [PATCH 2/5] Proposed fix for pattern_matching --- torch/fx/subgraph_rewriter.py | 89 ++++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 33 deletions(-) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 72ea56aa31196..2fab89cfac16b 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -131,6 +131,10 @@ def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Mo gm.graph.lint() +def add_suffix_to_graph(graph, suffix): + for node in graph.nodes: + node.name += suffix + @compatibility(is_backward_compatible=True) def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]: """ @@ -261,11 +265,10 @@ def forward(self, x, w1, w2): if matcher.matches_subgraph_from_anchor(anchor): - def pattern_is_contained(nodes_map : Dict[Node, Node]) -> bool: + def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool: # `lookup` represents all the nodes in `original_graph` # that are part of `pattern` - lookup: Dict[Node, Node] = {v : k for k, v - in nodes_map.items()} + lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()} for n in lookup.keys(): # Nodes that can "leak"... @@ -295,39 +298,46 @@ def pattern_is_contained(nodes_map : Dict[Node, Node]) -> bool: # It's not a match if the pattern leaks out into the rest # of the graph if pattern_is_contained(matcher.nodes_map): - for k, v in matcher.nodes_map.items(): - # Shallow copy nodes_map - matches.append(Match(anchor=anchor, - nodes_map=copy.copy(matcher.nodes_map))) + # Shallow copy nodes_map + matches.append(Match(anchor=anchor, + nodes_map=copy.copy({ + key: value + for key, value in matcher.nodes_map.items() + }))) # The set of all nodes in `original_graph` that we've seen thus far # as part of a pattern match replaced_nodes: Set[Node] = set() + # As we progressively replace node, we need to keep track on how the match results need to change also + match_changed_node: Dict[Node, Node] = dict() # Return True if one of the nodes in the current match has already # been used as part of another match def overlaps_with_prev_match(match: Match) -> bool: - for n in match.nodes_map.values(): - if n in replaced_nodes and n.op != "placeholder": + for pn, gn in match.nodes_map.items(): + if pn.op in ["placeholder", "output"]: + continue + if gn in replaced_nodes and gn.op != "placeholder": return True return False - for match in matches: - + for i, match in enumerate(matches): # Skip overlapping matches if overlaps_with_prev_match(match): continue + suffixed_replacement_graph = replacement_graph.__deepcopy__() + add_suffix_to_graph(suffixed_replacement_graph, f"_{i}") # Map replacement graph nodes to their copy in `original_graph` val_map: Dict[Node, Node] = {} pattern_placeholders = [n for n in pattern_graph.nodes if n.op == "placeholder"] assert len(pattern_placeholders) - replacement_placeholders = [n for n in replacement_graph.nodes + replacement_placeholders = [n for n in suffixed_replacement_graph.nodes if n.op == "placeholder"] assert len(pattern_placeholders) == len(replacement_placeholders) - placeholder_map = {r : p for r, p + placeholder_map = {r: p for r, p in zip(replacement_placeholders, pattern_placeholders)} # node from `original_graph` that matched with the output node @@ -341,21 +351,23 @@ def mark_node_as_replaced(n: Node) -> None: mark_node_as_replaced(n_) replaced_nodes.add(n) - mark_node_as_replaced(subgraph_output) + for input_node in subgraph_output.all_input_nodes: + mark_node_as_replaced(input_node) - # Intialize `val_map` with mappings from placeholder nodes in + # Initialize `val_map` with mappings from placeholder nodes in # `replacement` to their corresponding node in `original_graph` for replacement_node in replacement_placeholders: # Get the `original_graph` placeholder node # corresponding to the current `replacement_node` pattern_node = placeholder_map[replacement_node] - original_graph_node = match.nodes_map[pattern_node] + original_graph_node = match_changed_node.get(match.nodes_map[pattern_node], match.nodes_map[pattern_node]) + # Populate `val_map` val_map[replacement_node] = original_graph_node # Copy the replacement graph over with original_graph.inserting_before(subgraph_output): - copied_output = original_graph.graph_copy(replacement_graph, + copied_output = original_graph.graph_copy(suffixed_replacement_graph, val_map) # Hook the output Node of the replacement subgraph in to the @@ -366,20 +378,33 @@ def mark_node_as_replaced(n: Node) -> None: # original graph that corresponds to the end of the pattern # subgraph if subgraph_output.op != "output": - # `subgraph_output` may have multiple args. These args could - # be from the orignal graph, or they could have come from - # the insertion of `replacement_subgraph`. We need to find - # the Node that was originally matched as part of - # `pattern` (i.e. a Node from the original graph). We can - # figure this out by looking in `match.nodes_map`. The map - # was created before `replacement_subgraph` was spliced in, - # so we know that, if a Node is in `match.nodes_map.values`, - # it must have come from the original graph - for n in subgraph_output.all_input_nodes: - if (n.op != "placeholder" - and n in match.nodes_map.values()): - subgraph_output = n - break + pattern_outputs = [n for n in pattern_graph.nodes + if n.op == "output"] + assert len(pattern_outputs) + replacement_outputs = [n for n in suffixed_replacement_graph.nodes + if n.op == "output"] + assert len(replacement_outputs) == len(pattern_outputs) + outputs_map = {p: r for r, p + in zip(replacement_outputs, pattern_outputs)} + + for pn, gn in match.nodes_map.items(): + if gn.op == "placeholder": + continue + + # We search for the node corresponding to the output of the pattern. + if pn.op != "output": + continue + assert subgraph_output == gn + + # We update all anchor inputs to the new nodes + rn = outputs_map[pn] + for pn_input, rn_input in zip(pn.all_input_nodes, rn.all_input_nodes): + gn_input = match.nodes_map[pn_input] + rn_input_in_original_graph = val_map[rn_input] + gn.replace_input_with(gn_input, rn_input_in_original_graph) + # We store the updated node point in case other nodes want to use it + match_changed_node[gn_input] = rn_input_in_original_graph + assert subgraph_output.op != "output" # CASE 2: The pattern subgraph match extends to the end of the # original graph, so we need to change the current graph's @@ -392,8 +417,6 @@ def mark_node_as_replaced(n: Node) -> None: subgraph_output._input_nodes = {copied_output: None} assert isinstance(copied_output, Node) - subgraph_output.replace_all_uses_with(copied_output) - # Erase the `pattern` nodes for node in reversed(original_graph.nodes): if len(node.users) == 0 and node.op != "output": From 93f8ef2ec62c4ea5d27a4a5296a0d0c419b0f36e Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Mon, 11 Oct 2021 21:21:19 +0200 Subject: [PATCH 3/5] Fix --- torch/fx/subgraph_rewriter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 2fab89cfac16b..e64f428e52bf2 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -131,7 +131,7 @@ def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Mo gm.graph.lint() -def add_suffix_to_graph(graph, suffix): +def _add_suffix_to_graph(graph, suffix): for node in graph.nodes: node.name += suffix @@ -327,7 +327,7 @@ def overlaps_with_prev_match(match: Match) -> bool: continue suffixed_replacement_graph = replacement_graph.__deepcopy__() - add_suffix_to_graph(suffixed_replacement_graph, f"_{i}") + _add_suffix_to_graph(suffixed_replacement_graph, f"_{i}") # Map replacement graph nodes to their copy in `original_graph` val_map: Dict[Node, Node] = {} @@ -401,7 +401,7 @@ def mark_node_as_replaced(n: Node) -> None: for pn_input, rn_input in zip(pn.all_input_nodes, rn.all_input_nodes): gn_input = match.nodes_map[pn_input] rn_input_in_original_graph = val_map[rn_input] - gn.replace_input_with(gn_input, rn_input_in_original_graph) + gn_input.replace_all_uses_with(rn_input_in_original_graph) # We store the updated node point in case other nodes want to use it match_changed_node[gn_input] = rn_input_in_original_graph From bbe5a47f661ae4c635806b092766f776e4c591bd Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Mon, 11 Oct 2021 21:47:51 +0200 Subject: [PATCH 4/5] Lint --- test/fx/test_subgraph_rewriter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index 567c2fe1b4844..7d6eface81b42 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -489,4 +489,3 @@ def comparison(x): ref_outs = comparison_fn(x) test_outs = traced.forward(x) self.assertEqual(ref_outs, test_outs) - From 0c92850e1426b6d9269e53a638f95b3723793969 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 12 Oct 2021 20:02:40 +0200 Subject: [PATCH 5/5] Remove unecessary copies --- torch/fx/subgraph_rewriter.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index e64f428e52bf2..71e2d64b53982 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -131,10 +131,6 @@ def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Mo gm.graph.lint() -def _add_suffix_to_graph(graph, suffix): - for node in graph.nodes: - node.name += suffix - @compatibility(is_backward_compatible=True) def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]: """ @@ -326,15 +322,13 @@ def overlaps_with_prev_match(match: Match) -> bool: if overlaps_with_prev_match(match): continue - suffixed_replacement_graph = replacement_graph.__deepcopy__() - _add_suffix_to_graph(suffixed_replacement_graph, f"_{i}") # Map replacement graph nodes to their copy in `original_graph` val_map: Dict[Node, Node] = {} pattern_placeholders = [n for n in pattern_graph.nodes if n.op == "placeholder"] assert len(pattern_placeholders) - replacement_placeholders = [n for n in suffixed_replacement_graph.nodes + replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] assert len(pattern_placeholders) == len(replacement_placeholders) placeholder_map = {r: p for r, p @@ -367,7 +361,7 @@ def mark_node_as_replaced(n: Node) -> None: # Copy the replacement graph over with original_graph.inserting_before(subgraph_output): - copied_output = original_graph.graph_copy(suffixed_replacement_graph, + copied_output = original_graph.graph_copy(replacement_graph, val_map) # Hook the output Node of the replacement subgraph in to the @@ -381,7 +375,7 @@ def mark_node_as_replaced(n: Node) -> None: pattern_outputs = [n for n in pattern_graph.nodes if n.op == "output"] assert len(pattern_outputs) - replacement_outputs = [n for n in suffixed_replacement_graph.nodes + replacement_outputs = [n for n in replacement_graph.nodes if n.op == "output"] assert len(replacement_outputs) == len(pattern_outputs) outputs_map = {p: r for r, p