diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/action/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/action/collections.py index edd6b3f9b..222d2745f 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/action/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/action/collections.py @@ -160,7 +160,7 @@ def _set_watch_changes_attr( self._set_watch_changes_attr(element["elements"], context) # type:ignore def _refine_schema(self, sub_schema: CollectionSchema) -> CollectionSchema: - actions_schema = {} + actions_schema = {**sub_schema["actions"]} for name, action in self._actions.items(): dynamics: List[bool] = [] for field in action.get("form", []): diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/chart/chart_collection_decorator.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/chart/chart_collection_decorator.py index 7c9886e0c..9a5862eff 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/chart/chart_collection_decorator.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/chart/chart_collection_decorator.py @@ -37,7 +37,7 @@ async def render_chart(self, caller: User, name: str, record_id: List) -> Chart: return await self.child_collection.render_chart(caller, name, record_id) def _refine_schema(self, sub_schema: CollectionSchema) -> CollectionSchema: - charts = {} + charts = {**sub_schema["charts"]} for name, chart in self._charts.items(): charts[name] = chart diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/segments/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/segments/collections.py index 631d82e05..e5fe27a7a 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/segments/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/segments/collections.py @@ -24,7 +24,7 @@ def add_segment(self, name: str, segment: SegmentAlias): self.mark_schema_as_dirty() def _refine_schema(self, sub_schema: CollectionSchema) -> CollectionSchema: - return {**sub_schema, "segments": [*self._segments.keys()]} + return {**sub_schema, "segments": [*sub_schema["segments"], *self._segments.keys()]} async def _refine_filter( self, caller: User, _filter: Union[Optional[PaginatedFilter], Optional[Filter]] diff --git a/src/datasource_toolkit/tests/decorators/action/test_action_decorator.py b/src/datasource_toolkit/tests/decorators/action/test_action_decorator.py index ac53d486b..141ce0a74 100644 --- a/src/datasource_toolkit/tests/decorators/action/test_action_decorator.py +++ b/src/datasource_toolkit/tests/decorators/action/test_action_decorator.py @@ -859,3 +859,10 @@ def test_get_form_should_return_hidden_fields_when_asked(self): ], ) if_fn.assert_not_called() + + def test_should_schema_should_contains_actions_define_in_custom_datasource(self): + with patch.dict(self.collection_product.schema, {"actions": {"action_test": {}}}): + self.assertIn( + "action_test", + self.datasource_decorator.get_collection("Product").schema["actions"], + ) diff --git a/src/datasource_toolkit/tests/decorators/chart/test_chart_collection_decorator.py b/src/datasource_toolkit/tests/decorators/chart/test_chart_collection_decorator.py index 50b26ca43..81564b625 100644 --- a/src/datasource_toolkit/tests/decorators/chart/test_chart_collection_decorator.py +++ b/src/datasource_toolkit/tests/decorators/chart/test_chart_collection_decorator.py @@ -24,8 +24,8 @@ def setUpClass(cls) -> None: cls.datasource: Datasource = Datasource() Collection.__abstractmethods__ = set() # to instantiate abstract class - cls.collection_book = Collection("Product", cls.datasource) - cls.collection_book.add_fields( + cls.collection_product = Collection("Product", cls.datasource) + cls.collection_product.add_fields( { "id": { "column_type": PrimitiveType.NUMBER, @@ -35,7 +35,7 @@ def setUpClass(cls) -> None: }, } ) - cls.datasource.add_collection(cls.collection_book) + cls.datasource.add_collection(cls.collection_product) cls.mocked_caller = User( rendering_id=1, @@ -54,7 +54,7 @@ def setUp(self) -> None: self.decorated_collection: ChartCollectionDecorator = self.decorated_datasource.get_collection("Product") def test_schema_should_not_change(self): - assert self.decorated_collection.schema["charts"] == self.collection_book.schema["charts"] + assert self.decorated_collection.schema["charts"] == self.collection_product.schema["charts"] def test_add_chart_should_raise_if_chart_name_already_exists(self): self.decorated_collection.add_chart("test_chart", lambda ctx, result_builder: True) @@ -67,7 +67,7 @@ def test_add_chart_should_raise_if_chart_name_already_exists(self): ) def test_render_chart_should_call_child_collection(self): - with patch.object(self.collection_book, "render_chart", new_callable=AsyncMock) as mock_render_chart: + with patch.object(self.collection_product, "render_chart", new_callable=AsyncMock) as mock_render_chart: self.loop.run_until_complete(self.decorated_collection.render_chart(self.mocked_caller, "child_chart", [1])) mock_render_chart.assert_awaited_once_with(self.mocked_caller, "child_chart", [1]) @@ -82,3 +82,10 @@ async def chart_fn(context, result_builder: ResultBuilder, record_id): ) assert result == {"countCurrent": 1, "countPrevious": None} + + def test_should_schema_should_contains_charts_define_in_custom_datasource(self): + with patch.dict(self.collection_product._schema, {"charts": {"chart_test": None}}): + self.assertIn( + "chart_test", + self.decorated_collection.schema["charts"], + ) diff --git a/src/datasource_toolkit/tests/decorators/segments/test_segments_decorator.py b/src/datasource_toolkit/tests/decorators/segments/test_segments_decorator.py index d60659e9c..2b9001ec6 100644 --- a/src/datasource_toolkit/tests/decorators/segments/test_segments_decorator.py +++ b/src/datasource_toolkit/tests/decorators/segments/test_segments_decorator.py @@ -1,6 +1,7 @@ import asyncio import sys from unittest import TestCase +from unittest.mock import patch if sys.version_info >= (3, 9): import zoneinfo @@ -26,8 +27,8 @@ def setUpClass(cls) -> None: cls.datasource: Datasource = Datasource() Collection.__abstractmethods__ = set() # to instantiate abstract class - cls.collection_book = Collection("Product", cls.datasource) - cls.collection_book.add_fields( + cls.collection_product = Collection("Product", cls.datasource) + cls.collection_product.add_fields( { "id": Column(column_type=PrimitiveType.NUMBER, is_primary_key=True, type=FieldType.COLUMN), "name": Column(column_type=PrimitiveType.STRING, type=FieldType.COLUMN), @@ -38,7 +39,7 @@ def setUpClass(cls) -> None: } ) - cls.datasource.add_collection(cls.collection_book) + cls.datasource.add_collection(cls.collection_product) cls.datasource_decorator = DatasourceDecorator(cls.datasource, SegmentCollectionDecorator) cls.mocked_caller = User( @@ -121,3 +122,10 @@ async def segment_fn(context: CollectionCustomizationContext): self.decorated_collection_product._refine_filter(self.mocked_caller, filter_) ) assert returned_filter == Filter({"condition_tree": ConditionTreeLeaf("name", Operator.EQUAL, "a_name_value")}) + + def test_should_schema_should_contains_segments_define_in_custom_datasource(self): + with patch.dict(self.collection_product._schema, {"segments": ["segment_test"]}): + self.assertIn( + "segment_test", + self.decorated_collection_product.schema["segments"], + )