diff --git a/pyproject.toml b/pyproject.toml index ca58e34..4887b47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ test = [ "pytest-timeout", ] publish = ["twine", "build"] -dev = ["mkdocs-material[imaging]", "wampproto@git+https://github.com/xconnio/wampproto-python"] +dev = ["mkdocs-material[imaging]", "wampproto@git+https://github.com/xconnio/wampproto-python", "protobuf", "pycapnp"] capnproto = [ "wampproto-messages-capnproto@git+https://github.com/xconnio/wampproto-messages-capnproto@main#subdirectory=python" ] @@ -81,6 +81,7 @@ exclude = [ "node_modules", "site-packages", "venv", + "tests/schemas/*" ] line-length = 120 diff --git a/tests/codec_test.py b/tests/codec_test.py new file mode 100644 index 0000000..4804bf8 --- /dev/null +++ b/tests/codec_test.py @@ -0,0 +1,227 @@ +import os +from pathlib import Path +from typing import Type, TypeVar + +import capnp +from google.protobuf.message import Message + +from xconn.client import connect_anonymous +from xconn.codec import Codec +from xconn.types import Event, OutgoingDataMessage, IncomingDataMessage +from tests.schemas.profile_pb2 import ProfileCreate, ProfileGet + + +T = TypeVar("T", bound=Message) + + +class ProtobufCodec(Codec[T]): + def name(self) -> str: + return "protobuf" + + def encode(self, obj: T) -> OutgoingDataMessage: + payload = obj.SerializeToString() + return OutgoingDataMessage(args=[payload], kwargs={}, details={}) + + def decode(self, msg: IncomingDataMessage, out_type: Type[T]) -> T: + if len(msg.args) == 0 or not isinstance(msg.args[0], bytes): + raise ValueError("ProtobufCodec: cannot decode, expected first arg to be bytes") + + obj = out_type() + obj.ParseFromString(msg.args[0]) + return obj + + +def test_rpc_object_protobuf(): + session = connect_anonymous("ws://localhost:8080/ws", "realm1") + session.set_payload_codec(ProtobufCodec()) + + def inv_handler(profile: ProfileCreate) -> ProfileGet: + return ProfileGet( + id="123", + username=profile.username, + email=profile.email, + age=profile.age, + created_at="2025-10-28T17:00:00Z", + ) + + session.register_object("io.xconn.profile.create", inv_handler) + create_msg = ProfileCreate(username="john", email="john@xconn.io", age=25) + + result = session.call_object("io.xconn.profile.create", create_msg, ProfileGet) + assert isinstance(result, ProfileGet) + assert result.username == "john" + assert result.email == "john@xconn.io" + assert result.age == 25 + assert result.id == "123" + assert result.created_at == "2025-10-28T17:00:00Z" + + session.leave() + + +def test_pubsub_protobuf(): + session = connect_anonymous("ws://localhost:8080/ws", "realm1") + session.set_payload_codec(ProtobufCodec()) + + def event_handler(event: Event): + user: ProfileCreate = event.args[0] + assert user.username == "john" + assert user.email == "john@xconn.io" + assert user.age == 25 + + session.subscribe_object("io.xconn.object", event_handler, ProfileCreate) + + create_msg = ProfileCreate(username="john", email="john@xconn.io", age=25) + session.publish_object("io.xconn.object", create_msg) + + session.leave() + + +def test_rpc_object_one_param_with_return_type(): + session = connect_anonymous("ws://localhost:8080/ws", "realm1") + session.set_payload_codec(ProtobufCodec()) + + def create_profile_handler(prof: ProfileCreate) -> ProfileGet: + return ProfileGet( + id="356", + username=prof.username, + email=prof.email, + age=prof.age, + created_at="2025-10-30T17:00:00Z", + ) + + session.register_object("io.xconn.profile.create", create_profile_handler) + + profile_create = ProfileCreate(username="john", email="john@xconn.io", age=25) + profile = session.call_object("io.xconn.profile.create", profile_create, ProfileGet) + + assert profile.id == "356" + assert profile.username == "john" + assert profile.email == "john@xconn.io" + assert profile.age == 25 + assert profile.created_at == "2025-10-30T17:00:00Z" + + session.leave() + + +def test_rpc_object_no_param(): + session = connect_anonymous("ws://localhost:8080/ws", "realm1") + session.set_payload_codec(ProtobufCodec()) + + options = {"flag": False} + + def invocation_handler() -> None: + options["flag"] = True + + session.register_object("io.xconn.param.none", invocation_handler) + + result = session.call_object("io.xconn.param.none") + + assert options["flag"] is True + assert result is None + + session.leave() + + +def test_rpc_object_no_param_with_return(): + session = connect_anonymous("ws://localhost:8080/ws", "realm1") + session.set_payload_codec(ProtobufCodec()) + + def get_profile_handler() -> ProfileGet: + return ProfileGet( + id="636", + username="admin", + email="admin@xconn.io", + age=30, + created_at="2025-10-30T17:00:00Z", + ) + + session.register_object("io.xconn.profile.get", get_profile_handler) + + profile = session.call_object("io.xconn.profile.get", return_type=ProfileGet) + + assert profile.id == "636" + assert profile.username == "admin" + assert profile.email == "admin@xconn.io" + assert profile.age == 30 + assert profile.created_at == "2025-10-30T17:00:00Z" + + session.leave() + + +T = TypeVar("T") + +root_dir = Path(__file__).resolve().parent +module_file = os.path.join(root_dir, "schemas", "user.capnp") +user_capnp = capnp.load(str(module_file)) + +UserCreate = user_capnp.UserCreate +UserGet = user_capnp.UserGet + + +class CapnpProtoCodec(Codec[T]): + def name(self) -> str: + return "capnproto" + + def encode(self, obj: T) -> OutgoingDataMessage: + payload = obj.to_bytes_packed() + return OutgoingDataMessage(args=[payload], kwargs={}, details={}) + + def decode(self, msg: IncomingDataMessage, out_type: Type[T]) -> T: + if len(msg.args) == 0 or not isinstance(msg.args[0], bytes): + raise ValueError("CapnpProtoCodec: cannot decode, expected first arg to be bytes") + + return out_type.from_bytes_packed(msg.args[0]) + + +def test_rpc_object_capnproto(): + session = connect_anonymous("ws://localhost:8080/ws", "realm1") + session.set_payload_codec(CapnpProtoCodec()) + + def create_handler(user_create: UserCreate) -> UserGet: + user_get = UserGet.new_message() + user_get.id = 999 + user_get.name = user_create.name + user_get.email = user_create.email + user_get.age = user_create.age + user_get.isAdmin = False + + return user_get + + session.register_object("io.xconn.user.create", create_handler) + + new_user = UserCreate.new_message() + new_user.name = "john" + new_user.email = "john@xconn.io" + new_user.age = 35 + + user = session.call_object("io.xconn.user.create", new_user, UserGet) + + assert user.id == 999 + assert user.name == "john" + assert user.email == "john@xconn.io" + assert user.age == 35 + assert not user.isAdmin + + session.leave() + + +def test_pubsub_capnproto(): + session = connect_anonymous("ws://localhost:8080/ws", "realm1") + session.set_payload_codec(CapnpProtoCodec()) + + def event_handler(event: Event): + user: UserCreate = event.args[0] + assert user.name == "alice" + assert user.email == "alice@xconn.io" + assert user.age == 21 + + session.subscribe_object("io.xconn.object", event_handler, UserCreate) + + new_user = UserCreate.new_message() + new_user.name = "alice" + new_user.email = "alice@xconn.io" + new_user.age = 21 + + session.publish_object("io.xconn.object", new_user) + + session.leave() diff --git a/tests/schemas/__init__.py b/tests/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/schemas/profile.proto b/tests/schemas/profile.proto new file mode 100644 index 0000000..49d1b18 --- /dev/null +++ b/tests/schemas/profile.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +message ProfileCreate { + string username = 1; + string email = 2; + int32 age = 3; +} + +message ProfileGet { + string id = 1; + string username = 2; + string email = 3; + int32 age = 4; + string created_at = 5; +} diff --git a/tests/schemas/profile_pb2.py b/tests/schemas/profile_pb2.py new file mode 100644 index 0000000..176d7e5 --- /dev/null +++ b/tests/schemas/profile_pb2.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tests/profile.proto +"""Generated protocol buffer code.""" + +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x13tests/profile.proto"=\n\rProfileCreate\x12\x10\n\x08username\x18\x01 \x01(\t\x12\r\n\x05\x65' + b'mail\x18\x02 \x01(\t\x12\x0b\n\x03\x61ge\x18\x03 \x01(\x05"Z\n\nProfileGet\x12\n\n\x02id\x18\x01 ' + b"\x01(\t\x12\x10\n\x08username\x18\x02 \x01(\t\x12\r\n\x05\x65mail\x18\x03 \x01(\t\x12\x0b\n\x03\x61" + b"ge\x18\x04 \x01(\x05\x12\x12\n\ncreated_at\x18\x05 \x01(\tb\x06proto3" +) + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "tests.profile_pb2", globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _PROFILECREATE._serialized_start = 23 + _PROFILECREATE._serialized_end = 84 + _PROFILEGET._serialized_start = 86 + _PROFILEGET._serialized_end = 176 +# @@protoc_insertion_point(module_scope) diff --git a/tests/schemas/profile_pb2.pyi b/tests/schemas/profile_pb2.pyi new file mode 100644 index 0000000..a1f8304 --- /dev/null +++ b/tests/schemas/profile_pb2.pyi @@ -0,0 +1,38 @@ +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Optional as _Optional + +DESCRIPTOR: _descriptor.FileDescriptor + +class ProfileCreate(_message.Message): + __slots__ = ["age", "email", "username"] + AGE_FIELD_NUMBER: _ClassVar[int] + EMAIL_FIELD_NUMBER: _ClassVar[int] + USERNAME_FIELD_NUMBER: _ClassVar[int] + age: int + email: str + username: str + def __init__( + self, username: _Optional[str] = ..., email: _Optional[str] = ..., age: _Optional[int] = ... + ) -> None: ... + +class ProfileGet(_message.Message): + __slots__ = ["age", "created_at", "email", "id", "username"] + AGE_FIELD_NUMBER: _ClassVar[int] + CREATED_AT_FIELD_NUMBER: _ClassVar[int] + EMAIL_FIELD_NUMBER: _ClassVar[int] + ID_FIELD_NUMBER: _ClassVar[int] + USERNAME_FIELD_NUMBER: _ClassVar[int] + age: int + created_at: str + email: str + id: str + username: str + def __init__( + self, + id: _Optional[str] = ..., + username: _Optional[str] = ..., + email: _Optional[str] = ..., + age: _Optional[int] = ..., + created_at: _Optional[str] = ..., + ) -> None: ... diff --git a/tests/schemas/user.capnp b/tests/schemas/user.capnp new file mode 100644 index 0000000..6b63716 --- /dev/null +++ b/tests/schemas/user.capnp @@ -0,0 +1,14 @@ +@0x9d0356c942e754d3; +struct UserCreate { + name @0 :Text; + email @1 :Text; + age @2 :UInt16; +} + +struct UserGet { + id @0 :UInt32; + name @1 :Text; + email @2 :Text; + age @3 :UInt16; + isAdmin @4 :Bool; +} diff --git a/xconn/codec.py b/xconn/codec.py new file mode 100644 index 0000000..3e5f314 --- /dev/null +++ b/xconn/codec.py @@ -0,0 +1,18 @@ +from typing import Generic, Type, TypeVar + +from xconn.types import IncomingDataMessage, OutgoingDataMessage + +T = TypeVar("T") + + +class Codec(Generic[T]): + def name(self) -> str: + raise NotImplementedError + + def encode(self, obj: T) -> OutgoingDataMessage: + """Serialize a Python object to bytes.""" + raise NotImplementedError + + def decode(self, msg: IncomingDataMessage, out_type: Type[T]) -> T: + """Deserialize the incoming message into an instance of out_type.""" + raise NotImplementedError diff --git a/xconn/session.py b/xconn/session.py index b13aed7..1611028 100644 --- a/xconn/session.py +++ b/xconn/session.py @@ -1,16 +1,21 @@ from __future__ import annotations +import inspect from concurrent.futures import Future from threading import Thread -from typing import Callable, Any +from typing import Callable, Any, TypeVar, Type, overload from dataclasses import dataclass from wampproto import messages, session, uris from xconn import types, exception, uris as xconn_uris +from xconn.codec import Codec from xconn.exception import ApplicationError from xconn.helpers import exception_from_error, SessionScopeIDGenerator +TReq = TypeVar("TReq") +TRes = TypeVar("TRes") + @dataclass class RegisterRequest: @@ -68,6 +73,8 @@ def __init__(self, base_session: types.BaseSession): self._disconnect_callback: list[Callable[[], None] | None] = [] + self._payload_codec: Codec = None + thread = Thread(target=self._wait, daemon=False) thread.start() @@ -170,6 +177,99 @@ def _process_incoming_message(self, msg: messages.Message): else: raise ValueError("received unknown message") + def set_payload_codec(self, codec: Codec) -> None: + self._payload_codec = codec + + @overload + def call_object(self, procedure: str, request: TReq, return_type: Type[TRes]) -> TRes: + ... + + @overload + def call_object(self, procedure: str, request: None = None, return_type: None = None) -> None: + ... + + @overload + def call_object(self, procedure: str, request: None, return_type: Type[TRes]) -> TRes: + ... + + def call_object(self, procedure: str, request: TReq = None, return_type: Type[TRes] | None = None) -> TRes | None: + if self._payload_codec is None: + raise ValueError("no payload codec set") + + if request is not None: + encoded = self._payload_codec.encode(request) + result = self.call(procedure, args=encoded.args, kwargs=encoded.kwargs, options=encoded.details) + else: + result = self.call(procedure) + + if return_type is not None: + return self._payload_codec.decode(result.args[0], return_type) + + return None + + def subscribe_object(self, topic: str, event_handler: Callable[[types.Event], None], return_type: Type[TRes]): + if self._payload_codec is None: + raise ValueError("no payload codec set") + + def _event_handler(event: types.Event): + if len(event.args) != 1: + raise ValueError("only one argument expected in event") + + data = event.args[0] + d = self._payload_codec.decode(data, return_type) + event_handler(types.Event(args=[d], kwargs={}, details={})) + + return self.subscribe(topic, _event_handler) + + def publish_object(self, topic: str, request: TReq): + if self._payload_codec is None: + raise ValueError("no payload codec set") + + encoded = self._payload_codec.encode(request) + return self.publish(topic, [encoded]) + + def register_object( + self, + procedure: str, + invocation_handler: Callable[[TReq], TRes | None] | Callable[[], TRes | None], + ): + if self._payload_codec is None: + raise ValueError("no payload codec set") + + sig = inspect.signature(invocation_handler) + + params = list(sig.parameters.values()) + if len(params) > 1: + raise ValueError("invocation handler must accept 0 or 1 argument") + + if len(params) == 1: + # get parameter's type hint + param_type = params[0].annotation + if param_type is inspect._empty: + raise TypeError("invocation handler parameter must have a type annotation") + else: + param_type = None + + def _invocation_handler(invocation: types.Invocation): + request_obj = None + if param_type is not None: + if len(invocation.args) != 1: + raise ValueError("only one argument expected in invocation") + + request_obj = self._payload_codec.decode(invocation.args[0], param_type) + + result = invocation_handler(request_obj) if param_type is not None else invocation_handler() + + # no return type in invocation handler + if sig.return_annotation is inspect._empty or result is None: + return None + + encoded = self._payload_codec.encode(result) + + return types.Result(args=[encoded]) + + return self.register(procedure, _invocation_handler) + def call( self, procedure: str, diff --git a/xconn/types.py b/xconn/types.py index 07b4d08..1f358a9 100644 --- a/xconn/types.py +++ b/xconn/types.py @@ -26,26 +26,34 @@ class UnsubscribeRequest: @dataclass -class Result: +class OutgoingDataMessage: args: list | None = None kwargs: dict | None = None details: dict | None = None @dataclass -class Invocation: - args: list | None - kwargs: dict | None - details: dict | None +class Result(OutgoingDataMessage): + pass @dataclass -class Event: +class IncomingDataMessage: args: list | None kwargs: dict | None details: dict | None +@dataclass +class Invocation(IncomingDataMessage): + pass + + +@dataclass() +class Event(IncomingDataMessage): + pass + + @dataclass class TransportConfig: # max wait time for connection to be established