From 72cc04f9e1bfc4b7558e2f53fab8ca37cf196fae Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Mon, 11 Oct 2021 19:24:26 +0200 Subject: [PATCH 1/7] Add failing test for torch.fx.replace_pattern --- test/fx/test_subgraph_rewriter.py | 32 +++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index 17cf05bb67d9a..567c2fe1b4844 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -458,3 +458,35 @@ def forward(self, x): if n.op == 'placeholder': assert n.type == int assert m.type == int + + def test_subgraph_writer_replace_consecutive_submodules(self): + + def f(x): + x = torch.sigmoid(x) + x = torch.sigmoid(x) + return torch.sigmoid(x) + + def pattern(x): + return torch.sigmoid(x) + + def replacement(x): + return torch.exp(x) + + def comparison(x): + x = torch.exp(x) + x = torch.exp(x) + return torch.exp(x) + + traced = symbolic_trace(f) + comparison_fn = symbolic_trace(comparison) + + x = torch.randn(3, 4) + + subgraph_rewriter.replace_pattern(traced, pattern, replacement) + + traced.graph.lint() + + ref_outs = comparison_fn(x) + test_outs = traced.forward(x) + self.assertEqual(ref_outs, test_outs) + From b26e21f77d6542322fa813c64eb41d855526bae7 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Mon, 11 Oct 2021 20:50:26 +0200 Subject: [PATCH 2/7] Proposed fix for pattern_matching --- torch/fx/subgraph_rewriter.py | 89 ++++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 33 deletions(-) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 72ea56aa31196..2fab89cfac16b 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -131,6 +131,10 @@ def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Mo gm.graph.lint() +def add_suffix_to_graph(graph, suffix): + for node in graph.nodes: + node.name += suffix + @compatibility(is_backward_compatible=True) def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]: """ @@ -261,11 +265,10 @@ def forward(self, x, w1, w2): if matcher.matches_subgraph_from_anchor(anchor): - def pattern_is_contained(nodes_map : Dict[Node, Node]) -> bool: + def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool: # `lookup` represents all the nodes in `original_graph` # that are part of `pattern` - lookup: Dict[Node, Node] = {v : k for k, v - in nodes_map.items()} + lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()} for n in lookup.keys(): # Nodes that can "leak"... @@ -295,39 +298,46 @@ def pattern_is_contained(nodes_map : Dict[Node, Node]) -> bool: # It's not a match if the pattern leaks out into the rest # of the graph if pattern_is_contained(matcher.nodes_map): - for k, v in matcher.nodes_map.items(): - # Shallow copy nodes_map - matches.append(Match(anchor=anchor, - nodes_map=copy.copy(matcher.nodes_map))) + # Shallow copy nodes_map + matches.append(Match(anchor=anchor, + nodes_map=copy.copy({ + key: value + for key, value in matcher.nodes_map.items() + }))) # The set of all nodes in `original_graph` that we've seen thus far # as part of a pattern match replaced_nodes: Set[Node] = set() + # As we progressively replace node, we need to keep track on how the match results need to change also + match_changed_node: Dict[Node, Node] = dict() # Return True if one of the nodes in the current match has already # been used as part of another match def overlaps_with_prev_match(match: Match) -> bool: - for n in match.nodes_map.values(): - if n in replaced_nodes and n.op != "placeholder": + for pn, gn in match.nodes_map.items(): + if pn.op in ["placeholder", "output"]: + continue + if gn in replaced_nodes and gn.op != "placeholder": return True return False - for match in matches: - + for i, match in enumerate(matches): # Skip overlapping matches if overlaps_with_prev_match(match): continue + suffixed_replacement_graph = replacement_graph.__deepcopy__() + add_suffix_to_graph(suffixed_replacement_graph, f"_{i}") # Map replacement graph nodes to their copy in `original_graph` val_map: Dict[Node, Node] = {} pattern_placeholders = [n for n in pattern_graph.nodes if n.op == "placeholder"] assert len(pattern_placeholders) - replacement_placeholders = [n for n in replacement_graph.nodes + replacement_placeholders = [n for n in suffixed_replacement_graph.nodes if n.op == "placeholder"] assert len(pattern_placeholders) == len(replacement_placeholders) - placeholder_map = {r : p for r, p + placeholder_map = {r: p for r, p in zip(replacement_placeholders, pattern_placeholders)} # node from `original_graph` that matched with the output node @@ -341,21 +351,23 @@ def mark_node_as_replaced(n: Node) -> None: mark_node_as_replaced(n_) replaced_nodes.add(n) - mark_node_as_replaced(subgraph_output) + for input_node in subgraph_output.all_input_nodes: + mark_node_as_replaced(input_node) - # Intialize `val_map` with mappings from placeholder nodes in + # Initialize `val_map` with mappings from placeholder nodes in # `replacement` to their corresponding node in `original_graph` for replacement_node in replacement_placeholders: # Get the `original_graph` placeholder node # corresponding to the current `replacement_node` pattern_node = placeholder_map[replacement_node] - original_graph_node = match.nodes_map[pattern_node] + original_graph_node = match_changed_node.get(match.nodes_map[pattern_node], match.nodes_map[pattern_node]) + # Populate `val_map` val_map[replacement_node] = original_graph_node # Copy the replacement graph over with original_graph.inserting_before(subgraph_output): - copied_output = original_graph.graph_copy(replacement_graph, + copied_output = original_graph.graph_copy(suffixed_replacement_graph, val_map) # Hook the output Node of the replacement subgraph in to the @@ -366,20 +378,33 @@ def mark_node_as_replaced(n: Node) -> None: # original graph that corresponds to the end of the pattern # subgraph if subgraph_output.op != "output": - # `subgraph_output` may have multiple args. These args could - # be from the orignal graph, or they could have come from - # the insertion of `replacement_subgraph`. We need to find - # the Node that was originally matched as part of - # `pattern` (i.e. a Node from the original graph). We can - # figure this out by looking in `match.nodes_map`. The map - # was created before `replacement_subgraph` was spliced in, - # so we know that, if a Node is in `match.nodes_map.values`, - # it must have come from the original graph - for n in subgraph_output.all_input_nodes: - if (n.op != "placeholder" - and n in match.nodes_map.values()): - subgraph_output = n - break + pattern_outputs = [n for n in pattern_graph.nodes + if n.op == "output"] + assert len(pattern_outputs) + replacement_outputs = [n for n in suffixed_replacement_graph.nodes + if n.op == "output"] + assert len(replacement_outputs) == len(pattern_outputs) + outputs_map = {p: r for r, p + in zip(replacement_outputs, pattern_outputs)} + + for pn, gn in match.nodes_map.items(): + if gn.op == "placeholder": + continue + + # We search for the node corresponding to the output of the pattern. + if pn.op != "output": + continue + assert subgraph_output == gn + + # We update all anchor inputs to the new nodes + rn = outputs_map[pn] + for pn_input, rn_input in zip(pn.all_input_nodes, rn.all_input_nodes): + gn_input = match.nodes_map[pn_input] + rn_input_in_original_graph = val_map[rn_input] + gn.replace_input_with(gn_input, rn_input_in_original_graph) + # We store the updated node point in case other nodes want to use it + match_changed_node[gn_input] = rn_input_in_original_graph + assert subgraph_output.op != "output" # CASE 2: The pattern subgraph match extends to the end of the # original graph, so we need to change the current graph's @@ -392,8 +417,6 @@ def mark_node_as_replaced(n: Node) -> None: subgraph_output._input_nodes = {copied_output: None} assert isinstance(copied_output, Node) - subgraph_output.replace_all_uses_with(copied_output) - # Erase the `pattern` nodes for node in reversed(original_graph.nodes): if len(node.users) == 0 and node.op != "output": From 93f8ef2ec62c4ea5d27a4a5296a0d0c419b0f36e Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Mon, 11 Oct 2021 21:21:19 +0200 Subject: [PATCH 3/7] Fix --- torch/fx/subgraph_rewriter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 2fab89cfac16b..e64f428e52bf2 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -131,7 +131,7 @@ def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Mo gm.graph.lint() -def add_suffix_to_graph(graph, suffix): +def _add_suffix_to_graph(graph, suffix): for node in graph.nodes: node.name += suffix @@ -327,7 +327,7 @@ def overlaps_with_prev_match(match: Match) -> bool: continue suffixed_replacement_graph = replacement_graph.__deepcopy__() - add_suffix_to_graph(suffixed_replacement_graph, f"_{i}") + _add_suffix_to_graph(suffixed_replacement_graph, f"_{i}") # Map replacement graph nodes to their copy in `original_graph` val_map: Dict[Node, Node] = {} @@ -401,7 +401,7 @@ def mark_node_as_replaced(n: Node) -> None: for pn_input, rn_input in zip(pn.all_input_nodes, rn.all_input_nodes): gn_input = match.nodes_map[pn_input] rn_input_in_original_graph = val_map[rn_input] - gn.replace_input_with(gn_input, rn_input_in_original_graph) + gn_input.replace_all_uses_with(rn_input_in_original_graph) # We store the updated node point in case other nodes want to use it match_changed_node[gn_input] = rn_input_in_original_graph From bbe5a47f661ae4c635806b092766f776e4c591bd Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Mon, 11 Oct 2021 21:47:51 +0200 Subject: [PATCH 4/7] Lint --- test/fx/test_subgraph_rewriter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index 567c2fe1b4844..7d6eface81b42 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -489,4 +489,3 @@ def comparison(x): ref_outs = comparison_fn(x) test_outs = traced.forward(x) self.assertEqual(ref_outs, test_outs) - From 0c92850e1426b6d9269e53a638f95b3723793969 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 12 Oct 2021 20:02:40 +0200 Subject: [PATCH 5/7] Remove unecessary copies --- torch/fx/subgraph_rewriter.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index e64f428e52bf2..71e2d64b53982 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -131,10 +131,6 @@ def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Mo gm.graph.lint() -def _add_suffix_to_graph(graph, suffix): - for node in graph.nodes: - node.name += suffix - @compatibility(is_backward_compatible=True) def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]: """ @@ -326,15 +322,13 @@ def overlaps_with_prev_match(match: Match) -> bool: if overlaps_with_prev_match(match): continue - suffixed_replacement_graph = replacement_graph.__deepcopy__() - _add_suffix_to_graph(suffixed_replacement_graph, f"_{i}") # Map replacement graph nodes to their copy in `original_graph` val_map: Dict[Node, Node] = {} pattern_placeholders = [n for n in pattern_graph.nodes if n.op == "placeholder"] assert len(pattern_placeholders) - replacement_placeholders = [n for n in suffixed_replacement_graph.nodes + replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] assert len(pattern_placeholders) == len(replacement_placeholders) placeholder_map = {r: p for r, p @@ -367,7 +361,7 @@ def mark_node_as_replaced(n: Node) -> None: # Copy the replacement graph over with original_graph.inserting_before(subgraph_output): - copied_output = original_graph.graph_copy(suffixed_replacement_graph, + copied_output = original_graph.graph_copy(replacement_graph, val_map) # Hook the output Node of the replacement subgraph in to the @@ -381,7 +375,7 @@ def mark_node_as_replaced(n: Node) -> None: pattern_outputs = [n for n in pattern_graph.nodes if n.op == "output"] assert len(pattern_outputs) - replacement_outputs = [n for n in suffixed_replacement_graph.nodes + replacement_outputs = [n for n in replacement_graph.nodes if n.op == "output"] assert len(replacement_outputs) == len(pattern_outputs) outputs_map = {p: r for r, p From 21b3882b9b0a4ec70e3805a581cb83cb960733d4 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Wed, 13 Oct 2021 16:04:41 +0200 Subject: [PATCH 6/7] [torch.fx] Fix pattern matching the same node multiple times --- test/fx/test_subgraph_rewriter.py | 58 ++++++++++ torch/fx/subgraph_rewriter.py | 183 +++++++++++++++--------------- 2 files changed, 148 insertions(+), 93 deletions(-) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index 7d6eface81b42..70ff66a7a7d16 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -489,3 +489,61 @@ def comparison(x): ref_outs = comparison_fn(x) test_outs = traced.forward(x) self.assertEqual(ref_outs, test_outs) + + def test_subgraph_rewriter_replaces_parallel_functions(self): + def f(x): + y = torch.sigmoid(x) + z = torch.sigmoid(x) + return y, z + + def pattern(x): + return torch.sigmoid(x) + + def replacement(x): + return torch.relu(x) + + def comparison(x): + y = torch.relu(x) + z = torch.relu(x) + return y, z + + traced = symbolic_trace(f) + + print(traced.graph) + subgraph_rewriter.replace_pattern(traced, pattern, replacement) + print(traced.graph) + traced.graph.lint() + + x = torch.randn(3, 4) + ref_outs = comparison(x) + test_outs = traced.forward(x) + self.assertEqual(ref_outs, test_outs) + + def test_subgraph_rewriter_replaces_parallel_functions_when_agregated(self): + def f(x): + y = torch.sigmoid(x) + z = torch.sigmoid(x) + return y + z + + def pattern(x): + return torch.sigmoid(x) + + def replacement(x): + return torch.relu(x) + + def comparison(x): + y = torch.relu(x) + z = torch.relu(x) + return y + z + + traced = symbolic_trace(f) + + print(traced.graph) + subgraph_rewriter.replace_pattern(traced, pattern, replacement) + print(traced.graph) + traced.graph.lint() + + x = torch.randn(3, 4) + ref_outs = comparison(x) + test_outs = traced.forward(x) + self.assertEqual(ref_outs, test_outs) \ No newline at end of file diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 71e2d64b53982..c466592ded194 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -28,9 +28,9 @@ def __init__(self, pattern: Graph) -> None: assert len(self.pattern_anchor.all_input_nodes) == 1, \ "Pattern matching on multiple outputs is not supported" # Maps nodes in the pattern subgraph to nodes in the larger graph - self.nodes_map: Dict[Node, Node] = {} + self.nodes_map: List[Dict[Node, Node]] = [{}] - def matches_subgraph_from_anchor(self, anchor: Node) -> bool: + def matches_subgraph_from_anchor(self, anchor: Node) -> List[Dict[Node, Node]]: """ Checks if the whole pattern can be matched starting from ``anchor`` in the larger graph. @@ -38,16 +38,20 @@ def matches_subgraph_from_anchor(self, anchor: Node) -> bool: Pattern matching is done by recursively comparing the pattern node's use-def relationships against the graph node's. """ - self.nodes_map = {} - return self._match_nodes(self.pattern_anchor, anchor) + self.nodes_map: List[Dict[Node, Node]] = [{}] + self._match_nodes(self.pattern_anchor, anchor) + + # We need to filter out the one that are empty + self.nodes_map = [elt for elt in self.nodes_map if len(elt) > 0] + return self.nodes_map # Compare the pattern node `pn` against the graph node `gn` - def _match_nodes(self, pn: Node, gn: Node) -> bool: + def _match_nodes(self, pn: Node, gn: Node, graph_id: int = 0) -> bool: # Check if we've already matched these nodes in the current # traversal - if pn in self.nodes_map: - return self.nodes_map[pn] == gn + if pn in self.nodes_map[graph_id]: + return self.nodes_map[graph_id][pn] == gn def attributes_are_equal(pn: Node, gn: Node) -> bool: # Use placeholder and output nodes as wildcards. The @@ -63,7 +67,7 @@ def attributes_are_equal(pn: Node, gn: Node) -> bool: return False # Optimistically mark `pn` as a match for `gn` - self.nodes_map[pn] = gn + self.nodes_map[graph_id][pn] = gn # Traverse the use-def relationships to ensure that `pn` is a true # match for `gn` @@ -73,14 +77,22 @@ def attributes_are_equal(pn: Node, gn: Node) -> bool: and len(pn.all_input_nodes) != len(gn.all_input_nodes)): return False if pn.op == "output": - match_found = any(self._match_nodes(pn.all_input_nodes[0], gn_) - for gn_ in gn.all_input_nodes) + # Only the first graph compares the output. + assert graph_id == 0 + # We broadcast the result to all the other potential graph matching. + self.nodes_map += [copy.copy(self.nodes_map[graph_id]) for _ in range(len(gn.all_input_nodes) - 1)] + all_matches = tuple(self._match_nodes(pn.all_input_nodes[0], gn_, graph_id_) + for graph_id_, gn_ in enumerate(gn.all_input_nodes) + ) + self.nodes_map = [node_map for node_map, match in zip(self.nodes_map, all_matches) if match] + # This is not really needed to return that value + return any(all_matches) else: match_found = (len(pn.all_input_nodes) == len(gn.all_input_nodes) - and all(self._match_nodes(pn_, gn_) for pn_, gn_ + and all(self._match_nodes(pn_, gn_, graph_id) for pn_, gn_ in zip(pn.all_input_nodes, gn.all_input_nodes))) if not match_found: - self.nodes_map.pop(pn) + self.nodes_map[graph_id].pop(pn) return False return True @@ -256,50 +268,49 @@ def forward(self, x, w1, w2): matcher = _SubgraphMatcher(pattern_graph) matches: List[Match] = [] - # Consider each node as an "anchor" (deepest matching graph node) - for anchor in original_graph.nodes: + def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool: + # `lookup` represents all the nodes in `original_graph` + # that are part of `pattern` + lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()} + for n in lookup.keys(): - if matcher.matches_subgraph_from_anchor(anchor): - - def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool: - # `lookup` represents all the nodes in `original_graph` - # that are part of `pattern` - lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()} - for n in lookup.keys(): - - # Nodes that can "leak"... - - # Placeholders (by definition) - if n.op == "placeholder": - continue - # Pattern output (acts as a container) - if lookup[n].op == "output": - continue - # Result contained by pattern output (what we'll - # hook in to the new Graph, thus what we'll - # potentially use in other areas of the Graph as - # an input Node) - if (len(lookup[n].users) == 1 - and list(lookup[n].users.keys())[0].op == "output"): - continue - - for user in n.users: - # If this node has users that were not in - # `lookup`, then it must leak out of the - # pattern subgraph - if user not in lookup: - return False - return True + # Nodes that can "leak"... + + # Placeholders (by definition) + if n.op == "placeholder": + continue + # Pattern output (acts as a container) + if lookup[n].op == "output": + continue + # Placeholders (by definition) + if lookup[n].op == "placeholder": + continue + # Result contained by pattern output (what we'll + # hook in to the new Graph, thus what we'll + # potentially use in other areas of the Graph as + # an input Node) + if (len(lookup[n].users) == 1 + and list(lookup[n].users.keys())[0].op == "output"): + continue + + for user in n.users: + # If this node has users that were not in + # `lookup`, then it must leak out of the + # pattern subgraph + if user not in lookup: + return False + return True - # It's not a match if the pattern leaks out into the rest - # of the graph - if pattern_is_contained(matcher.nodes_map): + # Consider each node as an "anchor" (deepest matching graph node) + for anchor in original_graph.nodes: + potential_matches = matcher.matches_subgraph_from_anchor(anchor) + # It's not a match if the pattern leaks out into the rest + # of the graph + for node_map in potential_matches: + if pattern_is_contained(node_map): # Shallow copy nodes_map matches.append(Match(anchor=anchor, - nodes_map=copy.copy({ - key: value - for key, value in matcher.nodes_map.items() - }))) + nodes_map=copy.copy(node_map))) # The set of all nodes in `original_graph` that we've seen thus far # as part of a pattern match @@ -367,48 +378,34 @@ def mark_node_as_replaced(n: Node) -> None: # Hook the output Node of the replacement subgraph in to the # original Graph at the correct location - # CASE 1: We need to hook the replacement subgraph in somewhere - # in the middle of the graph. We replace the Node in the - # original graph that corresponds to the end of the pattern - # subgraph - if subgraph_output.op != "output": - pattern_outputs = [n for n in pattern_graph.nodes + pattern_outputs = [n for n in pattern_graph.nodes + if n.op == "output"] + assert len(pattern_outputs) + replacement_outputs = [n for n in replacement_graph.nodes if n.op == "output"] - assert len(pattern_outputs) - replacement_outputs = [n for n in replacement_graph.nodes - if n.op == "output"] - assert len(replacement_outputs) == len(pattern_outputs) - outputs_map = {p: r for r, p - in zip(replacement_outputs, pattern_outputs)} - - for pn, gn in match.nodes_map.items(): - if gn.op == "placeholder": - continue - - # We search for the node corresponding to the output of the pattern. - if pn.op != "output": - continue - assert subgraph_output == gn - - # We update all anchor inputs to the new nodes - rn = outputs_map[pn] - for pn_input, rn_input in zip(pn.all_input_nodes, rn.all_input_nodes): - gn_input = match.nodes_map[pn_input] - rn_input_in_original_graph = val_map[rn_input] - gn_input.replace_all_uses_with(rn_input_in_original_graph) - # We store the updated node point in case other nodes want to use it - match_changed_node[gn_input] = rn_input_in_original_graph - - assert subgraph_output.op != "output" - # CASE 2: The pattern subgraph match extends to the end of the - # original graph, so we need to change the current graph's - # output Node to reflect the insertion of the replacement graph. - # We'll keep the current output Node, but update its args and - # `_input_nodes` as necessary - else: - subgraph_output.args = ((copied_output,)) - if isinstance(copied_output, Node): - subgraph_output._input_nodes = {copied_output: None} + assert len(replacement_outputs) == len(pattern_outputs) + outputs_map = {p: r for r, p + in zip(replacement_outputs, pattern_outputs)} + + for pn, gn in match.nodes_map.items(): + if gn.op == "placeholder": + continue + + # We search for the node corresponding to the output of the pattern. + if pn.op != "output": + continue + + # the anchor should correspond to `subgraph_output` + assert subgraph_output == gn + + # We update all anchor inputs to the new nodes + rn = outputs_map[pn] + for pn_input, rn_input in zip(pn.all_input_nodes, rn.all_input_nodes): + gn_input = match.nodes_map[pn_input] + rn_input_in_original_graph = val_map[rn_input] + gn_input.replace_all_uses_with(rn_input_in_original_graph) + # We store the updated node point in case other nodes want to use it + match_changed_node[gn_input] = rn_input_in_original_graph assert isinstance(copied_output, Node) # Erase the `pattern` nodes From b262b84e5a58286dd4a623a2e79fffbd797de5c8 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Thu, 14 Oct 2021 17:45:02 +0200 Subject: [PATCH 7/7] Remove prints --- test/fx/test_subgraph_rewriter.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index 70ff66a7a7d16..9e38969700726 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -509,9 +509,7 @@ def comparison(x): traced = symbolic_trace(f) - print(traced.graph) subgraph_rewriter.replace_pattern(traced, pattern, replacement) - print(traced.graph) traced.graph.lint() x = torch.randn(3, 4) @@ -519,7 +517,7 @@ def comparison(x): test_outs = traced.forward(x) self.assertEqual(ref_outs, test_outs) - def test_subgraph_rewriter_replaces_parallel_functions_when_agregated(self): + def test_subgraph_rewriter_replaces_parallel_functions_when_aggregated(self): def f(x): y = torch.sigmoid(x) z = torch.sigmoid(x) @@ -538,9 +536,7 @@ def comparison(x): traced = symbolic_trace(f) - print(traced.graph) subgraph_rewriter.replace_pattern(traced, pattern, replacement) - print(traced.graph) traced.graph.lint() x = torch.randn(3, 4)