diff --git a/dev-requirements.in b/dev-requirements.in index cef6ce1929..47d9a1c34e 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -62,3 +62,4 @@ ipykernel orjson kubernetes>=12.0.1 httpx +marshmallow-jsonschema diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 3e436f0814..6e88093e8f 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -22,7 +22,6 @@ import msgpack from cachetools import LRUCache -from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import literals_pb2 from fsspec.asyn import _run_coros_in_chunks # pylint: disable=W0212 from google.protobuf import json_format as _json_format @@ -36,6 +35,11 @@ from mashumaro.mixins.json import DataClassJSONMixin from typing_extensions import Annotated, get_args, get_origin +try: + from dataclasses_json import DataClassJsonMixin +except ImportError: + DataClassJsonMixin = None # type: ignore + from flytekit.core.annotation import FlyteAnnotation from flytekit.core.constants import CACHE_KEY_METADATA, FLYTE_USE_OLD_DC_FORMAT, MESSAGEPACK, SERIALIZATION_FORMAT from flytekit.core.context_manager import FlyteContext @@ -510,7 +514,7 @@ def __init__(self) -> None: self._json_encoder: Dict[Type, JSONEncoder] = dict() self._json_decoder: Dict[Type, JSONDecoder] = dict() - def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): + def assert_type(self, expected_type, v: T): # Skip iterating all attributes in the dataclass if the type of v already matches the expected_type expected_type = get_underlying_type(expected_type) if type(v) == expected_type or issubclass(type(v), expected_type): @@ -621,35 +625,12 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict() except Exception as e: - logger.error( - f"Failed to extract schema for object {t}, error: {e}\n" - f"Please remove `DataClassJsonMixin` and `dataclass_json` decorator from the dataclass definition" + logger.warning( + f"Failed to extract schema for object {t}, (will run schemaless) error: {e}\n" + f"If you have postponed annotations turned on (PEP 563) turn it off please. Postponed " + f"evaluation doesn't work with json dataclasses" ) - if schema is None: - try: - # This produce JSON SCHEMA draft 2020-12 - from marshmallow_enum import EnumField, LoadDumpOptions - - if issubclass(t, DataClassJsonMixin): - s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema() - for _, v in s.fields.items(): - # marshmallow-jsonschema only supports enums loaded by name. - # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 - if isinstance(v, EnumField): - v.load_by = LoadDumpOptions.name - # check if DataClass mixin - from marshmallow_jsonschema import JSONSchema - - schema = JSONSchema().dump(s) - except Exception as e: - # https://github.com/lovasoa/marshmallow_dataclass/issues/13 - logger.warning( - f"Failed to extract schema for object {t}, (will run schemaless) error: {e}" - f"If you have postponed annotations turned on (PEP 563) turn it off please. Postponed" - f"evaluation doesn't work with json dataclasses" - ) - # Recursively construct the dataclass_type which contains the literal type of each field literal_type = {} @@ -698,6 +679,20 @@ def to_generic_literal( # JSON serialization using mashumaro's DataClassJSONMixin if isinstance(python_val, DataClassJSONMixin): json_str = python_val.to_json() + elif DataClassJsonMixin is not None and isinstance(python_val, DataClassJsonMixin): + # Support legacy dataclasses_json.DataClassJsonMixin + # Use mashumaro's encoder to handle serialization_strategy from mashumaro Config, + # then filter out dataclass_json_config which mashumaro includes from type hints + try: + encoder = self._json_encoder[python_type] + except KeyError: + encoder = JSONEncoder(python_type) + self._json_encoder[python_type] = encoder + json_str = encoder.encode(python_val) + # Recursively remove dataclass_json_config from the output (including nested dataclasses) + dict_obj = json.loads(json_str) + dict_obj = self._remove_dataclass_json_config(dict_obj) + json_str = json.dumps(dict_obj) else: try: encoder = self._json_encoder[python_type] @@ -738,6 +733,20 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp json_str = python_val.to_json() dict_obj = json.loads(json_str) msgpack_bytes = msgpack.dumps(dict_obj) + elif DataClassJsonMixin is not None and isinstance(python_val, DataClassJsonMixin): + # Support legacy dataclasses_json.DataClassJsonMixin + # Use mashumaro's encoder to handle serialization_strategy from mashumaro Config, + # then filter out dataclass_json_config which mashumaro includes from type hints + try: + encoder = self._msgpack_encoder[python_type] + except KeyError: + encoder = MessagePackEncoder(python_type) + self._msgpack_encoder[python_type] = encoder + msgpack_bytes = encoder.encode(python_val) + # Recursively remove dataclass_json_config from the output (including nested dataclasses) + dict_obj = msgpack.loads(msgpack_bytes, strict_map_key=False) + dict_obj = self._remove_dataclass_json_config(dict_obj) + msgpack_bytes = msgpack.dumps(dict_obj) else: # The function looks up or creates a MessagePackEncoder specifically designed for the object's type. # This encoder is then used to convert a data class into MessagePack Bytes. @@ -860,6 +869,57 @@ def get_expected_type(python_val: T, types: tuple) -> Type[T | None]: object.__setattr__(python_val, n, self._make_dataclass_serializable(val, t)) return python_val + def _remove_dataclass_json_config(self, dict_obj: typing.Any) -> typing.Any: + """ + Recursively remove dataclass_json_config from nested dicts. + This is needed because mashumaro includes it when serializing dataclasses_json types. + """ + if isinstance(dict_obj, dict): + result = {k: self._remove_dataclass_json_config(v) for k, v in dict_obj.items() if k != "dataclass_json_config"} + return result + elif isinstance(dict_obj, list): + return [self._remove_dataclass_json_config(item) for item in dict_obj] + return dict_obj + + def _dataclass_to_dict(self, python_val: typing.Any) -> typing.Any: + """ + Convert a dataclass to a dict, properly serializing FlyteFile/FlyteDirectory objects. + This is used for dataclasses_json types which don't have proper _serialize support. + """ + from flytekit.types.directory import FlyteDirectory + from flytekit.types.file import FlyteFile + from flytekit.types.structured import StructuredDataset + + if python_val is None: + return None # type: ignore + + # Handle enums - convert to value + if isinstance(python_val, enum.Enum): + return python_val.value + + # Handle FlyteFile, FlyteDirectory, StructuredDataset by calling their _serialize method + if isinstance(python_val, (FlyteFile, FlyteDirectory, StructuredDataset)): + return python_val._serialize() + + # Handle lists + if isinstance(python_val, list): + return [self._dataclass_to_dict(item) for item in python_val] + + # Handle dicts + if isinstance(python_val, dict): + return {k: self._dataclass_to_dict(v) for k, v in python_val.items()} + + # Handle dataclasses + if dataclasses.is_dataclass(python_val) and not isinstance(python_val, type): + result = {} + for field in dataclasses.fields(python_val): + val = getattr(python_val, field.name) + result[field.name] = self._dataclass_to_dict(val) + return result + + # Return primitive values as-is + return python_val + def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: if val is None: return val @@ -904,12 +964,100 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An return dc + def _fix_val_flyte_file_directory(self, ctx: FlyteContext, python_type: typing.Type, val: typing.Any) -> typing.Any: + """ + Fix FlyteFile and FlyteDirectory fields that were deserialized via dataclasses_json. + dataclasses_json doesn't use the custom _deserialize method, so we need to manually + set up the proper local path and downloader for remote files. + """ + from flytekit.types.directory import FlyteDirectory + from flytekit.types.file import FlyteFile + + if val is None: + return val + + # Check for FlyteFile/FlyteDirectory FIRST, before the generic dataclass check + # because FlyteFile and FlyteDirectory are also dataclasses + if isinstance(val, FlyteFile): + # Fix FlyteFile if path is remote - need to set up local path and downloader + if ctx.file_access.is_remote(val.path): + from flytekit.types.file.file import FlyteFilePathTransformer + + transformer = FlyteFilePathTransformer() + return transformer.dict_to_flyte_file( + {"path": val.path, "metadata": getattr(val, "metadata", None)}, + type(val) if type(val) != FlyteFile else FlyteFile, + ) + return val + elif isinstance(val, FlyteDirectory): + # Fix FlyteDirectory if path is remote - need to set up local path and downloader + if ctx.file_access.is_remote(val.path): + from flytekit.types.directory.types import FlyteDirToMultipartBlobTransformer + + transformer = FlyteDirToMultipartBlobTransformer() + return transformer.dict_to_flyte_directory( + {"path": val.path, "metadata": getattr(val, "metadata", None)}, + type(val) if type(val) != FlyteDirectory else FlyteDirectory, + ) + return val + + python_type = get_underlying_type(python_type) + origin = get_origin(python_type) + args = get_args(python_type) + + if origin is list and args: + return [self._fix_val_flyte_file_directory(ctx, args[0], item) for item in val] + elif origin is dict and args and len(args) == 2: + return {k: self._fix_val_flyte_file_directory(ctx, args[1], v) for k, v in val.items()} + elif origin is typing.Union: + # Handle Optional and Union types + for arg in args: + if arg is type(None): + continue + try: + return self._fix_val_flyte_file_directory(ctx, arg, val) + except Exception: + continue + return val + elif dataclasses.is_dataclass(python_type) and not isinstance(python_type, type): + # Already an instance, skip + return val + elif dataclasses.is_dataclass(python_type): + return self._fix_dataclass_flyte_file_directory(ctx, python_type, val) + + return val + + def _fix_dataclass_flyte_file_directory( + self, ctx: FlyteContext, dc_type: type, dc: typing.Any + ) -> typing.Any: + """ + Walk through dataclass fields and fix FlyteFile/FlyteDirectory values. + """ + hints = typing.get_type_hints(dc_type) + for f in dataclasses.fields(dc_type): + python_type = hints.get(f.name, f.type) + val = getattr(dc, f.name) + fixed_val = self._fix_val_flyte_file_directory(ctx, python_type, val) + if fixed_val is not val: + object.__setattr__(dc, f.name, fixed_val) + return dc + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T: if binary_idl_object.tag == MESSAGEPACK: if issubclass(expected_python_type, DataClassJSONMixin): dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False) json_str = json.dumps(dict_obj) dc = expected_python_type.from_json(json_str) # type: ignore + elif DataClassJsonMixin is not None and issubclass(expected_python_type, DataClassJsonMixin): + # Support legacy dataclasses_json.DataClassJsonMixin + dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False) + json_str = json.dumps(dict_obj) + dc = expected_python_type.from_json(json_str) # type: ignore + # Fix up FlyteFile/FlyteDirectory fields that don't get proper deserialization with dataclasses_json + from flytekit.core.context_manager import FlyteContextManager + + ctx = FlyteContextManager.current_context() + dc = self._fix_dataclass_flyte_file_directory(ctx, expected_python_type, dc) else: try: decoder = self._msgpack_decoder[expected_python_type] @@ -939,6 +1087,12 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: # We can't use hasattr(expected_python_type, "from_json") here because we rely on mashumaro's API to customize the deserialization behavior for Flyte types. if issubclass(expected_python_type, DataClassJSONMixin): dc = expected_python_type.from_json(json_str) # type: ignore + elif DataClassJsonMixin is not None and issubclass(expected_python_type, DataClassJsonMixin): + # Support legacy dataclasses_json.DataClassJsonMixin + dc = expected_python_type.from_json(json_str) # type: ignore + # dataclasses_json doesn't use custom _deserialize methods for FlyteFile/FlyteDirectory, + # so we need to fix them up manually to ensure proper local paths and downloaders are set + dc = self._fix_dataclass_flyte_file_directory(ctx, expected_python_type, dc) else: # The function looks up or creates a JSONDecoder specifically designed for the object's type. # This decoder is then used to convert a JSON string into a data class. @@ -2514,7 +2668,7 @@ def convert_marshmallow_json_schema_to_python_class(schema: dict, schema_name: t """ attribute_list = generate_attribute_list_from_dataclass_json(schema, schema_name) - return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) + return dataclasses.make_dataclass(schema_name, attribute_list) def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> type: @@ -2525,7 +2679,7 @@ def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typ """ attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name) - return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) + return dataclasses.make_dataclass(schema_name, attribute_list) def _get_element_type( diff --git a/pyproject.toml b/pyproject.toml index cb27b8002e..305c775b0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "click>=6.6", "cloudpickle>=2.0.0", "croniter>=0.3.20", - "dataclasses-json>=0.5.2,<0.5.12", # TODO: remove upper-bound after fixing change in contract + "dataclasses-json>=0.5.2", "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", @@ -36,7 +36,6 @@ dependencies = [ "keyring>=18.0.1", "markdown-it-py", "marshmallow-enum", - "marshmallow-jsonschema>=0.12.0", "mashumaro>=3.15", "msgpack>=1.1.0", "protobuf!=4.25.0", diff --git a/tests/flytekit/unit/core/test_dataclass_guessing.py b/tests/flytekit/unit/core/test_dataclass_guessing.py index e3face7342..80eafd1ba4 100644 --- a/tests/flytekit/unit/core/test_dataclass_guessing.py +++ b/tests/flytekit/unit/core/test_dataclass_guessing.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from flytekit.core.type_engine import TypeEngine, strict_type_hint_matching -from mashumaro.codecs.json import JSONDecoder +from mashumaro.codecs.json import JSONDecoder, JSONEncoder class JobConfig(BaseModel): @@ -51,7 +51,8 @@ def test_guessing_of_nested_pydantic(): input_config_dc_version = decoder.decode(input_config_json) # recover the dataclass back into json, and then back into pydantic, and make sure it matches. - json_dc_version = input_config_dc_version.to_json() + encoder = JSONEncoder(guessed_type) + json_dc_version = encoder.encode(input_config_dc_version) reconstituted_pydantic = SchedulerConfig.model_validate_json(json_dc_version) assert reconstituted_pydantic == input_config @@ -80,7 +81,8 @@ def test_nested_pydantic_reconstruction_from_raw_json(): input_config_dc_version = decoder.decode(existing_json) # recover the dataclass back into json, and then back into pydantic, and make sure it matches. - json_dc_version = input_config_dc_version.to_json() + encoder = JSONEncoder(guessed_type) + json_dc_version = encoder.encode(input_config_dc_version) reconstituted_pydantic = SchedulerConfig.model_validate_json(json_dc_version) assert reconstituted_pydantic == SchedulerConfig( input_storage_bucket="s3://input-storage-bucket", @@ -120,7 +122,8 @@ def test_guessing_of_nested_pydantic_mapped(): input_config_dc_version = decoder.decode(input_config_json) # recover the dataclass back into json, and then back into pydantic, and make sure it matches. - json_dc_version = input_config_dc_version.to_json() + encoder = JSONEncoder(guessed_type) + json_dc_version = encoder.encode(input_config_dc_version) reconstituted_pydantic = SchedulerConfigMapped.model_validate_json(json_dc_version) assert reconstituted_pydantic == input_config diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 8945ea46dd..22bec00cb1 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3821,8 +3821,7 @@ def test_strict_type_matching_error(): strict_type_hint_matching(xs, lt) -@pytest.mark.asyncio -async def test_dict_transformer_annotated_type(): +def test_dict_transformer_annotated_type(): ctx = FlyteContext.current_context() # Test case 1: Regular Dict type @@ -3831,7 +3830,7 @@ async def test_dict_transformer_annotated_type(): expected_type = TypeEngine.to_literal_type(regular_dict_type) # This should work fine - literal1 = await TypeEngine.async_to_literal(ctx, regular_dict, regular_dict_type, expected_type) + literal1 = TypeEngine.to_literal(ctx, regular_dict, regular_dict_type, expected_type) assert literal1.map.literals["a"].scalar.primitive.integer == 1 assert literal1.map.literals["b"].scalar.primitive.integer == 2 @@ -3840,7 +3839,7 @@ async def test_dict_transformer_annotated_type(): annotated_dict_type = Annotated[Dict[str, int], "some_metadata"] expected_type = TypeEngine.to_literal_type(annotated_dict_type) - literal2 = await TypeEngine.async_to_literal(ctx, annotated_dict, annotated_dict_type, expected_type) + literal2 = TypeEngine.to_literal(ctx, annotated_dict, annotated_dict_type, expected_type) assert literal2.map.literals["x"].scalar.primitive.integer == 10 assert literal2.map.literals["y"].scalar.primitive.integer == 20 @@ -3849,7 +3848,7 @@ async def test_dict_transformer_annotated_type(): nested_dict_type = Dict[str, Annotated[Dict[str, int], "inner_metadata"]] expected_type = TypeEngine.to_literal_type(nested_dict_type) - literal3 = await TypeEngine.async_to_literal(ctx, nested_dict, nested_dict_type, expected_type) + literal3 = TypeEngine.to_literal(ctx, nested_dict, nested_dict_type, expected_type) assert literal3.map.literals["outer"].map.literals["inner"].scalar.primitive.integer == 42 @pytest.fixture(autouse=True)