From 6a5542ab49a311073b4996ea8aff15d70c5d6948 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 26 Dec 2025 14:32:30 +0800 Subject: [PATCH] Fix infinite floating value in data. --- graph_net/tensor_meta.py | 4 ++-- graph_net/torch/utils.py | 21 +++++++-------------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/graph_net/tensor_meta.py b/graph_net/tensor_meta.py index f1a082a4e..a007160db 100755 --- a/graph_net/tensor_meta.py +++ b/graph_net/tensor_meta.py @@ -86,14 +86,14 @@ def _get_limited_precision_float_str(self, value): return value if math.isnan(value) or math.isinf(value): return f'float("{value}")' - return f"{value:.3f}" + return f"{value:.6f}" def _format_data(self, data): if data is None: return "None" elif isinstance(data, list): return "[{}]".format( - ", ".join(f"{x:.6f}" if isinstance(x, float) else str(x) for x in data) + ", ".join(self._get_limited_precision_float_str(x) for x in data) ) else: return repr(data) diff --git a/graph_net/torch/utils.py b/graph_net/torch/utils.py index 62837fc8c..2b0648479 100755 --- a/graph_net/torch/utils.py +++ b/graph_net/torch/utils.py @@ -21,11 +21,9 @@ def apply_templates(forward_code: str) -> str: def get_limited_precision_float_str(value): if not isinstance(value, float): return value - if math.isnan(value): - return "float('nan')" - if math.isinf(value): - return "float('inf')" if value > 0 else "float('-inf')" - return f"{value:.3f}" + if math.isnan(value) or math.isinf(value): + return f'float("{value}")' + return f"{value:.6f}" def convert_state_and_inputs_impl(state_dict, example_inputs): @@ -130,16 +128,11 @@ def format_data(data): return "None" elif isinstance(data, torch.Tensor): if data.dtype.is_floating_point: - - def float_to_str(x): - if math.isinf(x): - return "float('inf')" if x > 0 else "float('-inf')" - if math.isnan(x): - return "float('nan')" - return f"{x:.6f}" - return "[{}]".format( - ", ".join(float_to_str(x) for x in data.flatten().tolist()) + ", ".join( + get_limited_precision_float_str(x) + for x in data.flatten().tolist() + ) ) else: return "[{}]".format(", ".join(f"{x}" for x in data.flatten().tolist()))