Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,4 @@ ipykernel
orjson
kubernetes>=12.0.1
httpx
marshmallow-jsonschema
216 changes: 185 additions & 31 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import msgpack
from cachetools import LRUCache
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from fsspec.asyn import _run_coros_in_chunks # pylint: disable=W0212
from google.protobuf import json_format as _json_format
Expand All @@ -36,6 +35,11 @@
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin

try:
from dataclasses_json import DataClassJsonMixin
except ImportError:
DataClassJsonMixin = None # type: ignore

from flytekit.core.annotation import FlyteAnnotation
from flytekit.core.constants import CACHE_KEY_METADATA, FLYTE_USE_OLD_DC_FORMAT, MESSAGEPACK, SERIALIZATION_FORMAT
from flytekit.core.context_manager import FlyteContext
Expand Down Expand Up @@ -510,7 +514,7 @@ def __init__(self) -> None:
self._json_encoder: Dict[Type, JSONEncoder] = dict()
self._json_decoder: Dict[Type, JSONDecoder] = dict()

def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
def assert_type(self, expected_type, v: T):
# Skip iterating all attributes in the dataclass if the type of v already matches the expected_type
expected_type = get_underlying_type(expected_type)
if type(v) == expected_type or issubclass(type(v), expected_type):
Expand Down Expand Up @@ -621,35 +625,12 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:

schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict()
except Exception as e:
logger.error(
f"Failed to extract schema for object {t}, error: {e}\n"
f"Please remove `DataClassJsonMixin` and `dataclass_json` decorator from the dataclass definition"
logger.warning(
f"Failed to extract schema for object {t}, (will run schemaless) error: {e}\n"
f"If you have postponed annotations turned on (PEP 563) turn it off please. Postponed "
f"evaluation doesn't work with json dataclasses"
)

if schema is None:
try:
# This produce JSON SCHEMA draft 2020-12
from marshmallow_enum import EnumField, LoadDumpOptions

if issubclass(t, DataClassJsonMixin):
s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema()
for _, v in s.fields.items():
# marshmallow-jsonschema only supports enums loaded by name.
# https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228
if isinstance(v, EnumField):
v.load_by = LoadDumpOptions.name
# check if DataClass mixin
from marshmallow_jsonschema import JSONSchema

schema = JSONSchema().dump(s)
except Exception as e:
# https://github.com/lovasoa/marshmallow_dataclass/issues/13
logger.warning(
f"Failed to extract schema for object {t}, (will run schemaless) error: {e}"
f"If you have postponed annotations turned on (PEP 563) turn it off please. Postponed"
f"evaluation doesn't work with json dataclasses"
)

# Recursively construct the dataclass_type which contains the literal type of each field
literal_type = {}

Expand Down Expand Up @@ -698,6 +679,20 @@ def to_generic_literal(
# JSON serialization using mashumaro's DataClassJSONMixin
if isinstance(python_val, DataClassJSONMixin):
json_str = python_val.to_json()
elif DataClassJsonMixin is not None and isinstance(python_val, DataClassJsonMixin):
# Support legacy dataclasses_json.DataClassJsonMixin
# Use mashumaro's encoder to handle serialization_strategy from mashumaro Config,
# then filter out dataclass_json_config which mashumaro includes from type hints
try:
encoder = self._json_encoder[python_type]
except KeyError:
encoder = JSONEncoder(python_type)
self._json_encoder[python_type] = encoder
json_str = encoder.encode(python_val)
# Recursively remove dataclass_json_config from the output (including nested dataclasses)
dict_obj = json.loads(json_str)
dict_obj = self._remove_dataclass_json_config(dict_obj)
json_str = json.dumps(dict_obj)
else:
try:
encoder = self._json_encoder[python_type]
Expand Down Expand Up @@ -738,6 +733,20 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
json_str = python_val.to_json()
dict_obj = json.loads(json_str)
msgpack_bytes = msgpack.dumps(dict_obj)
elif DataClassJsonMixin is not None and isinstance(python_val, DataClassJsonMixin):
# Support legacy dataclasses_json.DataClassJsonMixin
# Use mashumaro's encoder to handle serialization_strategy from mashumaro Config,
# then filter out dataclass_json_config which mashumaro includes from type hints
try:
encoder = self._msgpack_encoder[python_type]
except KeyError:
encoder = MessagePackEncoder(python_type)
self._msgpack_encoder[python_type] = encoder
msgpack_bytes = encoder.encode(python_val)
# Recursively remove dataclass_json_config from the output (including nested dataclasses)
dict_obj = msgpack.loads(msgpack_bytes, strict_map_key=False)
dict_obj = self._remove_dataclass_json_config(dict_obj)
msgpack_bytes = msgpack.dumps(dict_obj)
else:
# The function looks up or creates a MessagePackEncoder specifically designed for the object's type.
# This encoder is then used to convert a data class into MessagePack Bytes.
Expand Down Expand Up @@ -860,6 +869,57 @@ def get_expected_type(python_val: T, types: tuple) -> Type[T | None]:
object.__setattr__(python_val, n, self._make_dataclass_serializable(val, t))
return python_val

def _remove_dataclass_json_config(self, dict_obj: typing.Any) -> typing.Any:
"""
Recursively remove dataclass_json_config from nested dicts.
This is needed because mashumaro includes it when serializing dataclasses_json types.
"""
if isinstance(dict_obj, dict):
result = {k: self._remove_dataclass_json_config(v) for k, v in dict_obj.items() if k != "dataclass_json_config"}
return result
elif isinstance(dict_obj, list):
return [self._remove_dataclass_json_config(item) for item in dict_obj]
return dict_obj

def _dataclass_to_dict(self, python_val: typing.Any) -> typing.Any:
"""
Convert a dataclass to a dict, properly serializing FlyteFile/FlyteDirectory objects.
This is used for dataclasses_json types which don't have proper _serialize support.
"""
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.structured import StructuredDataset

if python_val is None:
return None # type: ignore

# Handle enums - convert to value
if isinstance(python_val, enum.Enum):
return python_val.value

# Handle FlyteFile, FlyteDirectory, StructuredDataset by calling their _serialize method
if isinstance(python_val, (FlyteFile, FlyteDirectory, StructuredDataset)):
return python_val._serialize()

# Handle lists
if isinstance(python_val, list):
return [self._dataclass_to_dict(item) for item in python_val]

# Handle dicts
if isinstance(python_val, dict):
return {k: self._dataclass_to_dict(v) for k, v in python_val.items()}

# Handle dataclasses
if dataclasses.is_dataclass(python_val) and not isinstance(python_val, type):
result = {}
for field in dataclasses.fields(python_val):
val = getattr(python_val, field.name)
result[field.name] = self._dataclass_to_dict(val)
return result

# Return primitive values as-is
return python_val

def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
if val is None:
return val
Expand Down Expand Up @@ -904,12 +964,100 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An

return dc

def _fix_val_flyte_file_directory(self, ctx: FlyteContext, python_type: typing.Type, val: typing.Any) -> typing.Any:
"""
Fix FlyteFile and FlyteDirectory fields that were deserialized via dataclasses_json.
dataclasses_json doesn't use the custom _deserialize method, so we need to manually
set up the proper local path and downloader for remote files.
"""
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile

if val is None:
return val

# Check for FlyteFile/FlyteDirectory FIRST, before the generic dataclass check
# because FlyteFile and FlyteDirectory are also dataclasses
if isinstance(val, FlyteFile):
# Fix FlyteFile if path is remote - need to set up local path and downloader
if ctx.file_access.is_remote(val.path):
from flytekit.types.file.file import FlyteFilePathTransformer

transformer = FlyteFilePathTransformer()
return transformer.dict_to_flyte_file(
{"path": val.path, "metadata": getattr(val, "metadata", None)},
type(val) if type(val) != FlyteFile else FlyteFile,
)
return val
elif isinstance(val, FlyteDirectory):
# Fix FlyteDirectory if path is remote - need to set up local path and downloader
if ctx.file_access.is_remote(val.path):
from flytekit.types.directory.types import FlyteDirToMultipartBlobTransformer

transformer = FlyteDirToMultipartBlobTransformer()
return transformer.dict_to_flyte_directory(
{"path": val.path, "metadata": getattr(val, "metadata", None)},
type(val) if type(val) != FlyteDirectory else FlyteDirectory,
)
return val

python_type = get_underlying_type(python_type)
origin = get_origin(python_type)
args = get_args(python_type)

if origin is list and args:
return [self._fix_val_flyte_file_directory(ctx, args[0], item) for item in val]
elif origin is dict and args and len(args) == 2:
return {k: self._fix_val_flyte_file_directory(ctx, args[1], v) for k, v in val.items()}
elif origin is typing.Union:
# Handle Optional and Union types
for arg in args:
if arg is type(None):
continue
try:
return self._fix_val_flyte_file_directory(ctx, arg, val)
except Exception:
continue
return val
elif dataclasses.is_dataclass(python_type) and not isinstance(python_type, type):
# Already an instance, skip
return val
elif dataclasses.is_dataclass(python_type):
return self._fix_dataclass_flyte_file_directory(ctx, python_type, val)

return val

def _fix_dataclass_flyte_file_directory(
self, ctx: FlyteContext, dc_type: type, dc: typing.Any
) -> typing.Any:
"""
Walk through dataclass fields and fix FlyteFile/FlyteDirectory values.
"""
hints = typing.get_type_hints(dc_type)
for f in dataclasses.fields(dc_type):
python_type = hints.get(f.name, f.type)
val = getattr(dc, f.name)
fixed_val = self._fix_val_flyte_file_directory(ctx, python_type, val)
if fixed_val is not val:
object.__setattr__(dc, f.name, fixed_val)
return dc

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T:
if binary_idl_object.tag == MESSAGEPACK:
if issubclass(expected_python_type, DataClassJSONMixin):
dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False)
json_str = json.dumps(dict_obj)
dc = expected_python_type.from_json(json_str) # type: ignore
elif DataClassJsonMixin is not None and issubclass(expected_python_type, DataClassJsonMixin):
# Support legacy dataclasses_json.DataClassJsonMixin
dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False)
json_str = json.dumps(dict_obj)
dc = expected_python_type.from_json(json_str) # type: ignore
# Fix up FlyteFile/FlyteDirectory fields that don't get proper deserialization with dataclasses_json
from flytekit.core.context_manager import FlyteContextManager

