Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
14f77b9
Support to save the running_states.
Xreki Dec 3, 2025
5678f1c
Define a dataclass DecomposeConfig.
Xreki Dec 3, 2025
7d9581f
Fix a sample.
Xreki Dec 3, 2025
3d30e86
Merge branch 'develop' into opt_saved_results
Xreki Dec 3, 2025
864e7b3
Record the number of original incorrect models.
Xreki Dec 3, 2025
4f954ce
Add original_name for ResNet18.
Xreki Dec 5, 2025
3b6c041
Support use original tensor_meta to to recover the re-extracted samples.
Xreki Dec 5, 2025
a59fbba
Add original_name in meta for some paddle samples.
Xreki Dec 8, 2025
78010b8
Merge branch 'add_original_name_sample' into add_original_names
Xreki Dec 8, 2025
2543be1
Merge branch 'develop' into add_original_names
Xreki Dec 8, 2025
a7982d5
Optimize codes.
Xreki Dec 8, 2025
6d15fda
Merge branch 'opt_saved_results' into add_original_names
Xreki Dec 8, 2025
5c40420
Enable meta restorer in binary composer.
Xreki Dec 8, 2025
74f423e
Merge branch 'develop' into opt_saved_results
Xreki Dec 8, 2025
1196549
Optimzie codes.
Xreki Dec 8, 2025
c8a9f68
Merge branch 'develop' into add_original_names
Xreki Dec 9, 2025
c067624
Merge branch 'develop' into opt_saved_results
Xreki Dec 9, 2025
ade8bb9
Merge branch 'opt_saved_results' into add_original_names
Xreki Dec 9, 2025
2b99941
Temporally support to save the random states.
Xreki Dec 9, 2025
d310856
Change the initialization method of tensor back to truncated normal.
Xreki Dec 9, 2025
d7c91a2
Opimize codes.
Xreki Dec 10, 2025
00b070d
Support fixed-start method.
Xreki Dec 10, 2025
7e95d7f
Merge branch 'opt_saved_results' into add_original_names
Xreki Dec 10, 2025
7cfd4eb
Support fixed-start method.
Xreki Dec 10, 2025
a2f80c6
Merge branch 'opt_saved_results' into add_original_names
Xreki Dec 10, 2025
d24ec70
Merge branch 'develop' into add_original_names
Xreki Jan 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions graph_net/paddle/random_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import pickle
import numpy as np
import random
import re
import paddle

from graph_net.paddle import samples_util


def set_seed(random_seed):
paddle.seed(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)


def _extract_model_name_for_original_sample(model_path):
fields = model_path.rstrip("/").split(os.sep)
pattern = r"^subgraph(_\d+)?$"
model_name = (
f"{fields[-2]}_{fields[-1]}" if re.match(pattern, fields[-1]) else fields[-1]
)
return model_name


def _extract_model_name_for_decomposed_subgraph(model_path):
# Parse model name and subgraph index
model_name_with_subgraph_idx = model_path.rstrip("/").split(os.sep)[-1]
model_name = "_".join(model_name_with_subgraph_idx.split("_")[:-1])
return model_name


def _generate_random_state_filename(model_path):
samples_dir = samples_util.get_default_samples_directory()
if os.path.abspath(model_path).startswith(samples_dir):
model_name = _extract_model_name_for_original_sample(model_path)
else:
model_name = _extract_model_name_for_decomposed_subgraph(model_path)
return f"{model_name}.random_states.pkl"


def save_random_states(model_path, output_dir, random_state_dict):
filepath = os.path.join(output_dir, _generate_random_state_filename(model_path))
print(f"Write to {filepath}.", flush=True)
try:
with open(filepath, "wb") as f:
pickle.dump(random_state_dict, f)
except Exception:
print(f"Fail to open {filepath}.")


def load_random_states(model_path, output_dir):
filepath = os.path.join(output_dir, _generate_random_state_filename(model_path))
print(f"Read from {filepath}.", flush=True)
random_states = None
try:
with open(filepath, "rb") as f:
random_states = pickle.load(f)
except Exception:
print(f"Fail to open {filepath}.")
return random_states
25 changes: 12 additions & 13 deletions graph_net/paddle/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,12 @@
import base64
import argparse
import numpy as np
import random

os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"

import paddle
from graph_net import imp_util
from graph_net.paddle import utils


def set_seed(random_seed):
paddle.seed(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)
from graph_net.paddle import utils, random_util


def load_class_from_file(file_path: str, class_name: str):
Expand All @@ -31,17 +24,20 @@ def get_input_dict(model_path):
params = inputs_params["weight_info"]
inputs = inputs_params["input_info"]

