From 1616ee76db2a7841f2b9c5f7ec85c7347e2d1a93 Mon Sep 17 00:00:00 2001 From: denghaodong Date: Thu, 15 Jan 2026 13:12:23 +0000 Subject: [PATCH] add remote graph decomposer --- graph_net/torch/remote_graph_decomposer.py | 189 +++++++++++++++ .../torch/remote_graph_decomposer_entry.py | 223 ++++++++++++++++++ .../test/test_remote_graph_decomposer.sh | 49 ++++ 3 files changed, 461 insertions(+) create mode 100644 graph_net/torch/remote_graph_decomposer.py create mode 100644 graph_net/torch/remote_graph_decomposer_entry.py create mode 100755 graph_net_rpc/test/test_remote_graph_decomposer.sh diff --git a/graph_net/torch/remote_graph_decomposer.py b/graph_net/torch/remote_graph_decomposer.py new file mode 100644 index 000000000..9aa253d06 --- /dev/null +++ b/graph_net/torch/remote_graph_decomposer.py @@ -0,0 +1,189 @@ +import argparse +import os +import sys +from pathlib import Path +from typing import Dict, Optional + +from graph_net_rpc.sample_remote_executor import SampleRemoteExecutor + + +def _build_remote_rpc_cmd(args) -> str: + """ + Build the command string that will be executed on the remote server. + + The gRPC server will set: + INPUT_WORKSPACE= + OUTPUT_WORKSPACE= + + So we append required args with those env vars. + """ + cmd = (args.rpc_cmd or "").strip() + if not cmd: + raise ValueError("rpc_cmd is empty") + + # For our recommended server entrypoint, append the required parameters. + # We keep this logic permissive: if user provides a different rpc_cmd, they can + # still rely on these appended args as long as their entrypoint accepts them. + # + # Important: The decorator-config is a (potentially long) base64 string; we wrap it + # in double quotes. If the string itself contains quotes (shouldn't), user must handle it. + cmd += ' --model-path "$INPUT_WORKSPACE"' + cmd += f' --decorator-config "{args.decorator_config}"' + + # optional: let server choose log file name; we still pass it so server can avoid collisions + if args.remote_log_file: + cmd += f' --log-file "{args.remote_log_file}"' + + return cmd + + +def _write_files_dict_to_dir(files_dict: Dict[str, bytes], dst_dir: str) -> None: + """ + Materialize returned {relative_path: bytes} into dst_dir/relative_path. + """ + dst_root = Path(dst_dir) + dst_root.mkdir(parents=True, exist_ok=True) + + for rel_path, content in files_dict.items(): + # Normalize tar paths (always forward slashes) + rel_path = rel_path.lstrip("/") + + out_path = dst_root / rel_path + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_bytes(content) + + +def _print_remote_log(files_dict: Dict[str, bytes], log_name: Optional[str]) -> None: + """ + Print remote log to stderr if found. + """ + if not log_name: + return + + # Try exact match first. + if log_name in files_dict: + data = files_dict[log_name] + else: + # Fallback: try to find any .txt or .log likely produced by server. + candidates = sorted( + [k for k in files_dict.keys() if k.endswith((".log", ".txt"))] + ) + if len(candidates) == 1: + data = files_dict[candidates[0]] + log_name = candidates[0] + else: + print( + f"[remote_subgraph_decomposer] Remote log not found. expected={log_name}, candidates={candidates}", + file=sys.stderr, + flush=True, + ) + return + + try: + text = data.decode("utf-8", errors="replace") + print(f"\n===== Remote log: {log_name} =====", file=sys.stderr, flush=True) + print(text, file=sys.stderr, flush=True) + print("===== End remote log =====\n", file=sys.stderr, flush=True) + except Exception as e: + print( + f"[remote_subgraph_decomposer] Failed to decode remote log {log_name}: {e} (bytes={len(data)})", + file=sys.stderr, + flush=True, + ) + + +def main(args): + model_path = os.path.abspath(args.model_path) + assert os.path.isdir(model_path), f"--model-path must be a directory: {model_path}" + + local_out_dir = os.path.abspath(args.output_dir) + Path(local_out_dir).mkdir(parents=True, exist_ok=True) + + rpc_cmd = _build_remote_rpc_cmd(args) + + print( + f"[remote_subgraph_decomposer] model_path: {model_path}", + file=sys.stderr, + flush=True, + ) + print( + f"[remote_subgraph_decomposer] output_dir: {local_out_dir}", + file=sys.stderr, + flush=True, + ) + print( + f"[remote_subgraph_decomposer] server: {args.machine}:{args.port}", + file=sys.stderr, + flush=True, + ) + print( + f"[remote_subgraph_decomposer] remote rpc_cmd: {rpc_cmd}", + file=sys.stderr, + flush=True, + ) + + executor = SampleRemoteExecutor(machine=args.machine, port=args.port) + try: + files_dict = executor.execute(model_path, rpc_cmd) + + # Write everything returned by server OUTPUT_WORKSPACE into local_out_dir. + _write_files_dict_to_dir(files_dict, local_out_dir) + # Optionally print log + _print_remote_log(files_dict, args.remote_log_file) + + print( + f"[remote_subgraph_decomposer] Done. Materialized {len(files_dict)} file(s) under {local_out_dir}", + file=sys.stderr, + flush=True, + ) + finally: + executor.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Remote subgraph decomposer client (torch)." + ) + parser.add_argument( + "--model-path", + type=str, + required=True, + help="Path to a *single model directory* to upload to the server.", + ) + parser.add_argument( + "--decorator-config", + type=str, + required=True, + help="Base64-encoded decorator config JSON (same as passed to graph_net.torch.run_model).", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help=( + "Local output directory to materialize files returned from server OUTPUT_WORKSPACE. " + "Typically this should be the pass workspace dir, e.g. (so that it contains samples/...)." + ), + ) + parser.add_argument("--machine", type=str, default="localhost") + parser.add_argument("--port", type=int, default=50052) + + parser.add_argument( + "--rpc-cmd", + type=str, + default="python3 -m graph_net.torch.remote_graph_decomposer_entry", + help=( + "Command to execute on remote server. It should accept " + "--model-path and --decorator-config (and optionally --log-file)." + ), + ) + + parser.add_argument( + "--remote-log-file", + type=str, + default="log_remote_graph_decomposer.txt", + help="Remote log filename expected to be produced under OUTPUT_WORKSPACE (optional).", + ) + + args = parser.parse_args() + main(args) diff --git a/graph_net/torch/remote_graph_decomposer_entry.py b/graph_net/torch/remote_graph_decomposer_entry.py new file mode 100644 index 000000000..7ee44a38d --- /dev/null +++ b/graph_net/torch/remote_graph_decomposer_entry.py @@ -0,0 +1,223 @@ +import argparse +import base64 +import json +import os +import subprocess +import sys +from pathlib import Path +from typing import Any, Dict, Iterable + + +def b64_to_json(b64str: str) -> Dict[str, Any]: + return json.loads(base64.b64decode(b64str).decode("utf-8")) + + +def json_to_b64(obj: Dict[str, Any]) -> str: + return base64.b64encode(json.dumps(obj).encode("utf-8")).decode("utf-8") + + +def get_repo_root() -> Path: + return Path(__file__).resolve().parents[2] + + +def rewrite_path_to_repo_root(path_value: str, repo_root: Path) -> str: + """ + Rewrite a path that points into GraphNet repo to server-local repo_root. + + Remap any string containing "graph_net/<...>" to "/graph_net/<...>". + """ + if not isinstance(path_value, str) or not path_value: + return path_value + + norm = path_value.replace("\\", "/") + # If it's already under this repo_root, keep it. + repo_norm = str(repo_root).replace("\\", "/").rstrip("/") + if norm.startswith(repo_norm + "/"): + return str(Path(norm)) + + marker = "/graph_net/" + idx = norm.find(marker) + if idx >= 0: + rel = norm[idx + 1 :] # strip leading "/" to make it relative + return str(repo_root / rel) + # Try without leading slash (edge cases) + marker2 = "graph_net/" + idx2 = norm.find(marker2) + if idx2 >= 0: + rel = norm[idx2:] + return str(repo_root / rel) + # Not a repo-internal path; leave as-is. + return path_value + + +def deep_rewrite_paths( + obj: Any, repo_root: Path, keys_to_rewrite: Iterable[str] +) -> Any: + """ + Recursively rewrite known path fields in a nested dict/list structure. + """ + if isinstance(obj, dict): + new_d: Dict[str, Any] = {} + for k, v in obj.items(): + if k in keys_to_rewrite and isinstance(v, str): + new_d[k] = rewrite_path_to_repo_root(v, repo_root) + else: + new_d[k] = deep_rewrite_paths(v, repo_root, keys_to_rewrite) + return new_d + if isinstance(obj, list): + return [deep_rewrite_paths(x, repo_root, keys_to_rewrite) for x in obj] + return obj + + +def prepare_decorator_config_for_server( + decorator_config_b64: str, output_workspace: str, repo_root: Path +) -> str: + """ + - Decode client config + - Rewrite repo paths to server-local repo_root + - Force output_dir = $OUTPUT_WORKSPACE/samples (torch) + - Re-encode to base64 + """ + cfg = b64_to_json(decorator_config_b64) + + # These keys are used by subgraph_decompose_and_evaluation_step.py (torch/paddle) + # plus some forward-compatible names. + keys_to_rewrite = { + "decorator_path", + "custom_extractor_path", + "post_extract_process_path", + "post_process_path", + "extractor_path", + } + cfg = deep_rewrite_paths(cfg, repo_root=repo_root, keys_to_rewrite=keys_to_rewrite) + + # Force output dir (torch): OUTPUT_WORKSPACE/samples + forced_output_dir = os.path.join(output_workspace, "samples") + try: + custom_cfg = cfg["decorator_config"]["custom_extractor_config"] + if isinstance(custom_cfg, dict): + custom_cfg["output_dir"] = forced_output_dir + except Exception: + pass + + return json_to_b64(cfg) + + +def run_remote_subgraph_decompose(args) -> int: + input_workspace = os.environ.get("INPUT_WORKSPACE") or args.model_path + output_workspace = os.environ.get("OUTPUT_WORKSPACE") + + if not input_workspace: + print( + "[remote_subgraph_decompose_entry][ERROR] Missing INPUT_WORKSPACE and --model-path.", + file=sys.stderr, + flush=True, + ) + return 2 + if not output_workspace: + print( + "[remote_subgraph_decompose_entry][ERROR] Missing OUTPUT_WORKSPACE env var.", + file=sys.stderr, + flush=True, + ) + return 2 + + input_workspace = os.path.abspath(input_workspace) + output_workspace = os.path.abspath(output_workspace) + os.makedirs(output_workspace, exist_ok=True) + + repo_root = get_repo_root() + + # Rewrite decorator-config for server environment + new_b64 = prepare_decorator_config_for_server( + decorator_config_b64=args.decorator_config, + output_workspace=output_workspace, + repo_root=repo_root, + ) + + # Ensure forced samples dir exists (not required, but clearer) + os.makedirs(os.path.join(output_workspace, "samples"), exist_ok=True) + + log_path = os.path.join(output_workspace, args.log_file) + + cmd = [ + sys.executable, + "-m", + "graph_net.torch.run_model", + "--model-path", + input_workspace, + "--decorator-config", + new_b64, + ] + + # Helpful markers that show up in the gRPC server captured stderr/stdout as well. + print( + f"[remote_subgraph_decompose_entry] repo_root: {repo_root}", + file=sys.stderr, + flush=True, + ) + print( + f"[remote_subgraph_decompose_entry] INPUT_WORKSPACE: {input_workspace}", + file=sys.stderr, + flush=True, + ) + print( + f"[remote_subgraph_decompose_entry] OUTPUT_WORKSPACE: {output_workspace}", + file=sys.stderr, + flush=True, + ) + print( + f"[remote_subgraph_decompose_entry] log_path: {log_path}", + file=sys.stderr, + flush=True, + ) + print( + f"[remote_subgraph_decompose_entry] cmd: {' '.join(cmd)}", + file=sys.stderr, + flush=True, + ) + + # Run and log to OUTPUT_WORKSPACE (will be included in output.tar.gz) + with open(log_path, "w", encoding="utf-8") as f: + proc = subprocess.run(cmd, stdout=f, stderr=f, text=True) + + if proc.returncode != 0: + print( + f"[remote_subgraph_decompose_entry][ERROR] run_model failed with returncode={proc.returncode}. " + f"See {args.log_file} in returned output.tar.gz.", + file=sys.stderr, + flush=True, + ) + + return proc.returncode + + +def main(args) -> int: + return run_remote_subgraph_decompose(args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Server entrypoint: remote subgraph decomposition for torch." + ) + parser.add_argument( + "--model-path", + type=str, + required=False, + default=None, + help="Optional; if not provided, uses $INPUT_WORKSPACE. Kept for CLI symmetry.", + ) + parser.add_argument( + "--decorator-config", + type=str, + required=True, + help="Base64-encoded decorator config JSON from client.", + ) + parser.add_argument( + "--log-file", + type=str, + default="log_remote_decompose.txt", + help="Log filename to create under $OUTPUT_WORKSPACE.", + ) + args = parser.parse_args() + main(args) diff --git a/graph_net_rpc/test/test_remote_graph_decomposer.sh b/graph_net_rpc/test/test_remote_graph_decomposer.sh new file mode 100755 index 000000000..0d56c2f09 --- /dev/null +++ b/graph_net_rpc/test/test_remote_graph_decomposer.sh @@ -0,0 +1,49 @@ +#!/bin/bash +export LD_LIBRARY_PATH=/usr/lib64/:/usr/local/lib/:$LD_LIBRARY_PATH +set -e +GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(graph_net.__file__))") +# Test configuration +MODEL_NAME="resnet18" +MODEL_PATH="$GRAPH_NET_ROOT/../samples/timm/$MODEL_NAME" +OUTPUT_DIR="/tmp/remote_graph_decomposer" +MACHINE="localhost" +PORT=50052 + +# Create decorator config (similar to naive_graph_decomposer_test.sh) +config_json_str=$(cat <