Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion iwf/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
WorkflowWorkerRpcRequestInternalChannelInfos,
WorkflowWorkerRpcRequestSignalChannelInfos,
)
from iwf.iwf_api.types import Unset
from iwf.object_encoder import ObjectEncoder
from iwf.state_movement import StateMovement
from iwf.type_store import TypeStore
Expand All @@ -16,7 +17,7 @@ class Communication:
_internal_channel_type_store: TypeStore
_signal_channel_type_store: dict[str, Optional[type]]
_object_encoder: ObjectEncoder
_to_publish_internal_channel: dict[str, list[EncodedObject]]
_to_publish_internal_channel: dict[str, list[Union[EncodedObject, Unset]]]
_state_movements: list[StateMovement]
_internal_channel_infos: Optional[WorkflowWorkerRpcRequestInternalChannelInfos]
_signal_channel_infos: Optional[WorkflowWorkerRpcRequestSignalChannelInfos]
Expand Down
7 changes: 5 additions & 2 deletions iwf/communication_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ class CommunicationMethod:
is_prefix: bool

@classmethod
def signal_channel_def(cls, name: str, value_type: type):
def signal_channel_def(cls, name: str, value_type: Union[type, None]):
return CommunicationMethod(
name, CommunicationMethodType.SignalChannel, value_type, False
name,
CommunicationMethodType.SignalChannel,
value_type if value_type is not None else type(None),
False,
)

@classmethod
Expand Down
9 changes: 5 additions & 4 deletions iwf/data_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@

from iwf.errors import WorkflowDefinitionError
from iwf.iwf_api.models import EncodedObject
from iwf.iwf_api.types import Unset
from iwf.object_encoder import ObjectEncoder
from iwf.type_store import TypeStore


class DataAttributes:
_type_store: TypeStore
_object_encoder: ObjectEncoder
_current_values: dict[str, Union[EncodedObject, None]]
_updated_values_to_return: dict[str, EncodedObject]
_current_values: dict[str, Union[EncodedObject, None, Unset]]
_updated_values_to_return: dict[str, Union[EncodedObject, Unset]]

def __init__(
self,
type_store: TypeStore,
object_encoder: ObjectEncoder,
current_values: dict[str, Union[EncodedObject, None]],
current_values: dict[str, Union[EncodedObject, None, Unset]],
):
self._object_encoder = object_encoder
self._type_store = type_store
Expand Down Expand Up @@ -56,5 +57,5 @@ def set_data_attribute(self, key: str, value: Any):
self._current_values[key] = encoded_value
self._updated_values_to_return[key] = encoded_value

def get_updated_values_to_return(self) -> dict[str, EncodedObject]:
def get_updated_values_to_return(self) -> dict[str, Union[EncodedObject, Unset]]:
return self._updated_values_to_return
81 changes: 44 additions & 37 deletions iwf/object_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from typing_extensions import Literal

from iwf.iwf_api.models import EncodedObject
from iwf.iwf_api.types import Unset
from iwf.iwf_api.types import UNSET, Unset

# StrEnum is available in 3.11+
if sys.version_info >= (3, 11):
Expand All @@ -50,14 +50,15 @@ class PayloadConverter(ABC):
def to_payload(
self,
value: Any,
) -> EncodedObject:
) -> Union[EncodedObject, Unset]:
"""Encode values into payloads.

Args:
value: value to be converted

Returns:
Converted payload.
A boolean to indicate if the payload was converted and the converted value
or Unset

Raises:
Exception: Any issue during conversion.
Expand Down Expand Up @@ -90,19 +91,20 @@ class EncodingPayloadConverter(ABC):

@property
@abstractmethod
def encoding(self) -> str:
def encoding(self) -> Union[str, Unset]:
"""Encoding for the payload this converter works with."""
raise NotImplementedError

@abstractmethod
def to_payload(self, value: Any) -> Optional[EncodedObject]:
def to_payload(self, value: Any) -> tuple[bool, Union[EncodedObject, Unset]]:
"""Encode a single value to a payload or None.

Args:
value: Value to be converted.

Returns:
Payload of the value or None if unable to convert.
A boolean to indicate if the payload was converted and the converted value
or Unset