random_state_dict = {}
input_dict = {}
for name, meta in params.items():
original_name = (
meta["original_name"] if meta.get("original_name", None) else name
)
random_state_dict[name] = np.random.get_state()
input_dict[name] = paddle.nn.parameter.Parameter(
utils.replay_tensor(meta), name=original_name
)
for name, meta in inputs.items():
random_state_dict[name] = np.random.get_state()
input_dict[name] = utils.replay_tensor(meta)
return input_dict
return input_dict, random_state_dict


def _convert_to_dict(config_str):
Expand All @@ -66,9 +62,6 @@ def _get_decorator(args):


def main(args):
initalize_seed = 123
set_seed(random_seed=initalize_seed)

model_path = args.model_path
model_class = load_class_from_file(
f"{model_path}/model.py", class_name="GraphModule"
Expand All @@ -77,7 +70,12 @@ def main(args):
model = model_class()
print(f"{model_path=}")

input_dict = get_input_dict(args.model_path)
initalize_seed = 123
random_util.set_seed(random_seed=initalize_seed)

input_dict, random_state_dict = get_input_dict(args.model_path)
output_dir = "/work/GraphNet/graph_net/test/outputs/pass_0"
random_util.save_random_states(model_path, output_dir, random_state_dict)
model = _get_decorator(args)(model)
model(**input_dict)

Expand All @@ -98,4 +96,5 @@ def main(args):
help="decorator configuration string",
)
args = parser.parse_args()
print(args)
main(args=args)
31 changes: 17 additions & 14 deletions graph_net/paddle/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
import sys
import os
import numpy as np
import random
import platform
import traceback
import subprocess
import re

from graph_net.paddle import utils
from graph_net import path_utils
from graph_net import test_compiler_util

from graph_net.paddle import utils, random_util
from graph_net.paddle.backend.graph_compiler_backend import GraphCompilerBackend
from graph_net.paddle.backend.cinn_backend import CinnBackend
from graph_net.paddle.backend.nope_backend import NopeBackend
Expand All @@ -31,12 +29,6 @@ def get_compiler_backend(args) -> GraphCompilerBackend:
return registry_backend[args.compiler]


def set_seed(random_seed):
paddle.seed(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)


def init_env(args):
if test_compiler_util.is_gpu_device(args.device):
paddle.set_flags({"FLAGS_cudnn_exhaustive_search": 1})
Expand Down Expand Up @@ -96,14 +88,25 @@ def get_model(model_path):
return model_class()


def get_input_dict(model_path):
def get_input_dict(model_path, random_states_path=None):
inputs_params = utils.load_converted_from_text(f"{model_path}")
params = inputs_params["weight_info"]
inputs = inputs_params["input_info"]

params.update(inputs)
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
return state_dict

random_states = (
random_util.load_random_states(model_path, random_states_path)
if random_states_path
else None
)

input_dict = {}
for name, meta in params.items():
if random_states is not None and random_states.get(name, None) is not None:
np.random.set_state(random_states[name])
tensor = utils.replay_tensor(meta)
input_dict[name] = tensor
return input_dict


def get_input_spec(model_path):
Expand Down Expand Up @@ -486,7 +489,7 @@ def main(args):
assert args.device in ["cuda", "dcu", "xpu", "cpu"]

initalize_seed = 123
set_seed(random_seed=initalize_seed)
random_util.set_seed(random_seed=initalize_seed)

if path_utils.is_single_model_dir(args.model_path):
test_single_model(args)
Expand Down
17 changes: 13 additions & 4 deletions graph_net/paddle/test_reference_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from graph_net import path_utils
from graph_net import test_compiler_util
from graph_net.paddle import test_compiler
from graph_net.paddle import random_util, test_compiler


def get_reference_log_path(reference_dir, model_path):
Expand Down Expand Up @@ -38,8 +38,10 @@ def test_single_model(args):
compiler = test_compiler.get_compiler_backend(args)
test_compiler.check_and_print_gpu_utilization(compiler)

input_dict = test_compiler.get_input_dict(model_path)
model = test_compiler.get_model(model_path)
input_dict = test_compiler.get_input_dict(
args.model_path, args.random_states_path
)
model = test_compiler.get_model(args.model_path)
model.eval()

