From d6e15447a2ae5994d6470188addc107ba98a3b56 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 17 Dec 2024 16:25:26 +0100 Subject: [PATCH 1/6] chore: add aggregate to lazy join --- .../decorators/lazy_join/collection.py | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py index 42612c2b9..db4d3c5c9 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py @@ -3,6 +3,7 @@ from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.decorators.collection_decorator import CollectionDecorator from forestadmin.datasource_toolkit.interfaces.fields import ManyToOne, is_many_to_one +from forestadmin.datasource_toolkit.interfaces.query.aggregation import AggregateResult, Aggregation from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter @@ -39,6 +40,38 @@ async def _refine_filter( return _filter + async def aggregate( + self, caller: User, filter_: Filter | None, aggregation: Aggregation, limit: int | None = None + ) -> List[AggregateResult]: + replaced = {} + + def replacer(field_name: str) -> str: + if self._is_useless_join(field_name.split(":")[0], aggregation.projection): + new_field_name = self._get_fk_field_for_projection(field_name) + replaced[new_field_name] = field_name + return new_field_name + else: + return field_name + + new_aggregation = aggregation.replace_fields(replacer) + + aggregate_result = await self.child_collection.aggregate( + caller, cast(Filter, await self._refine_filter(caller, filter_)), new_aggregation, limit + ) + if aggregation == new_aggregation: + return aggregate_result + + for result in aggregate_result: + group = {} + for field, value in result["group"].items(): + if field in replaced: + group[replaced[field]] = value + else: + group[field] = value + result["group"] = group + + return aggregate_result + def _is_useless_join(self, relation: str, projection: Projection) -> bool: relation_schema = self.schema["fields"][relation] sub_projections = projection.relations[relation] @@ -63,7 +96,7 @@ def _get_projection_without_useless_joins(self, projection: Projection) -> Proje returned_projection.remove(f"{relation}:{relation_projections[0]}") # add foreign keys to projection - fk_field = self._get_fk_field_for_projection(relation) + fk_field = self._get_fk_field_for_projection(f"{relation}:{relation_projections[0]}") if fk_field not in returned_projection: returned_projection.append(fk_field) From d51cee8f77c6f0e489cb785989a3e9c40f57a03c Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 17 Dec 2024 16:53:20 +0100 Subject: [PATCH 2/6] chore: add few tests --- .../lazy_join/test_lazy_join_decorator.py | 139 +++++++++++++++++- 1 file changed, 137 insertions(+), 2 deletions(-) diff --git a/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py b/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py index 7e30e1b39..675edb00a 100644 --- a/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py +++ b/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py @@ -1,7 +1,10 @@ import asyncio import sys from unittest import TestCase -from unittest.mock import AsyncMock, patch +from unittest.mock import ANY, AsyncMock, patch + +from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation +from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter if sys.version_info >= (3, 9): import zoneinfo @@ -53,6 +56,10 @@ def setUpClass(cls) -> None: column_type=PrimitiveType.STRING, type=FieldType.COLUMN, ), + "price": Column( + column_type=PrimitiveType.NUMBER, + type=FieldType.COLUMN, + ), } ) cls.collection_person = Collection("Person", cls.datasource) @@ -226,7 +233,7 @@ def test_should_disable_join_on_projection_but_not_in_condition_tree(self): response, ) - def test_should_correctly_handle_null_relations(self): + def test_should_correctly_handle_null_relations_on_list(self): with patch.object( self.collection_book, "list", @@ -252,3 +259,131 @@ def test_should_correctly_handle_null_relations(self): ], result, ) + + def test_should_not_join_on_aggregate_when_group_by_foreign_pk(self): + with patch.object( + self.collection_book, + "aggregate", + new_callable=AsyncMock, + return_value=[ + {"value": 1824.11, "group": {"author_id": 2}}, + {"value": 824.11, "group": {"author_id": 3}}, + ], + ) as mock_aggregate: + result = self.loop.run_until_complete( + self.decorated_book_collection.aggregate( + self.mocked_caller, + Filter({}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}), + None, + ) + ) + self.assertEqual( + result, + [ + {"value": 1824.11, "group": {"author:id": 2}}, + {"value": 824.11, "group": {"author:id": 3}}, + ], + ) + mock_aggregate.assert_awaited_once_with( + self.mocked_caller, + Filter({}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}), + None, + ) + + def test_should_join_on_aggregate_when_group_by_foreign_field(self): + with patch.object( + self.collection_book, + "aggregate", + new_callable=AsyncMock, + return_value=[ + {"value": 1824.11, "group": {"author:first_name": "Isaac"}}, + {"value": 824.11, "group": {"author:first_name": "JK"}}, + ], + ) as mock_aggregate: + result = self.loop.run_until_complete( + self.decorated_book_collection.aggregate( + self.mocked_caller, + Filter({}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:first_name"}]}), + None, + ) + ) + self.assertEqual( + result, + [ + {"value": 1824.11, "group": {"author:first_name": "Isaac"}}, + {"value": 824.11, "group": {"author:first_name": "JK"}}, + ], + ) + mock_aggregate.assert_awaited_once_with( + self.mocked_caller, + Filter({}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:first_name"}]}), + None, + ) + + def test_should_not_join_on_aggregate_when_group_by_foreign_pk_and_filter_on_foreign_pk(self): + with patch.object( + self.collection_book, + "aggregate", + new_callable=AsyncMock, + return_value=[ + {"value": 1824.11, "group": {"author_id": 2}}, + {"value": 824.11, "group": {"author_id": 3}}, + ], + ) as mock_aggregate: + result = self.loop.run_until_complete( + self.decorated_book_collection.aggregate( + self.mocked_caller, + Filter({"condition_tree": ConditionTreeLeaf("author:id", "not_equal", 50)}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}), + None, + ) + ) + self.assertEqual( + result, + [ + {"value": 1824.11, "group": {"author:id": 2}}, + {"value": 824.11, "group": {"author:id": 3}}, + ], + ) + mock_aggregate.assert_awaited_once_with( + self.mocked_caller, + Filter({"condition_tree": ConditionTreeLeaf("author_id", "not_equal", 50)}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}), + None, + ) + + def test_should_join_on_aggregate_when_group_by_foreign_pk_and_filter_on_foreign_field(self): + with patch.object( + self.collection_book, + "aggregate", + new_callable=AsyncMock, + return_value=[ + {"value": 1824.11, "group": {"author_id": 2}}, + {"value": 824.11, "group": {"author_id": 3}}, + ], + ) as mock_aggregate: + result = self.loop.run_until_complete( + self.decorated_book_collection.aggregate( + self.mocked_caller, + Filter({"condition_tree": ConditionTreeLeaf("author:first_name", "not_equal", "wrong_name")}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}), + None, + ) + ) + self.assertEqual( + result, + [ + {"value": 1824.11, "group": {"author:id": 2}}, + {"value": 824.11, "group": {"author:id": 3}}, + ], + ) + mock_aggregate.assert_awaited_once_with( + self.mocked_caller, + Filter({"condition_tree": ConditionTreeLeaf("author:first_name", "not_equal", "wrong_name")}), + Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}), + None, + ) From 78719d66efc881efe4cc49a1521a3435a68eb827 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 17 Dec 2024 17:25:54 +0100 Subject: [PATCH 3/6] chore: fix linting --- .../tests/decorators/lazy_join/test_lazy_join_decorator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py b/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py index 675edb00a..b50cc7b73 100644 --- a/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py +++ b/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py @@ -1,7 +1,7 @@ import asyncio import sys from unittest import TestCase -from unittest.mock import ANY, AsyncMock, patch +from unittest.mock import AsyncMock, patch from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter From 84634e8e18ef076cffdc9e8ffaad6b3e5e9acf69 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 17 Dec 2024 17:31:32 +0100 Subject: [PATCH 4/6] chore: fix syntax for python 3.9 --- .../datasource_toolkit/decorators/lazy_join/collection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py index db4d3c5c9..1b78f06bb 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py @@ -1,4 +1,4 @@ -from typing import List, Union, cast +from typing import List, Optional, Union, cast from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.decorators.collection_decorator import CollectionDecorator @@ -41,7 +41,7 @@ async def _refine_filter( return _filter async def aggregate( - self, caller: User, filter_: Filter | None, aggregation: Aggregation, limit: int | None = None + self, caller: User, filter_: Union[Filter, None], aggregation: Aggregation, limit: Optional[int] = None ) -> List[AggregateResult]: replaced = {} From 34e72f9563cd13c41e51291203066fc42c570abb Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 18 Dec 2024 11:10:37 +0100 Subject: [PATCH 5/6] chore: refactor a bit --- .../decorators/lazy_join/collection.py | 61 ++++++++++--------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py index 1b78f06bb..59d956df2 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/lazy_join/collection.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union, cast +from typing import Dict, List, Optional, Union, cast from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.decorators.collection_decorator import CollectionDecorator @@ -18,7 +18,7 @@ async def list(self, caller: User, filter_: PaginatedFilter, projection: Project refined_filter = cast(PaginatedFilter, await self._refine_filter(caller, filter_)) ret = await self.child_collection.list(caller, refined_filter, simplified_projection) - return self._apply_joins_on_records(projection, simplified_projection, ret) + return self._apply_joins_on_simplified_records(projection, simplified_projection, ret) async def _refine_filter( self, caller: User, _filter: Union[Filter, PaginatedFilter, None] @@ -29,11 +29,11 @@ async def _refine_filter( _filter.condition_tree = _filter.condition_tree.replace( lambda leaf: ( ConditionTreeLeaf( - self._get_fk_field_for_projection(leaf.field), + self._get_fk_field_for_many_to_one_projection(leaf.field), leaf.operator, leaf.value, ) - if self._is_useless_join(leaf.field.split(":")[0], _filter.condition_tree.projection) + if self._is_useless_join_for_projection(leaf.field.split(":")[0], _filter.condition_tree.projection) else leaf ) ) @@ -43,36 +43,25 @@ async def _refine_filter( async def aggregate( self, caller: User, filter_: Union[Filter, None], aggregation: Aggregation, limit: Optional[int] = None ) -> List[AggregateResult]: - replaced = {} + replaced = {} # new_name -> old_name; for a simpler reconciliation def replacer(field_name: str) -> str: - if self._is_useless_join(field_name.split(":")[0], aggregation.projection): - new_field_name = self._get_fk_field_for_projection(field_name) + if self._is_useless_join_for_projection(field_name.split(":")[0], aggregation.projection): + new_field_name = self._get_fk_field_for_many_to_one_projection(field_name) replaced[new_field_name] = field_name return new_field_name - else: - return field_name + return field_name new_aggregation = aggregation.replace_fields(replacer) - aggregate_result = await self.child_collection.aggregate( + aggregate_results = await self.child_collection.aggregate( caller, cast(Filter, await self._refine_filter(caller, filter_)), new_aggregation, limit ) if aggregation == new_aggregation: - return aggregate_result + return aggregate_results + return self._replace_fields_in_aggregate_group(aggregate_results, replaced) - for result in aggregate_result: - group = {} - for field, value in result["group"].items(): - if field in replaced: - group[replaced[field]] = value - else: - group[field] = value - result["group"] = group - - return aggregate_result - - def _is_useless_join(self, relation: str, projection: Projection) -> bool: + def _is_useless_join_for_projection(self, relation: str, projection: Projection) -> bool: relation_schema = self.schema["fields"][relation] sub_projections = projection.relations[relation] @@ -82,7 +71,7 @@ def _is_useless_join(self, relation: str, projection: Projection) -> bool: and sub_projections[0] == relation_schema["foreign_key_target"] ) - def _get_fk_field_for_projection(self, projection: str) -> str: + def _get_fk_field_for_many_to_one_projection(self, projection: str) -> str: relation_name = projection.split(":")[0] relation_schema = cast(ManyToOne, self.schema["fields"][relation_name]) @@ -91,18 +80,18 @@ def _get_fk_field_for_projection(self, projection: str) -> str: def _get_projection_without_useless_joins(self, projection: Projection) -> Projection: returned_projection = Projection(*projection) for relation, relation_projections in projection.relations.items(): - if self._is_useless_join(relation, projection): + if self._is_useless_join_for_projection(relation, projection): # remove foreign key target from projection returned_projection.remove(f"{relation}:{relation_projections[0]}") # add foreign keys to projection - fk_field = self._get_fk_field_for_projection(f"{relation}:{relation_projections[0]}") + fk_field = self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}") if fk_field not in returned_projection: returned_projection.append(fk_field) return returned_projection - def _apply_joins_on_records( + def _apply_joins_on_simplified_records( self, initial_projection: Projection, requested_projection: Projection, records: List[RecordsDataAlias] ) -> List[RecordsDataAlias]: if requested_projection == initial_projection: @@ -117,7 +106,9 @@ def _apply_joins_on_records( relation_schema = self.schema["fields"][relation] if is_many_to_one(relation_schema): - fk_value = record[self._get_fk_field_for_projection(f"{relation}:{relation_projections[0]}")] + fk_value = record[ + self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}") + ] record[relation] = {relation_projections[0]: fk_value} if fk_value else None # remove foreign keys @@ -125,3 +116,17 @@ def _apply_joins_on_records( del record[projection] return records + + def _replace_fields_in_aggregate_group( + self, aggregate_results: List[AggregateResult], field_to_replace: Dict[str, str] + ) -> List[AggregateResult]: + for aggregate_result in aggregate_results: + group = {} + for field, value in aggregate_result["group"].items(): + if field in field_to_replace: + group[field_to_replace[field]] = value + else: + group[field] = value + aggregate_result["group"] = group + + return aggregate_results From 8f63700cc686c58c7641d79446046715ec7caffb Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 18 Dec 2024 11:22:06 +0100 Subject: [PATCH 6/6] chore: change import order --- .../tests/decorators/lazy_join/test_lazy_join_decorator.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py b/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py index b50cc7b73..647101048 100644 --- a/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py +++ b/src/datasource_toolkit/tests/decorators/lazy_join/test_lazy_join_decorator.py @@ -3,9 +3,6 @@ from unittest import TestCase from unittest.mock import AsyncMock, patch -from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation -from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter - if sys.version_info >= (3, 9): import zoneinfo else: @@ -17,8 +14,10 @@ from forestadmin.datasource_toolkit.decorators.datasource_decorator import DatasourceDecorator from forestadmin.datasource_toolkit.decorators.lazy_join.collection import LazyJoinCollectionDecorator from forestadmin.datasource_toolkit.interfaces.fields import Column, FieldType, ManyToOne, OneToMany, PrimitiveType +from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter +from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter from forestadmin.datasource_toolkit.interfaces.query.projections import Projection