diff --git a/iwf/communication.py b/iwf/communication.py index 78d141a..0a79504 100644 --- a/iwf/communication.py +++ b/iwf/communication.py @@ -7,6 +7,7 @@ WorkflowWorkerRpcRequestInternalChannelInfos, WorkflowWorkerRpcRequestSignalChannelInfos, ) +from iwf.iwf_api.types import Unset from iwf.object_encoder import ObjectEncoder from iwf.state_movement import StateMovement from iwf.type_store import TypeStore @@ -16,7 +17,7 @@ class Communication: _internal_channel_type_store: TypeStore _signal_channel_type_store: dict[str, Optional[type]] _object_encoder: ObjectEncoder - _to_publish_internal_channel: dict[str, list[EncodedObject]] + _to_publish_internal_channel: dict[str, list[Union[EncodedObject, Unset]]] _state_movements: list[StateMovement] _internal_channel_infos: Optional[WorkflowWorkerRpcRequestInternalChannelInfos] _signal_channel_infos: Optional[WorkflowWorkerRpcRequestSignalChannelInfos] diff --git a/iwf/communication_schema.py b/iwf/communication_schema.py index 8f60d02..d93ad3b 100644 --- a/iwf/communication_schema.py +++ b/iwf/communication_schema.py @@ -16,9 +16,12 @@ class CommunicationMethod: is_prefix: bool @classmethod - def signal_channel_def(cls, name: str, value_type: type): + def signal_channel_def(cls, name: str, value_type: Union[type, None]): return CommunicationMethod( - name, CommunicationMethodType.SignalChannel, value_type, False + name, + CommunicationMethodType.SignalChannel, + value_type if value_type is not None else type(None), + False, ) @classmethod diff --git a/iwf/data_attributes.py b/iwf/data_attributes.py index 159d43c..200d864 100644 --- a/iwf/data_attributes.py +++ b/iwf/data_attributes.py @@ -2,6 +2,7 @@ from iwf.errors import WorkflowDefinitionError from iwf.iwf_api.models import EncodedObject +from iwf.iwf_api.types import Unset from iwf.object_encoder import ObjectEncoder from iwf.type_store import TypeStore @@ -9,14 +10,14 @@ class DataAttributes: _type_store: TypeStore _object_encoder: ObjectEncoder - _current_values: dict[str, Union[EncodedObject, None]] - _updated_values_to_return: dict[str, EncodedObject] + _current_values: dict[str, Union[EncodedObject, None, Unset]] + _updated_values_to_return: dict[str, Union[EncodedObject, Unset]] def __init__( self, type_store: TypeStore, object_encoder: ObjectEncoder, - current_values: dict[str, Union[EncodedObject, None]], + current_values: dict[str, Union[EncodedObject, None, Unset]], ): self._object_encoder = object_encoder self._type_store = type_store @@ -56,5 +57,5 @@ def set_data_attribute(self, key: str, value: Any): self._current_values[key] = encoded_value self._updated_values_to_return[key] = encoded_value - def get_updated_values_to_return(self) -> dict[str, EncodedObject]: + def get_updated_values_to_return(self) -> dict[str, Union[EncodedObject, Unset]]: return self._updated_values_to_return diff --git a/iwf/object_encoder.py b/iwf/object_encoder.py index b1166a4..19d3e7b 100644 --- a/iwf/object_encoder.py +++ b/iwf/object_encoder.py @@ -33,7 +33,7 @@ from typing_extensions import Literal from iwf.iwf_api.models import EncodedObject -from iwf.iwf_api.types import Unset +from iwf.iwf_api.types import UNSET, Unset # StrEnum is available in 3.11+ if sys.version_info >= (3, 11): @@ -50,14 +50,15 @@ class PayloadConverter(ABC): def to_payload( self, value: Any, - ) -> EncodedObject: + ) -> Union[EncodedObject, Unset]: """Encode values into payloads. Args: value: value to be converted Returns: - Converted payload. + A boolean to indicate if the payload was converted and the converted value + or Unset Raises: Exception: Any issue during conversion. @@ -90,19 +91,20 @@ class EncodingPayloadConverter(ABC): @property @abstractmethod - def encoding(self) -> str: + def encoding(self) -> Union[str, Unset]: """Encoding for the payload this converter works with.""" raise NotImplementedError @abstractmethod - def to_payload(self, value: Any) -> Optional[EncodedObject]: + def to_payload(self, value: Any) -> tuple[bool, Union[EncodedObject, Unset]]: """Encode a single value to a payload or None. Args: value: Value to be converted. Returns: - Payload of the value or None if unable to convert. + A boolean to indicate if the payload was converted and the converted value + or Unset Raises: TypeError: Value is not the expected type. @@ -145,7 +147,7 @@ class CompositePayloadConverter(PayloadConverter): converters: List of payload converters to delegate to, in order. """ - converters: Mapping[str, EncodingPayloadConverter] + converters: Mapping[Union[str, Unset], EncodingPayloadConverter] def __init__(self, *converters: EncodingPayloadConverter) -> None: """Initializes the data converter. @@ -159,7 +161,7 @@ def __init__(self, *converters: EncodingPayloadConverter) -> None: def to_payload( self, value: Any, - ) -> EncodedObject: + ) -> Union[EncodedObject, Unset]: """Encode values trying each converter. See base class. Always returns the same number of payloads as values. @@ -169,12 +171,13 @@ def to_payload( """ # We intentionally attempt these serially just in case a stateful # converter may rely on the previous values - payload = None + payload: Union[EncodedObject, Unset] = Unset() + is_encoded = False for converter in self.converters.values(): - payload = converter.to_payload(value) - if payload is not None: + is_encoded, payload = converter.to_payload(value) + if is_encoded: break - if payload is None: + if not is_encoded: raise RuntimeError( f"Value of type {type(value)} has no known converter", ) @@ -194,7 +197,7 @@ def from_payload( RuntimeError: Error during decode """ encoding = payload.encoding - assert isinstance(encoding, str) + assert isinstance(encoding, (str, Unset)) converter = self.converters.get(encoding) if converter is None: raise KeyError(f"Unknown payload encoding {encoding}") @@ -229,17 +232,15 @@ class BinaryNullPayloadConverter(EncodingPayloadConverter): """Converter for 'binary/null' payloads supporting None values.""" @property - def encoding(self) -> str: + def encoding(self) -> Union[str, Unset]: """See base class.""" - return "binary/null" + return UNSET - def to_payload(self, value: Any) -> Optional[EncodedObject]: + def to_payload(self, value: Any) -> tuple[bool, Union[EncodedObject, Unset]]: """See base class.""" if value is None: - return EncodedObject( - encoding=self.encoding, - ) - return None + return (True, UNSET) + return (False, UNSET) def from_payload( self, @@ -256,18 +257,21 @@ class BinaryPlainPayloadConverter(EncodingPayloadConverter): """Converter for 'binary/plain' payloads supporting bytes values.""" @property - def encoding(self) -> str: + def encoding(self) -> Union[str, Unset]: """See base class.""" return "binary/plain" - def to_payload(self, value: Any) -> Optional[EncodedObject]: + def to_payload(self, value: Any) -> tuple[bool, Union[EncodedObject, Unset]]: """See base class.""" if isinstance(value, bytes): - return EncodedObject( - encoding=self.encoding, - data=str(value), + return ( + True, + EncodedObject( + encoding=self.encoding, + data=str(value), + ), ) - return None + return (False, UNSET) def from_payload( self, @@ -345,11 +349,11 @@ def __init__( self._custom_type_converters = custom_type_converters @property - def encoding(self) -> str: + def encoding(self) -> Union[str, Unset]: """See base class.""" return self._encoding - def to_payload(self, value: Any) -> Optional[EncodedObject]: + def to_payload(self, value: Any) -> tuple[bool, Union[EncodedObject, Unset]]: """See base class.""" # Check for pydantic then send warning if hasattr(value, "parse_obj"): @@ -358,13 +362,16 @@ def to_payload(self, value: Any) -> Optional[EncodedObject]: "https://github.com/temporalio/samples-python/tree/main/pydantic_converter for better support", ) # We let JSON conversion errors be thrown to caller - return EncodedObject( - encoding=self.encoding, - data=json.dumps( - value, - cls=self._encoder, - separators=(",", ":"), - sort_keys=True, + return ( + True, + EncodedObject( + encoding=self.encoding, + data=json.dumps( + value, + cls=self._encoder, + separators=(",", ":"), + sort_keys=True, + ), ), ) @@ -428,7 +435,7 @@ class PayloadCodec(ABC): @abstractmethod def encode( self, - payload: EncodedObject, + payload: Union[EncodedObject, Unset], ) -> EncodedObject: """Encode the given payloads. @@ -486,7 +493,7 @@ def __post_init__(self) -> None: # noqa: D105 def encode( self, value: Any, - ) -> EncodedObject: + ) -> Union[EncodedObject, Unset]: """Encode values into payloads. First converts values to payload then encodes payload using codec. diff --git a/iwf/state_execution_locals.py b/iwf/state_execution_locals.py index 8e24fcf..bf553ff 100644 --- a/iwf/state_execution_locals.py +++ b/iwf/state_execution_locals.py @@ -1,19 +1,20 @@ -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Union from iwf.errors import WorkflowDefinitionError from iwf.iwf_api.models import EncodedObject, KeyValue +from iwf.iwf_api.types import Unset from iwf.object_encoder import ObjectEncoder class StateExecutionLocals: - _record_events: dict[str, EncodedObject] - _attribute_name_to_encoded_object_map: dict[str, EncodedObject] - _upsert_attributes_to_return_to_server: dict[str, EncodedObject] + _record_events: dict[str, Union[EncodedObject, Unset]] + _attribute_name_to_encoded_object_map: dict[str, Union[EncodedObject, Unset]] + _upsert_attributes_to_return_to_server: dict[str, Union[EncodedObject, Unset]] _object_encoder: ObjectEncoder def __init__( self, - attribute_name_to_encoded_object_map: dict[str, EncodedObject], + attribute_name_to_encoded_object_map: dict[str, Union[EncodedObject, Unset]], object_encoder: ObjectEncoder, ): self._object_encoder = object_encoder diff --git a/iwf/tests/test_rpc_with_memo_duplicate_java_tests.py b/iwf/tests/test_rpc_with_memo_duplicate_java_tests.py index d6dc7f3..7b5cb79 100644 --- a/iwf/tests/test_rpc_with_memo_duplicate_java_tests.py +++ b/iwf/tests/test_rpc_with_memo_duplicate_java_tests.py @@ -22,7 +22,6 @@ class TestRpcWithMemo(unittest.TestCase): def setUpClass(cls): cls.client = Client(registry) - @unittest.skip("Currently broken: difference in behavior with the iwf-java-sdk") def test_rpc_memo_workflow_func1(self): wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}" run_id = self.client.start_workflow( diff --git a/iwf/tests/workflows/wait_signal_workflow.py b/iwf/tests/workflows/wait_signal_workflow.py index 6238871..ca0e040 100644 --- a/iwf/tests/workflows/wait_signal_workflow.py +++ b/iwf/tests/workflows/wait_signal_workflow.py @@ -125,11 +125,11 @@ class WaitSignalWorkflow(ObjectWorkflow): def get_communication_schema(self) -> CommunicationSchema: return CommunicationSchema.create( CommunicationMethod.signal_channel_def(test_channel_int, int), - CommunicationMethod.signal_channel_def(test_channel_none, type(None)), + CommunicationMethod.signal_channel_def(test_channel_none, None), CommunicationMethod.signal_channel_def(test_channel_str, str), - CommunicationMethod.signal_channel_def(test_idle_channel_none, type(None)), - CommunicationMethod.signal_channel_def(test_channel1, type(None)), - CommunicationMethod.signal_channel_def(test_channel2, type(None)), + CommunicationMethod.signal_channel_def(test_idle_channel_none, None), + CommunicationMethod.signal_channel_def(test_channel1, None), + CommunicationMethod.signal_channel_def(test_channel2, None), ) def get_workflow_states(self) -> StateSchema: diff --git a/iwf/worker_service.py b/iwf/worker_service.py index dfe0acc..cdb4450 100644 --- a/iwf/worker_service.py +++ b/iwf/worker_service.py @@ -85,7 +85,9 @@ def handle_workflow_worker_rpc( unset_to_none(request.input_), rpc_info.input_type ) - current_data_attributes: dict[str, typing.Union[EncodedObject, None]] = {} + current_data_attributes: dict[str, typing.Union[EncodedObject, None, Unset]] = ( + {} + ) if not isinstance(request.data_attributes, Unset): current_data_attributes = { assert_not_unset(attr.key): unset_to_none(attr.value) @@ -184,7 +186,9 @@ def handle_workflow_state_wait_until( unset_to_none(request.state_input), get_input_type(state) ) - current_data_attributes: dict[str, typing.Union[EncodedObject, None]] = {} + current_data_attributes: dict[str, typing.Union[EncodedObject, None, Unset]] = ( + {} + ) if not isinstance(request.data_objects, Unset): current_data_attributes = { assert_not_unset(attr.key): unset_to_none(attr.value) @@ -267,7 +271,9 @@ def handle_workflow_state_execute( unset_to_none(request.state_input), get_input_type(state) ) - current_data_attributes: dict[str, typing.Union[EncodedObject, None]] = {} + current_data_attributes: dict[str, typing.Union[EncodedObject, None, Unset]] = ( + {} + ) if not isinstance(request.data_objects, Unset): current_data_attributes = { assert_not_unset(attr.key): unset_to_none(attr.value) @@ -407,9 +413,11 @@ def _create_upsert_search_attributes( return sas -def to_map(key_values: Union[None, Unset, List[KeyValue]]) -> dict[str, EncodedObject]: +def to_map( + key_values: Union[None, Unset, List[KeyValue]], +) -> dict[str, Union[EncodedObject, Unset]]: key_values = unset_to_none(key_values) or [] - kvs = {} + kvs: dict[str, Union[EncodedObject, Unset]] = {} for kv in key_values: k = unset_to_none(kv.key) v = unset_to_none(kv.value)