diff --git a/graph_net/test/dtype_gen_test.sh b/graph_net/test/dtype_gen_test.sh index 546835680..31528ef5a 100755 --- a/graph_net/test/dtype_gen_test.sh +++ b/graph_net/test/dtype_gen_test.sh @@ -20,7 +20,7 @@ python3 -m graph_net.apply_sample_pass \ "limits_handled_models": null } EOF -) +) # Step 2: Apply passes to generate samples python3 -m graph_net.apply_sample_pass \ diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index 926874ce3..718b39197 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -10,6 +10,27 @@ import torch import torch.fx as fx from graph_net.torch.dtype_gen_passes.pass_base import DtypeGeneralizationPass +import operator + +AMP_CALL_FUNCTION = { + torch.matmul, + torch.mm, + torch.bmm, + operator.matmul, + torch.nn.functional.linear, + torch.nn.functional.conv1d, + torch.nn.functional.conv2d, + torch.nn.functional.conv3d, + torch.nn.functional.scaled_dot_product_attention, + torch.addmm, + torch.einsum, +} + +AMP_CALL_METHOD = { + "matmul", + "mm", + "bmm", +} class ConcretePass(DtypeGeneralizationPass): @@ -86,14 +107,66 @@ def create_get_attr(node: fx.Node) -> fx.Node: return new_graph.call_method("to", args=(new_node, self.torch_dtype)) return new_node + def create_new_args(node: fx.Node) -> list: + """new_args of node with dtype conversion if needed.""" + new_args = [] + + for arg in node.args: + if isinstance(arg, fx.Node): + mapped = val_map[arg] + if self._is_float32_tensor(arg): + mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) + new_args.append(mapped) + else: + new_args.append(arg) + return new_args + + def create_call_function(node: fx.Node) -> fx.Node: + """Create a call_function node with dtype conversion if needed.""" + if node.target not in AMP_CALL_FUNCTION: + return new_graph.node_copy(node, lambda x: val_map[x]) + + new_args = create_new_args(node) + + new_kwargs = { + k: val_map[v] if isinstance(v, fx.Node) else v + for k, v in node.kwargs.items() + } + + return new_graph.call_function( + node.target, + args=tuple(new_args), + kwargs=new_kwargs, + ) + + def create_call_method(node: fx.Node) -> fx.Node: + if node.target not in AMP_CALL_METHOD: + return new_graph.node_copy(node, lambda x: val_map[x]) + + new_args = create_new_args(node) + + new_kwargs = { + k: (val_map[v] if isinstance(v, fx.Node) else v) + for k, v in node.kwargs.items() + } + + return new_graph.call_method( + node.target, + tuple(new_args), + new_kwargs, + ) + for node in gm.graph.nodes: if node.op == "placeholder": val_map[node] = create_placeholder(node) elif node.op == "get_attr": val_map[node] = create_get_attr(node) + elif node.op == "call_function": + val_map[node] = create_call_function(node) + elif node.op == "call_method": + val_map[node] = create_call_method(node) else: - new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x)) - val_map[node] = new_node + val_map[node] = new_graph.node_copy(node, lambda x: val_map.get(x, x)) # Replace the graph gm.graph = new_graph