test_compiler_util.print_with_log_prompt(
Expand Down Expand Up @@ -107,6 +109,7 @@ def test_multi_models(args):
f"--trials {args.trials}",
f"--log-prompt {args.log_prompt}",
f"--seed {args.seed}",
f"--random-states-path {args.random_states_path}",
f"--reference-dir {args.reference_dir}",
]
)
Expand All @@ -130,7 +133,7 @@ def main(args):
assert args.compiler in {"cinn", "nope"}
assert args.device in ["cuda"]

test_compiler.set_seed(random_seed=args.seed)
random_util.set_seed(random_seed=args.seed)
test_compiler.init_env(args)

ref_dump_dir = Path(args.reference_dir)
Expand Down Expand Up @@ -191,6 +194,12 @@ def main(args):
default=123,
help="Random seed (default: 123)",
)
parser.add_argument(
"--random-states-path",
type=str,
required=False,
help="Path to random-states of model (s)",
)
parser.add_argument(
"--reference-dir",
type=str,
Expand Down
15 changes: 11 additions & 4 deletions graph_net/paddle/test_target_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import paddle
from graph_net import path_utils
from graph_net import test_compiler_util
from graph_net.paddle import test_compiler, test_reference_device
from graph_net.paddle import random_util, test_compiler, test_reference_device


def parse_config_from_reference_log(log_path):
Expand Down Expand Up @@ -53,7 +53,7 @@ def update_args_and_set_seed(args, model_path):
vars(args)["compiler"] = config.get("compiler")
vars(args)["trials"] = int(config.get("trials"))
vars(args)["warmup"] = int(config.get("warmup"))
test_compiler.set_seed(random_seed=int(config.get("seed")))
random_util.set_seed(random_seed=int(config.get("seed")))
return args


Expand All @@ -67,8 +67,8 @@ def test_single_model(args):
compiler = test_compiler.get_compiler_backend(args)
test_compiler.check_and_print_gpu_utilization(compiler)

input_dict = test_compiler.get_input_dict(model_path)
model = test_compiler.get_model(model_path)
input_dict = test_compiler.get_input_dict(args.model_path, args.random_states_path)
model = test_compiler.get_model(args.model_path)
model.eval()

test_compiler_util.print_basic_config(
Expand Down Expand Up @@ -146,6 +146,7 @@ def test_multi_models(args):
f"--device {args.device}",
f"--log-prompt {args.log_prompt}",
f"--reference-dir {args.reference_dir}",
f"--random-states-path {args.random_states_path}",
]
)
cmd_ret = os.system(cmd)
Expand Down Expand Up @@ -210,5 +211,11 @@ def main(args):
default=None,
help="Path to samples list, each line contains a sample path",
)
parser.add_argument(
"--random-states-path",
type=str,
required=False,
help="Path to random-states of model (s)",
)
args = parser.parse_args()
main(args=args)
15 changes: 10 additions & 5 deletions graph_net/paddle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ast
import math
import numpy as np
from scipy.stats import truncnorm
import paddle

kLiteralTensorSize = 64
Expand Down Expand Up @@ -197,17 +198,21 @@ def init_integer_tensor(dtype, shape, min_val, max_val, use_numpy):
def init_float_tensor(shape, mean, std, min_val, max_val, use_numpy):
tensor = None
if use_numpy:
if mean is not None and std is not None:
if mean is not None and std is not None and std != 0.0:
# NumPy does not support truncated normal, we simulate it here.
array = np.random.normal(0, 1, shape) * std * 0.2 + mean
array = np.clip(array, min_val, max_val)
a = (min_val - mean) / std
b = (max_val - mean) / std
array = truncnorm.rvs(a, b, loc=mean, scale=std, size=shape)
else:
array = np.random.uniform(low=min_val, high=max_val, size=shape)
tensor = paddle.to_tensor(array)
else:
if mean is not None and std is not None:
tensor = paddle.randn(shape, dtype="float32") * std * 0.2 + mean
tensor = paddle.clip(tensor, min=min_val, max=max_val)
tensor = paddle.empty(shape=shape, dtype="float32")
initializer = paddle.nn.initializer.TruncatedNormal(
mean=mean, std=std, a=min_val, b=max_val
)
initializer(tensor)
else:
tensor = paddle.uniform(
shape=shape, dtype="float32", min=min_val, max=max_val
Expand Down
1 change: 1 addition & 0 deletions graph_net/subgraph_decompose_and_evaluation_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def run_evaluation(
test_module_arguments["reference-dir"] = os.path.join(
work_dir, "reference_device_outputs"
)
test_module_arguments["random-states-path"] = work_dir

cmd = [sys.executable, "-m", f"graph_net.{framework}.{test_module_name}"] + [
item
Expand Down