Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/labthings_fastapi/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import httpx
from urllib.parse import urlparse, urljoin

import numpy as np
from pydantic import BaseModel

from .outputs import ClientBlobOutput
Expand Down Expand Up @@ -159,7 +160,9 @@ def set_property(self, path: str, value: Any) -> None:
r = self.client.put(urljoin(self.path, path), json=value)
r.raise_for_status()

def invoke_action(self, path: str, **kwargs: Any) -> Any:
def invoke_action(
self, path: str, labthings_typehint: str | None, **kwargs: Any
) -> Any:
r"""Invoke an action on the Thing.

This method will make the initial POST request to invoke an action,
Expand Down Expand Up @@ -205,7 +208,7 @@ def invoke_action(self, path: str, **kwargs: Any) -> Any:
href=invocation["output"]["href"],
client=self.client,
)
return invocation["output"]
return _adjust_type(invocation["output"], labthings_typehint)
else:
raise RuntimeError(f"Action did not complete successfully: {invocation}")

Expand Down Expand Up @@ -276,6 +279,15 @@ class Client(cls): # type: ignore[valid-type, misc]
return Client


def _adjust_type(value: Any, labthings_typehint: str | None) -> Any:
"""Adjust the return type based on labthings_typehint."""
if labthings_typehint is None:
return value
if labthings_typehint == "ndarray":
return np.array(value)
raise ValueError(f"No type of {labthings_typehint} known")


class PropertyClientDescriptor:
"""A base class for properties on `.ThingClient` objects."""

Expand Down Expand Up @@ -361,9 +373,12 @@ def add_action(cls: type[ThingClient], action_name: str, action: dict) -> None:
:param action: a dictionary representing the action, in :ref:`wot_td`
format.
"""
labthings_typehint = action["output"].get("format", None)

def action_method(self: ThingClient, **kwargs: Any) -> Any:
return self.invoke_action(action_name, **kwargs)
return self.invoke_action(
action_name, labthings_typehint=labthings_typehint, **kwargs
)

if "output" in action and "type" in action["output"]:
action_method.__annotations__["return"] = action["output"]["type"]
Expand Down
5 changes: 5 additions & 0 deletions src/labthings_fastapi/thing_description/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,12 @@ def type_to_dataschema(t: type, **kwargs: Any) -> DataSchema:
:raise ValidationError: if the datatype cannot be represented
by a `.DataSchema`.
"""
data_format = None
if hasattr(t, "model_json_schema"):
# The input should be a `BaseModel` subclass, in which case this works:
json_schema = t.model_json_schema()
if "_labthings_typehint" in t.__private_attributes__:
data_format = t.__private_attributes__["_labthings_typehint"].default
else:
# In principle, the below should work for any type, though some
# deferred annotations can go wrong.
Expand All @@ -319,6 +322,8 @@ def type_to_dataschema(t: type, **kwargs: Any) -> DataSchema:
if k in schema_dict:
del schema_dict[k]
schema_dict.update(kwargs)
if data_format is not None:
schema_dict["format"] = data_format
try:
return DataSchema(**schema_dict)
except ValidationError as ve:
Expand Down
12 changes: 12 additions & 0 deletions src/labthings_fastapi/types/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,15 @@ class DenumpifyingDict(RootModel):

root: Annotated[Mapping, WrapSerializer(denumpify_serializer)]
model_config = ConfigDict(arbitrary_types_allowed=True)


class ArrayModel(RootModel):
"""A model automatically used by actions as the return type for a numpy array.

This models is passed to FastAPI as the return model for any action that returns
a numpy array. The private typehint is saved as format information to allow
a ThingClient to reconstruct the array from the list sent over HTTP.
"""

root: NDArray
_labthings_typehint: str = "ndarray"
6 changes: 6 additions & 0 deletions src/labthings_fastapi/utilities/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from pydantic import BaseModel, ConfigDict, Field, RootModel
from pydantic.main import create_model
from fastapi.dependencies.utils import analyze_param, get_typed_signature
import numpy as np

from ..types.numpy import ArrayModel


class EmptyObject(BaseModel):
Expand Down Expand Up @@ -178,6 +181,9 @@ def return_type(func: Callable) -> Type:
else:
# We use `get_type_hints` rather than just `sig.return_annotation`
# because it resolves forward references, etc.
rtype = get_type_hints(func)["return"]
if isinstance(rtype, type) and issubclass(rtype, np.ndarray):
return ArrayModel
type_hints = get_type_hints(func, include_extras=True)
return type_hints["return"]

Expand Down
24 changes: 18 additions & 6 deletions tests/test_numpy_type.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from __future__ import annotations

from pydantic import BaseModel, RootModel
from pydantic import BaseModel
import numpy as np
from fastapi.testclient import TestClient

from labthings_fastapi.testing import create_thing_without_server
from labthings_fastapi.types.numpy import NDArray, DenumpifyingDict
from labthings_fastapi.types.numpy import NDArray, DenumpifyingDict, ArrayModel
import labthings_fastapi as lt


class ArrayModel(RootModel):
root: NDArray


def check_field_works_with_list(data):
class Model(BaseModel):
a: NDArray
Expand Down Expand Up @@ -70,6 +67,10 @@ class MyNumpyThing(lt.Thing):
def action_with_arrays(self, a: NDArray) -> NDArray:
return a * 2

@lt.action
def read_array(self) -> NDArray:
return np.array([1, 2])


def test_thing_description():
"""Make sure the TD validates when numpy types are used."""
Expand Down Expand Up @@ -102,3 +103,14 @@ def test_rootmodel():
m = ArrayModel(root=input)
assert isinstance(m.root, np.ndarray)
assert (m.model_dump() == [0, 1, 2]).all()


def test_numpy_over_http():
"""Read numpy array over http."""
server = lt.ThingServer({"np_thing": MyNumpyThing})
with TestClient(server.app) as client:
np_thing_client = lt.ThingClient.from_url("/np_thing/", client=client)

array = np_thing_client.read_array()
assert isinstance(array, np.ndarray)
assert np.array_equal(array, np.array([1, 2]))
Loading