Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List, Union, cast
from typing import Dict, List, Optional, Union, cast

from forestadmin.agent_toolkit.utils.context import User
from forestadmin.datasource_toolkit.decorators.collection_decorator import CollectionDecorator
from forestadmin.datasource_toolkit.interfaces.fields import ManyToOne, is_many_to_one
from forestadmin.datasource_toolkit.interfaces.query.aggregation import AggregateResult, Aggregation
from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf
from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter
from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter
Expand All @@ -17,7 +18,7 @@ async def list(self, caller: User, filter_: PaginatedFilter, projection: Project
refined_filter = cast(PaginatedFilter, await self._refine_filter(caller, filter_))
ret = await self.child_collection.list(caller, refined_filter, simplified_projection)

return self._apply_joins_on_records(projection, simplified_projection, ret)
return self._apply_joins_on_simplified_records(projection, simplified_projection, ret)

async def _refine_filter(
self, caller: User, _filter: Union[Filter, PaginatedFilter, None]
Expand All @@ -28,18 +29,39 @@ async def _refine_filter(
_filter.condition_tree = _filter.condition_tree.replace(
lambda leaf: (
ConditionTreeLeaf(
self._get_fk_field_for_projection(leaf.field),
self._get_fk_field_for_many_to_one_projection(leaf.field),
leaf.operator,
leaf.value,
)
if self._is_useless_join(leaf.field.split(":")[0], _filter.condition_tree.projection)
if self._is_useless_join_for_projection(leaf.field.split(":")[0], _filter.condition_tree.projection)
else leaf
)
)

return _filter

def _is_useless_join(self, relation: str, projection: Projection) -> bool:
async def aggregate(
self, caller: User, filter_: Union[Filter, None], aggregation: Aggregation, limit: Optional[int] = None
) -> List[AggregateResult]:
replaced = {} # new_name -> old_name; for a simpler reconciliation

def replacer(field_name: str) -> str:
if self._is_useless_join_for_projection(field_name.split(":")[0], aggregation.projection):
new_field_name = self._get_fk_field_for_many_to_one_projection(field_name)
replaced[new_field_name] = field_name
return new_field_name
return field_name

new_aggregation = aggregation.replace_fields(replacer)

aggregate_results = await self.child_collection.aggregate(
caller, cast(Filter, await self._refine_filter(caller, filter_)), new_aggregation, limit
)
if aggregation == new_aggregation:
return aggregate_results
return self._replace_fields_in_aggregate_group(aggregate_results, replaced)

def _is_useless_join_for_projection(self, relation: str, projection: Projection) -> bool:
relation_schema = self.schema["fields"][relation]
sub_projections = projection.relations[relation]

Expand All @@ -49,7 +71,7 @@ def _is_useless_join(self, relation: str, projection: Projection) -> bool:
and sub_projections[0] == relation_schema["foreign_key_target"]
)

def _get_fk_field_for_projection(self, projection: str) -> str:
def _get_fk_field_for_many_to_one_projection(self, projection: str) -> str:
relation_name = projection.split(":")[0]
relation_schema = cast(ManyToOne, self.schema["fields"][relation_name])

Expand All @@ -58,18 +80,18 @@ def _get_fk_field_for_projection(self, projection: str) -> str:
def _get_projection_without_useless_joins(self, projection: Projection) -> Projection:
returned_projection = Projection(*projection)
for relation, relation_projections in projection.relations.items():
if self._is_useless_join(relation, projection):
if self._is_useless_join_for_projection(relation, projection):
# remove foreign key target from projection
returned_projection.remove(f"{relation}:{relation_projections[0]}")

# add foreign keys to projection
fk_field = self._get_fk_field_for_projection(relation)
fk_field = self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}")
if fk_field not in returned_projection:
returned_projection.append(fk_field)

return returned_projection

def _apply_joins_on_records(
def _apply_joins_on_simplified_records(
self, initial_projection: Projection, requested_projection: Projection, records: List[RecordsDataAlias]
) -> List[RecordsDataAlias]:
if requested_projection == initial_projection:
Expand All @@ -84,11 +106,27 @@ def _apply_joins_on_records(
relation_schema = self.schema["fields"][relation]

if is_many_to_one(relation_schema):
fk_value = record[self._get_fk_field_for_projection(f"{relation}:{relation_projections[0]}")]
fk_value = record[
self._get_fk_field_for_many_to_one_projection(f"{relation}:{relation_projections[0]}")
]
record[relation] = {relation_projections[0]: fk_value} if fk_value else None

# remove foreign keys
for projection in projections_to_rm:
del record[projection]

return records

def _replace_fields_in_aggregate_group(
self, aggregate_results: List[AggregateResult], field_to_replace: Dict[str, str]
) -> List[AggregateResult]:
for aggregate_result in aggregate_results:
group = {}
for field, value in aggregate_result["group"].items():
if field in field_to_replace:
group[field_to_replace[field]] = value
else:
group[field] = value
aggregate_result["group"] = group

return aggregate_results
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from forestadmin.datasource_toolkit.decorators.datasource_decorator import DatasourceDecorator
from forestadmin.datasource_toolkit.decorators.lazy_join.collection import LazyJoinCollectionDecorator
from forestadmin.datasource_toolkit.interfaces.fields import Column, FieldType, ManyToOne, OneToMany, PrimitiveType
from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation
from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf
from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter
from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter
from forestadmin.datasource_toolkit.interfaces.query.projections import Projection


Expand Down Expand Up @@ -53,6 +55,10 @@ def setUpClass(cls) -> None:
column_type=PrimitiveType.STRING,
type=FieldType.COLUMN,
),
"price": Column(
column_type=PrimitiveType.NUMBER,
type=FieldType.COLUMN,
),
}
)
cls.collection_person = Collection("Person", cls.datasource)
Expand Down Expand Up @@ -226,7 +232,7 @@ def test_should_disable_join_on_projection_but_not_in_condition_tree(self):
response,
)

