Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
1 change: 1 addition & 0 deletions graph_net/graph_net_bench/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# gRPC generated code package
21 changes: 0 additions & 21 deletions graph_net/graph_net_bench/grpc/client.py

This file was deleted.

7 changes: 5 additions & 2 deletions graph_net/graph_net_bench/grpc/message.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ message CompressedData {

message RpcData {
oneof rpc_data_type {
CompressedData compressed_data = 1;
string str_data = 2;
CompressedData compressed_data = 1; // For input (single tar.gz)
string str_data = 2;
}
}

message ExecutionRequest {
string rpc_cmd = 1;
RpcData rpc_input = 2;
optional string output_file_name = 3;
int64 random_seed = 4;
}

message ExecutionReply {
Expand All @@ -34,3 +35,5 @@ service SampleRemoteExecutor {
}

// python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. message.proto


38 changes: 21 additions & 17 deletions graph_net/graph_net_bench/grpc/message_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

81 changes: 35 additions & 46 deletions graph_net/graph_net_bench/grpc/message_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from . import message_pb2 as message__pb2

import message_pb2 as message__pb2

GRPC_GENERATED_VERSION = "1.76.0"
GRPC_GENERATED_VERSION = '1.76.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False

try:
from grpc._utilities import first_version_is_lower

_version_not_supported = first_version_is_lower(
GRPC_VERSION, GRPC_GENERATED_VERSION
)
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True

if _version_not_supported:
raise RuntimeError(
f"The grpc package installed is at version {GRPC_VERSION},"
+ " but the generated code in message_pb2_grpc.py depends on"
+ f" grpcio>={GRPC_GENERATED_VERSION}."
+ f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}"
+ f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}."
f'The grpc package installed is at version {GRPC_VERSION},'
+ ' but the generated code in message_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)


Expand All @@ -37,11 +33,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.Execute = channel.unary_unary(
"/sample_remote_executor.SampleRemoteExecutor/Execute",
request_serializer=message__pb2.ExecutionRequest.SerializeToString,
response_deserializer=message__pb2.ExecutionReply.FromString,
_registered_method=True,
)
'/sample_remote_executor.SampleRemoteExecutor/Execute',
request_serializer=message__pb2.ExecutionRequest.SerializeToString,
response_deserializer=message__pb2.ExecutionReply.FromString,
_registered_method=True)


class SampleRemoteExecutorServicer(object):
Expand All @@ -50,48 +45,43 @@ class SampleRemoteExecutorServicer(object):
def Execute(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_SampleRemoteExecutorServicer_to_server(servicer, server):
rpc_method_handlers = {
"Execute": grpc.unary_unary_rpc_method_handler(
servicer.Execute,
request_deserializer=message__pb2.ExecutionRequest.FromString,
response_serializer=message__pb2.ExecutionReply.SerializeToString,
),
'Execute': grpc.unary_unary_rpc_method_handler(
servicer.Execute,
request_deserializer=message__pb2.ExecutionRequest.FromString,
response_serializer=message__pb2.ExecutionReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
"sample_remote_executor.SampleRemoteExecutor", rpc_method_handlers
)
'sample_remote_executor.SampleRemoteExecutor', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers(
"sample_remote_executor.SampleRemoteExecutor", rpc_method_handlers
)
server.add_registered_method_handlers('sample_remote_executor.SampleRemoteExecutor', rpc_method_handlers)


# This class is part of an EXPERIMENTAL API.
# This class is part of an EXPERIMENTAL API.
class SampleRemoteExecutor(object):
"""Missing associated documentation comment in .proto file."""

@staticmethod
def Execute(
request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
def Execute(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
"/sample_remote_executor.SampleRemoteExecutor/Execute",
'/sample_remote_executor.SampleRemoteExecutor/Execute',
message__pb2.ExecutionRequest.SerializeToString,
message__pb2.ExecutionReply.FromString,
options,
Expand All @@ -102,5 +92,4 @@ def Execute(
wait_for_ready,
timeout,
metadata,
_registered_method=True,
)
_registered_method=True)
94 changes: 94 additions & 0 deletions graph_net/graph_net_bench/grpc/sample_remote_executor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python3
"""gRPC Client CLI for SampleRemoteExecutor.

Usage:
python -m graph_net.graph_net_bench.client --help
"""

import argparse
import sys

from graph_net.graph_net_bench.sample_remote_executor import SampleRemoteExecutor

def main(args):
parser = argparse.ArgumentParser(
description="gRPC Client for remote model execution",
formatter_class=argparse.RawDescriptionHelpFormatter,
)

parser.add_argument(
"--machine",
type=str,
default="localhost",
help="Remote server address (default: localhost)",
)
parser.add_argument(
"--port",
type=int,
default=50052,
help="gRPC server port (default: 50052)",
)
parser.add_argument(
"--model-path",
type=str,
required=True,
help="Path to model directory containing model.py and weight_meta.py",
)
parser.add_argument(
"--random-seed",
type=int,
default=42,
help="Random seed for reproducible inference (default: 42)",
)
parser.add_argument(
"--rpc-cmd",
type=str,
default="python3 -m graph_net.torch.test_reference_device",
help="Command to execute on remote server",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Directory to save output tensors (default: current directory)",
)


executor = SampleRemoteExecutor(
machine=args.machine,
port=args.port,
rpc_cmd=args.rpc_cmd,
)

try:
print(f"Sending request to {args.machine}:{args.port}...", file=sys.stderr)
tensors = executor(args.model_path, args.random_seed)

print(f"Received {len(tensors)} output tensors:", file=sys.stderr)
for i, tensor in enumerate(tensors):
print(f" output_{i}: shape={tensor.shape}, dtype={tensor.dtype}", file=sys.stderr)

if args.output_dir:
import os
from pathlib import Path
import torch

output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)

for i, tensor in enumerate(tensors):
output_file = output_path / f"output_{i}.pt"
torch.save(tensor, output_file)
print(f"Saved output_{i} to {output_file}", file=sys.stderr)

print("Execution completed successfully!", file=sys.stderr)

except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
finally:
executor.close()


if __name__ == "__main__":
main()
Loading
Loading