From aa0bba7f46efc2065f20ee49e06201c2503dc63f Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 6 Nov 2024 16:12:14 +0100 Subject: [PATCH 01/71] chore: add datasource_composite --- .../context/relaxed_wrappers/collection.py | 4 ++ .../datasource_composite.py | 59 +++++++++++++++++++ .../datasource_toolkit/datasources.py | 3 +- .../interfaces/models/collections.py | 9 +++ 4 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py index 21781b7c6..b11c3e671 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py @@ -51,6 +51,10 @@ def get_collection(self, name: str) -> "RelaxedCollection": def add_collection(self, collection: "RelaxedCollection") -> None: raise RelaxedDatasourceException("Cannot modify existing datasources") + @property + def schema(self): + return self.datasource.schema + class RelaxedCollection(Collection): def __init__(self, collection: Collection): diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py new file mode 100644 index 000000000..7ae1c5a22 --- /dev/null +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py @@ -0,0 +1,59 @@ +from typing import Any, List + +from forestadmin.agent_toolkit.utils.context import User +from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException +from forestadmin.datasource_toolkit.interfaces.chart import Chart +from forestadmin.datasource_toolkit.interfaces.models.collections import BoundCollection, DatasourceSchema + + +class CompositeDatasource(Datasource): + def __init__(self) -> None: + super().__init__() + self._datasources: List[Datasource] = [] + + @property + def schema(self) -> DatasourceSchema: + charts = {} + for datasource in self._datasources: + charts.update(datasource.schema["charts"]) + return {"charts": charts} + + @property + def collections(self) -> List[BoundCollection]: + ret = [] + for datasource in self._datasources: + ret.extend(datasource.collections) + return ret + + def get_collection(self, name: str) -> Any: + for datasource in self._datasources: + try: + return datasource.get_collection(name) + except Exception: + pass + + collection_names = [c.name for c in self.collections] + collection_names.sort() + raise DatasourceToolkitException( + f"Collection {name} not found. List of available collection: {', '.join(collection_names)}" + ) + + def add_datasource(self, datasource: Datasource): + existing_collection_names = [c.name for c in self.collections] + for collection in datasource.collections: + if collection.name in existing_collection_names: + raise DatasourceToolkitException(f"Collection '{collection.name}' already exists.") + + for chart_name in datasource.schema["charts"].keys(): + if chart_name in self.schema["charts"].keys(): + raise DatasourceToolkitException(f"Chart '{chart_name}' already exists.") + + self._datasources.append(datasource) + + async def render_chart(self, caller: User, name: str) -> Chart: + for datasource in self._datasources: + if name in datasource.schema["charts"]: + return await datasource.render_chart(caller, name) + + raise DatasourceToolkitException(f"Chart {name} is not defined in the datasource.") diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py index 9623aa5f3..9141d2618 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py @@ -5,6 +5,7 @@ from forestadmin.datasource_toolkit.interfaces.chart import Chart from forestadmin.datasource_toolkit.interfaces.models.collections import BoundCollection from forestadmin.datasource_toolkit.interfaces.models.collections import Datasource as DatasourceInterface +from forestadmin.datasource_toolkit.interfaces.models.collections import DatasourceSchema class DatasourceException(DatasourceToolkitException): @@ -16,7 +17,7 @@ def __init__(self) -> None: self._collections: Dict[str, BoundCollection] = {} @property - def schema(self): + def schema(self) -> DatasourceSchema: return {"charts": {}} @property diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py index 818a19cc4..7e3e0b43f 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py @@ -15,6 +15,10 @@ class CollectionSchema(TypedDict): charts: Dict[str, Callable] +class DatasourceSchema(TypedDict): + charts: Dict[str, Callable] + + class Collection(abc.ABC): @property @abc.abstractmethod @@ -49,3 +53,8 @@ def get_collection(self, name: str) -> BoundCollection: @abc.abstractmethod def add_collection(self, collection: BoundCollection) -> None: raise NotImplementedError + + @property + @abc.abstractmethod + def schema(self) -> DatasourceSchema: + raise NotImplementedError From eef7c43ae70fe32dc32e2bad1e3b5ef9efeb1e73 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 6 Nov 2024 16:15:24 +0100 Subject: [PATCH 02/71] chore: add tests for composite datasource --- .../test_composite_datasource.py | 162 ++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py diff --git a/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py b/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py new file mode 100644 index 000000000..9691b8704 --- /dev/null +++ b/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py @@ -0,0 +1,162 @@ +import asyncio +import sys +from unittest import TestCase +from unittest.mock import AsyncMock, Mock, PropertyMock, patch + +if sys.version_info >= (3, 9): + import zoneinfo +else: + from backports import zoneinfo + +from forestadmin.agent_toolkit.utils.context import User +from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource +from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException +from forestadmin.datasource_toolkit.interfaces.fields import Column, FieldType, PrimitiveType + + +class BaseTestCompositeDatasource(TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.loop = asyncio.new_event_loop() + Collection.__abstractmethods__ = set() # to instantiate abstract class + + cls.mocked_caller = User( + rendering_id=1, + user_id=1, + tags={}, + email="dummy@user.fr", + first_name="dummy", + last_name="user", + team="operational", + timezone=zoneinfo.ZoneInfo("Europe/Paris"), + request={"ip": "127.0.0.1"}, + ) + + def setUp(self) -> None: + self.composite_ds = CompositeDatasource() + self.ds1_charts = {"charts": {"chart1": Mock()}} + self.ds2_charts = {"charts": {"chart2": Mock()}} + + DS1 = type("DS1", (Datasource,), {"schema": PropertyMock(return_value=self.ds1_charts)}) + DS2 = type("DS2", (Datasource,), {"schema": PropertyMock(return_value=self.ds2_charts)}) + + self.datasource_1: Datasource = DS1() + self.collection_person = Collection("Person", self.datasource_1) + self.collection_person.add_fields( + { + "id": Column(column_type=PrimitiveType.NUMBER, is_primary_key=True, type=FieldType.COLUMN), + "first_name": Column(column_type=PrimitiveType.STRING, type=FieldType.COLUMN), + "last_name": Column(column_type=PrimitiveType.STRING, type=FieldType.COLUMN), + } + ) + self.datasource_1.add_collection(self.collection_person) + + self.datasource_2: Datasource = DS2() + self.collection_order = Collection("Order", self.datasource_2) + self.collection_order.add_fields( + { + "id": Column(column_type=PrimitiveType.NUMBER, is_primary_key=True, type=FieldType.COLUMN), + "customer_id": Column(column_type=PrimitiveType.NUMBER, type=FieldType.COLUMN), + "price": Column(column_type=PrimitiveType.NUMBER, type=FieldType.COLUMN), + } + ) + + self.datasource_2.add_collection(self.collection_order) + + +class TestCompositeDatasource(BaseTestCompositeDatasource): + def setUp(self) -> None: + super().setUp() + self.composite_ds.add_datasource(self.datasource_1) + + def test_add_datasource_should_raise_if_multiple_collection_with_same_name(self): + collection_person = Collection("Person", self.datasource_2) + collection_person.add_fields( + { + "id": Column(column_type=PrimitiveType.NUMBER, is_primary_key=True, type=FieldType.COLUMN), + "first_name": Column(column_type=PrimitiveType.STRING, type=FieldType.COLUMN), + "last_name": Column(column_type=PrimitiveType.STRING, type=FieldType.COLUMN), + } + ) + self.datasource_2.add_collection(collection_person) + + self.assertRaisesRegex( + DatasourceToolkitException, + r"Collection 'Person' already exists\.", + self.composite_ds.add_datasource, + self.datasource_2, + ) + + def test_collection_should_return_collection_of_all_datasources(self): + self.composite_ds.add_datasource(self.datasource_2) + + self.assertEqual(len(self.composite_ds.collections), 2) + self.assertIn("Person", [c.name for c in self.composite_ds.collections]) + self.assertIn("Order", [c.name for c in self.composite_ds.collections]) + + def test_get_collection_should_search_in_all_datasources(self): + self.composite_ds.add_datasource(self.datasource_2) + + collection = self.composite_ds.get_collection("Person") + self.assertEqual(collection.name, "Person") + self.assertEqual(collection.datasource, self.datasource_1) + + collection = self.composite_ds.get_collection("Order") + self.assertEqual(collection.name, "Order") + self.assertEqual(collection.datasource, self.datasource_2) + + def test_get_collection_should_list_collection_names_if_collection_not_found(self): + self.composite_ds.add_datasource(self.datasource_2) + + self.assertRaisesRegex( + DatasourceToolkitException, + "Collection Unknown not found. List of available collection: Order, Person", + self.composite_ds.get_collection, + "Unknown", + ) + + +class TestCompositeDatasourceCharts(BaseTestCompositeDatasource): + def setUp(self) -> None: + self.ds1_charts = {"charts": {"chart1": Mock()}} + self.ds2_charts = {"charts": {"chart2": Mock()}} + super().setUp() + + self.composite_ds.add_datasource(self.datasource_1) + + def test_add_datasource_should_raise_if_duplicated_chart(self): + self.ds1_charts["charts"]["chart2"] = Mock() + + self.assertRaisesRegex( + DatasourceToolkitException, + r"Chart 'chart2' already exists.", + self.composite_ds.add_datasource, + self.datasource_2, + ) + + def test_schema_should_contains_all_charts(self): + self.composite_ds.add_datasource(self.datasource_2) + self.assertIn("chart1", self.composite_ds.schema["charts"]) + self.assertIn("chart2", self.composite_ds.schema["charts"]) + + def test_render_chart_should_raise_if_chart_is_unknown(self): + self.composite_ds.add_datasource(self.datasource_2) + self.assertRaisesRegex( + DatasourceToolkitException, + "Chart unknown is not defined in the datasource.", + self.loop.run_until_complete, + self.composite_ds.render_chart(self.mocked_caller, "unknown"), + ) + + def test_render_chart_should_call_render_chart_on_good_datasource(self): + self.composite_ds.add_datasource(self.datasource_2) + + with patch.object(self.datasource_1, "render_chart", new_callable=AsyncMock) as mock_render_chart: + self.loop.run_until_complete(self.composite_ds.render_chart(self.mocked_caller, "chart1")) + mock_render_chart.assert_awaited_with(self.mocked_caller, "chart1") + + with patch.object(self.datasource_2, "render_chart", new_callable=AsyncMock) as mock_render_chart: + self.loop.run_until_complete(self.composite_ds.render_chart(self.mocked_caller, "chart2")) + mock_render_chart.assert_awaited_with(self.mocked_caller, "chart2") From e58fa78ab5e20492ec19ae781d3d03fd5bdd6039 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 6 Nov 2024 16:15:43 +0100 Subject: [PATCH 03/71] chore: use datasource composite --- .../datasource_customizer/datasource_customizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py index 38deb3e44..2904eb167 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py @@ -1,6 +1,7 @@ from typing import Dict, Optional from forestadmin.datasource_toolkit.datasource_customizer.collection_customizer import CollectionCustomizer +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasource_customizer.types import DataSourceOptions from forestadmin.datasource_toolkit.datasources import Datasource from forestadmin.datasource_toolkit.decorators.chart.types import DataSourceChartDefinition @@ -12,7 +13,7 @@ class DatasourceCustomizer: def __init__(self) -> None: - self.composite_datasource: Datasource = Datasource() + self.composite_datasource: CompositeDatasource = CompositeDatasource() self.stack = DecoratorStack(self.composite_datasource) @property @@ -52,8 +53,7 @@ async def _add_datasource(): rename_decorator.rename_collections(_options.get("rename", {})) datasource = rename_decorator - for collection in datasource.collections: - self.composite_datasource.add_collection(collection) + self.composite_datasource.add_datasource(datasource) self.stack.queue_customization(_add_datasource) return self From db910e9467697858a546b9d17ed6fc7f36839a4c Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 6 Nov 2024 16:49:56 +0100 Subject: [PATCH 04/71] chore: add name in datasources --- .../datasource_django/datasource.py | 6 +++-- .../datasource_sqlalchemy/datasource.py | 4 ++-- .../context/relaxed_wrappers/collection.py | 4 ++++ .../datasource_composite.py | 22 ++++++++++++++++++- .../datasource_toolkit/datasources.py | 9 ++++++-- .../interfaces/models/collections.py | 5 +++++ .../test_composite_datasource.py | 12 ++++++++++ 7 files changed, 55 insertions(+), 7 deletions(-) diff --git a/src/datasource_django/forestadmin/datasource_django/datasource.py b/src/datasource_django/forestadmin/datasource_django/datasource.py index 453fcf68b..62eeb9119 100644 --- a/src/datasource_django/forestadmin/datasource_django/datasource.py +++ b/src/datasource_django/forestadmin/datasource_django/datasource.py @@ -1,11 +1,13 @@ +from typing import Optional + from django.apps import apps from forestadmin.datasource_django.collection import DjangoCollection from forestadmin.datasource_django.interface import BaseDjangoDatasource class DjangoDatasource(BaseDjangoDatasource): - def __init__(self, support_polymorphic_relations: bool = False) -> None: - super().__init__() + def __init__(self, support_polymorphic_relations: bool = False, name: Optional[str] = None) -> None: + super().__init__(name) self.support_polymorphic_relations = support_polymorphic_relations self._create_collections() diff --git a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py index 24668e05e..a044ca0b6 100644 --- a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py +++ b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py @@ -8,8 +8,8 @@ class SqlAlchemyDatasource(BaseSqlAlchemyDatasource): - def __init__(self, Base: Any, db_uri: Optional[str] = None) -> None: - super().__init__() + def __init__(self, Base: Any, db_uri: Optional[str] = None, name: Optional[str] = None) -> None: + super().__init__(name) self._base = Base self.__is_using_flask_sqlalchemy = hasattr(Base, "Model") diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py index b11c3e671..a9086418a 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py @@ -55,6 +55,10 @@ def add_collection(self, collection: "RelaxedCollection") -> None: def schema(self): return self.datasource.schema + @property + def name(self): + return self.datasource.name + class RelaxedCollection(Collection): def __init__(self, collection: Collection): diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py index 7ae1c5a22..6eec48d67 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py @@ -1,5 +1,6 @@ -from typing import Any, List +from typing import Any, Dict, List +from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.datasources import Datasource from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException @@ -49,6 +50,12 @@ def add_datasource(self, datasource: Datasource): if chart_name in self.schema["charts"].keys(): raise DatasourceToolkitException(f"Chart '{chart_name}' already exists.") + if datasource.name in [ds.name for ds in self._datasources]: + ForestLogger.log( + "warning", + f"A datasource with the name '{datasource.name}' already exists. " + "You can use the optional parameter 'name' when creating a datasource.", + ) self._datasources.append(datasource) async def render_chart(self, caller: User, name: str) -> Chart: @@ -57,3 +64,16 @@ async def render_chart(self, caller: User, name: str) -> Chart: return await datasource.render_chart(caller, name) raise DatasourceToolkitException(f"Chart {name} is not defined in the datasource.") + + def get_datasources(self) -> List[Datasource]: + return [*self._datasources] + + def get_datasource(self, name: str) -> Datasource: + for datasource in self._datasources: + if name == datasource.name: + return datasource + + raise DatasourceToolkitException( + f"Datasource with name '{name}' is not found. Datasources names are: " + f"{', '.join([ds.name for ds in self._datasources])}" + ) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py index 9141d2618..bef43701a 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Optional from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException @@ -13,8 +13,13 @@ class DatasourceException(DatasourceToolkitException): class Datasource(DatasourceInterface[BoundCollection]): - def __init__(self) -> None: + def __init__(self, name: Optional[str] = None) -> None: self._collections: Dict[str, BoundCollection] = {} + self._name = name if name is not None else self.__class__.__name__ + + @property + def name(self) -> str: + return self._name @property def schema(self) -> DatasourceSchema: diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py index 7e3e0b43f..d9e05ed03 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py @@ -41,6 +41,11 @@ def schema(self) -> CollectionSchema: class Datasource(Generic[BoundCollection], abc.ABC): + @property + @abc.abstractmethod + def name(self) -> str: + raise NotImplementedError + @property @abc.abstractmethod def collections(self) -> List[BoundCollection]: diff --git a/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py b/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py index 9691b8704..53d9229a9 100644 --- a/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py +++ b/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py @@ -117,6 +117,18 @@ def test_get_collection_should_list_collection_names_if_collection_not_found(sel "Unknown", ) + def test_should_log_if_multiple_datasources_have_same_name(self): + ds1 = Datasource(name="test") + ds2 = Datasource(name="test") + self.composite_ds.add_datasource(ds1) + with patch("forestadmin.datasource_toolkit.datasource_customizer.datasource_composite.ForestLogger.log") as log: + self.composite_ds.add_datasource(ds2) + log.assert_any_call( + "warning", + "A datasource with the name 'test' already exists. You can use the optional parameter 'name' when " + "creating a datasource.", + ) + class TestCompositeDatasourceCharts(BaseTestCompositeDatasource): def setUp(self) -> None: From 3fbc694aa320faa68a4527e49faf07138e5d2add Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 19 Nov 2024 11:36:10 +0100 Subject: [PATCH 05/71] chore: add native query to datasources --- .../datasource_django/datasource.py | 29 ++++++++++++++++++- .../datasource_sqlalchemy/datasource.py | 17 +++++++++-- .../context/relaxed_wrappers/collection.py | 3 ++ .../datasource_composite.py | 5 ++++ .../datasource_toolkit/datasources.py | 11 +++++-- .../decorators/decorator_stack.py | 2 +- .../interfaces/models/collections.py | 6 ++++ 7 files changed, 67 insertions(+), 6 deletions(-) diff --git a/src/datasource_django/forestadmin/datasource_django/datasource.py b/src/datasource_django/forestadmin/datasource_django/datasource.py index 62eeb9119..829314e95 100644 --- a/src/datasource_django/forestadmin/datasource_django/datasource.py +++ b/src/datasource_django/forestadmin/datasource_django/datasource.py @@ -1,8 +1,13 @@ -from typing import Optional +from datetime import date +from typing import List, Optional +from asgiref.sync import sync_to_async from django.apps import apps +from django.db import connection from forestadmin.datasource_django.collection import DjangoCollection +from forestadmin.datasource_django.exception import DjangoDatasourceException from forestadmin.datasource_django.interface import BaseDjangoDatasource +from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias class DjangoDatasource(BaseDjangoDatasource): @@ -10,6 +15,7 @@ def __init__(self, support_polymorphic_relations: bool = False, name: Optional[s super().__init__(name) self.support_polymorphic_relations = support_polymorphic_relations self._create_collections() + self.enable_native_query() def _create_collections(self): models = apps.get_models(include_auto_created=True) @@ -17,3 +23,24 @@ def _create_collections(self): if model._meta.proxy is False: collection = DjangoCollection(self, model, self.support_polymorphic_relations) self.add_collection(collection) + + async def execute_native_query(self, native_query: str) -> List[RecordsDataAlias]: + def _execute_native_query(): + cursor = connection.cursor() + try: + rows = cursor.execute(native_query) + ret = [] + for row in rows: + return_row = {} + for i, field_name in enumerate(rows.description): + value = row[i] + if isinstance(value, date): + value = value.isoformat() + return_row[field_name[0]] = value + ret.append(return_row) + return ret + except Exception as e: + # TODO: verify + raise DjangoDatasourceException(str(e)) + + return await sync_to_async(_execute_native_query)() diff --git a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py index a044ca0b6..6e1263ee9 100644 --- a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py +++ b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py @@ -1,9 +1,10 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from forestadmin.datasource_sqlalchemy.collections import SqlAlchemyCollection from forestadmin.datasource_sqlalchemy.exceptions import SqlAlchemyDatasourceException from forestadmin.datasource_sqlalchemy.interfaces import BaseSqlAlchemyDatasource -from sqlalchemy import create_engine +from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias +from sqlalchemy import create_engine, text from sqlalchemy.orm import Mapper, sessionmaker @@ -25,6 +26,7 @@ def __init__(self, Base: Any, db_uri: Optional[str] = None, name: Optional[str] self.Session = sessionmaker(bind) self._create_collections() + self.enable_native_query() def _find_db_uri(self, base_class): engine = None @@ -56,6 +58,17 @@ def build_mappers(self) -> Dict[str, Mapper]: mappers[mapper.persist_selectable.name] = mapper return mappers + async def execute_native_query(self, native_query: str) -> List[RecordsDataAlias]: + try: + session = self.Session() + query = native_query + if isinstance(query, str): + query = text(query) + rows = session.execute(query) + return [*rows.mappings()] + except Exception as exc: + raise SqlAlchemyDatasourceException(str(exc)) + # unused code, can be use full but can be remove # from forestadmin.datasource_toolkit.datasources import DatasourceException # from forestadmin.datasource_toolkit.interfaces.fields import FieldType, ManyToMany, ManyToOne diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py index a9086418a..20e18e9f9 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py @@ -59,6 +59,9 @@ def schema(self): def name(self): return self.datasource.name + async def execute_native_query(self, native_query: str) -> Any: + raise RelaxedDatasourceException("Cannot use this method. Please use 'collection.get_native_driver' instead.") + class RelaxedCollection(Collection): def __init__(self, collection: Collection): diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py index 6eec48d67..984b28bfa 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py @@ -77,3 +77,8 @@ def get_datasource(self, name: str) -> Datasource: f"Datasource with name '{name}' is not found. Datasources names are: " f"{', '.join([ds.name for ds in self._datasources])}" ) + + async def execute_native_query(self, native_query: str) -> Any: + raise DatasourceToolkitException( + "Cannot use this method. Please use 'get_datasource(name).execute_native_query' instead." + ) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py index bef43701a..62d595b52 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException @@ -16,6 +16,7 @@ class Datasource(DatasourceInterface[BoundCollection]): def __init__(self, name: Optional[str] = None) -> None: self._collections: Dict[str, BoundCollection] = {} self._name = name if name is not None else self.__class__.__name__ + self._schema: DatasourceSchema = {"charts": {}, "native_query": False} @property def name(self) -> str: @@ -23,7 +24,10 @@ def name(self) -> str: @property def schema(self) -> DatasourceSchema: - return {"charts": {}} + return self._schema + + def enable_native_query(self): + self._schema["native_query"] = True @property def collections(self) -> List[BoundCollection]: @@ -46,3 +50,6 @@ def add_collection(self, collection: BoundCollection) -> None: async def render_chart(self, caller: User, name: str) -> Chart: raise DatasourceException(f"Chart {name} not exists on this datasource.") + + async def execute_native_query(self, native_query: str) -> Any: + raise NotImplementedError(f"'execute_native_query' is not implemented on {self.__class__.__name__}") diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py index 5a1d297fd..760a708e8 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py @@ -28,7 +28,7 @@ class DecoratorStack: def __init__(self, datasource: Datasource) -> None: self._customizations: List = list() - last = datasource + last = self.base_datasource = datasource # Step 0: Do not query datasource when we know the result with yield an empty set. last = self.override = DatasourceDecorator(last, OverrideCollectionDecorator) # type: ignore diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py index d9e05ed03..bc1ec5b3f 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py @@ -3,6 +3,7 @@ from forestadmin.datasource_toolkit.interfaces.actions import Action from forestadmin.datasource_toolkit.interfaces.fields import FieldAlias +from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias from typing_extensions import Self @@ -17,6 +18,7 @@ class CollectionSchema(TypedDict): class DatasourceSchema(TypedDict): charts: Dict[str, Callable] + native_query: bool class Collection(abc.ABC): @@ -63,3 +65,7 @@ def add_collection(self, collection: BoundCollection) -> None: @abc.abstractmethod def schema(self) -> DatasourceSchema: raise NotImplementedError + + @abc.abstractmethod + async def execute_native_query(self, native_query: str) -> List[RecordsDataAlias]: + raise NotImplementedError From 3924ec1549196ba4df52918fb7d2e9efba637621 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 19 Nov 2024 11:53:57 +0100 Subject: [PATCH 06/71] chore: add native_query route --- .../forestadmin/agent_toolkit/agent.py | 9 ++++ .../resources/collections/native_query.py | 46 +++++++++++++++++++ .../forestadmin/django_agent/urls.py | 3 +- .../django_agent/views/native_query.py | 15 ++++++ .../forestadmin/flask_agent/agent.py | 13 +++++- 5 files changed, 84 insertions(+), 2 deletions(-) create mode 100644 src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py create mode 100644 src/django_agent/forestadmin/django_agent/views/native_query.py diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py index ccc4dcb88..e7e0ad01a 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py @@ -9,6 +9,7 @@ from forestadmin.agent_toolkit.resources.collections.charts_datasource import ChartsDatasourceResource from forestadmin.agent_toolkit.resources.collections.crud import CrudResource from forestadmin.agent_toolkit.resources.collections.crud_related import CrudRelatedResource +from forestadmin.agent_toolkit.resources.collections.native_query import NativeQueryResource from forestadmin.agent_toolkit.resources.collections.stats import StatsResource from forestadmin.agent_toolkit.resources.security.resources import Authentication from forestadmin.agent_toolkit.services.permissions.ip_whitelist_service import IpWhiteListService @@ -37,6 +38,7 @@ class Resources(TypedDict): actions: ActionResource collection_charts: ChartsCollectionResource datasource_charts: ChartsDatasourceResource + native_query: NativeQueryResource class Agent: @@ -112,6 +114,13 @@ async def __mk_resources(self): self._ip_white_list_service, self.options, ), + "native_query": NativeQueryResource( + self.customizer.composite_datasource, + await self.customizer.get_datasource(), + self._permission_service, + self._ip_white_list_service, + self.options, + ), } async def get_resources(self): diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py new file mode 100644 index 000000000..d5e5f83f3 --- /dev/null +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -0,0 +1,46 @@ +from typing import Literal, Union + +from forestadmin.agent_toolkit.forest_logger import ForestLogger +from forestadmin.agent_toolkit.options import Options +from forestadmin.agent_toolkit.resources.collections.base_collection_resource import BaseCollectionResource +from forestadmin.agent_toolkit.resources.collections.decorators import authenticate, check_method, ip_white_list +from forestadmin.agent_toolkit.services.permissions.ip_whitelist_service import IpWhiteListService +from forestadmin.agent_toolkit.services.permissions.permission_service import PermissionService +from forestadmin.agent_toolkit.utils.context import HttpResponseBuilder, Request, RequestMethod, Response +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource +from forestadmin.datasource_toolkit.datasource_customizer.datasource_customizer import DatasourceCustomizer +from forestadmin.datasource_toolkit.interfaces.models.collections import BoundCollection, Datasource + +DatasourceAlias = Union[Datasource[BoundCollection], DatasourceCustomizer] + + +LiteralMethod = Literal["native_query"] + + +class NativeQueryResource(BaseCollectionResource): + def __init__( + self, + composite_datasource: CompositeDatasource, + datasource: DatasourceAlias, + permission: PermissionService, + ip_white_list_service: IpWhiteListService, + options: Options, + ): + super().__init__(datasource, permission, ip_white_list_service, options) + self.composite_datasource: CompositeDatasource = composite_datasource + + @ip_white_list + async def dispatch(self, request: Request, method_name: Literal["native_query"]) -> Response: + try: + return HttpResponseBuilder.build_success_response(await self.handle_native_query(request)) + except Exception as exc: + ForestLogger.log("exception", exc) + return HttpResponseBuilder.build_client_error_response([exc]) + + @check_method(RequestMethod.POST) + @authenticate + async def handle_native_query(self, request: Request) -> Response: + # TODO: permission check + # TODO: context variable injector + ds = self.composite_datasource.get_datasource(request.body["datasource"]) + return await ds.execute_native_query(request.body["native_query"]) diff --git a/src/django_agent/forestadmin/django_agent/urls.py b/src/django_agent/forestadmin/django_agent/urls.py index 60e9098e7..da2765286 100644 --- a/src/django_agent/forestadmin/django_agent/urls.py +++ b/src/django_agent/forestadmin/django_agent/urls.py @@ -1,7 +1,7 @@ from django.conf import settings from django.urls import path -from .views import actions, authentication, capabilities, charts, crud, crud_related, index, stats +from .views import actions, authentication, capabilities, charts, crud, crud_related, index, native_query, stats app_name = "django_agent" @@ -16,6 +16,7 @@ # generic path(f"{prefix}forest/", index.index, name="index"), path(f"{prefix}forest/_internal/capabilities", capabilities.capabilities, name="capabilities"), + path(f"{prefix}forest/_internal/native_query", native_query.native_query, name="capabilities"), path(f"{prefix}forest/scope-cache-invalidation", index.scope_cache_invalidation, name="scope_invalidation"), # authentication path(f"{prefix}forest/authentication", authentication.authentication, name="authentication"), diff --git a/src/django_agent/forestadmin/django_agent/views/native_query.py b/src/django_agent/forestadmin/django_agent/views/native_query.py new file mode 100644 index 000000000..a49f95512 --- /dev/null +++ b/src/django_agent/forestadmin/django_agent/views/native_query.py @@ -0,0 +1,15 @@ +from asgiref.sync import async_to_sync +from django.http import HttpRequest +from forestadmin.django_agent.apps import DjangoAgentApp +from forestadmin.django_agent.utils.converter import convert_request, convert_response + + +@async_to_sync +async def native_query(request: HttpRequest, **kwargs): + resource = (await DjangoAgentApp.get_agent().get_resources())["native_query"] + response = await resource.dispatch(convert_request(request, kwargs), "native_query") + return convert_response(response) + + +# This is so ugly... But django.views.decorators.csrf.csrf_exempt is not asyncio ready +native_query.csrf_exempt = True diff --git a/src/flask_agent/forestadmin/flask_agent/agent.py b/src/flask_agent/forestadmin/flask_agent/agent.py index e5c267797..ce87e0507 100644 --- a/src/flask_agent/forestadmin/flask_agent/agent.py +++ b/src/flask_agent/forestadmin/flask_agent/agent.py @@ -16,6 +16,7 @@ from forestadmin.agent_toolkit.resources.base import BaseResource from forestadmin.agent_toolkit.resources.capabilities import LiteralMethod as CapabilitiesLiteralMethod from forestadmin.agent_toolkit.resources.collections.crud import LiteralMethod as CrudLiteralMethod +from forestadmin.agent_toolkit.resources.collections.native_query import LiteralMethod as NativeQueryLiteralMethod from forestadmin.agent_toolkit.resources.security.resources import LiteralMethod as AuthLiteralMethod from forestadmin.agent_toolkit.utils.context import Request from forestadmin.agent_toolkit.utils.forest_schema.type import AgentMeta @@ -116,7 +117,13 @@ async def _get_collection_response( request: FlaskRequest, resource: BaseResource, method: Optional[ - Union[AuthLiteralMethod, CrudLiteralMethod, ActionLiteralMethod, CapabilitiesLiteralMethod] + Union[ + AuthLiteralMethod, + CrudLiteralMethod, + ActionLiteralMethod, + CapabilitiesLiteralMethod, + NativeQueryLiteralMethod, + ] ] = None, detail: bool = False, ) -> FlaskResponse: @@ -133,6 +140,10 @@ async def index() -> FlaskResponse: # type: ignore async def capabilities() -> FlaskResponse: # type: ignore return await _get_collection_response(request, (await agent.get_resources())["capabilities"], "capabilities") + @blueprint.route("/_internal/native_query", methods=["POST"]) + async def native_query() -> FlaskResponse: # type: ignore + return await _get_collection_response(request, (await agent.get_resources())["native_query"], "native_query") + @blueprint.route("/authentication/callback", methods=["GET"]) async def callback() -> FlaskResponse: # type: ignore return await _get_collection_response(request, (await agent.get_resources())["authentication"], "callback") From e628fe6a23a394f6c47d436b8311fef6bcbb2d24 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 21 Nov 2024 10:12:17 +0100 Subject: [PATCH 07/71] chore(datasources): change datasource name to native query connections --- .../datasource_django/datasource.py | 59 ++++++++++++++++--- .../datasource_sqlalchemy/datasource.py | 11 ++-- .../datasource_composite.py | 42 ++++++------- .../datasource_toolkit/datasources.py | 17 ++---- .../interfaces/models/collections.py | 11 +--- 5 files changed, 85 insertions(+), 55 deletions(-) diff --git a/src/datasource_django/forestadmin/datasource_django/datasource.py b/src/datasource_django/forestadmin/datasource_django/datasource.py index 829314e95..ed26fb9f9 100644 --- a/src/datasource_django/forestadmin/datasource_django/datasource.py +++ b/src/datasource_django/forestadmin/datasource_django/datasource.py @@ -1,9 +1,10 @@ from datetime import date -from typing import List, Optional +from typing import Dict, List, Optional, Union from asgiref.sync import sync_to_async from django.apps import apps -from django.db import connection +from django.db import connections +from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.datasource_django.collection import DjangoCollection from forestadmin.datasource_django.exception import DjangoDatasourceException from forestadmin.datasource_django.interface import BaseDjangoDatasource @@ -11,11 +12,42 @@ class DjangoDatasource(BaseDjangoDatasource): - def __init__(self, support_polymorphic_relations: bool = False, name: Optional[str] = None) -> None: - super().__init__(name) + def __init__( + self, + support_polymorphic_relations: bool = False, + live_query_connection: Optional[Union[str, Dict[str, str]]] = None, + ) -> None: + self._live_query_connections: Dict[str, str] = self._handle_live_query_connections_param(live_query_connection) + super().__init__([*self._live_query_connections.keys()]) + self.support_polymorphic_relations = support_polymorphic_relations self._create_collections() - self.enable_native_query() + + def _handle_live_query_connections_param( + self, live_query_connections: Optional[Union[str, Dict[str, str]]] + ) -> Dict[str, str]: + if live_query_connections is None: + return {} + + if isinstance(live_query_connections, str): + ret = {live_query_connections: "default"} + if len(connections.all()) > 1: + ForestLogger.log( + "info", + f"You enabled live query as {live_query_connections} for django 'default' database." + " To use it over multiple databases, read the related documentation here: http://link.", + # TODO: link + ) + else: + ret = live_query_connections + + for forest_name, db_name in ret.items(): + if db_name not in connections: + raise DjangoDatasourceException( + f"Connection to database '{db_name}' for alias '{forest_name}' is not found in django databases. " + f"Existing connections are {','.join([*connections])}" + ) + return ret def _create_collections(self): models = apps.get_models(include_auto_created=True) @@ -24,9 +56,22 @@ def _create_collections(self): collection = DjangoCollection(self, model, self.support_polymorphic_relations) self.add_collection(collection) - async def execute_native_query(self, native_query: str) -> List[RecordsDataAlias]: + async def execute_native_query(self, connection_name: str, native_query: str) -> List[RecordsDataAlias]: + if self._live_query_connections is None or connection_name not in self._live_query_connections.keys(): + # This one should never occur + raise DjangoDatasourceException( + f"Native query connection '{connection_name}' is not known by DjangoDatasource." + ) + + if self._live_query_connections[connection_name] not in connections: + raise DjangoDatasourceException( + f"Connection to database '{self._live_query_connections[connection_name]}' for alias " + f"'{connection_name}' is not found in django connections. " + f"Existing connections are {','.join([*connections])}" + ) + def _execute_native_query(): - cursor = connection.cursor() + cursor = connections[self._live_query_connections[connection_name]].cursor() # type: ignore try: rows = cursor.execute(native_query) ret = [] diff --git a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py index 6e1263ee9..dd497e31c 100644 --- a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py +++ b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py @@ -9,8 +9,8 @@ class SqlAlchemyDatasource(BaseSqlAlchemyDatasource): - def __init__(self, Base: Any, db_uri: Optional[str] = None, name: Optional[str] = None) -> None: - super().__init__(name) + def __init__(self, Base: Any, db_uri: Optional[str] = None, live_query_connection: Optional[str] = None) -> None: + super().__init__([live_query_connection] if live_query_connection is not None else None) self._base = Base self.__is_using_flask_sqlalchemy = hasattr(Base, "Model") @@ -26,7 +26,6 @@ def __init__(self, Base: Any, db_uri: Optional[str] = None, name: Optional[str] self.Session = sessionmaker(bind) self._create_collections() - self.enable_native_query() def _find_db_uri(self, base_class): engine = None @@ -58,7 +57,11 @@ def build_mappers(self) -> Dict[str, Mapper]: mappers[mapper.persist_selectable.name] = mapper return mappers - async def execute_native_query(self, native_query: str) -> List[RecordsDataAlias]: + async def execute_native_query(self, connection_name: str, native_query: str) -> List[RecordsDataAlias]: + if connection_name != self.schema["native_query_connections"][0]: + raise SqlAlchemyDatasourceException( + f"The native query connection '{connection_name}' doesn't belongs to this datasource." + ) try: session = self.Session() query = native_query diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py index 984b28bfa..f11676ccf 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py @@ -1,6 +1,5 @@ -from typing import Any, Dict, List +from typing import Any, List -from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.datasources import Datasource from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException @@ -16,9 +15,12 @@ def __init__(self) -> None: @property def schema(self) -> DatasourceSchema: charts = {} + native_queries = [] for datasource in self._datasources: charts.update(datasource.schema["charts"]) - return {"charts": charts} + native_queries.extend(datasource.schema["native_query_connections"]) + + return {"charts": charts, "native_query_connections": native_queries} @property def collections(self) -> List[BoundCollection]: @@ -46,16 +48,14 @@ def add_datasource(self, datasource: Datasource): if collection.name in existing_collection_names: raise DatasourceToolkitException(f"Collection '{collection.name}' already exists.") - for chart_name in datasource.schema["charts"].keys(): - if chart_name in self.schema["charts"].keys(): - raise DatasourceToolkitException(f"Chart '{chart_name}' already exists.") + for connection in datasource.schema["charts"].keys(): + if connection in self.schema["charts"].keys(): + raise DatasourceToolkitException(f"Chart '{connection}' already exists.") + + for connection in datasource.schema["native_query_connections"]: + if connection in self.schema["native_query_connections"]: + raise DatasourceToolkitException(f"Native query connection '{connection}' already exists.") - if datasource.name in [ds.name for ds in self._datasources]: - ForestLogger.log( - "warning", - f"A datasource with the name '{datasource.name}' already exists. " - "You can use the optional parameter 'name' when creating a datasource.", - ) self._datasources.append(datasource) async def render_chart(self, caller: User, name: str) -> Chart: @@ -65,20 +65,12 @@ async def render_chart(self, caller: User, name: str) -> Chart: raise DatasourceToolkitException(f"Chart {name} is not defined in the datasource.") - def get_datasources(self) -> List[Datasource]: - return [*self._datasources] - - def get_datasource(self, name: str) -> Datasource: + async def execute_native_query(self, connection_name: str, native_query: str) -> Any: for datasource in self._datasources: - if name == datasource.name: - return datasource - - raise DatasourceToolkitException( - f"Datasource with name '{name}' is not found. Datasources names are: " - f"{', '.join([ds.name for ds in self._datasources])}" - ) + if connection_name in datasource.schema["native_query_connections"]: + return await datasource.execute_native_query(connection_name, native_query) - async def execute_native_query(self, native_query: str) -> Any: raise DatasourceToolkitException( - "Cannot use this method. Please use 'get_datasource(name).execute_native_query' instead." + f"Cannot find {connection_name} in datasources. " + f"Existing connection names are: {','.join(self.schema['native_query_connections'])}" ) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py index 62d595b52..6b2970975 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py @@ -13,22 +13,17 @@ class DatasourceException(DatasourceToolkitException): class Datasource(DatasourceInterface[BoundCollection]): - def __init__(self, name: Optional[str] = None) -> None: + def __init__(self, live_query_connections: Optional[List[str]] = None) -> None: self._collections: Dict[str, BoundCollection] = {} - self._name = name if name is not None else self.__class__.__name__ - self._schema: DatasourceSchema = {"charts": {}, "native_query": False} - - @property - def name(self) -> str: - return self._name + self._schema: DatasourceSchema = { + "charts": {}, + "native_query_connections": live_query_connections or [], + } @property def schema(self) -> DatasourceSchema: return self._schema - def enable_native_query(self): - self._schema["native_query"] = True - @property def collections(self) -> List[BoundCollection]: return list(self._collections.values()) @@ -51,5 +46,5 @@ def add_collection(self, collection: BoundCollection) -> None: async def render_chart(self, caller: User, name: str) -> Chart: raise DatasourceException(f"Chart {name} not exists on this datasource.") - async def execute_native_query(self, native_query: str) -> Any: + async def execute_native_query(self, connection_name: str, native_query: str) -> Any: raise NotImplementedError(f"'execute_native_query' is not implemented on {self.__class__.__name__}") diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py index bc1ec5b3f..a55687039 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py @@ -1,5 +1,5 @@ import abc -from typing import Callable, Dict, Generic, List, TypedDict, TypeVar +from typing import Any, Callable, Dict, Generic, List, Optional, TypedDict, TypeVar, Union from forestadmin.datasource_toolkit.interfaces.actions import Action from forestadmin.datasource_toolkit.interfaces.fields import FieldAlias @@ -18,7 +18,7 @@ class CollectionSchema(TypedDict): class DatasourceSchema(TypedDict): charts: Dict[str, Callable] - native_query: bool + native_query_connections: List[str] class Collection(abc.ABC): @@ -43,11 +43,6 @@ def schema(self) -> CollectionSchema: class Datasource(Generic[BoundCollection], abc.ABC): - @property - @abc.abstractmethod - def name(self) -> str: - raise NotImplementedError - @property @abc.abstractmethod def collections(self) -> List[BoundCollection]: @@ -67,5 +62,5 @@ def schema(self) -> DatasourceSchema: raise NotImplementedError @abc.abstractmethod - async def execute_native_query(self, native_query: str) -> List[RecordsDataAlias]: + async def execute_native_query(self, connection_name: str, native_query: str) -> List[RecordsDataAlias]: raise NotImplementedError From ed3a43dd69ad82fba3efa35fbc34ad79141cd5c4 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 21 Nov 2024 15:10:47 +0100 Subject: [PATCH 08/71] chore(datasources): native query connections is no longer in schema --- .../datasource_django/datasource.py | 19 ++++++++++++------ .../datasource_sqlalchemy/datasource.py | 2 ++ .../datasource_composite.py | 20 ++++++++++++------- .../datasource_toolkit/datasources.py | 5 ++++- .../interfaces/models/collections.py | 7 +++++-- 5 files changed, 37 insertions(+), 16 deletions(-) diff --git a/src/datasource_django/forestadmin/datasource_django/datasource.py b/src/datasource_django/forestadmin/datasource_django/datasource.py index ed26fb9f9..62a673e0f 100644 --- a/src/datasource_django/forestadmin/datasource_django/datasource.py +++ b/src/datasource_django/forestadmin/datasource_django/datasource.py @@ -17,8 +17,10 @@ def __init__( support_polymorphic_relations: bool = False, live_query_connection: Optional[Union[str, Dict[str, str]]] = None, ) -> None: - self._live_query_connections: Dict[str, str] = self._handle_live_query_connections_param(live_query_connection) - super().__init__([*self._live_query_connections.keys()]) + self._django_live_query_connections: Dict[str, str] = self._handle_live_query_connections_param( + live_query_connection + ) + super().__init__([*self._django_live_query_connections.keys()]) self.support_polymorphic_relations = support_polymorphic_relations self._create_collections() @@ -57,21 +59,26 @@ def _create_collections(self): self.add_collection(collection) async def execute_native_query(self, connection_name: str, native_query: str) -> List[RecordsDataAlias]: - if self._live_query_connections is None or connection_name not in self._live_query_connections.keys(): + if ( + self._django_live_query_connections is None + or connection_name not in self._django_live_query_connections.keys() + ): + # TODO: verify # This one should never occur raise DjangoDatasourceException( f"Native query connection '{connection_name}' is not known by DjangoDatasource." ) - if self._live_query_connections[connection_name] not in connections: + if self._django_live_query_connections[connection_name] not in connections: + # TODO: verify raise DjangoDatasourceException( - f"Connection to database '{self._live_query_connections[connection_name]}' for alias " + f"Connection to database '{self._django_live_query_connections[connection_name]}' for alias " f"'{connection_name}' is not found in django connections. " f"Existing connections are {','.join([*connections])}" ) def _execute_native_query(): - cursor = connections[self._live_query_connections[connection_name]].cursor() # type: ignore + cursor = connections[self._django_live_query_connections[connection_name]].cursor() # type: ignore try: rows = cursor.execute(native_query) ret = [] diff --git a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py index dd497e31c..068ba002d 100644 --- a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py +++ b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py @@ -59,6 +59,7 @@ def build_mappers(self) -> Dict[str, Mapper]: async def execute_native_query(self, connection_name: str, native_query: str) -> List[RecordsDataAlias]: if connection_name != self.schema["native_query_connections"][0]: + # TODO: verify raise SqlAlchemyDatasourceException( f"The native query connection '{connection_name}' doesn't belongs to this datasource." ) @@ -70,6 +71,7 @@ async def execute_native_query(self, connection_name: str, native_query: str) -> rows = session.execute(query) return [*rows.mappings()] except Exception as exc: + # TODO: verify raise SqlAlchemyDatasourceException(str(exc)) # unused code, can be use full but can be remove diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py index f11676ccf..b93541135 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py @@ -15,12 +15,17 @@ def __init__(self) -> None: @property def schema(self) -> DatasourceSchema: charts = {} - native_queries = [] for datasource in self._datasources: charts.update(datasource.schema["charts"]) - native_queries.extend(datasource.schema["native_query_connections"]) - return {"charts": charts, "native_query_connections": native_queries} + return {"charts": charts} + + def get_native_query_connection(self) -> List[str]: + native_queries = [] + for datasource in self._datasources: + + native_queries.extend(datasource.get_native_query_connection()) + return native_queries @property def collections(self) -> List[BoundCollection]: @@ -52,8 +57,9 @@ def add_datasource(self, datasource: Datasource): if connection in self.schema["charts"].keys(): raise DatasourceToolkitException(f"Chart '{connection}' already exists.") - for connection in datasource.schema["native_query_connections"]: - if connection in self.schema["native_query_connections"]: + existing_native_query_connection_names = self.get_native_query_connection() + for connection in datasource.get_native_query_connection(): + if connection in existing_native_query_connection_names: raise DatasourceToolkitException(f"Native query connection '{connection}' already exists.") self._datasources.append(datasource) @@ -67,10 +73,10 @@ async def render_chart(self, caller: User, name: str) -> Chart: async def execute_native_query(self, connection_name: str, native_query: str) -> Any: for datasource in self._datasources: - if connection_name in datasource.schema["native_query_connections"]: + if connection_name in datasource.get_native_query_connection(): return await datasource.execute_native_query(connection_name, native_query) raise DatasourceToolkitException( f"Cannot find {connection_name} in datasources. " - f"Existing connection names are: {','.join(self.schema['native_query_connections'])}" + f"Existing connection names are: {','.join(self.get_native_query_connection())}" ) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py index 6b2970975..608045d30 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py @@ -15,11 +15,14 @@ class DatasourceException(DatasourceToolkitException): class Datasource(DatasourceInterface[BoundCollection]): def __init__(self, live_query_connections: Optional[List[str]] = None) -> None: self._collections: Dict[str, BoundCollection] = {} + self._live_query_connections = live_query_connections self._schema: DatasourceSchema = { "charts": {}, - "native_query_connections": live_query_connections or [], } + def get_native_query_connection(self) -> List[str]: + return self._live_query_connections or [] + @property def schema(self) -> DatasourceSchema: return self._schema diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py index a55687039..3dd747f55 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Callable, Dict, Generic, List, Optional, TypedDict, TypeVar, Union +from typing import Callable, Dict, Generic, List, TypedDict, TypeVar from forestadmin.datasource_toolkit.interfaces.actions import Action from forestadmin.datasource_toolkit.interfaces.fields import FieldAlias @@ -18,7 +18,6 @@ class CollectionSchema(TypedDict): class DatasourceSchema(TypedDict): charts: Dict[str, Callable] - native_query_connections: List[str] class Collection(abc.ABC): @@ -52,6 +51,10 @@ def collections(self) -> List[BoundCollection]: def get_collection(self, name: str) -> BoundCollection: raise NotImplementedError + @abc.abstractmethod + def get_native_query_connection(self) -> List[str]: + raise NotImplementedError + @abc.abstractmethod def add_collection(self, collection: BoundCollection) -> None: raise NotImplementedError From e94c627a540f8ccde8c957157318fc059a1abc3d Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 21 Nov 2024 15:22:27 +0100 Subject: [PATCH 09/71] chore(datasources): adapt native query and capabilities route --- .../forestadmin/agent_toolkit/agent.py | 4 +++- .../agent_toolkit/resources/capabilities.py | 13 +++++++++---- .../resources/collections/native_query.py | 9 ++++++--- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py index e7e0ad01a..2fc452a1d 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py @@ -75,7 +75,9 @@ def __del__(self): async def __mk_resources(self): self._resources: Resources = { "capabilities": CapabilitiesResource( - await self.customizer.get_datasource(), self._ip_white_list_service, self.options + self.customizer.composite_datasource, + self._ip_white_list_service, + self.options, ), "authentication": Authentication(self._ip_white_list_service, self.options), "crud": CrudResource( diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/capabilities.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/capabilities.py index 6e6aa7680..02a530380 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/capabilities.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/capabilities.py @@ -7,6 +7,7 @@ from forestadmin.agent_toolkit.services.permissions.ip_whitelist_service import IpWhiteListService from forestadmin.agent_toolkit.utils.context import HttpResponseBuilder, Request, RequestMethod, Response from forestadmin.agent_toolkit.utils.forest_schema.generator_field import SchemaFieldGenerator +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasource_customizer.datasource_customizer import DatasourceCustomizer from forestadmin.datasource_toolkit.exceptions import BusinessError from forestadmin.datasource_toolkit.interfaces.fields import Column, is_column @@ -19,12 +20,12 @@ class CapabilitiesResource(IpWhitelistResource): def __init__( self, - datasource: DatasourceAlias, + composite_datasource: CompositeDatasource, ip_white_list_service: IpWhiteListService, options: Options, ): super().__init__(ip_white_list_service, options) - self.datasource = datasource + self.composite_datasource: CompositeDatasource = composite_datasource @ip_white_list async def dispatch(self, request: Request, method_name: LiteralMethod) -> Response: @@ -40,14 +41,18 @@ async def dispatch(self, request: Request, method_name: LiteralMethod) -> Respon @check_method(RequestMethod.POST) @authenticate async def capabilities(self, request: Request) -> Response: - ret = {"collections": []} + ret = {"collections": [], "nativeQueryConnections": []} requested_collections = request.body.get("collectionNames", []) for collection_name in requested_collections: ret["collections"].append(self._get_collection_capability(collection_name)) + + ret["nativeQueryConnections"] = [ + {"name": connection} for connection in self.composite_datasource.get_native_query_connection() + ] return HttpResponseBuilder.build_success_response(ret) def _get_collection_capability(self, collection_name: str) -> Dict[str, Any]: - collection = self.datasource.get_collection(collection_name) + collection = self.composite_datasource.get_collection(collection_name) fields = [] for field_name, field_schema in collection.schema["fields"].items(): if is_column(field_schema): diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py index d5e5f83f3..33bb828de 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -32,7 +32,7 @@ def __init__( @ip_white_list async def dispatch(self, request: Request, method_name: Literal["native_query"]) -> Response: try: - return HttpResponseBuilder.build_success_response(await self.handle_native_query(request)) + return await self.handle_native_query(request) # type:ignore except Exception as exc: ForestLogger.log("exception", exc) return HttpResponseBuilder.build_client_error_response([exc]) @@ -42,5 +42,8 @@ async def dispatch(self, request: Request, method_name: Literal["native_query"]) async def handle_native_query(self, request: Request) -> Response: # TODO: permission check # TODO: context variable injector - ds = self.composite_datasource.get_datasource(request.body["datasource"]) - return await ds.execute_native_query(request.body["native_query"]) + return HttpResponseBuilder.build_success_response( + await self.composite_datasource.execute_native_query( + request.body["connection_name"], request.body["native_query"] + ) + ) From 0be8e229c72e58dddaf6400746d999c61f41c3d8 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 25 Nov 2024 09:57:39 +0100 Subject: [PATCH 10/71] chore: refactor context variable injector --- .../resources/collections/crud.py | 3 +- .../resources/collections/stats.py | 32 +--------- .../context_variable_injector_mixin.py | 64 +++++++++++++++++++ 3 files changed, 69 insertions(+), 30 deletions(-) create mode 100644 src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py index 942872632..daa044358 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py @@ -21,6 +21,7 @@ parse_timezone, ) from forestadmin.agent_toolkit.resources.collections.requests import RequestCollection, RequestCollectionException +from forestadmin.agent_toolkit.resources.context_variable_injector_mixin import ContextVariableInjectorResourceMixin from forestadmin.agent_toolkit.services.serializers import add_search_metadata from forestadmin.agent_toolkit.services.serializers.json_api import JsonApiException, JsonApiSerializer from forestadmin.agent_toolkit.utils.context import HttpResponseBuilder, Request, RequestMethod, Response, User @@ -62,7 +63,7 @@ LiteralMethod = Literal["list", "count", "add", "get", "delete_list", "csv"] -class CrudResource(BaseCollectionResource): +class CrudResource(BaseCollectionResource, ContextVariableInjectorResourceMixin): @ip_white_list async def dispatch(self, request: Request, method_name: LiteralMethod) -> Response: method = getattr(self, method_name) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py index 465ef794a..5b78966f4 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py @@ -14,12 +14,10 @@ ) from forestadmin.agent_toolkit.resources.collections.filter import build_filter from forestadmin.agent_toolkit.resources.collections.requests import RequestCollection, RequestCollectionException +from forestadmin.agent_toolkit.resources.context_variable_injector_mixin import ContextVariableInjectorResourceMixin from forestadmin.agent_toolkit.utils.context import FileResponse, HttpResponseBuilder, Request, RequestMethod, Response -from forestadmin.agent_toolkit.utils.context_variable_injector import ContextVariableInjector -from forestadmin.agent_toolkit.utils.context_variable_instantiator import ContextVariablesInstantiator from forestadmin.datasource_toolkit.exceptions import ForestException from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation -from forestadmin.datasource_toolkit.interfaces.query.condition_tree.factory import ConditionTreeFactory from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.base import ConditionTree from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import Aggregator, ConditionTreeBranch from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf @@ -28,7 +26,7 @@ from forestadmin.datasource_toolkit.utils.schema import SchemaUtils -class StatsResource(BaseCollectionResource): +class StatsResource(BaseCollectionResource, ContextVariableInjectorResourceMixin): FREQUENCIES = {"Day": "d", "Week": "W-MON", "Month": "BMS", "Year": "BYS"} FORMAT = {"Day": "%d/%m/%Y", "Week": "W%V-%Y", "Month": "%b %Y", "Year": "%Y"} @@ -234,7 +232,7 @@ def _use_interval_res(tree: ConditionTree) -> None: async def _get_filter(self, request: RequestCollection) -> Filter: scope_tree = await self.permission.get_scope(request.user, request.collection) - await self.__inject_context_variables(request) + await self.inject_context_variables_in_filter(request) return build_filter(request, scope_tree) async def _compute_value(self, request: RequestCollection, filter: Filter) -> int: @@ -246,27 +244,3 @@ async def _compute_value(self, request: RequestCollection, filter: Filter) -> in if len(rows): res = int(rows[0]["value"]) return res - - async def __inject_context_variables(self, request: RequestCollection): - context_variables_dct = request.body.pop("contextVariables", {}) - if request.body.get("filter") is None: - return - - context_variables = await ContextVariablesInstantiator.build_context_variables( - request.user, context_variables_dct, self.permission - ) - condition_tree = request.body["filter"] - if isinstance(request.body["filter"], str): - condition_tree = json.loads(condition_tree) - - condition_tree = ConditionTreeFactory.from_plain_object(condition_tree) - - injected_filter: ConditionTree = condition_tree.replace( - lambda leaf: ConditionTreeLeaf( - leaf.field, - leaf.operator, - ContextVariableInjector.inject_context_in_value(leaf.value, context_variables), - ) - ) - - request.body["filter"] = injected_filter.to_plain_object() diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py new file mode 100644 index 000000000..1d5974922 --- /dev/null +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING + +from forestadmin.agent_toolkit.utils.context_variable_injector import ContextVariableInjector +from forestadmin.agent_toolkit.utils.context_variable_instantiator import ContextVariablesInstantiator +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.factory import ConditionTreeFactory +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.base import ConditionTree +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf + +if TYPE_CHECKING: + from forestadmin.agent_toolkit.resources.collections.requests import RequestCollection + from forestadmin.agent_toolkit.utils.context import Request + + +class ContextVariableInjectorResourceMixin: + async def inject_context_variables_in_filter(self, request: "RequestCollection"): + context_variables_dct = request.body.pop("contextVariables", {}) + if request.body.get("filter") is None: + return + + context_variables = await ContextVariablesInstantiator.build_context_variables( + request.user, context_variables_dct, self.permission + ) + condition_tree = request.body["filter"] + if isinstance(request.body["filter"], str): + condition_tree = json.loads(condition_tree) + + condition_tree = ConditionTreeFactory.from_plain_object(condition_tree) + + injected_filter: ConditionTree = condition_tree.replace( + lambda leaf: ConditionTreeLeaf( + leaf.field, + leaf.operator, + ContextVariableInjector.inject_context_in_value(leaf.value, context_variables), + ) + ) + + request.body["filter"] = injected_filter.to_plain_object() + + async def inject_context_variables_in_live_query_segment(self, request: "RequestCollection"): + # TODO: handle context variables from front or not ?? + if request.query.get("segmentQuery") is None: + return + context_variables_dct = request.query.pop("contextVariables", {}) + + context_variables = await ContextVariablesInstantiator.build_context_variables( + request.user, context_variables_dct, self.permission + ) + + request.query["segmentQuery"] = ContextVariableInjector.inject_context_in_value( + request.query["segmentQuery"], context_variables + ) + + async def inject_context_variables_in_live_query_chart(self, request: "Request"): + context_variables_dct = request.body.get("contextVariables", {}) + context_variables = await ContextVariablesInstantiator.build_context_variables( + request.user, context_variables_dct, self.permission + ) + + request.body["query"] = ContextVariableInjector.inject_context_in_value( + request.body["query"], context_variables + ) From 1e1262888975811756bcf4bbe0f9c37a88656fab Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 25 Nov 2024 11:59:25 +0100 Subject: [PATCH 11/71] chore: fix typo --- .../agent_toolkit/resources/capabilities.py | 2 +- .../forestadmin/datasource_sqlalchemy/datasource.py | 2 +- .../datasource_customizer/datasource_composite.py | 12 ++++++------ .../forestadmin/datasource_toolkit/datasources.py | 2 +- .../interfaces/models/collections.py | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/capabilities.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/capabilities.py index 02a530380..997e8ba19 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/capabilities.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/capabilities.py @@ -47,7 +47,7 @@ async def capabilities(self, request: Request) -> Response: ret["collections"].append(self._get_collection_capability(collection_name)) ret["nativeQueryConnections"] = [ - {"name": connection} for connection in self.composite_datasource.get_native_query_connection() + {"name": connection} for connection in self.composite_datasource.get_native_query_connections() ] return HttpResponseBuilder.build_success_response(ret) diff --git a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py index 068ba002d..4f1bddca2 100644 --- a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py +++ b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py @@ -58,7 +58,7 @@ def build_mappers(self) -> Dict[str, Mapper]: return mappers async def execute_native_query(self, connection_name: str, native_query: str) -> List[RecordsDataAlias]: - if connection_name != self.schema["native_query_connections"][0]: + if connection_name != self.get_native_query_connections()[0]: # TODO: verify raise SqlAlchemyDatasourceException( f"The native query connection '{connection_name}' doesn't belongs to this datasource." diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py index b93541135..b81b52207 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py @@ -20,11 +20,11 @@ def schema(self) -> DatasourceSchema: return {"charts": charts} - def get_native_query_connection(self) -> List[str]: + def get_native_query_connections(self) -> List[str]: native_queries = [] for datasource in self._datasources: - native_queries.extend(datasource.get_native_query_connection()) + native_queries.extend(datasource.get_native_query_connections()) return native_queries @property @@ -57,8 +57,8 @@ def add_datasource(self, datasource: Datasource): if connection in self.schema["charts"].keys(): raise DatasourceToolkitException(f"Chart '{connection}' already exists.") - existing_native_query_connection_names = self.get_native_query_connection() - for connection in datasource.get_native_query_connection(): + existing_native_query_connection_names = self.get_native_query_connections() + for connection in datasource.get_native_query_connections(): if connection in existing_native_query_connection_names: raise DatasourceToolkitException(f"Native query connection '{connection}' already exists.") @@ -73,10 +73,10 @@ async def render_chart(self, caller: User, name: str) -> Chart: async def execute_native_query(self, connection_name: str, native_query: str) -> Any: for datasource in self._datasources: - if connection_name in datasource.get_native_query_connection(): + if connection_name in datasource.get_native_query_connections(): return await datasource.execute_native_query(connection_name, native_query) raise DatasourceToolkitException( f"Cannot find {connection_name} in datasources. " - f"Existing connection names are: {','.join(self.get_native_query_connection())}" + f"Existing connection names are: {','.join(self.get_native_query_connections())}" ) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py index 608045d30..d08868ba1 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py @@ -20,7 +20,7 @@ def __init__(self, live_query_connections: Optional[List[str]] = None) -> None: "charts": {}, } - def get_native_query_connection(self) -> List[str]: + def get_native_query_connections(self) -> List[str]: return self._live_query_connections or [] @property diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py index 3dd747f55..cf71ed388 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py @@ -52,7 +52,7 @@ def get_collection(self, name: str) -> BoundCollection: raise NotImplementedError @abc.abstractmethod - def get_native_query_connection(self) -> List[str]: + def get_native_query_connections(self) -> List[str]: raise NotImplementedError @abc.abstractmethod From bd4e75c1fb9db8691a3a42a91f5a84a75fe68670 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 25 Nov 2024 12:00:05 +0100 Subject: [PATCH 12/71] chore: ineject context variables in chart queries --- .../resources/collections/native_query.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py index 33bb828de..d6d511e96 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -4,6 +4,7 @@ from forestadmin.agent_toolkit.options import Options from forestadmin.agent_toolkit.resources.collections.base_collection_resource import BaseCollectionResource from forestadmin.agent_toolkit.resources.collections.decorators import authenticate, check_method, ip_white_list +from forestadmin.agent_toolkit.resources.context_variable_injector_mixin import ContextVariableInjectorResourceMixin from forestadmin.agent_toolkit.services.permissions.ip_whitelist_service import IpWhiteListService from forestadmin.agent_toolkit.services.permissions.permission_service import PermissionService from forestadmin.agent_toolkit.utils.context import HttpResponseBuilder, Request, RequestMethod, Response @@ -17,7 +18,7 @@ LiteralMethod = Literal["native_query"] -class NativeQueryResource(BaseCollectionResource): +class NativeQueryResource(BaseCollectionResource, ContextVariableInjectorResourceMixin): def __init__( self, composite_datasource: CompositeDatasource, @@ -40,10 +41,8 @@ async def dispatch(self, request: Request, method_name: Literal["native_query"]) @check_method(RequestMethod.POST) @authenticate async def handle_native_query(self, request: Request) -> Response: - # TODO: permission check - # TODO: context variable injector + await self.permission.can_chart(request) + await self.inject_context_variables_in_live_query_chart(request) return HttpResponseBuilder.build_success_response( - await self.composite_datasource.execute_native_query( - request.body["connection_name"], request.body["native_query"] - ) + await self.composite_datasource.execute_native_query(request.body["connectionName"], request.body["query"]) ) From e0f61266c44570cba2590e264d777cedab1523ce Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 25 Nov 2024 12:00:34 +0100 Subject: [PATCH 13/71] chore: draft permissions for segments live queries --- .../permissions/permission_service.py | 68 +++++++++++++++++-- .../permissions/permissions_functions.py | 18 +++++ 2 files changed, 79 insertions(+), 7 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py index 78ffbfc0f..824fe1e31 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py @@ -8,6 +8,8 @@ _decode_actions_permissions, _decode_crud_permissions, _decode_scope_permissions, + _decode_segment_query_permissions, + _dict_hash, _hash_chart, ) from forestadmin.agent_toolkit.services.permissions.smart_actions_checker import SmartActionChecker @@ -75,6 +77,36 @@ async def can_chart(self, request: RequestCollection) -> bool: ) return is_allowed + async def can_live_query_segment(self, request: RequestCollection) -> bool: + live_query = request.query["segmentQuery"] + # connection_name = request.query["connectionName"] + hash_live_query = _dict_hash( + { + "query": live_query, + # "connection_name": connection_name # TODO: review when connectionName in permissions + } + ) + is_allowed = hash_live_query in (await self._get_segment_queries(request.user.rendering_id, False)).get( + request.collection.name + ) + + # Refetch + if is_allowed is False: + is_allowed = hash_live_query in await self._get_segment_queries(request.user.rendering_id, True) + + # still not allowed - throw forbidden message + if is_allowed is False: + ForestLogger.log( + "debug", + f"User {request.user.user_id} cannot retrieve segment queries on rendering {request.user.rendering_id}", + ) + raise ForbiddenError("You don't have permission to access this segment query.") + ForestLogger.log( + "debug", + f"User {request.user.user_id} can retrieve segment queries on rendering {request.user.rendering_id}", + ) + return is_allowed + async def can_smart_action( self, request: RequestCollection, collection: Collection, filter_: Filter, allow_fetch: bool = True ): @@ -151,10 +183,7 @@ async def _get_chart_data(self, rendering_id: int, force_fetch: bool = False) -> ForestLogger.log("debug", f"Loading rendering permissions for rendering {rendering_id}") response = await ForestHttpApi.get_rendering_permissions(rendering_id, self.options) - stat_hash = [] - for stat in response["stats"]: - stat_hash.append(f'{stat["type"]}:{_hash_chart(stat)}') - self.cache["forest.stats"] = stat_hash + self._handle_rendering_permissions(response) return self.cache["forest.stats"] @@ -178,12 +207,37 @@ async def _get_collection_permissions_data(self, force_fetch: bool = False): async def _get_scope_and_team_data(self, rendering_id: int): if "forest.scopes" not in self.cache: response = await ForestHttpApi.get_rendering_permissions(rendering_id, self.options) - data = {"scopes": _decode_scope_permissions(response["collections"]), "team": response["team"]} - - self.cache["forest.scopes"] = data + self._handle_rendering_permissions(response) return self.cache["forest.scopes"] + async def _get_segment_queries(self, rendering_id: int, force_fetch: bool): + if force_fetch and "forest.segment_queries" in self.cache: + del self.cache["forest.segment_queries"] + + if "forest.segment_queries" not in self.cache: + response = await ForestHttpApi.get_rendering_permissions(rendering_id, self.options) + self._handle_rendering_permissions(response) + + return self.cache["forest.segment_queries"] + + def _handle_rendering_permissions(self, rendering_permissions): + # forest.stats + stat_hash = [] + for stat in rendering_permissions["stats"]: + stat_hash.append(f'{stat["type"]}:{_hash_chart(stat)}') + self.cache["forest.stats"] = stat_hash + + # forest.scopes + data = { + "scopes": _decode_scope_permissions(rendering_permissions["collections"]), + "team": rendering_permissions["team"], + } + self.cache["forest.scopes"] = data + + # forest.segment_queries + self.cache["forest.segment_queries"] = _decode_segment_query_permissions(rendering_permissions["collections"]) + async def _find_action_from_endpoint( self, collection: Collection, get_params: Dict, http_method: str ) -> Optional[ForestServerAction]: diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py index e86ccda40..28de72167 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py @@ -11,6 +11,23 @@ ################## +# TODO: adapt after new version of permissions api +def _decode_segment_query_permissions(raw_permission: Dict[Any, Any]): + segment_queries = {} + for collection_name, value in raw_permission.items(): + segment_queries[collection_name] = [] + for segment_query in value.get("segments", []): + segment_queries[collection_name].append( + _dict_hash( + { + "query": segment_query, + # "connection_name": connection_name # TODO: review when connectionName in permissions + } + ) + ) + return segment_queries + + def _decode_scope_permissions(raw_permission: Dict[Any, Any]) -> Dict[str, ConditionTree]: scopes = {} for collection_name, value in raw_permission.items(): @@ -57,6 +74,7 @@ def _dict_hash(data: Dict[Any, Any]) -> str: def _hash_chart(chart: Dict[Any, Any]) -> str: known_chart_keys = [ + # "connectionName", # TODO: to enable with next backend version of permissions "type", "apiRoute", "smartRoute", From 873763dc0e3007778e69e3822ca72a12db86cf96 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 25 Nov 2024 12:01:00 +0100 Subject: [PATCH 14/71] chore: draft of live query segments --- .../forestadmin/agent_toolkit/agent.py | 1 + .../resources/collections/crud.py | 61 +++++++++++++++++-- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py index 2fc452a1d..760929e8e 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py @@ -81,6 +81,7 @@ async def __mk_resources(self): ), "authentication": Authentication(self._ip_white_list_service, self.options), "crud": CrudResource( + self.customizer.composite_datasource, await self.customizer.get_datasource(), self._permission_service, self._ip_white_list_service, diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py index daa044358..86823ca21 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py @@ -1,8 +1,9 @@ import asyncio -from typing import Any, Awaitable, Dict, List, Literal, Tuple, Union, cast +from typing import Any, Awaitable, Dict, List, Literal, Optional, Tuple, Union, cast from uuid import UUID from forestadmin.agent_toolkit.forest_logger import ForestLogger +from forestadmin.agent_toolkit.options import Options from forestadmin.agent_toolkit.resources.collections.base_collection_resource import BaseCollectionResource from forestadmin.agent_toolkit.resources.collections.decorators import ( authenticate, @@ -22,6 +23,8 @@ ) from forestadmin.agent_toolkit.resources.collections.requests import RequestCollection, RequestCollectionException from forestadmin.agent_toolkit.resources.context_variable_injector_mixin import ContextVariableInjectorResourceMixin +from forestadmin.agent_toolkit.services.permissions.ip_whitelist_service import IpWhiteListService +from forestadmin.agent_toolkit.services.permissions.permission_service import PermissionService from forestadmin.agent_toolkit.services.serializers import add_search_metadata from forestadmin.agent_toolkit.services.serializers.json_api import JsonApiException, JsonApiSerializer from forestadmin.agent_toolkit.utils.context import HttpResponseBuilder, Request, RequestMethod, Response, User @@ -29,7 +32,9 @@ from forestadmin.agent_toolkit.utils.id import unpack_id from forestadmin.datasource_toolkit.collections import Collection from forestadmin.datasource_toolkit.datasource_customizer.collection_customizer import CollectionCustomizer -from forestadmin.datasource_toolkit.datasources import DatasourceException +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource +from forestadmin.datasource_toolkit.datasource_customizer.datasource_customizer import DatasourceCustomizer +from forestadmin.datasource_toolkit.datasources import Datasource, DatasourceException from forestadmin.datasource_toolkit.exceptions import ForbiddenError from forestadmin.datasource_toolkit.interfaces.fields import ( ManyToOne, @@ -64,6 +69,17 @@ class CrudResource(BaseCollectionResource, ContextVariableInjectorResourceMixin): + def __init__( + self, + datasource_composite: CompositeDatasource, + datasource: Union[Datasource, DatasourceCustomizer], + permission: PermissionService, + ip_white_list_service: IpWhiteListService, + options: Options, + ): + self._datasource_composite = datasource_composite + super().__init__(datasource, permission, ip_white_list_service, options) + @ip_white_list async def dispatch(self, request: Request, method_name: LiteralMethod) -> Response: method = getattr(self, method_name) @@ -160,6 +176,9 @@ async def list(self, request: RequestCollection) -> Response: scope_tree = await self.permission.get_scope(request.user, request.collection) try: paginated_filter = build_paginated_filter(request, scope_tree) + condition_tree = await self._handle_live_query_segment(request, paginated_filter.condition_tree) + paginated_filter = paginated_filter.override({"condition_tree": condition_tree}) + except FilterException as e: ForestLogger.log("exception", e) return HttpResponseBuilder.build_client_error_response([e]) @@ -192,6 +211,8 @@ async def csv(self, request: RequestCollection) -> Response: scope_tree = await self.permission.get_scope(request.user, request.collection) try: paginated_filter = build_paginated_filter(request, scope_tree) + condition_tree = await self._handle_live_query_segment(request, paginated_filter.condition_tree) + paginated_filter = paginated_filter.override({"condition_tree": condition_tree}) paginated_filter.page = None except FilterException as e: ForestLogger.log("exception", e) @@ -221,9 +242,12 @@ async def count(self, request: RequestCollection) -> Response: return HttpResponseBuilder.build_success_response({"meta": {"count": "deactivated"}}) scope_tree = await self.permission.get_scope(request.user, request.collection) - filter = build_filter(request, scope_tree) + filter_ = build_filter(request, scope_tree) + filter_ = filter_.override( + {"condition_tree": await self._handle_live_query_segment(request, filter_.condition_tree)} + ) aggregation = Aggregation({"operation": "Count"}) - result = await request.collection.aggregate(request.user, filter, aggregation) + result = await request.collection.aggregate(request.user, filter_, aggregation) try: count = result[0]["value"] except IndexError: @@ -425,3 +449,32 @@ def _serialize_records_with_relationships( schema = JsonApiSerializer.get(collection) return schema(projections=projection).dump(records if many is True else records[0], many=many) + + async def _handle_live_query_segment( + self, request: RequestCollection, condition_tree: Optional[ConditionTree] + ) -> ConditionTree: + + if request.query.get("segmentQuery") is not None: + # if "connectionName" not in request.query: + # # TODO: correct exception + # raise Exception + + await self.permission.can_live_query_segment(request) + await self.inject_context_variables_in_live_query_segment(request) + # TODO: remove connectionName mock + rslt = await self._datasource_composite.execute_native_query( + "django" if request.collection.name.startswith("app_") else "sqlalchemy", + request.query["segmentQuery"], + ) + + trees = [] + if condition_tree: + trees.append(condition_tree) + trees.append( + ConditionTreeLeaf( + SchemaUtils.get_primary_keys(request.collection.schema)[0], + Operator.IN, + [entry["id"] for entry in rslt], + ) + ) + return ConditionTreeFactory.intersect(trees) From 1cb7070be7916d2043c374a64fe67dd35ec01446 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 25 Nov 2024 12:01:21 +0100 Subject: [PATCH 15/71] chore(example): update example projects --- .../django/django_demo/.forestadmin-schema.json | 10 +++++----- .../app/forest/custom_datasources/typicode.py | 5 +++++ src/_example/django/django_demo/app/forest_admin.py | 4 ++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/_example/django/django_demo/.forestadmin-schema.json b/src/_example/django/django_demo/.forestadmin-schema.json index 6503219e8..b2850a5df 100644 --- a/src/_example/django/django_demo/.forestadmin-schema.json +++ b/src/_example/django/django_demo/.forestadmin-schema.json @@ -2726,7 +2726,7 @@ "enums": null, "field": "body", "inverseOf": null, - "isFilterable": false, + "isFilterable": true, "isPrimaryKey": false, "isReadOnly": false, "isRequired": false, @@ -2740,7 +2740,7 @@ "enums": null, "field": "email", "inverseOf": null, - "isFilterable": false, + "isFilterable": true, "isPrimaryKey": false, "isReadOnly": false, "isRequired": false, @@ -2768,7 +2768,7 @@ "enums": null, "field": "name", "inverseOf": null, - "isFilterable": false, + "isFilterable": true, "isPrimaryKey": false, "isReadOnly": false, "isRequired": false, @@ -3486,7 +3486,7 @@ "enums": null, "field": "body", "inverseOf": null, - "isFilterable": false, + "isFilterable": true, "isPrimaryKey": false, "isReadOnly": false, "isRequired": false, @@ -3514,7 +3514,7 @@ "enums": null, "field": "title", "inverseOf": null, - "isFilterable": false, + "isFilterable": true, "isPrimaryKey": false, "isReadOnly": false, "isRequired": false, diff --git a/src/_example/django/django_demo/app/forest/custom_datasources/typicode.py b/src/_example/django/django_demo/app/forest/custom_datasources/typicode.py index f6f9bfd04..dcf49f57d 100644 --- a/src/_example/django/django_demo/app/forest/custom_datasources/typicode.py +++ b/src/_example/django/django_demo/app/forest/custom_datasources/typicode.py @@ -107,14 +107,17 @@ def __init__(self, datasource: Datasource[Self]): "name": { "type": FieldType.COLUMN, "column_type": "String", + "filter_operators": set([Operator.EQUAL]), }, "email": { "type": FieldType.COLUMN, "column_type": "String", + "filter_operators": set([Operator.EQUAL]), }, "body": { "type": FieldType.COLUMN, "column_type": "String", + "filter_operators": set([Operator.EQUAL]), }, } ) @@ -140,10 +143,12 @@ def __init__(self, datasource: Datasource[Self]): "title": { "type": FieldType.COLUMN, "column_type": "String", + "filter_operators": set([Operator.EQUAL]), }, "body": { "type": FieldType.COLUMN, "column_type": "String", + "filter_operators": set([Operator.EQUAL]), }, } ) diff --git a/src/_example/django/django_demo/app/forest_admin.py b/src/_example/django/django_demo/app/forest_admin.py index 6d43fbc27..a3cef98cf 100644 --- a/src/_example/django/django_demo/app/forest_admin.py +++ b/src/_example/django/django_demo/app/forest_admin.py @@ -50,9 +50,9 @@ def customize_forest(agent: DjangoAgent): # customize_forest_logging() - agent.add_datasource(DjangoDatasource(support_polymorphic_relations=True)) + agent.add_datasource(DjangoDatasource(support_polymorphic_relations=True, live_query_connection="django")) agent.add_datasource(TypicodeDatasource()) - agent.add_datasource(SqlAlchemyDatasource(Base, DB_URI)) + agent.add_datasource(SqlAlchemyDatasource(Base, DB_URI, live_query_connection="sqlalchemy")) agent.customize_collection("address").add_segment("France", segment_addr_fr("address")) agent.customize_collection("app_address").add_segment("France", segment_addr_fr("app_address")) From 3abfb99541dbcfc532faf3feacdcaacaccb037df Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 27 Nov 2024 09:49:33 +0100 Subject: [PATCH 16/71] chore: add parameters to execute native query and handle it --- .../resources/collections/crud.py | 17 ++++++++++++----- .../resources/collections/native_query.py | 6 ++++-- .../context_variable_injector_mixin.py | 18 +++++++++++++----- .../utils/context_variable_injector.py | 18 +++++++++++++++++- .../datasource_django/datasource.py | 7 +++++-- .../datasource_sqlalchemy/datasource.py | 11 +++++++++-- .../context/relaxed_wrappers/collection.py | 4 +++- .../datasource_composite.py | 7 +++---- .../datasource_toolkit/datasources.py | 2 +- .../decorators/datasource_decorator.py | 9 +++++++-- .../interfaces/models/collections.py | 4 +++- 11 files changed, 77 insertions(+), 26 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py index 86823ca21..3d2bbb670 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py @@ -452,7 +452,7 @@ def _serialize_records_with_relationships( async def _handle_live_query_segment( self, request: RequestCollection, condition_tree: Optional[ConditionTree] - ) -> ConditionTree: + ) -> Optional[ConditionTree]: if request.query.get("segmentQuery") is not None: # if "connectionName" not in request.query: @@ -460,11 +460,17 @@ async def _handle_live_query_segment( # raise Exception await self.permission.can_live_query_segment(request) - await self.inject_context_variables_in_live_query_segment(request) + vars = await self.inject_and_get_context_variables_in_live_query_segment(request) + # TODO: remove connectionName mock + if request.collection.name.startswith("app_"): + connection_name = "django" + elif request.collection.name.startswith("sqlalchemy_"): + connection_name = "dj_sqlachemy" + else: + connection_name = "sqlalchemy" # TODO: remove connectionName mock rslt = await self._datasource_composite.execute_native_query( - "django" if request.collection.name.startswith("app_") else "sqlalchemy", - request.query["segmentQuery"], + connection_name, request.query["segmentQuery"], vars ) trees = [] @@ -477,4 +483,5 @@ async def _handle_live_query_segment( [entry["id"] for entry in rslt], ) ) - return ConditionTreeFactory.intersect(trees) + return ConditionTreeFactory.intersect(trees) + return condition_tree diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py index d6d511e96..448f2884a 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -42,7 +42,9 @@ async def dispatch(self, request: Request, method_name: Literal["native_query"]) @authenticate async def handle_native_query(self, request: Request) -> Response: await self.permission.can_chart(request) - await self.inject_context_variables_in_live_query_chart(request) + variables = await self.inject_and_get_context_variables_in_live_query_chart(request) return HttpResponseBuilder.build_success_response( - await self.composite_datasource.execute_native_query(request.body["connectionName"], request.body["query"]) + await self.composite_datasource.execute_native_query( + request.body["connectionName"], request.body["query"], variables + ) ) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py index 1d5974922..af1c93953 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict from forestadmin.agent_toolkit.utils.context_variable_injector import ContextVariableInjector from forestadmin.agent_toolkit.utils.context_variable_instantiator import ContextVariablesInstantiator @@ -39,21 +39,24 @@ async def inject_context_variables_in_filter(self, request: "RequestCollection") request.body["filter"] = injected_filter.to_plain_object() - async def inject_context_variables_in_live_query_segment(self, request: "RequestCollection"): + async def inject_and_get_context_variables_in_live_query_segment( + self, request: "RequestCollection" + ) -> Dict[str, str]: # TODO: handle context variables from front or not ?? if request.query.get("segmentQuery") is None: - return + return {} context_variables_dct = request.query.pop("contextVariables", {}) context_variables = await ContextVariablesInstantiator.build_context_variables( request.user, context_variables_dct, self.permission ) - request.query["segmentQuery"] = ContextVariableInjector.inject_context_in_value( + request.query["segmentQuery"], vars = ContextVariableInjector.format_query_and_get_vars( request.query["segmentQuery"], context_variables ) + return vars - async def inject_context_variables_in_live_query_chart(self, request: "Request"): + async def inject_and_get_context_variables_in_live_query_chart(self, request: "Request") -> Dict[str, str]: context_variables_dct = request.body.get("contextVariables", {}) context_variables = await ContextVariablesInstantiator.build_context_variables( request.user, context_variables_dct, self.permission @@ -62,3 +65,8 @@ async def inject_context_variables_in_live_query_chart(self, request: "Request") request.body["query"] = ContextVariableInjector.inject_context_in_value( request.body["query"], context_variables ) + + request.query["query"], vars = ContextVariableInjector.format_query_and_get_vars( + request.query["query"], context_variables + ) + return vars diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py index 35de62534..9a3b89d9b 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py @@ -1,5 +1,5 @@ import re -from typing import Optional +from typing import Dict, Optional, Tuple from forestadmin.agent_toolkit.utils.context_variables import ContextVariables from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.base import ConditionTree @@ -29,3 +29,19 @@ def inject_context_in_value(value, context_variable: ContextVariables): return value return re.sub(r"{{([^}]+)}}", lambda match: str(context_variable.get_value(match.group(1))), value) + + @staticmethod + def format_query_and_get_vars(value, context_variable: ContextVariables) -> Tuple[str, Dict[str, str]]: + if not isinstance(value, str): + return value + + vars = {} + + def _match(match): + # TODO: find a better way to handle `parameters` (and like '%') over all datasources + vars[match.group(1).replace(".", "__")] = context_variable.get_value(match.group(1)) + return f"%({match.group(1).replace(".", "__")})s" + + ret = re.sub(r"{{([^}]+)}}", _match, value.replace("%", "\\%")) + + return ret, vars diff --git a/src/datasource_django/forestadmin/datasource_django/datasource.py b/src/datasource_django/forestadmin/datasource_django/datasource.py index 62a673e0f..f227e4925 100644 --- a/src/datasource_django/forestadmin/datasource_django/datasource.py +++ b/src/datasource_django/forestadmin/datasource_django/datasource.py @@ -58,7 +58,9 @@ def _create_collections(self): collection = DjangoCollection(self, model, self.support_polymorphic_relations) self.add_collection(collection) - async def execute_native_query(self, connection_name: str, native_query: str) -> List[RecordsDataAlias]: + async def execute_native_query( + self, connection_name: str, native_query: str, parameters: Dict[str, str] + ) -> List[RecordsDataAlias]: if ( self._django_live_query_connections is None or connection_name not in self._django_live_query_connections.keys() @@ -80,7 +82,8 @@ async def execute_native_query(self, connection_name: str, native_query: str) -> def _execute_native_query(): cursor = connections[self._django_live_query_connections[connection_name]].cursor() # type: ignore try: - rows = cursor.execute(native_query) + # TODO: find a better way to handle `parameters` (and like '%') over all datasources + rows = cursor.execute(native_query.replace("\\%", "%%"), parameters) ret = [] for row in rows: return_row = {} diff --git a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py index 4f1bddca2..327c1f1c2 100644 --- a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py +++ b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py @@ -57,7 +57,9 @@ def build_mappers(self) -> Dict[str, Mapper]: mappers[mapper.persist_selectable.name] = mapper return mappers - async def execute_native_query(self, connection_name: str, native_query: str) -> List[RecordsDataAlias]: + async def execute_native_query( + self, connection_name: str, native_query: str, parameters: Dict[str, str] + ) -> List[RecordsDataAlias]: if connection_name != self.get_native_query_connections()[0]: # TODO: verify raise SqlAlchemyDatasourceException( @@ -67,8 +69,13 @@ async def execute_native_query(self, connection_name: str, native_query: str) -> session = self.Session() query = native_query if isinstance(query, str): + query = native_query + # TODO: find a better way to handle `parameters` (and like '%') over all datasources + for key in parameters.keys(): + query = query.replace(f"%({key})s", f":{key}") + query = query.replace("\\%", "%") query = text(query) - rows = session.execute(query) + rows = session.execute(query, parameters) return [*rows.mappings()] except Exception as exc: # TODO: verify diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py index 20e18e9f9..7d116345c 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py @@ -59,7 +59,9 @@ def schema(self): def name(self): return self.datasource.name - async def execute_native_query(self, native_query: str) -> Any: + async def execute_native_query( + self, connection_name: str, native_query: str, parameters: Dict[str, str] + ) -> List[Dict[str, Any]]: raise RelaxedDatasourceException("Cannot use this method. Please use 'collection.get_native_driver' instead.") diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py index b81b52207..f1e6caf71 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, Dict, List from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.datasources import Datasource @@ -23,7 +23,6 @@ def schema(self) -> DatasourceSchema: def get_native_query_connections(self) -> List[str]: native_queries = [] for datasource in self._datasources: - native_queries.extend(datasource.get_native_query_connections()) return native_queries @@ -71,10 +70,10 @@ async def render_chart(self, caller: User, name: str) -> Chart: raise DatasourceToolkitException(f"Chart {name} is not defined in the datasource.") - async def execute_native_query(self, connection_name: str, native_query: str) -> Any: + async def execute_native_query(self, connection_name: str, native_query: str, parameters: Dict[str, str]) -> Any: for datasource in self._datasources: if connection_name in datasource.get_native_query_connections(): - return await datasource.execute_native_query(connection_name, native_query) + return await datasource.execute_native_query(connection_name, native_query, parameters) raise DatasourceToolkitException( f"Cannot find {connection_name} in datasources. " diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py index d08868ba1..6ba4939ad 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py @@ -49,5 +49,5 @@ def add_collection(self, collection: BoundCollection) -> None: async def render_chart(self, caller: User, name: str) -> Chart: raise DatasourceException(f"Chart {name} not exists on this datasource.") - async def execute_native_query(self, connection_name: str, native_query: str) -> Any: + async def execute_native_query(self, connection_name: str, native_query: str, parameters: Dict[str, str]) -> Any: raise NotImplementedError(f"'execute_native_query' is not implemented on {self.__class__.__name__}") diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py index 62450bfa9..64ebddda3 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Any, Dict, List, Union from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.collections import Collection @@ -11,7 +11,7 @@ class DatasourceDecorator(Datasource): def __init__( self, child_datasource: Union[Datasource, "DatasourceDecorator"], class_collection_decorator: type ) -> None: - super().__init__() + super().__init__(child_datasource.get_native_query_connections()) self.child_datasource = child_datasource self.class_collection_decorator = class_collection_decorator self._decorators: Dict[Collection, CollectionDecorator] = {} @@ -38,3 +38,8 @@ def get_charts(self): async def render_chart(self, caller: User, name: str) -> Chart: return await self.child_datasource.render_chart(caller, name) + + async def execute_native_query( + self, connection_name: str, native_query: str, parameters: Dict[str, str] + ) -> List[Dict[str, Any]]: + return await self.child_datasource.execute_native_query(connection_name, native_query, parameters) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py index cf71ed388..9e750a7fa 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py @@ -65,5 +65,7 @@ def schema(self) -> DatasourceSchema: raise NotImplementedError @abc.abstractmethod - async def execute_native_query(self, connection_name: str, native_query: str) -> List[RecordsDataAlias]: + async def execute_native_query( + self, connection_name: str, native_query: str, parameters: Dict[str, str] + ) -> List[RecordsDataAlias]: raise NotImplementedError From 7830f63b34c6f4bff5d931cc0b2f925f7e00d880 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 27 Nov 2024 15:04:21 +0100 Subject: [PATCH 17/71] chore: add comments for query formating --- .../utils/context_variable_injector.py | 15 +++++++++------ .../forestadmin/datasource_django/datasource.py | 4 +++- .../datasource_sqlalchemy/datasource.py | 6 ++++-- .../forestadmin/datasource_toolkit/datasources.py | 3 +++ 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py index 9a3b89d9b..6a0c6033f 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py @@ -31,17 +31,20 @@ def inject_context_in_value(value, context_variable: ContextVariables): return re.sub(r"{{([^}]+)}}", lambda match: str(context_variable.get_value(match.group(1))), value) @staticmethod - def format_query_and_get_vars(value, context_variable: ContextVariables) -> Tuple[str, Dict[str, str]]: + def format_query_and_get_vars(value: str, context_variable: ContextVariables) -> Tuple[str, Dict[str, str]]: if not isinstance(value, str): return value - - vars = {} + variables = {} + # to allow datasources to rework variables: + # - '%' are replaced by '\%' + # - '{{var}}' are replaced by '%(var)s' + # - '.' in vars are replaced by '__', and also in the returned mapping + # - and the mapping of vars is returned def _match(match): - # TODO: find a better way to handle `parameters` (and like '%') over all datasources - vars[match.group(1).replace(".", "__")] = context_variable.get_value(match.group(1)) + variables[match.group(1).replace(".", "__")] = context_variable.get_value(match.group(1)) return f"%({match.group(1).replace(".", "__")})s" ret = re.sub(r"{{([^}]+)}}", _match, value.replace("%", "\\%")) - return ret, vars + return (ret, variables) diff --git a/src/datasource_django/forestadmin/datasource_django/datasource.py b/src/datasource_django/forestadmin/datasource_django/datasource.py index f227e4925..ad72bfc08 100644 --- a/src/datasource_django/forestadmin/datasource_django/datasource.py +++ b/src/datasource_django/forestadmin/datasource_django/datasource.py @@ -82,8 +82,10 @@ async def execute_native_query( def _execute_native_query(): cursor = connections[self._django_live_query_connections[connection_name]].cursor() # type: ignore try: - # TODO: find a better way to handle `parameters` (and like '%') over all datasources + # replace '\s' by '%%' + # %(var)s is already the correct syntax rows = cursor.execute(native_query.replace("\\%", "%%"), parameters) + ret = [] for row in rows: return_row = {} diff --git a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py index 327c1f1c2..de967c909 100644 --- a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py +++ b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py @@ -70,10 +70,12 @@ async def execute_native_query( query = native_query if isinstance(query, str): query = native_query - # TODO: find a better way to handle `parameters` (and like '%') over all datasources for key in parameters.keys(): + # replace '%(...)s' by ':...' query = query.replace(f"%({key})s", f":{key}") - query = query.replace("\\%", "%") + # replace '\%' by '%' + query = query.replace("\\%", "%") + query = text(query) rows = session.execute(query, parameters) return [*rows.mappings()] diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py index 6ba4939ad..4ddc1b44e 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py @@ -50,4 +50,7 @@ async def render_chart(self, caller: User, name: str) -> Chart: raise DatasourceException(f"Chart {name} not exists on this datasource.") async def execute_native_query(self, connection_name: str, native_query: str, parameters: Dict[str, str]) -> Any: + # in native_query, there is the following syntax: + # - parameters to inject by 'execute' method are in the format '%(var)s' + # - '%' (in 'like' comparisons) are replaced by '\%' (to avoid conflict with previous rule) raise NotImplementedError(f"'execute_native_query' is not implemented on {self.__class__.__name__}") From 24533e435dd1bfd783dfc8b825ab2846349bf45f Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 27 Nov 2024 15:36:25 +0100 Subject: [PATCH 18/71] chore: add error handling --- .../resources/collections/crud.py | 33 ++++++++++--------- .../resources/collections/native_query.py | 7 ++++ .../resources/collections/requests.py | 2 +- .../permissions/permission_service.py | 1 + .../agent_toolkit/utils/context.py | 2 ++ .../datasource_composite.py | 6 ++-- 6 files changed, 31 insertions(+), 20 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py index 3d2bbb670..0bea1ff53 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py @@ -35,7 +35,7 @@ from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasource_customizer.datasource_customizer import DatasourceCustomizer from forestadmin.datasource_toolkit.datasources import Datasource, DatasourceException -from forestadmin.datasource_toolkit.exceptions import ForbiddenError +from forestadmin.datasource_toolkit.exceptions import BusinessError, ForbiddenError from forestadmin.datasource_toolkit.interfaces.fields import ( ManyToOne, OneToOne, @@ -453,25 +453,26 @@ def _serialize_records_with_relationships( async def _handle_live_query_segment( self, request: RequestCollection, condition_tree: Optional[ConditionTree] ) -> Optional[ConditionTree]: + # TODO: remove connectionName mock + if request.collection.name.startswith("app_"): + request.query["connectionName"] = "django" + elif request.collection.name.startswith("sqlalchemy_"): + request.query["connectionName"] = "dj_sqlachemy" + else: + request.query["connectionName"] = "sqlalchemy" + # TODO: remove connectionName mock if request.query.get("segmentQuery") is not None: - # if "connectionName" not in request.query: - # # TODO: correct exception - # raise Exception + if "connectionName" not in request.query: + raise BusinessError("Missing 'connectionName' parameter.") await self.permission.can_live_query_segment(request) - vars = await self.inject_and_get_context_variables_in_live_query_segment(request) - # TODO: remove connectionName mock - if request.collection.name.startswith("app_"): - connection_name = "django" - elif request.collection.name.startswith("sqlalchemy_"): - connection_name = "dj_sqlachemy" - else: - connection_name = "sqlalchemy" - # TODO: remove connectionName mock - rslt = await self._datasource_composite.execute_native_query( - connection_name, request.query["segmentQuery"], vars + variables = await self.inject_and_get_context_variables_in_live_query_segment(request) + native_query_result = await self._datasource_composite.execute_native_query( + request.query["connectionName"], request.query["segmentQuery"], variables ) + if len(native_query_result) > 0 and "id" not in native_query_result[0]: + raise BusinessError("Live query must return an 'id' field.") trees = [] if condition_tree: @@ -480,7 +481,7 @@ async def _handle_live_query_segment( ConditionTreeLeaf( SchemaUtils.get_primary_keys(request.collection.schema)[0], Operator.IN, - [entry["id"] for entry in rslt], + [entry["id"] for entry in native_query_result], ) ) return ConditionTreeFactory.intersect(trees) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py index 448f2884a..4859d1967 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -10,6 +10,7 @@ from forestadmin.agent_toolkit.utils.context import HttpResponseBuilder, Request, RequestMethod, Response from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasource_customizer.datasource_customizer import DatasourceCustomizer +from forestadmin.datasource_toolkit.exceptions import BusinessError from forestadmin.datasource_toolkit.interfaces.models.collections import BoundCollection, Datasource DatasourceAlias = Union[Datasource[BoundCollection], DatasourceCustomizer] @@ -42,6 +43,12 @@ async def dispatch(self, request: Request, method_name: Literal["native_query"]) @authenticate async def handle_native_query(self, request: Request) -> Response: await self.permission.can_chart(request) + assert request.body is not None + if "connectionName" not in request.body: + raise BusinessError("Missing 'connectionName' in parameter.") + if "query" not in request.body: + raise BusinessError("Missing 'query' in parameter.") + variables = await self.inject_and_get_context_variables_in_live_query_chart(request) return HttpResponseBuilder.build_success_response( await self.composite_datasource.execute_native_query( diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/requests.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/requests.py index 598aa0eb0..37c6c0809 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/requests.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/requests.py @@ -39,7 +39,7 @@ class RequestCollection(Request): def __init__( self, method: RequestMethod, - collection: Union[Collection, CollectionCustomizer], + collection: Collection, headers: Dict[str, str], client_ip: str, query: Dict[str, str], diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py index 824fe1e31..999d102b2 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py @@ -59,6 +59,7 @@ async def can(self, caller: User, collection: Collection, action: str, allow_fet return is_allowed async def can_chart(self, request: RequestCollection) -> bool: + # TODO: verify after new permissions hash_request = request.body["type"] + ":" + _hash_chart(request.body) is_allowed = hash_request in await self._get_chart_data(request.user.rendering_id, False) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py index ed3ac7a4a..bb21dffcf 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py @@ -144,6 +144,8 @@ def build_method_not_allowed_response() -> Response: @staticmethod def _get_error_status(error: Exception): + if isinstance(error, BusinessError): + return 400 if isinstance(error, ValidationError): return 400 if isinstance(error, ForbiddenError): diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py index f1e6caf71..06c98a878 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py @@ -2,7 +2,7 @@ from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.datasources import Datasource -from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException +from forestadmin.datasource_toolkit.exceptions import BusinessError, DatasourceToolkitException from forestadmin.datasource_toolkit.interfaces.chart import Chart from forestadmin.datasource_toolkit.interfaces.models.collections import BoundCollection, DatasourceSchema @@ -75,7 +75,7 @@ async def execute_native_query(self, connection_name: str, native_query: str, pa if connection_name in datasource.get_native_query_connections(): return await datasource.execute_native_query(connection_name, native_query, parameters) - raise DatasourceToolkitException( - f"Cannot find {connection_name} in datasources. " + raise BusinessError( + f"Cannot find connection '{connection_name}' in datasources. " f"Existing connection names are: {','.join(self.get_native_query_connections())}" ) From 44919e599818ef88c55f475bf4b5e21c09c4f017 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 27 Nov 2024 16:49:58 +0100 Subject: [PATCH 19/71] chore: fix existing tests --- .../agent_toolkit/utils/context.py | 4 +- .../tests/resources/collections/test_crud.py | 257 +++++++++++++++--- .../permissions/test_permission_service.py | 7 +- src/agent_toolkit/tests/test_agent_toolkit.py | 29 +- .../test_composite_datasource.py | 12 - src/flask_agent/tests/test_flask_agent.py | 1 + 6 files changed, 259 insertions(+), 51 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py index bb21dffcf..fbadca4ee 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py @@ -144,14 +144,14 @@ def build_method_not_allowed_response() -> Response: @staticmethod def _get_error_status(error: Exception): - if isinstance(error, BusinessError): - return 400 if isinstance(error, ValidationError): return 400 if isinstance(error, ForbiddenError): return 403 if isinstance(error, UnprocessableError): return 422 + if isinstance(error, BusinessError): + return 400 if isinstance(error, HTTPError): return error.code diff --git a/src/agent_toolkit/tests/resources/collections/test_crud.py b/src/agent_toolkit/tests/resources/collections/test_crud.py index c167d6a8e..3ed9d8c47 100644 --- a/src/agent_toolkit/tests/resources/collections/test_crud.py +++ b/src/agent_toolkit/tests/resources/collections/test_crud.py @@ -25,6 +25,7 @@ from forestadmin.agent_toolkit.utils.context import Request, RequestMethod, User from forestadmin.agent_toolkit.utils.csv import CsvException from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasources import Datasource, DatasourceException from forestadmin.datasource_toolkit.exceptions import ValidationError from forestadmin.datasource_toolkit.interfaces.fields import FieldType, Operator, PrimitiveType @@ -262,7 +263,13 @@ def test_dispatch(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) crud_resource.get = AsyncMock() crud_resource.list = AsyncMock() crud_resource.csv = AsyncMock() @@ -306,7 +313,13 @@ def test_dispatch_error(self, mock_request_collection: Mock): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) crud_resource.get = AsyncMock() with patch.object( @@ -357,7 +370,13 @@ def test_get( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.dump = Mock( return_value={"data": {"type": "order", "attributes": mock_order}} ) @@ -430,7 +449,13 @@ def test_get_no_data( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) response = self.loop.run_until_complete(crud_resource.get(request)) self.permission_service.can.assert_any_await(request.user, request.collection, "read") @@ -470,7 +495,13 @@ def test_get_errors( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.dump = Mock(side_effect=JsonApiException) response = self.loop.run_until_complete(crud_resource.get(request)) @@ -517,7 +548,13 @@ def test_get_should_return_to_many_relations_as_link(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) with patch.object(self.collection_order, "list", new_callable=AsyncMock, return_value=mock_orders): response = self.loop.run_until_complete(crud_resource.get(request)) @@ -545,7 +582,13 @@ def test_get_with_polymorphic_relation_should_add_projection_star(self, mocked_j headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.dump = Mock( return_value={ "data": {"type": "tag", "attributes": {"id": 10, "taggable_id": 10, "taggable_type": "product"}}, @@ -585,7 +628,13 @@ def test_add( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.load = Mock(return_value=mock_order) mocked_json_serializer_get.return_value.dump = Mock( return_value={"data": {"type": "order", "attributes": mock_order}} @@ -629,7 +678,13 @@ def test_add_errors( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) crud_resource.extract_data = AsyncMock(return_value=(mock_order, [])) # JsonApiException @@ -724,7 +779,13 @@ def test_add_with_relation( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.load = Mock(return_value=mock_order) mocked_json_serializer_get.return_value.dump = Mock( return_value={"data": {"type": "order", "attributes": mock_order}} @@ -802,7 +863,13 @@ def test_add_should_create_and_associate_polymorphic_many_to_one(self, mocked_js return_value={"taggable_id": 14, "taggable_type": "order", "tag": "aaaaa"} ) mocked_json_serializer_get.return_value.dump = Mock(return_value={}) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) with patch.object( self.collection_tag, "create", new_callable=AsyncMock, return_value=[{}] @@ -838,7 +905,13 @@ def test_add_should_create_and_associate_polymorphic_one_to_one(self, mocked_jso mocked_json_serializer_get.return_value.load = Mock(return_value={"cost": 12.3, "important": True, "tags": 22}) mocked_json_serializer_get.return_value.dump = Mock(return_value={}) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) with patch.object( self.collection_order, "create", @@ -888,7 +961,13 @@ def test_add_should_return_to_many_relations_as_link(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) with patch.object(self.collection_order, "create", new_callable=AsyncMock, return_value=mock_orders): response = self.loop.run_until_complete(crud_resource.add(request)) @@ -924,7 +1003,13 @@ def test_list(self, mocked_json_serializer_get: Mock): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.dump = Mock( return_value={ "data": [ @@ -966,7 +1051,13 @@ def test_list_should_return_to_many_relations_as_link(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) with patch.object(self.collection_order, "list", new_callable=AsyncMock, return_value=mock_orders): response = self.loop.run_until_complete(crud_resource.list(request)) @@ -990,7 +1081,13 @@ def test_list_should_return_to_many_relations_as_link(self): def test_list_with_polymorphic_many_to_one_should_query_all_relation_record_columns( self, mocked_json_serializer_get: Mock ): - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.dump = Mock( return_value={ "data": [ @@ -1081,7 +1178,13 @@ def test_list_should_parse_multi_field_sorting(self, mocked_json_serializer_get: headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.dump = Mock( return_value={ "data": [ @@ -1115,7 +1218,13 @@ def test_list_errors(self, mocked_json_serializer_get: Mock): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) # FilterException response = self.loop.run_until_complete(crud_resource.list(request)) @@ -1182,7 +1291,13 @@ def test_count(self): client_ip="127.0.0.1", ) self.collection_order.aggregate = AsyncMock(return_value=[{"value": 1000, "group": {}}]) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) response = self.loop.run_until_complete(crud_resource.count(request)) self.permission_service.can.assert_any_await(request.user, request.collection, "browse") @@ -1212,7 +1327,13 @@ def test_deactivate_count(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order._schema["countable"] = False response = self.loop.run_until_complete(crud_resource.count(request)) @@ -1256,7 +1377,13 @@ def test_edit( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order.list = AsyncMock(return_value=[mock_order]) self.collection_order.update = AsyncMock() mocked_json_serializer_get.return_value.load = Mock(return_value=mock_order) @@ -1308,7 +1435,13 @@ def test_edit_errors( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) # CollectionResourceException with patch( @@ -1360,7 +1493,13 @@ def test_edit_should_not_throw_and_do_nothing_on_empty_record(self): mock_order = {"id": 10, "cost": 201} self.collection_order.list = AsyncMock(return_value=[mock_order]) self.collection_order.update = AsyncMock() - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) response = self.loop.run_until_complete(crud_resource.update(request)) self.permission_service.can.reset_mock() @@ -1384,7 +1523,13 @@ def test_edit_should_not_update_pk_if_not_set_in_attributes(self): client_ip="127.0.0.1", ) mock_order = {"id": 10, "cost": 201} - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order.list = AsyncMock(return_value=[mock_order]) self.collection_order.update = AsyncMock() @@ -1407,7 +1552,13 @@ def test_edit_should_update_pk_if_set_in_attributes(self): client_ip="127.0.0.1", ) mock_order = {"id": 11, "cost": 201} - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order.list = AsyncMock(return_value=[mock_order]) self.collection_order.update = AsyncMock() @@ -1434,7 +1585,13 @@ def test_update_should_return_to_many_relations_as_link(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) with patch.object(self.collection_order, "list", new_callable=AsyncMock, return_value=mock_orders): response = self.loop.run_until_complete(crud_resource.update(request)) @@ -1474,7 +1631,13 @@ def test_delete( client_ip="127.0.0.1", ) self.collection_order.delete = AsyncMock() - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) response = self.loop.run_until_complete(crud_resource.delete(request)) self.permission_service.can.assert_any_await(request.user, request.collection, "delete") self.permission_service.can.reset_mock() @@ -1490,7 +1653,13 @@ def test_delete_error(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) # CollectionResourceException with patch( @@ -1524,7 +1693,13 @@ def test_delete_list( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order.delete = AsyncMock() response = self.loop.run_until_complete(crud_resource.delete_list(request)) @@ -1549,7 +1724,13 @@ def test_csv(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order.list = AsyncMock(return_value=mock_orders) response = self.loop.run_until_complete(crud_resource.csv(request)) @@ -1580,7 +1761,13 @@ def test_csv_errors(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) # FilterException response = self.loop.run_until_complete(crud_resource.csv(request)) @@ -1654,7 +1841,13 @@ def test_csv_should_not_apply_pagination(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order.list = AsyncMock(return_value=mock_orders) response = self.loop.run_until_complete(crud_resource.csv(request)) diff --git a/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py b/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py index cecd5da05..921fa77ce 100644 --- a/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py +++ b/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py @@ -383,7 +383,7 @@ def test_can_chart_should_retry_on_not_allowed_at_first_try(self): def mock_get_rendering_permissions(rendering_id, options): if mock["call_count"] == 0: - ret = {"stats": {}} + ret = {"stats": {}, "collections": {}, "team": {}} elif mock["call_count"] == 1: ret = http_patches["get_rendering_permissions"].kwargs["return_value"] mock["call_count"] = mock["call_count"] + 1 @@ -416,7 +416,10 @@ def test_can_chart_should_raise_forbidden_error_on_not_allowed_chart(self): ) with patch.object( - ForestHttpApi, "get_rendering_permissions", new_callable=AsyncMock, return_value={"stats": {}} + ForestHttpApi, + "get_rendering_permissions", + new_callable=AsyncMock, + return_value={"stats": {}, "collections": {}, "team": {}}, ): self.assertRaisesRegex( ForbiddenError, diff --git a/src/agent_toolkit/tests/test_agent_toolkit.py b/src/agent_toolkit/tests/test_agent_toolkit.py index 0301ca2ad..2aceffeb6 100644 --- a/src/agent_toolkit/tests/test_agent_toolkit.py +++ b/src/agent_toolkit/tests/test_agent_toolkit.py @@ -18,6 +18,7 @@ @patch("forestadmin.agent_toolkit.agent.CrudRelatedResource") @patch("forestadmin.agent_toolkit.agent.StatsResource") @patch("forestadmin.agent_toolkit.agent.ActionResource") +@patch("forestadmin.agent_toolkit.agent.NativeQueryResource") @patch("forestadmin.agent_toolkit.agent.SchemaEmitter.get_serialized_schema", new_callable=AsyncMock) @patch("forestadmin.agent_toolkit.agent.ForestHttpApi.send_schema", new_callable=AsyncMock) class TestAgent(TestCase): @@ -34,6 +35,7 @@ def test_create( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -81,6 +83,7 @@ def test_property_resources( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -97,10 +100,14 @@ def test_property_resources( mocked_authentication_resource.assert_called_once_with(agent._ip_white_list_service, agent.options) mocked_capabilities_resource.assert_called_once_with( - "fake_datasource", agent._ip_white_list_service, agent.options + agent.customizer.composite_datasource, agent._ip_white_list_service, agent.options ) mocked_crud_resource.assert_called_once_with( - "fake_datasource", agent._permission_service, agent._ip_white_list_service, agent.options + agent.customizer.composite_datasource, + "fake_datasource", + agent._permission_service, + agent._ip_white_list_service, + agent.options, ) mocked_crud_related_resource.assert_called_once_with( "fake_datasource", agent._permission_service, agent._ip_white_list_service, agent.options @@ -111,8 +118,16 @@ def test_property_resources( mocked_action_resource.assert_called_once_with( "fake_datasource", agent._permission_service, agent._ip_white_list_service, agent.options ) + mocked_native_query_resource.assert_called_once_with( + agent.customizer.composite_datasource, + "fake_datasource", + agent._permission_service, + agent._ip_white_list_service, + agent.options, + ) - assert len(resources) == 8 + assert len(resources) == 9 + assert "native_query" in resources assert "capabilities" in resources assert "authentication" in resources assert "crud" in resources @@ -126,6 +141,7 @@ def test_property_meta( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -147,6 +163,7 @@ def test_add_datasource( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -167,6 +184,7 @@ def test_add_chart( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -190,6 +208,7 @@ def test_remove_collections( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -210,6 +229,7 @@ def test_customize_datasource( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -234,6 +254,7 @@ def test_start( mocked_create_json_api_schema, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -279,6 +300,7 @@ def test_start_dont_crash_if_schema_generation_or_sending_fail( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -315,6 +337,7 @@ def test_use_should_add_a_plugin( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, diff --git a/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py b/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py index 53d9229a9..9691b8704 100644 --- a/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py +++ b/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py @@ -117,18 +117,6 @@ def test_get_collection_should_list_collection_names_if_collection_not_found(sel "Unknown", ) - def test_should_log_if_multiple_datasources_have_same_name(self): - ds1 = Datasource(name="test") - ds2 = Datasource(name="test") - self.composite_ds.add_datasource(ds1) - with patch("forestadmin.datasource_toolkit.datasource_customizer.datasource_composite.ForestLogger.log") as log: - self.composite_ds.add_datasource(ds2) - log.assert_any_call( - "warning", - "A datasource with the name 'test' already exists. You can use the optional parameter 'name' when " - "creating a datasource.", - ) - class TestCompositeDatasourceCharts(BaseTestCompositeDatasource): def setUp(self) -> None: diff --git a/src/flask_agent/tests/test_flask_agent.py b/src/flask_agent/tests/test_flask_agent.py index c119f82b7..ae8cd5048 100644 --- a/src/flask_agent/tests/test_flask_agent.py +++ b/src/flask_agent/tests/test_flask_agent.py @@ -69,6 +69,7 @@ def test_build_blueprint(self, mocked_blueprint, mock_base_agent_resources, mock calls = [ call("", methods=["GET"]), call("/_internal/capabilities", methods=["POST"]), + call("/_internal/native_query", methods=["POST"]), call("/authentication/callback", methods=["GET"]), call("/_actions////hooks/load", methods=["POST"]), call("/_actions////hooks/change", methods=["POST"]), From b90ac03a19651ff636c164ad37b88fb9c755b5d8 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 27 Nov 2024 16:53:38 +0100 Subject: [PATCH 20/71] chore: fix linting --- .../forestadmin/agent_toolkit/resources/collections/stats.py | 1 - .../agent_toolkit/utils/context_variable_injector.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py index 5b78966f4..1ac2fa581 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py @@ -1,4 +1,3 @@ -import json from datetime import date, datetime from typing import Any, Dict, List, Literal, Optional, Union, cast from uuid import uuid1 diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py index 6a0c6033f..2bcc3c076 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py @@ -43,7 +43,7 @@ def format_query_and_get_vars(value: str, context_variable: ContextVariables) -> def _match(match): variables[match.group(1).replace(".", "__")] = context_variable.get_value(match.group(1)) - return f"%({match.group(1).replace(".", "__")})s" + return f"%({match.group(1).replace('.', '__')})s" ret = re.sub(r"{{([^}]+)}}", _match, value.replace("%", "\\%")) From 3f32b27a6ecde6cf184eb61f027dd4f262ddafec Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 27 Nov 2024 16:55:22 +0100 Subject: [PATCH 21/71] chore: fix linting --- src/agent_toolkit/tests/resources/collections/test_crud.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/agent_toolkit/tests/resources/collections/test_crud.py b/src/agent_toolkit/tests/resources/collections/test_crud.py index 3ed9d8c47..3cb3dcf23 100644 --- a/src/agent_toolkit/tests/resources/collections/test_crud.py +++ b/src/agent_toolkit/tests/resources/collections/test_crud.py @@ -25,7 +25,6 @@ from forestadmin.agent_toolkit.utils.context import Request, RequestMethod, User from forestadmin.agent_toolkit.utils.csv import CsvException from forestadmin.datasource_toolkit.collections import Collection -from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasources import Datasource, DatasourceException from forestadmin.datasource_toolkit.exceptions import ValidationError from forestadmin.datasource_toolkit.interfaces.fields import FieldType, Operator, PrimitiveType From f51ee5eb5c31e00971cd733a2b317491bae9e952 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 27 Nov 2024 16:59:49 +0100 Subject: [PATCH 22/71] chore: fix test --- src/agent_toolkit/tests/resources/collections/test_crud.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/agent_toolkit/tests/resources/collections/test_crud.py b/src/agent_toolkit/tests/resources/collections/test_crud.py index 3cb3dcf23..9abbd8672 100644 --- a/src/agent_toolkit/tests/resources/collections/test_crud.py +++ b/src/agent_toolkit/tests/resources/collections/test_crud.py @@ -25,6 +25,7 @@ from forestadmin.agent_toolkit.utils.context import Request, RequestMethod, User from forestadmin.agent_toolkit.utils.csv import CsvException from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasources import Datasource, DatasourceException from forestadmin.datasource_toolkit.exceptions import ValidationError from forestadmin.datasource_toolkit.interfaces.fields import FieldType, Operator, PrimitiveType @@ -227,6 +228,7 @@ def setUpClass(cls) -> None: ) # cls.datasource = Mock(Datasource) cls.datasource = Datasource() + cls.datasource_composite = CompositeDatasource() cls.datasource.get_collection = lambda x: cls.datasource._collections[x] cls._create_collections() cls.datasource._collections = { @@ -236,6 +238,7 @@ def setUpClass(cls) -> None: "product": cls.collection_product, "tag": cls.collection_tag, } + cls.datasource_composite.add_datasource(cls.datasource) for collection in cls.datasource.collections: create_json_api_schema(collection) From c726991c7690d461ca80e099604397aea6f9ef78 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 28 Nov 2024 16:51:59 +0100 Subject: [PATCH 23/71] chore: add tests for capability route --- .../resources/test_capabilities_resource.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/agent_toolkit/tests/resources/test_capabilities_resource.py b/src/agent_toolkit/tests/resources/test_capabilities_resource.py index de94e0312..0ba6a52eb 100644 --- a/src/agent_toolkit/tests/resources/test_capabilities_resource.py +++ b/src/agent_toolkit/tests/resources/test_capabilities_resource.py @@ -71,7 +71,7 @@ def setUpClass(cls) -> None: is_production=False, ) # type:ignore - cls.datasource = Datasource() + cls.datasource = Datasource(live_query_connections=["test1", "test2"]) Collection.__abstractmethods__ = set() # type:ignore # to instantiate abstract class cls.book_collection = Collection("Book", cls.datasource) # type:ignore cls.book_collection.add_fields( @@ -127,7 +127,7 @@ def test_dispatch_should_not_dispatch_to_capabilities_when_no_post_request(self) self.assertEqual(response.status, 405) - def test_dispatch_should_dispatch_POST_to_capabilities(self): + def test_dispatch_should_return_correct_collection_and_fields_capabilities(self): request = Request( method=RequestMethod.POST, query={}, @@ -156,3 +156,18 @@ def test_dispatch_should_dispatch_POST_to_capabilities(self): } ], ) + + def test_should_return_correct_datasource_connections_capabilities(self): + request = Request( + method=RequestMethod.POST, + query={}, + body={"collectionNames": ["Book"]}, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + response: Response = self.loop.run_until_complete(self.capabilities_resource.dispatch(request, "capabilities")) + self.assertEqual(response.status, 200) + response_content = json.loads(response.body) + + self.assertEqual(response_content["nativeQueryConnections"], [{"name": "test1"}, {"name": "test2"}]) From f9ce5a5edf50cad0b15c942a1e39b0cb04ab48d5 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 28 Nov 2024 16:54:35 +0100 Subject: [PATCH 24/71] chore: add tests for django datasource --- .../datasource_django/datasource.py | 10 +- .../tests/test_django_datasource.py | 114 +++++++++++++++++- 2 files changed, 117 insertions(+), 7 deletions(-) diff --git a/src/datasource_django/forestadmin/datasource_django/datasource.py b/src/datasource_django/forestadmin/datasource_django/datasource.py index ad72bfc08..be68ac1a5 100644 --- a/src/datasource_django/forestadmin/datasource_django/datasource.py +++ b/src/datasource_django/forestadmin/datasource_django/datasource.py @@ -61,17 +61,15 @@ def _create_collections(self): async def execute_native_query( self, connection_name: str, native_query: str, parameters: Dict[str, str] ) -> List[RecordsDataAlias]: - if ( - self._django_live_query_connections is None - or connection_name not in self._django_live_query_connections.keys() - ): + if connection_name not in self._django_live_query_connections.keys(): # TODO: verify - # This one should never occur + # This one should never occur while datasource composite works fine raise DjangoDatasourceException( f"Native query connection '{connection_name}' is not known by DjangoDatasource." ) if self._django_live_query_connections[connection_name] not in connections: + # This one should never occur # TODO: verify raise DjangoDatasourceException( f"Connection to database '{self._django_live_query_connections[connection_name]}' for alias " @@ -82,7 +80,7 @@ async def execute_native_query( def _execute_native_query(): cursor = connections[self._django_live_query_connections[connection_name]].cursor() # type: ignore try: - # replace '\s' by '%%' + # replace '\%' by '%%' # %(var)s is already the correct syntax rows = cursor.execute(native_query.replace("\\%", "%%"), parameters) diff --git a/src/datasource_django/tests/test_django_datasource.py b/src/datasource_django/tests/test_django_datasource.py index dc64674c9..d2cbdb625 100644 --- a/src/datasource_django/tests/test_django_datasource.py +++ b/src/datasource_django/tests/test_django_datasource.py @@ -1,8 +1,10 @@ -from unittest import TestCase +import asyncio from unittest.mock import Mock, call, patch +from django.test import SimpleTestCase, TestCase from forestadmin.datasource_django.collection import DjangoCollection from forestadmin.datasource_django.datasource import DjangoDatasource +from forestadmin.datasource_django.exception import DjangoDatasourceException mock_collection1 = Mock(DjangoCollection) mock_collection1.name = "first" @@ -66,3 +68,113 @@ def test_django_datasource_should_ignore_proxy_models(self): """ignoring proxy models means no collections added twice or more""" datasource = DjangoDatasource() self.assertEqual(len([c.name for c in datasource.collections if c.name == "auth_user"]), 1) + + +class TestDjangoDatasourceConnectionQueryCreation(SimpleTestCase): + def test_should_not_create_native_query_connection_if_no_params(self): + ds = DjangoDatasource() + self.assertEqual(ds.get_native_query_connections(), []) + + def test_should_create_native_query_connection_to_default_if_string_is_set(self): + ds = DjangoDatasource(live_query_connection="django") + self.assertEqual(ds.get_native_query_connections(), ["django"]) + self.assertEqual(ds._django_live_query_connections["django"], "default") + + def test_should_log_when_creating_connection_with_string_param_and_multiple_databases_are_set_up(self): + with patch("forestadmin.datasource_django.datasource.ForestLogger.log") as log_fn: + DjangoDatasource(live_query_connection="django") + # TODO: adapt error message + log_fn.assert_any_call( + "info", + "You enabled live query as django for django 'default' database. " + "To use it over multiple databases, read the related documentation here: http://link.", + ) + + def test_should_raise_if_connection_query_target_non_existent_database(self): + self.assertRaisesRegex( + DjangoDatasourceException, + r"Connection to database 'plouf' for alias 'plif' is not found in django databases\. " + r"Existing connections are default,other", + DjangoDatasource, + live_query_connection={"django": "default", "plif": "plouf"}, + ) + + +class TestDjangoDatasourceNativeQueryExecution(TestCase): + fixtures = ["person.json", "book.json", "rating.json", "tag.json"] + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.loop = asyncio.new_event_loop() + cls.dj_datasource = DjangoDatasource(live_query_connection={"django": "default", "other": "other"}) + + def test_should_raise_if_connection_is_not_known_by_datasource(self): + self.assertRaisesRegex( + DjangoDatasourceException, + r"Native query connection 'foo' is not known by DjangoDatasource.", + self.loop.run_until_complete, + self.dj_datasource.execute_native_query("foo", "select * from blabla", {}), + ) + + async def test_should_correctly_execute_query(self): + result = await self.dj_datasource.execute_native_query( + "django", "select * from test_app_person order by person_pk;", {} + ) + self.assertEqual( + result, + [ + { + "person_pk": 1, + "first_name": "Isaac", + "last_name": "Asimov", + "birth_date": "1920-02-01", + "auth_user_id": None, + }, + { + "person_pk": 2, + "first_name": "J.K.", + "last_name": "Rowling", + "birth_date": "1965-07-31", + "auth_user_id": None, + }, + ], + ) + + async def test_should_correctly_execute_query_with_formatting(self): + result = await self.dj_datasource.execute_native_query( + "django", + "select * from test_app_person where first_name = %(first_name)s order by person_pk;", + {"first_name": "Isaac"}, + ) + self.assertEqual( + result, + [ + { + "person_pk": 1, + "first_name": "Isaac", + "last_name": "Asimov", + "birth_date": "1920-02-01", + "auth_user_id": None, + }, + ], + ) + + async def test_should_correctly_execute_query_with_percent(self): + result = await self.dj_datasource.execute_native_query( + "django", + "select * from test_app_person where first_name like 'Is\\%' order by person_pk;", + {}, + ) + self.assertEqual( + result, + [ + { + "person_pk": 1, + "first_name": "Isaac", + "last_name": "Asimov", + "birth_date": "1920-02-01", + "auth_user_id": None, + }, + ], + ) From 51f710a70338edc87e7ffeacb995737282a55846 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 28 Nov 2024 17:26:35 +0100 Subject: [PATCH 25/71] chore: add last test on django datasource --- src/datasource_django/tests/test_django_datasource.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/datasource_django/tests/test_django_datasource.py b/src/datasource_django/tests/test_django_datasource.py index d2cbdb625..780cf1851 100644 --- a/src/datasource_django/tests/test_django_datasource.py +++ b/src/datasource_django/tests/test_django_datasource.py @@ -178,3 +178,11 @@ async def test_should_correctly_execute_query_with_percent(self): }, ], ) + + def test_should_correctly_raise_exception_during_sql_error(self): + self.assertRaisesRegex( + DjangoDatasourceException, + r"no such table: blabla", + self.loop.run_until_complete, + self.dj_datasource.execute_native_query("django", "select * from blabla", {}), + ) From a89f2d6871ccbc9ad8afeafb08c018792cfbd097 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Fri, 29 Nov 2024 10:09:18 +0100 Subject: [PATCH 26/71] chore: remove useless verification --- .../forestadmin/datasource_django/datasource.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/datasource_django/forestadmin/datasource_django/datasource.py b/src/datasource_django/forestadmin/datasource_django/datasource.py index be68ac1a5..a21772085 100644 --- a/src/datasource_django/forestadmin/datasource_django/datasource.py +++ b/src/datasource_django/forestadmin/datasource_django/datasource.py @@ -62,21 +62,11 @@ async def execute_native_query( self, connection_name: str, native_query: str, parameters: Dict[str, str] ) -> List[RecordsDataAlias]: if connection_name not in self._django_live_query_connections.keys(): - # TODO: verify # This one should never occur while datasource composite works fine raise DjangoDatasourceException( f"Native query connection '{connection_name}' is not known by DjangoDatasource." ) - if self._django_live_query_connections[connection_name] not in connections: - # This one should never occur - # TODO: verify - raise DjangoDatasourceException( - f"Connection to database '{self._django_live_query_connections[connection_name]}' for alias " - f"'{connection_name}' is not found in django connections. " - f"Existing connections are {','.join([*connections])}" - ) - def _execute_native_query(): cursor = connections[self._django_live_query_connections[connection_name]].cursor() # type: ignore try: From 2a1cfa301dfaba49328539643cd53804e7fbbdc6 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Fri, 29 Nov 2024 10:10:06 +0100 Subject: [PATCH 27/71] chore: update example project --- .../django/django_demo/app/forest_admin.py | 14 ++++++++++++-- .../django/django_demo/django_demo/settings.py | 8 ++++---- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/_example/django/django_demo/app/forest_admin.py b/src/_example/django/django_demo/app/forest_admin.py index a3cef98cf..ee01871cc 100644 --- a/src/_example/django/django_demo/app/forest_admin.py +++ b/src/_example/django/django_demo/app/forest_admin.py @@ -50,12 +50,22 @@ def customize_forest(agent: DjangoAgent): # customize_forest_logging() - agent.add_datasource(DjangoDatasource(support_polymorphic_relations=True, live_query_connection="django")) + + agent.add_datasource( + DjangoDatasource( + support_polymorphic_relations=True, live_query_connection={"django": "default", "dj_sqlachemy": "other"} + ) + ) agent.add_datasource(TypicodeDatasource()) - agent.add_datasource(SqlAlchemyDatasource(Base, DB_URI, live_query_connection="sqlalchemy")) + agent.add_datasource( + SqlAlchemyDatasource(Base, DB_URI, live_query_connection="sqlalchemy"), + ) agent.customize_collection("address").add_segment("France", segment_addr_fr("address")) agent.customize_collection("app_address").add_segment("France", segment_addr_fr("app_address")) + agent.customize_collection("app_customer_blocked_customer").rename_field("from_customer", "from").rename_field( + "to_customer", "to" + ) # # ## ADDRESS agent.customize_collection("app_address").add_segment( diff --git a/src/_example/django/django_demo/django_demo/settings.py b/src/_example/django/django_demo/django_demo/settings.py index 6018be425..2efd1ea59 100644 --- a/src/_example/django/django_demo/django_demo/settings.py +++ b/src/_example/django/django_demo/django_demo/settings.py @@ -123,10 +123,10 @@ "NAME": os.path.join(BASE_DIR, "db.sqlite3"), "ATOMIC_REQUESTS": True, }, - # "other": { - # "ENGINE": "django.db.backends.sqlite3", - # "NAME": os.path.join(BASE_DIR, "db_flask_example.sqlite"), - # }, + "other": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": os.path.join(BASE_DIR, "db_sqlalchemy.sql"), + }, } DATABASE_ROUTERS = ["django_demo.db_router.DBRouter"] From eda63fcd8a4a19898e849781aa0628d8f1f723bf Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Fri, 29 Nov 2024 16:15:59 +0100 Subject: [PATCH 28/71] chore: add tests on sqlalchemy datasource --- .../django/django_demo/app/forest_admin.py | 3 - .../tests/fixture/models.py | 16 +++- .../tests/test_sqlalchemy_datasource.py | 95 ++++++++++++++++++- 3 files changed, 108 insertions(+), 6 deletions(-) diff --git a/src/_example/django/django_demo/app/forest_admin.py b/src/_example/django/django_demo/app/forest_admin.py index ee01871cc..76dc72bf9 100644 --- a/src/_example/django/django_demo/app/forest_admin.py +++ b/src/_example/django/django_demo/app/forest_admin.py @@ -63,9 +63,6 @@ def customize_forest(agent: DjangoAgent): agent.customize_collection("address").add_segment("France", segment_addr_fr("address")) agent.customize_collection("app_address").add_segment("France", segment_addr_fr("app_address")) - agent.customize_collection("app_customer_blocked_customer").rename_field("from_customer", "from").rename_field( - "to_customer", "to" - ) # # ## ADDRESS agent.customize_collection("app_address").add_segment( diff --git a/src/datasource_sqlalchemy/tests/fixture/models.py b/src/datasource_sqlalchemy/tests/fixture/models.py index ffb6e102a..43b437719 100644 --- a/src/datasource_sqlalchemy/tests/fixture/models.py +++ b/src/datasource_sqlalchemy/tests/fixture/models.py @@ -3,9 +3,11 @@ import os from datetime import date, datetime +import sqlalchemy # type: ignore from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String, create_engine, func, types -from sqlalchemy.orm import Session, declarative_base, relationship, validates +from sqlalchemy.orm import Session, relationship, validates +use_sqlalchemy_2 = sqlalchemy.__version__.split(".")[0] == "2" test_db_path = os.path.abspath(os.path.join(__file__, "..", "..", "..", "..", "..", "test_db.sql")) engine = create_engine(f"sqlite:///{test_db_path}", echo=False) fixtures_dir = os.path.abspath(os.path.join(__file__, "..")) @@ -42,7 +44,17 @@ def __import__(cls, d): return cls(**params) -Base = declarative_base(cls=_Base) +if use_sqlalchemy_2: + from sqlalchemy.orm import DeclarativeBase + + class Base(DeclarativeBase, _Base): + pass + +else: + from sqlalchemy.orm import declarative_base + + Base = declarative_base(cls=_Base) + Base.metadata.bind = engine diff --git a/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py b/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py index ec5662028..b853de5ba 100644 --- a/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py +++ b/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py @@ -1,8 +1,9 @@ +import asyncio +import os from unittest import TestCase from unittest.mock import Mock, patch from flask import Flask -from flask_sqlalchemy import SQLAlchemy from forestadmin.datasource_sqlalchemy.datasource import SqlAlchemyDatasource from forestadmin.datasource_sqlalchemy.exceptions import SqlAlchemyDatasourceException from sqlalchemy.orm import DeclarativeMeta @@ -50,6 +51,8 @@ def test_create_datasources_not_search_engine_when_db_uri_is_supply(self): @patch("forestadmin.datasource_sqlalchemy.datasource.sessionmaker") def test_create_datasource_with_flask_sqlalchemy_integration_should_find_engine(self, mocked_sessionmaker): + from flask_sqlalchemy import SQLAlchemy + app = Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///memory" db = SQLAlchemy() @@ -93,3 +96,93 @@ def test_with_models(self): assert len(datasource._collections) == 4 assert datasource.get_collection("address").datasource == datasource + + +class TestSQLAlchemyDatasourceConnectionQueryCreation(TestCase): + def test_should_not_create_native_query_connection_if_no_params(self): + ds = SqlAlchemyDatasource(models.Base) + self.assertEqual(ds.get_native_query_connections(), []) + + def test_should_create_native_query_connection_to_default_if_string_is_set(self): + ds = SqlAlchemyDatasource(models.Base, live_query_connection="sqlalchemy") + self.assertEqual(ds.get_native_query_connections(), ["sqlalchemy"]) + + +class TestSQLAlchemyDatasourceNativeQueryExecution(TestCase): + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.loop = asyncio.new_event_loop() + if os.path.exists(models.test_db_path): + os.remove(models.test_db_path) + models.create_test_database() + models.load_fixtures() + cls.sql_alchemy_datasource = SqlAlchemyDatasource(models.Base, live_query_connection="sqlalchemy") + + def test_should_raise_if_connection_is_not_known_by_datasource(self): + self.assertRaisesRegex( + SqlAlchemyDatasourceException, + r"The native query connection 'foo' doesn't belongs to this datasource.", + self.loop.run_until_complete, + self.sql_alchemy_datasource.execute_native_query("foo", "select * from blabla", {}), + ) + + def test_should_correctly_execute_query(self): + result = self.loop.run_until_complete( + self.sql_alchemy_datasource.execute_native_query( + "sqlalchemy", "select * from customer where id <= 2 order by id;", {} + ) + ) + self.assertEqual( + result, + [ + {"id": 1, "first_name": "David", "last_name": "Myers", "age": 112}, + {"id": 2, "first_name": "Thomas", "last_name": "Odom", "age": 92}, + ], + ) + + def test_should_correctly_execute_query_with_formatting(self): + result = self.loop.run_until_complete( + self.sql_alchemy_datasource.execute_native_query( + "sqlalchemy", + """select * + from customer + where first_name = %(first_name)s + and last_name = %(last_name)s + order by id""", + {"first_name": "David", "last_name": "Myers"}, + ) + ) + self.assertEqual( + result, + [ + {"id": 1, "first_name": "David", "last_name": "Myers", "age": 112}, + ], + ) + + def test_should_correctly_execute_query_with_percent(self): + result = self.loop.run_until_complete( + self.sql_alchemy_datasource.execute_native_query( + "sqlalchemy", + """select * + from customer + where first_name like 'Dav\\%' + order by id""", + {}, + ) + ) + + self.assertEqual( + result, + [ + {"id": 1, "first_name": "David", "last_name": "Myers", "age": 112}, + ], + ) + + def test_should_correctly_raise_exception_during_sql_error(self): + self.assertRaisesRegex( + SqlAlchemyDatasourceException, + r"no such table: blabla", + self.loop.run_until_complete, + self.sql_alchemy_datasource.execute_native_query("sqlalchemy", "select * from blabla", {}), + ) From ad6510af14cb368d7bcb0d456805d14208b39a07 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Fri, 29 Nov 2024 17:37:26 +0100 Subject: [PATCH 29/71] chore: fix concurency issue with test database --- .../tests/fixture/models.py | 18 +++++++++++------- .../tests/test_sqlalchemy_collections.py | 15 ++++++++------- .../tests/test_sqlalchemy_datasource.py | 14 ++++++++++---- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/datasource_sqlalchemy/tests/fixture/models.py b/src/datasource_sqlalchemy/tests/fixture/models.py index 43b437719..9f5bdb56a 100644 --- a/src/datasource_sqlalchemy/tests/fixture/models.py +++ b/src/datasource_sqlalchemy/tests/fixture/models.py @@ -8,8 +8,6 @@ from sqlalchemy.orm import Session, relationship, validates use_sqlalchemy_2 = sqlalchemy.__version__.split(".")[0] == "2" -test_db_path = os.path.abspath(os.path.join(__file__, "..", "..", "..", "..", "..", "test_db.sql")) -engine = create_engine(f"sqlite:///{test_db_path}", echo=False) fixtures_dir = os.path.abspath(os.path.join(__file__, "..")) # to import/export json as fixtures @@ -55,7 +53,13 @@ class Base(DeclarativeBase, _Base): Base = declarative_base(cls=_Base) -Base.metadata.bind = engine + +def get_models_base(db_file_name): + test_db_path = os.path.abspath(os.path.join(__file__, "..", "..", "..", "..", "..", f"{db_file_name}.sql")) + engine = create_engine(f"sqlite:///{test_db_path}", echo=False) + Base.metadata.bind = engine + Base.metadata.file_path = test_db_path + return Base class ORDER_STATUS(str, enum.Enum): @@ -116,7 +120,7 @@ class CustomersAddresses(Base): address_id = Column(Integer, ForeignKey("address.id"), primary_key=True) -def load_fixtures(): +def load_fixtures(base): with open(os.path.join(fixtures_dir, "addresses.json"), "r") as fin: data = json.load(fin) addresses = [Address.__import__(d) for d in data] @@ -133,7 +137,7 @@ def load_fixtures(): data = json.load(fin) orders = [Order.__import__(d) for d in data] - with Session(Base.metadata.bind) as session: + with Session(base.metadata.bind) as session: session.bulk_save_objects(addresses) session.bulk_save_objects(customers) session.bulk_save_objects(customers_addresses) @@ -141,5 +145,5 @@ def load_fixtures(): session.commit() -def create_test_database(): - Base.metadata.create_all(engine) +def create_test_database(base): + base.metadata.create_all(base.metadata.bind) diff --git a/src/datasource_sqlalchemy/tests/test_sqlalchemy_collections.py b/src/datasource_sqlalchemy/tests/test_sqlalchemy_collections.py index b0e27d722..366515f8d 100644 --- a/src/datasource_sqlalchemy/tests/test_sqlalchemy_collections.py +++ b/src/datasource_sqlalchemy/tests/test_sqlalchemy_collections.py @@ -143,10 +143,11 @@ class TestSqlAlchemyCollectionWithModels(TestCase): @classmethod def setUpClass(cls): cls.loop = asyncio.new_event_loop() - if os.path.exists(models.test_db_path): - os.remove(models.test_db_path) - models.create_test_database() - models.load_fixtures() + cls.sql_alchemy_base = models.get_models_base("test_collection_operations") + if os.path.exists(cls.sql_alchemy_base.metadata.file_path): + os.remove(cls.sql_alchemy_base.metadata.file_path) + models.create_test_database(cls.sql_alchemy_base) + models.load_fixtures(cls.sql_alchemy_base) cls.datasource = SqlAlchemyDatasource(models.Base) cls.mocked_caller = User( rendering_id=1, @@ -162,7 +163,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - os.remove(models.test_db_path) + os.remove(cls.sql_alchemy_base.metadata.file_path) cls.loop.close() def test_get_columns(self): @@ -306,7 +307,7 @@ def test_aggregate(self): def test_get_native_driver_should_return_connection(self): with self.datasource.get_collection("order").get_native_driver() as connection: self.assertIsInstance(connection, Session) - self.assertEqual(str(connection.bind.url), f"sqlite:///{models.test_db_path}") + self.assertEqual(str(connection.bind.url), f"sqlite:///{self.sql_alchemy_base.metadata.file_path}") rows = connection.execute(text('select id,amount from "order" where id = 3')).all() self.assertEqual(rows, [(3, 5285)]) @@ -314,7 +315,7 @@ def test_get_native_driver_should_return_connection(self): def test_get_native_driver_should_work_without_declaring_request_as_text(self): with self.datasource.get_collection("order").get_native_driver() as connection: self.assertIsInstance(connection, Session) - self.assertEqual(str(connection.bind.url), f"sqlite:///{models.test_db_path}") + self.assertEqual(str(connection.bind.url), f"sqlite:///{self.sql_alchemy_base.metadata.file_path}") rows = connection.execute('select id,amount from "order" where id = 3').all() self.assertEqual(rows, [(3, 5285)]) diff --git a/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py b/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py index b853de5ba..362602c41 100644 --- a/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py +++ b/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py @@ -113,12 +113,18 @@ class TestSQLAlchemyDatasourceNativeQueryExecution(TestCase): def setUpClass(cls) -> None: super().setUpClass() cls.loop = asyncio.new_event_loop() - if os.path.exists(models.test_db_path): - os.remove(models.test_db_path) - models.create_test_database() - models.load_fixtures() + cls.sql_alchemy_base = models.get_models_base("test_datasource_native_query") + if os.path.exists(cls.sql_alchemy_base.metadata.file_path): + os.remove(cls.sql_alchemy_base.metadata.file_path) + models.create_test_database(cls.sql_alchemy_base) + models.load_fixtures(cls.sql_alchemy_base) cls.sql_alchemy_datasource = SqlAlchemyDatasource(models.Base, live_query_connection="sqlalchemy") + @classmethod + def tearDownClass(cls): + os.remove(cls.sql_alchemy_base.metadata.file_path) + cls.loop.close() + def test_should_raise_if_connection_is_not_known_by_datasource(self): self.assertRaisesRegex( SqlAlchemyDatasourceException, From 84e6d808a01871120d5949f9103951467c2ae6ce Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 2 Dec 2024 17:01:11 +0100 Subject: [PATCH 30/71] chore: fixes after front and backend pluging --- .../agent_toolkit/resources/collections/crud.py | 9 --------- .../resources/collections/native_query.py | 15 ++++++++++++--- .../resources/context_variable_injector_mixin.py | 4 ++-- .../services/permissions/permission_service.py | 7 ++----- .../services/permissions/permissions_functions.py | 8 ++++---- 5 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py index 0bea1ff53..14c1e8098 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py @@ -453,15 +453,6 @@ def _serialize_records_with_relationships( async def _handle_live_query_segment( self, request: RequestCollection, condition_tree: Optional[ConditionTree] ) -> Optional[ConditionTree]: - # TODO: remove connectionName mock - if request.collection.name.startswith("app_"): - request.query["connectionName"] = "django" - elif request.collection.name.startswith("sqlalchemy_"): - request.query["connectionName"] = "dj_sqlachemy" - else: - request.query["connectionName"] = "sqlalchemy" - # TODO: remove connectionName mock - if request.query.get("segmentQuery") is not None: if "connectionName" not in request.query: raise BusinessError("Missing 'connectionName' parameter.") diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py index 4859d1967..16cc328e8 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -1,4 +1,5 @@ from typing import Literal, Union +from uuid import uuid4 from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.agent_toolkit.options import Options @@ -51,7 +52,15 @@ async def handle_native_query(self, request: Request) -> Response: variables = await self.inject_and_get_context_variables_in_live_query_chart(request) return HttpResponseBuilder.build_success_response( - await self.composite_datasource.execute_native_query( - request.body["connectionName"], request.body["query"], variables - ) + { + "data": { + "id": str(uuid4()), + "type": "stats", + "attributes": { + "value": await self.composite_datasource.execute_native_query( + request.body["connectionName"], request.body["query"], variables + ), + }, + } + } ) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py index af1c93953..fd06ca205 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py @@ -66,7 +66,7 @@ async def inject_and_get_context_variables_in_live_query_chart(self, request: "R request.body["query"], context_variables ) - request.query["query"], vars = ContextVariableInjector.format_query_and_get_vars( - request.query["query"], context_variables + request.body["query"], vars = ContextVariableInjector.format_query_and_get_vars( + request.body["query"], context_variables ) return vars diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py index 999d102b2..b69b1368a 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py @@ -80,12 +80,9 @@ async def can_chart(self, request: RequestCollection) -> bool: async def can_live_query_segment(self, request: RequestCollection) -> bool: live_query = request.query["segmentQuery"] - # connection_name = request.query["connectionName"] + connection_name = request.query["connectionName"] hash_live_query = _dict_hash( - { - "query": live_query, - # "connection_name": connection_name # TODO: review when connectionName in permissions - } + {"query": live_query, "connection_name": connection_name} # TODO: review when connectionName in permissions ) is_allowed = hash_live_query in (await self._get_segment_queries(request.user.rendering_id, False)).get( request.collection.name diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py index 28de72167..b3b39dc33 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py @@ -16,12 +16,12 @@ def _decode_segment_query_permissions(raw_permission: Dict[Any, Any]): segment_queries = {} for collection_name, value in raw_permission.items(): segment_queries[collection_name] = [] - for segment_query in value.get("segments", []): + for segment_query in value.get("liveQuerySegments", []): segment_queries[collection_name].append( _dict_hash( { - "query": segment_query, - # "connection_name": connection_name # TODO: review when connectionName in permissions + "query": segment_query["query"], + "connection_name": segment_query["connectionName"], } ) ) @@ -74,7 +74,7 @@ def _dict_hash(data: Dict[Any, Any]) -> str: def _hash_chart(chart: Dict[Any, Any]) -> str: known_chart_keys = [ - # "connectionName", # TODO: to enable with next backend version of permissions + "connectionName", "type", "apiRoute", "smartRoute", From 8ce15fa8827ad86af735ee61450811ffe10cb4d1 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 3 Dec 2024 11:02:34 +0100 Subject: [PATCH 31/71] chore: handle every type of chart --- .../resources/collections/native_query.py | 81 +++++++++++++++++-- .../permissions/permission_service.py | 5 +- .../datasource_toolkit/interfaces/chart.py | 2 - 3 files changed, 77 insertions(+), 11 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py index 16cc328e8..21a8861d5 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -1,4 +1,4 @@ -from typing import Literal, Union +from typing import Any, Dict, List, Literal, Optional, Union from uuid import uuid4 from forestadmin.agent_toolkit.forest_logger import ForestLogger @@ -11,7 +11,15 @@ from forestadmin.agent_toolkit.utils.context import HttpResponseBuilder, Request, RequestMethod, Response from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasource_customizer.datasource_customizer import DatasourceCustomizer -from forestadmin.datasource_toolkit.exceptions import BusinessError +from forestadmin.datasource_toolkit.exceptions import BusinessError, UnprocessableError, ValidationError +from forestadmin.datasource_toolkit.interfaces.chart import ( + Chart, + DistributionChart, + LeaderboardChart, + ObjectiveChart, + TimeBasedChart, + ValueChart, +) from forestadmin.datasource_toolkit.interfaces.models.collections import BoundCollection, Datasource DatasourceAlias = Union[Datasource[BoundCollection], DatasourceCustomizer] @@ -49,18 +57,81 @@ async def handle_native_query(self, request: Request) -> Response: raise BusinessError("Missing 'connectionName' in parameter.") if "query" not in request.body: raise BusinessError("Missing 'query' in parameter.") + if request.body.get("type") not in ["Line", "Objective", "Leaderboard", "Pie", "Value"]: + raise ValidationError(f"Unknown chart type '{request.body.get("type")}'.") variables = await self.inject_and_get_context_variables_in_live_query_chart(request) + native_query_results = await self.composite_datasource.execute_native_query( + request.body["connectionName"], request.body["query"], variables + ) + + chart_result: Chart + if request.body["type"] == "Line": + chart_result = self._handle_line_chart(native_query_results) + elif request.body["type"] == "Objective": + chart_result = self._handle_objective_chart(native_query_results) + elif request.body["type"] == "Leaderboard": + chart_result = self._handle_leaderboard_chart(native_query_results) + elif request.body["type"] == "Pie": + chart_result = self._handle_pie_chart(native_query_results) + elif request.body["type"] == "Value": + chart_result = self._handle_value_chart(native_query_results) + return HttpResponseBuilder.build_success_response( { "data": { "id": str(uuid4()), "type": "stats", "attributes": { - "value": await self.composite_datasource.execute_native_query( - request.body["connectionName"], request.body["query"], variables - ), + "value": chart_result, # type:ignore }, } } ) + + def _handle_line_chart(self, native_query_results: List[Dict[str, Any]]) -> TimeBasedChart: + if len(native_query_results) >= 1: + if "key" not in native_query_results[0] or "value" not in native_query_results[0]: + raise UnprocessableError("Native query for 'Line' chart must return 'key' and 'value' fields.") + + return [{"label": res["key"], "values": {"value": res["value"]}} for res in native_query_results] + + def _handle_objective_chart(self, native_query_results: List[Dict[str, Any]]) -> ObjectiveChart: + if len(native_query_results) == 1: + if "value" not in native_query_results[0] or "objective" not in native_query_results[0]: + raise UnprocessableError( + "Native query for 'Objective' chart must return 'value' and 'objective' fields." + ) + else: + raise UnprocessableError("Native query for 'Objective' chart must return only one row.") + + return { + "value": native_query_results[0]["value"], + "objective": native_query_results[0]["objective"], + } + + def _handle_leaderboard_chart(self, native_query_results: List[Dict[str, Any]]) -> LeaderboardChart: + if len(native_query_results) >= 1: + if "key" not in native_query_results[0] or "value" not in native_query_results[0]: + raise UnprocessableError("Native query for 'Leaderboard' chart must return 'key' and 'value' fields.") + + return [{"key": res["key"], "value": res["value"]} for res in native_query_results] + + def _handle_pie_chart(self, native_query_results: List[Dict[str, Any]]) -> DistributionChart: + if len(native_query_results) >= 1: + if "key" not in native_query_results[0] or "value" not in native_query_results[0]: + raise UnprocessableError("Native query for 'Pie' chart must return 'key' and 'value' fields.") + + return [{"key": res["key"], "value": res["value"]} for res in native_query_results] + + def _handle_value_chart(self, native_query_results: List[Dict[str, Any]]) -> ValueChart: + if len(native_query_results) == 1: + if "value" not in native_query_results[0]: + raise UnprocessableError("Native query for 'Value' chart must return 'value' field.") + else: + raise UnprocessableError("Native query for 'Value' chart must return only one row.") + + ret = {"countCurrent": native_query_results[0]["value"]} + if "previous" in native_query_results[0]: + ret["countPrevious"] = native_query_results[0]["previous"] + return ret # type:ignore diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py index b69b1368a..e941804b4 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py @@ -59,7 +59,6 @@ async def can(self, caller: User, collection: Collection, action: str, allow_fet return is_allowed async def can_chart(self, request: RequestCollection) -> bool: - # TODO: verify after new permissions hash_request = request.body["type"] + ":" + _hash_chart(request.body) is_allowed = hash_request in await self._get_chart_data(request.user.rendering_id, False) @@ -81,9 +80,7 @@ async def can_chart(self, request: RequestCollection) -> bool: async def can_live_query_segment(self, request: RequestCollection) -> bool: live_query = request.query["segmentQuery"] connection_name = request.query["connectionName"] - hash_live_query = _dict_hash( - {"query": live_query, "connection_name": connection_name} # TODO: review when connectionName in permissions - ) + hash_live_query = _dict_hash({"query": live_query, "connection_name": connection_name}) is_allowed = hash_live_query in (await self._get_segment_queries(request.user.rendering_id, False)).get( request.collection.name ) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/chart.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/chart.py index f4170d17a..5d5f1b400 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/chart.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/chart.py @@ -31,5 +31,3 @@ LeaderboardChart, SmartChart, ] - -a: TimeBasedChart = [{"label": "&", "values": {"value"}}] From ce1f66f02e4cae7a707f3f60612ca98c2d8f5c51 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 3 Dec 2024 15:16:39 +0100 Subject: [PATCH 32/71] chore: improve error handling --- .../agent_toolkit/resources/collections/crud.py | 12 ++++-------- .../resources/collections/native_query.py | 4 ++-- .../services/permissions/permissions_functions.py | 1 - .../forestadmin/datasource_django/datasource.py | 8 +++----- .../tests/test_django_datasource.py | 5 +++-- .../forestadmin/datasource_sqlalchemy/datasource.py | 7 +++---- .../tests/test_sqlalchemy_datasource.py | 5 +++-- .../context/relaxed_wrappers/collection.py | 3 +++ .../datasource_customizer/datasource_composite.py | 4 ++-- .../forestadmin/datasource_toolkit/exceptions.py | 12 ++++++++++++ 10 files changed, 35 insertions(+), 26 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py index 14c1e8098..99d7d6cd4 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py @@ -35,7 +35,7 @@ from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasource_customizer.datasource_customizer import DatasourceCustomizer from forestadmin.datasource_toolkit.datasources import Datasource, DatasourceException -from forestadmin.datasource_toolkit.exceptions import BusinessError, ForbiddenError +from forestadmin.datasource_toolkit.exceptions import BusinessError, ForbiddenError, NativeQueryException from forestadmin.datasource_toolkit.interfaces.fields import ( ManyToOne, OneToOne, @@ -454,8 +454,8 @@ async def _handle_live_query_segment( self, request: RequestCollection, condition_tree: Optional[ConditionTree] ) -> Optional[ConditionTree]: if request.query.get("segmentQuery") is not None: - if "connectionName" not in request.query: - raise BusinessError("Missing 'connectionName' parameter.") + if "connectionName" not in request.query or request.query["connectionName"] == "": + raise NativeQueryException("Missing 'connectionName' parameter.") await self.permission.can_live_query_segment(request) variables = await self.inject_and_get_context_variables_in_live_query_segment(request) @@ -469,11 +469,7 @@ async def _handle_live_query_segment( if condition_tree: trees.append(condition_tree) trees.append( - ConditionTreeLeaf( - SchemaUtils.get_primary_keys(request.collection.schema)[0], - Operator.IN, - [entry["id"] for entry in native_query_result], - ) + ConditionTreeFactory.match_ids(request.collection.schema, [[*r.values()] for r in native_query_result]) ) return ConditionTreeFactory.intersect(trees) return condition_tree diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py index 21a8861d5..54b06c73d 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -53,8 +53,8 @@ async def dispatch(self, request: Request, method_name: Literal["native_query"]) async def handle_native_query(self, request: Request) -> Response: await self.permission.can_chart(request) assert request.body is not None - if "connectionName" not in request.body: - raise BusinessError("Missing 'connectionName' in parameter.") + if "connectionName" not in request.body or request.body["connectionName"] is None: + raise BusinessError("Setting a 'Native query connection' is mandatory.") if "query" not in request.body: raise BusinessError("Missing 'query' in parameter.") if request.body.get("type") not in ["Line", "Objective", "Leaderboard", "Pie", "Value"]: diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py index b3b39dc33..259f0f40d 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py @@ -11,7 +11,6 @@ ################## -# TODO: adapt after new version of permissions api def _decode_segment_query_permissions(raw_permission: Dict[Any, Any]): segment_queries = {} for collection_name, value in raw_permission.items(): diff --git a/src/datasource_django/forestadmin/datasource_django/datasource.py b/src/datasource_django/forestadmin/datasource_django/datasource.py index a21772085..c8ae66964 100644 --- a/src/datasource_django/forestadmin/datasource_django/datasource.py +++ b/src/datasource_django/forestadmin/datasource_django/datasource.py @@ -8,6 +8,7 @@ from forestadmin.datasource_django.collection import DjangoCollection from forestadmin.datasource_django.exception import DjangoDatasourceException from forestadmin.datasource_django.interface import BaseDjangoDatasource +from forestadmin.datasource_toolkit.exceptions import NativeQueryException from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias @@ -63,9 +64,7 @@ async def execute_native_query( ) -> List[RecordsDataAlias]: if connection_name not in self._django_live_query_connections.keys(): # This one should never occur while datasource composite works fine - raise DjangoDatasourceException( - f"Native query connection '{connection_name}' is not known by DjangoDatasource." - ) + raise NativeQueryException(f"Native query connection '{connection_name}' is not known by DjangoDatasource.") def _execute_native_query(): cursor = connections[self._django_live_query_connections[connection_name]].cursor() # type: ignore @@ -85,7 +84,6 @@ def _execute_native_query(): ret.append(return_row) return ret except Exception as e: - # TODO: verify - raise DjangoDatasourceException(str(e)) + raise NativeQueryException(str(e)) return await sync_to_async(_execute_native_query)() diff --git a/src/datasource_django/tests/test_django_datasource.py b/src/datasource_django/tests/test_django_datasource.py index 780cf1851..02e66c036 100644 --- a/src/datasource_django/tests/test_django_datasource.py +++ b/src/datasource_django/tests/test_django_datasource.py @@ -5,6 +5,7 @@ from forestadmin.datasource_django.collection import DjangoCollection from forestadmin.datasource_django.datasource import DjangoDatasource from forestadmin.datasource_django.exception import DjangoDatasourceException +from forestadmin.datasource_toolkit.exceptions import NativeQueryException mock_collection1 = Mock(DjangoCollection) mock_collection1.name = "first" @@ -111,7 +112,7 @@ def setUpClass(cls) -> None: def test_should_raise_if_connection_is_not_known_by_datasource(self): self.assertRaisesRegex( - DjangoDatasourceException, + NativeQueryException, r"Native query connection 'foo' is not known by DjangoDatasource.", self.loop.run_until_complete, self.dj_datasource.execute_native_query("foo", "select * from blabla", {}), @@ -181,7 +182,7 @@ async def test_should_correctly_execute_query_with_percent(self): def test_should_correctly_raise_exception_during_sql_error(self): self.assertRaisesRegex( - DjangoDatasourceException, + NativeQueryException, r"no such table: blabla", self.loop.run_until_complete, self.dj_datasource.execute_native_query("django", "select * from blabla", {}), diff --git a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py index de967c909..82284b1ef 100644 --- a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py +++ b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py @@ -3,6 +3,7 @@ from forestadmin.datasource_sqlalchemy.collections import SqlAlchemyCollection from forestadmin.datasource_sqlalchemy.exceptions import SqlAlchemyDatasourceException from forestadmin.datasource_sqlalchemy.interfaces import BaseSqlAlchemyDatasource +from forestadmin.datasource_toolkit.exceptions import NativeQueryException from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias from sqlalchemy import create_engine, text from sqlalchemy.orm import Mapper, sessionmaker @@ -61,8 +62,7 @@ async def execute_native_query( self, connection_name: str, native_query: str, parameters: Dict[str, str] ) -> List[RecordsDataAlias]: if connection_name != self.get_native_query_connections()[0]: - # TODO: verify - raise SqlAlchemyDatasourceException( + raise NativeQueryException( f"The native query connection '{connection_name}' doesn't belongs to this datasource." ) try: @@ -80,8 +80,7 @@ async def execute_native_query( rows = session.execute(query, parameters) return [*rows.mappings()] except Exception as exc: - # TODO: verify - raise SqlAlchemyDatasourceException(str(exc)) + raise NativeQueryException(str(exc)) # unused code, can be use full but can be remove # from forestadmin.datasource_toolkit.datasources import DatasourceException diff --git a/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py b/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py index 362602c41..a1cf7982d 100644 --- a/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py +++ b/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py @@ -6,6 +6,7 @@ from flask import Flask from forestadmin.datasource_sqlalchemy.datasource import SqlAlchemyDatasource from forestadmin.datasource_sqlalchemy.exceptions import SqlAlchemyDatasourceException +from forestadmin.datasource_toolkit.exceptions import NativeQueryException from sqlalchemy.orm import DeclarativeMeta from .fixture import models @@ -127,7 +128,7 @@ def tearDownClass(cls): def test_should_raise_if_connection_is_not_known_by_datasource(self): self.assertRaisesRegex( - SqlAlchemyDatasourceException, + NativeQueryException, r"The native query connection 'foo' doesn't belongs to this datasource.", self.loop.run_until_complete, self.sql_alchemy_datasource.execute_native_query("foo", "select * from blabla", {}), @@ -187,7 +188,7 @@ def test_should_correctly_execute_query_with_percent(self): def test_should_correctly_raise_exception_during_sql_error(self): self.assertRaisesRegex( - SqlAlchemyDatasourceException, + NativeQueryException, r"no such table: blabla", self.loop.run_until_complete, self.sql_alchemy_datasource.execute_native_query("sqlalchemy", "select * from blabla", {}), diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py index 7d116345c..166ec77f6 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py @@ -59,6 +59,9 @@ def schema(self): def name(self): return self.datasource.name + def get_native_query_connections(self): + raise NotImplementedError + async def execute_native_query( self, connection_name: str, native_query: str, parameters: Dict[str, str] ) -> List[Dict[str, Any]]: diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py index 06c98a878..f03cb55ca 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py @@ -2,7 +2,7 @@ from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.datasources import Datasource -from forestadmin.datasource_toolkit.exceptions import BusinessError, DatasourceToolkitException +from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException, NativeQueryException from forestadmin.datasource_toolkit.interfaces.chart import Chart from forestadmin.datasource_toolkit.interfaces.models.collections import BoundCollection, DatasourceSchema @@ -75,7 +75,7 @@ async def execute_native_query(self, connection_name: str, native_query: str, pa if connection_name in datasource.get_native_query_connections(): return await datasource.execute_native_query(connection_name, native_query, parameters) - raise BusinessError( + raise NativeQueryException( f"Cannot find connection '{connection_name}' in datasources. " f"Existing connection names are: {','.join(self.get_native_query_connections())}" ) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/exceptions.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/exceptions.py index 462226e21..1c91bf651 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/exceptions.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/exceptions.py @@ -25,6 +25,18 @@ def __init__( super().__init__(message, *args) +class NativeQueryException(BusinessError): + def __init__( + self, + message: str = "", + headers: Optional[Dict[str, Any]] = None, + name: str = "NativeQueryError", + data: Optional[Dict[str, Any]] = {}, + *args: object, + ) -> None: + super().__init__(message, headers, name, data, *args) + + class ValidationError(BusinessError): def __init__( self, From c764794d64dfad09abe1bed0c8f6c85822933345 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 3 Dec 2024 15:40:37 +0100 Subject: [PATCH 33/71] chore: add native_query tests on composite datasource --- .../test_composite_datasource.py | 70 ++++++++++++++++++- 1 file changed, 67 insertions(+), 3 deletions(-) diff --git a/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py b/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py index 9691b8704..cac1af06b 100644 --- a/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py +++ b/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py @@ -12,7 +12,7 @@ from forestadmin.datasource_toolkit.collections import Collection from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasources import Datasource -from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException +from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException, NativeQueryException from forestadmin.datasource_toolkit.interfaces.fields import Column, FieldType, PrimitiveType @@ -39,8 +39,25 @@ def setUp(self) -> None: self.ds1_charts = {"charts": {"chart1": Mock()}} self.ds2_charts = {"charts": {"chart2": Mock()}} - DS1 = type("DS1", (Datasource,), {"schema": PropertyMock(return_value=self.ds1_charts)}) - DS2 = type("DS2", (Datasource,), {"schema": PropertyMock(return_value=self.ds2_charts)}) + self.ds1_connections = ["db1", "db2"] + self.ds2_connections = ["db3"] + + DS1 = type( + "DS1", + (Datasource,), + { + "schema": PropertyMock(return_value=self.ds1_charts), + "get_native_query_connections": Mock(return_value=self.ds1_connections), + }, + ) + DS2 = type( + "DS2", + (Datasource,), + { + "schema": PropertyMock(return_value=self.ds2_charts), + "get_native_query_connections": Mock(return_value=self.ds2_connections), + }, + ) self.datasource_1: Datasource = DS1() self.collection_person = Collection("Person", self.datasource_1) @@ -160,3 +177,50 @@ def test_render_chart_should_call_render_chart_on_good_datasource(self): with patch.object(self.datasource_2, "render_chart", new_callable=AsyncMock) as mock_render_chart: self.loop.run_until_complete(self.composite_ds.render_chart(self.mocked_caller, "chart2")) mock_render_chart.assert_awaited_with(self.mocked_caller, "chart2") + + +class TestCompositeDatasourceNativeQuery(BaseTestCompositeDatasource): + + def setUp(self) -> None: + super().setUp() + self.composite_ds.add_datasource(self.datasource_1) + + def test_add_datasource_should_raise_if_duplicated_live_query_connection(self): + with patch.object(self.datasource_2, "get_native_query_connections", return_value=["db1"]): + self.assertRaisesRegex( + DatasourceToolkitException, + r"Native query connection 'db1' already exists.", + self.composite_ds.add_datasource, + self.datasource_2, + ) + + def test_get_native_query_connection_should_return_all_connections(self): + self.composite_ds.add_datasource(self.datasource_2) + connections = self.composite_ds.get_native_query_connections() + self.assertIn("db1", connections) + self.assertIn("db2", connections) + self.assertIn("db3", connections) + + def test_execute_native_query_should_raise_if_connection_is_unknown(self): + self.composite_ds.add_datasource(self.datasource_2) + + self.assertRaisesRegex( + NativeQueryException, + r"Cannot find connection 'bla' in datasources. Existing connection names are: db1,db2,db3", + self.loop.run_until_complete, + self.composite_ds.execute_native_query("bla", "select * from ...", {}), + ) + + def test_execute_native_query_should_call_to_correct_datasource(self): + self.composite_ds.add_datasource(self.datasource_2) + + with patch.object(self.datasource_1, "execute_native_query", new_callable=AsyncMock) as mock_ds1_exec: + self.loop.run_until_complete(self.composite_ds.execute_native_query("db1", "select * from ...", {})) + mock_ds1_exec.assert_any_await("db1", "select * from ...", {}) + + self.loop.run_until_complete(self.composite_ds.execute_native_query("db2", "select * from ...", {})) + mock_ds1_exec.assert_any_await("db2", "select * from ...", {}) + + with patch.object(self.datasource_2, "execute_native_query", new_callable=AsyncMock) as mock_ds2_exec: + self.loop.run_until_complete(self.composite_ds.execute_native_query("db3", "select * from ...", {})) + mock_ds2_exec.assert_any_await("db3", "select * from ...", {}) From 4245b1b47288b4cf94010cf0485ce4bf0d6abc6a Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 3 Dec 2024 17:42:01 +0100 Subject: [PATCH 34/71] chore: add tests on native query resource --- .../resources/collections/native_query.py | 2 +- .../context_variable_injector_mixin.py | 4 - .../collections/test_native_query.py | 854 ++++++++++++++++++ 3 files changed, 855 insertions(+), 5 deletions(-) create mode 100644 src/agent_toolkit/tests/resources/collections/test_native_query.py diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py index 54b06c73d..39b64fcde 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Union from uuid import uuid4 from forestadmin.agent_toolkit.forest_logger import ForestLogger diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py index fd06ca205..61b4a4fee 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py @@ -62,10 +62,6 @@ async def inject_and_get_context_variables_in_live_query_chart(self, request: "R request.user, context_variables_dct, self.permission ) - request.body["query"] = ContextVariableInjector.inject_context_in_value( - request.body["query"], context_variables - ) - request.body["query"], vars = ContextVariableInjector.format_query_and_get_vars( request.body["query"], context_variables ) diff --git a/src/agent_toolkit/tests/resources/collections/test_native_query.py b/src/agent_toolkit/tests/resources/collections/test_native_query.py new file mode 100644 index 000000000..6ea865d31 --- /dev/null +++ b/src/agent_toolkit/tests/resources/collections/test_native_query.py @@ -0,0 +1,854 @@ +import asyncio +import importlib +import json +import sys +from unittest import TestCase +from unittest.mock import AsyncMock, Mock, patch + +import forestadmin.agent_toolkit.resources.collections.native_query + +if sys.version_info >= (3, 9): + import zoneinfo +else: + from backports import zoneinfo + +import forestadmin.agent_toolkit.resources.collections.crud +from forestadmin.agent_toolkit.options import Options +from forestadmin.agent_toolkit.services.permissions.ip_whitelist_service import IpWhiteListService +from forestadmin.agent_toolkit.services.permissions.permission_service import PermissionService +from forestadmin.agent_toolkit.utils.context import Request, RequestMethod, User +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource +from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.exceptions import ForbiddenError + + +def authenticate_mock(fn): + async def wrapped2(self, request): + request.user = User( + rendering_id=1, + user_id=1, + tags={}, + email="dummy@user.fr", + first_name="dummy", + last_name="user", + team="operational", + timezone=zoneinfo.ZoneInfo("Europe/Paris"), + request={"ip": "127.0.0.1"}, + ) + + return await fn(self, request) + + return wrapped2 + + +def ip_white_list_mock(fn): + async def wrapped(self, request: Request, *args, **kwargs): + return await fn(self, request, *args, **kwargs) + + return wrapped + + +patch("forestadmin.agent_toolkit.resources.collections.decorators.authenticate", authenticate_mock).start() +patch("forestadmin.agent_toolkit.resources.collections.decorators.ip_white_list", ip_white_list_mock).start() +# how to mock decorators, and why they are not testable : +# https://dev.to/stack-labs/how-to-mock-a-decorator-in-python-55jc + +importlib.reload(forestadmin.agent_toolkit.resources.collections.native_query) +from forestadmin.agent_toolkit.resources.collections.native_query import NativeQueryResource # noqa: E402 + + +class TestNativeQueryResourceBase(TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.loop = asyncio.new_event_loop() + cls.options = Options( + auth_secret="fake_secret", + env_secret="fake_secret", + server_url="http://fake:5000", + prefix="", + is_production=False, + ) + # cls.datasource = Mock(Datasource) + cls.datasource = Datasource(["db1", "db2"]) + cls.datasource_composite = CompositeDatasource() + cls.datasource_composite.add_datasource(cls.datasource) + + def setUp(self): + self.datasource_composite.execute_native_query = AsyncMock() + self.ip_white_list_service = Mock(IpWhiteListService) + self.ip_white_list_service.is_enable = AsyncMock(return_value=False) + + self.permission_service = Mock(PermissionService) + self.permission_service.can_chart = AsyncMock(return_value=None) + + self.native_query_resource = NativeQueryResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + + +class TestNativeQueryResourceOnError(TestNativeQueryResourceBase): + def test_should_return_error_if_cannot_chart_on_permission(self): + request = Request( + method=RequestMethod.POST, + headers={}, + body={}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.permission_service, + "can_chart", + new_callable=AsyncMock, + side_effect=ForbiddenError("You don't have permission to access this chart."), + ) as mock_can_chart: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + self.assertEqual(response.status, 403) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body, + { + "errors": [ + { + "name": "ForbiddenError", + "detail": "You don't have permission to access this chart.", + "status": 403, + "data": {}, + } + ] + }, + ) + + mock_can_chart.assert_awaited_once_with(request) + + def test_should_return_error_if_connectionName_is_not_here(self): + request = Request( + method=RequestMethod.POST, + headers={}, + body={}, + client_ip="127.0.0.1", + query={}, + ) + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + self.assertEqual(response.status, 400) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body, + { + "errors": [ + { + "name": "ValidationError", + "detail": "Setting a 'Native query connection' is mandatory.", + "status": 400, + "data": {}, + } + ] + }, + ) + + def test_should_return_error_if_query_is_not_here(self): + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1"}, + client_ip="127.0.0.1", + query={}, + ) + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + self.assertEqual(response.status, 400) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body, + { + "errors": [ + { + "name": "ValidationError", + "detail": "Missing 'query' in parameter.", + "status": 400, + "data": {}, + } + ] + }, + ) + + def test_should_return_error_if_chart_type_is_unknown_or_missing(self): + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": "select count(*) from orders;"}, + client_ip="127.0.0.1", + query={}, + ) + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + self.assertEqual(response.status, 400) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body, + { + "errors": [ + { + "name": "ValidationError", + "detail": "Unknown chart type 'None'.", + "status": 400, + "data": {}, + } + ] + }, + ) + + request.body["type"] = "unknown" + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + self.assertEqual(response.status, 400) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body, + { + "errors": [ + { + "name": "ValidationError", + "detail": "Unknown chart type 'unknown'.", + "status": 400, + "data": {}, + } + ] + }, + ) + + +class TestNativeQueryResourceValueChart(TestNativeQueryResourceBase): + def test_should_correctly_handle_value_chart(self): + native_query = "select count(*) as value from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Value"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, "execute_native_query", new_callable=AsyncMock, return_value=[{"value": 100}] + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 200) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertIn("id", response_body["data"]) + self.assertEqual(response_body["data"]["type"], "stats") + self.assertEqual(response_body["data"]["attributes"], {"value": {"countCurrent": 100}}) + + def test_should_return_error_if_value_query_return_fields_are_not_good(self): + native_query = "select count(*) as not_value from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Value"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, "execute_native_query", new_callable=AsyncMock, return_value=[{"not_value": 100}] + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Value' chart must return 'value' field.", + "status": 422, + "data": {}, + } + ], + ) + + def test_should_return_error_if_value_query_does_not_return_one_row(self): + native_query = "select count(*) as not_value from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Value"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"value": 100}, {"value": 100}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Value' chart must return only one row.", + "status": 422, + "data": {}, + } + ], + ) + + def test_should_correctly_handle_value_chart_with_previous(self): + native_query = "select count(*) as value, 0 as previous from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Value"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"value": 100, "previous": 0}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 200) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertIn("id", response_body["data"]) + self.assertEqual(response_body["data"]["type"], "stats") + self.assertEqual(response_body["data"]["attributes"], {"value": {"countCurrent": 100, "countPrevious": 0}}) + + +class TestNativeQueryResourceLineChart(TestNativeQueryResourceBase): + def test_should_correctly_handle_line_chart(self): + native_query = "select count(*) as value, date as key from orders group by date, order by date;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Line"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"key": "2020-01", "value": 100}, {"key": "2020-02", "value": 110}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 200) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertIn("id", response_body["data"]) + self.assertEqual(response_body["data"]["type"], "stats") + self.assertEqual( + response_body["data"]["attributes"], + {"value": [{"label": "2020-01", "values": {"value": 100}}, {"label": "2020-02", "values": {"value": 110}}]}, + ) + + def test_should_return_error_if_line_query_return_fields_are_not_good(self): + native_query = "select count(*) as not_value, date as key from orders group by date, order by date;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Line"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"key": "2020-01", "not_value": 100}, {"key": "2020-02", "not_value": 110}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Line' chart must return 'key' and 'value' fields.", + "status": 422, + "data": {}, + } + ], + ) + + native_query = "select count(*) as value, date as not_key from orders group by date, order by date;" + request.body["query"] = native_query + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"not_key": "2020-01", "value": 100}, {"not_key": "2020-02", "value": 110}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Line' chart must return 'key' and 'value' fields.", + "status": 422, + "data": {}, + } + ], + ) + + +class TestNativeQueryResourceObjectiveChart(TestNativeQueryResourceBase): + def test_should_correctly_handle_objective_chart(self): + native_query = "select count(*) as value, 1000 as objective from orders " + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Objective"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"value": 150, "objective": 1000}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 200) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertIn("id", response_body["data"]) + self.assertEqual(response_body["data"]["type"], "stats") + self.assertEqual(response_body["data"]["attributes"], {"value": {"value": 150, "objective": 1000}}) + + def test_should_return_error_if_objective_query_return_fields_are_not_good(self): + native_query = "select count(*) as not_value, 1000 as objective from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Objective"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"not_value": 150, "objective": 1000}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Objective' chart must return 'value' and 'objective' fields.", + "status": 422, + "data": {}, + } + ], + ) + + native_query = "select count(*) as value, 1000 as not_objective from orders;" + request.body["query"] = native_query + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"value": 150, "not_objective": 1000}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Objective' chart must return 'value' and 'objective' fields.", + "status": 422, + "data": {}, + } + ], + ) + + def test_should_return_error_if_objective_query_does_not_return_one_row(self): + native_query = "select count(*) as value, 1000 as objective from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Objective"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"value": 100, "objective": 1000}, {"value": 100, "objective": 1000}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Objective' chart must return only one row.", + "status": 422, + "data": {}, + } + ], + ) + + +class TestNativeQueryResourceLeaderboardChart(TestNativeQueryResourceBase): + def test_should_correctly_handle_leaderboard_chart(self): + native_query = "select sum(score) as value, customer as key from results group by customer order by value desc;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Leaderboard"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[ + {"value": 150, "key": "Jean"}, + {"value": 140, "key": "Elsa"}, + {"value": 0, "key": "Gautier"}, + ], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 200) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertIn("id", response_body["data"]) + self.assertEqual(response_body["data"]["type"], "stats") + self.assertEqual( + response_body["data"]["attributes"], + {"value": [{"key": "Jean", "value": 150}, {"key": "Elsa", "value": 140}, {"key": "Gautier", "value": 0}]}, + ) + + def test_should_return_error_if_leaderboard_query_return_fields_are_not_good(self): + native_query = ( + "select sum(score) as not_value, customer as key from results group by customer order by value desc;" + ) + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Leaderboard"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[ + {"not_value": 150, "key": "Jean"}, + {"not_value": 140, "key": "Elsa"}, + {"not_value": 0, "key": "Gautier"}, + ], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Leaderboard' chart must return 'key' and 'value' fields.", + "status": 422, + "data": {}, + } + ], + ) + + native_query = ( + "select sum(score) as value, customer as not_key from results group by customer order by value desc;" + ) + request.body["query"] = native_query + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[ + {"value": 150, "not_key": "Jean"}, + {"value": 140, "not_key": "Elsa"}, + {"value": 0, "not_key": "Gautier"}, + ], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Leaderboard' chart must return 'key' and 'value' fields.", + "status": 422, + "data": {}, + } + ], + ) + + +class TestNativeQueryResourcePieChart(TestNativeQueryResourceBase): + def test_should_correctly_handle_pie_chart(self): + native_query = "select count(*) as value, status as key from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Pie"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[ + {"value": 150, "key": "pending"}, + {"value": 140, "key": "delivering"}, + {"value": 10, "key": "lost"}, + ], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 200) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertIn("id", response_body["data"]) + self.assertEqual(response_body["data"]["type"], "stats") + self.assertEqual( + response_body["data"]["attributes"], + { + "value": [ + {"value": 150, "key": "pending"}, + {"value": 140, "key": "delivering"}, + {"value": 10, "key": "lost"}, + ] + }, + ) + + def test_should_return_error_if_pie_query_return_fields_are_not_good(self): + native_query = "select count(*) as not_value, status as key from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Pie"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[ + {"not_value": 150, "key": "pending"}, + {"not_value": 140, "key": "delivering"}, + {"not_value": 10, "key": "lost"}, + ], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Pie' chart must return 'key' and 'value' fields.", + "status": 422, + "data": {}, + } + ], + ) + + native_query = "select count(*) as value, status as not_key from orders;" + request.body["query"] = native_query + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[ + {"value": 150, "not_key": "pending"}, + {"value": 140, "not_key": "delivering"}, + {"value": 10, "not_key": "lost"}, + ], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Pie' chart must return 'key' and 'value' fields.", + "status": 422, + "data": {}, + } + ], + ) + + +class TestNativeQueryResourceVariableConectextVariables(TestNativeQueryResourceBase): + + def test_should_correctly_handle_variable_context(self): + native_query = "select count(*) as value, status as key from orders where customer = {{recordId}};" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Pie", "contextVariables": {"recordId": 1}}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[], + ) as mock_exec_native_query: + self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with( + "db1", + "select count(*) as value, status as key from orders where customer = %(recordId)s;", + {"recordId": 1}, + ) + + def test_should_correctly_handle_variable_context_and_like_percent_comparison(self): + native_query = ( + "select count(*) as value, status as key from orders where customer = {{recordId}} " + "and customer_name like '%henry%';" + ) + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Pie", "contextVariables": {"recordId": 1}}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[], + ) as mock_exec_native_query: + self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with( + "db1", + "select count(*) as value, status as key from orders where customer = %(recordId)s " + "and customer_name like '\\%henry\\%';", + {"recordId": 1}, + ) From 180d9327cc2e9db9c32143a287cde85f8028568b Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 3 Dec 2024 17:44:53 +0100 Subject: [PATCH 35/71] chore: fix linting --- .../agent_toolkit/resources/collections/native_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py index 39b64fcde..323ffca31 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -58,7 +58,7 @@ async def handle_native_query(self, request: Request) -> Response: if "query" not in request.body: raise BusinessError("Missing 'query' in parameter.") if request.body.get("type") not in ["Line", "Objective", "Leaderboard", "Pie", "Value"]: - raise ValidationError(f"Unknown chart type '{request.body.get("type")}'.") + raise ValidationError(f"Unknown chart type '{request.body.get('type')}'.") variables = await self.inject_and_get_context_variables_in_live_query_chart(request) native_query_results = await self.composite_datasource.execute_native_query( From c577da6140e7b29b67aa39ee507dcf86df1b62b0 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 4 Dec 2024 13:59:10 +0100 Subject: [PATCH 36/71] chore: add tests for segments --- .../resources/collections/crud.py | 2 +- .../resources/collections/native_query.py | 2 +- .../context_variable_injector_mixin.py | 3 - .../services/permissions/permissions_types.py | 1 + .../tests/resources/collections/test_crud.py | 350 +++++++++++++++++- 5 files changed, 337 insertions(+), 21 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py index 99d7d6cd4..d8fd4a8f6 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py @@ -454,7 +454,7 @@ async def _handle_live_query_segment( self, request: RequestCollection, condition_tree: Optional[ConditionTree] ) -> Optional[ConditionTree]: if request.query.get("segmentQuery") is not None: - if "connectionName" not in request.query or request.query["connectionName"] == "": + if request.query.get("connectionName") in ["", None]: raise NativeQueryException("Missing 'connectionName' parameter.") await self.permission.can_live_query_segment(request) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py index 323ffca31..6b7ebd8e6 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -53,7 +53,7 @@ async def dispatch(self, request: Request, method_name: Literal["native_query"]) async def handle_native_query(self, request: Request) -> Response: await self.permission.can_chart(request) assert request.body is not None - if "connectionName" not in request.body or request.body["connectionName"] is None: + if request.body.get("connectionName") is None: raise BusinessError("Setting a 'Native query connection' is mandatory.") if "query" not in request.body: raise BusinessError("Missing 'query' in parameter.") diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py index 61b4a4fee..d6c6b73f1 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py @@ -42,9 +42,6 @@ async def inject_context_variables_in_filter(self, request: "RequestCollection") async def inject_and_get_context_variables_in_live_query_segment( self, request: "RequestCollection" ) -> Dict[str, str]: - # TODO: handle context variables from front or not ?? - if request.query.get("segmentQuery") is None: - return {} context_variables_dct = request.query.pop("contextVariables", {}) context_variables = await ContextVariablesInstantiator.build_context_variables( diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_types.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_types.py index 5163890b9..34117a132 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_types.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_types.py @@ -3,6 +3,7 @@ from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import ConditionTreeBranch +# this file is never imported class PermissionBody(TypedDict): actions: Set[str] actions_by_user: Dict[str, Set[int]] diff --git a/src/agent_toolkit/tests/resources/collections/test_crud.py b/src/agent_toolkit/tests/resources/collections/test_crud.py index 9abbd8672..3495b4aed 100644 --- a/src/agent_toolkit/tests/resources/collections/test_crud.py +++ b/src/agent_toolkit/tests/resources/collections/test_crud.py @@ -27,26 +27,29 @@ from forestadmin.datasource_toolkit.collections import Collection from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasources import Datasource, DatasourceException -from forestadmin.datasource_toolkit.exceptions import ValidationError +from forestadmin.datasource_toolkit.exceptions import ForbiddenError, NativeQueryException, ValidationError from forestadmin.datasource_toolkit.interfaces.fields import FieldType, Operator, PrimitiveType +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import ConditionTreeBranch from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf from forestadmin.datasource_toolkit.interfaces.query.projections import Projection from forestadmin.datasource_toolkit.validations.records import RecordValidatorException +FAKE_USER = User( + rendering_id=1, + user_id=1, + tags={}, + email="dummy@user.fr", + first_name="dummy", + last_name="user", + team="operational", + timezone=zoneinfo.ZoneInfo("Europe/Paris"), + request={"ip": "127.0.0.1"}, +) + def authenticate_mock(fn): async def wrapped2(self, request): - request.user = User( - rendering_id=1, - user_id=1, - tags={}, - email="dummy@user.fr", - first_name="dummy", - last_name="user", - team="operational", - timezone=zoneinfo.ZoneInfo("Europe/Paris"), - request={"ip": "127.0.0.1"}, - ) + request.user = FAKE_USER return await fn(self, request) @@ -227,7 +230,7 @@ def setUpClass(cls) -> None: is_production=False, ) # cls.datasource = Mock(Datasource) - cls.datasource = Datasource() + cls.datasource = Datasource(["db_connection"]) cls.datasource_composite = CompositeDatasource() cls.datasource.get_collection = lambda x: cls.datasource._collections[x] cls._create_collections() @@ -249,6 +252,7 @@ def setUp(self): self.permission_service = Mock(PermissionService) self.permission_service.get_scope = AsyncMock(return_value=ConditionTreeLeaf("id", Operator.GREATER_THAN, 0)) self.permission_service.can = AsyncMock(return_value=None) + self.permission_service.can_live_query_segment = AsyncMock(return_value=None) @classmethod def tearDownClass(cls) -> None: @@ -1163,9 +1167,9 @@ def test_list_with_polymorphic_many_to_one_should_query_all_relation_record_colu ) def test_list_should_parse_multi_field_sorting(self, mocked_json_serializer_get: Mock): mock_orders = [ - {"id": 10, "cost": 200, "important": "02_PENDING"}, - {"id": 11, "cost": 201, "important": "02_PENDING"}, - {"id": 13, "cost": 20, "important": "01_URGENT"}, + {"id": 10, "cost": 200, "important": True}, + {"id": 11, "cost": 201, "important": True}, + {"id": 13, "cost": 20, "important": False}, ] request = RequestCollection( RequestMethod.GET, @@ -1204,6 +1208,39 @@ def test_list_should_parse_multi_field_sorting(self, mocked_json_serializer_get: self.assertEqual(paginated_filter.sort[0], {"field": "important", "ascending": True}) self.assertEqual(paginated_filter.sort[1], {"field": "cost", "ascending": False}) + def test_list_should_handle_live_query_segment(self): + mock_orders = [{"id": 10, "cost": 200}, {"id": 11, "cost": 201}] + + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost", + "search": "test", + "segmentName": "test_live_query", + "segmentQuery": "select id from order where important is true;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + self.collection_order.list = AsyncMock(return_value=mock_orders) + + with patch.object( + crud_resource, "_handle_live_query_segment", new_callable=AsyncMock + ) as mock_handle_live_queries: + self.loop.run_until_complete(crud_resource.list(request)) + mock_handle_live_queries.assert_awaited_once_with(request, ConditionTreeLeaf("id", "greater_than", 0)) + @patch( "forestadmin.agent_toolkit.resources.collections.crud.JsonApiSerializer.get", return_value=Mock, @@ -1321,6 +1358,37 @@ def test_count(self): assert response_content["count"] == 0 self.collection_order.aggregate.assert_called() + def test_count_should_handle_live_query_segment(self): + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost", + "search": "test", + "segmentName": "test_live_query", + "segmentQuery": "select id from order where important is true;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + self.collection_order.aggregate = AsyncMock(return_value=[{"value": 1000, "group": {}}]) + + with patch.object( + crud_resource, "_handle_live_query_segment", new_callable=AsyncMock + ) as mock_handle_live_queries: + self.loop.run_until_complete(crud_resource.count(request)) + mock_handle_live_queries.assert_awaited_once_with(request, ConditionTreeLeaf("id", "greater_than", 0)) + def test_deactivate_count(self): request = RequestCollection( RequestMethod.GET, @@ -1860,3 +1928,253 @@ def test_csv_should_not_apply_pagination(self): self.assertIsNone(self.collection_order.list.await_args[0][1].page) self.collection_order.list.assert_awaited() self.assertIsNone(self.collection_order.list.await_args[0][1].page) + + def test_csv_should_handle_live_query_segment(self): + mock_orders = [{"id": 10, "cost": 200}, {"id": 11, "cost": 201}] + + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost", + "search": "test", + "segmentName": "test_live_query", + "segmentQuery": "select id from order where important is true;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + self.collection_order.list = AsyncMock(return_value=mock_orders) + + with patch.object( + crud_resource, "_handle_live_query_segment", new_callable=AsyncMock + ) as mock_handle_live_queries: + self.loop.run_until_complete(crud_resource.csv(request)) + mock_handle_live_queries.assert_awaited_once_with(request, ConditionTreeLeaf("id", "greater_than", 0)) + + # live queries + + def test_handle_native_query_should_handle_live_query_segments(self): + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost,important", + "segmentName": "test_live_query", + "segmentQuery": "select id from order where important is true;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + user=FAKE_USER, + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"id": 10}, {"id": 11}], + ) as mock_exec_native_query: + condition_tree = self.loop.run_until_complete(crud_resource._handle_live_query_segment(request, None)) + self.assertEqual(condition_tree, ConditionTreeLeaf("id", "in", [10, 11])) + mock_exec_native_query.assert_awaited_once_with( + "db_connection", "select id from order where important is true;", {} + ) + + def test_handle_native_query_should_inject_context_variable_and_handle_like_percent(self): + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost,important", + "segmentName": "test_live_query", + "segmentQuery": "select id from user where first_name like 'Ga%' or id = {{currentUser.id}};", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + user=FAKE_USER, + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + with patch.object( + self.permission_service, + "get_user_data", + new_callable=AsyncMock, + return_value={ + "id": 1, + "firstName": "dummy", + "lastName": "user", + "fullName": "dummy user", + "email": "dummy@user.fr", + "tags": {}, + "roleId": 8, + "permissionLevel": "admin", + }, + ): + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"id": 10}, {"id": 11}], + ) as mock_exec_native_query: + condition_tree = self.loop.run_until_complete(crud_resource._handle_live_query_segment(request, None)) + self.assertEqual(condition_tree, ConditionTreeLeaf("id", "in", [10, 11])) + mock_exec_native_query.assert_awaited_once_with( + "db_connection", + "select id from user where first_name like 'Ga\\%' or id = %(currentUser__id)s;", + {"currentUser__id": 1}, + ) + + def test_handle_native_query_should_intersect_existing_condition_tree(self): + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost,important", + "segmentName": "test_live_query", + "segmentQuery": "select id from user where id=10 or id=11;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + user=FAKE_USER, + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"id": 10}, {"id": 11}], + ): + condition_tree = self.loop.run_until_complete( + crud_resource._handle_live_query_segment(request, ConditionTreeLeaf("id", "equal", 25)) + ) + + self.assertEqual( + condition_tree, + ConditionTreeBranch( + "and", + [ + ConditionTreeLeaf("id", "equal", 25), + ConditionTreeLeaf("id", "in", [10, 11]), + ], + ), + ) + + def test_handle_native_query_should_raise_error_if_live_query_params_are_incorrect(self): + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost,important", + "segmentName": "test_live_query", + "segmentQuery": "select id from order where important is true;", + }, + headers={}, + client_ip="127.0.0.1", + user=FAKE_USER, + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + + self.assertRaisesRegex( + NativeQueryException, + "Missing 'connectionName' parameter.", + self.loop.run_until_complete, + crud_resource._handle_live_query_segment(request, None), + ) + + request.query["connectionName"] = None + self.assertRaisesRegex( + NativeQueryException, + "Missing 'connectionName' parameter.", + self.loop.run_until_complete, + crud_resource._handle_live_query_segment(request, None), + ) + + request.query["connectionName"] = "" + self.assertRaisesRegex( + NativeQueryException, + "Missing 'connectionName' parameter.", + self.loop.run_until_complete, + crud_resource._handle_live_query_segment(request, None), + ) + + def test_handle_native_query_should_raise_error_if_not_permission(self): + + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost,important", + "segmentName": "test_live_query", + "segmentQuery": "select id from order where important is true;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + user=FAKE_USER, + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + + with patch.object( + self.permission_service, + "can_live_query_segment", + new_callable=AsyncMock, + side_effect=ForbiddenError("You don't have permission to access this segment query."), + ): + self.assertRaisesRegex( + ForbiddenError, + "You don't have permission to access this segment query.", + self.loop.run_until_complete, + crud_resource._handle_live_query_segment(request, None), + ) From 5c13d6603c219b4d9b7adcb4e93f1fd0bd3a95bf Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 4 Dec 2024 15:39:16 +0100 Subject: [PATCH 37/71] chore: details on pk field for segments --- .../resources/collections/crud.py | 12 +++++-- .../tests/resources/collections/test_crud.py | 36 +++++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py index d8fd4a8f6..81231e23a 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py @@ -61,6 +61,7 @@ from forestadmin.datasource_toolkit.interfaces.query.projections.factory import ProjectionFactory from forestadmin.datasource_toolkit.interfaces.records import CompositeIdAlias, RecordsDataAlias from forestadmin.datasource_toolkit.utils.collections import CollectionUtils +from forestadmin.datasource_toolkit.utils.records import RecordUtils from forestadmin.datasource_toolkit.utils.schema import SchemaUtils from forestadmin.datasource_toolkit.validations.field import FieldValidatorException from forestadmin.datasource_toolkit.validations.records import RecordValidator, RecordValidatorException @@ -462,14 +463,19 @@ async def _handle_live_query_segment( native_query_result = await self._datasource_composite.execute_native_query( request.query["connectionName"], request.query["segmentQuery"], variables ) - if len(native_query_result) > 0 and "id" not in native_query_result[0]: - raise BusinessError("Live query must return an 'id' field.") + + pk_field = SchemaUtils.get_primary_keys(request.collection.schema)[0] + if len(native_query_result) > 0 and pk_field not in native_query_result[0]: + raise NativeQueryException(f"Live query must return the primary key field ('{pk_field}').") trees = [] if condition_tree: trees.append(condition_tree) trees.append( - ConditionTreeFactory.match_ids(request.collection.schema, [[*r.values()] for r in native_query_result]) + ConditionTreeFactory.match_ids( + request.collection.schema, + [RecordUtils.get_primary_key(request.collection.schema, r) for r in native_query_result], + ) ) return ConditionTreeFactory.intersect(trees) return condition_tree diff --git a/src/agent_toolkit/tests/resources/collections/test_crud.py b/src/agent_toolkit/tests/resources/collections/test_crud.py index 3495b4aed..33d9f4e83 100644 --- a/src/agent_toolkit/tests/resources/collections/test_crud.py +++ b/src/agent_toolkit/tests/resources/collections/test_crud.py @@ -2178,3 +2178,39 @@ def test_handle_native_query_should_raise_error_if_not_permission(self): self.loop.run_until_complete, crud_resource._handle_live_query_segment(request, None), ) + + def test_handle_native_query_should_raise_error_if_pk_not_returned(self): + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost,important", + "segmentName": "test_live_query", + "segmentQuery": "select id as bla, cost from order where important is true;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + user=FAKE_USER, + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"bla": 10, "cost": 100}, {"bla": 11, "cost": 100}], + ): + self.assertRaisesRegex( + NativeQueryException, + r"Live query must return the primary key field \('id'\).", + self.loop.run_until_complete, + crud_resource._handle_live_query_segment(request, None), + ) From cdbf0165fbdaef48dc44439b00dbba5b35687fde Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 4 Dec 2024 15:56:57 +0100 Subject: [PATCH 38/71] chore: fix linting --- .../forestadmin/agent_toolkit/resources/collections/crud.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py index 81231e23a..adff49245 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py @@ -35,7 +35,7 @@ from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasource_customizer.datasource_customizer import DatasourceCustomizer from forestadmin.datasource_toolkit.datasources import Datasource, DatasourceException -from forestadmin.datasource_toolkit.exceptions import BusinessError, ForbiddenError, NativeQueryException +from forestadmin.datasource_toolkit.exceptions import ForbiddenError, NativeQueryException from forestadmin.datasource_toolkit.interfaces.fields import ( ManyToOne, OneToOne, From 324c22f17b5926bf1a7ead63af321cd45b3823da Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 4 Dec 2024 18:02:47 +0100 Subject: [PATCH 39/71] chore: add tests on permissions --- .../permissions/test_permission_service.py | 184 +++++++++++++++++- 1 file changed, 182 insertions(+), 2 deletions(-) diff --git a/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py b/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py index 921fa77ce..9f840d12e 100644 --- a/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py +++ b/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py @@ -38,7 +38,7 @@ class BaseTestPermissionService(TestCase): @classmethod def setUpClass(cls) -> None: cls.loop = asyncio.new_event_loop() - cls.datasource = Datasource() + cls.datasource = Datasource(["database_1"]) Collection.__abstractmethods__ = set() # to instantiate abstract class cls.booking_collection = Collection("Booking", cls.datasource) cls.booking_collection.add_fields( @@ -174,7 +174,18 @@ def mock_forest_http_api(self, scope=None) -> PatchHttpApiDict: "get_rendering_permissions", new_callable=AsyncMock, return_value={ - "collections": {"Booking": {"scope": scope, "segments": []}}, + "collections": { + "Booking": { + "scope": scope, + "segments": ["select id from booking where title is null"], + "liveQuerySegments": [ + { + "query": "select id from booking where title is null", + "connectionName": "database_1", + } + ], + } + }, "stats": [ { "type": "Pie", @@ -191,6 +202,11 @@ def mock_forest_http_api(self, scope=None) -> PatchHttpApiDict: "aggregateFieldName": None, "sourceCollectionName": "Booking", }, + { + "type": "Pie", + "query": "select sum(amount) as value, status as key from app_order group by status", + "connectionName": "database_1", + }, ], "team": { "id": 1, @@ -428,6 +444,87 @@ def test_can_chart_should_raise_forbidden_error_on_not_allowed_chart(self): self.permission_service.can_chart(request), ) + def test_can_chart_should_handle_live_query_chart(self): + http_patches: PatchHttpApiDict = self.mock_forest_http_api() + rendering_permission_mock: AsyncMock = http_patches["get_rendering_permissions"].start() + + request = RequestCollection( + method=RequestMethod.POST, + collection=self.booking_collection, + body={ + "connectionName": "database_1", + "contextVariables": {}, + "query": "select sum(amount) as value, status as key from app_order group by status", + "type": "Pie", + }, + query={}, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + + is_allowed = self.loop.run_until_complete(self.permission_service.can_chart(request)) + rendering_permission_mock.assert_awaited_once() + self.assertTrue(is_allowed) + + http_patches["get_rendering_permissions"].stop() + + def test_can_chart_should_raise_forbidden_error_on_not_wrong_live_query_connection(self): + http_patches: PatchHttpApiDict = self.mock_forest_http_api() + http_patches["get_rendering_permissions"].start() + + request = RequestCollection( + method=RequestMethod.POST, + collection=self.booking_collection, + body={ + "connectionName": "wrong_database", + "contextVariables": {}, + "query": "select sum(amount) as value, status as key from app_order group by status", + "type": "Pie", + }, + query={}, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + + self.assertRaisesRegex( + ForbiddenError, + r"🌳🌳🌳You don't have permission to access this chart.", + self.loop.run_until_complete, + self.permission_service.can_chart(request), + ) + + http_patches["get_rendering_permissions"].stop() + + def test_can_chart_should_raise_forbidden_error_on_mismatching_query(self): + http_patches: PatchHttpApiDict = self.mock_forest_http_api() + http_patches["get_rendering_permissions"].start() + + request = RequestCollection( + method=RequestMethod.POST, + collection=self.booking_collection, + body={ + "connectionName": "database_1", + "contextVariables": {}, + "query": "select sum(amount) as value, status as key from app_order group by status ; ", + "type": "Pie", + }, + query={}, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + + self.assertRaisesRegex( + ForbiddenError, + r"🌳🌳🌳You don't have permission to access this chart.", + self.loop.run_until_complete, + self.permission_service.can_chart(request), + ) + + http_patches["get_rendering_permissions"].stop() + class Test04GetScopePermissionService(BaseTestPermissionService): def test_get_scope_should_return_null_when_no_scope_in_permissions(self): @@ -649,3 +746,86 @@ def test_can_smart_action_should_throw_when_action_is_unknown(self): http_patches["get_environment_permissions"].stop() http_patches["get_users"].stop() + + +class Test06CanLiveQuerySegment(BaseTestPermissionService): + def test_can_live_query_segment_should_allow_correct_live_query_segment(self): + http_patches: PatchHttpApiDict = self.mock_forest_http_api() + rendering_permission_mock: AsyncMock = http_patches["get_rendering_permissions"].start() + + request = RequestCollection( + method=RequestMethod.GET, + collection=self.booking_collection, + body=None, + query={ + "fields[Booking]": "id,title", + "connectionName": "database_1", + "segmentName": "no_title", + "segmentQuery": "select id from booking where title is null", + }, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + + is_allowed = self.loop.run_until_complete(self.permission_service.can_live_query_segment(request)) + rendering_permission_mock.assert_awaited_once() + self.assertTrue(is_allowed) + + http_patches["get_rendering_permissions"].stop() + + def test_can_live_query_segment_should_raise_forbidden_when_wrong_connection_name(self): + http_patches: PatchHttpApiDict = self.mock_forest_http_api() + http_patches["get_rendering_permissions"].start() + + request = RequestCollection( + method=RequestMethod.GET, + collection=self.booking_collection, + body=None, + query={ + "fields[Booking]": "id,title", + "connectionName": "database_2", + "segmentName": "no_title", + "segmentQuery": "select id from booking where title is null", + }, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + + self.assertRaisesRegex( + ForbiddenError, + r"🌳🌳🌳You don't have permission to access this segment query.", + self.loop.run_until_complete, + self.permission_service.can_live_query_segment(request), + ) + + http_patches["get_rendering_permissions"].stop() + + def test_can_live_query_segment_should_raise_forbidden_when_mismatching_query(self): + http_patches: PatchHttpApiDict = self.mock_forest_http_api() + http_patches["get_rendering_permissions"].start() + + request = RequestCollection( + method=RequestMethod.GET, + collection=self.booking_collection, + body=None, + query={ + "fields[Booking]": "id,title", + "connectionName": "database_1", + "segmentName": "no_title", + "segmentQuery": "select id from booking where title is null;", + }, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + + self.assertRaisesRegex( + ForbiddenError, + r"🌳🌳🌳You don't have permission to access this segment query.", + self.loop.run_until_complete, + self.permission_service.can_live_query_segment(request), + ) + + http_patches["get_rendering_permissions"].stop() From 1a3084ddf1661f1c5ee5b0433bb0ba03421f0260 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 5 Dec 2024 14:27:36 +0100 Subject: [PATCH 40/71] chore: add test on django routes --- src/django_agent/tests/test_http_routes.py | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/django_agent/tests/test_http_routes.py b/src/django_agent/tests/test_http_routes.py index 5abb2f24a..86c57a5f1 100644 --- a/src/django_agent/tests/test_http_routes.py +++ b/src/django_agent/tests/test_http_routes.py @@ -18,6 +18,7 @@ def setUpClass(cls) -> None: cls.mocked_resources = {} for key in [ "capabilities", + "native_query", "authentication", "crud", "crud_related", @@ -457,3 +458,45 @@ def test_stat_list(self): self.assertEqual(request_param.method, RequestMethod.POST) self.assertEqual(request_param.query["collection_name"], "customer") self.assertEqual(request_param.body, {"post_attr": "post_value"}) + + +class TestDjangoAgentNativeQueryRoutes(TestDjangoAgentRoutes): + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.native_query_resource = cls.loop.run_until_complete(cls.django_agent.get_resources())["native_query"] + + def test_native_query(self): + response = self.client.post( + f"/{self.conf_prefix}forest/_internal/native_query?timezone=Europe%2FParis", + json.dumps( + { + "connectionName": "django", + "contextVariables": {}, + "query": "select status as key, sum(amount) as value from order group by key", + "type": "Pie", + } + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'{"mock": "ok"}') + self.native_query_resource.dispatch.assert_awaited() + call_args = self.native_query_resource.dispatch.await_args[0] + self.assertEqual(call_args[1], "native_query") + self.assertEqual( + call_args[0], + Request( + RequestMethod.POST, + body={ + "connectionName": "django", + "contextVariables": {}, + "query": "select status as key, sum(amount) as value from order group by key", + "type": "Pie", + }, + query={"timezone": "Europe/Paris"}, + headers={"Cookie": "", "Content-Length": "146", "Content-Type": "application/json"}, + client_ip="127.0.0.1", + ), + ) From e005930252b4a3390d51419ff907c505d559e0c3 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 5 Dec 2024 14:27:46 +0100 Subject: [PATCH 41/71] chore: add test on flask routes --- .../tests/test_flask_agent_blueprint.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/flask_agent/tests/test_flask_agent_blueprint.py b/src/flask_agent/tests/test_flask_agent_blueprint.py index 7bbf1360a..6f6bb2ba2 100644 --- a/src/flask_agent/tests/test_flask_agent_blueprint.py +++ b/src/flask_agent/tests/test_flask_agent_blueprint.py @@ -13,6 +13,7 @@ def setUpClass(cls) -> None: cls.loop = asyncio.new_event_loop() cls.mocked_resources = {} for key in [ + "native_query", "capabilities", "authentication", "crud", @@ -76,6 +77,44 @@ def test_capabilities(self): ), ) + def test_native_query(self): + response = self.client.post( + "/forest/_internal/native_query?timezone=Europe%2FParis", + json={ + "connectionName": "django", + "contextVariables": {}, + "query": "select status as key, sum(amount) as value from order group by key", + "type": "Pie", + }, + ) + assert response.status_code == 200 + assert response.json == {"mock": "ok"} + self.mocked_resources["native_query"].dispatch.assert_awaited() + call_args = self.mocked_resources["native_query"].dispatch.await_args.args + self.assertEqual(call_args[1], "native_query") + headers = {**call_args[0].headers} + del headers["User-Agent"] + call_args[0].headers = headers + self.assertEqual( + call_args[0], + Request( + RequestMethod.POST, + body={ + "connectionName": "django", + "contextVariables": {}, + "query": "select status as key, sum(amount) as value from order group by key", + "type": "Pie", + }, + query={"timezone": "Europe/Paris"}, + headers={ + "Host": "localhost", + "Content-Type": "application/json", + "Content-Length": "146", + }, + client_ip="127.0.0.1", + ), + ) + def test_hook_load(self): response = self.client.post("/forest/_actions/customer/1/action_name/hooks/load") assert response.status_code == 200 From 5f3f93bc0a7d3d50a447cd18c5a4f4d5a3e72455 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 5 Dec 2024 16:06:53 +0100 Subject: [PATCH 42/71] chore: adapt error messages --- .../forestadmin/agent_toolkit/resources/collections/crud.py | 2 +- .../agent_toolkit/resources/collections/native_query.py | 2 +- src/agent_toolkit/tests/resources/collections/test_crud.py | 6 +++--- .../tests/resources/collections/test_native_query.py | 2 +- .../datasource_customizer/datasource_composite.py | 5 +---- .../datasource_customizer/test_composite_datasource.py | 2 +- 6 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py index adff49245..04dc4ab3a 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py @@ -456,7 +456,7 @@ async def _handle_live_query_segment( ) -> Optional[ConditionTree]: if request.query.get("segmentQuery") is not None: if request.query.get("connectionName") in ["", None]: - raise NativeQueryException("Missing 'connectionName' parameter.") + raise NativeQueryException("Missing native query connection attribute") await self.permission.can_live_query_segment(request) variables = await self.inject_and_get_context_variables_in_live_query_segment(request) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py index 6b7ebd8e6..f2363e2cf 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -54,7 +54,7 @@ async def handle_native_query(self, request: Request) -> Response: await self.permission.can_chart(request) assert request.body is not None if request.body.get("connectionName") is None: - raise BusinessError("Setting a 'Native query connection' is mandatory.") + raise BusinessError("Missing native query connection attribute") if "query" not in request.body: raise BusinessError("Missing 'query' in parameter.") if request.body.get("type") not in ["Line", "Objective", "Leaderboard", "Pie", "Value"]: diff --git a/src/agent_toolkit/tests/resources/collections/test_crud.py b/src/agent_toolkit/tests/resources/collections/test_crud.py index 33d9f4e83..70a0abe89 100644 --- a/src/agent_toolkit/tests/resources/collections/test_crud.py +++ b/src/agent_toolkit/tests/resources/collections/test_crud.py @@ -2120,7 +2120,7 @@ def test_handle_native_query_should_raise_error_if_live_query_params_are_incorre self.assertRaisesRegex( NativeQueryException, - "Missing 'connectionName' parameter.", + "Missing native query connection attribute", self.loop.run_until_complete, crud_resource._handle_live_query_segment(request, None), ) @@ -2128,7 +2128,7 @@ def test_handle_native_query_should_raise_error_if_live_query_params_are_incorre request.query["connectionName"] = None self.assertRaisesRegex( NativeQueryException, - "Missing 'connectionName' parameter.", + "Missing native query connection attribute", self.loop.run_until_complete, crud_resource._handle_live_query_segment(request, None), ) @@ -2136,7 +2136,7 @@ def test_handle_native_query_should_raise_error_if_live_query_params_are_incorre request.query["connectionName"] = "" self.assertRaisesRegex( NativeQueryException, - "Missing 'connectionName' parameter.", + "Missing native query connection attribute", self.loop.run_until_complete, crud_resource._handle_live_query_segment(request, None), ) diff --git a/src/agent_toolkit/tests/resources/collections/test_native_query.py b/src/agent_toolkit/tests/resources/collections/test_native_query.py index 6ea865d31..a76ea3ed6 100644 --- a/src/agent_toolkit/tests/resources/collections/test_native_query.py +++ b/src/agent_toolkit/tests/resources/collections/test_native_query.py @@ -147,7 +147,7 @@ def test_should_return_error_if_connectionName_is_not_here(self): "errors": [ { "name": "ValidationError", - "detail": "Setting a 'Native query connection' is mandatory.", + "detail": "Missing native query connection attribute", "status": 400, "data": {}, } diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py index f03cb55ca..8ac4931e6 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py @@ -75,7 +75,4 @@ async def execute_native_query(self, connection_name: str, native_query: str, pa if connection_name in datasource.get_native_query_connections(): return await datasource.execute_native_query(connection_name, native_query, parameters) - raise NativeQueryException( - f"Cannot find connection '{connection_name}' in datasources. " - f"Existing connection names are: {','.join(self.get_native_query_connections())}" - ) + raise NativeQueryException(f"Native query connection '{connection_name}' is unknown") diff --git a/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py b/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py index cac1af06b..dae0e2c3f 100644 --- a/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py +++ b/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py @@ -206,7 +206,7 @@ def test_execute_native_query_should_raise_if_connection_is_unknown(self): self.assertRaisesRegex( NativeQueryException, - r"Cannot find connection 'bla' in datasources. Existing connection names are: db1,db2,db3", + r"Native query connection 'bla' is unknown", self.loop.run_until_complete, self.composite_ds.execute_native_query("bla", "select * from ...", {}), ) From ab41081c7783333eaf5c2852c328c2c0207730e5 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 9 Dec 2024 09:24:51 +0100 Subject: [PATCH 43/71] chore: can link add datasource methods --- src/agent_toolkit/forestadmin/agent_toolkit/agent.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py index 760929e8e..d014a2031 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py @@ -131,7 +131,9 @@ async def get_resources(self): await self.__mk_resources() return self._resources - def add_datasource(self, datasource: Datasource[BoundCollection], options: Optional[DataSourceOptions] = None): + def add_datasource( + self, datasource: Datasource[BoundCollection], options: Optional[DataSourceOptions] = None + ) -> Self: """Add a datasource Args: @@ -142,6 +144,7 @@ def add_datasource(self, datasource: Datasource[BoundCollection], options: Optio options = {} self.customizer.add_datasource(datasource, options) self._resources = None + return self def use(self, plugin: type, options: Optional[Dict] = {}) -> Self: """Load a plugin across all collections From 481423b532a7c9de6d0cba4c1c8cc7bd7578c3d9 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 10 Dec 2024 15:11:00 +0100 Subject: [PATCH 44/71] chore: add sql query checker --- .../resources/collections/crud.py | 2 + .../resources/collections/native_query.py | 2 + .../agent_toolkit/utils/sql_query_checker.py | 37 +++++++++++++ .../tests/utils/test_sql_query_checker.py | 55 +++++++++++++++++++ 4 files changed, 96 insertions(+) create mode 100644 src/agent_toolkit/forestadmin/agent_toolkit/utils/sql_query_checker.py create mode 100644 src/agent_toolkit/tests/utils/test_sql_query_checker.py diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py index 04dc4ab3a..8a3a09e7d 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py @@ -30,6 +30,7 @@ from forestadmin.agent_toolkit.utils.context import HttpResponseBuilder, Request, RequestMethod, Response, User from forestadmin.agent_toolkit.utils.csv import Csv, CsvException from forestadmin.agent_toolkit.utils.id import unpack_id +from forestadmin.agent_toolkit.utils.sql_query_checker import SqlQueryChecker from forestadmin.datasource_toolkit.collections import Collection from forestadmin.datasource_toolkit.datasource_customizer.collection_customizer import CollectionCustomizer from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource @@ -459,6 +460,7 @@ async def _handle_live_query_segment( raise NativeQueryException("Missing native query connection attribute") await self.permission.can_live_query_segment(request) + SqlQueryChecker.check_query(request.query["segmentQuery"]) variables = await self.inject_and_get_context_variables_in_live_query_segment(request) native_query_result = await self._datasource_composite.execute_native_query( request.query["connectionName"], request.query["segmentQuery"], variables diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py index f2363e2cf..2bcfe621d 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -9,6 +9,7 @@ from forestadmin.agent_toolkit.services.permissions.ip_whitelist_service import IpWhiteListService from forestadmin.agent_toolkit.services.permissions.permission_service import PermissionService from forestadmin.agent_toolkit.utils.context import HttpResponseBuilder, Request, RequestMethod, Response +from forestadmin.agent_toolkit.utils.sql_query_checker import SqlQueryChecker from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasource_customizer.datasource_customizer import DatasourceCustomizer from forestadmin.datasource_toolkit.exceptions import BusinessError, UnprocessableError, ValidationError @@ -60,6 +61,7 @@ async def handle_native_query(self, request: Request) -> Response: if request.body.get("type") not in ["Line", "Objective", "Leaderboard", "Pie", "Value"]: raise ValidationError(f"Unknown chart type '{request.body.get('type')}'.") + SqlQueryChecker.check_query(request.body["query"]) variables = await self.inject_and_get_context_variables_in_live_query_chart(request) native_query_results = await self.composite_datasource.execute_native_query( request.body["connectionName"], request.body["query"], variables diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/utils/sql_query_checker.py b/src/agent_toolkit/forestadmin/agent_toolkit/utils/sql_query_checker.py new file mode 100644 index 000000000..975f3c1ad --- /dev/null +++ b/src/agent_toolkit/forestadmin/agent_toolkit/utils/sql_query_checker.py @@ -0,0 +1,37 @@ +import re + +from forestadmin.datasource_toolkit.exceptions import NativeQueryException + + +class EmptySQLQueryException(NativeQueryException): + def __init__(self, *args: object) -> None: + super().__init__("You cannot execute an empty SQL query.") + + +class ChainedSQLQueryException(NativeQueryException): + def __init__(self, *args: object) -> None: + super().__init__("You cannot chain SQL queries.") + + +class NonSelectSQLQueryException(NativeQueryException): + def __init__(self, *args: object) -> None: + super().__init__("Only SELECT queries are allowed.") + + +class SqlQueryChecker: + QUERY_SELECT = re.compile(r"^SELECT\s(.|\n)*FROM\s(.|\n)*$", re.IGNORECASE) + + @staticmethod + def check_query(input_query: str) -> bool: + input_query_trimmed = input_query.strip() + + if len(input_query_trimmed) == 0: + raise EmptySQLQueryException() + + if ";" in input_query_trimmed and input_query_trimmed.index(";") != len(input_query_trimmed) - 1: + raise ChainedSQLQueryException() + + if not SqlQueryChecker.QUERY_SELECT.match(input_query_trimmed): + raise NonSelectSQLQueryException() + + return True diff --git a/src/agent_toolkit/tests/utils/test_sql_query_checker.py b/src/agent_toolkit/tests/utils/test_sql_query_checker.py new file mode 100644 index 000000000..dcfecb359 --- /dev/null +++ b/src/agent_toolkit/tests/utils/test_sql_query_checker.py @@ -0,0 +1,55 @@ +from unittest import TestCase + +from forestadmin.agent_toolkit.utils.sql_query_checker import ( + ChainedSQLQueryException, + EmptySQLQueryException, + NonSelectSQLQueryException, + SqlQueryChecker, +) + + +class TestSqlQueryChecker(TestCase): + def test_normal_sql_query_should_be_ok(self): + self.assertTrue( + SqlQueryChecker.check_query( + """ + Select status, sum(amount) as value + from order + where status != "rejected" + group by status having status != "rejected"; + """ + ) + ) + + def test_should_raise_on_linked_query(self): + self.assertRaisesRegex( + ChainedSQLQueryException, + r"You cannot chain SQL queries\.", + SqlQueryChecker.check_query, + """ + Select status, sum(amount) as value + from order + where status != "rejected" + group by status having status != "rejected"; delete from user_debts; + """, + ) + + def test_should_raise_on_empty_query(self): + self.assertRaisesRegex( + EmptySQLQueryException, + r"You cannot execute an empty SQL query\.", + SqlQueryChecker.check_query, + """ + + """, + ) + + def test_should_raise_on_non_select_query(self): + self.assertRaisesRegex( + NonSelectSQLQueryException, + r"Only SELECT queries are allowed\.", + SqlQueryChecker.check_query, + """ + delete from user_debts; + """, + ) From 3e6f394bba71890c67b3e6a646035ee58c97ae3a Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 10 Dec 2024 16:53:04 +0100 Subject: [PATCH 45/71] chore: try to improve ci speed on test py3.12 --- .github/actions/tests/action.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index 6bdb31b3b..caaefc90d 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -23,6 +23,8 @@ runs: - name: Install package dependencies shell: bash working-directory: ${{ inputs.current_package }} + env: + PYTHON_KEYRING_BACKEND: keyring.backends.fail.Keyring run: poetry install --no-interaction --with test - name: Test with pytest shell: bash From a25d815630a8ab665487df54dca6acbb841d5f59 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 10 Dec 2024 17:00:23 +0100 Subject: [PATCH 46/71] chore: try to improve ci speed on test py3.12 --- .github/actions/tests/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index caaefc90d..a927e0f8a 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -24,7 +24,7 @@ runs: shell: bash working-directory: ${{ inputs.current_package }} env: - PYTHON_KEYRING_BACKEND: keyring.backends.fail.Keyring + PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring run: poetry install --no-interaction --with test - name: Test with pytest shell: bash From 715485bc577bf4a869f6db50ada0fdfcae2797f8 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 10 Dec 2024 17:09:59 +0100 Subject: [PATCH 47/71] chore: try to improve ci speed on test py3.12 --- .github/actions/tests/action.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index a927e0f8a..caa3d2087 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -23,9 +23,7 @@ runs: - name: Install package dependencies shell: bash working-directory: ${{ inputs.current_package }} - env: - PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring - run: poetry install --no-interaction --with test + run: export PYTHON_KEYRING_BACKEND=keyring.backends.fail.Keyring && poetry install --no-interaction --with test - name: Test with pytest shell: bash working-directory: ${{ inputs.current_package }} From d099793831d9a7492d6bb8f0e8137f044b35cb3d Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 10 Dec 2024 17:20:19 +0100 Subject: [PATCH 48/71] chore: try to improve ci speed on test py3.12 --- .github/actions/tests/action.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index caa3d2087..9714b516d 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -15,7 +15,7 @@ runs: uses: actions/setup-python@v4 with: python-version: ${{ inputs.python-version }} - # cache: 'poetry' + cache: 'poetry' - name: Install poetry uses: snok/install-poetry@v1 with: @@ -23,7 +23,7 @@ runs: - name: Install package dependencies shell: bash working-directory: ${{ inputs.current_package }} - run: export PYTHON_KEYRING_BACKEND=keyring.backends.fail.Keyring && poetry install --no-interaction --with test + run: poetry install --no-interaction --with test - name: Test with pytest shell: bash working-directory: ${{ inputs.current_package }} From 8dfa13e51bbd079f9a7ebee42df495a3ca973be9 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 10 Dec 2024 17:24:09 +0100 Subject: [PATCH 49/71] chore: try to improve ci speed on test py3.12 --- .github/actions/tests/action.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index 9714b516d..87421efe3 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -11,15 +11,16 @@ runs: using: "composite" steps: - uses: actions/checkout@v3 + - name: Install poetry + # run: pipx install poetry + uses: snok/install-poetry@v1 + with: + version: 1.7.1 - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ inputs.python-version }} cache: 'poetry' - - name: Install poetry - uses: snok/install-poetry@v1 - with: - version: 1.7.1 - name: Install package dependencies shell: bash working-directory: ${{ inputs.current_package }} From 11284c4ca6e8a974fa6aba9d8f327411aa2991e3 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 10 Dec 2024 17:30:20 +0100 Subject: [PATCH 50/71] chore: try to improve ci speed on test py3.12 --- .github/actions/tests/action.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index 87421efe3..1dfb8a3bf 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -12,7 +12,6 @@ runs: steps: - uses: actions/checkout@v3 - name: Install poetry - # run: pipx install poetry uses: snok/install-poetry@v1 with: version: 1.7.1 @@ -20,7 +19,7 @@ runs: uses: actions/setup-python@v4 with: python-version: ${{ inputs.python-version }} - cache: 'poetry' + # cache: 'poetry' - name: Install package dependencies shell: bash working-directory: ${{ inputs.current_package }} From b4ee9a100d0103acbcf0302c35557986301686cf Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 10 Dec 2024 17:49:34 +0100 Subject: [PATCH 51/71] chore(ci): try to cache poetry --- .github/actions/tests/action.yml | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index 1dfb8a3bf..24efbf83c 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -12,14 +12,23 @@ runs: steps: - uses: actions/checkout@v3 - name: Install poetry - uses: snok/install-poetry@v1 - with: - version: 1.7.1 + # uses: snok/install-poetry@v1 + # with: + # version: 1.7.1 + run: | + pipx install poetry + echo "POETRY_CACHE_DIR=$(pip cache dir)" >> $GITHUB_ENV + - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ inputs.python-version }} # cache: 'poetry' + - name: Cache poetry + uses: actions/cache@v3 + with: + path: ${{ env.POETRY_CACHE_DIR }} + key: ${{ runner.os }}-poetry-${{ steps.setup-python.outputs.python-version }}-${{ inputs.current_package }}-${{ hashFiles('**/poetry.lock') }} - name: Install package dependencies shell: bash working-directory: ${{ inputs.current_package }} From 02d53b9991183cec9270ce2580ecd52a45e995fb Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 10 Dec 2024 17:52:44 +0100 Subject: [PATCH 52/71] chore(ci): try to cache poetry --- .github/actions/tests/action.yml | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index 24efbf83c..b0c61aa0d 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -12,13 +12,9 @@ runs: steps: - uses: actions/checkout@v3 - name: Install poetry - # uses: snok/install-poetry@v1 - # with: - # version: 1.7.1 - run: | - pipx install poetry - echo "POETRY_CACHE_DIR=$(pip cache dir)" >> $GITHUB_ENV - + uses: snok/install-poetry@v1 + with: + version: 1.7.1 - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 with: @@ -27,7 +23,7 @@ runs: - name: Cache poetry uses: actions/cache@v3 with: - path: ${{ env.POETRY_CACHE_DIR }} + path: '.cache/pypoetry' key: ${{ runner.os }}-poetry-${{ steps.setup-python.outputs.python-version }}-${{ inputs.current_package }}-${{ hashFiles('**/poetry.lock') }} - name: Install package dependencies shell: bash From b9083f79a370d9cfb259e04df55ad90ef8e1fb50 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 10 Dec 2024 17:57:19 +0100 Subject: [PATCH 53/71] chore(ci): try to cache poetry --- .github/actions/changes/action.yml | 4 ++++ .github/actions/tests/action.yml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/actions/changes/action.yml b/.github/actions/changes/action.yml index e2975f68f..9b831c861 100644 --- a/.github/actions/changes/action.yml +++ b/.github/actions/changes/action.yml @@ -10,7 +10,11 @@ runs: - 'src/datasource_toolkit/**' ./src/datasource_sqlalchemy: - 'src/datasource_sqlalchemy/**' + ./src/datasource_django: + - 'src/datasource_django/**' ./src/agent_toolkit: - 'src/agent_toolkit/**' ./src/flask_agent: - 'src/flask_agent/**' + ./src/django_agent: + - 'src/django_agent/**' diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index b0c61aa0d..82dfba8c5 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -23,7 +23,7 @@ runs: - name: Cache poetry uses: actions/cache@v3 with: - path: '.cache/pypoetry' + path: '~/.cache/pypoetry' key: ${{ runner.os }}-poetry-${{ steps.setup-python.outputs.python-version }}-${{ inputs.current_package }}-${{ hashFiles('**/poetry.lock') }} - name: Install package dependencies shell: bash From 4bf3330826211960d4b6b84b9b8dbc01bd53a367 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 10 Dec 2024 18:16:31 +0100 Subject: [PATCH 54/71] chore(ci): try to cache poetry --- .github/actions/tests/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index 82dfba8c5..094d894ea 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -24,7 +24,7 @@ runs: uses: actions/cache@v3 with: path: '~/.cache/pypoetry' - key: ${{ runner.os }}-poetry-${{ steps.setup-python.outputs.python-version }}-${{ inputs.current_package }}-${{ hashFiles('**/poetry.lock') }} + key: ${{ runner.os }}-poetry-${{ inputs.python-version }}-${{ inputs.current_package }}-${{ hashFiles('**/poetry.lock') }} - name: Install package dependencies shell: bash working-directory: ${{ inputs.current_package }} From df08234bac16e86525d611dd22540795bcba93cf Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 16 Dec 2024 11:24:42 +0100 Subject: [PATCH 55/71] chore(CI): use cache v4 --- .github/actions/coverage/action.yml | 2 +- .github/actions/release/action.yml | 2 +- .github/actions/tests/action.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/actions/coverage/action.yml b/.github/actions/coverage/action.yml index cb258d0ab..7eb42229b 100644 --- a/.github/actions/coverage/action.yml +++ b/.github/actions/coverage/action.yml @@ -45,7 +45,7 @@ runs: # debug # - name: Archive code coverage final results - # uses: actions/upload-artifact@v2 + # uses: actions/upload-artifact@v4 # with: # name: coverage.xml # path: ./src/coverage.xml diff --git a/.github/actions/release/action.yml b/.github/actions/release/action.yml index 78f9501ad..948ba3d0b 100644 --- a/.github/actions/release/action.yml +++ b/.github/actions/release/action.yml @@ -40,7 +40,7 @@ runs: - uses: actions/setup-node@v2 with: node-version: 14.17.6 - # - uses: actions/cache@v2 + # - uses: actions/cache@v4 # with: # path: '**/node_modules' # key: ${{ runner.os }}-modules-${{ hashFiles('**/yarn.lock') }} diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index 094d894ea..c4eb8a5cc 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -21,7 +21,7 @@ runs: python-version: ${{ inputs.python-version }} # cache: 'poetry' - name: Cache poetry - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: '~/.cache/pypoetry' key: ${{ runner.os }}-poetry-${{ inputs.python-version }}-${{ inputs.current_package }}-${{ hashFiles('**/poetry.lock') }} From 1f08a1d90802571e223d73bf3739a7c3a03632d1 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 16 Dec 2024 13:47:54 +0100 Subject: [PATCH 56/71] chore: handle sse cache invalidation for segments query --- .../services/permissions/sse_cache_invalidation.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/sse_cache_invalidation.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/sse_cache_invalidation.py index f531a307e..d2912225e 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/sse_cache_invalidation.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/sse_cache_invalidation.py @@ -1,23 +1,28 @@ +from __future__ import annotations + import time from threading import Thread -from typing import Dict +from typing import TYPE_CHECKING, Dict, List import urllib3 from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.agent_toolkit.options import Options from sseclient import SSEClient +if TYPE_CHECKING: + from forestadmin.agent_toolkit.services.permissions.permission_service import PermissionService + class SSECacheInvalidation(Thread): - _MESSAGE__CACHE_KEYS: Dict[str, str] = { + _MESSAGE__CACHE_KEYS: Dict[str, List[str]] = { "refresh-users": ["forest.users"], "refresh-roles": ["forest.collections"], - "refresh-renderings": ["forest.collections", "forest.stats", "forest.scopes"], + "refresh-renderings": ["forest.collections", "forest.stats", "forest.scopes", "forest.segment_queries"], # "refresh-customizations": None, # work with nocode actions # TODO: add one for ip whitelist when server implement it } - def __init__(self, permission_service: "PermissionService", options: Options, *args, **kwargs): # noqa: F821 + def __init__(self, permission_service: "PermissionService", options: Options, *args, **kwargs): super().__init__(name="SSECacheInvalidationThread", daemon=True, *args, **kwargs) self.permission_service = permission_service self.options: Options = options From 21dc9269325f18de65e891d26db0d2ec11749317 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 16 Dec 2024 13:55:25 +0100 Subject: [PATCH 57/71] chore: add link to doc --- .../forestadmin/datasource_django/datasource.py | 3 +-- src/datasource_django/tests/test_django_datasource.py | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/datasource_django/forestadmin/datasource_django/datasource.py b/src/datasource_django/forestadmin/datasource_django/datasource.py index c8ae66964..02cdcdb67 100644 --- a/src/datasource_django/forestadmin/datasource_django/datasource.py +++ b/src/datasource_django/forestadmin/datasource_django/datasource.py @@ -38,8 +38,7 @@ def _handle_live_query_connections_param( ForestLogger.log( "info", f"You enabled live query as {live_query_connections} for django 'default' database." - " To use it over multiple databases, read the related documentation here: http://link.", - # TODO: link + " To use it over multiple databases, read the related documentation here: https://docs.forestadmin.com/developer-guide-agents-python/data-sources/provided-data-sources/django#enable-support-of-live-queries.", ) else: ret = live_query_connections diff --git a/src/datasource_django/tests/test_django_datasource.py b/src/datasource_django/tests/test_django_datasource.py index 02e66c036..0fb23ed90 100644 --- a/src/datasource_django/tests/test_django_datasource.py +++ b/src/datasource_django/tests/test_django_datasource.py @@ -84,11 +84,12 @@ def test_should_create_native_query_connection_to_default_if_string_is_set(self) def test_should_log_when_creating_connection_with_string_param_and_multiple_databases_are_set_up(self): with patch("forestadmin.datasource_django.datasource.ForestLogger.log") as log_fn: DjangoDatasource(live_query_connection="django") - # TODO: adapt error message log_fn.assert_any_call( "info", "You enabled live query as django for django 'default' database. " - "To use it over multiple databases, read the related documentation here: http://link.", + "To use it over multiple databases, read the related documentation here: " + "https://docs.forestadmin.com/developer-guide-agents-python/data-sources/provided-data-sources/" + "django#enable-support-of-live-queries.", ) def test_should_raise_if_connection_query_target_non_existent_database(self): From e4603bf931c5f968080fae39c661a78dccbda8ed Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 16 Dec 2024 14:04:55 +0100 Subject: [PATCH 58/71] chore: fix linting --- .../forestadmin/datasource_django/datasource.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/datasource_django/forestadmin/datasource_django/datasource.py b/src/datasource_django/forestadmin/datasource_django/datasource.py index 02cdcdb67..2a943fcb0 100644 --- a/src/datasource_django/forestadmin/datasource_django/datasource.py +++ b/src/datasource_django/forestadmin/datasource_django/datasource.py @@ -38,7 +38,9 @@ def _handle_live_query_connections_param( ForestLogger.log( "info", f"You enabled live query as {live_query_connections} for django 'default' database." - " To use it over multiple databases, read the related documentation here: https://docs.forestadmin.com/developer-guide-agents-python/data-sources/provided-data-sources/django#enable-support-of-live-queries.", + " To use it over multiple databases, read the related documentation here: " + "https://docs.forestadmin.com/developer-guide-agents-python/" + "data-sources/provided-data-sources/django#enable-support-of-live-queries.", ) else: ret = live_query_connections From ffb52b283fc2b0943468e39554526eeef61316a4 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 16 Dec 2024 15:52:23 +0100 Subject: [PATCH 59/71] chore: refactor permission for key rendering --- .../permissions/permission_service.py | 54 ++++++++----------- .../permissions/sse_cache_invalidation.py | 2 +- .../permissions/test_permission_service.py | 5 +- .../forestadmin/django_agent/views/index.py | 6 ++- src/django_agent/tests/test_http_routes.py | 18 ++++--- .../forestadmin/flask_agent/agent.py | 2 +- .../tests/test_flask_agent_blueprint.py | 3 +- 7 files changed, 45 insertions(+), 45 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py index e941804b4..1c4542381 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py @@ -29,7 +29,7 @@ class PermissionService: def __init__(self, options: RoleOptions): self.options = options - self.cache: TTLCache[int, Any] = TTLCache(maxsize=256, ttl=options["permission_cache_duration"]) + self.cache: TTLCache[str, Any] = TTLCache(maxsize=256, ttl=options["permission_cache_duration"]) def invalidate_cache(self, key: str): if key in self.cache: @@ -60,11 +60,11 @@ async def can(self, caller: User, collection: Collection, action: str, allow_fet async def can_chart(self, request: RequestCollection) -> bool: hash_request = request.body["type"] + ":" + _hash_chart(request.body) - is_allowed = hash_request in await self._get_chart_data(request.user.rendering_id, False) + is_allowed = hash_request in (await self._get_rendering_data(request.user.rendering_id, False))["stats"] # Refetch if is_allowed is False: - is_allowed = hash_request in await self._get_chart_data(request.user.rendering_id, True) + is_allowed = hash_request in (await self._get_rendering_data(request.user.rendering_id, True))["stats"] # still not allowed - throw forbidden message if is_allowed is False: @@ -136,7 +136,7 @@ async def get_scope( caller: User, collection: Union[Collection, CollectionCustomizer], ) -> Optional[ConditionTree]: - permissions = await self._get_scope_and_team_data(caller.rendering_id) + permissions = await self._get_rendering_data(caller.rendering_id) scope = permissions["scopes"].get(collection.name) if scope is None: return None @@ -167,21 +167,9 @@ async def get_user_data(self, user_id: int): return self.cache["forest.users"][user_id] async def get_team(self, rendering_id: int): - permissions = await self._get_scope_and_team_data(rendering_id) + permissions = await self._get_rendering_data(rendering_id) return permissions["team"] - async def _get_chart_data(self, rendering_id: int, force_fetch: bool = False) -> Dict: - if force_fetch and "forest.stats" in self.cache: - del self.cache["forest.stats"] - - if "forest.stats" not in self.cache: - ForestLogger.log("debug", f"Loading rendering permissions for rendering {rendering_id}") - response = await ForestHttpApi.get_rendering_permissions(rendering_id, self.options) - - self._handle_rendering_permissions(response) - - return self.cache["forest.stats"] - async def _get_collection_permissions_data(self, force_fetch: bool = False): if force_fetch and "forest.collections" in self.cache: del self.cache["forest.collections"] @@ -199,39 +187,43 @@ async def _get_collection_permissions_data(self, force_fetch: bool = False): return self.cache["forest.collections"] - async def _get_scope_and_team_data(self, rendering_id: int): - if "forest.scopes" not in self.cache: + async def _get_rendering_data(self, rendering_id: int, force_fetch: bool = False): + if force_fetch and "forest.rendering" in self.cache: + del self.cache["forest.rendering"] + + if "forest.rendering" not in self.cache: response = await ForestHttpApi.get_rendering_permissions(rendering_id, self.options) self._handle_rendering_permissions(response) - return self.cache["forest.scopes"] + return self.cache["forest.rendering"] async def _get_segment_queries(self, rendering_id: int, force_fetch: bool): - if force_fetch and "forest.segment_queries" in self.cache: - del self.cache["forest.segment_queries"] + if force_fetch and "forest.rendering" in self.cache: + del self.cache["forest.rendering"] - if "forest.segment_queries" not in self.cache: + if "forest.rendering" not in self.cache: response = await ForestHttpApi.get_rendering_permissions(rendering_id, self.options) self._handle_rendering_permissions(response) - return self.cache["forest.segment_queries"] + return self.cache["forest.rendering"]["segment_queries"] def _handle_rendering_permissions(self, rendering_permissions): + rendering_cache = {} + # forest.stats stat_hash = [] for stat in rendering_permissions["stats"]: stat_hash.append(f'{stat["type"]}:{_hash_chart(stat)}') - self.cache["forest.stats"] = stat_hash + rendering_cache["stats"] = stat_hash # forest.scopes - data = { - "scopes": _decode_scope_permissions(rendering_permissions["collections"]), - "team": rendering_permissions["team"], - } - self.cache["forest.scopes"] = data + rendering_cache["scopes"] = _decode_scope_permissions(rendering_permissions["collections"]) + rendering_cache["team"] = rendering_permissions["team"] # forest.segment_queries - self.cache["forest.segment_queries"] = _decode_segment_query_permissions(rendering_permissions["collections"]) + rendering_cache["segment_queries"] = _decode_segment_query_permissions(rendering_permissions["collections"]) + + self.cache["forest.rendering"] = rendering_cache async def _find_action_from_endpoint( self, collection: Collection, get_params: Dict, http_method: str diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/sse_cache_invalidation.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/sse_cache_invalidation.py index d2912225e..29dc247d6 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/sse_cache_invalidation.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/sse_cache_invalidation.py @@ -17,7 +17,7 @@ class SSECacheInvalidation(Thread): _MESSAGE__CACHE_KEYS: Dict[str, List[str]] = { "refresh-users": ["forest.users"], "refresh-roles": ["forest.collections"], - "refresh-renderings": ["forest.collections", "forest.stats", "forest.scopes", "forest.segment_queries"], + "refresh-renderings": ["forest.collections", "forest.rendering"], # "refresh-customizations": None, # work with nocode actions # TODO: add one for ip whitelist when server implement it } diff --git a/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py b/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py index 9f840d12e..a63508880 100644 --- a/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py +++ b/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py @@ -232,12 +232,11 @@ def test_invalidate_cache_should_delete_corresponding_key(self): self.assertNotIn("forest.users", self.permission_service.cache) def test_dont_call_api_when_something_is_cached(self): - http_patches: PatchHttpApiDict = self.mock_forest_http_api() http_patches: PatchHttpApiDict = self.mock_forest_http_api() http_mocks: MockHttpApiDict = {name: patch.start() for name, patch in http_patches.items()} - response_1 = self.loop.run_until_complete(self.permission_service._get_chart_data(1)) - response_2 = self.loop.run_until_complete(self.permission_service._get_chart_data(1)) + response_1 = self.loop.run_until_complete(self.permission_service._get_rendering_data(1)) + response_2 = self.loop.run_until_complete(self.permission_service._get_rendering_data(1)) self.assertEqual(response_1, response_2) response_1 = self.loop.run_until_complete(self.permission_service.get_user_data(1)) diff --git a/src/django_agent/forestadmin/django_agent/views/index.py b/src/django_agent/forestadmin/django_agent/views/index.py index 7e888e703..59cd970b2 100644 --- a/src/django_agent/forestadmin/django_agent/views/index.py +++ b/src/django_agent/forestadmin/django_agent/views/index.py @@ -10,5 +10,9 @@ async def index(request: HttpRequest): @transaction.non_atomic_requests async def scope_cache_invalidation(request: HttpRequest): - DjangoAgentApp.get_agent()._permission_service.invalidate_cache("forest.scopes") + DjangoAgentApp.get_agent()._permission_service.invalidate_cache("forest.rendering") return HttpResponse(status=204) + + +# This is so ugly... But django.views.decorators.csrf.csrf_exempt is not asyncio ready +scope_cache_invalidation.csrf_exempt = True diff --git a/src/django_agent/tests/test_http_routes.py b/src/django_agent/tests/test_http_routes.py index 86c57a5f1..7f6690d3f 100644 --- a/src/django_agent/tests/test_http_routes.py +++ b/src/django_agent/tests/test_http_routes.py @@ -95,12 +95,18 @@ def test_index(self): self.assertEqual(response.content, b"") def test_scope_cache_invalidation(self): - response = self.client.get( - f"/{self.conf_prefix}forest/scope-cache-invalidation", - HTTP_X_FORWARDED_FOR="179.114.131.49", - ) - self.assertEqual(response.status_code, 204) - self.assertEqual(response.content, b"") + with patch.object( + self.django_agent._permission_service, + "invalidate_cache", + spy=self.django_agent._permission_service.invalidate_cache, + ) as spy_invalidate: + response = self.client.get( + f"/{self.conf_prefix}forest/scope-cache-invalidation", + HTTP_X_FORWARDED_FOR="179.114.131.49", + ) + self.assertEqual(response.status_code, 204) + self.assertEqual(response.content, b"") + spy_invalidate.assert_called_once_with("forest.rendering") class TestDjangoAgentAuthenticationRoutes(TestDjangoAgentRoutes): diff --git a/src/flask_agent/forestadmin/flask_agent/agent.py b/src/flask_agent/forestadmin/flask_agent/agent.py index ce87e0507..d3fbff658 100644 --- a/src/flask_agent/forestadmin/flask_agent/agent.py +++ b/src/flask_agent/forestadmin/flask_agent/agent.py @@ -210,7 +210,7 @@ async def csv_related(**_) -> FlaskResponse: # type: ignore @blueprint.route("/scope-cache-invalidation", methods=["POST"]) async def scope_cache_invalidation(**_) -> FlaskResponse: # type: ignore - agent._permission_service.invalidate_cache("forest.scopes") + agent._permission_service.invalidate_cache("forest.rendering") rsp = FlaskResponse(status=204) return rsp diff --git a/src/flask_agent/tests/test_flask_agent_blueprint.py b/src/flask_agent/tests/test_flask_agent_blueprint.py index 6f6bb2ba2..f7f235e86 100644 --- a/src/flask_agent/tests/test_flask_agent_blueprint.py +++ b/src/flask_agent/tests/test_flask_agent_blueprint.py @@ -279,6 +279,5 @@ def test_datasource_chart(self): def test_invalidate_cache(self): with patch.object(self.agent._permission_service, "invalidate_cache") as mocked_invalidate_cache: response = self.client.post("/forest/scope-cache-invalidation") - mocked_invalidate_cache.assert_called_with("forest.scopes") - assert response.status_code == 204 + mocked_invalidate_cache.assert_called_with("forest.rendering") assert response.status_code == 204 From d0657dea0cec2fa62441c82f61ce23282f53b4da Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 16 Dec 2024 16:15:38 +0100 Subject: [PATCH 60/71] chore(ci): disable cache on poetry install --- .github/actions/tests/action.yml | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index c4eb8a5cc..e14dacc07 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -20,15 +20,18 @@ runs: with: python-version: ${{ inputs.python-version }} # cache: 'poetry' - - name: Cache poetry - uses: actions/cache@v4 - with: - path: '~/.cache/pypoetry' - key: ${{ runner.os }}-poetry-${{ inputs.python-version }}-${{ inputs.current_package }}-${{ hashFiles('**/poetry.lock') }} + # 315MB of cache * 6 sub package * 5 python versions ~= 10GB. + # - name: Cache poetry + # id: cache-poetry-install + # uses: actions/cache@v4 + # with: + # path: '~/.cache/pypoetry' + # key: ${{ runner.os }}-poetry-${{ inputs.python-version }}-${{ inputs.current_package }}-${{ hashFiles('**/poetry.lock') }} - name: Install package dependencies shell: bash working-directory: ${{ inputs.current_package }} run: poetry install --no-interaction --with test + # if: steps.cache-poetry-install.outputs.cache-hit != 'true' - name: Test with pytest shell: bash working-directory: ${{ inputs.current_package }} From 8e7a89501c2da30458e2b3508b318ead5d5579c8 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 16 Dec 2024 16:16:29 +0100 Subject: [PATCH 61/71] chore(ci): try to improve poetry setup --- .github/actions/tests/action.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index e14dacc07..e49b52cf4 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -12,14 +12,15 @@ runs: steps: - uses: actions/checkout@v3 - name: Install poetry - uses: snok/install-poetry@v1 - with: - version: 1.7.1 + run: pipx install poetry + # uses: snok/install-poetry@v1 + # with: + # version: 1.7.1 - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ inputs.python-version }} - # cache: 'poetry' + cache: 'poetry' # 315MB of cache * 6 sub package * 5 python versions ~= 10GB. # - name: Cache poetry # id: cache-poetry-install From b4ab040f810eb41a68e54e5bd1dd5c062c262c27 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 16 Dec 2024 16:20:21 +0100 Subject: [PATCH 62/71] chore(ci): restore old way to install poetry --- .github/actions/tests/action.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index e49b52cf4..2a09ee300 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -12,15 +12,15 @@ runs: steps: - uses: actions/checkout@v3 - name: Install poetry - run: pipx install poetry - # uses: snok/install-poetry@v1 - # with: - # version: 1.7.1 + # run: pipx install poetry + uses: snok/install-poetry@v1 + with: + version: 1.7.1 - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ inputs.python-version }} - cache: 'poetry' + # cache: 'poetry' # 315MB of cache * 6 sub package * 5 python versions ~= 10GB. # - name: Cache poetry # id: cache-poetry-install From 66f437d1c57d722c10a5b20cc1c203921e1a5c6d Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 16 Dec 2024 16:22:11 +0100 Subject: [PATCH 63/71] chore(ci): another try to improve poetry setup --- .github/actions/tests/action.yml | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index 2a09ee300..dac480d17 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -12,15 +12,16 @@ runs: steps: - uses: actions/checkout@v3 - name: Install poetry - # run: pipx install poetry - uses: snok/install-poetry@v1 - with: - version: 1.7.1 + run: pipx install poetry + shell: bash + # uses: snok/install-poetry@v1 + # with: + # version: 1.7.1 - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ inputs.python-version }} - # cache: 'poetry' + cache: 'poetry' # 315MB of cache * 6 sub package * 5 python versions ~= 10GB. # - name: Cache poetry # id: cache-poetry-install From 62d81073e14c3b9626bdda86079f4f2cb6fd12fa Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Mon, 16 Dec 2024 16:27:54 +0100 Subject: [PATCH 64/71] chore(ci): disable setup python cache on poetry because of concurrency with other packages --- .github/actions/tests/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index dac480d17..c07f9ab62 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -21,7 +21,7 @@ runs: uses: actions/setup-python@v4 with: python-version: ${{ inputs.python-version }} - cache: 'poetry' + # cache: 'poetry' # 315MB of cache * 6 sub package * 5 python versions ~= 10GB. # - name: Cache poetry # id: cache-poetry-install From d3f41b8220390a261b55a370487ac2e4caf835d7 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 17 Dec 2024 10:25:21 +0100 Subject: [PATCH 65/71] chore: don't show stacktrace on 403 forbidden --- .../agent_toolkit/resources/collections/native_query.py | 4 +++- .../forestadmin/agent_toolkit/resources/collections/stats.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py index 2bcfe621d..601e4a3b8 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -12,7 +12,7 @@ from forestadmin.agent_toolkit.utils.sql_query_checker import SqlQueryChecker from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasource_customizer.datasource_customizer import DatasourceCustomizer -from forestadmin.datasource_toolkit.exceptions import BusinessError, UnprocessableError, ValidationError +from forestadmin.datasource_toolkit.exceptions import BusinessError, ForbiddenError, UnprocessableError, ValidationError from forestadmin.datasource_toolkit.interfaces.chart import ( Chart, DistributionChart, @@ -45,6 +45,8 @@ def __init__( async def dispatch(self, request: Request, method_name: Literal["native_query"]) -> Response: try: return await self.handle_native_query(request) # type:ignore + except ForbiddenError as exc: + return HttpResponseBuilder.build_client_error_response([exc]) except Exception as exc: ForestLogger.log("exception", exc) return HttpResponseBuilder.build_client_error_response([exc]) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py index 1ac2fa581..816d6c4c2 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py @@ -15,7 +15,7 @@ from forestadmin.agent_toolkit.resources.collections.requests import RequestCollection, RequestCollectionException from forestadmin.agent_toolkit.resources.context_variable_injector_mixin import ContextVariableInjectorResourceMixin from forestadmin.agent_toolkit.utils.context import FileResponse, HttpResponseBuilder, Request, RequestMethod, Response -from forestadmin.datasource_toolkit.exceptions import ForestException +from forestadmin.datasource_toolkit.exceptions import ForbiddenError, ForestException from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.base import ConditionTree from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import Aggregator, ConditionTreeBranch @@ -63,6 +63,8 @@ async def dispatch( ) try: return await meth(request_collection) + except ForbiddenError as exc: + return HttpResponseBuilder.build_client_error_response([exc]) except Exception as exc: ForestLogger.log("exception", exc) return HttpResponseBuilder.build_client_error_response([exc]) From 35949f926a4846d222e257362f551a64404bd13b Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 17 Dec 2024 14:57:20 +0100 Subject: [PATCH 66/71] chore: update example --- src/_example/django/django_demo/app/forest/customer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/_example/django/django_demo/app/forest/customer.py b/src/_example/django/django_demo/app/forest/customer.py index 1ec1e02c3..3a32ca7cb 100644 --- a/src/_example/django/django_demo/app/forest/customer.py +++ b/src/_example/django/django_demo/app/forest/customer.py @@ -69,7 +69,7 @@ async def get_customer_spending_values(records: List[RecordsDataAlias], context: def customer_full_name() -> ComputedDefinition: async def _get_customer_fullname_values(records: List[RecordsDataAlias], context: CollectionCustomizationContext): - return [f"{record['first_name']} - {record['last_name']}" for record in records] + return [f"{record.get('first_name', '')} - {record.get('last_name', '')}" for record in records] return { "column_type": "String", From 36de6a9b71888dee028a7b0ff809ea267f2a4de5 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Tue, 17 Dec 2024 17:02:47 +0100 Subject: [PATCH 67/71] chore: theses are in a different PR --- .github/actions/changes/action.yml | 4 ---- .github/actions/coverage/action.yml | 2 +- .github/actions/release/action.yml | 2 +- .github/actions/tests/action.yml | 11 ----------- 4 files changed, 2 insertions(+), 17 deletions(-) diff --git a/.github/actions/changes/action.yml b/.github/actions/changes/action.yml index 9b831c861..e2975f68f 100644 --- a/.github/actions/changes/action.yml +++ b/.github/actions/changes/action.yml @@ -10,11 +10,7 @@ runs: - 'src/datasource_toolkit/**' ./src/datasource_sqlalchemy: - 'src/datasource_sqlalchemy/**' - ./src/datasource_django: - - 'src/datasource_django/**' ./src/agent_toolkit: - 'src/agent_toolkit/**' ./src/flask_agent: - 'src/flask_agent/**' - ./src/django_agent: - - 'src/django_agent/**' diff --git a/.github/actions/coverage/action.yml b/.github/actions/coverage/action.yml index 7eb42229b..cb258d0ab 100644 --- a/.github/actions/coverage/action.yml +++ b/.github/actions/coverage/action.yml @@ -45,7 +45,7 @@ runs: # debug # - name: Archive code coverage final results - # uses: actions/upload-artifact@v4 + # uses: actions/upload-artifact@v2 # with: # name: coverage.xml # path: ./src/coverage.xml diff --git a/.github/actions/release/action.yml b/.github/actions/release/action.yml index 948ba3d0b..78f9501ad 100644 --- a/.github/actions/release/action.yml +++ b/.github/actions/release/action.yml @@ -40,7 +40,7 @@ runs: - uses: actions/setup-node@v2 with: node-version: 14.17.6 - # - uses: actions/cache@v4 + # - uses: actions/cache@v2 # with: # path: '**/node_modules' # key: ${{ runner.os }}-modules-${{ hashFiles('**/yarn.lock') }} diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index c07f9ab62..0aeb8f0a9 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -14,26 +14,15 @@ runs: - name: Install poetry run: pipx install poetry shell: bash - # uses: snok/install-poetry@v1 - # with: - # version: 1.7.1 - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ inputs.python-version }} # cache: 'poetry' - # 315MB of cache * 6 sub package * 5 python versions ~= 10GB. - # - name: Cache poetry - # id: cache-poetry-install - # uses: actions/cache@v4 - # with: - # path: '~/.cache/pypoetry' - # key: ${{ runner.os }}-poetry-${{ inputs.python-version }}-${{ inputs.current_package }}-${{ hashFiles('**/poetry.lock') }} - name: Install package dependencies shell: bash working-directory: ${{ inputs.current_package }} run: poetry install --no-interaction --with test - # if: steps.cache-poetry-install.outputs.cache-hit != 'true' - name: Test with pytest shell: bash working-directory: ${{ inputs.current_package }} From 4b9fdc826f40d2d495a03f8fcf62ba43a6f8ce4b Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Wed, 18 Dec 2024 14:35:41 +0100 Subject: [PATCH 68/71] chore: permission refactor wasn't complete --- .../permissions/permission_service.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py index 1c4542381..a9f0f020f 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py @@ -81,13 +81,15 @@ async def can_live_query_segment(self, request: RequestCollection) -> bool: live_query = request.query["segmentQuery"] connection_name = request.query["connectionName"] hash_live_query = _dict_hash({"query": live_query, "connection_name": connection_name}) - is_allowed = hash_live_query in (await self._get_segment_queries(request.user.rendering_id, False)).get( - request.collection.name - ) + is_allowed = hash_live_query in ( + (await self._get_rendering_data(request.user.rendering_id, False))["segment_queries"] + ).get(request.collection.name) # Refetch if is_allowed is False: - is_allowed = hash_live_query in await self._get_segment_queries(request.user.rendering_id, True) + is_allowed = ( + hash_live_query in (await self._get_rendering_data(request.user.rendering_id, True))["segment_queries"] + ) # still not allowed - throw forbidden message if is_allowed is False: @@ -197,16 +199,6 @@ async def _get_rendering_data(self, rendering_id: int, force_fetch: bool = False return self.cache["forest.rendering"] - async def _get_segment_queries(self, rendering_id: int, force_fetch: bool): - if force_fetch and "forest.rendering" in self.cache: - del self.cache["forest.rendering"] - - if "forest.rendering" not in self.cache: - response = await ForestHttpApi.get_rendering_permissions(rendering_id, self.options) - self._handle_rendering_permissions(response) - - return self.cache["forest.rendering"]["segment_queries"] - def _handle_rendering_permissions(self, rendering_permissions): rendering_cache = {} From 35e47778c000a155ea3db923634ce046cfa7c211 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 19 Dec 2024 11:29:44 +0100 Subject: [PATCH 69/71] chore: fix for review --- .../datasource_django/datasource.py | 22 +++++++++++-------- .../decorators/decorator_stack.py | 2 +- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/datasource_django/forestadmin/datasource_django/datasource.py b/src/datasource_django/forestadmin/datasource_django/datasource.py index 2a943fcb0..5a3b12015 100644 --- a/src/datasource_django/forestadmin/datasource_django/datasource.py +++ b/src/datasource_django/forestadmin/datasource_django/datasource.py @@ -4,7 +4,6 @@ from asgiref.sync import sync_to_async from django.apps import apps from django.db import connections -from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.datasource_django.collection import DjangoCollection from forestadmin.datasource_django.exception import DjangoDatasourceException from forestadmin.datasource_django.interface import BaseDjangoDatasource @@ -18,6 +17,19 @@ def __init__( support_polymorphic_relations: bool = False, live_query_connection: Optional[Union[str, Dict[str, str]]] = None, ) -> None: + """ Create a django datasource. + More information here: + https://docs.forestadmin.com/developer-guide-agents-python/data-sources/provided-data-sources/django + + + Args: + support_polymorphic_relations (bool, optional, default to `False`): Enable introspection over \ + polymorphic relation (AKA GenericForeignKey). Defaults to False. + live_query_connection (Union[str, Dict[str, str]], optional, default to `None`): Set a connectionName to \ + use live queries. If a string is given, this connection will be map to django 'default' database. \ + Otherwise, you must use a dict `{'connectionName': 'DjangoDatabaseName'}`. \ + None doesn't enable this feature. + """ self._django_live_query_connections: Dict[str, str] = self._handle_live_query_connections_param( live_query_connection ) @@ -34,14 +46,6 @@ def _handle_live_query_connections_param( if isinstance(live_query_connections, str): ret = {live_query_connections: "default"} - if len(connections.all()) > 1: - ForestLogger.log( - "info", - f"You enabled live query as {live_query_connections} for django 'default' database." - " To use it over multiple databases, read the related documentation here: " - "https://docs.forestadmin.com/developer-guide-agents-python/" - "data-sources/provided-data-sources/django#enable-support-of-live-queries.", - ) else: ret = live_query_connections diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py index 760a708e8..5a1d297fd 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/decorator_stack.py @@ -28,7 +28,7 @@ class DecoratorStack: def __init__(self, datasource: Datasource) -> None: self._customizations: List = list() - last = self.base_datasource = datasource + last = datasource # Step 0: Do not query datasource when we know the result with yield an empty set. last = self.override = DatasourceDecorator(last, OverrideCollectionDecorator) # type: ignore From df361312b3a2b6091ce8f3f097f40e2c23f1dd81 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 19 Dec 2024 11:46:12 +0100 Subject: [PATCH 70/71] chore: remove conflict --- .github/actions/tests/action.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/actions/tests/action.yml b/.github/actions/tests/action.yml index 0aeb8f0a9..6bdb31b3b 100644 --- a/.github/actions/tests/action.yml +++ b/.github/actions/tests/action.yml @@ -11,14 +11,15 @@ runs: using: "composite" steps: - uses: actions/checkout@v3 - - name: Install poetry - run: pipx install poetry - shell: bash - name: Set up Python ${{ inputs.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ inputs.python-version }} # cache: 'poetry' + - name: Install poetry + uses: snok/install-poetry@v1 + with: + version: 1.7.1 - name: Install package dependencies shell: bash working-directory: ${{ inputs.current_package }} From cf138b9305c4a616e77930f6345de19b714b2741 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 19 Dec 2024 14:07:43 +0100 Subject: [PATCH 71/71] chore: remove test related to removed log --- src/datasource_django/tests/test_django_datasource.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/datasource_django/tests/test_django_datasource.py b/src/datasource_django/tests/test_django_datasource.py index 0fb23ed90..ae0b98ac0 100644 --- a/src/datasource_django/tests/test_django_datasource.py +++ b/src/datasource_django/tests/test_django_datasource.py @@ -81,17 +81,6 @@ def test_should_create_native_query_connection_to_default_if_string_is_set(self) self.assertEqual(ds.get_native_query_connections(), ["django"]) self.assertEqual(ds._django_live_query_connections["django"], "default") - def test_should_log_when_creating_connection_with_string_param_and_multiple_databases_are_set_up(self): - with patch("forestadmin.datasource_django.datasource.ForestLogger.log") as log_fn: - DjangoDatasource(live_query_connection="django") - log_fn.assert_any_call( - "info", - "You enabled live query as django for django 'default' database. " - "To use it over multiple databases, read the related documentation here: " - "https://docs.forestadmin.com/developer-guide-agents-python/data-sources/provided-data-sources/" - "django#enable-support-of-live-queries.", - ) - def test_should_raise_if_connection_query_target_non_existent_database(self): self.assertRaisesRegex( DjangoDatasourceException,