ctx = FlyteContextManager.current_context()
dc = self._fix_dataclass_flyte_file_directory(ctx, expected_python_type, dc)
else:
try:
decoder = self._msgpack_decoder[expected_python_type]
Expand Down Expand Up @@ -939,6 +1087,12 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
# We can't use hasattr(expected_python_type, "from_json") here because we rely on mashumaro's API to customize the deserialization behavior for Flyte types.
if issubclass(expected_python_type, DataClassJSONMixin):
dc = expected_python_type.from_json(json_str) # type: ignore
elif DataClassJsonMixin is not None and issubclass(expected_python_type, DataClassJsonMixin):
# Support legacy dataclasses_json.DataClassJsonMixin
dc = expected_python_type.from_json(json_str) # type: ignore
# dataclasses_json doesn't use custom _deserialize methods for FlyteFile/FlyteDirectory,
# so we need to fix them up manually to ensure proper local paths and downloaders are set
dc = self._fix_dataclass_flyte_file_directory(ctx, expected_python_type, dc)
else:
# The function looks up or creates a JSONDecoder specifically designed for the object's type.
# This decoder is then used to convert a JSON string into a data class.
Expand Down Expand Up @@ -2514,7 +2668,7 @@ def convert_marshmallow_json_schema_to_python_class(schema: dict, schema_name: t
"""

attribute_list = generate_attribute_list_from_dataclass_json(schema, schema_name)
return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list))
return dataclasses.make_dataclass(schema_name, attribute_list)


def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> type:
Expand All @@ -2525,7 +2679,7 @@ def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typ
"""

attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name)
return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list))
return dataclasses.make_dataclass(schema_name, attribute_list)


