From b0bcbf0744dd2562b61576a786330d69d401a66a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 29 Jan 2026 09:55:43 -0800 Subject: [PATCH 1/8] Remove upper bound for dataclasses-json Signed-off-by: Kevin Su --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d2c1e98fb0..e21de75243 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "click>=6.6", "cloudpickle>=2.0.0", "croniter>=0.3.20", - "dataclasses-json>=0.5.2,<0.5.12", # TODO: remove upper-bound after fixing change in contract + "dataclasses-json>=0.5.2", "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", From 8cb44d0ace8d1f7745c720c5e91fc217746fb4e9 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 29 Jan 2026 12:19:29 -0800 Subject: [PATCH 2/8] test Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 33 ++++----------------------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 24a78f184b..e2d82ad4a5 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -22,7 +22,6 @@ import msgpack from cachetools import LRUCache -from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import literals_pb2 from fsspec.asyn import _run_coros_in_chunks # pylint: disable=W0212 from google.protobuf import json_format as _json_format @@ -510,7 +509,7 @@ def __init__(self) -> None: self._json_encoder: Dict[Type, JSONEncoder] = dict() self._json_decoder: Dict[Type, JSONDecoder] = dict() - def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): + def assert_type(self, expected_type, v: T): # Skip iterating all attributes in the dataclass if the type of v already matches the expected_type expected_type = get_underlying_type(expected_type) if type(v) == expected_type or issubclass(type(v), expected_type): @@ -614,7 +613,6 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: # Drop all annotations and handle only the dataclass type passed in. t = args[0] - schema = None try: # This produce JSON SCHEMA draft 2020-12 from mashumaro.jsonschema import build_json_schema @@ -625,30 +623,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: f"Failed to extract schema for object {t}, error: {e}\n" f"Please remove `DataClassJsonMixin` and `dataclass_json` decorator from the dataclass definition" ) - - if schema is None: - try: - # This produce JSON SCHEMA draft 2020-12 - from marshmallow_enum import EnumField, LoadDumpOptions - - if issubclass(t, DataClassJsonMixin): - s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema() - for _, v in s.fields.items(): - # marshmallow-jsonschema only supports enums loaded by name. - # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 - if isinstance(v, EnumField): - v.load_by = LoadDumpOptions.name - # check if DataClass mixin - from marshmallow_jsonschema import JSONSchema - - schema = JSONSchema().dump(s) - except Exception as e: - # https://github.com/lovasoa/marshmallow_dataclass/issues/13 - logger.warning( - f"Failed to extract schema for object {t}, (will run schemaless) error: {e}" - f"If you have postponed annotations turned on (PEP 563) turn it off please. Postponed" - f"evaluation doesn't work with json dataclasses" - ) + raise # Recursively construct the dataclass_type which contains the literal type of each field literal_type = {} @@ -2488,7 +2463,7 @@ def convert_marshmallow_json_schema_to_python_class(schema: dict, schema_name: t """ attribute_list = generate_attribute_list_from_dataclass_json(schema, schema_name) - return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) + return dataclasses.make_dataclass(schema_name, attribute_list) def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> type: @@ -2499,7 +2474,7 @@ def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typ """ attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name) - return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) + return dataclasses.make_dataclass(schema_name, attribute_list) def _get_element_type(element_property: typing.Dict[str, str]) -> Type: From 9af018e4678745f8ab4fe0d8074e6c171e8980ce Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 29 Jan 2026 13:50:48 -0800 Subject: [PATCH 3/8] wip Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e2d82ad4a5..ee23ed94b9 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -35,6 +35,11 @@ from mashumaro.mixins.json import DataClassJSONMixin from typing_extensions import Annotated, get_args, get_origin +try: + from dataclasses_json import DataClassJsonMixin +except ImportError: + DataClassJsonMixin = None # type: ignore + from flytekit.core.annotation import FlyteAnnotation from flytekit.core.constants import CACHE_KEY_METADATA, FLYTE_USE_OLD_DC_FORMAT, MESSAGEPACK, SERIALIZATION_FORMAT from flytekit.core.context_manager import FlyteContext @@ -914,6 +919,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: # We can't use hasattr(expected_python_type, "from_json") here because we rely on mashumaro's API to customize the deserialization behavior for Flyte types. if issubclass(expected_python_type, DataClassJSONMixin): dc = expected_python_type.from_json(json_str) # type: ignore + elif DataClassJsonMixin is not None and issubclass(expected_python_type, DataClassJsonMixin): + # Support legacy dataclasses_json.DataClassJsonMixin + dc = expected_python_type.from_json(json_str) # type: ignore else: # The function looks up or creates a JSONDecoder specifically designed for the object's type. # This decoder is then used to convert a JSON string into a data class. From d431c8cbd0eb49fcd2fead78b9ebe50d9e6c4d04 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 30 Jan 2026 02:42:04 -0800 Subject: [PATCH 4/8] fix tests Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 99 +++++++++++++++++++- tests/flytekit/unit/core/test_type_engine.py | 9 +- 2 files changed, 99 insertions(+), 9 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index ce556aa06a..df6f4644e9 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -618,17 +618,18 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: # Drop all annotations and handle only the dataclass type passed in. t = args[0] + schema = None try: # This produce JSON SCHEMA draft 2020-12 from mashumaro.jsonschema import build_json_schema schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict() except Exception as e: - logger.error( - f"Failed to extract schema for object {t}, error: {e}\n" - f"Please remove `DataClassJsonMixin` and `dataclass_json` decorator from the dataclass definition" + logger.warning( + f"Failed to extract schema for object {t}, (will run schemaless) error: {e}\n" + f"If you have postponed annotations turned on (PEP 563) turn it off please. Postponed " + f"evaluation doesn't work with json dataclasses" ) - raise # Recursively construct the dataclass_type which contains the literal type of each field literal_type = {} @@ -884,12 +885,99 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An return dc + def _fix_val_flyte_file_directory( + self, ctx: FlyteContext, python_type: typing.Type, val: typing.Any + ) -> typing.Any: + """ + Fix FlyteFile and FlyteDirectory fields that were deserialized via dataclasses_json. + dataclasses_json doesn't use the custom _deserialize method, so we need to manually + set up the proper local path and downloader for remote files. + """ + from flytekit.types.directory import FlyteDirectory + from flytekit.types.file import FlyteFile + + if val is None: + return val + + # Check for FlyteFile/FlyteDirectory FIRST, before the generic dataclass check + # because FlyteFile and FlyteDirectory are also dataclasses + if isinstance(val, FlyteFile): + # Fix FlyteFile if path is remote - need to set up local path and downloader + if ctx.file_access.is_remote(val.path): + from flytekit.types.file.file import FlyteFilePathTransformer + transformer = FlyteFilePathTransformer() + return transformer.dict_to_flyte_file( + {"path": val.path, "metadata": getattr(val, "metadata", None)}, + type(val) if type(val) != FlyteFile else FlyteFile + ) + return val + elif isinstance(val, FlyteDirectory): + # Fix FlyteDirectory if path is remote - need to set up local path and downloader + if ctx.file_access.is_remote(val.path): + from flytekit.types.directory.types import FlyteDirToMultipartBlobTransformer + transformer = FlyteDirToMultipartBlobTransformer() + return transformer.dict_to_flyte_directory( + {"path": val.path, "metadata": getattr(val, "metadata", None)}, + type(val) if type(val) != FlyteDirectory else FlyteDirectory + ) + return val + + python_type = get_underlying_type(python_type) + origin = get_origin(python_type) + args = get_args(python_type) + + if origin is list and args: + return [self._fix_val_flyte_file_directory(ctx, args[0], item) for item in val] + elif origin is dict and args and len(args) == 2: + return {k: self._fix_val_flyte_file_directory(ctx, args[1], v) for k, v in val.items()} + elif origin is typing.Union: + # Handle Optional and Union types + for arg in args: + if arg is type(None): + continue + try: + return self._fix_val_flyte_file_directory(ctx, arg, val) + except Exception: + continue + return val + elif dataclasses.is_dataclass(python_type) and not isinstance(python_type, type): + # Already an instance, skip + return val + elif dataclasses.is_dataclass(python_type): + return self._fix_dataclass_flyte_file_directory(ctx, python_type, val) + + return val + + def _fix_dataclass_flyte_file_directory( + self, ctx: FlyteContext, dc_type: Type[dataclasses.dataclass], dc: typing.Any + ) -> typing.Any: + """ + Walk through dataclass fields and fix FlyteFile/FlyteDirectory values. + """ + hints = typing.get_type_hints(dc_type) + for f in dataclasses.fields(dc_type): + python_type = hints.get(f.name, f.type) + val = getattr(dc, f.name) + fixed_val = self._fix_val_flyte_file_directory(ctx, python_type, val) + if fixed_val is not val: + object.__setattr__(dc, f.name, fixed_val) + return dc + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T: if binary_idl_object.tag == MESSAGEPACK: if issubclass(expected_python_type, DataClassJSONMixin): dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False) json_str = json.dumps(dict_obj) dc = expected_python_type.from_json(json_str) # type: ignore + elif DataClassJsonMixin is not None and issubclass(expected_python_type, DataClassJsonMixin): + # Support legacy dataclasses_json.DataClassJsonMixin + dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False) + json_str = json.dumps(dict_obj) + dc = expected_python_type.from_json(json_str) # type: ignore + # Fix up FlyteFile/FlyteDirectory fields that don't get proper deserialization with dataclasses_json + from flytekit.core.context_manager import FlyteContextManager + ctx = FlyteContextManager.current_context() + dc = self._fix_dataclass_flyte_file_directory(ctx, expected_python_type, dc) else: try: decoder = self._msgpack_decoder[expected_python_type] @@ -922,6 +1010,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: elif DataClassJsonMixin is not None and issubclass(expected_python_type, DataClassJsonMixin): # Support legacy dataclasses_json.DataClassJsonMixin dc = expected_python_type.from_json(json_str) # type: ignore + # dataclasses_json doesn't use custom _deserialize methods for FlyteFile/FlyteDirectory, + # so we need to fix them up manually to ensure proper local paths and downloaders are set + dc = self._fix_dataclass_flyte_file_directory(ctx, expected_python_type, dc) else: # The function looks up or creates a JSONDecoder specifically designed for the object's type. # This decoder is then used to convert a JSON string into a data class. diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 8945ea46dd..22bec00cb1 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3821,8 +3821,7 @@ def test_strict_type_matching_error(): strict_type_hint_matching(xs, lt) -@pytest.mark.asyncio -async def test_dict_transformer_annotated_type(): +def test_dict_transformer_annotated_type(): ctx = FlyteContext.current_context() # Test case 1: Regular Dict type @@ -3831,7 +3830,7 @@ async def test_dict_transformer_annotated_type(): expected_type = TypeEngine.to_literal_type(regular_dict_type) # This should work fine - literal1 = await TypeEngine.async_to_literal(ctx, regular_dict, regular_dict_type, expected_type) + literal1 = TypeEngine.to_literal(ctx, regular_dict, regular_dict_type, expected_type) assert literal1.map.literals["a"].scalar.primitive.integer == 1 assert literal1.map.literals["b"].scalar.primitive.integer == 2 @@ -3840,7 +3839,7 @@ async def test_dict_transformer_annotated_type(): annotated_dict_type = Annotated[Dict[str, int], "some_metadata"] expected_type = TypeEngine.to_literal_type(annotated_dict_type) - literal2 = await TypeEngine.async_to_literal(ctx, annotated_dict, annotated_dict_type, expected_type) + literal2 = TypeEngine.to_literal(ctx, annotated_dict, annotated_dict_type, expected_type) assert literal2.map.literals["x"].scalar.primitive.integer == 10 assert literal2.map.literals["y"].scalar.primitive.integer == 20 @@ -3849,7 +3848,7 @@ async def test_dict_transformer_annotated_type(): nested_dict_type = Dict[str, Annotated[Dict[str, int], "inner_metadata"]] expected_type = TypeEngine.to_literal_type(nested_dict_type) - literal3 = await TypeEngine.async_to_literal(ctx, nested_dict, nested_dict_type, expected_type) + literal3 = TypeEngine.to_literal(ctx, nested_dict, nested_dict_type, expected_type) assert literal3.map.literals["outer"].map.literals["inner"].scalar.primitive.integer == 42 @pytest.fixture(autouse=True) From ca64cc91a5cd4feef40b57c1d90041c7238853d8 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 30 Jan 2026 03:00:59 -0800 Subject: [PATCH 5/8] test Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 64 +++++++++++++++++-- .../unit/core/test_dataclass_guessing.py | 11 ++-- 2 files changed, 66 insertions(+), 9 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index df6f4644e9..8b1924070c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -679,6 +679,13 @@ def to_generic_literal( # JSON serialization using mashumaro's DataClassJSONMixin if isinstance(python_val, DataClassJSONMixin): json_str = python_val.to_json() + elif DataClassJsonMixin is not None and isinstance(python_val, DataClassJsonMixin): + # Support legacy dataclasses_json.DataClassJsonMixin + # We can't use mashumaro's encoder because it includes dataclass_json_config in the output + # We can't use to_json() directly because it doesn't properly serialize FlyteFile/FlyteDirectory + # So we manually convert to dict with proper handling + dict_obj = self._dataclass_to_dict(python_val) + json_str = json.dumps(dict_obj) else: try: encoder = self._json_encoder[python_type] @@ -719,6 +726,13 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp json_str = python_val.to_json() dict_obj = json.loads(json_str) msgpack_bytes = msgpack.dumps(dict_obj) + elif DataClassJsonMixin is not None and isinstance(python_val, DataClassJsonMixin): + # Support legacy dataclasses_json.DataClassJsonMixin + # We can't use mashumaro's encoder because it includes dataclass_json_config in the output + # We can't use to_json() directly because it doesn't properly serialize FlyteFile/FlyteDirectory + # So we manually convert to dict with proper handling + dict_obj = self._dataclass_to_dict(python_val) + msgpack_bytes = msgpack.dumps(dict_obj) else: # The function looks up or creates a MessagePackEncoder specifically designed for the object's type. # This encoder is then used to convert a data class into MessagePack Bytes. @@ -841,6 +855,45 @@ def get_expected_type(python_val: T, types: tuple) -> Type[T | None]: object.__setattr__(python_val, n, self._make_dataclass_serializable(val, t)) return python_val + def _dataclass_to_dict(self, python_val: typing.Any) -> typing.Dict[str, typing.Any]: + """ + Convert a dataclass to a dict, properly serializing FlyteFile/FlyteDirectory objects. + This is used for dataclasses_json types which don't have proper _serialize support. + """ + from flytekit.types.directory import FlyteDirectory + from flytekit.types.file import FlyteFile + from flytekit.types.structured import StructuredDataset + + if python_val is None: + return None # type: ignore + + # Handle enums - convert to value + if isinstance(python_val, enum.Enum): + return python_val.value + + # Handle FlyteFile, FlyteDirectory, StructuredDataset by calling their _serialize method + if isinstance(python_val, (FlyteFile, FlyteDirectory, StructuredDataset)): + return python_val._serialize() + + # Handle lists + if isinstance(python_val, list): + return [self._dataclass_to_dict(item) for item in python_val] + + # Handle dicts + if isinstance(python_val, dict): + return {k: self._dataclass_to_dict(v) for k, v in python_val.items()} + + # Handle dataclasses + if dataclasses.is_dataclass(python_val) and not isinstance(python_val, type): + result = {} + for field in dataclasses.fields(python_val): + val = getattr(python_val, field.name) + result[field.name] = self._dataclass_to_dict(val) + return result + + # Return primitive values as-is + return python_val + def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: if val is None: return val @@ -885,9 +938,7 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An return dc - def _fix_val_flyte_file_directory( - self, ctx: FlyteContext, python_type: typing.Type, val: typing.Any - ) -> typing.Any: + def _fix_val_flyte_file_directory(self, ctx: FlyteContext, python_type: typing.Type, val: typing.Any) -> typing.Any: """ Fix FlyteFile and FlyteDirectory fields that were deserialized via dataclasses_json. dataclasses_json doesn't use the custom _deserialize method, so we need to manually @@ -905,20 +956,22 @@ def _fix_val_flyte_file_directory( # Fix FlyteFile if path is remote - need to set up local path and downloader if ctx.file_access.is_remote(val.path): from flytekit.types.file.file import FlyteFilePathTransformer + transformer = FlyteFilePathTransformer() return transformer.dict_to_flyte_file( {"path": val.path, "metadata": getattr(val, "metadata", None)}, - type(val) if type(val) != FlyteFile else FlyteFile + type(val) if type(val) != FlyteFile else FlyteFile, ) return val elif isinstance(val, FlyteDirectory): # Fix FlyteDirectory if path is remote - need to set up local path and downloader if ctx.file_access.is_remote(val.path): from flytekit.types.directory.types import FlyteDirToMultipartBlobTransformer + transformer = FlyteDirToMultipartBlobTransformer() return transformer.dict_to_flyte_directory( {"path": val.path, "metadata": getattr(val, "metadata", None)}, - type(val) if type(val) != FlyteDirectory else FlyteDirectory + type(val) if type(val) != FlyteDirectory else FlyteDirectory, ) return val @@ -976,6 +1029,7 @@ def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[ dc = expected_python_type.from_json(json_str) # type: ignore # Fix up FlyteFile/FlyteDirectory fields that don't get proper deserialization with dataclasses_json from flytekit.core.context_manager import FlyteContextManager + ctx = FlyteContextManager.current_context() dc = self._fix_dataclass_flyte_file_directory(ctx, expected_python_type, dc) else: diff --git a/tests/flytekit/unit/core/test_dataclass_guessing.py b/tests/flytekit/unit/core/test_dataclass_guessing.py index e3face7342..80eafd1ba4 100644 --- a/tests/flytekit/unit/core/test_dataclass_guessing.py +++ b/tests/flytekit/unit/core/test_dataclass_guessing.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from flytekit.core.type_engine import TypeEngine, strict_type_hint_matching -from mashumaro.codecs.json import JSONDecoder +from mashumaro.codecs.json import JSONDecoder, JSONEncoder class JobConfig(BaseModel): @@ -51,7 +51,8 @@ def test_guessing_of_nested_pydantic(): input_config_dc_version = decoder.decode(input_config_json) # recover the dataclass back into json, and then back into pydantic, and make sure it matches. - json_dc_version = input_config_dc_version.to_json() + encoder = JSONEncoder(guessed_type) + json_dc_version = encoder.encode(input_config_dc_version) reconstituted_pydantic = SchedulerConfig.model_validate_json(json_dc_version) assert reconstituted_pydantic == input_config @@ -80,7 +81,8 @@ def test_nested_pydantic_reconstruction_from_raw_json(): input_config_dc_version = decoder.decode(existing_json) # recover the dataclass back into json, and then back into pydantic, and make sure it matches. - json_dc_version = input_config_dc_version.to_json() + encoder = JSONEncoder(guessed_type) + json_dc_version = encoder.encode(input_config_dc_version) reconstituted_pydantic = SchedulerConfig.model_validate_json(json_dc_version) assert reconstituted_pydantic == SchedulerConfig( input_storage_bucket="s3://input-storage-bucket", @@ -120,7 +122,8 @@ def test_guessing_of_nested_pydantic_mapped(): input_config_dc_version = decoder.decode(input_config_json) # recover the dataclass back into json, and then back into pydantic, and make sure it matches. - json_dc_version = input_config_dc_version.to_json() + encoder = JSONEncoder(guessed_type) + json_dc_version = encoder.encode(input_config_dc_version) reconstituted_pydantic = SchedulerConfigMapped.model_validate_json(json_dc_version) assert reconstituted_pydantic == input_config From 143029128abd47df1589df5d9ab56979a6487fda Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 30 Jan 2026 09:40:17 -0800 Subject: [PATCH 6/8] test Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 46 ++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 8b1924070c..6e88093e8f 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -681,10 +681,17 @@ def to_generic_literal( json_str = python_val.to_json() elif DataClassJsonMixin is not None and isinstance(python_val, DataClassJsonMixin): # Support legacy dataclasses_json.DataClassJsonMixin - # We can't use mashumaro's encoder because it includes dataclass_json_config in the output - # We can't use to_json() directly because it doesn't properly serialize FlyteFile/FlyteDirectory - # So we manually convert to dict with proper handling - dict_obj = self._dataclass_to_dict(python_val) + # Use mashumaro's encoder to handle serialization_strategy from mashumaro Config, + # then filter out dataclass_json_config which mashumaro includes from type hints + try: + encoder = self._json_encoder[python_type] + except KeyError: + encoder = JSONEncoder(python_type) + self._json_encoder[python_type] = encoder + json_str = encoder.encode(python_val) + # Recursively remove dataclass_json_config from the output (including nested dataclasses) + dict_obj = json.loads(json_str) + dict_obj = self._remove_dataclass_json_config(dict_obj) json_str = json.dumps(dict_obj) else: try: @@ -728,10 +735,17 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp msgpack_bytes = msgpack.dumps(dict_obj) elif DataClassJsonMixin is not None and isinstance(python_val, DataClassJsonMixin): # Support legacy dataclasses_json.DataClassJsonMixin - # We can't use mashumaro's encoder because it includes dataclass_json_config in the output - # We can't use to_json() directly because it doesn't properly serialize FlyteFile/FlyteDirectory - # So we manually convert to dict with proper handling - dict_obj = self._dataclass_to_dict(python_val) + # Use mashumaro's encoder to handle serialization_strategy from mashumaro Config, + # then filter out dataclass_json_config which mashumaro includes from type hints + try: + encoder = self._msgpack_encoder[python_type] + except KeyError: + encoder = MessagePackEncoder(python_type) + self._msgpack_encoder[python_type] = encoder + msgpack_bytes = encoder.encode(python_val) + # Recursively remove dataclass_json_config from the output (including nested dataclasses) + dict_obj = msgpack.loads(msgpack_bytes, strict_map_key=False) + dict_obj = self._remove_dataclass_json_config(dict_obj) msgpack_bytes = msgpack.dumps(dict_obj) else: # The function looks up or creates a MessagePackEncoder specifically designed for the object's type. @@ -855,7 +869,19 @@ def get_expected_type(python_val: T, types: tuple) -> Type[T | None]: object.__setattr__(python_val, n, self._make_dataclass_serializable(val, t)) return python_val - def _dataclass_to_dict(self, python_val: typing.Any) -> typing.Dict[str, typing.Any]: + def _remove_dataclass_json_config(self, dict_obj: typing.Any) -> typing.Any: + """ + Recursively remove dataclass_json_config from nested dicts. + This is needed because mashumaro includes it when serializing dataclasses_json types. + """ + if isinstance(dict_obj, dict): + result = {k: self._remove_dataclass_json_config(v) for k, v in dict_obj.items() if k != "dataclass_json_config"} + return result + elif isinstance(dict_obj, list): + return [self._remove_dataclass_json_config(item) for item in dict_obj] + return dict_obj + + def _dataclass_to_dict(self, python_val: typing.Any) -> typing.Any: """ Convert a dataclass to a dict, properly serializing FlyteFile/FlyteDirectory objects. This is used for dataclasses_json types which don't have proper _serialize support. @@ -1002,7 +1028,7 @@ def _fix_val_flyte_file_directory(self, ctx: FlyteContext, python_type: typing.T return val def _fix_dataclass_flyte_file_directory( - self, ctx: FlyteContext, dc_type: Type[dataclasses.dataclass], dc: typing.Any + self, ctx: FlyteContext, dc_type: type, dc: typing.Any ) -> typing.Any: """ Walk through dataclass fields and fix FlyteFile/FlyteDirectory values. From 4cfeb10451402c7b11dd257babfc0eddaf66d184 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 30 Jan 2026 10:16:15 -0800 Subject: [PATCH 7/8] nit Signed-off-by: Kevin Su --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e058f4b840..305c775b0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,6 @@ dependencies = [ "keyring>=18.0.1", "markdown-it-py", "marshmallow-enum", - "marshmallow-jsonschema>=0.12.0", "mashumaro>=3.15", "msgpack>=1.1.0", "protobuf!=4.25.0", From c20570f94623c25b6f5e6bcccea19fc594c6c68b Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 30 Jan 2026 10:22:30 -0800 Subject: [PATCH 8/8] test Signed-off-by: Kevin Su --- dev-requirements.in | 1 + 1 file changed, 1 insertion(+) diff --git a/dev-requirements.in b/dev-requirements.in index cef6ce1929..47d9a1c34e 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -62,3 +62,4 @@ ipykernel orjson kubernetes>=12.0.1 httpx +marshmallow-jsonschema