From 5dbe5adecf3fbad15ff788822169183df49f63d1 Mon Sep 17 00:00:00 2001 From: WHoutstanding Date: Tue, 20 Jan 2026 14:35:31 +0800 Subject: [PATCH 1/9] fix dtype_generalization_pass.py --- graph_net/test/dtype_gen_test.sh | 77 ++++++++++--------- .../dtype_generalization_pass.py | 39 +++++++++- 2 files changed, 77 insertions(+), 39 deletions(-) diff --git a/graph_net/test/dtype_gen_test.sh b/graph_net/test/dtype_gen_test.sh index 546835680..a65e0d7ef 100755 --- a/graph_net/test/dtype_gen_test.sh +++ b/graph_net/test/dtype_gen_test.sh @@ -7,8 +7,9 @@ OUTPUT_DIR="/tmp/dtype_gen_samples" mkdir -p "$OUTPUT_DIR" # Step 1: Initialize dtype generalization passes (samples of torchvision) -python3 -m graph_net.apply_sample_pass \ - --model-path-list "graph_net/config/small100_torch_samples_list.txt" \ +# python3 -m graph_net.apply_sample_pass \ +python3 -m pdb -m graph_net.apply_sample_pass \ + --model-path-list "graph_net/config/f16_error_samples.txt" \ --sample-pass-file-path "$GRAPH_NET_ROOT/torch/sample_pass/dtype_generalizer.py" \ --sample-pass-class-name InitDataTypeGeneralizationPasses \ --sample-pass-config $(base64 -w 0 <&1) +# output=$(python -m graph_net.torch.validate \ +# --model-path "$model_path" 2>&1) - if echo "$output" | grep -q "Validation success, model_path="; then - echo "SUCCESS" - ((SUCCESS_CNT++)) - else - echo "FAIL" - ((FAIL_CNT++)) - fi -done +# if echo "$output" | grep -q "Validation success, model_path="; then +# echo "SUCCESS" +# ((SUCCESS_CNT++)) +# else +# echo "FAIL" +# ((FAIL_CNT++)) +# fi +# done -echo "====================" -echo "SUCCESS $SUCCESS_CNT" -echo "FAIL $FAIL_CNT" \ No newline at end of file +# echo "====================" +# echo "SUCCESS $SUCCESS_CNT" +# echo "FAIL $FAIL_CNT" \ No newline at end of file diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index 926874ce3..1d971e1b7 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -85,12 +85,49 @@ def create_get_attr(node: fx.Node) -> fx.Node: ): return new_graph.call_method("to", args=(new_node, self.torch_dtype)) return new_node - + + def create_call_function(node: fx.Node) -> fx.Node: + """Create a call_function node with dtype conversion if needed.""" + if node.target in ( + torch.matmul, + torch.nn.functional.linear, + torch.nn.functional.conv2d, + torch.bmm, + torch.nn.functional.scaled_dot_product_attention, + ): + new_args = [] + + for arg in node.args: + if isinstance(arg, fx.Node): + mapped = val_map[arg] + if self._is_float32_tensor(arg): + mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) + new_args.append(mapped) + else: + new_args.append(arg) + + new_kwargs = { + k: val_map[v] if isinstance(v, fx.Node) else v + for k, v in node.kwargs.items() + } + + new_node = new_graph.call_function( + node.target, + args=tuple(new_args), + kwargs=new_kwargs, + ) + + return new_node + else: + return new_graph.node_copy(node, lambda x: val_map[x]) + for node in gm.graph.nodes: if node.op == "placeholder": val_map[node] = create_placeholder(node) elif node.op == "get_attr": val_map[node] = create_get_attr(node) + elif node.op == "call_function": + val_map[node] = create_call_function(node) else: new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x)) val_map[node] = new_node From d85abf4e9048f8a745b873a1e6c95c0b4c6d8a70 Mon Sep 17 00:00:00 2001 From: WHoutstanding Date: Tue, 20 Jan 2026 14:37:16 +0800 Subject: [PATCH 2/9] fix dtype_generalization_pass.py --- graph_net/test/dtype_gen_test.sh | 74 ++++++++++++++++---------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/graph_net/test/dtype_gen_test.sh b/graph_net/test/dtype_gen_test.sh index a65e0d7ef..e24687248 100755 --- a/graph_net/test/dtype_gen_test.sh +++ b/graph_net/test/dtype_gen_test.sh @@ -23,46 +23,46 @@ python3 -m pdb -m graph_net.apply_sample_pass \ EOF ) -# Step 2: Apply passes to generate samples -# python3 -m graph_net.apply_sample_pass \ -# --model-path-list "graph_net/config/small100_torch_samples_list.txt" \ -# --sample-pass-file-path "$GRAPH_NET_ROOT/torch/sample_pass/dtype_generalizer.py" \ -# --sample-pass-class-name ApplyDataTypeGeneralizationPasses \ -# --sample-pass-config $(base64 -w 0 <&1) + output=$(python -m graph_net.torch.validate \ + --model-path "$model_path" 2>&1) -# if echo "$output" | grep -q "Validation success, model_path="; then -# echo "SUCCESS" -# ((SUCCESS_CNT++)) -# else -# echo "FAIL" -# ((FAIL_CNT++)) -# fi -# done + if echo "$output" | grep -q "Validation success, model_path="; then + echo "SUCCESS" + ((SUCCESS_CNT++)) + else + echo "FAIL" + ((FAIL_CNT++)) + fi +done -# echo "====================" -# echo "SUCCESS $SUCCESS_CNT" -# echo "FAIL $FAIL_CNT" \ No newline at end of file +echo "====================" +echo "SUCCESS $SUCCESS_CNT" +echo "FAIL $FAIL_CNT" \ No newline at end of file From f454c1409bba21ba0e4e3cb18c0ffe1de5b48279 Mon Sep 17 00:00:00 2001 From: WHoutstanding Date: Tue, 20 Jan 2026 14:38:50 +0800 Subject: [PATCH 3/9] fix dtype_generalization_pass.py --- graph_net/test/dtype_gen_test.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/graph_net/test/dtype_gen_test.sh b/graph_net/test/dtype_gen_test.sh index e24687248..fb3ced4d5 100755 --- a/graph_net/test/dtype_gen_test.sh +++ b/graph_net/test/dtype_gen_test.sh @@ -7,9 +7,8 @@ OUTPUT_DIR="/tmp/dtype_gen_samples" mkdir -p "$OUTPUT_DIR" # Step 1: Initialize dtype generalization passes (samples of torchvision) -# python3 -m graph_net.apply_sample_pass \ python3 -m pdb -m graph_net.apply_sample_pass \ - --model-path-list "graph_net/config/f16_error_samples.txt" \ + --model-path-list "graph_net/config/small100_torch_samples_list.txt" \ --sample-pass-file-path "$GRAPH_NET_ROOT/torch/sample_pass/dtype_generalizer.py" \ --sample-pass-class-name InitDataTypeGeneralizationPasses \ --sample-pass-config $(base64 -w 0 < Date: Tue, 20 Jan 2026 14:40:50 +0800 Subject: [PATCH 4/9] fix dtype_generalization_pass.py --- graph_net/test/dtype_gen_test.sh | 2 +- .../dtype_gen_passes/dtype_generalization_pass.py | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/graph_net/test/dtype_gen_test.sh b/graph_net/test/dtype_gen_test.sh index fb3ced4d5..3be617f13 100755 --- a/graph_net/test/dtype_gen_test.sh +++ b/graph_net/test/dtype_gen_test.sh @@ -7,7 +7,7 @@ OUTPUT_DIR="/tmp/dtype_gen_samples" mkdir -p "$OUTPUT_DIR" # Step 1: Initialize dtype generalization passes (samples of torchvision) -python3 -m pdb -m graph_net.apply_sample_pass \ +python3 -m graph_net.apply_sample_pass \ --model-path-list "graph_net/config/small100_torch_samples_list.txt" \ --sample-pass-file-path "$GRAPH_NET_ROOT/torch/sample_pass/dtype_generalizer.py" \ --sample-pass-class-name InitDataTypeGeneralizationPasses \ diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index 1d971e1b7..01aab34b7 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -85,10 +85,10 @@ def create_get_attr(node: fx.Node) -> fx.Node: ): return new_graph.call_method("to", args=(new_node, self.torch_dtype)) return new_node - + def create_call_function(node: fx.Node) -> fx.Node: - """Create a call_function node with dtype conversion if needed.""" - if node.target in ( + """Create a call_function node with dtype conversion if needed.""" + if node.target in ( torch.matmul, torch.nn.functional.linear, torch.nn.functional.conv2d, @@ -101,7 +101,9 @@ def create_call_function(node: fx.Node) -> fx.Node: if isinstance(arg, fx.Node): mapped = val_map[arg] if self._is_float32_tensor(arg): - mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) + mapped = new_graph.call_method( + "to", (mapped, self.torch_dtype) + ) new_args.append(mapped) else: new_args.append(arg) @@ -120,7 +122,7 @@ def create_call_function(node: fx.Node) -> fx.Node: return new_node else: return new_graph.node_copy(node, lambda x: val_map[x]) - + for node in gm.graph.nodes: if node.op == "placeholder": val_map[node] = create_placeholder(node) From 5c771d04a9212bdb569b5427fdd19e3193dd70c5 Mon Sep 17 00:00:00 2001 From: WHoutstanding Date: Tue, 20 Jan 2026 14:42:17 +0800 Subject: [PATCH 5/9] fix dtype_generalization_pass.py --- graph_net/test/dtype_gen_test.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graph_net/test/dtype_gen_test.sh b/graph_net/test/dtype_gen_test.sh index 3be617f13..31528ef5a 100755 --- a/graph_net/test/dtype_gen_test.sh +++ b/graph_net/test/dtype_gen_test.sh @@ -22,7 +22,7 @@ python3 -m graph_net.apply_sample_pass \ EOF ) -Step 2: Apply passes to generate samples +# Step 2: Apply passes to generate samples python3 -m graph_net.apply_sample_pass \ --model-path-list "graph_net/config/small100_torch_samples_list.txt" \ --sample-pass-file-path "$GRAPH_NET_ROOT/torch/sample_pass/dtype_generalizer.py" \ @@ -43,7 +43,7 @@ EOF ) -Step 3: Valiation +# Step 3: Valiation SUCCESS_CNT=0 FAIL_CNT=0 From 18f1a82b3dcf66f2f033e9f3c3dc95deb8a65ef2 Mon Sep 17 00:00:00 2001 From: WHoutstanding Date: Wed, 21 Jan 2026 17:12:20 +0800 Subject: [PATCH 6/9] fix dtype_generalization.py --- .../dtype_generalization_pass.py | 109 ++++++++++++------ 1 file changed, 73 insertions(+), 36 deletions(-) diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index 01aab34b7..52b4473f2 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -11,6 +11,22 @@ import torch.fx as fx from graph_net.torch.dtype_gen_passes.pass_base import DtypeGeneralizationPass +AMP_CALL_FUNCTION = { + torch.matmul, + torch.mm, + torch.bmm, + torch.nn.functional.linear, + torch.nn.functional.conv1d, + torch.nn.functional.conv2d, + torch.nn.functional.conv3d, + torch.nn.functional.scaled_dot_product_attention, +} + +AMP_CALL_METHOD = { + "matmul", + "mm", + "bmm", +} class ConcretePass(DtypeGeneralizationPass): """ @@ -86,43 +102,62 @@ def create_get_attr(node: fx.Node) -> fx.Node: return new_graph.call_method("to", args=(new_node, self.torch_dtype)) return new_node + def create_call_function(node: fx.Node) -> fx.Node: - """Create a call_function node with dtype conversion if needed.""" - if node.target in ( - torch.matmul, - torch.nn.functional.linear, - torch.nn.functional.conv2d, - torch.bmm, - torch.nn.functional.scaled_dot_product_attention, - ): - new_args = [] - - for arg in node.args: - if isinstance(arg, fx.Node): - mapped = val_map[arg] - if self._is_float32_tensor(arg): - mapped = new_graph.call_method( - "to", (mapped, self.torch_dtype) - ) - new_args.append(mapped) - else: - new_args.append(arg) - - new_kwargs = { - k: val_map[v] if isinstance(v, fx.Node) else v - for k, v in node.kwargs.items() - } - - new_node = new_graph.call_function( - node.target, - args=tuple(new_args), - kwargs=new_kwargs, - ) - - return new_node - else: + """Create a call_function node with dtype conversion if needed.""" + if node.target not in AMP_CALL_FUNCTION: + return new_graph.node_copy(node, lambda x: val_map[x]) + + new_args = [] + + for arg in node.args: + if isinstance(arg, fx.Node): + mapped = val_map[arg] + if self._is_float32_tensor(arg): + mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) + new_args.append(mapped) + else: + new_args.append(arg) + + new_kwargs = { + k: val_map[v] if isinstance(v, fx.Node) else v + for k, v in node.kwargs.items() + } + + return new_graph.call_function( + node.target, + args=tuple(new_args), + kwargs=new_kwargs, + ) + + def create_call_method(node: fx.Node) -> fx.Node: + if node.target not in AMP_CALL_METHOD: return new_graph.node_copy(node, lambda x: val_map[x]) + new_args = [] + for arg in node.args: + if isinstance(arg, fx.Node): + mapped = val_map[arg] + if self._is_float32_tensor(arg): + mapped = new_graph.call_method( + "to", (mapped, self.torch_dtype) + ) + new_args.append(mapped) + else: + new_args.append(arg) + + new_kwargs = { + k: (val_map[v] if isinstance(v, fx.Node) else v) + for k, v in node.kwargs.items() + } + + return new_graph.call_method( + node.target, + tuple(new_args), + new_kwargs, + ) + + for node in gm.graph.nodes: if node.op == "placeholder": val_map[node] = create_placeholder(node) @@ -130,9 +165,10 @@ def create_call_function(node: fx.Node) -> fx.Node: val_map[node] = create_get_attr(node) elif node.op == "call_function": val_map[node] = create_call_function(node) + elif node.op == "call_method": + val_map[node] = create_call_method(node) else: - new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x)) - val_map[node] = new_node + val_map[node] = new_graph.node_copy(node, lambda x: val_map.get(x, x)) # Replace the graph gm.graph = new_graph @@ -140,6 +176,7 @@ def create_call_function(node: fx.Node) -> fx.Node: return gm + def _is_float32_tensor(self, node: fx.Node) -> bool: """ Check if a node represents a float32 tensor. From a8db2775411ad867481a3dd2a34a2248bee1311b Mon Sep 17 00:00:00 2001 From: WHoutstanding Date: Wed, 21 Jan 2026 17:16:24 +0800 Subject: [PATCH 7/9] fix pre-commit of dtype_generalization.py --- .../dtype_gen_passes/dtype_generalization_pass.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index 52b4473f2..a79f9b60b 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -28,6 +28,7 @@ "bmm", } + class ConcretePass(DtypeGeneralizationPass): """ FX Graph pass that converts dtypes of tensors. @@ -102,9 +103,8 @@ def create_get_attr(node: fx.Node) -> fx.Node: return new_graph.call_method("to", args=(new_node, self.torch_dtype)) return new_node - def create_call_function(node: fx.Node) -> fx.Node: - """Create a call_function node with dtype conversion if needed.""" + """Create a call_function node with dtype conversion if needed.""" if node.target not in AMP_CALL_FUNCTION: return new_graph.node_copy(node, lambda x: val_map[x]) @@ -139,9 +139,7 @@ def create_call_method(node: fx.Node) -> fx.Node: if isinstance(arg, fx.Node): mapped = val_map[arg] if self._is_float32_tensor(arg): - mapped = new_graph.call_method( - "to", (mapped, self.torch_dtype) - ) + mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) new_args.append(mapped) else: new_args.append(arg) @@ -157,7 +155,6 @@ def create_call_method(node: fx.Node) -> fx.Node: new_kwargs, ) - for node in gm.graph.nodes: if node.op == "placeholder": val_map[node] = create_placeholder(node) @@ -176,7 +173,6 @@ def create_call_method(node: fx.Node) -> fx.Node: return gm - def _is_float32_tensor(self, node: fx.Node) -> bool: """ Check if a node represents a float32 tensor. From 4f723a69f4585b7ad053fcac7010d899b442a1f7 Mon Sep 17 00:00:00 2001 From: WHoutstanding Date: Wed, 21 Jan 2026 19:02:57 +0800 Subject: [PATCH 8/9] fix dtype_generalization.py --- graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index a79f9b60b..720ad02a0 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -10,16 +10,20 @@ import torch import torch.fx as fx from graph_net.torch.dtype_gen_passes.pass_base import DtypeGeneralizationPass +import operator AMP_CALL_FUNCTION = { torch.matmul, torch.mm, torch.bmm, + operator.matmul, torch.nn.functional.linear, torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d, torch.nn.functional.scaled_dot_product_attention, + torch.addmm, + torch.einsum, } AMP_CALL_METHOD = { From e41ae05218a5531c97a83e7141cdb3050054d888 Mon Sep 17 00:00:00 2001 From: WHoutstanding Date: Thu, 22 Jan 2026 16:39:14 +0800 Subject: [PATCH 9/9] add create_new_args --- .../dtype_generalization_pass.py | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index 720ad02a0..718b39197 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -107,11 +107,8 @@ def create_get_attr(node: fx.Node) -> fx.Node: return new_graph.call_method("to", args=(new_node, self.torch_dtype)) return new_node - def create_call_function(node: fx.Node) -> fx.Node: - """Create a call_function node with dtype conversion if needed.""" - if node.target not in AMP_CALL_FUNCTION: - return new_graph.node_copy(node, lambda x: val_map[x]) - + def create_new_args(node: fx.Node) -> list: + """new_args of node with dtype conversion if needed.""" new_args = [] for arg in node.args: @@ -122,6 +119,14 @@ def create_call_function(node: fx.Node) -> fx.Node: new_args.append(mapped) else: new_args.append(arg) + return new_args + + def create_call_function(node: fx.Node) -> fx.Node: + """Create a call_function node with dtype conversion if needed.""" + if node.target not in AMP_CALL_FUNCTION: + return new_graph.node_copy(node, lambda x: val_map[x]) + + new_args = create_new_args(node) new_kwargs = { k: val_map[v] if isinstance(v, fx.Node) else v @@ -138,15 +143,7 @@ def create_call_method(node: fx.Node) -> fx.Node: if node.target not in AMP_CALL_METHOD: return new_graph.node_copy(node, lambda x: val_map[x]) - new_args = [] - for arg in node.args: - if isinstance(arg, fx.Node): - mapped = val_map[arg] - if self._is_float32_tensor(arg): - mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) - new_args.append(mapped) - else: - new_args.append(arg) + new_args = create_new_args(node) new_kwargs = { k: (val_map[v] if isinstance(v, fx.Node) else v)