diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index 17cf05bb67d9a..7d6eface81b42 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -458,3 +458,34 @@ def forward(self, x): if n.op == 'placeholder': assert n.type == int assert m.type == int + + def test_subgraph_writer_replace_consecutive_submodules(self): + + def f(x): + x = torch.sigmoid(x) + x = torch.sigmoid(x) + return torch.sigmoid(x) + + def pattern(x): + return torch.sigmoid(x) + + def replacement(x): + return torch.exp(x) + + def comparison(x): + x = torch.exp(x) + x = torch.exp(x) + return torch.exp(x) + + traced = symbolic_trace(f) + comparison_fn = symbolic_trace(comparison) + + x = torch.randn(3, 4) + + subgraph_rewriter.replace_pattern(traced, pattern, replacement) + + traced.graph.lint() + + ref_outs = comparison_fn(x) + test_outs = traced.forward(x) + self.assertEqual(ref_outs, test_outs) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 72ea56aa31196..71e2d64b53982 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -261,11 +261,10 @@ def forward(self, x, w1, w2): if matcher.matches_subgraph_from_anchor(anchor): - def pattern_is_contained(nodes_map : Dict[Node, Node]) -> bool: + def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool: # `lookup` represents all the nodes in `original_graph` # that are part of `pattern` - lookup: Dict[Node, Node] = {v : k for k, v - in nodes_map.items()} + lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()} for n in lookup.keys(): # Nodes that can "leak"... @@ -295,25 +294,30 @@ def pattern_is_contained(nodes_map : Dict[Node, Node]) -> bool: # It's not a match if the pattern leaks out into the rest # of the graph if pattern_is_contained(matcher.nodes_map): - for k, v in matcher.nodes_map.items(): - # Shallow copy nodes_map - matches.append(Match(anchor=anchor, - nodes_map=copy.copy(matcher.nodes_map))) + # Shallow copy nodes_map + matches.append(Match(anchor=anchor, + nodes_map=copy.copy({ + key: value + for key, value in matcher.nodes_map.items() + }))) # The set of all nodes in `original_graph` that we've seen thus far # as part of a pattern match replaced_nodes: Set[Node] = set() + # As we progressively replace node, we need to keep track on how the match results need to change also + match_changed_node: Dict[Node, Node] = dict() # Return True if one of the nodes in the current match has already # been used as part of another match def overlaps_with_prev_match(match: Match) -> bool: - for n in match.nodes_map.values(): - if n in replaced_nodes and n.op != "placeholder": + for pn, gn in match.nodes_map.items(): + if pn.op in ["placeholder", "output"]: + continue + if gn in replaced_nodes and gn.op != "placeholder": return True return False - for match in matches: - + for i, match in enumerate(matches): # Skip overlapping matches if overlaps_with_prev_match(match): continue @@ -327,7 +331,7 @@ def overlaps_with_prev_match(match: Match) -> bool: replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] assert len(pattern_placeholders) == len(replacement_placeholders) - placeholder_map = {r : p for r, p + placeholder_map = {r: p for r, p in zip(replacement_placeholders, pattern_placeholders)} # node from `original_graph` that matched with the output node @@ -341,15 +345,17 @@ def mark_node_as_replaced(n: Node) -> None: mark_node_as_replaced(n_) replaced_nodes.add(n) - mark_node_as_replaced(subgraph_output) + for input_node in subgraph_output.all_input_nodes: + mark_node_as_replaced(input_node) - # Intialize `val_map` with mappings from placeholder nodes in + # Initialize `val_map` with mappings from placeholder nodes in # `replacement` to their corresponding node in `original_graph` for replacement_node in replacement_placeholders: # Get the `original_graph` placeholder node # corresponding to the current `replacement_node` pattern_node = placeholder_map[replacement_node] - original_graph_node = match.nodes_map[pattern_node] + original_graph_node = match_changed_node.get(match.nodes_map[pattern_node], match.nodes_map[pattern_node]) + # Populate `val_map` val_map[replacement_node] = original_graph_node @@ -366,20 +372,33 @@ def mark_node_as_replaced(n: Node) -> None: # original graph that corresponds to the end of the pattern # subgraph if subgraph_output.op != "output": - # `subgraph_output` may have multiple args. These args could - # be from the orignal graph, or they could have come from - # the insertion of `replacement_subgraph`. We need to find - # the Node that was originally matched as part of - # `pattern` (i.e. a Node from the original graph). We can - # figure this out by looking in `match.nodes_map`. The map - # was created before `replacement_subgraph` was spliced in, - # so we know that, if a Node is in `match.nodes_map.values`, - # it must have come from the original graph - for n in subgraph_output.all_input_nodes: - if (n.op != "placeholder" - and n in match.nodes_map.values()): - subgraph_output = n - break + pattern_outputs = [n for n in pattern_graph.nodes + if n.op == "output"] + assert len(pattern_outputs) + replacement_outputs = [n for n in replacement_graph.nodes + if n.op == "output"] + assert len(replacement_outputs) == len(pattern_outputs) + outputs_map = {p: r for r, p + in zip(replacement_outputs, pattern_outputs)} + + for pn, gn in match.nodes_map.items(): + if gn.op == "placeholder": + continue + + # We search for the node corresponding to the output of the pattern. + if pn.op != "output": + continue + assert subgraph_output == gn + + # We update all anchor inputs to the new nodes + rn = outputs_map[pn] + for pn_input, rn_input in zip(pn.all_input_nodes, rn.all_input_nodes): + gn_input = match.nodes_map[pn_input] + rn_input_in_original_graph = val_map[rn_input] + gn_input.replace_all_uses_with(rn_input_in_original_graph) + # We store the updated node point in case other nodes want to use it + match_changed_node[gn_input] = rn_input_in_original_graph + assert subgraph_output.op != "output" # CASE 2: The pattern subgraph match extends to the end of the # original graph, so we need to change the current graph's @@ -392,8 +411,6 @@ def mark_node_as_replaced(n: Node) -> None: subgraph_output._input_nodes = {copied_output: None} assert isinstance(copied_output, Node) - subgraph_output.replace_all_uses_with(copied_output) - # Erase the `pattern` nodes for node in reversed(original_graph.nodes): if len(node.users) == 0 and node.op != "output":