diff --git a/AUTHORS.rst b/AUTHORS.rst index 2752f1c..1eac9ad 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -30,3 +30,4 @@ Contributors (chronological) - Robert Sawicki `@ww3pl `_ - `@aberres `_ - George Alton `@georgealton `_ +- Adrian Vandier Ast `@AdrianVandierAst `_ diff --git a/marshmallow_jsonapi/fields.py b/marshmallow_jsonapi/fields.py index 985ee54..7bbd064 100644 --- a/marshmallow_jsonapi/fields.py +++ b/marshmallow_jsonapi/fields.py @@ -201,7 +201,7 @@ def extract_value(self, data): # fall back below to old behaviour of only IDs. if "attributes" in data and self.__schema: result = self.schema.load( - {"data": data, "included": self.root.included_data} + {"data": data, "included": self.root.included_data.values()} ) return result.data if _MARSHMALLOW_VERSION_INFO[0] < 3 else result diff --git a/marshmallow_jsonapi/schema.py b/marshmallow_jsonapi/schema.py index b8de86a..4955f29 100644 --- a/marshmallow_jsonapi/schema.py +++ b/marshmallow_jsonapi/schema.py @@ -166,19 +166,18 @@ def unwrap_item(self, item): # Fold included data related to this relationship into the item, so # that we can deserialize the whole objects instead of just IDs. if self.included_data: - included_data = [] + included_data = None inner_data = value.get("data", []) # Data may be ``None`` (for empty relationships), but we only # need to process it when it's present. if inner_data: if not is_collection(inner_data): - included_data = next( - self._extract_from_included(inner_data), None - ) + included_data = self._extract_from_included(inner_data) else: + included_data = [] for data in inner_data: - included_data.extend(self._extract_from_included(data)) + included_data.append(self._extract_from_included(data)) if included_data: value["data"] = included_data @@ -235,7 +234,7 @@ def _do_load(self, data, many=None, **kwargs): # Store this on the instance so we have access to the included data # when processing relationships (``included`` is outside of the # ``data``). - self.included_data = data.get("included", {}) + self.included_data = self._load_included_data(data.get("included", [])) self.document_meta = data.get("meta", {}) try: @@ -257,16 +256,28 @@ def _do_load(self, data, many=None, **kwargs): return data, formatted_messages return result - def _extract_from_included(self, data): - """Extract included data matching the items in ``data``. + def _load_included_data(self, included): + """ Transform a list of resource object into a dict indexed by object type and id. + """ + included_data = {} + for item in included: + if "type" not in item.keys() or "id" not in item.keys(): + raise ma.ValidationError( + [ + { + "detail": "`included` objects must include `type` and `id` keys.", + "source": {"pointer": "/included"}, + } + ] + ) + included_data[(item["type"], item["id"])] = item + return included_data - For each item in ``data``, extract the full data from the included - data. + def _extract_from_included(self, data): + """Extract included data matching the item in ``data``. """ - return ( - item - for item in self.included_data - if item["type"] == data["type"] and str(item["id"]) == str(data["id"]) + return self.included_data.get( + (data["type"], data["id"]), {"type": data["type"], "id": data["id"]} ) def inflect(self, text): diff --git a/tests/base.py b/tests/base.py index d722613..078988c 100644 --- a/tests/base.py +++ b/tests/base.py @@ -92,6 +92,7 @@ class CommentSchema(Schema): related_url_kwargs={"id": ""}, schema=AuthorSchema, many=False, + type_="people", ) class Meta: diff --git a/tests/test_schema.py b/tests/test_schema.py index 4ca16ea..226e372 100755 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -4,13 +4,14 @@ from marshmallow_jsonapi import Schema, fields from marshmallow_jsonapi.exceptions import IncorrectTypeError from marshmallow_jsonapi.utils import _MARSHMALLOW_VERSION_INFO -from tests.base import unpack +from tests.base import unpack, fake from tests.base import ( AuthorSchema, CommentSchema, PostSchema, PolygonSchema, ArticleSchema, + Comment, ) @@ -372,6 +373,31 @@ class Meta(PostSchema.Meta): assert "from_context" in included["attributes"] assert included["attributes"]["from_context"] == "Hello World" + def test_load_n_dump_same_schema(self): + json_data = { + "data": { + "type": "comments", + "id": "1", + "attributes": {"body": fake.bs()}, + "relationships": {"author": {"data": {"type": "people", "id": "1"}}}, + }, + "included": [ + { + "type": "people", + "id": "1", + "attributes": { + "first_name": fake.first_name(), + "last_name": fake.last_name(), + }, + } + ], + } + schema = CommentSchema() + data = unpack(schema.load(json_data)) + comment = Comment(**data) + out_json_data = unpack(schema.dump(comment)) + assert json_data["included"] == out_json_data["included"] + def get_error_by_field(errors, field): for err in errors["errors"]: