From eeca32625374802db965c520978eed3cd53f0288 Mon Sep 17 00:00:00 2001 From: Denghaodong Date: Wed, 7 Jan 2026 06:58:24 +0000 Subject: [PATCH 01/11] add grpc server and remote executo demo --- DESIGN.md | 705 ++++++++++++++++++ graph_net/graph_net_bench/grpc/message_pb2.py | 38 +- .../graph_net_bench/grpc/message_pb2_grpc.py | 77 +- graph_net/graph_net_bench/grpc/server.py | 177 ++++- .../graph_net_bench/sample_remote_executor.py | 100 ++- 5 files changed, 1022 insertions(+), 75 deletions(-) create mode 100644 DESIGN.md diff --git a/DESIGN.md b/DESIGN.md new file mode 100644 index 000000000..a597f2818 --- /dev/null +++ b/DESIGN.md @@ -0,0 +1,705 @@ +# GraphNet 远程模型测试框架 - 设计文档 v4 + +## 一、设计原则 + +**核心思想**:服务端是**通用的模型执行引擎**,不依赖 GraphNet 代码。客户端负责调用 GraphNet 测试框架,通过 RPC 远程执行模型。 + +**重要约束**:严格基于现有的 `message.proto`,不修改协议定义。 + +--- + +## 二、系统架构 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Client │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ GraphNet Test Framework │ │ +│ │ ┌────────────────────────────────────────────┐ │ │ +│ │ │ test_compiler.py │ │ │ +│ │ │ - 性能测试 (warmup + trials) │ │ │ +│ │ │ - 正确性验证 │ │ │ +│ │ └────────────────────────────────────────────┘ │ │ +│ └───────────────────┬──────────────────────────────────┘ │ +│ │ │ +│ ┌───────────────────▼──────────────────────────────────┐ │ +│ │ SampleRemoteExecutor (graph_net_bench) │ │ +│ │ - 打包模型目录 (tar.gz) │ │ +│ │ - 通过 RPC 执行模型 │ │ +│ │ - 返回 tuple[Tensor] 结果 │ │ +│ └───────────────────┬──────────────────────────────────┘│ +│ │ gRPC (message.proto) │ +└──────────────────────┼───────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Server Machine (无 GraphNet 依赖) │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ gRPC Server (基于现有 proto) │ │ +│ │ - 接收 ExecutionRequest │ │ +│ │ - 解压并执行模型 │ │ +│ │ - 返回 ExecutionReply │ │ +│ └───────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +--- + +## 三、现有协议 (message.proto) + +```protobuf +syntax = "proto3"; + +package sample_remote_executor; + +message CompressedData { + string filename = 1; + uint32 original_size = 2; + bytes payload = 3; + string compression_algo = 4; +} + +message RpcData { + oneof rpc_data_type { + CompressedData compressed_data = 1; + string str_data = 2; + } +} + +message ExecutionRequest { + string rpc_cmd = 1; + RpcData rpc_input = 2; + optional string output_file_name = 3; +} + +message ExecutionReply { + int64 ret_code = 1; + string stdout = 2; + string stderr = 3; + RpcData rpc_output = 4; +} + +service SampleRemoteExecutor { + rpc Execute (ExecutionRequest) returns (ExecutionReply); +} +``` + +### 协议使用约定 + +| 字段 | 用途 | +|------|------| +| `rpc_cmd` | 命令标识:"execute_model" | +| `rpc_input.compressed_data` | 压缩的模型目录 (tar.gz) | +| `output_file_name` | random_seed (字符串形式) | +| `stdout` | 序列化的输出张量列表 (JSON) | +| `ret_code` | 0=成功, 非0=失败 | + +--- + +## 四、客户端实现 + +### 4.1 SampleRemoteExecutor 类 + +**文件**: `graph_net/graph_net_bench/sample_remote_executor.py` + +**状态**: ✅ 已实现 + +```python +""" +SampleRemoteExecutor: 远程模型执行器 + +使用方式: + import graph_net_bench as gnb + + # 创建执行器 + sample_remote_executor = gnb.SampleRemoteExecutor(machine="192.168.1.100", port=50052) + + # 直接调用,返回 tuple[Tensor, ...] + ret: tuple[torch.Tensor, ...] = sample_remote_executor(sample_model_path, random_seed=42) + + # 支持上下文管理器 + with sample_remote_executor: + outputs = sample_remote_executor(model_path, random_seed=1024) +""" + +import grpc +import tarfile +import json +from pathlib import Path +from io import BytesIO +from typing import Tuple, Optional +from contextlib import contextmanager + +import torch + + +class SampleRemoteExecutor: + """远程模型执行器 + + 通过 gRPC 在远程服务器上执行模型推理。 + + Attributes: + machine: 服务器 IP 地址 + port: 服务器端口 + channel: gRPC 通道 + stub: gRPC 存根 + """ + + def __init__(self, machine: str, port: int): + """ + Args: + machine: 服务器 IP 地址 + port: 服务器端口 + """ + self.machine = machine + self.port = port + self._channel: Optional[grpc.Channel] = None + self._stub = None + + def _get_stub(self): + """获取 gRPC 存根(延迟初始化)""" + if self._stub is None: + from .grpc import message_pb2, message_pb2_grpc + self._channel = grpc.insecure_channel(f"{self.machine}:{self.port}") + self._stub = message_pb2_grpc.SampleRemoteExecutorStub(self._channel) + return self._stub + + @contextmanager + def __call__(self, model_path: str, random_seed: int) -> Tuple[torch.Tensor, ...]: + """ + 远程执行模型 + + Args: + model_path: 模型目录路径 + random_seed: 随机种子,用于生成可复现的输入 + + Returns: + tuple[Tensor, ...]: 模型输出张量 + + Raises: + RuntimeError: 远程执行失败 + """ + # 1. 压缩模型目录 + compressed_data = self._compress_model(model_path) + + # 2. 构造 RPC 请求 + from .grpc import message_pb2 + + stub = self._get_stub() + request = message_pb2.ExecutionRequest( + rpc_cmd="execute_model", + rpc_input=message_pb2.RpcData( + compressed_data=compressed_data + ), + # 使用 output_file_name 字段传递 random_seed + output_file_name=str(random_seed) + ) + + # 3. 发送请求 + reply = stub.Execute(request) + + # 4. 解析结果 + if reply.ret_code != 0: + raise RuntimeError(f"Remote execution failed: {reply.stderr}") + + # stdout 包含序列化的张量数据 (JSON 格式) + return self._deserialize_tensors(reply.stdout) + + def __enter__(self): + """支持上下文管理器""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """清理资源""" + self.close() + + def close(self): + """关闭 gRPC 通道""" + if self._channel is not None: + self._channel.close() + self._channel = None + self._stub = None + + def _compress_model(self, model_path: str): + """压缩模型目录为 CompressedData + + Args: + model_path: 模型目录路径 + + Returns: + CompressedData: 压缩的模型数据 + """ + from .grpc import message_pb2 + + buffer = BytesIO() + with tarfile.open(fileobj=buffer, mode="w:gz") as tar: + model_dir = Path(model_path) + for item in model_dir.rglob("*"): + if item.is_file(): + arcname = item.relative_to(model_dir) + tar.add(item, arcname=arcname) + + compressed_bytes = buffer.getvalue() + + return message_pb2.CompressedData( + filename=f"{Path(model_path).name}.tar.gz", + original_size=len(compressed_bytes), + payload=compressed_bytes, + compression_algo="gzip" + ) + + def _deserialize_tensors(self, json_str: str) -> Tuple[torch.Tensor, ...]: + """从 JSON 反序列化张量列表 + + Args: + json_str: JSON 格式的张量数据 + + Returns: + tuple[Tensor, ...]: 张量元组 + """ + import numpy as np + + data = json.loads(json_str) + result = [] + + for tensor_data in data: + dtype = getattr(torch, tensor_data["dtype"]) + shape = tuple(tensor_data["shape"]) + # 处理不同数据类型的序列化 + if tensor_data["data"] is None: + np_array = np.zeros(shape, dtype=dtype.__name__) + else: + np_array = np.frombuffer( + tensor_data["data"].encode("latin1"), + dtype=np.dtype(dtype.__name__.replace("torch.", "")) + ) + np_array = np_array.reshape(shape) + result.append(torch.from_numpy(np_array)) + + return tuple(result) +``` + +### 4.2 使用示例 + +```python +# 方式一:直接调用 +import graph_net_bench as gnb + +# 创建执行器 +executor = gnb.SampleRemoteExecutor(machine="192.168.1.100", port=50052) + +# 直接调用,返回 tuple[Tensor, ...] +ret: tuple[torch.Tensor, ...] = executor("/path/to/model", random_seed=42) + +# 方式二:使用上下文管理器(推荐,自动关闭连接) +with gnb.SampleRemoteExecutor(machine="192.168.1.100", port=50052) as executor: + outputs = executor("/path/to/model", random_seed=1024) + +# 方式三:在 GraphNet 测试框架中使用 +def test_single_model_remote(args): + import graph_net_bench as gnb + + executor = gnb.SampleRemoteExecutor( + machine=args.remote_machine, + port=args.remote_port + ) + + with executor: + # 远程执行模型 + eager_out = executor(args.model_path, random_seed=1024) + + # 本地进行正确性对比 + compare_correctness(eager_out, compiled_out, args) +``` + +--- + +## 五、服务端实现 (TODO) + +### 5.1 remote_model_server.py + +**文件**: `graph_net/graph_net_bench/server/remote_model_server.py` + +**状态**: ⏳ 待实现 + +```python +import grpc +from concurrent import futures +import tempfile +import shutil +import tarfile +import json +import torch +from pathlib import Path +from io import BytesIO + +import message_pb2 +import message_pb2_grpc + + +class RemoteModelExecutorServicer(message_pb2_grpc.SampleRemoteExecutorServicer): + """远程模型执行服务""" + + def Execute(self, request, context): + """ + 执行模型推理 + + Args: + request: ExecutionRequest + - rpc_cmd: "execute_model" + - rpc_input.compressed_data: 压缩的模型 + - output_file_name: random_seed (字符串) + + Returns: + ExecutionReply + - ret_code: 0=成功 + - stdout: 序列化的输出张量 (JSON) + - stderr: 错误信息 + """ + temp_dir = None + try: + # 1. 解析参数 + if request.rpc_cmd != "execute_model": + return message_pb2.ExecutionReply( + ret_code=-1, + stderr=f"Unknown rpc_cmd: {request.rpc_cmd}" + ) + + random_seed = int(request.output_file_name) + + # 2. 解压模型到临时目录 + temp_dir = tempfile.mkdtemp(prefix="remote_model_") + model_path = self._decompress_model( + request.rpc_input.compressed_data, + temp_dir + ) + + # 3. 加载模型和权重 + model = self._load_model(model_path) + + # 4. 设置随机种子 + torch.manual_seed(random_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(random_seed) + + # 5. 执行推理 + model.eval() + with torch.no_grad(): + outputs = model() + + # 6. 序列化输出 + if not isinstance(outputs, tuple): + outputs = (outputs,) + + json_output = self._serialize_tensors(outputs) + + return message_pb2.ExecutionReply( + ret_code=0, + stdout=json_output, + stderr="" + ) + + except Exception as e: + import traceback + return message_pb2.ExecutionReply( + ret_code=-1, + stderr=f"{str(e)}\n{traceback.format_exc()}" + ) + finally: + if temp_dir: + shutil.rmtree(temp_dir, ignore_errors=True) + + def _decompress_model(self, compressed_data, temp_dir): + """解压模型目录""" + buffer = BytesIO(compressed_data.payload) + with tarfile.open(fileobj=buffer, mode="r:gz") as tar: + tar.extractall(path=temp_dir) + return temp_dir + + def _load_model(self, model_path): + """加载模型 + + 模型目录结构: + - model.py: 定义 GraphModule 类 + - weight_meta.py: 权重元数据 (包含权重数据或生成参数) + """ + import sys + import importlib.util + + model_file = Path(model_path) / "model.py" + spec = importlib.util.spec_from_file_location( + "remote_model", + str(model_file) + ) + module = importlib.util.module_from_spec(spec) + sys.modules["remote_model"] = module + spec.loader.exec_module(module) + + # 动态创建权重并加载到模型 + weight_tensors = self._create_weight_tensors(model_path, module.GraphModule) + + model = module.GraphModule() + + # 加载权重 + # GraphModule 的参数名与 weight_meta 中的 name 匹配 + for name, tensor in weight_tensors.items(): + param = getattr(model, name, None) + if param is not None: + param.data.copy_(tensor) + + return model + + def _create_weight_tensors(self, model_path, graph_module_class): + """根据 weight_meta.py 创建权重张量 + + 从 weight_meta.py 中读取元数据,动态生成张量数据。 + 如果元数据中有 data,则使用实际数据;否则生成随机张量。 + """ + import importlib.util + import numpy as np + + weight_meta_file = Path(model_path) / "weight_meta.py" + if not weight_meta_file.exists(): + return {} + + # 导入 weight_meta 模块 + spec = importlib.util.spec_from_file_location( + "weight_meta", + str(weight_meta_file) + ) + weight_meta_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(weight_meta_module) + + weight_tensors = {} + tensor_classes = [ + getattr(weight_meta_module, name) + for name in dir(weight_meta_module) + if name.startswith("Program_weight_tensor_meta_") + ] + + for tensor_cls in tensor_classes: + name = tensor_cls.name + shape = tensor_cls.shape + dtype = getattr(torch, tensor_cls.dtype) + device = tensor_cls.device if hasattr(tensor_cls, 'device') else 'cpu' + + if tensor_cls.data is not None: + # 使用实际数据 + np_array = np.array(tensor_cls.data, dtype=np.dtype(dtype.__name__)) + np_array = np_array.reshape(shape) + weight_tensors[name] = torch.from_numpy(np_array).to(device) + else: + # 生成随机张量 + weight_tensors[name] = torch.randn(shape, dtype=dtype, device=device) + + return weight_tensors + + def _serialize_tensors(self, outputs): + """序列化张量列表为 JSON""" + tensor_list = [] + + for tensor in outputs: + tensor_data = { + "dtype": str(tensor.dtype), + "shape": list(tensor.shape), + "data": tensor.numpy().tobytes().decode("latin1") + } + tensor_list.append(tensor_data) + + return json.dumps(tensor_list) + + +def serve(port=50052, max_workers=4): + """启动 gRPC 服务器""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers)) + message_pb2_grpc.add_SampleRemoteExecutorServicer_to_server( + RemoteModelExecutorServicer(), server + ) + server.add_insecure_port(f"0.0.0.0:{port}") + print(f"Server started on port {port}...") + server.start() + server.wait_for_termination() + + +if __name__ == "__main__": + serve() +``` + +--- + +## 六、目录结构 + +``` +graph_net/graph_net_bench/ +├── grpc/ +│ ├── message.proto # 保持不变 (协议定义) +│ ├── message_pb2.py # protobuf 自动生成 +│ ├── message_pb2_grpc.py # gRPC 存根自动生成 +│ ├── server.py # 保留 (简单 echo 测试) +│ └── client.py # 保留 (简单测试) +├── server/ +│ ├── __init__.py # 新增: 包初始化 +│ └── remote_model_server.py # ⏳ 远程模型服务 (待实现) +├── sample_remote_executor.py # ✅ 客户端实现 (已完成) +├── __init__.py # 新增: 包初始化,导出 SampleRemoteExecutor +└── DESIGN.md # 本设计文档 +``` + +--- + +## 七、关键实现细节 + +### 7.1 数据序列化方案 + +| 数据 | 序列化方式 | 协议字段 | +|------|------------|----------| +| 模型目录 | tar.gz → bytes | `rpc_input.compressed_data` | +| random_seed | 字符串 → `str()` | `output_file_name` | +| 输出张量 | JSON → stdout | `stdout` | + +### 7.2 模型目录结构 + +GraphNet 的 sample 目录结构: +``` +model_path/ +├── model.py # 定义 GraphModule 类 +├── weight_meta.py # 权重元数据 (包含形状、数据) +└── input_tensor_constraints.py # 可选: 输入约束 +``` + +**weight_meta.py 示例**: +```python +class Program_weight_tensor_meta_L_self_modules_classifier_parameters_bias_: + name = "L_self_modules_classifier_parameters_bias_" + shape = [2] + dtype = "torch.float32" + device = "cuda:0" + data = [0.0, 0.0] # 可为 None (表示随机) +``` + +### 7.3 限制与约束 + +| 约束 | 说明 | +|------|------| +| 模型大小 | 受 gRPC 消息大小限制 (默认 4MB) | +| 张量大小 | JSON 序列化效率较低,大张量需优化 | +| 并发 | 服务端 max_workers 控制并发数 | +| 依赖 | 服务端需要 torch, numpy, protobuf | + +### 7.4 未来优化 + +| 优化项 | 方案 | +|--------|------| +| 大模型传输 | 分块传输,使用 RpcData 流式处理 | +| 张量序列化 | 使用 protobuf bytes 替代 JSON | +| 结果缓存 | 缓存模型加载结果 | +| GPU 分配 | 支持指定设备 (cuda:0, cpu 等) | + +--- + +## 八、命令行使用 + +### 服务端 (远程机器) + +```bash +# 安装依赖 +pip install torch numpy grpcio grpcio-tools + +# 启动远程模型服务器 +cd /denghaodong/code/GraphNet/graph_net/graph_net_bench/server +python remote_model_server.py --port 50052 +``` + +### 客户端 (本地机器) + +```bash +# 安装依赖 +pip install torch numpy grpcio + +# 使用示例 +python -c " +import graph_net_bench as gnb + +# 创建执行器 +executor = gnb.SampleRemoteExecutor(machine='192.168.1.100', port=50052) + +# 远程执行模型 +outputs = executor('/path/to/sample/model', random_seed=42) +print(f'Received {len(outputs)} output tensors') + +# 关闭连接 +executor.close() +" +``` + +--- + +## 九、与 GraphNet 集成示例 + +### 9.1 独立使用 + +```python +import graph_net_bench as gnb +from pathlib import Path + +# 配置 +SERVER_IP = "192.168.1.100" +SERVER_PORT = 50052 +MODEL_PATH = "/path/to/transformers-auto-model/model_name" + +# 创建执行器 +executor = gnb.SampleRemoteExecutor(machine=SERVER_IP, port=SERVER_PORT) + +try: + # 执行远程推理 + outputs = executor(MODEL_PATH, random_seed=42) + print(f"Success: {len(outputs)} outputs") +finally: + executor.close() +``` + +### 9.2 在 test_compiler.py 中集成 + +```python +# graph_net/torch/test_compiler.py + +def test_single_model_remote(args): + """使用远程服务器测试单个模型""" + import graph_net_bench as gnb + + executor = gnb.SampleRemoteExecutor( + machine=args.remote_machine, + port=args.remote_port + ) + + with executor: + # 1. 远程执行 eager 模式 (基线) + eager_out = executor(args.model_path, random_seed=args.seed) + + # 2. 编译模型 (本地或远程) + compiled_out = compile_and_execute(args) + + # 3. 本地进行正确性对比 + compare_correctness(eager_out, compiled_out, args) + + return True +``` + +--- + +## 十、故障排除 + +| 问题 | 可能原因 | 解决方案 | +|------|----------|----------| +| 连接超时 | 网络不通/端口未开放 | 检查防火墙和服务器状态 | +| 内存不足 | 模型过大 | 使用更小的 batch size | +| 序列化失败 | 张量包含复杂类型 | 简化输出或使用分块传输 | +| 权重加载失败 | 参数名不匹配 | 检查 model.py 和 weight_meta.py | + +--- + +## 十一、参考资源 + +- [gRPC Python 文档](https://grpc.io/docs/languages/python/) +- [GraphNet 项目](https://github.com/PaddlePaddle/GraphNet) +- [PyTorch TorchScript](https://pytorch.org/docs/stable/jit.html) \ No newline at end of file diff --git a/graph_net/graph_net_bench/grpc/message_pb2.py b/graph_net/graph_net_bench/grpc/message_pb2.py index 4cb6c49a5..1565fbcb8 100644 --- a/graph_net/graph_net_bench/grpc/message_pb2.py +++ b/graph_net/graph_net_bench/grpc/message_pb2.py @@ -9,32 +9,36 @@ from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder - _runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, 6, 31, 1, "", "message.proto" + _runtime_version.Domain.PUBLIC, + 6, + 31, + 1, + '', + 'message.proto' ) # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\rmessage.proto\x12\x16sample_remote_executor"d\n\x0e\x43ompressedData\x12\x10\n\x08\x66ilename\x18\x01 \x01(\t\x12\x15\n\roriginal_size\x18\x02 \x01(\r\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12\x18\n\x10\x63ompression_algo\x18\x04 \x01(\t"q\n\x07RpcData\x12\x41\n\x0f\x63ompressed_data\x18\x01 \x01(\x0b\x32&.sample_remote_executor.CompressedDataH\x00\x12\x12\n\x08str_data\x18\x02 \x01(\tH\x00\x42\x0f\n\rrpc_data_type"\x8b\x01\n\x10\x45xecutionRequest\x12\x0f\n\x07rpc_cmd\x18\x01 \x01(\t\x12\x32\n\trpc_input\x18\x02 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData\x12\x1d\n\x10output_file_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x13\n\x11_output_file_name"w\n\x0e\x45xecutionReply\x12\x10\n\x08ret_code\x18\x01 \x01(\x03\x12\x0e\n\x06stdout\x18\x02 \x01(\t\x12\x0e\n\x06stderr\x18\x03 \x01(\t\x12\x33\n\nrpc_output\x18\x04 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData2s\n\x14SampleRemoteExecutor\x12[\n\x07\x45xecute\x12(.sample_remote_executor.ExecutionRequest\x1a&.sample_remote_executor.ExecutionReplyb\x06proto3' -) + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rmessage.proto\x12\x16sample_remote_executor\"d\n\x0e\x43ompressedData\x12\x10\n\x08\x66ilename\x18\x01 \x01(\t\x12\x15\n\roriginal_size\x18\x02 \x01(\r\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12\x18\n\x10\x63ompression_algo\x18\x04 \x01(\t\"q\n\x07RpcData\x12\x41\n\x0f\x63ompressed_data\x18\x01 \x01(\x0b\x32&.sample_remote_executor.CompressedDataH\x00\x12\x12\n\x08str_data\x18\x02 \x01(\tH\x00\x42\x0f\n\rrpc_data_type\"\x8b\x01\n\x10\x45xecutionRequest\x12\x0f\n\x07rpc_cmd\x18\x01 \x01(\t\x12\x32\n\trpc_input\x18\x02 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData\x12\x1d\n\x10output_file_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x13\n\x11_output_file_name\"w\n\x0e\x45xecutionReply\x12\x10\n\x08ret_code\x18\x01 \x01(\x03\x12\x0e\n\x06stdout\x18\x02 \x01(\t\x12\x0e\n\x06stderr\x18\x03 \x01(\t\x12\x33\n\nrpc_output\x18\x04 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData2s\n\x14SampleRemoteExecutor\x12[\n\x07\x45xecute\x12(.sample_remote_executor.ExecutionRequest\x1a&.sample_remote_executor.ExecutionReplyb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "message_pb2", _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'message_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals["_COMPRESSEDDATA"]._serialized_start = 41 - _globals["_COMPRESSEDDATA"]._serialized_end = 141 - _globals["_RPCDATA"]._serialized_start = 143 - _globals["_RPCDATA"]._serialized_end = 256 - _globals["_EXECUTIONREQUEST"]._serialized_start = 259 - _globals["_EXECUTIONREQUEST"]._serialized_end = 398 - _globals["_EXECUTIONREPLY"]._serialized_start = 400 - _globals["_EXECUTIONREPLY"]._serialized_end = 519 - _globals["_SAMPLEREMOTEEXECUTOR"]._serialized_start = 521 - _globals["_SAMPLEREMOTEEXECUTOR"]._serialized_end = 636 + DESCRIPTOR._loaded_options = None + _globals['_COMPRESSEDDATA']._serialized_start=41 + _globals['_COMPRESSEDDATA']._serialized_end=141 + _globals['_RPCDATA']._serialized_start=143 + _globals['_RPCDATA']._serialized_end=256 + _globals['_EXECUTIONREQUEST']._serialized_start=259 + _globals['_EXECUTIONREQUEST']._serialized_end=398 + _globals['_EXECUTIONREPLY']._serialized_start=400 + _globals['_EXECUTIONREPLY']._serialized_end=519 + _globals['_SAMPLEREMOTEEXECUTOR']._serialized_start=521 + _globals['_SAMPLEREMOTEEXECUTOR']._serialized_end=636 # @@protoc_insertion_point(module_scope) diff --git a/graph_net/graph_net_bench/grpc/message_pb2_grpc.py b/graph_net/graph_net_bench/grpc/message_pb2_grpc.py index acdbf3c87..82582553d 100644 --- a/graph_net/graph_net_bench/grpc/message_pb2_grpc.py +++ b/graph_net/graph_net_bench/grpc/message_pb2_grpc.py @@ -1,29 +1,27 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc +import warnings import message_pb2 as message__pb2 -GRPC_GENERATED_VERSION = "1.76.0" +GRPC_GENERATED_VERSION = '1.76.0' GRPC_VERSION = grpc.__version__ _version_not_supported = False try: from grpc._utilities import first_version_is_lower - - _version_not_supported = first_version_is_lower( - GRPC_VERSION, GRPC_GENERATED_VERSION - ) + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) except ImportError: _version_not_supported = True if _version_not_supported: raise RuntimeError( - f"The grpc package installed is at version {GRPC_VERSION}," - + " but the generated code in message_pb2_grpc.py depends on" - + f" grpcio>={GRPC_GENERATED_VERSION}." - + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" - + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." + f'The grpc package installed is at version {GRPC_VERSION},' + + ' but the generated code in message_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' ) @@ -37,11 +35,10 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Execute = channel.unary_unary( - "/sample_remote_executor.SampleRemoteExecutor/Execute", - request_serializer=message__pb2.ExecutionRequest.SerializeToString, - response_deserializer=message__pb2.ExecutionReply.FromString, - _registered_method=True, - ) + '/sample_remote_executor.SampleRemoteExecutor/Execute', + request_serializer=message__pb2.ExecutionRequest.SerializeToString, + response_deserializer=message__pb2.ExecutionReply.FromString, + _registered_method=True) class SampleRemoteExecutorServicer(object): @@ -50,48 +47,43 @@ class SampleRemoteExecutorServicer(object): def Execute(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_SampleRemoteExecutorServicer_to_server(servicer, server): rpc_method_handlers = { - "Execute": grpc.unary_unary_rpc_method_handler( - servicer.Execute, - request_deserializer=message__pb2.ExecutionRequest.FromString, - response_serializer=message__pb2.ExecutionReply.SerializeToString, - ), + 'Execute': grpc.unary_unary_rpc_method_handler( + servicer.Execute, + request_deserializer=message__pb2.ExecutionRequest.FromString, + response_serializer=message__pb2.ExecutionReply.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - "sample_remote_executor.SampleRemoteExecutor", rpc_method_handlers - ) + 'sample_remote_executor.SampleRemoteExecutor', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers( - "sample_remote_executor.SampleRemoteExecutor", rpc_method_handlers - ) + server.add_registered_method_handlers('sample_remote_executor.SampleRemoteExecutor', rpc_method_handlers) -# This class is part of an EXPERIMENTAL API. + # This class is part of an EXPERIMENTAL API. class SampleRemoteExecutor(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def Execute( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): + def Execute(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary( request, target, - "/sample_remote_executor.SampleRemoteExecutor/Execute", + '/sample_remote_executor.SampleRemoteExecutor/Execute', message__pb2.ExecutionRequest.SerializeToString, message__pb2.ExecutionReply.FromString, options, @@ -102,5 +94,4 @@ def Execute( wait_for_ready, timeout, metadata, - _registered_method=True, - ) + _registered_method=True) diff --git a/graph_net/graph_net_bench/grpc/server.py b/graph_net/graph_net_bench/grpc/server.py index 481cdb8b3..98fd7307f 100644 --- a/graph_net/graph_net_bench/grpc/server.py +++ b/graph_net/graph_net_bench/grpc/server.py @@ -1,29 +1,182 @@ import grpc from concurrent import futures +import tempfile +import shutil +import tarfile +import json +import torch +import numpy as np +from pathlib import Path +from io import BytesIO + import message_pb2 import message_pb2_grpc -class SampleRemoteExecutor(message_pb2_grpc.SampleRemoteExecutorServicer): +class RemoteModelExecutorServicer(message_pb2_grpc.SampleRemoteExecutorServicer): + """远程模型执行服务""" + def Execute(self, request, context): - print("[GraphNet] Received ExecuteRequest") - return message_pb2.ExecutionReply( - ret_code=0, stdout="", stderr="", rpc_output=request.rpc_input + """执行模型推理""" + temp_dir = None + + try: + # 1. 验证命令 + if request.rpc_cmd != "execute_model": + return message_pb2.ExecutionReply( + ret_code=-1, + stderr=f"Unknown rpc_cmd: {request.rpc_cmd}" + ) + + # 2. 获取 random_seed + random_seed = int(request.output_file_name) + + # 3. 解压模型 + temp_dir = tempfile.mkdtemp(prefix="remote_model_") + model_path = self._decompress_model( + request.rpc_input.compressed_data, + temp_dir + ) + + # 4. 加载模型 + model = self._load_model(model_path) + + # 5. 设置随机种子 + torch.manual_seed(random_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(random_seed) + + # 6. 执行推理 + model.eval() + with torch.no_grad(): + outputs = model() + + # 7. 序列化输出 + if not isinstance(outputs, tuple): + outputs = (outputs,) + + json_output = self._serialize_tensors(outputs) + + return message_pb2.ExecutionReply( + ret_code=0, + stdout=json_output, + stderr="" + ) + + except Exception as e: + import traceback + return message_pb2.ExecutionReply( + ret_code=-1, + stderr=f"{str(e)}\n{traceback.format_exc()}" + ) + + finally: + if temp_dir: + shutil.rmtree(temp_dir, ignore_errors=True) + + def _decompress_model(self, compressed_data, temp_dir): + """解压模型目录""" + buffer = BytesIO(compressed_data.payload) + with tarfile.open(fileobj=buffer, mode="r:gz") as tar: + tar.extractall(path=temp_dir) + return temp_dir + + def _load_model(self, model_path): + """加载模型""" + import importlib.util + + model_file = Path(model_path) / "model.py" + if not model_file.exists(): + raise FileNotFoundError(f"model.py not found in {model_path}") + + spec = importlib.util.spec_from_file_location( + "remote_model_module", + str(model_file) ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if not hasattr(module, 'GraphModule'): + raise ValueError("model.py must define 'GraphModule' class") + + # 创建权重 + weight_tensors = self._create_weight_tensors(model_path) + model = module.GraphModule() + + # 加载权重 + for name, tensor in weight_tensors.items(): + param = getattr(model, name, None) + if param is not None and isinstance(param, torch.Tensor): + param.data.copy_(tensor) + return model + + def _create_weight_tensors(self, model_path): + """根据 weight_meta.py 创建权重张量""" + import importlib.util + + weight_meta_file = Path(model_path) / "weight_meta.py" + if not weight_meta_file.exists(): + return {} + + spec = importlib.util.spec_from_file_location( + "weight_meta_module", + str(weight_meta_file) + ) + weight_meta_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(weight_meta_module) -def serve(): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + weight_tensors = {} + tensor_classes = [ + getattr(weight_meta_module, name) + for name in dir(weight_meta_module) + if name.startswith("Program_weight_tensor_meta_") + ] + + for tensor_cls in tensor_classes: + name = tensor_cls.name + shape = tensor_cls.shape + dtype = getattr(torch, tensor_cls.dtype) + device = getattr(tensor_cls, 'device', 'cpu') + + if tensor_cls.data is not None: + np_array = np.array(tensor_cls.data, dtype=np.dtype(dtype.__name__)) + np_array = np_array.reshape(shape) + weight_tensors[name] = torch.from_numpy(np_array).to(device) + else: + weight_tensors[name] = torch.randn(shape, dtype=dtype, device=device) + + return weight_tensors + + def _serialize_tensors(self, outputs): + """序列化张量为 JSON""" + tensor_list = [] + for tensor in outputs: + tensor_data = { + "dtype": str(tensor.dtype), + "shape": list(tensor.shape), + "data": tensor.cpu().numpy().tobytes().decode("latin1") + } + tensor_list.append(tensor_data) + return json.dumps(tensor_list) + + +def serve(port=50052, max_workers=4): + """启动 gRPC 服务器""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers)) message_pb2_grpc.add_SampleRemoteExecutorServicer_to_server( - SampleRemoteExecutor(), server + RemoteModelExecutorServicer(), server ) - - # Listen on all interfaces (0.0.0.0) at port 50052 - server.add_insecure_port("0.0.0.0:50052") - print("Server started on port 50052...") + server.add_insecure_port(f"0.0.0.0:{port}") + print(f"Server started on port {port}...") server.start() server.wait_for_termination() if __name__ == "__main__": - serve() + import argparse + parser = argparse.ArgumentParser(description="Remote Model Server") + parser.add_argument("--port", type=int, default=50052) + parser.add_argument("--max-workers", type=int, default=4) + args = parser.parse_args() + serve(port=args.port, max_workers=args.max_workers) diff --git a/graph_net/graph_net_bench/sample_remote_executor.py b/graph_net/graph_net_bench/sample_remote_executor.py index 301554f0d..0c1a20370 100644 --- a/graph_net/graph_net_bench/sample_remote_executor.py +++ b/graph_net/graph_net_bench/sample_remote_executor.py @@ -1,7 +1,101 @@ +import grpc +import tarfile +import json +from pathlib import Path +from io import BytesIO +from typing import Tuple, Optional +from contextlib import contextmanager +import torch + + class SampleRemoteExecutor: + """远程模型执行器""" + def __init__(self, machine: str, port: int): self.machine = machine self.port = port - - def __call__(self, model_path: str, random_seed: int) -> tuple: - raise NotImplementedError("TODO") + self._channel: Optional[grpc.Channel] = None + self._stub = None + + def _get_stub(self): + if self._stub is None: + from .grpc import message_pb2, message_pb2_grpc + self._channel = grpc.insecure_channel(f"{self.machine}:{self.port}") + self._stub = message_pb2_grpc.SampleRemoteExecutorStub(self._channel) + return self._stub + + def __call__(self, model_path: str, random_seed: int) -> Tuple[torch.Tensor, ...]: + """远程执行模型""" + from .grpc import message_pb2 + + compressed_data = self._compress_model(model_path) + + # 2. 构建请求 + stub = self._get_stub() + request = message_pb2.ExecutionRequest( + rpc_cmd="execute_model", + rpc_input=message_pb2.RpcData(compressed_data=compressed_data), + output_file_name=str(random_seed) + ) + + reply = stub.Execute(request) + + if reply.ret_code != 0: + raise RuntimeError(f"Remote execution failed: {reply.stderr}") + + return self._deserialize_tensors(reply.stdout) + + def _compress_model(self, model_path: str): + from .grpc import message_pb2 + + buffer = BytesIO() + with tarfile.open(fileobj=buffer, mode="w:gz") as tar: + model_dir = Path(model_path) + for item in model_dir.rglob("*"): + if item.is_file(): + arcname = item.relative_to(model_dir) + tar.add(item, arcname=arcname) + + compressed_bytes = buffer.getvalue() + + return message_pb2.CompressedData( + filename=f"{Path(model_path).name}.tar.gz", + original_size=len(compressed_bytes), + payload=compressed_bytes, + compression_algo="gzip" + ) + + def _deserialize_tensors(self, json_str: str) -> Tuple[torch.Tensor, ...]: + import numpy as np + + data = json.loads(json_str) + result = [] + + for tensor_data in data: + dtype = getattr(torch, tensor_data["dtype"]) + shape = tuple(tensor_data["shape"]) + + if tensor_data["data"] is None: + np_array = np.zeros(shape, dtype=dtype.__name__) + else: + np_array = np.frombuffer( + tensor_data["data"].encode("latin1"), + dtype=np.dtype(dtype.__name__.replace("torch.", "")) + ) + np_array = np_array.reshape(shape) + + result.append(torch.from_numpy(np_array)) + + return tuple(result) + + def close(self): + if self._channel is not None: + self._channel.close() + self._channel = None + self._stub = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() From 933de80d756a319313fdc412f30f0665ff1e305b Mon Sep 17 00:00:00 2001 From: Denghaodong Date: Wed, 7 Jan 2026 13:56:15 +0000 Subject: [PATCH 02/11] add rpc_cmd script --- graph_net/graph_net_bench/grpc/client.py | 83 +++++++- graph_net/graph_net_bench/grpc/message.proto | 3 +- graph_net/graph_net_bench/grpc/message_pb2.py | 18 +- graph_net/graph_net_bench/grpc/server.py | 188 ++++++------------ .../graph_net_bench/sample_remote_executor.py | 43 ++-- graph_net/graph_net_bench/sample_rpc_cmd.py | 139 +++++++++++++ 6 files changed, 309 insertions(+), 165 deletions(-) create mode 100644 graph_net/graph_net_bench/sample_rpc_cmd.py diff --git a/graph_net/graph_net_bench/grpc/client.py b/graph_net/graph_net_bench/grpc/client.py index 9df794b9b..5053ea1da 100644 --- a/graph_net/graph_net_bench/grpc/client.py +++ b/graph_net/graph_net_bench/grpc/client.py @@ -1,21 +1,82 @@ import grpc import message_pb2 import message_pb2_grpc +import argparse +import tarfile +from pathlib import Path +from io import BytesIO -def run(): - # REPLACE 'SERVER_IP' with the actual IP address of Machine A - server_ip = "localhost" - channel = grpc.insecure_channel(f"{server_ip}:50052") - stub = message_pb2_grpc.SampleRemoteExecutorStub(channel) +def _compress_model(model_path: str): + buffer = BytesIO() + with tarfile.open(fileobj=buffer, mode="w:gz") as tar: + model_dir = Path(model_path) + for item in model_dir.rglob("*"): + if item.is_file(): + arcname = item.relative_to(model_dir) + tar.add(item, arcname=arcname) + + compressed_bytes = buffer.getvalue() - request = message_pb2.ExecutionRequest( - rpc_cmd="my-echo", - rpc_input=message_pb2.RpcData(str_data="gooooooooooood"), + return message_pb2.CompressedData( + filename=f"{Path(model_path).name}.tar.gz", + original_size=len(compressed_bytes), + payload=compressed_bytes, + compression_algo="gzip" ) - response = stub.Execute(request) - print(f"{response.rpc_output=}") + + +def run(server_ip: str = "localhost", port: int = 50052, timeout: float = 10.0, + rpc_cmd: str = "execute_model", output_file_name: str = "42", + model_path: str = None): + channel = grpc.insecure_channel( + f"{server_ip}:{port}", + options=[ + ('grpc.max_send_message_length', 100 * 1024 * 1024), # 100MB + ('grpc.max_receive_message_length', 100 * 1024 * 1024), # 100MB + ] + ) + stub = message_pb2_grpc.SampleRemoteExecutorStub(channel) + + if model_path: + # 发送压缩的模型数据 + compressed_data = _compress_model(model_path) + request = message_pb2.ExecutionRequest( + rpc_cmd=rpc_cmd, + rpc_input=message_pb2.RpcData(compressed_data=compressed_data), + output_file_name=output_file_name + ) + else: + # 发送简单的字符串数据(用于测试连通性) + request = message_pb2.ExecutionRequest( + rpc_cmd=rpc_cmd, + rpc_input=message_pb2.RpcData(str_data="test data"), + output_file_name=output_file_name + ) + + try: + response = stub.Execute(request, timeout=timeout) + print(f"ret_code: {response.ret_code}") + # print(f"stdout: {response.stdout}") + print(f"stderr: {response.stderr}") + if response.stdout: + print(f"rpc_output: {response.rpc_output}") + except grpc.RpcError as e: + print(f"gRPC 调用失败: {e.code()}: {e.details()}") + raise if __name__ == "__main__": - run() + parser = argparse.ArgumentParser(description="gRPC Client") + parser.add_argument("--server-ip", type=str, default="localhost") + parser.add_argument("--port", type=int, default=50052) + parser.add_argument("--timeout", type=float, default=10.0) + parser.add_argument("--rpc-cmd", type=str, default="execute_model") + parser.add_argument("--output-file-name", type=str, default="42") + parser.add_argument("--model-path", type=str, default=None, + help="Path to model directory (optional)") + args = parser.parse_args() + + run(server_ip=args.server_ip, port=args.port, timeout=args.timeout, + rpc_cmd=args.rpc_cmd, output_file_name=args.output_file_name, + model_path=args.model_path) diff --git a/graph_net/graph_net_bench/grpc/message.proto b/graph_net/graph_net_bench/grpc/message.proto index 2d2ef4b41..4c6656f19 100644 --- a/graph_net/graph_net_bench/grpc/message.proto +++ b/graph_net/graph_net_bench/grpc/message.proto @@ -12,7 +12,8 @@ message CompressedData { message RpcData { oneof rpc_data_type { CompressedData compressed_data = 1; - string str_data = 2; + string str_data = 2; + bytes npz_data = 3; } } diff --git a/graph_net/graph_net_bench/grpc/message_pb2.py b/graph_net/graph_net_bench/grpc/message_pb2.py index 1565fbcb8..52d793731 100644 --- a/graph_net/graph_net_bench/grpc/message_pb2.py +++ b/graph_net/graph_net_bench/grpc/message_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rmessage.proto\x12\x16sample_remote_executor\"d\n\x0e\x43ompressedData\x12\x10\n\x08\x66ilename\x18\x01 \x01(\t\x12\x15\n\roriginal_size\x18\x02 \x01(\r\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12\x18\n\x10\x63ompression_algo\x18\x04 \x01(\t\"q\n\x07RpcData\x12\x41\n\x0f\x63ompressed_data\x18\x01 \x01(\x0b\x32&.sample_remote_executor.CompressedDataH\x00\x12\x12\n\x08str_data\x18\x02 \x01(\tH\x00\x42\x0f\n\rrpc_data_type\"\x8b\x01\n\x10\x45xecutionRequest\x12\x0f\n\x07rpc_cmd\x18\x01 \x01(\t\x12\x32\n\trpc_input\x18\x02 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData\x12\x1d\n\x10output_file_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x13\n\x11_output_file_name\"w\n\x0e\x45xecutionReply\x12\x10\n\x08ret_code\x18\x01 \x01(\x03\x12\x0e\n\x06stdout\x18\x02 \x01(\t\x12\x0e\n\x06stderr\x18\x03 \x01(\t\x12\x33\n\nrpc_output\x18\x04 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData2s\n\x14SampleRemoteExecutor\x12[\n\x07\x45xecute\x12(.sample_remote_executor.ExecutionRequest\x1a&.sample_remote_executor.ExecutionReplyb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rmessage.proto\x12\x16sample_remote_executor\"d\n\x0e\x43ompressedData\x12\x10\n\x08\x66ilename\x18\x01 \x01(\t\x12\x15\n\roriginal_size\x18\x02 \x01(\r\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12\x18\n\x10\x63ompression_algo\x18\x04 \x01(\t\"\x85\x01\n\x07RpcData\x12\x41\n\x0f\x63ompressed_data\x18\x01 \x01(\x0b\x32&.sample_remote_executor.CompressedDataH\x00\x12\x12\n\x08str_data\x18\x02 \x01(\tH\x00\x12\x12\n\x08npz_data\x18\x03 \x01(\x0cH\x00\x42\x0f\n\rrpc_data_type\"\x8b\x01\n\x10\x45xecutionRequest\x12\x0f\n\x07rpc_cmd\x18\x01 \x01(\t\x12\x32\n\trpc_input\x18\x02 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData\x12\x1d\n\x10output_file_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x13\n\x11_output_file_name\"w\n\x0e\x45xecutionReply\x12\x10\n\x08ret_code\x18\x01 \x01(\x03\x12\x0e\n\x06stdout\x18\x02 \x01(\t\x12\x0e\n\x06stderr\x18\x03 \x01(\t\x12\x33\n\nrpc_output\x18\x04 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData2s\n\x14SampleRemoteExecutor\x12[\n\x07\x45xecute\x12(.sample_remote_executor.ExecutionRequest\x1a&.sample_remote_executor.ExecutionReplyb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -33,12 +33,12 @@ DESCRIPTOR._loaded_options = None _globals['_COMPRESSEDDATA']._serialized_start=41 _globals['_COMPRESSEDDATA']._serialized_end=141 - _globals['_RPCDATA']._serialized_start=143 - _globals['_RPCDATA']._serialized_end=256 - _globals['_EXECUTIONREQUEST']._serialized_start=259 - _globals['_EXECUTIONREQUEST']._serialized_end=398 - _globals['_EXECUTIONREPLY']._serialized_start=400 - _globals['_EXECUTIONREPLY']._serialized_end=519 - _globals['_SAMPLEREMOTEEXECUTOR']._serialized_start=521 - _globals['_SAMPLEREMOTEEXECUTOR']._serialized_end=636 + _globals['_RPCDATA']._serialized_start=144 + _globals['_RPCDATA']._serialized_end=277 + _globals['_EXECUTIONREQUEST']._serialized_start=280 + _globals['_EXECUTIONREQUEST']._serialized_end=419 + _globals['_EXECUTIONREPLY']._serialized_start=421 + _globals['_EXECUTIONREPLY']._serialized_end=540 + _globals['_SAMPLEREMOTEEXECUTOR']._serialized_start=542 + _globals['_SAMPLEREMOTEEXECUTOR']._serialized_end=657 # @@protoc_insertion_point(module_scope) diff --git a/graph_net/graph_net_bench/grpc/server.py b/graph_net/graph_net_bench/grpc/server.py index 98fd7307f..9a81c0596 100644 --- a/graph_net/graph_net_bench/grpc/server.py +++ b/graph_net/graph_net_bench/grpc/server.py @@ -3,64 +3,75 @@ import tempfile import shutil import tarfile -import json -import torch -import numpy as np from pathlib import Path from io import BytesIO - +import subprocess +import os import message_pb2 import message_pb2_grpc class RemoteModelExecutorServicer(message_pb2_grpc.SampleRemoteExecutorServicer): - """远程模型执行服务""" def Execute(self, request, context): - """执行模型推理""" - temp_dir = None + input_workspace = None + output_workspace = None try: - # 1. 验证命令 - if request.rpc_cmd != "execute_model": + input_workspace = tempfile.mkdtemp(prefix="input_workspace_") + output_workspace = tempfile.mkdtemp(prefix="output_workspace_") + + self._decompress_model(request.rpc_input.compressed_data, input_workspace) + + # 3. 构建 rpc_cmd 脚本路径 + rpc_cmd_path = os.path.join(input_workspace, request.rpc_cmd) + + # 4. 设置环境变量并执行 rpc_cmd 脚本 + env = os.environ.copy() + env["INPUT_WORKSPACE"] = input_workspace + env["OUTPUT_WORKSPACE"] = output_workspace + env["OUTPUT_FILE_NAME"] = request.output_file_name + + result = subprocess.run( + ["python3", rpc_cmd_path], + capture_output=True, + text=True, + env=env, + timeout=300 # 5分钟超时 + ) + + if result.returncode != 0: return message_pb2.ExecutionReply( ret_code=-1, - stderr=f"Unknown rpc_cmd: {request.rpc_cmd}" + stderr=f"rpc_cmd 执行失败:\n{result.stderr}" ) - # 2. 获取 random_seed - random_seed = int(request.output_file_name) - - # 3. 解压模型 - temp_dir = tempfile.mkdtemp(prefix="remote_model_") - model_path = self._decompress_model( - request.rpc_input.compressed_data, - temp_dir + # 5. 读取 output_workspace/{output_file_name} 到 reply.rpc_output + output_file_path = os.path.join( + output_workspace, + request.output_file_name ) - # 4. 加载模型 - model = self._load_model(model_path) - - # 5. 设置随机种子 - torch.manual_seed(random_seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(random_seed) - - # 6. 执行推理 - model.eval() - with torch.no_grad(): - outputs = model() - - # 7. 序列化输出 - if not isinstance(outputs, tuple): - outputs = (outputs,) - - json_output = self._serialize_tensors(outputs) + if not os.path.exists(output_file_path): + return message_pb2.ExecutionReply( + ret_code=-1, + stderr=f"输出文件不存在: {output_file_path}" + ) + # 读取 .npz 文件并返回 return message_pb2.ExecutionReply( ret_code=0, - stdout=json_output, - stderr="" + stdout=result.stdout, + stderr=result.stderr, + rpc_output=message_pb2.RpcData( + npz_data=Path(output_file_path).read_bytes() + ) + ) + + except subprocess.TimeoutExpired: + return message_pb2.ExecutionReply( + ret_code=-1, + stderr="rpc_cmd 执行超时(5分钟)" ) except Exception as e: @@ -71,99 +82,28 @@ def Execute(self, request, context): ) finally: - if temp_dir: - shutil.rmtree(temp_dir, ignore_errors=True) + # 6. 清理临时目录 + if input_workspace and os.path.exists(input_workspace): + shutil.rmtree(input_workspace, ignore_errors=True) + if output_workspace and os.path.exists(output_workspace): + shutil.rmtree(output_workspace, ignore_errors=True) - def _decompress_model(self, compressed_data, temp_dir): - """解压模型目录""" + def _decompress_model(self, compressed_data, target_dir): buffer = BytesIO(compressed_data.payload) with tarfile.open(fileobj=buffer, mode="r:gz") as tar: - tar.extractall(path=temp_dir) - return temp_dir - - def _load_model(self, model_path): - """加载模型""" - import importlib.util - - model_file = Path(model_path) / "model.py" - if not model_file.exists(): - raise FileNotFoundError(f"model.py not found in {model_path}") - - spec = importlib.util.spec_from_file_location( - "remote_model_module", - str(model_file) - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - if not hasattr(module, 'GraphModule'): - raise ValueError("model.py must define 'GraphModule' class") - - # 创建权重 - weight_tensors = self._create_weight_tensors(model_path) - model = module.GraphModule() - - # 加载权重 - for name, tensor in weight_tensors.items(): - param = getattr(model, name, None) - if param is not None and isinstance(param, torch.Tensor): - param.data.copy_(tensor) - - return model - - def _create_weight_tensors(self, model_path): - """根据 weight_meta.py 创建权重张量""" - import importlib.util - - weight_meta_file = Path(model_path) / "weight_meta.py" - if not weight_meta_file.exists(): - return {} - - spec = importlib.util.spec_from_file_location( - "weight_meta_module", - str(weight_meta_file) - ) - weight_meta_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(weight_meta_module) - - weight_tensors = {} - tensor_classes = [ - getattr(weight_meta_module, name) - for name in dir(weight_meta_module) - if name.startswith("Program_weight_tensor_meta_") - ] - - for tensor_cls in tensor_classes: - name = tensor_cls.name - shape = tensor_cls.shape - dtype = getattr(torch, tensor_cls.dtype) - device = getattr(tensor_cls, 'device', 'cpu') - - if tensor_cls.data is not None: - np_array = np.array(tensor_cls.data, dtype=np.dtype(dtype.__name__)) - np_array = np_array.reshape(shape) - weight_tensors[name] = torch.from_numpy(np_array).to(device) - else: - weight_tensors[name] = torch.randn(shape, dtype=dtype, device=device) - - return weight_tensors - - def _serialize_tensors(self, outputs): - """序列化张量为 JSON""" - tensor_list = [] - for tensor in outputs: - tensor_data = { - "dtype": str(tensor.dtype), - "shape": list(tensor.shape), - "data": tensor.cpu().numpy().tobytes().decode("latin1") - } - tensor_list.append(tensor_data) - return json.dumps(tensor_list) + tar.extractall(path=target_dir) def serve(port=50052, max_workers=4): """启动 gRPC 服务器""" - server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers)) + # 增加消息大小限制(支持最多 100MB) + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=max_workers), + options=[ + ('grpc.max_send_message_length', 100 * 1024 * 1024), # 100MB + ('grpc.max_receive_message_length', 100 * 1024 * 1024), # 100MB + ] + ) message_pb2_grpc.add_SampleRemoteExecutorServicer_to_server( RemoteModelExecutorServicer(), server ) diff --git a/graph_net/graph_net_bench/sample_remote_executor.py b/graph_net/graph_net_bench/sample_remote_executor.py index 0c1a20370..57e33e0dd 100644 --- a/graph_net/graph_net_bench/sample_remote_executor.py +++ b/graph_net/graph_net_bench/sample_remote_executor.py @@ -9,7 +9,6 @@ class SampleRemoteExecutor: - """远程模型执行器""" def __init__(self, machine: str, port: int): self.machine = machine @@ -20,7 +19,12 @@ def __init__(self, machine: str, port: int): def _get_stub(self): if self._stub is None: from .grpc import message_pb2, message_pb2_grpc - self._channel = grpc.insecure_channel(f"{self.machine}:{self.port}") + self._channel = grpc.insecure_channel( + f"{self.machine}:{self.port}", + options=[ + ("grpc.max_receive_message_length", 32 * 1024 * 1024), # 32MB + ("grpc.max_send_message_length", 32 * 1024 * 1024), + ],) self._stub = message_pb2_grpc.SampleRemoteExecutorStub(self._channel) return self._stub @@ -30,7 +34,6 @@ def __call__(self, model_path: str, random_seed: int) -> Tuple[torch.Tensor, ... compressed_data = self._compress_model(model_path) - # 2. 构建请求 stub = self._get_stub() request = message_pb2.ExecutionRequest( rpc_cmd="execute_model", @@ -43,7 +46,11 @@ def __call__(self, model_path: str, random_seed: int) -> Tuple[torch.Tensor, ... if reply.ret_code != 0: raise RuntimeError(f"Remote execution failed: {reply.stderr}") - return self._deserialize_tensors(reply.stdout) + # 从 rpc_output.npz_data 读取 .npz 文件 + if reply.rpc_output.HasField("npz_data"): + return self._load_npz_from_bytes(reply.rpc_output.npz_data) + else: + raise RuntimeError(f"Invalid reply: expected npz_data in rpc_output") def _compress_model(self, model_path: str): from .grpc import message_pb2 @@ -65,26 +72,22 @@ def _compress_model(self, model_path: str): compression_algo="gzip" ) - def _deserialize_tensors(self, json_str: str) -> Tuple[torch.Tensor, ...]: + def _load_npz_from_bytes(self, npz_bytes: bytes) -> Tuple[torch.Tensor, ...]: + """从 bytes 加载 .npz 文件并转换为张量元组""" import numpy as np + from io import BytesIO - data = json.loads(json_str) - result = [] - - for tensor_data in data: - dtype = getattr(torch, tensor_data["dtype"]) - shape = tuple(tensor_data["shape"]) + # 将 bytes 写入临时内存文件 + with BytesIO(npz_bytes) as f: + npz = np.load(f, allow_pickle=True) - if tensor_data["data"] is None: - np_array = np.zeros(shape, dtype=dtype.__name__) - else: - np_array = np.frombuffer( - tensor_data["data"].encode("latin1"), - dtype=np.dtype(dtype.__name__.replace("torch.", "")) - ) - np_array = np_array.reshape(shape) + result = [] + # 按字母顺序读取所有数组(保持一致性) + for key in sorted(npz.files): + arr = npz[key] + result.append(torch.from_numpy(arr)) - result.append(torch.from_numpy(np_array)) + npz.close() return tuple(result) diff --git a/graph_net/graph_net_bench/sample_rpc_cmd.py b/graph_net/graph_net_bench/sample_rpc_cmd.py new file mode 100644 index 000000000..860d68dbe --- /dev/null +++ b/graph_net/graph_net_bench/sample_rpc_cmd.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 + +import os +import sys +import json +import importlib.util +from pathlib import Path + +import torch +import numpy as np + + +def load_model_and_weights(model_path: str): + + model_file = Path(model_path) / "model.py" + if not model_file.exists(): + raise FileNotFoundError(f"model.py not found in {model_path}") + + spec = importlib.util.spec_from_file_location("remote_model_module", str(model_file)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if not hasattr(module, 'GraphModule'): + raise ValueError("model.py must define 'GraphModule' class") + + weight_tensors = {} + weight_meta_file = Path(model_path) / "weight_meta.py" + if weight_meta_file.exists(): + spec = importlib.util.spec_from_file_location( + "weight_meta_module", str(weight_meta_file) + ) + weight_meta_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(weight_meta_module) + + tensor_classes = [ + getattr(weight_meta_module, name) + for name in dir(weight_meta_module) + if name.startswith("Program_weight_tensor_meta_") + ] + + for tensor_cls in tensor_classes: + name = tensor_cls.name + shape = tensor_cls.shape + dtype_name = tensor_cls.dtype.replace("torch.", "") + dtype = getattr(torch, dtype_name) + device = getattr(tensor_cls, 'device', 'cpu') + + if tensor_cls.data is not None: + np_array = np.array(tensor_cls.data, dtype=np.dtype(dtype.__name__)) + np_array = np_array.reshape(shape) + weight_tensors[name] = torch.from_numpy(np_array).to(device) + else: + if dtype == torch.bool: + weight_tensors[name] = torch.zeros(shape, dtype=dtype, device=device) + else: + weight_tensors[name] = torch.randn(shape, dtype=dtype, device=device) + + model = module.GraphModule() + for name, tensor in weight_tensors.items(): + param = getattr(model, name, None) + if param is not None and isinstance(param, torch.Tensor): + param.data.copy_(tensor) + + return model, weight_tensors + + +def get_forward_inputs(model_path: str, weight_tensors: dict): + import inspect + + model_file = Path(model_path) / "model.py" + spec = importlib.util.spec_from_file_location("remote_model_module", str(model_file)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + forward_params = inspect.signature(module.GraphModule.forward).parameters + param_names = [ + name for name in forward_params.keys() + if name != 'self' + ] + + inputs = [] + for param_name in param_names: + if param_name in weight_tensors: + inputs.append(weight_tensors[param_name]) + else: + raise ValueError(f"Missing weight tensor for parameter: {param_name}") + + return inputs + + +def save_outputs_as_npz(outputs, output_path: str): + if not isinstance(outputs, tuple): + outputs = (outputs,) + + np_arrays = {} + for i, tensor in enumerate(outputs): + key = f"output_{i}" + np_arrays[key] = tensor.cpu().numpy() + + np.savez(output_path, **np_arrays) + + +def main(): + input_workspace = os.environ.get("INPUT_WORKSPACE") + output_workspace = os.environ.get("OUTPUT_WORKSPACE") + output_file_name = os.environ.get("OUTPUT_FILE_NAME") + + if not input_workspace: + raise RuntimeError("INPUT_WORKSPACE environment variable not set") + if not output_workspace: + raise RuntimeError("OUTPUT_WORKSPACE environment variable not set") + if not output_file_name: + raise RuntimeError("OUTPUT_FILE_NAME environment variable not set") + + print(f"INPUT_WORKSPACE: {input_workspace}", file=sys.stderr) + print(f"OUTPUT_WORKSPACE: {output_workspace}", file=sys.stderr) + print(f"OUTPUT_FILE_NAME: {output_file_name}", file=sys.stderr) + + print("Loading model and weights...", file=sys.stderr) + model, weight_tensors = load_model_and_weights(input_workspace) + print(f"Model loaded, {len(weight_tensors)} weight tensors", file=sys.stderr) + + print("Preparing inputs...", file=sys.stderr) + inputs = get_forward_inputs(input_workspace, weight_tensors) + print(f"Prepared {len(inputs)} inputs for forward()", file=sys.stderr) + + print("Running inference...", file=sys.stderr) + model.eval() + with torch.no_grad(): + outputs = model(*inputs) + + output_path = os.path.join(output_workspace, output_file_name) + print(f"Saving outputs to {output_path}...", file=sys.stderr) + save_outputs_as_npz(outputs, output_path) + print("Outputs saved successfully!", file=sys.stderr) + + +if __name__ == "__main__": + main() \ No newline at end of file From 5872d0280a0ef440686f1fcb50aae8083f03aff7 Mon Sep 17 00:00:00 2001 From: Denghaodong Date: Thu, 8 Jan 2026 08:34:41 +0000 Subject: [PATCH 03/11] add RpcExecutor --- graph_net/graph_net_bench/_init_.py | 3 + graph_net/graph_net_bench/grpc/client.py | 165 ++++++++++-------- graph_net/graph_net_bench/grpc/message.proto | 3 + graph_net/graph_net_bench/grpc/message_pb2.py | 12 +- graph_net/graph_net_bench/grpc/server.py | 136 ++++++++------- .../graph_net_bench/sample_remote_executor.py | 129 +++++++------- graph_net/graph_net_bench/sample_rpc_cmd.py | 57 +++--- 7 files changed, 279 insertions(+), 226 deletions(-) create mode 100644 graph_net/graph_net_bench/_init_.py diff --git a/graph_net/graph_net_bench/_init_.py b/graph_net/graph_net_bench/_init_.py new file mode 100644 index 000000000..95b260bf5 --- /dev/null +++ b/graph_net/graph_net_bench/_init_.py @@ -0,0 +1,3 @@ +from .sample_remote_executor import SampleRemoteExecutor + +__all__ = ["SampleRemoteExecutor"] \ No newline at end of file diff --git a/graph_net/graph_net_bench/grpc/client.py b/graph_net/graph_net_bench/grpc/client.py index 5053ea1da..c2c3290b7 100644 --- a/graph_net/graph_net_bench/grpc/client.py +++ b/graph_net/graph_net_bench/grpc/client.py @@ -1,82 +1,103 @@ -import grpc -import message_pb2 -import message_pb2_grpc +#!/usr/bin/env python3 +"""gRPC Client CLI for SampleRemoteExecutor. + +Usage: + python -m graph_net.graph_net_bench.grpc.client --help +""" + import argparse -import tarfile -from pathlib import Path -from io import BytesIO - - -def _compress_model(model_path: str): - buffer = BytesIO() - with tarfile.open(fileobj=buffer, mode="w:gz") as tar: - model_dir = Path(model_path) - for item in model_dir.rglob("*"): - if item.is_file(): - arcname = item.relative_to(model_dir) - tar.add(item, arcname=arcname) - - compressed_bytes = buffer.getvalue() - - return message_pb2.CompressedData( - filename=f"{Path(model_path).name}.tar.gz", - original_size=len(compressed_bytes), - payload=compressed_bytes, - compression_algo="gzip" +import sys + +try: + from ..sample_remote_executor import SampleRemoteExecutor +except ImportError: + import sys + from pathlib import Path + # Add graph_net_bench directory to path for direct execution + sys.path.insert(0, str(Path(__file__).parent.parent)) + from sample_remote_executor import SampleRemoteExecutor + + +def main(): + parser = argparse.ArgumentParser( + description="gRPC Client for remote model execution", + formatter_class=argparse.RawDescriptionHelpFormatter, ) + parser.add_argument( + "--machine", + type=str, + default="localhost", + help="Remote server address (default: localhost)", + ) + parser.add_argument( + "--port", + type=int, + default=50052, + help="gRPC server port (default: 50052)", + ) + parser.add_argument( + "--model-path", + type=str, + required=True, + help="Path to model directory containing model.py and weight_meta.py", + ) + parser.add_argument( + "--random-seed", + type=int, + default=42, + help="Random seed for reproducible inference (default: 42)", + ) + parser.add_argument( + "--rpc-cmd", + type=str, + default="python3 /denghaodong/code/GraphNet/graph_net/graph_net_bench/sample_rpc_cmd.py", + help="Command to execute on remote server", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory to save output tensors (default: current directory)", + ) -def run(server_ip: str = "localhost", port: int = 50052, timeout: float = 10.0, - rpc_cmd: str = "execute_model", output_file_name: str = "42", - model_path: str = None): - channel = grpc.insecure_channel( - f"{server_ip}:{port}", - options=[ - ('grpc.max_send_message_length', 100 * 1024 * 1024), # 100MB - ('grpc.max_receive_message_length', 100 * 1024 * 1024), # 100MB - ] + args = parser.parse_args() + + executor = SampleRemoteExecutor( + machine=args.machine, + port=args.port, + rpc_cmd=args.rpc_cmd, ) - stub = message_pb2_grpc.SampleRemoteExecutorStub(channel) - - if model_path: - # 发送压缩的模型数据 - compressed_data = _compress_model(model_path) - request = message_pb2.ExecutionRequest( - rpc_cmd=rpc_cmd, - rpc_input=message_pb2.RpcData(compressed_data=compressed_data), - output_file_name=output_file_name - ) - else: - # 发送简单的字符串数据(用于测试连通性) - request = message_pb2.ExecutionRequest( - rpc_cmd=rpc_cmd, - rpc_input=message_pb2.RpcData(str_data="test data"), - output_file_name=output_file_name - ) try: - response = stub.Execute(request, timeout=timeout) - print(f"ret_code: {response.ret_code}") - # print(f"stdout: {response.stdout}") - print(f"stderr: {response.stderr}") - if response.stdout: - print(f"rpc_output: {response.rpc_output}") - except grpc.RpcError as e: - print(f"gRPC 调用失败: {e.code()}: {e.details()}") - raise + print(f"Sending request to {args.machine}:{args.port}...", file=sys.stderr) + tensors = executor(args.model_path, args.random_seed) + print(f"Received {len(tensors)} output tensors:", file=sys.stderr) + for i, tensor in enumerate(tensors): + print(f" output_{i}: shape={tensor.shape}, dtype={tensor.dtype}", file=sys.stderr) -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="gRPC Client") - parser.add_argument("--server-ip", type=str, default="localhost") - parser.add_argument("--port", type=int, default=50052) - parser.add_argument("--timeout", type=float, default=10.0) - parser.add_argument("--rpc-cmd", type=str, default="execute_model") - parser.add_argument("--output-file-name", type=str, default="42") - parser.add_argument("--model-path", type=str, default=None, - help="Path to model directory (optional)") - args = parser.parse_args() + if args.output_dir: + import os + from pathlib import Path + import torch + + output_path = Path(args.output_dir) + output_path.mkdir(parents=True, exist_ok=True) - run(server_ip=args.server_ip, port=args.port, timeout=args.timeout, - rpc_cmd=args.rpc_cmd, output_file_name=args.output_file_name, - model_path=args.model_path) + for i, tensor in enumerate(tensors): + output_file = output_path / f"output_{i}.pt" + torch.save(tensor, output_file) + print(f"Saved output_{i} to {output_file}", file=sys.stderr) + + print("Execution completed successfully!", file=sys.stderr) + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + finally: + executor.close() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/graph_net/graph_net_bench/grpc/message.proto b/graph_net/graph_net_bench/grpc/message.proto index 4c6656f19..411f93b50 100644 --- a/graph_net/graph_net_bench/grpc/message.proto +++ b/graph_net/graph_net_bench/grpc/message.proto @@ -21,6 +21,7 @@ message ExecutionRequest { string rpc_cmd = 1; RpcData rpc_input = 2; optional string output_file_name = 3; + int64 random_seed = 4; } message ExecutionReply { @@ -35,3 +36,5 @@ service SampleRemoteExecutor { } // python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. message.proto + + diff --git a/graph_net/graph_net_bench/grpc/message_pb2.py b/graph_net/graph_net_bench/grpc/message_pb2.py index 52d793731..7d49fa7d1 100644 --- a/graph_net/graph_net_bench/grpc/message_pb2.py +++ b/graph_net/graph_net_bench/grpc/message_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rmessage.proto\x12\x16sample_remote_executor\"d\n\x0e\x43ompressedData\x12\x10\n\x08\x66ilename\x18\x01 \x01(\t\x12\x15\n\roriginal_size\x18\x02 \x01(\r\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12\x18\n\x10\x63ompression_algo\x18\x04 \x01(\t\"\x85\x01\n\x07RpcData\x12\x41\n\x0f\x63ompressed_data\x18\x01 \x01(\x0b\x32&.sample_remote_executor.CompressedDataH\x00\x12\x12\n\x08str_data\x18\x02 \x01(\tH\x00\x12\x12\n\x08npz_data\x18\x03 \x01(\x0cH\x00\x42\x0f\n\rrpc_data_type\"\x8b\x01\n\x10\x45xecutionRequest\x12\x0f\n\x07rpc_cmd\x18\x01 \x01(\t\x12\x32\n\trpc_input\x18\x02 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData\x12\x1d\n\x10output_file_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x13\n\x11_output_file_name\"w\n\x0e\x45xecutionReply\x12\x10\n\x08ret_code\x18\x01 \x01(\x03\x12\x0e\n\x06stdout\x18\x02 \x01(\t\x12\x0e\n\x06stderr\x18\x03 \x01(\t\x12\x33\n\nrpc_output\x18\x04 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData2s\n\x14SampleRemoteExecutor\x12[\n\x07\x45xecute\x12(.sample_remote_executor.ExecutionRequest\x1a&.sample_remote_executor.ExecutionReplyb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rmessage.proto\x12\x16sample_remote_executor\"d\n\x0e\x43ompressedData\x12\x10\n\x08\x66ilename\x18\x01 \x01(\t\x12\x15\n\roriginal_size\x18\x02 \x01(\r\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12\x18\n\x10\x63ompression_algo\x18\x04 \x01(\t\"\x85\x01\n\x07RpcData\x12\x41\n\x0f\x63ompressed_data\x18\x01 \x01(\x0b\x32&.sample_remote_executor.CompressedDataH\x00\x12\x12\n\x08str_data\x18\x02 \x01(\tH\x00\x12\x12\n\x08npz_data\x18\x03 \x01(\x0cH\x00\x42\x0f\n\rrpc_data_type\"\xa0\x01\n\x10\x45xecutionRequest\x12\x0f\n\x07rpc_cmd\x18\x01 \x01(\t\x12\x32\n\trpc_input\x18\x02 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData\x12\x1d\n\x10output_file_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x0brandom_seed\x18\x04 \x01(\x03\x42\x13\n\x11_output_file_name\"w\n\x0e\x45xecutionReply\x12\x10\n\x08ret_code\x18\x01 \x01(\x03\x12\x0e\n\x06stdout\x18\x02 \x01(\t\x12\x0e\n\x06stderr\x18\x03 \x01(\t\x12\x33\n\nrpc_output\x18\x04 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData2s\n\x14SampleRemoteExecutor\x12[\n\x07\x45xecute\x12(.sample_remote_executor.ExecutionRequest\x1a&.sample_remote_executor.ExecutionReplyb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -36,9 +36,9 @@ _globals['_RPCDATA']._serialized_start=144 _globals['_RPCDATA']._serialized_end=277 _globals['_EXECUTIONREQUEST']._serialized_start=280 - _globals['_EXECUTIONREQUEST']._serialized_end=419 - _globals['_EXECUTIONREPLY']._serialized_start=421 - _globals['_EXECUTIONREPLY']._serialized_end=540 - _globals['_SAMPLEREMOTEEXECUTOR']._serialized_start=542 - _globals['_SAMPLEREMOTEEXECUTOR']._serialized_end=657 + _globals['_EXECUTIONREQUEST']._serialized_end=440 + _globals['_EXECUTIONREPLY']._serialized_start=442 + _globals['_EXECUTIONREPLY']._serialized_end=561 + _globals['_SAMPLEREMOTEEXECUTOR']._serialized_start=563 + _globals['_SAMPLEREMOTEEXECUTOR']._serialized_end=678 # @@protoc_insertion_point(module_scope) diff --git a/graph_net/graph_net_bench/grpc/server.py b/graph_net/graph_net_bench/grpc/server.py index 9a81c0596..5dc3710b5 100644 --- a/graph_net/graph_net_bench/grpc/server.py +++ b/graph_net/graph_net_bench/grpc/server.py @@ -1,12 +1,15 @@ -import grpc -from concurrent import futures +import os +import sys +import subprocess import tempfile import shutil import tarfile -from pathlib import Path +from concurrent import futures from io import BytesIO -import subprocess -import os +from pathlib import Path + +import grpc + import message_pb2 import message_pb2_grpc @@ -14,98 +17,105 @@ class RemoteModelExecutorServicer(message_pb2_grpc.SampleRemoteExecutorServicer): def Execute(self, request, context): - input_workspace = None - output_workspace = None + input_workspace = tempfile.mkdtemp(prefix="remote_input_") + output_workspace = tempfile.mkdtemp(prefix="remote_output_") try: - input_workspace = tempfile.mkdtemp(prefix="input_workspace_") - output_workspace = tempfile.mkdtemp(prefix="output_workspace_") + # 0) 基本校验 + if not request.rpc_cmd: + return message_pb2.ExecutionReply(ret_code=-1, stderr="rpc_cmd is empty") + + if not request.HasField("output_file_name") or not request.output_file_name: + return message_pb2.ExecutionReply(ret_code=-1, stderr="output_file_name is required") + + if request.rpc_input.WhichOneof("rpc_data_type") != "compressed_data": + return message_pb2.ExecutionReply( + ret_code=-1, + stderr="rpc_input must be RpcData.compressed_data (tar.gz bytes)", + ) - self._decompress_model(request.rpc_input.compressed_data, input_workspace) + # 1) 解压输入到 input_workspace + self._decompress_to_dir(request.rpc_input.compressed_data, input_workspace) - # 3. 构建 rpc_cmd 脚本路径 - rpc_cmd_path = os.path.join(input_workspace, request.rpc_cmd) + # 2) 执行 rpc_cmd + out_path = Path(output_workspace) / request.output_file_name - # 4. 设置环境变量并执行 rpc_cmd 脚本 env = os.environ.copy() env["INPUT_WORKSPACE"] = input_workspace env["OUTPUT_WORKSPACE"] = output_workspace env["OUTPUT_FILE_NAME"] = request.output_file_name - - result = subprocess.run( - ["python3", rpc_cmd_path], + env["OUTPUT_FILE_PATH"] = str(out_path) + env["RANDOM_SEED"] = str(request.random_seed) + # Add grpc directory to PYTHONPATH so message_pb2 can be imported + grpc_dir = Path(__file__).parent.resolve() + env["PYTHONPATH"] = f"{grpc_dir}:{env.get('PYTHONPATH', '')}" + + print(f"Executing rpc_cmd: {request.rpc_cmd}", file=sys.stderr) + print(f"Working directory: {input_workspace}", file=sys.stderr) + proc = subprocess.run( + request.rpc_cmd, + shell=True, + cwd=input_workspace, + env=env, capture_output=True, text=True, - env=env, - timeout=300 # 5分钟超时 ) - if result.returncode != 0: + print(f"returncode: {proc.returncode}", file=sys.stderr) + print(f"stdout: {proc.stdout}", file=sys.stderr) + print(f"stderr: {proc.stderr}", file=sys.stderr) + + if proc.returncode != 0: return message_pb2.ExecutionReply( - ret_code=-1, - stderr=f"rpc_cmd 执行失败:\n{result.stderr}" + ret_code=proc.returncode, + stdout=proc.stdout or "", + stderr=proc.stderr or f"rpc_cmd failed with returncode={proc.returncode}", ) - # 5. 读取 output_workspace/{output_file_name} 到 reply.rpc_output - output_file_path = os.path.join( - output_workspace, - request.output_file_name - ) - - if not os.path.exists(output_file_path): + # 3) 回读输出文件 + if not out_path.exists(): + print(f"Output file not found at {out_path}", file=sys.stderr) + print(f"Contents of output_workspace: {list(Path(output_workspace).rglob('*'))}", file=sys.stderr) return message_pb2.ExecutionReply( ret_code=-1, - stderr=f"输出文件不存在: {output_file_path}" + stdout=proc.stdout or "", + stderr=(proc.stderr or "") + f"\nExpected output not found: {out_path}", ) - # 读取 .npz 文件并返回 + payload = out_path.read_bytes() + return message_pb2.ExecutionReply( ret_code=0, - stdout=result.stdout, - stderr=result.stderr, + stdout=proc.stdout or "", + stderr=proc.stderr or "", rpc_output=message_pb2.RpcData( - npz_data=Path(output_file_path).read_bytes() - ) - ) - - except subprocess.TimeoutExpired: - return message_pb2.ExecutionReply( - ret_code=-1, - stderr="rpc_cmd 执行超时(5分钟)" + compressed_data=message_pb2.CompressedData( + filename=request.output_file_name, + original_size=len(payload), + payload=payload, + compression_algo="raw", + ) + ), ) except Exception as e: import traceback - return message_pb2.ExecutionReply( - ret_code=-1, - stderr=f"{str(e)}\n{traceback.format_exc()}" - ) - + return message_pb2.ExecutionReply(ret_code=-1, stderr=f"{e}\n{traceback.format_exc()}") finally: - # 6. 清理临时目录 - if input_workspace and os.path.exists(input_workspace): - shutil.rmtree(input_workspace, ignore_errors=True) - if output_workspace and os.path.exists(output_workspace): - shutil.rmtree(output_workspace, ignore_errors=True) + shutil.rmtree(input_workspace, ignore_errors=True) + shutil.rmtree(output_workspace, ignore_errors=True) - def _decompress_model(self, compressed_data, target_dir): + def _decompress_to_dir(self, compressed_data: message_pb2.CompressedData, dst_dir: str) -> None: buffer = BytesIO(compressed_data.payload) with tarfile.open(fileobj=buffer, mode="r:gz") as tar: - tar.extractall(path=target_dir) + tar.extractall(path=dst_dir) def serve(port=50052, max_workers=4): - """启动 gRPC 服务器""" - # 增加消息大小限制(支持最多 100MB) - server = grpc.server( - futures.ThreadPoolExecutor(max_workers=max_workers), - options=[ - ('grpc.max_send_message_length', 100 * 1024 * 1024), # 100MB - ('grpc.max_receive_message_length', 100 * 1024 * 1024), # 100MB - ] - ) + server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers)) message_pb2_grpc.add_SampleRemoteExecutorServicer_to_server( - RemoteModelExecutorServicer(), server + RemoteModelExecutorServicer(), + server, ) server.add_insecure_port(f"0.0.0.0:{port}") print(f"Server started on port {port}...") @@ -119,4 +129,4 @@ def serve(port=50052, max_workers=4): parser.add_argument("--port", type=int, default=50052) parser.add_argument("--max-workers", type=int, default=4) args = parser.parse_args() - serve(port=args.port, max_workers=args.max_workers) + serve(port=args.port, max_workers=args.max_workers) \ No newline at end of file diff --git a/graph_net/graph_net_bench/sample_remote_executor.py b/graph_net/graph_net_bench/sample_remote_executor.py index 57e33e0dd..282a0fc52 100644 --- a/graph_net/graph_net_bench/sample_remote_executor.py +++ b/graph_net/graph_net_bench/sample_remote_executor.py @@ -1,104 +1,111 @@ import grpc import tarfile -import json from pathlib import Path from io import BytesIO from typing import Tuple, Optional -from contextlib import contextmanager + import torch +try: + from .grpc import message_pb2 + from .grpc import message_pb2_grpc +except ImportError: + import message_pb2 + import message_pb2_grpc + class SampleRemoteExecutor: - - def __init__(self, machine: str, port: int): + + def __init__(self, machine: str, port: int, rpc_cmd: str = "python3 /denghaodong/code/GraphNet/graph_net/graph_net_bench/sample_rpc_cmd.py"): self.machine = machine self.port = port + self.rpc_cmd = rpc_cmd self._channel: Optional[grpc.Channel] = None self._stub = None - + def _get_stub(self): if self._stub is None: - from .grpc import message_pb2, message_pb2_grpc - self._channel = grpc.insecure_channel( - f"{self.machine}:{self.port}", - options=[ - ("grpc.max_receive_message_length", 32 * 1024 * 1024), # 32MB - ("grpc.max_send_message_length", 32 * 1024 * 1024), - ],) + self._channel = grpc.insecure_channel(f"{self.machine}:{self.port}") self._stub = message_pb2_grpc.SampleRemoteExecutorStub(self._channel) return self._stub - + def __call__(self, model_path: str, random_seed: int) -> Tuple[torch.Tensor, ...]: - """远程执行模型""" - from .grpc import message_pb2 - - compressed_data = self._compress_model(model_path) - - stub = self._get_stub() + + compressed_data = self._compress_dir(model_path) + + # 输出文件名必须包含扩展名(mentor 约定) + output_file_name = f"outputs_seed_{random_seed}.npz" + request = message_pb2.ExecutionRequest( - rpc_cmd="execute_model", + rpc_cmd=self.rpc_cmd, rpc_input=message_pb2.RpcData(compressed_data=compressed_data), - output_file_name=str(random_seed) + output_file_name=output_file_name, + random_seed=int(random_seed), ) - + + stub = self._get_stub() reply = stub.Execute(request) - + if reply.ret_code != 0: - raise RuntimeError(f"Remote execution failed: {reply.stderr}") - - # 从 rpc_output.npz_data 读取 .npz 文件 - if reply.rpc_output.HasField("npz_data"): - return self._load_npz_from_bytes(reply.rpc_output.npz_data) - else: - raise RuntimeError(f"Invalid reply: expected npz_data in rpc_output") - - def _compress_model(self, model_path: str): - from .grpc import message_pb2 - + raise RuntimeError( + "Remote execution failed:\n" + f"ret_code={reply.ret_code}\n" + f"stdout:\n{reply.stdout}\n" + f"stderr:\n{reply.stderr}\n" + ) + + if reply.rpc_output.WhichOneof("rpc_data_type") != "compressed_data": + raise RuntimeError("Remote execution succeeded but rpc_output is not compressed_data") + + npz_bytes = reply.rpc_output.compressed_data.payload + return self._npz_bytes_to_tensors(npz_bytes) + + def _compress_dir(self, model_path: str): buffer = BytesIO() + model_dir = Path(model_path) + with tarfile.open(fileobj=buffer, mode="w:gz") as tar: - model_dir = Path(model_path) for item in model_dir.rglob("*"): if item.is_file(): arcname = item.relative_to(model_dir) tar.add(item, arcname=arcname) - + compressed_bytes = buffer.getvalue() - + return message_pb2.CompressedData( - filename=f"{Path(model_path).name}.tar.gz", + filename=f"{model_dir.name}.tar.gz", original_size=len(compressed_bytes), payload=compressed_bytes, - compression_algo="gzip" + compression_algo="gzip", ) - - def _load_npz_from_bytes(self, npz_bytes: bytes) -> Tuple[torch.Tensor, ...]: - """从 bytes 加载 .npz 文件并转换为张量元组""" + + def _npz_bytes_to_tensors(self, npz_bytes: bytes) -> Tuple[torch.Tensor, ...]: import numpy as np - from io import BytesIO - - # 将 bytes 写入临时内存文件 - with BytesIO(npz_bytes) as f: - npz = np.load(f, allow_pickle=True) - - result = [] - # 按字母顺序读取所有数组(保持一致性) - for key in sorted(npz.files): - arr = npz[key] - result.append(torch.from_numpy(arr)) - - npz.close() - - return tuple(result) - + + with np.load(BytesIO(npz_bytes), allow_pickle=False) as npz: + # 只接受 output_{i} 格式,按 i 排序,保证返回 tuple 稳定 + keys = [k for k in npz.files if k.startswith("output_")] + + def key_index(k: str) -> int: + # "output_0" -> 0 + return int(k.split("_", 1)[1]) + + keys.sort(key=key_index) + + tensors = [] + for k in keys: + arr = npz[k] + tensors.append(torch.from_numpy(arr)) # 默认 CPU + return tuple(tensors) + def close(self): if self._channel is not None: self._channel.close() self._channel = None self._stub = None - + def __enter__(self): return self - + def __exit__(self, exc_type, exc_val, exc_tb): - self.close() + self.close() \ No newline at end of file diff --git a/graph_net/graph_net_bench/sample_rpc_cmd.py b/graph_net/graph_net_bench/sample_rpc_cmd.py index 860d68dbe..c293a6b1a 100644 --- a/graph_net/graph_net_bench/sample_rpc_cmd.py +++ b/graph_net/graph_net_bench/sample_rpc_cmd.py @@ -2,7 +2,6 @@ import os import sys -import json import importlib.util from pathlib import Path @@ -11,18 +10,17 @@ def load_model_and_weights(model_path: str): - model_file = Path(model_path) / "model.py" if not model_file.exists(): raise FileNotFoundError(f"model.py not found in {model_path}") - + spec = importlib.util.spec_from_file_location("remote_model_module", str(model_file)) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + if not hasattr(module, 'GraphModule'): raise ValueError("model.py must define 'GraphModule' class") - + weight_tensors = {} weight_meta_file = Path(model_path) / "weight_meta.py" if weight_meta_file.exists(): @@ -31,22 +29,22 @@ def load_model_and_weights(model_path: str): ) weight_meta_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(weight_meta_module) - + tensor_classes = [ getattr(weight_meta_module, name) for name in dir(weight_meta_module) if name.startswith("Program_weight_tensor_meta_") ] - + for tensor_cls in tensor_classes: name = tensor_cls.name shape = tensor_cls.shape dtype_name = tensor_cls.dtype.replace("torch.", "") dtype = getattr(torch, dtype_name) device = getattr(tensor_cls, 'device', 'cpu') - + if tensor_cls.data is not None: - np_array = np.array(tensor_cls.data, dtype=np.dtype(dtype.__name__)) + np_array = np.array(tensor_cls.data, dtype=np.dtype(dtype_name)) np_array = np_array.reshape(shape) weight_tensors[name] = torch.from_numpy(np_array).to(device) else: @@ -54,49 +52,49 @@ def load_model_and_weights(model_path: str): weight_tensors[name] = torch.zeros(shape, dtype=dtype, device=device) else: weight_tensors[name] = torch.randn(shape, dtype=dtype, device=device) - + model = module.GraphModule() for name, tensor in weight_tensors.items(): param = getattr(model, name, None) if param is not None and isinstance(param, torch.Tensor): param.data.copy_(tensor) - + return model, weight_tensors def get_forward_inputs(model_path: str, weight_tensors: dict): import inspect - + model_file = Path(model_path) / "model.py" spec = importlib.util.spec_from_file_location("remote_model_module", str(model_file)) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + forward_params = inspect.signature(module.GraphModule.forward).parameters param_names = [ name for name in forward_params.keys() if name != 'self' ] - + inputs = [] for param_name in param_names: if param_name in weight_tensors: inputs.append(weight_tensors[param_name]) else: raise ValueError(f"Missing weight tensor for parameter: {param_name}") - + return inputs def save_outputs_as_npz(outputs, output_path: str): if not isinstance(outputs, tuple): outputs = (outputs,) - + np_arrays = {} for i, tensor in enumerate(outputs): key = f"output_{i}" np_arrays[key] = tensor.cpu().numpy() - + np.savez(output_path, **np_arrays) @@ -104,32 +102,43 @@ def main(): input_workspace = os.environ.get("INPUT_WORKSPACE") output_workspace = os.environ.get("OUTPUT_WORKSPACE") output_file_name = os.environ.get("OUTPUT_FILE_NAME") - + output_file_path = os.environ.get("OUTPUT_FILE_PATH") + seed_str = os.environ.get("RANDOM_SEED") + if not input_workspace: raise RuntimeError("INPUT_WORKSPACE environment variable not set") if not output_workspace: raise RuntimeError("OUTPUT_WORKSPACE environment variable not set") if not output_file_name: raise RuntimeError("OUTPUT_FILE_NAME environment variable not set") - + print(f"INPUT_WORKSPACE: {input_workspace}", file=sys.stderr) print(f"OUTPUT_WORKSPACE: {output_workspace}", file=sys.stderr) print(f"OUTPUT_FILE_NAME: {output_file_name}", file=sys.stderr) - + if output_file_path: + print(f"OUTPUT_FILE_PATH: {output_file_path}", file=sys.stderr) + + if seed_str: + seed = int(seed_str) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + print(f"RANDOM_SEED: {seed}", file=sys.stderr) + print("Loading model and weights...", file=sys.stderr) model, weight_tensors = load_model_and_weights(input_workspace) print(f"Model loaded, {len(weight_tensors)} weight tensors", file=sys.stderr) - + print("Preparing inputs...", file=sys.stderr) inputs = get_forward_inputs(input_workspace, weight_tensors) print(f"Prepared {len(inputs)} inputs for forward()", file=sys.stderr) - + print("Running inference...", file=sys.stderr) model.eval() with torch.no_grad(): outputs = model(*inputs) - - output_path = os.path.join(output_workspace, output_file_name) + + output_path = output_file_path or os.path.join(output_workspace, output_file_name) print(f"Saving outputs to {output_path}...", file=sys.stderr) save_outputs_as_npz(outputs, output_path) print("Outputs saved successfully!", file=sys.stderr) From 9925ca3c4de2fa60568ef97f472bb298f155b768 Mon Sep 17 00:00:00 2001 From: Denghaodong Date: Thu, 8 Jan 2026 08:41:35 +0000 Subject: [PATCH 04/11] remove DESIGN.md --- DESIGN.md | 705 ------------------------------------------------------ 1 file changed, 705 deletions(-) delete mode 100644 DESIGN.md diff --git a/DESIGN.md b/DESIGN.md deleted file mode 100644 index a597f2818..000000000 --- a/DESIGN.md +++ /dev/null @@ -1,705 +0,0 @@ -# GraphNet 远程模型测试框架 - 设计文档 v4 - -## 一、设计原则 - -**核心思想**:服务端是**通用的模型执行引擎**,不依赖 GraphNet 代码。客户端负责调用 GraphNet 测试框架,通过 RPC 远程执行模型。 - -**重要约束**:严格基于现有的 `message.proto`,不修改协议定义。 - ---- - -## 二、系统架构 - -``` -┌─────────────────────────────────────────────────────────────┐ -│ Client │ -│ ┌──────────────────────────────────────────────────────┐ │ -│ │ GraphNet Test Framework │ │ -│ │ ┌────────────────────────────────────────────┐ │ │ -│ │ │ test_compiler.py │ │ │ -│ │ │ - 性能测试 (warmup + trials) │ │ │ -│ │ │ - 正确性验证 │ │ │ -│ │ └────────────────────────────────────────────┘ │ │ -│ └───────────────────┬──────────────────────────────────┘ │ -│ │ │ -│ ┌───────────────────▼──────────────────────────────────┐ │ -│ │ SampleRemoteExecutor (graph_net_bench) │ │ -│ │ - 打包模型目录 (tar.gz) │ │ -│ │ - 通过 RPC 执行模型 │ │ -│ │ - 返回 tuple[Tensor] 结果 │ │ -│ └───────────────────┬──────────────────────────────────┘│ -│ │ gRPC (message.proto) │ -└──────────────────────┼───────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ Server Machine (无 GraphNet 依赖) │ -│ ┌──────────────────────────────────────────────────────┐ │ -│ │ gRPC Server (基于现有 proto) │ │ -│ │ - 接收 ExecutionRequest │ │ -│ │ - 解压并执行模型 │ │ -│ │ - 返回 ExecutionReply │ │ -│ └───────────────────────────────────────────────────────┘ │ -└─────────────────────────────────────────────────────────────┘ -``` - ---- - -## 三、现有协议 (message.proto) - -```protobuf -syntax = "proto3"; - -package sample_remote_executor; - -message CompressedData { - string filename = 1; - uint32 original_size = 2; - bytes payload = 3; - string compression_algo = 4; -} - -message RpcData { - oneof rpc_data_type { - CompressedData compressed_data = 1; - string str_data = 2; - } -} - -message ExecutionRequest { - string rpc_cmd = 1; - RpcData rpc_input = 2; - optional string output_file_name = 3; -} - -message ExecutionReply { - int64 ret_code = 1; - string stdout = 2; - string stderr = 3; - RpcData rpc_output = 4; -} - -service SampleRemoteExecutor { - rpc Execute (ExecutionRequest) returns (ExecutionReply); -} -``` - -### 协议使用约定 - -| 字段 | 用途 | -|------|------| -| `rpc_cmd` | 命令标识:"execute_model" | -| `rpc_input.compressed_data` | 压缩的模型目录 (tar.gz) | -| `output_file_name` | random_seed (字符串形式) | -| `stdout` | 序列化的输出张量列表 (JSON) | -| `ret_code` | 0=成功, 非0=失败 | - ---- - -## 四、客户端实现 - -### 4.1 SampleRemoteExecutor 类 - -**文件**: `graph_net/graph_net_bench/sample_remote_executor.py` - -**状态**: ✅ 已实现 - -```python -""" -SampleRemoteExecutor: 远程模型执行器 - -使用方式: - import graph_net_bench as gnb - - # 创建执行器 - sample_remote_executor = gnb.SampleRemoteExecutor(machine="192.168.1.100", port=50052) - - # 直接调用,返回 tuple[Tensor, ...] - ret: tuple[torch.Tensor, ...] = sample_remote_executor(sample_model_path, random_seed=42) - - # 支持上下文管理器 - with sample_remote_executor: - outputs = sample_remote_executor(model_path, random_seed=1024) -""" - -import grpc -import tarfile -import json -from pathlib import Path -from io import BytesIO -from typing import Tuple, Optional -from contextlib import contextmanager - -import torch - - -class SampleRemoteExecutor: - """远程模型执行器 - - 通过 gRPC 在远程服务器上执行模型推理。 - - Attributes: - machine: 服务器 IP 地址 - port: 服务器端口 - channel: gRPC 通道 - stub: gRPC 存根 - """ - - def __init__(self, machine: str, port: int): - """ - Args: - machine: 服务器 IP 地址 - port: 服务器端口 - """ - self.machine = machine - self.port = port - self._channel: Optional[grpc.Channel] = None - self._stub = None - - def _get_stub(self): - """获取 gRPC 存根(延迟初始化)""" - if self._stub is None: - from .grpc import message_pb2, message_pb2_grpc - self._channel = grpc.insecure_channel(f"{self.machine}:{self.port}") - self._stub = message_pb2_grpc.SampleRemoteExecutorStub(self._channel) - return self._stub - - @contextmanager - def __call__(self, model_path: str, random_seed: int) -> Tuple[torch.Tensor, ...]: - """ - 远程执行模型 - - Args: - model_path: 模型目录路径 - random_seed: 随机种子,用于生成可复现的输入 - - Returns: - tuple[Tensor, ...]: 模型输出张量 - - Raises: - RuntimeError: 远程执行失败 - """ - # 1. 压缩模型目录 - compressed_data = self._compress_model(model_path) - - # 2. 构造 RPC 请求 - from .grpc import message_pb2 - - stub = self._get_stub() - request = message_pb2.ExecutionRequest( - rpc_cmd="execute_model", - rpc_input=message_pb2.RpcData( - compressed_data=compressed_data - ), - # 使用 output_file_name 字段传递 random_seed - output_file_name=str(random_seed) - ) - - # 3. 发送请求 - reply = stub.Execute(request) - - # 4. 解析结果 - if reply.ret_code != 0: - raise RuntimeError(f"Remote execution failed: {reply.stderr}") - - # stdout 包含序列化的张量数据 (JSON 格式) - return self._deserialize_tensors(reply.stdout) - - def __enter__(self): - """支持上下文管理器""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """清理资源""" - self.close() - - def close(self): - """关闭 gRPC 通道""" - if self._channel is not None: - self._channel.close() - self._channel = None - self._stub = None - - def _compress_model(self, model_path: str): - """压缩模型目录为 CompressedData - - Args: - model_path: 模型目录路径 - - Returns: - CompressedData: 压缩的模型数据 - """ - from .grpc import message_pb2 - - buffer = BytesIO() - with tarfile.open(fileobj=buffer, mode="w:gz") as tar: - model_dir = Path(model_path) - for item in model_dir.rglob("*"): - if item.is_file(): - arcname = item.relative_to(model_dir) - tar.add(item, arcname=arcname) - - compressed_bytes = buffer.getvalue() - - return message_pb2.CompressedData( - filename=f"{Path(model_path).name}.tar.gz", - original_size=len(compressed_bytes), - payload=compressed_bytes, - compression_algo="gzip" - ) - - def _deserialize_tensors(self, json_str: str) -> Tuple[torch.Tensor, ...]: - """从 JSON 反序列化张量列表 - - Args: - json_str: JSON 格式的张量数据 - - Returns: - tuple[Tensor, ...]: 张量元组 - """ - import numpy as np - - data = json.loads(json_str) - result = [] - - for tensor_data in data: - dtype = getattr(torch, tensor_data["dtype"]) - shape = tuple(tensor_data["shape"]) - # 处理不同数据类型的序列化 - if tensor_data["data"] is None: - np_array = np.zeros(shape, dtype=dtype.__name__) - else: - np_array = np.frombuffer( - tensor_data["data"].encode("latin1"), - dtype=np.dtype(dtype.__name__.replace("torch.", "")) - ) - np_array = np_array.reshape(shape) - result.append(torch.from_numpy(np_array)) - - return tuple(result) -``` - -### 4.2 使用示例 - -```python -# 方式一:直接调用 -import graph_net_bench as gnb - -# 创建执行器 -executor = gnb.SampleRemoteExecutor(machine="192.168.1.100", port=50052) - -# 直接调用,返回 tuple[Tensor, ...] -ret: tuple[torch.Tensor, ...] = executor("/path/to/model", random_seed=42) - -# 方式二:使用上下文管理器(推荐,自动关闭连接) -with gnb.SampleRemoteExecutor(machine="192.168.1.100", port=50052) as executor: - outputs = executor("/path/to/model", random_seed=1024) - -# 方式三:在 GraphNet 测试框架中使用 -def test_single_model_remote(args): - import graph_net_bench as gnb - - executor = gnb.SampleRemoteExecutor( - machine=args.remote_machine, - port=args.remote_port - ) - - with executor: - # 远程执行模型 - eager_out = executor(args.model_path, random_seed=1024) - - # 本地进行正确性对比 - compare_correctness(eager_out, compiled_out, args) -``` - ---- - -## 五、服务端实现 (TODO) - -### 5.1 remote_model_server.py - -**文件**: `graph_net/graph_net_bench/server/remote_model_server.py` - -**状态**: ⏳ 待实现 - -```python -import grpc -from concurrent import futures -import tempfile -import shutil -import tarfile -import json -import torch -from pathlib import Path -from io import BytesIO - -import message_pb2 -import message_pb2_grpc - - -class RemoteModelExecutorServicer(message_pb2_grpc.SampleRemoteExecutorServicer): - """远程模型执行服务""" - - def Execute(self, request, context): - """ - 执行模型推理 - - Args: - request: ExecutionRequest - - rpc_cmd: "execute_model" - - rpc_input.compressed_data: 压缩的模型 - - output_file_name: random_seed (字符串) - - Returns: - ExecutionReply - - ret_code: 0=成功 - - stdout: 序列化的输出张量 (JSON) - - stderr: 错误信息 - """ - temp_dir = None - try: - # 1. 解析参数 - if request.rpc_cmd != "execute_model": - return message_pb2.ExecutionReply( - ret_code=-1, - stderr=f"Unknown rpc_cmd: {request.rpc_cmd}" - ) - - random_seed = int(request.output_file_name) - - # 2. 解压模型到临时目录 - temp_dir = tempfile.mkdtemp(prefix="remote_model_") - model_path = self._decompress_model( - request.rpc_input.compressed_data, - temp_dir - ) - - # 3. 加载模型和权重 - model = self._load_model(model_path) - - # 4. 设置随机种子 - torch.manual_seed(random_seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(random_seed) - - # 5. 执行推理 - model.eval() - with torch.no_grad(): - outputs = model() - - # 6. 序列化输出 - if not isinstance(outputs, tuple): - outputs = (outputs,) - - json_output = self._serialize_tensors(outputs) - - return message_pb2.ExecutionReply( - ret_code=0, - stdout=json_output, - stderr="" - ) - - except Exception as e: - import traceback - return message_pb2.ExecutionReply( - ret_code=-1, - stderr=f"{str(e)}\n{traceback.format_exc()}" - ) - finally: - if temp_dir: - shutil.rmtree(temp_dir, ignore_errors=True) - - def _decompress_model(self, compressed_data, temp_dir): - """解压模型目录""" - buffer = BytesIO(compressed_data.payload) - with tarfile.open(fileobj=buffer, mode="r:gz") as tar: - tar.extractall(path=temp_dir) - return temp_dir - - def _load_model(self, model_path): - """加载模型 - - 模型目录结构: - - model.py: 定义 GraphModule 类 - - weight_meta.py: 权重元数据 (包含权重数据或生成参数) - """ - import sys - import importlib.util - - model_file = Path(model_path) / "model.py" - spec = importlib.util.spec_from_file_location( - "remote_model", - str(model_file) - ) - module = importlib.util.module_from_spec(spec) - sys.modules["remote_model"] = module - spec.loader.exec_module(module) - - # 动态创建权重并加载到模型 - weight_tensors = self._create_weight_tensors(model_path, module.GraphModule) - - model = module.GraphModule() - - # 加载权重 - # GraphModule 的参数名与 weight_meta 中的 name 匹配 - for name, tensor in weight_tensors.items(): - param = getattr(model, name, None) - if param is not None: - param.data.copy_(tensor) - - return model - - def _create_weight_tensors(self, model_path, graph_module_class): - """根据 weight_meta.py 创建权重张量 - - 从 weight_meta.py 中读取元数据,动态生成张量数据。 - 如果元数据中有 data,则使用实际数据;否则生成随机张量。 - """ - import importlib.util - import numpy as np - - weight_meta_file = Path(model_path) / "weight_meta.py" - if not weight_meta_file.exists(): - return {} - - # 导入 weight_meta 模块 - spec = importlib.util.spec_from_file_location( - "weight_meta", - str(weight_meta_file) - ) - weight_meta_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(weight_meta_module) - - weight_tensors = {} - tensor_classes = [ - getattr(weight_meta_module, name) - for name in dir(weight_meta_module) - if name.startswith("Program_weight_tensor_meta_") - ] - - for tensor_cls in tensor_classes: - name = tensor_cls.name - shape = tensor_cls.shape - dtype = getattr(torch, tensor_cls.dtype) - device = tensor_cls.device if hasattr(tensor_cls, 'device') else 'cpu' - - if tensor_cls.data is not None: - # 使用实际数据 - np_array = np.array(tensor_cls.data, dtype=np.dtype(dtype.__name__)) - np_array = np_array.reshape(shape) - weight_tensors[name] = torch.from_numpy(np_array).to(device) - else: - # 生成随机张量 - weight_tensors[name] = torch.randn(shape, dtype=dtype, device=device) - - return weight_tensors - - def _serialize_tensors(self, outputs): - """序列化张量列表为 JSON""" - tensor_list = [] - - for tensor in outputs: - tensor_data = { - "dtype": str(tensor.dtype), - "shape": list(tensor.shape), - "data": tensor.numpy().tobytes().decode("latin1") - } - tensor_list.append(tensor_data) - - return json.dumps(tensor_list) - - -def serve(port=50052, max_workers=4): - """启动 gRPC 服务器""" - server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers)) - message_pb2_grpc.add_SampleRemoteExecutorServicer_to_server( - RemoteModelExecutorServicer(), server - ) - server.add_insecure_port(f"0.0.0.0:{port}") - print(f"Server started on port {port}...") - server.start() - server.wait_for_termination() - - -if __name__ == "__main__": - serve() -``` - ---- - -## 六、目录结构 - -``` -graph_net/graph_net_bench/ -├── grpc/ -│ ├── message.proto # 保持不变 (协议定义) -│ ├── message_pb2.py # protobuf 自动生成 -│ ├── message_pb2_grpc.py # gRPC 存根自动生成 -│ ├── server.py # 保留 (简单 echo 测试) -│ └── client.py # 保留 (简单测试) -├── server/ -│ ├── __init__.py # 新增: 包初始化 -│ └── remote_model_server.py # ⏳ 远程模型服务 (待实现) -├── sample_remote_executor.py # ✅ 客户端实现 (已完成) -├── __init__.py # 新增: 包初始化,导出 SampleRemoteExecutor -└── DESIGN.md # 本设计文档 -``` - ---- - -## 七、关键实现细节 - -### 7.1 数据序列化方案 - -| 数据 | 序列化方式 | 协议字段 | -|------|------------|----------| -| 模型目录 | tar.gz → bytes | `rpc_input.compressed_data` | -| random_seed | 字符串 → `str()` | `output_file_name` | -| 输出张量 | JSON → stdout | `stdout` | - -### 7.2 模型目录结构 - -GraphNet 的 sample 目录结构: -``` -model_path/ -├── model.py # 定义 GraphModule 类 -├── weight_meta.py # 权重元数据 (包含形状、数据) -└── input_tensor_constraints.py # 可选: 输入约束 -``` - -**weight_meta.py 示例**: -```python -class Program_weight_tensor_meta_L_self_modules_classifier_parameters_bias_: - name = "L_self_modules_classifier_parameters_bias_" - shape = [2] - dtype = "torch.float32" - device = "cuda:0" - data = [0.0, 0.0] # 可为 None (表示随机) -``` - -### 7.3 限制与约束 - -| 约束 | 说明 | -|------|------| -| 模型大小 | 受 gRPC 消息大小限制 (默认 4MB) | -| 张量大小 | JSON 序列化效率较低,大张量需优化 | -| 并发 | 服务端 max_workers 控制并发数 | -| 依赖 | 服务端需要 torch, numpy, protobuf | - -### 7.4 未来优化 - -| 优化项 | 方案 | -|--------|------| -| 大模型传输 | 分块传输,使用 RpcData 流式处理 | -| 张量序列化 | 使用 protobuf bytes 替代 JSON | -| 结果缓存 | 缓存模型加载结果 | -| GPU 分配 | 支持指定设备 (cuda:0, cpu 等) | - ---- - -## 八、命令行使用 - -### 服务端 (远程机器) - -```bash -# 安装依赖 -pip install torch numpy grpcio grpcio-tools - -# 启动远程模型服务器 -cd /denghaodong/code/GraphNet/graph_net/graph_net_bench/server -python remote_model_server.py --port 50052 -``` - -### 客户端 (本地机器) - -```bash -# 安装依赖 -pip install torch numpy grpcio - -# 使用示例 -python -c " -import graph_net_bench as gnb - -# 创建执行器 -executor = gnb.SampleRemoteExecutor(machine='192.168.1.100', port=50052) - -# 远程执行模型 -outputs = executor('/path/to/sample/model', random_seed=42) -print(f'Received {len(outputs)} output tensors') - -# 关闭连接 -executor.close() -" -``` - ---- - -## 九、与 GraphNet 集成示例 - -### 9.1 独立使用 - -```python -import graph_net_bench as gnb -from pathlib import Path - -# 配置 -SERVER_IP = "192.168.1.100" -SERVER_PORT = 50052 -MODEL_PATH = "/path/to/transformers-auto-model/model_name" - -# 创建执行器 -executor = gnb.SampleRemoteExecutor(machine=SERVER_IP, port=SERVER_PORT) - -try: - # 执行远程推理 - outputs = executor(MODEL_PATH, random_seed=42) - print(f"Success: {len(outputs)} outputs") -finally: - executor.close() -``` - -### 9.2 在 test_compiler.py 中集成 - -```python -# graph_net/torch/test_compiler.py - -def test_single_model_remote(args): - """使用远程服务器测试单个模型""" - import graph_net_bench as gnb - - executor = gnb.SampleRemoteExecutor( - machine=args.remote_machine, - port=args.remote_port - ) - - with executor: - # 1. 远程执行 eager 模式 (基线) - eager_out = executor(args.model_path, random_seed=args.seed) - - # 2. 编译模型 (本地或远程) - compiled_out = compile_and_execute(args) - - # 3. 本地进行正确性对比 - compare_correctness(eager_out, compiled_out, args) - - return True -``` - ---- - -## 十、故障排除 - -| 问题 | 可能原因 | 解决方案 | -|------|----------|----------| -| 连接超时 | 网络不通/端口未开放 | 检查防火墙和服务器状态 | -| 内存不足 | 模型过大 | 使用更小的 batch size | -| 序列化失败 | 张量包含复杂类型 | 简化输出或使用分块传输 | -| 权重加载失败 | 参数名不匹配 | 检查 model.py 和 weight_meta.py | - ---- - -## 十一、参考资源 - -- [gRPC Python 文档](https://grpc.io/docs/languages/python/) -- [GraphNet 项目](https://github.com/PaddlePaddle/GraphNet) -- [PyTorch TorchScript](https://pytorch.org/docs/stable/jit.html) \ No newline at end of file From 398f78c3f430f6ea5e774f7c79ced634242a07af Mon Sep 17 00:00:00 2001 From: Denghaodong Date: Fri, 9 Jan 2026 08:08:06 +0000 Subject: [PATCH 05/11] refactor: reorganize code structure --- graph_net/graph_net_bench/__init__.py | 0 graph_net/graph_net_bench/_init_.py | 3 --- graph_net/graph_net_bench/{grpc => }/client.py | 11 ++--------- graph_net/graph_net_bench/grpc/__init__.py | 1 + graph_net/graph_net_bench/grpc/message_pb2_grpc.py | 2 +- graph_net/graph_net_bench/sample_remote_executor.py | 8 ++------ graph_net/graph_net_bench/{grpc => }/server.py | 6 +++--- 7 files changed, 9 insertions(+), 22 deletions(-) create mode 100644 graph_net/graph_net_bench/__init__.py delete mode 100644 graph_net/graph_net_bench/_init_.py rename graph_net/graph_net_bench/{grpc => }/client.py (87%) create mode 100644 graph_net/graph_net_bench/grpc/__init__.py rename graph_net/graph_net_bench/{grpc => }/server.py (96%) diff --git a/graph_net/graph_net_bench/__init__.py b/graph_net/graph_net_bench/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graph_net/graph_net_bench/_init_.py b/graph_net/graph_net_bench/_init_.py deleted file mode 100644 index 95b260bf5..000000000 --- a/graph_net/graph_net_bench/_init_.py +++ /dev/null @@ -1,3 +0,0 @@ -from .sample_remote_executor import SampleRemoteExecutor - -__all__ = ["SampleRemoteExecutor"] \ No newline at end of file diff --git a/graph_net/graph_net_bench/grpc/client.py b/graph_net/graph_net_bench/client.py similarity index 87% rename from graph_net/graph_net_bench/grpc/client.py rename to graph_net/graph_net_bench/client.py index c2c3290b7..9ab4af7f4 100644 --- a/graph_net/graph_net_bench/grpc/client.py +++ b/graph_net/graph_net_bench/client.py @@ -2,20 +2,13 @@ """gRPC Client CLI for SampleRemoteExecutor. Usage: - python -m graph_net.graph_net_bench.grpc.client --help + python -m graph_net.graph_net_bench.client --help """ import argparse import sys -try: - from ..sample_remote_executor import SampleRemoteExecutor -except ImportError: - import sys - from pathlib import Path - # Add graph_net_bench directory to path for direct execution - sys.path.insert(0, str(Path(__file__).parent.parent)) - from sample_remote_executor import SampleRemoteExecutor +from graph_net.graph_net_bench.sample_remote_executor import SampleRemoteExecutor def main(): diff --git a/graph_net/graph_net_bench/grpc/__init__.py b/graph_net/graph_net_bench/grpc/__init__.py new file mode 100644 index 000000000..946232e40 --- /dev/null +++ b/graph_net/graph_net_bench/grpc/__init__.py @@ -0,0 +1 @@ +# gRPC generated code package \ No newline at end of file diff --git a/graph_net/graph_net_bench/grpc/message_pb2_grpc.py b/graph_net/graph_net_bench/grpc/message_pb2_grpc.py index 82582553d..3de844294 100644 --- a/graph_net/graph_net_bench/grpc/message_pb2_grpc.py +++ b/graph_net/graph_net_bench/grpc/message_pb2_grpc.py @@ -3,7 +3,7 @@ import grpc import warnings -import message_pb2 as message__pb2 +from . import message_pb2 as message__pb2 GRPC_GENERATED_VERSION = '1.76.0' GRPC_VERSION = grpc.__version__ diff --git a/graph_net/graph_net_bench/sample_remote_executor.py b/graph_net/graph_net_bench/sample_remote_executor.py index 282a0fc52..3d0c47340 100644 --- a/graph_net/graph_net_bench/sample_remote_executor.py +++ b/graph_net/graph_net_bench/sample_remote_executor.py @@ -6,12 +6,8 @@ import torch -try: - from .grpc import message_pb2 - from .grpc import message_pb2_grpc -except ImportError: - import message_pb2 - import message_pb2_grpc +from graph_net.graph_net_bench.grpc import message_pb2 +from graph_net.graph_net_bench.grpc import message_pb2_grpc class SampleRemoteExecutor: diff --git a/graph_net/graph_net_bench/grpc/server.py b/graph_net/graph_net_bench/server.py similarity index 96% rename from graph_net/graph_net_bench/grpc/server.py rename to graph_net/graph_net_bench/server.py index 5dc3710b5..c18142dc4 100644 --- a/graph_net/graph_net_bench/grpc/server.py +++ b/graph_net/graph_net_bench/server.py @@ -10,8 +10,8 @@ import grpc -import message_pb2 -import message_pb2_grpc +from graph_net.graph_net_bench.grpc import message_pb2 +from graph_net.graph_net_bench.grpc import message_pb2_grpc class RemoteModelExecutorServicer(message_pb2_grpc.SampleRemoteExecutorServicer): @@ -47,7 +47,7 @@ def Execute(self, request, context): env["OUTPUT_FILE_PATH"] = str(out_path) env["RANDOM_SEED"] = str(request.random_seed) # Add grpc directory to PYTHONPATH so message_pb2 can be imported - grpc_dir = Path(__file__).parent.resolve() + grpc_dir = Path(__file__).parent / "grpc" env["PYTHONPATH"] = f"{grpc_dir}:{env.get('PYTHONPATH', '')}" print(f"Executing rpc_cmd: {request.rpc_cmd}", file=sys.stderr) From 929709e098241e0b3db7ebc36eeb069ae9e4094b Mon Sep 17 00:00:00 2001 From: Denghaodong Date: Sun, 11 Jan 2026 11:43:52 +0000 Subject: [PATCH 06/11] change rpc_cmd to test_reference_device.py --- graph_net/graph_net_bench/client.py | 2 +- graph_net/graph_net_bench/sample_remote_executor.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/graph_net/graph_net_bench/client.py b/graph_net/graph_net_bench/client.py index 9ab4af7f4..c39e8ecf3 100644 --- a/graph_net/graph_net_bench/client.py +++ b/graph_net/graph_net_bench/client.py @@ -44,7 +44,7 @@ def main(): parser.add_argument( "--rpc-cmd", type=str, - default="python3 /denghaodong/code/GraphNet/graph_net/graph_net_bench/sample_rpc_cmd.py", + default="python3 -m graph_net.torch.test_reference_device", help="Command to execute on remote server", ) parser.add_argument( diff --git a/graph_net/graph_net_bench/sample_remote_executor.py b/graph_net/graph_net_bench/sample_remote_executor.py index 3d0c47340..1d6b8c741 100644 --- a/graph_net/graph_net_bench/sample_remote_executor.py +++ b/graph_net/graph_net_bench/sample_remote_executor.py @@ -12,7 +12,7 @@ class SampleRemoteExecutor: - def __init__(self, machine: str, port: int, rpc_cmd: str = "python3 /denghaodong/code/GraphNet/graph_net/graph_net_bench/sample_rpc_cmd.py"): + def __init__(self, machine: str, port: int, rpc_cmd: str = "python3 -m graph_net.torch.test_reference_device"): self.machine = machine self.port = port self.rpc_cmd = rpc_cmd @@ -29,7 +29,6 @@ def __call__(self, model_path: str, random_seed: int) -> Tuple[torch.Tensor, ... compressed_data = self._compress_dir(model_path) - # 输出文件名必须包含扩展名(mentor 约定) output_file_name = f"outputs_seed_{random_seed}.npz" request = message_pb2.ExecutionRequest( @@ -79,11 +78,9 @@ def _npz_bytes_to_tensors(self, npz_bytes: bytes) -> Tuple[torch.Tensor, ...]: import numpy as np with np.load(BytesIO(npz_bytes), allow_pickle=False) as npz: - # 只接受 output_{i} 格式,按 i 排序,保证返回 tuple 稳定 keys = [k for k in npz.files if k.startswith("output_")] def key_index(k: str) -> int: - # "output_0" -> 0 return int(k.split("_", 1)[1]) keys.sort(key=key_index) From f0e01795ae6544b2dc85b5a318990f129a317e1b Mon Sep 17 00:00:00 2001 From: Denghaodong Date: Mon, 12 Jan 2026 05:55:24 +0000 Subject: [PATCH 07/11] reconstruct SampleRemoteExecutor --- graph_net/graph_net_bench/grpc/message.proto | 5 +- graph_net/graph_net_bench/grpc/message_pb2.py | 18 +- .../graph_net_bench/grpc/message_pb2_grpc.py | 2 +- .../graph_net_bench/sample_remote_executor.py | 32 ++-- graph_net/graph_net_bench/server.py | 30 ++-- ...raph_decompose_and_evaluation_step_test.sh | 22 +++ .../torch/sample_pass/subgraph_generator.py | 2 +- .../torch/test_remote_reference_device.py | 155 ++++++++++++++++++ 8 files changed, 218 insertions(+), 48 deletions(-) create mode 100644 graph_net/torch/test_remote_reference_device.py diff --git a/graph_net/graph_net_bench/grpc/message.proto b/graph_net/graph_net_bench/grpc/message.proto index 411f93b50..1460f1ab6 100644 --- a/graph_net/graph_net_bench/grpc/message.proto +++ b/graph_net/graph_net_bench/grpc/message.proto @@ -11,9 +11,8 @@ message CompressedData { message RpcData { oneof rpc_data_type { - CompressedData compressed_data = 1; - string str_data = 2; - bytes npz_data = 3; + CompressedData compressed_data = 1; // For input (single tar.gz) + string str_data = 3; } } diff --git a/graph_net/graph_net_bench/grpc/message_pb2.py b/graph_net/graph_net_bench/grpc/message_pb2.py index 7d49fa7d1..b2c47df3b 100644 --- a/graph_net/graph_net_bench/grpc/message_pb2.py +++ b/graph_net/graph_net_bench/grpc/message_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rmessage.proto\x12\x16sample_remote_executor\"d\n\x0e\x43ompressedData\x12\x10\n\x08\x66ilename\x18\x01 \x01(\t\x12\x15\n\roriginal_size\x18\x02 \x01(\r\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12\x18\n\x10\x63ompression_algo\x18\x04 \x01(\t\"\x85\x01\n\x07RpcData\x12\x41\n\x0f\x63ompressed_data\x18\x01 \x01(\x0b\x32&.sample_remote_executor.CompressedDataH\x00\x12\x12\n\x08str_data\x18\x02 \x01(\tH\x00\x12\x12\n\x08npz_data\x18\x03 \x01(\x0cH\x00\x42\x0f\n\rrpc_data_type\"\xa0\x01\n\x10\x45xecutionRequest\x12\x0f\n\x07rpc_cmd\x18\x01 \x01(\t\x12\x32\n\trpc_input\x18\x02 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData\x12\x1d\n\x10output_file_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x0brandom_seed\x18\x04 \x01(\x03\x42\x13\n\x11_output_file_name\"w\n\x0e\x45xecutionReply\x12\x10\n\x08ret_code\x18\x01 \x01(\x03\x12\x0e\n\x06stdout\x18\x02 \x01(\t\x12\x0e\n\x06stderr\x18\x03 \x01(\t\x12\x33\n\nrpc_output\x18\x04 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData2s\n\x14SampleRemoteExecutor\x12[\n\x07\x45xecute\x12(.sample_remote_executor.ExecutionRequest\x1a&.sample_remote_executor.ExecutionReplyb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rmessage.proto\x12\x16sample_remote_executor\"d\n\x0e\x43ompressedData\x12\x10\n\x08\x66ilename\x18\x01 \x01(\t\x12\x15\n\roriginal_size\x18\x02 \x01(\r\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12\x18\n\x10\x63ompression_algo\x18\x04 \x01(\t\"q\n\x07RpcData\x12\x41\n\x0f\x63ompressed_data\x18\x01 \x01(\x0b\x32&.sample_remote_executor.CompressedDataH\x00\x12\x12\n\x08str_data\x18\x03 \x01(\tH\x00\x42\x0f\n\rrpc_data_type\"\xa0\x01\n\x10\x45xecutionRequest\x12\x0f\n\x07rpc_cmd\x18\x01 \x01(\t\x12\x32\n\trpc_input\x18\x02 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData\x12\x1d\n\x10output_file_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x0brandom_seed\x18\x04 \x01(\x03\x42\x13\n\x11_output_file_name\"w\n\x0e\x45xecutionReply\x12\x10\n\x08ret_code\x18\x01 \x01(\x03\x12\x0e\n\x06stdout\x18\x02 \x01(\t\x12\x0e\n\x06stderr\x18\x03 \x01(\t\x12\x33\n\nrpc_output\x18\x04 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData2s\n\x14SampleRemoteExecutor\x12[\n\x07\x45xecute\x12(.sample_remote_executor.ExecutionRequest\x1a&.sample_remote_executor.ExecutionReplyb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -33,12 +33,12 @@ DESCRIPTOR._loaded_options = None _globals['_COMPRESSEDDATA']._serialized_start=41 _globals['_COMPRESSEDDATA']._serialized_end=141 - _globals['_RPCDATA']._serialized_start=144 - _globals['_RPCDATA']._serialized_end=277 - _globals['_EXECUTIONREQUEST']._serialized_start=280 - _globals['_EXECUTIONREQUEST']._serialized_end=440 - _globals['_EXECUTIONREPLY']._serialized_start=442 - _globals['_EXECUTIONREPLY']._serialized_end=561 - _globals['_SAMPLEREMOTEEXECUTOR']._serialized_start=563 - _globals['_SAMPLEREMOTEEXECUTOR']._serialized_end=678 + _globals['_RPCDATA']._serialized_start=143 + _globals['_RPCDATA']._serialized_end=256 + _globals['_EXECUTIONREQUEST']._serialized_start=259 + _globals['_EXECUTIONREQUEST']._serialized_end=419 + _globals['_EXECUTIONREPLY']._serialized_start=421 + _globals['_EXECUTIONREPLY']._serialized_end=540 + _globals['_SAMPLEREMOTEEXECUTOR']._serialized_start=542 + _globals['_SAMPLEREMOTEEXECUTOR']._serialized_end=657 # @@protoc_insertion_point(module_scope) diff --git a/graph_net/graph_net_bench/grpc/message_pb2_grpc.py b/graph_net/graph_net_bench/grpc/message_pb2_grpc.py index 3de844294..82582553d 100644 --- a/graph_net/graph_net_bench/grpc/message_pb2_grpc.py +++ b/graph_net/graph_net_bench/grpc/message_pb2_grpc.py @@ -3,7 +3,7 @@ import grpc import warnings -from . import message_pb2 as message__pb2 +import message_pb2 as message__pb2 GRPC_GENERATED_VERSION = '1.76.0' GRPC_VERSION = grpc.__version__ diff --git a/graph_net/graph_net_bench/sample_remote_executor.py b/graph_net/graph_net_bench/sample_remote_executor.py index 1d6b8c741..11a87a29e 100644 --- a/graph_net/graph_net_bench/sample_remote_executor.py +++ b/graph_net/graph_net_bench/sample_remote_executor.py @@ -25,7 +25,7 @@ def _get_stub(self): self._stub = message_pb2_grpc.SampleRemoteExecutorStub(self._channel) return self._stub - def __call__(self, model_path: str, random_seed: int) -> Tuple[torch.Tensor, ...]: + def __call__(self, model_path: str, random_seed: int) -> dict: compressed_data = self._compress_dir(model_path) @@ -52,8 +52,8 @@ def __call__(self, model_path: str, random_seed: int) -> Tuple[torch.Tensor, ... if reply.rpc_output.WhichOneof("rpc_data_type") != "compressed_data": raise RuntimeError("Remote execution succeeded but rpc_output is not compressed_data") - npz_bytes = reply.rpc_output.compressed_data.payload - return self._npz_bytes_to_tensors(npz_bytes) + # 解压返回的 tar.gz 文件 + return self._extract_tar_to_dict(reply.rpc_output.compressed_data) def _compress_dir(self, model_path: str): buffer = BytesIO() @@ -74,22 +74,16 @@ def _compress_dir(self, model_path: str): compression_algo="gzip", ) - def _npz_bytes_to_tensors(self, npz_bytes: bytes) -> Tuple[torch.Tensor, ...]: - import numpy as np - - with np.load(BytesIO(npz_bytes), allow_pickle=False) as npz: - keys = [k for k in npz.files if k.startswith("output_")] - - def key_index(k: str) -> int: - return int(k.split("_", 1)[1]) - - keys.sort(key=key_index) - - tensors = [] - for k in keys: - arr = npz[k] - tensors.append(torch.from_numpy(arr)) # 默认 CPU - return tuple(tensors) + def _extract_tar_to_dict(self, compressed_data: message_pb2.CompressedData) -> dict: + """Extract tar.gz to {filename: bytes} dict.""" + buffer = BytesIO(compressed_data.payload) + files_dict = {} + with tarfile.open(fileobj=buffer, mode="r:gz") as tar: + for member in tar.getmembers(): + if member.isfile(): + file_content = tar.extractfile(member).read() + files_dict[member.name] = file_content + return files_dict def close(self): if self._channel is not None: diff --git a/graph_net/graph_net_bench/server.py b/graph_net/graph_net_bench/server.py index c18142dc4..24c5445b1 100644 --- a/graph_net/graph_net_bench/server.py +++ b/graph_net/graph_net_bench/server.py @@ -21,19 +21,13 @@ def Execute(self, request, context): output_workspace = tempfile.mkdtemp(prefix="remote_output_") try: - # 0) 基本校验 + # 0) 基本校验 if not request.rpc_cmd: return message_pb2.ExecutionReply(ret_code=-1, stderr="rpc_cmd is empty") if not request.HasField("output_file_name") or not request.output_file_name: return message_pb2.ExecutionReply(ret_code=-1, stderr="output_file_name is required") - if request.rpc_input.WhichOneof("rpc_data_type") != "compressed_data": - return message_pb2.ExecutionReply( - ret_code=-1, - stderr="rpc_input must be RpcData.compressed_data (tar.gz bytes)", - ) - # 1) 解压输入到 input_workspace self._decompress_to_dir(request.rpc_input.compressed_data, input_workspace) @@ -72,17 +66,23 @@ def Execute(self, request, context): stderr=proc.stderr or f"rpc_cmd failed with returncode={proc.returncode}", ) - # 3) 回读输出文件 - if not out_path.exists(): - print(f"Output file not found at {out_path}", file=sys.stderr) - print(f"Contents of output_workspace: {list(Path(output_workspace).rglob('*'))}", file=sys.stderr) + # 3) 将所有文件打包成 tar.gz + output_tar_path = Path(output_workspace) / "output.tar.gz" + with tarfile.open(output_tar_path, "w:gz") as tar: + for file_path in Path(output_workspace).rglob("*"): + if file_path.is_file() and file_path != output_tar_path: + arcname = file_path.relative_to(output_workspace) + tar.add(file_path, arcname=arcname) + + if not output_tar_path.exists(): + print(f"No output files found in {output_workspace}", file=sys.stderr) return message_pb2.ExecutionReply( ret_code=-1, stdout=proc.stdout or "", - stderr=(proc.stderr or "") + f"\nExpected output not found: {out_path}", + stderr=(proc.stderr or "") + f"\nNo output files found in {output_workspace}", ) - payload = out_path.read_bytes() + payload = output_tar_path.read_bytes() return message_pb2.ExecutionReply( ret_code=0, @@ -90,10 +90,10 @@ def Execute(self, request, context): stderr=proc.stderr or "", rpc_output=message_pb2.RpcData( compressed_data=message_pb2.CompressedData( - filename=request.output_file_name, + filename=output_tar_path.name, original_size=len(payload), payload=payload, - compression_algo="raw", + compression_algo="gzip", ) ), ) diff --git a/graph_net/test/subgraph_decompose_and_evaluation_step_test.sh b/graph_net/test/subgraph_decompose_and_evaluation_step_test.sh index 35dd5f216..ff7cf293e 100755 --- a/graph_net/test/subgraph_decompose_and_evaluation_step_test.sh +++ b/graph_net/test/subgraph_decompose_and_evaluation_step_test.sh @@ -49,6 +49,26 @@ test_target_device_config_str=$(cat < Date: Mon, 12 Jan 2026 06:14:39 +0000 Subject: [PATCH 08/11] remove unrelated changes --- graph_net/graph_net_bench/{client.py => client_demo.py} | 0 graph_net/graph_net_bench/grpc/message.proto | 2 +- graph_net/graph_net_bench/grpc/message_pb2.py | 2 +- graph_net/torch/sample_pass/subgraph_generator.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename graph_net/graph_net_bench/{client.py => client_demo.py} (100%) diff --git a/graph_net/graph_net_bench/client.py b/graph_net/graph_net_bench/client_demo.py similarity index 100% rename from graph_net/graph_net_bench/client.py rename to graph_net/graph_net_bench/client_demo.py diff --git a/graph_net/graph_net_bench/grpc/message.proto b/graph_net/graph_net_bench/grpc/message.proto index 1460f1ab6..0cb614f1b 100644 --- a/graph_net/graph_net_bench/grpc/message.proto +++ b/graph_net/graph_net_bench/grpc/message.proto @@ -12,7 +12,7 @@ message CompressedData { message RpcData { oneof rpc_data_type { CompressedData compressed_data = 1; // For input (single tar.gz) - string str_data = 3; + string str_data = 2; } } diff --git a/graph_net/graph_net_bench/grpc/message_pb2.py b/graph_net/graph_net_bench/grpc/message_pb2.py index b2c47df3b..f672481c8 100644 --- a/graph_net/graph_net_bench/grpc/message_pb2.py +++ b/graph_net/graph_net_bench/grpc/message_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rmessage.proto\x12\x16sample_remote_executor\"d\n\x0e\x43ompressedData\x12\x10\n\x08\x66ilename\x18\x01 \x01(\t\x12\x15\n\roriginal_size\x18\x02 \x01(\r\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12\x18\n\x10\x63ompression_algo\x18\x04 \x01(\t\"q\n\x07RpcData\x12\x41\n\x0f\x63ompressed_data\x18\x01 \x01(\x0b\x32&.sample_remote_executor.CompressedDataH\x00\x12\x12\n\x08str_data\x18\x03 \x01(\tH\x00\x42\x0f\n\rrpc_data_type\"\xa0\x01\n\x10\x45xecutionRequest\x12\x0f\n\x07rpc_cmd\x18\x01 \x01(\t\x12\x32\n\trpc_input\x18\x02 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData\x12\x1d\n\x10output_file_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x0brandom_seed\x18\x04 \x01(\x03\x42\x13\n\x11_output_file_name\"w\n\x0e\x45xecutionReply\x12\x10\n\x08ret_code\x18\x01 \x01(\x03\x12\x0e\n\x06stdout\x18\x02 \x01(\t\x12\x0e\n\x06stderr\x18\x03 \x01(\t\x12\x33\n\nrpc_output\x18\x04 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData2s\n\x14SampleRemoteExecutor\x12[\n\x07\x45xecute\x12(.sample_remote_executor.ExecutionRequest\x1a&.sample_remote_executor.ExecutionReplyb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rmessage.proto\x12\x16sample_remote_executor\"d\n\x0e\x43ompressedData\x12\x10\n\x08\x66ilename\x18\x01 \x01(\t\x12\x15\n\roriginal_size\x18\x02 \x01(\r\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12\x18\n\x10\x63ompression_algo\x18\x04 \x01(\t\"q\n\x07RpcData\x12\x41\n\x0f\x63ompressed_data\x18\x01 \x01(\x0b\x32&.sample_remote_executor.CompressedDataH\x00\x12\x12\n\x08str_data\x18\x02 \x01(\tH\x00\x42\x0f\n\rrpc_data_type\"\xa0\x01\n\x10\x45xecutionRequest\x12\x0f\n\x07rpc_cmd\x18\x01 \x01(\t\x12\x32\n\trpc_input\x18\x02 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData\x12\x1d\n\x10output_file_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x0brandom_seed\x18\x04 \x01(\x03\x42\x13\n\x11_output_file_name\"w\n\x0e\x45xecutionReply\x12\x10\n\x08ret_code\x18\x01 \x01(\x03\x12\x0e\n\x06stdout\x18\x02 \x01(\t\x12\x0e\n\x06stderr\x18\x03 \x01(\t\x12\x33\n\nrpc_output\x18\x04 \x01(\x0b\x32\x1f.sample_remote_executor.RpcData2s\n\x14SampleRemoteExecutor\x12[\n\x07\x45xecute\x12(.sample_remote_executor.ExecutionRequest\x1a&.sample_remote_executor.ExecutionReplyb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) diff --git a/graph_net/torch/sample_pass/subgraph_generator.py b/graph_net/torch/sample_pass/subgraph_generator.py index b47bf257f..4205d16b5 100644 --- a/graph_net/torch/sample_pass/subgraph_generator.py +++ b/graph_net/torch/sample_pass/subgraph_generator.py @@ -168,7 +168,7 @@ def forward(self, *args): self.extracted = True return self.submodule(*args) - def _subgra_subgraph_sources(self): + def _save_subgraph_sources(self): sources_json_obj = self._get_sources_json_obj() model_path = self._get_model_path() model_path.mkdir(parents=True, exist_ok=True) From 74c26dd2cd6065ca1b6cdea6b8f6c4363a716a60 Mon Sep 17 00:00:00 2001 From: Denghaodong Date: Mon, 12 Jan 2026 06:26:20 +0000 Subject: [PATCH 09/11] add test config --- graph_net/subgraph_decompose_and_evaluation_step.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/graph_net/subgraph_decompose_and_evaluation_step.py b/graph_net/subgraph_decompose_and_evaluation_step.py index 2c1c70a6e..2c328c1a6 100755 --- a/graph_net/subgraph_decompose_and_evaluation_step.py +++ b/graph_net/subgraph_decompose_and_evaluation_step.py @@ -92,6 +92,7 @@ def _init_task_scheduler(self, test_module_name): assert test_module_name in [ "test_compiler", "test_reference_device", + "test_remote_reference_device", "test_target_device", ] if test_module_name == "test_compiler": @@ -106,6 +107,12 @@ def _init_task_scheduler(self, test_module_name): "run_evaluation": True, "post_analysis": False, } + elif test_module_name == "test_remote_reference_device": + self.task_scheduler = { + "run_decomposer": False, + "run_evaluation": True, + "post_analysis": False, + } elif test_module_name == "test_target_device": self.task_scheduler = { "run_decomposer": False, From 62b90ef4b9b8a555d56795b30cac0d00e7f400e7 Mon Sep 17 00:00:00 2001 From: Denghaodong Date: Mon, 12 Jan 2026 07:24:29 +0000 Subject: [PATCH 10/11] executor --- .../graph_net_bench/grpc/message_pb2_grpc.py | 4 +- .../graph_net_bench/sample_remote_executor.py | 2 +- graph_net/graph_net_bench/sample_rpc_cmd.py | 148 ------------------ .../subgraph_decompose_and_evaluation_step.py | 2 +- ...raph_decompose_and_evaluation_step_test.sh | 2 +- 5 files changed, 4 insertions(+), 154 deletions(-) delete mode 100644 graph_net/graph_net_bench/sample_rpc_cmd.py diff --git a/graph_net/graph_net_bench/grpc/message_pb2_grpc.py b/graph_net/graph_net_bench/grpc/message_pb2_grpc.py index 82582553d..97aa80ebb 100644 --- a/graph_net/graph_net_bench/grpc/message_pb2_grpc.py +++ b/graph_net/graph_net_bench/grpc/message_pb2_grpc.py @@ -1,9 +1,7 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" import grpc import warnings - -import message_pb2 as message__pb2 +from . import message_pb2 as message__pb2 GRPC_GENERATED_VERSION = '1.76.0' GRPC_VERSION = grpc.__version__ diff --git a/graph_net/graph_net_bench/sample_remote_executor.py b/graph_net/graph_net_bench/sample_remote_executor.py index 11a87a29e..b327b8c6e 100644 --- a/graph_net/graph_net_bench/sample_remote_executor.py +++ b/graph_net/graph_net_bench/sample_remote_executor.py @@ -29,7 +29,7 @@ def __call__(self, model_path: str, random_seed: int) -> dict: compressed_data = self._compress_dir(model_path) - output_file_name = f"outputs_seed_{random_seed}.npz" + output_file_name = f"outputs_seed_{random_seed}" request = message_pb2.ExecutionRequest( rpc_cmd=self.rpc_cmd, diff --git a/graph_net/graph_net_bench/sample_rpc_cmd.py b/graph_net/graph_net_bench/sample_rpc_cmd.py deleted file mode 100644 index c293a6b1a..000000000 --- a/graph_net/graph_net_bench/sample_rpc_cmd.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 - -import os -import sys -import importlib.util -from pathlib import Path - -import torch -import numpy as np - - -def load_model_and_weights(model_path: str): - model_file = Path(model_path) / "model.py" - if not model_file.exists(): - raise FileNotFoundError(f"model.py not found in {model_path}") - - spec = importlib.util.spec_from_file_location("remote_model_module", str(model_file)) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - if not hasattr(module, 'GraphModule'): - raise ValueError("model.py must define 'GraphModule' class") - - weight_tensors = {} - weight_meta_file = Path(model_path) / "weight_meta.py" - if weight_meta_file.exists(): - spec = importlib.util.spec_from_file_location( - "weight_meta_module", str(weight_meta_file) - ) - weight_meta_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(weight_meta_module) - - tensor_classes = [ - getattr(weight_meta_module, name) - for name in dir(weight_meta_module) - if name.startswith("Program_weight_tensor_meta_") - ] - - for tensor_cls in tensor_classes: - name = tensor_cls.name - shape = tensor_cls.shape - dtype_name = tensor_cls.dtype.replace("torch.", "") - dtype = getattr(torch, dtype_name) - device = getattr(tensor_cls, 'device', 'cpu') - - if tensor_cls.data is not None: - np_array = np.array(tensor_cls.data, dtype=np.dtype(dtype_name)) - np_array = np_array.reshape(shape) - weight_tensors[name] = torch.from_numpy(np_array).to(device) - else: - if dtype == torch.bool: - weight_tensors[name] = torch.zeros(shape, dtype=dtype, device=device) - else: - weight_tensors[name] = torch.randn(shape, dtype=dtype, device=device) - - model = module.GraphModule() - for name, tensor in weight_tensors.items(): - param = getattr(model, name, None) - if param is not None and isinstance(param, torch.Tensor): - param.data.copy_(tensor) - - return model, weight_tensors - - -def get_forward_inputs(model_path: str, weight_tensors: dict): - import inspect - - model_file = Path(model_path) / "model.py" - spec = importlib.util.spec_from_file_location("remote_model_module", str(model_file)) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - forward_params = inspect.signature(module.GraphModule.forward).parameters - param_names = [ - name for name in forward_params.keys() - if name != 'self' - ] - - inputs = [] - for param_name in param_names: - if param_name in weight_tensors: - inputs.append(weight_tensors[param_name]) - else: - raise ValueError(f"Missing weight tensor for parameter: {param_name}") - - return inputs - - -def save_outputs_as_npz(outputs, output_path: str): - if not isinstance(outputs, tuple): - outputs = (outputs,) - - np_arrays = {} - for i, tensor in enumerate(outputs): - key = f"output_{i}" - np_arrays[key] = tensor.cpu().numpy() - - np.savez(output_path, **np_arrays) - - -def main(): - input_workspace = os.environ.get("INPUT_WORKSPACE") - output_workspace = os.environ.get("OUTPUT_WORKSPACE") - output_file_name = os.environ.get("OUTPUT_FILE_NAME") - output_file_path = os.environ.get("OUTPUT_FILE_PATH") - seed_str = os.environ.get("RANDOM_SEED") - - if not input_workspace: - raise RuntimeError("INPUT_WORKSPACE environment variable not set") - if not output_workspace: - raise RuntimeError("OUTPUT_WORKSPACE environment variable not set") - if not output_file_name: - raise RuntimeError("OUTPUT_FILE_NAME environment variable not set") - - print(f"INPUT_WORKSPACE: {input_workspace}", file=sys.stderr) - print(f"OUTPUT_WORKSPACE: {output_workspace}", file=sys.stderr) - print(f"OUTPUT_FILE_NAME: {output_file_name}", file=sys.stderr) - if output_file_path: - print(f"OUTPUT_FILE_PATH: {output_file_path}", file=sys.stderr) - - if seed_str: - seed = int(seed_str) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - print(f"RANDOM_SEED: {seed}", file=sys.stderr) - - print("Loading model and weights...", file=sys.stderr) - model, weight_tensors = load_model_and_weights(input_workspace) - print(f"Model loaded, {len(weight_tensors)} weight tensors", file=sys.stderr) - - print("Preparing inputs...", file=sys.stderr) - inputs = get_forward_inputs(input_workspace, weight_tensors) - print(f"Prepared {len(inputs)} inputs for forward()", file=sys.stderr) - - print("Running inference...", file=sys.stderr) - model.eval() - with torch.no_grad(): - outputs = model(*inputs) - - output_path = output_file_path or os.path.join(output_workspace, output_file_name) - print(f"Saving outputs to {output_path}...", file=sys.stderr) - save_outputs_as_npz(outputs, output_path) - print("Outputs saved successfully!", file=sys.stderr) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/graph_net/subgraph_decompose_and_evaluation_step.py b/graph_net/subgraph_decompose_and_evaluation_step.py index 2c328c1a6..983f2698f 100755 --- a/graph_net/subgraph_decompose_and_evaluation_step.py +++ b/graph_net/subgraph_decompose_and_evaluation_step.py @@ -403,7 +403,7 @@ def run_evaluation( test_module_name = test_config["test_module_name"] test_module_arguments = test_config[f"{test_module_name}_arguments"] test_module_arguments["model-path"] = work_dir - if test_module_name in ["test_reference_device", "test_target_device"]: + if test_module_name in ["test_reference_device", "test_remote_reference_device", "test_target_device"]: test_module_arguments["reference-dir"] = os.path.join( work_dir, "reference_device_outputs" ) diff --git a/graph_net/test/subgraph_decompose_and_evaluation_step_test.sh b/graph_net/test/subgraph_decompose_and_evaluation_step_test.sh index ff7cf293e..58c8fecee 100755 --- a/graph_net/test/subgraph_decompose_and_evaluation_step_test.sh +++ b/graph_net/test/subgraph_decompose_and_evaluation_step_test.sh @@ -69,7 +69,7 @@ test_remote_reference_device_config_str=$(cat < Date: Mon, 12 Jan 2026 08:50:08 +0000 Subject: [PATCH 11/11] change --- .../{client_demo.py => grpc/sample_remote_executor_test.py} | 4 +--- graph_net/graph_net_bench/{ => grpc}/server.py | 0 graph_net/subgraph_decompose_and_evaluation_step.py | 2 +- graph_net/test/subgraph_decompose_and_evaluation_step_test.sh | 4 +++- 4 files changed, 5 insertions(+), 5 deletions(-) rename graph_net/graph_net_bench/{client_demo.py => grpc/sample_remote_executor_test.py} (98%) rename graph_net/graph_net_bench/{ => grpc}/server.py (100%) diff --git a/graph_net/graph_net_bench/client_demo.py b/graph_net/graph_net_bench/grpc/sample_remote_executor_test.py similarity index 98% rename from graph_net/graph_net_bench/client_demo.py rename to graph_net/graph_net_bench/grpc/sample_remote_executor_test.py index c39e8ecf3..ad129ba3d 100644 --- a/graph_net/graph_net_bench/client_demo.py +++ b/graph_net/graph_net_bench/grpc/sample_remote_executor_test.py @@ -10,8 +10,7 @@ from graph_net.graph_net_bench.sample_remote_executor import SampleRemoteExecutor - -def main(): +def main(args): parser = argparse.ArgumentParser( description="gRPC Client for remote model execution", formatter_class=argparse.RawDescriptionHelpFormatter, @@ -54,7 +53,6 @@ def main(): help="Directory to save output tensors (default: current directory)", ) - args = parser.parse_args() executor = SampleRemoteExecutor( machine=args.machine, diff --git a/graph_net/graph_net_bench/server.py b/graph_net/graph_net_bench/grpc/server.py similarity index 100% rename from graph_net/graph_net_bench/server.py rename to graph_net/graph_net_bench/grpc/server.py diff --git a/graph_net/subgraph_decompose_and_evaluation_step.py b/graph_net/subgraph_decompose_and_evaluation_step.py index 983f2698f..feffc043c 100755 --- a/graph_net/subgraph_decompose_and_evaluation_step.py +++ b/graph_net/subgraph_decompose_and_evaluation_step.py @@ -104,7 +104,7 @@ def _init_task_scheduler(self, test_module_name): elif test_module_name == "test_reference_device": self.task_scheduler = { "run_decomposer": True, - "run_evaluation": True, + "run_evaluation": False, "post_analysis": False, } elif test_module_name == "test_remote_reference_device": diff --git a/graph_net/test/subgraph_decompose_and_evaluation_step_test.sh b/graph_net/test/subgraph_decompose_and_evaluation_step_test.sh index 58c8fecee..8663a2b0d 100755 --- a/graph_net/test/subgraph_decompose_and_evaluation_step_test.sh +++ b/graph_net/test/subgraph_decompose_and_evaluation_step_test.sh @@ -69,7 +69,9 @@ test_remote_reference_device_config_str=$(cat <