def test_should_correctly_handle_null_relations(self):
def test_should_correctly_handle_null_relations_on_list(self):
with patch.object(
self.collection_book,
"list",
Expand All @@ -252,3 +258,131 @@ def test_should_correctly_handle_null_relations(self):
],
result,
)

def test_should_not_join_on_aggregate_when_group_by_foreign_pk(self):
with patch.object(
self.collection_book,
"aggregate",
new_callable=AsyncMock,
return_value=[
{"value": 1824.11, "group": {"author_id": 2}},
{"value": 824.11, "group": {"author_id": 3}},
],
) as mock_aggregate:
result = self.loop.run_until_complete(
self.decorated_book_collection.aggregate(
self.mocked_caller,
Filter({}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}),
None,
)
)
self.assertEqual(
result,
[
{"value": 1824.11, "group": {"author:id": 2}},
{"value": 824.11, "group": {"author:id": 3}},
],
)
mock_aggregate.assert_awaited_once_with(
self.mocked_caller,
Filter({}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}),
None,
)

def test_should_join_on_aggregate_when_group_by_foreign_field(self):
with patch.object(
self.collection_book,
"aggregate",
new_callable=AsyncMock,
return_value=[
{"value": 1824.11, "group": {"author:first_name": "Isaac"}},
{"value": 824.11, "group": {"author:first_name": "JK"}},
],
) as mock_aggregate:
result = self.loop.run_until_complete(
self.decorated_book_collection.aggregate(
self.mocked_caller,
Filter({}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:first_name"}]}),
None,
)
)
self.assertEqual(
result,
[
{"value": 1824.11, "group": {"author:first_name": "Isaac"}},
{"value": 824.11, "group": {"author:first_name": "JK"}},
],
)
mock_aggregate.assert_awaited_once_with(
self.mocked_caller,
Filter({}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:first_name"}]}),
None,
)

def test_should_not_join_on_aggregate_when_group_by_foreign_pk_and_filter_on_foreign_pk(self):
with patch.object(
self.collection_book,
"aggregate",
new_callable=AsyncMock,
return_value=[
{"value": 1824.11, "group": {"author_id": 2}},
{"value": 824.11, "group": {"author_id": 3}},
],
) as mock_aggregate:
result = self.loop.run_until_complete(
self.decorated_book_collection.aggregate(
self.mocked_caller,
Filter({"condition_tree": ConditionTreeLeaf("author:id", "not_equal", 50)}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}),
None,
)
)
self.assertEqual(
result,
[
{"value": 1824.11, "group": {"author:id": 2}},
{"value": 824.11, "group": {"author:id": 3}},
],
)
mock_aggregate.assert_awaited_once_with(
self.mocked_caller,
Filter({"condition_tree": ConditionTreeLeaf("author_id", "not_equal", 50)}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}),
None,
)

def test_should_join_on_aggregate_when_group_by_foreign_pk_and_filter_on_foreign_field(self):
with patch.object(
self.collection_book,
"aggregate",
new_callable=AsyncMock,
return_value=[
{"value": 1824.11, "group": {"author_id": 2}},
{"value": 824.11, "group": {"author_id": 3}},
],
) as mock_aggregate:
result = self.loop.run_until_complete(
self.decorated_book_collection.aggregate(
self.mocked_caller,
Filter({"condition_tree": ConditionTreeLeaf("author:first_name", "not_equal", "wrong_name")}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author:id"}]}),
None,
)
)
self.assertEqual(
result,
[
{"value": 1824.11, "group": {"author:id": 2}},
{"value": 824.11, "group": {"author:id": 3}},
],
)
mock_aggregate.assert_awaited_once_with(
self.mocked_caller,
Filter({"condition_tree": ConditionTreeLeaf("author:first_name", "not_equal", "wrong_name")}),
Aggregation({"field": "price", "operation": "Sum", "groups": [{"field": "author_id"}]}),
None,
)
Loading