def _get_element_type(
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"click>=6.6",
"cloudpickle>=2.0.0",
"croniter>=0.3.20",
"dataclasses-json>=0.5.2,<0.5.12", # TODO: remove upper-bound after fixing change in contract
"dataclasses-json>=0.5.2",
"diskcache>=5.2.1",
"docker>=4.0.0",
"docstring-parser>=0.9.0",
Expand All @@ -36,7 +36,6 @@ dependencies = [
"keyring>=18.0.1",
"markdown-it-py",
"marshmallow-enum",
"marshmallow-jsonschema>=0.12.0",
"mashumaro>=3.15",
"msgpack>=1.1.0",
"protobuf!=4.25.0",
Expand Down
11 changes: 7 additions & 4 deletions tests/flytekit/unit/core/test_dataclass_guessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel

from flytekit.core.type_engine import TypeEngine, strict_type_hint_matching
from mashumaro.codecs.json import JSONDecoder
from mashumaro.codecs.json import JSONDecoder, JSONEncoder


class JobConfig(BaseModel):
Expand Down Expand Up @@ -51,7 +51,8 @@ def test_guessing_of_nested_pydantic():
input_config_dc_version = decoder.decode(input_config_json)

# recover the dataclass back into json, and then back into pydantic, and make sure it matches.
json_dc_version = input_config_dc_version.to_json()
encoder = JSONEncoder(guessed_type)
json_dc_version = encoder.encode(input_config_dc_version)
reconstituted_pydantic = SchedulerConfig.model_validate_json(json_dc_version)
assert reconstituted_pydantic == input_config

Expand Down Expand Up @@ -80,7 +81,8 @@ def test_nested_pydantic_reconstruction_from_raw_json():
input_config_dc_version = decoder.decode(existing_json)

# recover the dataclass back into json, and then back into pydantic, and make sure it matches.
json_dc_version = input_config_dc_version.to_json()
encoder = JSONEncoder(guessed_type)
json_dc_version = encoder.encode(input_config_dc_version)
reconstituted_pydantic = SchedulerConfig.model_validate_json(json_dc_version)
assert reconstituted_pydantic == SchedulerConfig(
input_storage_bucket="s3://input-storage-bucket",
Expand Down Expand Up @@ -120,7 +122,8 @@ def test_guessing_of_nested_pydantic_mapped():
input_config_dc_version = decoder.decode(input_config_json)

# recover the dataclass back into json, and then back into pydantic, and make sure it matches.
json_dc_version = input_config_dc_version.to_json()
encoder = JSONEncoder(guessed_type)
json_dc_version = encoder.encode(input_config_dc_version)
reconstituted_pydantic = SchedulerConfigMapped.model_validate_json(json_dc_version)
assert reconstituted_pydantic == input_config

Expand Down
9 changes: 4 additions & 5 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3821,8 +3821,7 @@ def test_strict_type_matching_error():
strict_type_hint_matching(xs, lt)


@pytest.mark.asyncio
async def test_dict_transformer_annotated_type():
def test_dict_transformer_annotated_type():
ctx = FlyteContext.current_context()

# Test case 1: Regular Dict type
Expand All @@ -3831,7 +3830,7 @@ async def test_dict_transformer_annotated_type():
expected_type = TypeEngine.to_literal_type(regular_dict_type)

# This should work fine
literal1 = await TypeEngine.async_to_literal(ctx, regular_dict, regular_dict_type, expected_type)
literal1 = TypeEngine.to_literal(ctx, regular_dict, regular_dict_type, expected_type)
assert literal1.map.literals["a"].scalar.primitive.integer == 1
assert literal1.map.literals["b"].scalar.primitive.integer == 2

Expand All @@ -3840,7 +3839,7 @@ async def test_dict_transformer_annotated_type():
annotated_dict_type = Annotated[Dict[str, int], "some_metadata"]
expected_type = TypeEngine.to_literal_type(annotated_dict_type)

literal2 = await TypeEngine.async_to_literal(ctx, annotated_dict, annotated_dict_type, expected_type)
literal2 = TypeEngine.to_literal(ctx, annotated_dict, annotated_dict_type, expected_type)
assert literal2.map.literals["x"].scalar.primitive.integer == 10
assert literal2.map.literals["y"].scalar.primitive.integer == 20

Expand All @@ -3849,7 +3848,7 @@ async def test_dict_transformer_annotated_type():
nested_dict_type = Dict[str, Annotated[Dict[str, int], "inner_metadata"]]
expected_type = TypeEngine.to_literal_type(nested_dict_type)

literal3 = await TypeEngine.async_to_literal(ctx, nested_dict, nested_dict_type, expected_type)
literal3 = TypeEngine.to_literal(ctx, nested_dict, nested_dict_type, expected_type)
assert literal3.map.literals["outer"].map.literals["inner"].scalar.primitive.integer == 42

@pytest.fixture(autouse=True)
Expand Down
Loading