Raises:
TypeError: Value is not the expected type.
Expand Down Expand Up @@ -145,7 +147,7 @@ class CompositePayloadConverter(PayloadConverter):
converters: List of payload converters to delegate to, in order.
"""

converters: Mapping[str, EncodingPayloadConverter]
converters: Mapping[Union[str, Unset], EncodingPayloadConverter]

def __init__(self, *converters: EncodingPayloadConverter) -> None:
"""Initializes the data converter.
Expand All @@ -159,7 +161,7 @@ def __init__(self, *converters: EncodingPayloadConverter) -> None:
def to_payload(
self,
value: Any,
) -> EncodedObject:
) -> Union[EncodedObject, Unset]:
"""Encode values trying each converter.

See base class. Always returns the same number of payloads as values.
Expand All @@ -169,12 +171,13 @@ def to_payload(
"""
# We intentionally attempt these serially just in case a stateful
# converter may rely on the previous values
payload = None
payload: Union[EncodedObject, Unset] = Unset()
is_encoded = False
for converter in self.converters.values():
payload = converter.to_payload(value)
if payload is not None:
is_encoded, payload = converter.to_payload(value)
if is_encoded:
break
if payload is None:
if not is_encoded:
raise RuntimeError(
f"Value of type {type(value)} has no known converter",
)
Expand All @@ -194,7 +197,7 @@ def from_payload(
RuntimeError: Error during decode
"""
encoding = payload.encoding
assert isinstance(encoding, str)
assert isinstance(encoding, (str, Unset))
converter = self.converters.get(encoding)
if converter is None:
raise KeyError(f"Unknown payload encoding {encoding}")
Expand Down Expand Up @@ -229,17 +232,15 @@ class BinaryNullPayloadConverter(EncodingPayloadConverter):
"""Converter for 'binary/null' payloads supporting None values."""

@property
def encoding(self) -> str:
def encoding(self) -> Union[str, Unset]:
"""See base class."""
return "binary/null"
return UNSET

def to_payload(self, value: Any) -> Optional[EncodedObject]:
def to_payload(self, value: Any) -> tuple[bool, Union[EncodedObject, Unset]]:
"""See base class."""
if value is None:
return EncodedObject(
encoding=self.encoding,
)
return None
return (True, UNSET)
return (False, UNSET)

def from_payload(
self,
Expand All @@ -256,18 +257,21 @@ class BinaryPlainPayloadConverter(EncodingPayloadConverter):
"""Converter for 'binary/plain' payloads supporting bytes values."""

@property
def encoding(self) -> str:
def encoding(self) -> Union[str, Unset]:
"""See base class."""
return "binary/plain"

def to_payload(self, value: Any) -> Optional[EncodedObject]:
def to_payload(self, value: Any) -> tuple[bool, Union[EncodedObject, Unset]]:
"""See base class."""
if isinstance(value, bytes):
return EncodedObject(
encoding=self.encoding,
data=str(value),
return (
True,
EncodedObject(
encoding=self.encoding,
data=str(value),
),
)
return None
return (False, UNSET)

def from_payload(
self,
Expand Down Expand Up @@ -345,11 +349,11 @@ def __init__(
self._custom_type_converters = custom_type_converters

@property
def encoding(self) -> str:
def encoding(self) -> Union[str, Unset]:
"""See base class."""
return self._encoding

def to_payload(self, value: Any) -> Optional[EncodedObject]:
def to_payload(self, value: Any) -> tuple[bool, Union[EncodedObject, Unset]]:
"""See base class."""
# Check for pydantic then send warning
if hasattr(value, "parse_obj"):
Expand All @@ -358,13 +362,16 @@ def to_payload(self, value: Any) -> Optional[EncodedObject]:
"https://github.com/temporalio/samples-python/tree/main/pydantic_converter for better support",
)
# We let JSON conversion errors be thrown to caller
return EncodedObject(
encoding=self.encoding,
data=json.dumps(
value,
cls=self._encoder,
separators=(",", ":"),
sort_keys=True,
return (
True,
EncodedObject(
encoding=self.encoding,
data=json.dumps(
value,
cls=self._encoder,
separators=(",", ":"),
sort_keys=True,
),
),
)

Expand Down Expand Up @@ -428,7 +435,7 @@ class PayloadCodec(ABC):
@abstractmethod
def encode(
self,
payload: EncodedObject,
payload: Union[EncodedObject, Unset],
) -> EncodedObject:
"""Encode the given payloads.

Expand Down Expand Up @@ -486,7 +493,7 @@ def __post_init__(self) -> None: # noqa: D105
def encode(
self,
value: Any,
) -> EncodedObject:
) -> Union[EncodedObject, Unset]:
"""Encode values into payloads.

First converts values to payload then encodes payload using codec.
Expand Down
11 changes: 6 additions & 5 deletions iwf/state_execution_locals.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from typing import Any, List, Tuple
from typing import Any, List, Tuple, Union

from iwf.errors import WorkflowDefinitionError
from iwf.iwf_api.models import EncodedObject, KeyValue
from iwf.iwf_api.types import Unset
from iwf.object_encoder import ObjectEncoder


class StateExecutionLocals:
_record_events: dict[str, EncodedObject]
_attribute_name_to_encoded_object_map: dict[str, EncodedObject]
_upsert_attributes_to_return_to_server: dict[str, EncodedObject]
_record_events: dict[str, Union[EncodedObject, Unset]]
_attribute_name_to_encoded_object_map: dict[str, Union[EncodedObject, Unset]]
_upsert_attributes_to_return_to_server: dict[str, Union[EncodedObject, Unset]]
_object_encoder: ObjectEncoder

def __init__(
self,
attribute_name_to_encoded_object_map: dict[str, EncodedObject],
attribute_name_to_encoded_object_map: dict[str, Union[EncodedObject, Unset]],
object_encoder: ObjectEncoder,
):
self._object_encoder = object_encoder
Expand Down
1 change: 0 additions & 1 deletion iwf/tests/test_rpc_with_memo_duplicate_java_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class TestRpcWithMemo(unittest.TestCase):
def setUpClass(cls):
cls.client = Client(registry)

@unittest.skip("Currently broken: difference in behavior with the iwf-java-sdk")
def test_rpc_memo_workflow_func1(self):
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"
run_id = self.client.start_workflow(
Expand Down
8 changes: 4 additions & 4 deletions iwf/tests/workflows/wait_signal_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ class WaitSignalWorkflow(ObjectWorkflow):
def get_communication_schema(self) -> CommunicationSchema:
return CommunicationSchema.create(
CommunicationMethod.signal_channel_def(test_channel_int, int),
CommunicationMethod.signal_channel_def(test_channel_none, type(None)),
CommunicationMethod.signal_channel_def(test_channel_none, None),
CommunicationMethod.signal_channel_def(test_channel_str, str),
CommunicationMethod.signal_channel_def(test_idle_channel_none, type(None)),
CommunicationMethod.signal_channel_def(test_channel1, type(None)),
CommunicationMethod.signal_channel_def(test_channel2, type(None)),
CommunicationMethod.signal_channel_def(test_idle_channel_none, None),
CommunicationMethod.signal_channel_def(test_channel1, None),
CommunicationMethod.signal_channel_def(test_channel2, None),
)

def get_workflow_states(self) -> StateSchema:
Expand Down
18 changes: 13 additions & 5 deletions iwf/worker_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def handle_workflow_worker_rpc(
unset_to_none(request.input_), rpc_info.input_type
)

current_data_attributes: dict[str, typing.Union[EncodedObject, None]] = {}
current_data_attributes: dict[str, typing.Union[EncodedObject, None, Unset]] = (
{}
)
if not isinstance(request.data_attributes, Unset):
current_data_attributes = {
assert_not_unset(attr.key): unset_to_none(attr.value)
Expand Down Expand Up @@ -184,7 +186,9 @@ def handle_workflow_state_wait_until(
unset_to_none(request.state_input), get_input_type(state)
)

current_data_attributes: dict[str, typing.Union[EncodedObject, None]] = {}
current_data_attributes: dict[str, typing.Union[EncodedObject, None, Unset]] = (
{}
)
if not isinstance(request.data_objects, Unset):
current_data_attributes = {
assert_not_unset(attr.key): unset_to_none(attr.value)
Expand Down Expand Up @@ -267,7 +271,9 @@ def handle_workflow_state_execute(
unset_to_none(request.state_input), get_input_type(state)
)

current_data_attributes: dict[str, typing.Union[EncodedObject, None]] = {}
current_data_attributes: dict[str, typing.Union[EncodedObject, None, Unset]] = (
{}
)
if not isinstance(request.data_objects, Unset):
current_data_attributes = {
assert_not_unset(attr.key): unset_to_none(attr.value)
Expand Down Expand Up @@ -407,9 +413,11 @@ def _create_upsert_search_attributes(
return sas


def to_map(key_values: Union[None, Unset, List[KeyValue]]) -> dict[str, EncodedObject]:
def to_map(
key_values: Union[None, Unset, List[KeyValue]],
) -> dict[str, Union[EncodedObject, Unset]]:
key_values = unset_to_none(key_values) or []
kvs = {}
kvs: dict[str, Union[EncodedObject, Unset]] = {}
for kv in key_values:
k = unset_to_none(kv.key)
v = unset_to_none(kv.value)
Expand Down