diff --git a/src/_example/django/django_demo/.forestadmin-schema.json b/src/_example/django/django_demo/.forestadmin-schema.json index 6503219e8..b2850a5df 100644 --- a/src/_example/django/django_demo/.forestadmin-schema.json +++ b/src/_example/django/django_demo/.forestadmin-schema.json @@ -2726,7 +2726,7 @@ "enums": null, "field": "body", "inverseOf": null, - "isFilterable": false, + "isFilterable": true, "isPrimaryKey": false, "isReadOnly": false, "isRequired": false, @@ -2740,7 +2740,7 @@ "enums": null, "field": "email", "inverseOf": null, - "isFilterable": false, + "isFilterable": true, "isPrimaryKey": false, "isReadOnly": false, "isRequired": false, @@ -2768,7 +2768,7 @@ "enums": null, "field": "name", "inverseOf": null, - "isFilterable": false, + "isFilterable": true, "isPrimaryKey": false, "isReadOnly": false, "isRequired": false, @@ -3486,7 +3486,7 @@ "enums": null, "field": "body", "inverseOf": null, - "isFilterable": false, + "isFilterable": true, "isPrimaryKey": false, "isReadOnly": false, "isRequired": false, @@ -3514,7 +3514,7 @@ "enums": null, "field": "title", "inverseOf": null, - "isFilterable": false, + "isFilterable": true, "isPrimaryKey": false, "isReadOnly": false, "isRequired": false, diff --git a/src/_example/django/django_demo/app/forest/custom_datasources/typicode.py b/src/_example/django/django_demo/app/forest/custom_datasources/typicode.py index f6f9bfd04..dcf49f57d 100644 --- a/src/_example/django/django_demo/app/forest/custom_datasources/typicode.py +++ b/src/_example/django/django_demo/app/forest/custom_datasources/typicode.py @@ -107,14 +107,17 @@ def __init__(self, datasource: Datasource[Self]): "name": { "type": FieldType.COLUMN, "column_type": "String", + "filter_operators": set([Operator.EQUAL]), }, "email": { "type": FieldType.COLUMN, "column_type": "String", + "filter_operators": set([Operator.EQUAL]), }, "body": { "type": FieldType.COLUMN, "column_type": "String", + "filter_operators": set([Operator.EQUAL]), }, } ) @@ -140,10 +143,12 @@ def __init__(self, datasource: Datasource[Self]): "title": { "type": FieldType.COLUMN, "column_type": "String", + "filter_operators": set([Operator.EQUAL]), }, "body": { "type": FieldType.COLUMN, "column_type": "String", + "filter_operators": set([Operator.EQUAL]), }, } ) diff --git a/src/_example/django/django_demo/app/forest/customer.py b/src/_example/django/django_demo/app/forest/customer.py index 1ec1e02c3..3a32ca7cb 100644 --- a/src/_example/django/django_demo/app/forest/customer.py +++ b/src/_example/django/django_demo/app/forest/customer.py @@ -69,7 +69,7 @@ async def get_customer_spending_values(records: List[RecordsDataAlias], context: def customer_full_name() -> ComputedDefinition: async def _get_customer_fullname_values(records: List[RecordsDataAlias], context: CollectionCustomizationContext): - return [f"{record['first_name']} - {record['last_name']}" for record in records] + return [f"{record.get('first_name', '')} - {record.get('last_name', '')}" for record in records] return { "column_type": "String", diff --git a/src/_example/django/django_demo/app/forest_admin.py b/src/_example/django/django_demo/app/forest_admin.py index 6d43fbc27..76dc72bf9 100644 --- a/src/_example/django/django_demo/app/forest_admin.py +++ b/src/_example/django/django_demo/app/forest_admin.py @@ -50,9 +50,16 @@ def customize_forest(agent: DjangoAgent): # customize_forest_logging() - agent.add_datasource(DjangoDatasource(support_polymorphic_relations=True)) + + agent.add_datasource( + DjangoDatasource( + support_polymorphic_relations=True, live_query_connection={"django": "default", "dj_sqlachemy": "other"} + ) + ) agent.add_datasource(TypicodeDatasource()) - agent.add_datasource(SqlAlchemyDatasource(Base, DB_URI)) + agent.add_datasource( + SqlAlchemyDatasource(Base, DB_URI, live_query_connection="sqlalchemy"), + ) agent.customize_collection("address").add_segment("France", segment_addr_fr("address")) agent.customize_collection("app_address").add_segment("France", segment_addr_fr("app_address")) diff --git a/src/_example/django/django_demo/django_demo/settings.py b/src/_example/django/django_demo/django_demo/settings.py index 6018be425..2efd1ea59 100644 --- a/src/_example/django/django_demo/django_demo/settings.py +++ b/src/_example/django/django_demo/django_demo/settings.py @@ -123,10 +123,10 @@ "NAME": os.path.join(BASE_DIR, "db.sqlite3"), "ATOMIC_REQUESTS": True, }, - # "other": { - # "ENGINE": "django.db.backends.sqlite3", - # "NAME": os.path.join(BASE_DIR, "db_flask_example.sqlite"), - # }, + "other": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": os.path.join(BASE_DIR, "db_sqlalchemy.sql"), + }, } DATABASE_ROUTERS = ["django_demo.db_router.DBRouter"] diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py index ccc4dcb88..d014a2031 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/agent.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/agent.py @@ -9,6 +9,7 @@ from forestadmin.agent_toolkit.resources.collections.charts_datasource import ChartsDatasourceResource from forestadmin.agent_toolkit.resources.collections.crud import CrudResource from forestadmin.agent_toolkit.resources.collections.crud_related import CrudRelatedResource +from forestadmin.agent_toolkit.resources.collections.native_query import NativeQueryResource from forestadmin.agent_toolkit.resources.collections.stats import StatsResource from forestadmin.agent_toolkit.resources.security.resources import Authentication from forestadmin.agent_toolkit.services.permissions.ip_whitelist_service import IpWhiteListService @@ -37,6 +38,7 @@ class Resources(TypedDict): actions: ActionResource collection_charts: ChartsCollectionResource datasource_charts: ChartsDatasourceResource + native_query: NativeQueryResource class Agent: @@ -73,10 +75,13 @@ def __del__(self): async def __mk_resources(self): self._resources: Resources = { "capabilities": CapabilitiesResource( - await self.customizer.get_datasource(), self._ip_white_list_service, self.options + self.customizer.composite_datasource, + self._ip_white_list_service, + self.options, ), "authentication": Authentication(self._ip_white_list_service, self.options), "crud": CrudResource( + self.customizer.composite_datasource, await self.customizer.get_datasource(), self._permission_service, self._ip_white_list_service, @@ -112,6 +117,13 @@ async def __mk_resources(self): self._ip_white_list_service, self.options, ), + "native_query": NativeQueryResource( + self.customizer.composite_datasource, + await self.customizer.get_datasource(), + self._permission_service, + self._ip_white_list_service, + self.options, + ), } async def get_resources(self): @@ -119,7 +131,9 @@ async def get_resources(self): await self.__mk_resources() return self._resources - def add_datasource(self, datasource: Datasource[BoundCollection], options: Optional[DataSourceOptions] = None): + def add_datasource( + self, datasource: Datasource[BoundCollection], options: Optional[DataSourceOptions] = None + ) -> Self: """Add a datasource Args: @@ -130,6 +144,7 @@ def add_datasource(self, datasource: Datasource[BoundCollection], options: Optio options = {} self.customizer.add_datasource(datasource, options) self._resources = None + return self def use(self, plugin: type, options: Optional[Dict] = {}) -> Self: """Load a plugin across all collections diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/capabilities.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/capabilities.py index 6e6aa7680..997e8ba19 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/capabilities.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/capabilities.py @@ -7,6 +7,7 @@ from forestadmin.agent_toolkit.services.permissions.ip_whitelist_service import IpWhiteListService from forestadmin.agent_toolkit.utils.context import HttpResponseBuilder, Request, RequestMethod, Response from forestadmin.agent_toolkit.utils.forest_schema.generator_field import SchemaFieldGenerator +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasource_customizer.datasource_customizer import DatasourceCustomizer from forestadmin.datasource_toolkit.exceptions import BusinessError from forestadmin.datasource_toolkit.interfaces.fields import Column, is_column @@ -19,12 +20,12 @@ class CapabilitiesResource(IpWhitelistResource): def __init__( self, - datasource: DatasourceAlias, + composite_datasource: CompositeDatasource, ip_white_list_service: IpWhiteListService, options: Options, ): super().__init__(ip_white_list_service, options) - self.datasource = datasource + self.composite_datasource: CompositeDatasource = composite_datasource @ip_white_list async def dispatch(self, request: Request, method_name: LiteralMethod) -> Response: @@ -40,14 +41,18 @@ async def dispatch(self, request: Request, method_name: LiteralMethod) -> Respon @check_method(RequestMethod.POST) @authenticate async def capabilities(self, request: Request) -> Response: - ret = {"collections": []} + ret = {"collections": [], "nativeQueryConnections": []} requested_collections = request.body.get("collectionNames", []) for collection_name in requested_collections: ret["collections"].append(self._get_collection_capability(collection_name)) + + ret["nativeQueryConnections"] = [ + {"name": connection} for connection in self.composite_datasource.get_native_query_connections() + ] return HttpResponseBuilder.build_success_response(ret) def _get_collection_capability(self, collection_name: str) -> Dict[str, Any]: - collection = self.datasource.get_collection(collection_name) + collection = self.composite_datasource.get_collection(collection_name) fields = [] for field_name, field_schema in collection.schema["fields"].items(): if is_column(field_schema): diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py index 942872632..8a3a09e7d 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/crud.py @@ -1,8 +1,9 @@ import asyncio -from typing import Any, Awaitable, Dict, List, Literal, Tuple, Union, cast +from typing import Any, Awaitable, Dict, List, Literal, Optional, Tuple, Union, cast from uuid import UUID from forestadmin.agent_toolkit.forest_logger import ForestLogger +from forestadmin.agent_toolkit.options import Options from forestadmin.agent_toolkit.resources.collections.base_collection_resource import BaseCollectionResource from forestadmin.agent_toolkit.resources.collections.decorators import ( authenticate, @@ -21,15 +22,21 @@ parse_timezone, ) from forestadmin.agent_toolkit.resources.collections.requests import RequestCollection, RequestCollectionException +from forestadmin.agent_toolkit.resources.context_variable_injector_mixin import ContextVariableInjectorResourceMixin +from forestadmin.agent_toolkit.services.permissions.ip_whitelist_service import IpWhiteListService +from forestadmin.agent_toolkit.services.permissions.permission_service import PermissionService from forestadmin.agent_toolkit.services.serializers import add_search_metadata from forestadmin.agent_toolkit.services.serializers.json_api import JsonApiException, JsonApiSerializer from forestadmin.agent_toolkit.utils.context import HttpResponseBuilder, Request, RequestMethod, Response, User from forestadmin.agent_toolkit.utils.csv import Csv, CsvException from forestadmin.agent_toolkit.utils.id import unpack_id +from forestadmin.agent_toolkit.utils.sql_query_checker import SqlQueryChecker from forestadmin.datasource_toolkit.collections import Collection from forestadmin.datasource_toolkit.datasource_customizer.collection_customizer import CollectionCustomizer -from forestadmin.datasource_toolkit.datasources import DatasourceException -from forestadmin.datasource_toolkit.exceptions import ForbiddenError +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource +from forestadmin.datasource_toolkit.datasource_customizer.datasource_customizer import DatasourceCustomizer +from forestadmin.datasource_toolkit.datasources import Datasource, DatasourceException +from forestadmin.datasource_toolkit.exceptions import ForbiddenError, NativeQueryException from forestadmin.datasource_toolkit.interfaces.fields import ( ManyToOne, OneToOne, @@ -55,6 +62,7 @@ from forestadmin.datasource_toolkit.interfaces.query.projections.factory import ProjectionFactory from forestadmin.datasource_toolkit.interfaces.records import CompositeIdAlias, RecordsDataAlias from forestadmin.datasource_toolkit.utils.collections import CollectionUtils +from forestadmin.datasource_toolkit.utils.records import RecordUtils from forestadmin.datasource_toolkit.utils.schema import SchemaUtils from forestadmin.datasource_toolkit.validations.field import FieldValidatorException from forestadmin.datasource_toolkit.validations.records import RecordValidator, RecordValidatorException @@ -62,7 +70,18 @@ LiteralMethod = Literal["list", "count", "add", "get", "delete_list", "csv"] -class CrudResource(BaseCollectionResource): +class CrudResource(BaseCollectionResource, ContextVariableInjectorResourceMixin): + def __init__( + self, + datasource_composite: CompositeDatasource, + datasource: Union[Datasource, DatasourceCustomizer], + permission: PermissionService, + ip_white_list_service: IpWhiteListService, + options: Options, + ): + self._datasource_composite = datasource_composite + super().__init__(datasource, permission, ip_white_list_service, options) + @ip_white_list async def dispatch(self, request: Request, method_name: LiteralMethod) -> Response: method = getattr(self, method_name) @@ -159,6 +178,9 @@ async def list(self, request: RequestCollection) -> Response: scope_tree = await self.permission.get_scope(request.user, request.collection) try: paginated_filter = build_paginated_filter(request, scope_tree) + condition_tree = await self._handle_live_query_segment(request, paginated_filter.condition_tree) + paginated_filter = paginated_filter.override({"condition_tree": condition_tree}) + except FilterException as e: ForestLogger.log("exception", e) return HttpResponseBuilder.build_client_error_response([e]) @@ -191,6 +213,8 @@ async def csv(self, request: RequestCollection) -> Response: scope_tree = await self.permission.get_scope(request.user, request.collection) try: paginated_filter = build_paginated_filter(request, scope_tree) + condition_tree = await self._handle_live_query_segment(request, paginated_filter.condition_tree) + paginated_filter = paginated_filter.override({"condition_tree": condition_tree}) paginated_filter.page = None except FilterException as e: ForestLogger.log("exception", e) @@ -220,9 +244,12 @@ async def count(self, request: RequestCollection) -> Response: return HttpResponseBuilder.build_success_response({"meta": {"count": "deactivated"}}) scope_tree = await self.permission.get_scope(request.user, request.collection) - filter = build_filter(request, scope_tree) + filter_ = build_filter(request, scope_tree) + filter_ = filter_.override( + {"condition_tree": await self._handle_live_query_segment(request, filter_.condition_tree)} + ) aggregation = Aggregation({"operation": "Count"}) - result = await request.collection.aggregate(request.user, filter, aggregation) + result = await request.collection.aggregate(request.user, filter_, aggregation) try: count = result[0]["value"] except IndexError: @@ -424,3 +451,33 @@ def _serialize_records_with_relationships( schema = JsonApiSerializer.get(collection) return schema(projections=projection).dump(records if many is True else records[0], many=many) + + async def _handle_live_query_segment( + self, request: RequestCollection, condition_tree: Optional[ConditionTree] + ) -> Optional[ConditionTree]: + if request.query.get("segmentQuery") is not None: + if request.query.get("connectionName") in ["", None]: + raise NativeQueryException("Missing native query connection attribute") + + await self.permission.can_live_query_segment(request) + SqlQueryChecker.check_query(request.query["segmentQuery"]) + variables = await self.inject_and_get_context_variables_in_live_query_segment(request) + native_query_result = await self._datasource_composite.execute_native_query( + request.query["connectionName"], request.query["segmentQuery"], variables + ) + + pk_field = SchemaUtils.get_primary_keys(request.collection.schema)[0] + if len(native_query_result) > 0 and pk_field not in native_query_result[0]: + raise NativeQueryException(f"Live query must return the primary key field ('{pk_field}').") + + trees = [] + if condition_tree: + trees.append(condition_tree) + trees.append( + ConditionTreeFactory.match_ids( + request.collection.schema, + [RecordUtils.get_primary_key(request.collection.schema, r) for r in native_query_result], + ) + ) + return ConditionTreeFactory.intersect(trees) + return condition_tree diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py new file mode 100644 index 000000000..601e4a3b8 --- /dev/null +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/native_query.py @@ -0,0 +1,141 @@ +from typing import Any, Dict, List, Literal, Union +from uuid import uuid4 + +from forestadmin.agent_toolkit.forest_logger import ForestLogger +from forestadmin.agent_toolkit.options import Options +from forestadmin.agent_toolkit.resources.collections.base_collection_resource import BaseCollectionResource +from forestadmin.agent_toolkit.resources.collections.decorators import authenticate, check_method, ip_white_list +from forestadmin.agent_toolkit.resources.context_variable_injector_mixin import ContextVariableInjectorResourceMixin +from forestadmin.agent_toolkit.services.permissions.ip_whitelist_service import IpWhiteListService +from forestadmin.agent_toolkit.services.permissions.permission_service import PermissionService +from forestadmin.agent_toolkit.utils.context import HttpResponseBuilder, Request, RequestMethod, Response +from forestadmin.agent_toolkit.utils.sql_query_checker import SqlQueryChecker +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource +from forestadmin.datasource_toolkit.datasource_customizer.datasource_customizer import DatasourceCustomizer +from forestadmin.datasource_toolkit.exceptions import BusinessError, ForbiddenError, UnprocessableError, ValidationError +from forestadmin.datasource_toolkit.interfaces.chart import ( + Chart, + DistributionChart, + LeaderboardChart, + ObjectiveChart, + TimeBasedChart, + ValueChart, +) +from forestadmin.datasource_toolkit.interfaces.models.collections import BoundCollection, Datasource + +DatasourceAlias = Union[Datasource[BoundCollection], DatasourceCustomizer] + + +LiteralMethod = Literal["native_query"] + + +class NativeQueryResource(BaseCollectionResource, ContextVariableInjectorResourceMixin): + def __init__( + self, + composite_datasource: CompositeDatasource, + datasource: DatasourceAlias, + permission: PermissionService, + ip_white_list_service: IpWhiteListService, + options: Options, + ): + super().__init__(datasource, permission, ip_white_list_service, options) + self.composite_datasource: CompositeDatasource = composite_datasource + + @ip_white_list + async def dispatch(self, request: Request, method_name: Literal["native_query"]) -> Response: + try: + return await self.handle_native_query(request) # type:ignore + except ForbiddenError as exc: + return HttpResponseBuilder.build_client_error_response([exc]) + except Exception as exc: + ForestLogger.log("exception", exc) + return HttpResponseBuilder.build_client_error_response([exc]) + + @check_method(RequestMethod.POST) + @authenticate + async def handle_native_query(self, request: Request) -> Response: + await self.permission.can_chart(request) + assert request.body is not None + if request.body.get("connectionName") is None: + raise BusinessError("Missing native query connection attribute") + if "query" not in request.body: + raise BusinessError("Missing 'query' in parameter.") + if request.body.get("type") not in ["Line", "Objective", "Leaderboard", "Pie", "Value"]: + raise ValidationError(f"Unknown chart type '{request.body.get('type')}'.") + + SqlQueryChecker.check_query(request.body["query"]) + variables = await self.inject_and_get_context_variables_in_live_query_chart(request) + native_query_results = await self.composite_datasource.execute_native_query( + request.body["connectionName"], request.body["query"], variables + ) + + chart_result: Chart + if request.body["type"] == "Line": + chart_result = self._handle_line_chart(native_query_results) + elif request.body["type"] == "Objective": + chart_result = self._handle_objective_chart(native_query_results) + elif request.body["type"] == "Leaderboard": + chart_result = self._handle_leaderboard_chart(native_query_results) + elif request.body["type"] == "Pie": + chart_result = self._handle_pie_chart(native_query_results) + elif request.body["type"] == "Value": + chart_result = self._handle_value_chart(native_query_results) + + return HttpResponseBuilder.build_success_response( + { + "data": { + "id": str(uuid4()), + "type": "stats", + "attributes": { + "value": chart_result, # type:ignore + }, + } + } + ) + + def _handle_line_chart(self, native_query_results: List[Dict[str, Any]]) -> TimeBasedChart: + if len(native_query_results) >= 1: + if "key" not in native_query_results[0] or "value" not in native_query_results[0]: + raise UnprocessableError("Native query for 'Line' chart must return 'key' and 'value' fields.") + + return [{"label": res["key"], "values": {"value": res["value"]}} for res in native_query_results] + + def _handle_objective_chart(self, native_query_results: List[Dict[str, Any]]) -> ObjectiveChart: + if len(native_query_results) == 1: + if "value" not in native_query_results[0] or "objective" not in native_query_results[0]: + raise UnprocessableError( + "Native query for 'Objective' chart must return 'value' and 'objective' fields." + ) + else: + raise UnprocessableError("Native query for 'Objective' chart must return only one row.") + + return { + "value": native_query_results[0]["value"], + "objective": native_query_results[0]["objective"], + } + + def _handle_leaderboard_chart(self, native_query_results: List[Dict[str, Any]]) -> LeaderboardChart: + if len(native_query_results) >= 1: + if "key" not in native_query_results[0] or "value" not in native_query_results[0]: + raise UnprocessableError("Native query for 'Leaderboard' chart must return 'key' and 'value' fields.") + + return [{"key": res["key"], "value": res["value"]} for res in native_query_results] + + def _handle_pie_chart(self, native_query_results: List[Dict[str, Any]]) -> DistributionChart: + if len(native_query_results) >= 1: + if "key" not in native_query_results[0] or "value" not in native_query_results[0]: + raise UnprocessableError("Native query for 'Pie' chart must return 'key' and 'value' fields.") + + return [{"key": res["key"], "value": res["value"]} for res in native_query_results] + + def _handle_value_chart(self, native_query_results: List[Dict[str, Any]]) -> ValueChart: + if len(native_query_results) == 1: + if "value" not in native_query_results[0]: + raise UnprocessableError("Native query for 'Value' chart must return 'value' field.") + else: + raise UnprocessableError("Native query for 'Value' chart must return only one row.") + + ret = {"countCurrent": native_query_results[0]["value"]} + if "previous" in native_query_results[0]: + ret["countPrevious"] = native_query_results[0]["previous"] + return ret # type:ignore diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/requests.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/requests.py index 598aa0eb0..37c6c0809 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/requests.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/requests.py @@ -39,7 +39,7 @@ class RequestCollection(Request): def __init__( self, method: RequestMethod, - collection: Union[Collection, CollectionCustomizer], + collection: Collection, headers: Dict[str, str], client_ip: str, query: Dict[str, str], diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py index 465ef794a..816d6c4c2 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py @@ -1,4 +1,3 @@ -import json from datetime import date, datetime from typing import Any, Dict, List, Literal, Optional, Union, cast from uuid import uuid1 @@ -14,12 +13,10 @@ ) from forestadmin.agent_toolkit.resources.collections.filter import build_filter from forestadmin.agent_toolkit.resources.collections.requests import RequestCollection, RequestCollectionException +from forestadmin.agent_toolkit.resources.context_variable_injector_mixin import ContextVariableInjectorResourceMixin from forestadmin.agent_toolkit.utils.context import FileResponse, HttpResponseBuilder, Request, RequestMethod, Response -from forestadmin.agent_toolkit.utils.context_variable_injector import ContextVariableInjector -from forestadmin.agent_toolkit.utils.context_variable_instantiator import ContextVariablesInstantiator -from forestadmin.datasource_toolkit.exceptions import ForestException +from forestadmin.datasource_toolkit.exceptions import ForbiddenError, ForestException from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation -from forestadmin.datasource_toolkit.interfaces.query.condition_tree.factory import ConditionTreeFactory from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.base import ConditionTree from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import Aggregator, ConditionTreeBranch from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf @@ -28,7 +25,7 @@ from forestadmin.datasource_toolkit.utils.schema import SchemaUtils -class StatsResource(BaseCollectionResource): +class StatsResource(BaseCollectionResource, ContextVariableInjectorResourceMixin): FREQUENCIES = {"Day": "d", "Week": "W-MON", "Month": "BMS", "Year": "BYS"} FORMAT = {"Day": "%d/%m/%Y", "Week": "W%V-%Y", "Month": "%b %Y", "Year": "%Y"} @@ -66,6 +63,8 @@ async def dispatch( ) try: return await meth(request_collection) + except ForbiddenError as exc: + return HttpResponseBuilder.build_client_error_response([exc]) except Exception as exc: ForestLogger.log("exception", exc) return HttpResponseBuilder.build_client_error_response([exc]) @@ -234,7 +233,7 @@ def _use_interval_res(tree: ConditionTree) -> None: async def _get_filter(self, request: RequestCollection) -> Filter: scope_tree = await self.permission.get_scope(request.user, request.collection) - await self.__inject_context_variables(request) + await self.inject_context_variables_in_filter(request) return build_filter(request, scope_tree) async def _compute_value(self, request: RequestCollection, filter: Filter) -> int: @@ -246,27 +245,3 @@ async def _compute_value(self, request: RequestCollection, filter: Filter) -> in if len(rows): res = int(rows[0]["value"]) return res - - async def __inject_context_variables(self, request: RequestCollection): - context_variables_dct = request.body.pop("contextVariables", {}) - if request.body.get("filter") is None: - return - - context_variables = await ContextVariablesInstantiator.build_context_variables( - request.user, context_variables_dct, self.permission - ) - condition_tree = request.body["filter"] - if isinstance(request.body["filter"], str): - condition_tree = json.loads(condition_tree) - - condition_tree = ConditionTreeFactory.from_plain_object(condition_tree) - - injected_filter: ConditionTree = condition_tree.replace( - lambda leaf: ConditionTreeLeaf( - leaf.field, - leaf.operator, - ContextVariableInjector.inject_context_in_value(leaf.value, context_variables), - ) - ) - - request.body["filter"] = injected_filter.to_plain_object() diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py new file mode 100644 index 000000000..d6c6b73f1 --- /dev/null +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/context_variable_injector_mixin.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Dict + +from forestadmin.agent_toolkit.utils.context_variable_injector import ContextVariableInjector +from forestadmin.agent_toolkit.utils.context_variable_instantiator import ContextVariablesInstantiator +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.factory import ConditionTreeFactory +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.base import ConditionTree +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf + +if TYPE_CHECKING: + from forestadmin.agent_toolkit.resources.collections.requests import RequestCollection + from forestadmin.agent_toolkit.utils.context import Request + + +class ContextVariableInjectorResourceMixin: + async def inject_context_variables_in_filter(self, request: "RequestCollection"): + context_variables_dct = request.body.pop("contextVariables", {}) + if request.body.get("filter") is None: + return + + context_variables = await ContextVariablesInstantiator.build_context_variables( + request.user, context_variables_dct, self.permission + ) + condition_tree = request.body["filter"] + if isinstance(request.body["filter"], str): + condition_tree = json.loads(condition_tree) + + condition_tree = ConditionTreeFactory.from_plain_object(condition_tree) + + injected_filter: ConditionTree = condition_tree.replace( + lambda leaf: ConditionTreeLeaf( + leaf.field, + leaf.operator, + ContextVariableInjector.inject_context_in_value(leaf.value, context_variables), + ) + ) + + request.body["filter"] = injected_filter.to_plain_object() + + async def inject_and_get_context_variables_in_live_query_segment( + self, request: "RequestCollection" + ) -> Dict[str, str]: + context_variables_dct = request.query.pop("contextVariables", {}) + + context_variables = await ContextVariablesInstantiator.build_context_variables( + request.user, context_variables_dct, self.permission + ) + + request.query["segmentQuery"], vars = ContextVariableInjector.format_query_and_get_vars( + request.query["segmentQuery"], context_variables + ) + return vars + + async def inject_and_get_context_variables_in_live_query_chart(self, request: "Request") -> Dict[str, str]: + context_variables_dct = request.body.get("contextVariables", {}) + context_variables = await ContextVariablesInstantiator.build_context_variables( + request.user, context_variables_dct, self.permission + ) + + request.body["query"], vars = ContextVariableInjector.format_query_and_get_vars( + request.body["query"], context_variables + ) + return vars diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py index 78ffbfc0f..a9f0f020f 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permission_service.py @@ -8,6 +8,8 @@ _decode_actions_permissions, _decode_crud_permissions, _decode_scope_permissions, + _decode_segment_query_permissions, + _dict_hash, _hash_chart, ) from forestadmin.agent_toolkit.services.permissions.smart_actions_checker import SmartActionChecker @@ -27,7 +29,7 @@ class PermissionService: def __init__(self, options: RoleOptions): self.options = options - self.cache: TTLCache[int, Any] = TTLCache(maxsize=256, ttl=options["permission_cache_duration"]) + self.cache: TTLCache[str, Any] = TTLCache(maxsize=256, ttl=options["permission_cache_duration"]) def invalidate_cache(self, key: str): if key in self.cache: @@ -58,11 +60,11 @@ async def can(self, caller: User, collection: Collection, action: str, allow_fet async def can_chart(self, request: RequestCollection) -> bool: hash_request = request.body["type"] + ":" + _hash_chart(request.body) - is_allowed = hash_request in await self._get_chart_data(request.user.rendering_id, False) + is_allowed = hash_request in (await self._get_rendering_data(request.user.rendering_id, False))["stats"] # Refetch if is_allowed is False: - is_allowed = hash_request in await self._get_chart_data(request.user.rendering_id, True) + is_allowed = hash_request in (await self._get_rendering_data(request.user.rendering_id, True))["stats"] # still not allowed - throw forbidden message if is_allowed is False: @@ -75,6 +77,33 @@ async def can_chart(self, request: RequestCollection) -> bool: ) return is_allowed + async def can_live_query_segment(self, request: RequestCollection) -> bool: + live_query = request.query["segmentQuery"] + connection_name = request.query["connectionName"] + hash_live_query = _dict_hash({"query": live_query, "connection_name": connection_name}) + is_allowed = hash_live_query in ( + (await self._get_rendering_data(request.user.rendering_id, False))["segment_queries"] + ).get(request.collection.name) + + # Refetch + if is_allowed is False: + is_allowed = ( + hash_live_query in (await self._get_rendering_data(request.user.rendering_id, True))["segment_queries"] + ) + + # still not allowed - throw forbidden message + if is_allowed is False: + ForestLogger.log( + "debug", + f"User {request.user.user_id} cannot retrieve segment queries on rendering {request.user.rendering_id}", + ) + raise ForbiddenError("You don't have permission to access this segment query.") + ForestLogger.log( + "debug", + f"User {request.user.user_id} can retrieve segment queries on rendering {request.user.rendering_id}", + ) + return is_allowed + async def can_smart_action( self, request: RequestCollection, collection: Collection, filter_: Filter, allow_fetch: bool = True ): @@ -109,7 +138,7 @@ async def get_scope( caller: User, collection: Union[Collection, CollectionCustomizer], ) -> Optional[ConditionTree]: - permissions = await self._get_scope_and_team_data(caller.rendering_id) + permissions = await self._get_rendering_data(caller.rendering_id) scope = permissions["scopes"].get(collection.name) if scope is None: return None @@ -140,24 +169,9 @@ async def get_user_data(self, user_id: int): return self.cache["forest.users"][user_id] async def get_team(self, rendering_id: int): - permissions = await self._get_scope_and_team_data(rendering_id) + permissions = await self._get_rendering_data(rendering_id) return permissions["team"] - async def _get_chart_data(self, rendering_id: int, force_fetch: bool = False) -> Dict: - if force_fetch and "forest.stats" in self.cache: - del self.cache["forest.stats"] - - if "forest.stats" not in self.cache: - ForestLogger.log("debug", f"Loading rendering permissions for rendering {rendering_id}") - response = await ForestHttpApi.get_rendering_permissions(rendering_id, self.options) - - stat_hash = [] - for stat in response["stats"]: - stat_hash.append(f'{stat["type"]}:{_hash_chart(stat)}') - self.cache["forest.stats"] = stat_hash - - return self.cache["forest.stats"] - async def _get_collection_permissions_data(self, force_fetch: bool = False): if force_fetch and "forest.collections" in self.cache: del self.cache["forest.collections"] @@ -175,14 +189,33 @@ async def _get_collection_permissions_data(self, force_fetch: bool = False): return self.cache["forest.collections"] - async def _get_scope_and_team_data(self, rendering_id: int): - if "forest.scopes" not in self.cache: + async def _get_rendering_data(self, rendering_id: int, force_fetch: bool = False): + if force_fetch and "forest.rendering" in self.cache: + del self.cache["forest.rendering"] + + if "forest.rendering" not in self.cache: response = await ForestHttpApi.get_rendering_permissions(rendering_id, self.options) - data = {"scopes": _decode_scope_permissions(response["collections"]), "team": response["team"]} + self._handle_rendering_permissions(response) + + return self.cache["forest.rendering"] + + def _handle_rendering_permissions(self, rendering_permissions): + rendering_cache = {} + + # forest.stats + stat_hash = [] + for stat in rendering_permissions["stats"]: + stat_hash.append(f'{stat["type"]}:{_hash_chart(stat)}') + rendering_cache["stats"] = stat_hash + + # forest.scopes + rendering_cache["scopes"] = _decode_scope_permissions(rendering_permissions["collections"]) + rendering_cache["team"] = rendering_permissions["team"] - self.cache["forest.scopes"] = data + # forest.segment_queries + rendering_cache["segment_queries"] = _decode_segment_query_permissions(rendering_permissions["collections"]) - return self.cache["forest.scopes"] + self.cache["forest.rendering"] = rendering_cache async def _find_action_from_endpoint( self, collection: Collection, get_params: Dict, http_method: str diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py index e86ccda40..259f0f40d 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_functions.py @@ -11,6 +11,22 @@ ################## +def _decode_segment_query_permissions(raw_permission: Dict[Any, Any]): + segment_queries = {} + for collection_name, value in raw_permission.items(): + segment_queries[collection_name] = [] + for segment_query in value.get("liveQuerySegments", []): + segment_queries[collection_name].append( + _dict_hash( + { + "query": segment_query["query"], + "connection_name": segment_query["connectionName"], + } + ) + ) + return segment_queries + + def _decode_scope_permissions(raw_permission: Dict[Any, Any]) -> Dict[str, ConditionTree]: scopes = {} for collection_name, value in raw_permission.items(): @@ -57,6 +73,7 @@ def _dict_hash(data: Dict[Any, Any]) -> str: def _hash_chart(chart: Dict[Any, Any]) -> str: known_chart_keys = [ + "connectionName", "type", "apiRoute", "smartRoute", diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_types.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_types.py index 5163890b9..34117a132 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_types.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/permissions_types.py @@ -3,6 +3,7 @@ from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import ConditionTreeBranch +# this file is never imported class PermissionBody(TypedDict): actions: Set[str] actions_by_user: Dict[str, Set[int]] diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/sse_cache_invalidation.py b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/sse_cache_invalidation.py index f531a307e..29dc247d6 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/sse_cache_invalidation.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/services/permissions/sse_cache_invalidation.py @@ -1,23 +1,28 @@ +from __future__ import annotations + import time from threading import Thread -from typing import Dict +from typing import TYPE_CHECKING, Dict, List import urllib3 from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.agent_toolkit.options import Options from sseclient import SSEClient +if TYPE_CHECKING: + from forestadmin.agent_toolkit.services.permissions.permission_service import PermissionService + class SSECacheInvalidation(Thread): - _MESSAGE__CACHE_KEYS: Dict[str, str] = { + _MESSAGE__CACHE_KEYS: Dict[str, List[str]] = { "refresh-users": ["forest.users"], "refresh-roles": ["forest.collections"], - "refresh-renderings": ["forest.collections", "forest.stats", "forest.scopes"], + "refresh-renderings": ["forest.collections", "forest.rendering"], # "refresh-customizations": None, # work with nocode actions # TODO: add one for ip whitelist when server implement it } - def __init__(self, permission_service: "PermissionService", options: Options, *args, **kwargs): # noqa: F821 + def __init__(self, permission_service: "PermissionService", options: Options, *args, **kwargs): super().__init__(name="SSECacheInvalidationThread", daemon=True, *args, **kwargs) self.permission_service = permission_service self.options: Options = options diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py index ed3ac7a4a..fbadca4ee 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context.py @@ -150,6 +150,8 @@ def _get_error_status(error: Exception): return 403 if isinstance(error, UnprocessableError): return 422 + if isinstance(error, BusinessError): + return 400 if isinstance(error, HTTPError): return error.code diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py index 35de62534..2bcc3c076 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/utils/context_variable_injector.py @@ -1,5 +1,5 @@ import re -from typing import Optional +from typing import Dict, Optional, Tuple from forestadmin.agent_toolkit.utils.context_variables import ContextVariables from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.base import ConditionTree @@ -29,3 +29,22 @@ def inject_context_in_value(value, context_variable: ContextVariables): return value return re.sub(r"{{([^}]+)}}", lambda match: str(context_variable.get_value(match.group(1))), value) + + @staticmethod + def format_query_and_get_vars(value: str, context_variable: ContextVariables) -> Tuple[str, Dict[str, str]]: + if not isinstance(value, str): + return value + variables = {} + # to allow datasources to rework variables: + # - '%' are replaced by '\%' + # - '{{var}}' are replaced by '%(var)s' + # - '.' in vars are replaced by '__', and also in the returned mapping + # - and the mapping of vars is returned + + def _match(match): + variables[match.group(1).replace(".", "__")] = context_variable.get_value(match.group(1)) + return f"%({match.group(1).replace('.', '__')})s" + + ret = re.sub(r"{{([^}]+)}}", _match, value.replace("%", "\\%")) + + return (ret, variables) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/utils/sql_query_checker.py b/src/agent_toolkit/forestadmin/agent_toolkit/utils/sql_query_checker.py new file mode 100644 index 000000000..975f3c1ad --- /dev/null +++ b/src/agent_toolkit/forestadmin/agent_toolkit/utils/sql_query_checker.py @@ -0,0 +1,37 @@ +import re + +from forestadmin.datasource_toolkit.exceptions import NativeQueryException + + +class EmptySQLQueryException(NativeQueryException): + def __init__(self, *args: object) -> None: + super().__init__("You cannot execute an empty SQL query.") + + +class ChainedSQLQueryException(NativeQueryException): + def __init__(self, *args: object) -> None: + super().__init__("You cannot chain SQL queries.") + + +class NonSelectSQLQueryException(NativeQueryException): + def __init__(self, *args: object) -> None: + super().__init__("Only SELECT queries are allowed.") + + +class SqlQueryChecker: + QUERY_SELECT = re.compile(r"^SELECT\s(.|\n)*FROM\s(.|\n)*$", re.IGNORECASE) + + @staticmethod + def check_query(input_query: str) -> bool: + input_query_trimmed = input_query.strip() + + if len(input_query_trimmed) == 0: + raise EmptySQLQueryException() + + if ";" in input_query_trimmed and input_query_trimmed.index(";") != len(input_query_trimmed) - 1: + raise ChainedSQLQueryException() + + if not SqlQueryChecker.QUERY_SELECT.match(input_query_trimmed): + raise NonSelectSQLQueryException() + + return True diff --git a/src/agent_toolkit/tests/resources/collections/test_crud.py b/src/agent_toolkit/tests/resources/collections/test_crud.py index c167d6a8e..70a0abe89 100644 --- a/src/agent_toolkit/tests/resources/collections/test_crud.py +++ b/src/agent_toolkit/tests/resources/collections/test_crud.py @@ -25,27 +25,31 @@ from forestadmin.agent_toolkit.utils.context import Request, RequestMethod, User from forestadmin.agent_toolkit.utils.csv import CsvException from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasources import Datasource, DatasourceException -from forestadmin.datasource_toolkit.exceptions import ValidationError +from forestadmin.datasource_toolkit.exceptions import ForbiddenError, NativeQueryException, ValidationError from forestadmin.datasource_toolkit.interfaces.fields import FieldType, Operator, PrimitiveType +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import ConditionTreeBranch from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf from forestadmin.datasource_toolkit.interfaces.query.projections import Projection from forestadmin.datasource_toolkit.validations.records import RecordValidatorException +FAKE_USER = User( + rendering_id=1, + user_id=1, + tags={}, + email="dummy@user.fr", + first_name="dummy", + last_name="user", + team="operational", + timezone=zoneinfo.ZoneInfo("Europe/Paris"), + request={"ip": "127.0.0.1"}, +) + def authenticate_mock(fn): async def wrapped2(self, request): - request.user = User( - rendering_id=1, - user_id=1, - tags={}, - email="dummy@user.fr", - first_name="dummy", - last_name="user", - team="operational", - timezone=zoneinfo.ZoneInfo("Europe/Paris"), - request={"ip": "127.0.0.1"}, - ) + request.user = FAKE_USER return await fn(self, request) @@ -226,7 +230,8 @@ def setUpClass(cls) -> None: is_production=False, ) # cls.datasource = Mock(Datasource) - cls.datasource = Datasource() + cls.datasource = Datasource(["db_connection"]) + cls.datasource_composite = CompositeDatasource() cls.datasource.get_collection = lambda x: cls.datasource._collections[x] cls._create_collections() cls.datasource._collections = { @@ -236,6 +241,7 @@ def setUpClass(cls) -> None: "product": cls.collection_product, "tag": cls.collection_tag, } + cls.datasource_composite.add_datasource(cls.datasource) for collection in cls.datasource.collections: create_json_api_schema(collection) @@ -246,6 +252,7 @@ def setUp(self): self.permission_service = Mock(PermissionService) self.permission_service.get_scope = AsyncMock(return_value=ConditionTreeLeaf("id", Operator.GREATER_THAN, 0)) self.permission_service.can = AsyncMock(return_value=None) + self.permission_service.can_live_query_segment = AsyncMock(return_value=None) @classmethod def tearDownClass(cls) -> None: @@ -262,7 +269,13 @@ def test_dispatch(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) crud_resource.get = AsyncMock() crud_resource.list = AsyncMock() crud_resource.csv = AsyncMock() @@ -306,7 +319,13 @@ def test_dispatch_error(self, mock_request_collection: Mock): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) crud_resource.get = AsyncMock() with patch.object( @@ -357,7 +376,13 @@ def test_get( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.dump = Mock( return_value={"data": {"type": "order", "attributes": mock_order}} ) @@ -430,7 +455,13 @@ def test_get_no_data( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) response = self.loop.run_until_complete(crud_resource.get(request)) self.permission_service.can.assert_any_await(request.user, request.collection, "read") @@ -470,7 +501,13 @@ def test_get_errors( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.dump = Mock(side_effect=JsonApiException) response = self.loop.run_until_complete(crud_resource.get(request)) @@ -517,7 +554,13 @@ def test_get_should_return_to_many_relations_as_link(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) with patch.object(self.collection_order, "list", new_callable=AsyncMock, return_value=mock_orders): response = self.loop.run_until_complete(crud_resource.get(request)) @@ -545,7 +588,13 @@ def test_get_with_polymorphic_relation_should_add_projection_star(self, mocked_j headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.dump = Mock( return_value={ "data": {"type": "tag", "attributes": {"id": 10, "taggable_id": 10, "taggable_type": "product"}}, @@ -585,7 +634,13 @@ def test_add( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.load = Mock(return_value=mock_order) mocked_json_serializer_get.return_value.dump = Mock( return_value={"data": {"type": "order", "attributes": mock_order}} @@ -629,7 +684,13 @@ def test_add_errors( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) crud_resource.extract_data = AsyncMock(return_value=(mock_order, [])) # JsonApiException @@ -724,7 +785,13 @@ def test_add_with_relation( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.load = Mock(return_value=mock_order) mocked_json_serializer_get.return_value.dump = Mock( return_value={"data": {"type": "order", "attributes": mock_order}} @@ -802,7 +869,13 @@ def test_add_should_create_and_associate_polymorphic_many_to_one(self, mocked_js return_value={"taggable_id": 14, "taggable_type": "order", "tag": "aaaaa"} ) mocked_json_serializer_get.return_value.dump = Mock(return_value={}) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) with patch.object( self.collection_tag, "create", new_callable=AsyncMock, return_value=[{}] @@ -838,7 +911,13 @@ def test_add_should_create_and_associate_polymorphic_one_to_one(self, mocked_jso mocked_json_serializer_get.return_value.load = Mock(return_value={"cost": 12.3, "important": True, "tags": 22}) mocked_json_serializer_get.return_value.dump = Mock(return_value={}) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) with patch.object( self.collection_order, "create", @@ -888,7 +967,13 @@ def test_add_should_return_to_many_relations_as_link(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) with patch.object(self.collection_order, "create", new_callable=AsyncMock, return_value=mock_orders): response = self.loop.run_until_complete(crud_resource.add(request)) @@ -924,7 +1009,13 @@ def test_list(self, mocked_json_serializer_get: Mock): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.dump = Mock( return_value={ "data": [ @@ -966,7 +1057,13 @@ def test_list_should_return_to_many_relations_as_link(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) with patch.object(self.collection_order, "list", new_callable=AsyncMock, return_value=mock_orders): response = self.loop.run_until_complete(crud_resource.list(request)) @@ -990,7 +1087,13 @@ def test_list_should_return_to_many_relations_as_link(self): def test_list_with_polymorphic_many_to_one_should_query_all_relation_record_columns( self, mocked_json_serializer_get: Mock ): - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.dump = Mock( return_value={ "data": [ @@ -1064,9 +1167,9 @@ def test_list_with_polymorphic_many_to_one_should_query_all_relation_record_colu ) def test_list_should_parse_multi_field_sorting(self, mocked_json_serializer_get: Mock): mock_orders = [ - {"id": 10, "cost": 200, "important": "02_PENDING"}, - {"id": 11, "cost": 201, "important": "02_PENDING"}, - {"id": 13, "cost": 20, "important": "01_URGENT"}, + {"id": 10, "cost": 200, "important": True}, + {"id": 11, "cost": 201, "important": True}, + {"id": 13, "cost": 20, "important": False}, ] request = RequestCollection( RequestMethod.GET, @@ -1081,7 +1184,13 @@ def test_list_should_parse_multi_field_sorting(self, mocked_json_serializer_get: headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) mocked_json_serializer_get.return_value.dump = Mock( return_value={ "data": [ @@ -1099,6 +1208,39 @@ def test_list_should_parse_multi_field_sorting(self, mocked_json_serializer_get: self.assertEqual(paginated_filter.sort[0], {"field": "important", "ascending": True}) self.assertEqual(paginated_filter.sort[1], {"field": "cost", "ascending": False}) + def test_list_should_handle_live_query_segment(self): + mock_orders = [{"id": 10, "cost": 200}, {"id": 11, "cost": 201}] + + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost", + "search": "test", + "segmentName": "test_live_query", + "segmentQuery": "select id from order where important is true;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + self.collection_order.list = AsyncMock(return_value=mock_orders) + + with patch.object( + crud_resource, "_handle_live_query_segment", new_callable=AsyncMock + ) as mock_handle_live_queries: + self.loop.run_until_complete(crud_resource.list(request)) + mock_handle_live_queries.assert_awaited_once_with(request, ConditionTreeLeaf("id", "greater_than", 0)) + @patch( "forestadmin.agent_toolkit.resources.collections.crud.JsonApiSerializer.get", return_value=Mock, @@ -1115,7 +1257,13 @@ def test_list_errors(self, mocked_json_serializer_get: Mock): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) # FilterException response = self.loop.run_until_complete(crud_resource.list(request)) @@ -1182,7 +1330,13 @@ def test_count(self): client_ip="127.0.0.1", ) self.collection_order.aggregate = AsyncMock(return_value=[{"value": 1000, "group": {}}]) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) response = self.loop.run_until_complete(crud_resource.count(request)) self.permission_service.can.assert_any_await(request.user, request.collection, "browse") @@ -1204,6 +1358,37 @@ def test_count(self): assert response_content["count"] == 0 self.collection_order.aggregate.assert_called() + def test_count_should_handle_live_query_segment(self): + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost", + "search": "test", + "segmentName": "test_live_query", + "segmentQuery": "select id from order where important is true;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + self.collection_order.aggregate = AsyncMock(return_value=[{"value": 1000, "group": {}}]) + + with patch.object( + crud_resource, "_handle_live_query_segment", new_callable=AsyncMock + ) as mock_handle_live_queries: + self.loop.run_until_complete(crud_resource.count(request)) + mock_handle_live_queries.assert_awaited_once_with(request, ConditionTreeLeaf("id", "greater_than", 0)) + def test_deactivate_count(self): request = RequestCollection( RequestMethod.GET, @@ -1212,7 +1397,13 @@ def test_deactivate_count(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order._schema["countable"] = False response = self.loop.run_until_complete(crud_resource.count(request)) @@ -1256,7 +1447,13 @@ def test_edit( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order.list = AsyncMock(return_value=[mock_order]) self.collection_order.update = AsyncMock() mocked_json_serializer_get.return_value.load = Mock(return_value=mock_order) @@ -1308,7 +1505,13 @@ def test_edit_errors( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) # CollectionResourceException with patch( @@ -1360,7 +1563,13 @@ def test_edit_should_not_throw_and_do_nothing_on_empty_record(self): mock_order = {"id": 10, "cost": 201} self.collection_order.list = AsyncMock(return_value=[mock_order]) self.collection_order.update = AsyncMock() - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) response = self.loop.run_until_complete(crud_resource.update(request)) self.permission_service.can.reset_mock() @@ -1384,7 +1593,13 @@ def test_edit_should_not_update_pk_if_not_set_in_attributes(self): client_ip="127.0.0.1", ) mock_order = {"id": 10, "cost": 201} - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order.list = AsyncMock(return_value=[mock_order]) self.collection_order.update = AsyncMock() @@ -1407,7 +1622,13 @@ def test_edit_should_update_pk_if_set_in_attributes(self): client_ip="127.0.0.1", ) mock_order = {"id": 11, "cost": 201} - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order.list = AsyncMock(return_value=[mock_order]) self.collection_order.update = AsyncMock() @@ -1434,7 +1655,13 @@ def test_update_should_return_to_many_relations_as_link(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) with patch.object(self.collection_order, "list", new_callable=AsyncMock, return_value=mock_orders): response = self.loop.run_until_complete(crud_resource.update(request)) @@ -1474,7 +1701,13 @@ def test_delete( client_ip="127.0.0.1", ) self.collection_order.delete = AsyncMock() - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) response = self.loop.run_until_complete(crud_resource.delete(request)) self.permission_service.can.assert_any_await(request.user, request.collection, "delete") self.permission_service.can.reset_mock() @@ -1490,7 +1723,13 @@ def test_delete_error(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) # CollectionResourceException with patch( @@ -1524,7 +1763,13 @@ def test_delete_list( headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order.delete = AsyncMock() response = self.loop.run_until_complete(crud_resource.delete_list(request)) @@ -1549,7 +1794,13 @@ def test_csv(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order.list = AsyncMock(return_value=mock_orders) response = self.loop.run_until_complete(crud_resource.csv(request)) @@ -1580,7 +1831,13 @@ def test_csv_errors(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) # FilterException response = self.loop.run_until_complete(crud_resource.csv(request)) @@ -1654,7 +1911,13 @@ def test_csv_should_not_apply_pagination(self): headers={}, client_ip="127.0.0.1", ) - crud_resource = CrudResource(self.datasource, self.permission_service, self.ip_white_list_service, self.options) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) self.collection_order.list = AsyncMock(return_value=mock_orders) response = self.loop.run_until_complete(crud_resource.csv(request)) @@ -1665,3 +1928,289 @@ def test_csv_should_not_apply_pagination(self): self.assertIsNone(self.collection_order.list.await_args[0][1].page) self.collection_order.list.assert_awaited() self.assertIsNone(self.collection_order.list.await_args[0][1].page) + + def test_csv_should_handle_live_query_segment(self): + mock_orders = [{"id": 10, "cost": 200}, {"id": 11, "cost": 201}] + + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost", + "search": "test", + "segmentName": "test_live_query", + "segmentQuery": "select id from order where important is true;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + self.collection_order.list = AsyncMock(return_value=mock_orders) + + with patch.object( + crud_resource, "_handle_live_query_segment", new_callable=AsyncMock + ) as mock_handle_live_queries: + self.loop.run_until_complete(crud_resource.csv(request)) + mock_handle_live_queries.assert_awaited_once_with(request, ConditionTreeLeaf("id", "greater_than", 0)) + + # live queries + + def test_handle_native_query_should_handle_live_query_segments(self): + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost,important", + "segmentName": "test_live_query", + "segmentQuery": "select id from order where important is true;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + user=FAKE_USER, + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"id": 10}, {"id": 11}], + ) as mock_exec_native_query: + condition_tree = self.loop.run_until_complete(crud_resource._handle_live_query_segment(request, None)) + self.assertEqual(condition_tree, ConditionTreeLeaf("id", "in", [10, 11])) + mock_exec_native_query.assert_awaited_once_with( + "db_connection", "select id from order where important is true;", {} + ) + + def test_handle_native_query_should_inject_context_variable_and_handle_like_percent(self): + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost,important", + "segmentName": "test_live_query", + "segmentQuery": "select id from user where first_name like 'Ga%' or id = {{currentUser.id}};", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + user=FAKE_USER, + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + with patch.object( + self.permission_service, + "get_user_data", + new_callable=AsyncMock, + return_value={ + "id": 1, + "firstName": "dummy", + "lastName": "user", + "fullName": "dummy user", + "email": "dummy@user.fr", + "tags": {}, + "roleId": 8, + "permissionLevel": "admin", + }, + ): + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"id": 10}, {"id": 11}], + ) as mock_exec_native_query: + condition_tree = self.loop.run_until_complete(crud_resource._handle_live_query_segment(request, None)) + self.assertEqual(condition_tree, ConditionTreeLeaf("id", "in", [10, 11])) + mock_exec_native_query.assert_awaited_once_with( + "db_connection", + "select id from user where first_name like 'Ga\\%' or id = %(currentUser__id)s;", + {"currentUser__id": 1}, + ) + + def test_handle_native_query_should_intersect_existing_condition_tree(self): + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost,important", + "segmentName": "test_live_query", + "segmentQuery": "select id from user where id=10 or id=11;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + user=FAKE_USER, + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"id": 10}, {"id": 11}], + ): + condition_tree = self.loop.run_until_complete( + crud_resource._handle_live_query_segment(request, ConditionTreeLeaf("id", "equal", 25)) + ) + + self.assertEqual( + condition_tree, + ConditionTreeBranch( + "and", + [ + ConditionTreeLeaf("id", "equal", 25), + ConditionTreeLeaf("id", "in", [10, 11]), + ], + ), + ) + + def test_handle_native_query_should_raise_error_if_live_query_params_are_incorrect(self): + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost,important", + "segmentName": "test_live_query", + "segmentQuery": "select id from order where important is true;", + }, + headers={}, + client_ip="127.0.0.1", + user=FAKE_USER, + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + + self.assertRaisesRegex( + NativeQueryException, + "Missing native query connection attribute", + self.loop.run_until_complete, + crud_resource._handle_live_query_segment(request, None), + ) + + request.query["connectionName"] = None + self.assertRaisesRegex( + NativeQueryException, + "Missing native query connection attribute", + self.loop.run_until_complete, + crud_resource._handle_live_query_segment(request, None), + ) + + request.query["connectionName"] = "" + self.assertRaisesRegex( + NativeQueryException, + "Missing native query connection attribute", + self.loop.run_until_complete, + crud_resource._handle_live_query_segment(request, None), + ) + + def test_handle_native_query_should_raise_error_if_not_permission(self): + + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost,important", + "segmentName": "test_live_query", + "segmentQuery": "select id from order where important is true;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + user=FAKE_USER, + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + + with patch.object( + self.permission_service, + "can_live_query_segment", + new_callable=AsyncMock, + side_effect=ForbiddenError("You don't have permission to access this segment query."), + ): + self.assertRaisesRegex( + ForbiddenError, + "You don't have permission to access this segment query.", + self.loop.run_until_complete, + crud_resource._handle_live_query_segment(request, None), + ) + + def test_handle_native_query_should_raise_error_if_pk_not_returned(self): + request = RequestCollection( + RequestMethod.GET, + self.collection_order, + query={ + "collection_name": "order", + "timezone": "Europe/Paris", + "fields[order]": "id,cost,important", + "segmentName": "test_live_query", + "segmentQuery": "select id as bla, cost from order where important is true;", + "connectionName": "db_connection", + }, + headers={}, + client_ip="127.0.0.1", + user=FAKE_USER, + ) + crud_resource = CrudResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"bla": 10, "cost": 100}, {"bla": 11, "cost": 100}], + ): + self.assertRaisesRegex( + NativeQueryException, + r"Live query must return the primary key field \('id'\).", + self.loop.run_until_complete, + crud_resource._handle_live_query_segment(request, None), + ) diff --git a/src/agent_toolkit/tests/resources/collections/test_native_query.py b/src/agent_toolkit/tests/resources/collections/test_native_query.py new file mode 100644 index 000000000..a76ea3ed6 --- /dev/null +++ b/src/agent_toolkit/tests/resources/collections/test_native_query.py @@ -0,0 +1,854 @@ +import asyncio +import importlib +import json +import sys +from unittest import TestCase +from unittest.mock import AsyncMock, Mock, patch + +import forestadmin.agent_toolkit.resources.collections.native_query + +if sys.version_info >= (3, 9): + import zoneinfo +else: + from backports import zoneinfo + +import forestadmin.agent_toolkit.resources.collections.crud +from forestadmin.agent_toolkit.options import Options +from forestadmin.agent_toolkit.services.permissions.ip_whitelist_service import IpWhiteListService +from forestadmin.agent_toolkit.services.permissions.permission_service import PermissionService +from forestadmin.agent_toolkit.utils.context import Request, RequestMethod, User +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource +from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.exceptions import ForbiddenError + + +def authenticate_mock(fn): + async def wrapped2(self, request): + request.user = User( + rendering_id=1, + user_id=1, + tags={}, + email="dummy@user.fr", + first_name="dummy", + last_name="user", + team="operational", + timezone=zoneinfo.ZoneInfo("Europe/Paris"), + request={"ip": "127.0.0.1"}, + ) + + return await fn(self, request) + + return wrapped2 + + +def ip_white_list_mock(fn): + async def wrapped(self, request: Request, *args, **kwargs): + return await fn(self, request, *args, **kwargs) + + return wrapped + + +patch("forestadmin.agent_toolkit.resources.collections.decorators.authenticate", authenticate_mock).start() +patch("forestadmin.agent_toolkit.resources.collections.decorators.ip_white_list", ip_white_list_mock).start() +# how to mock decorators, and why they are not testable : +# https://dev.to/stack-labs/how-to-mock-a-decorator-in-python-55jc + +importlib.reload(forestadmin.agent_toolkit.resources.collections.native_query) +from forestadmin.agent_toolkit.resources.collections.native_query import NativeQueryResource # noqa: E402 + + +class TestNativeQueryResourceBase(TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.loop = asyncio.new_event_loop() + cls.options = Options( + auth_secret="fake_secret", + env_secret="fake_secret", + server_url="http://fake:5000", + prefix="", + is_production=False, + ) + # cls.datasource = Mock(Datasource) + cls.datasource = Datasource(["db1", "db2"]) + cls.datasource_composite = CompositeDatasource() + cls.datasource_composite.add_datasource(cls.datasource) + + def setUp(self): + self.datasource_composite.execute_native_query = AsyncMock() + self.ip_white_list_service = Mock(IpWhiteListService) + self.ip_white_list_service.is_enable = AsyncMock(return_value=False) + + self.permission_service = Mock(PermissionService) + self.permission_service.can_chart = AsyncMock(return_value=None) + + self.native_query_resource = NativeQueryResource( + self.datasource_composite, + self.datasource, + self.permission_service, + self.ip_white_list_service, + self.options, + ) + + +class TestNativeQueryResourceOnError(TestNativeQueryResourceBase): + def test_should_return_error_if_cannot_chart_on_permission(self): + request = Request( + method=RequestMethod.POST, + headers={}, + body={}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.permission_service, + "can_chart", + new_callable=AsyncMock, + side_effect=ForbiddenError("You don't have permission to access this chart."), + ) as mock_can_chart: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + self.assertEqual(response.status, 403) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body, + { + "errors": [ + { + "name": "ForbiddenError", + "detail": "You don't have permission to access this chart.", + "status": 403, + "data": {}, + } + ] + }, + ) + + mock_can_chart.assert_awaited_once_with(request) + + def test_should_return_error_if_connectionName_is_not_here(self): + request = Request( + method=RequestMethod.POST, + headers={}, + body={}, + client_ip="127.0.0.1", + query={}, + ) + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + self.assertEqual(response.status, 400) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body, + { + "errors": [ + { + "name": "ValidationError", + "detail": "Missing native query connection attribute", + "status": 400, + "data": {}, + } + ] + }, + ) + + def test_should_return_error_if_query_is_not_here(self): + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1"}, + client_ip="127.0.0.1", + query={}, + ) + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + self.assertEqual(response.status, 400) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body, + { + "errors": [ + { + "name": "ValidationError", + "detail": "Missing 'query' in parameter.", + "status": 400, + "data": {}, + } + ] + }, + ) + + def test_should_return_error_if_chart_type_is_unknown_or_missing(self): + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": "select count(*) from orders;"}, + client_ip="127.0.0.1", + query={}, + ) + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + self.assertEqual(response.status, 400) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body, + { + "errors": [ + { + "name": "ValidationError", + "detail": "Unknown chart type 'None'.", + "status": 400, + "data": {}, + } + ] + }, + ) + + request.body["type"] = "unknown" + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + self.assertEqual(response.status, 400) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body, + { + "errors": [ + { + "name": "ValidationError", + "detail": "Unknown chart type 'unknown'.", + "status": 400, + "data": {}, + } + ] + }, + ) + + +class TestNativeQueryResourceValueChart(TestNativeQueryResourceBase): + def test_should_correctly_handle_value_chart(self): + native_query = "select count(*) as value from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Value"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, "execute_native_query", new_callable=AsyncMock, return_value=[{"value": 100}] + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 200) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertIn("id", response_body["data"]) + self.assertEqual(response_body["data"]["type"], "stats") + self.assertEqual(response_body["data"]["attributes"], {"value": {"countCurrent": 100}}) + + def test_should_return_error_if_value_query_return_fields_are_not_good(self): + native_query = "select count(*) as not_value from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Value"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, "execute_native_query", new_callable=AsyncMock, return_value=[{"not_value": 100}] + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Value' chart must return 'value' field.", + "status": 422, + "data": {}, + } + ], + ) + + def test_should_return_error_if_value_query_does_not_return_one_row(self): + native_query = "select count(*) as not_value from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Value"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"value": 100}, {"value": 100}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Value' chart must return only one row.", + "status": 422, + "data": {}, + } + ], + ) + + def test_should_correctly_handle_value_chart_with_previous(self): + native_query = "select count(*) as value, 0 as previous from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Value"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"value": 100, "previous": 0}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 200) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertIn("id", response_body["data"]) + self.assertEqual(response_body["data"]["type"], "stats") + self.assertEqual(response_body["data"]["attributes"], {"value": {"countCurrent": 100, "countPrevious": 0}}) + + +class TestNativeQueryResourceLineChart(TestNativeQueryResourceBase): + def test_should_correctly_handle_line_chart(self): + native_query = "select count(*) as value, date as key from orders group by date, order by date;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Line"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"key": "2020-01", "value": 100}, {"key": "2020-02", "value": 110}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 200) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertIn("id", response_body["data"]) + self.assertEqual(response_body["data"]["type"], "stats") + self.assertEqual( + response_body["data"]["attributes"], + {"value": [{"label": "2020-01", "values": {"value": 100}}, {"label": "2020-02", "values": {"value": 110}}]}, + ) + + def test_should_return_error_if_line_query_return_fields_are_not_good(self): + native_query = "select count(*) as not_value, date as key from orders group by date, order by date;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Line"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"key": "2020-01", "not_value": 100}, {"key": "2020-02", "not_value": 110}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Line' chart must return 'key' and 'value' fields.", + "status": 422, + "data": {}, + } + ], + ) + + native_query = "select count(*) as value, date as not_key from orders group by date, order by date;" + request.body["query"] = native_query + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"not_key": "2020-01", "value": 100}, {"not_key": "2020-02", "value": 110}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Line' chart must return 'key' and 'value' fields.", + "status": 422, + "data": {}, + } + ], + ) + + +class TestNativeQueryResourceObjectiveChart(TestNativeQueryResourceBase): + def test_should_correctly_handle_objective_chart(self): + native_query = "select count(*) as value, 1000 as objective from orders " + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Objective"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"value": 150, "objective": 1000}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 200) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertIn("id", response_body["data"]) + self.assertEqual(response_body["data"]["type"], "stats") + self.assertEqual(response_body["data"]["attributes"], {"value": {"value": 150, "objective": 1000}}) + + def test_should_return_error_if_objective_query_return_fields_are_not_good(self): + native_query = "select count(*) as not_value, 1000 as objective from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Objective"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"not_value": 150, "objective": 1000}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Objective' chart must return 'value' and 'objective' fields.", + "status": 422, + "data": {}, + } + ], + ) + + native_query = "select count(*) as value, 1000 as not_objective from orders;" + request.body["query"] = native_query + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"value": 150, "not_objective": 1000}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Objective' chart must return 'value' and 'objective' fields.", + "status": 422, + "data": {}, + } + ], + ) + + def test_should_return_error_if_objective_query_does_not_return_one_row(self): + native_query = "select count(*) as value, 1000 as objective from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Objective"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[{"value": 100, "objective": 1000}, {"value": 100, "objective": 1000}], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Objective' chart must return only one row.", + "status": 422, + "data": {}, + } + ], + ) + + +class TestNativeQueryResourceLeaderboardChart(TestNativeQueryResourceBase): + def test_should_correctly_handle_leaderboard_chart(self): + native_query = "select sum(score) as value, customer as key from results group by customer order by value desc;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Leaderboard"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[ + {"value": 150, "key": "Jean"}, + {"value": 140, "key": "Elsa"}, + {"value": 0, "key": "Gautier"}, + ], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 200) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertIn("id", response_body["data"]) + self.assertEqual(response_body["data"]["type"], "stats") + self.assertEqual( + response_body["data"]["attributes"], + {"value": [{"key": "Jean", "value": 150}, {"key": "Elsa", "value": 140}, {"key": "Gautier", "value": 0}]}, + ) + + def test_should_return_error_if_leaderboard_query_return_fields_are_not_good(self): + native_query = ( + "select sum(score) as not_value, customer as key from results group by customer order by value desc;" + ) + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Leaderboard"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[ + {"not_value": 150, "key": "Jean"}, + {"not_value": 140, "key": "Elsa"}, + {"not_value": 0, "key": "Gautier"}, + ], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Leaderboard' chart must return 'key' and 'value' fields.", + "status": 422, + "data": {}, + } + ], + ) + + native_query = ( + "select sum(score) as value, customer as not_key from results group by customer order by value desc;" + ) + request.body["query"] = native_query + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[ + {"value": 150, "not_key": "Jean"}, + {"value": 140, "not_key": "Elsa"}, + {"value": 0, "not_key": "Gautier"}, + ], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Leaderboard' chart must return 'key' and 'value' fields.", + "status": 422, + "data": {}, + } + ], + ) + + +class TestNativeQueryResourcePieChart(TestNativeQueryResourceBase): + def test_should_correctly_handle_pie_chart(self): + native_query = "select count(*) as value, status as key from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Pie"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[ + {"value": 150, "key": "pending"}, + {"value": 140, "key": "delivering"}, + {"value": 10, "key": "lost"}, + ], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 200) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertIn("id", response_body["data"]) + self.assertEqual(response_body["data"]["type"], "stats") + self.assertEqual( + response_body["data"]["attributes"], + { + "value": [ + {"value": 150, "key": "pending"}, + {"value": 140, "key": "delivering"}, + {"value": 10, "key": "lost"}, + ] + }, + ) + + def test_should_return_error_if_pie_query_return_fields_are_not_good(self): + native_query = "select count(*) as not_value, status as key from orders;" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Pie"}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[ + {"not_value": 150, "key": "pending"}, + {"not_value": 140, "key": "delivering"}, + {"not_value": 10, "key": "lost"}, + ], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Pie' chart must return 'key' and 'value' fields.", + "status": 422, + "data": {}, + } + ], + ) + + native_query = "select count(*) as value, status as not_key from orders;" + request.body["query"] = native_query + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[ + {"value": 150, "not_key": "pending"}, + {"value": 140, "not_key": "delivering"}, + {"value": 10, "not_key": "lost"}, + ], + ) as mock_exec_native_query: + response = self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with("db1", native_query, {}) + + self.assertEqual(response.status, 422) + self.assertEqual(response.headers, {"content-type": "application/json"}) + response_body = json.loads(response.body) + self.assertEqual( + response_body["errors"], + [ + { + "name": "UnprocessableError", + "detail": "Native query for 'Pie' chart must return 'key' and 'value' fields.", + "status": 422, + "data": {}, + } + ], + ) + + +class TestNativeQueryResourceVariableConectextVariables(TestNativeQueryResourceBase): + + def test_should_correctly_handle_variable_context(self): + native_query = "select count(*) as value, status as key from orders where customer = {{recordId}};" + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Pie", "contextVariables": {"recordId": 1}}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[], + ) as mock_exec_native_query: + self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with( + "db1", + "select count(*) as value, status as key from orders where customer = %(recordId)s;", + {"recordId": 1}, + ) + + def test_should_correctly_handle_variable_context_and_like_percent_comparison(self): + native_query = ( + "select count(*) as value, status as key from orders where customer = {{recordId}} " + "and customer_name like '%henry%';" + ) + request = Request( + method=RequestMethod.POST, + headers={}, + body={"connectionName": "db1", "query": native_query, "type": "Pie", "contextVariables": {"recordId": 1}}, + client_ip="127.0.0.1", + query={}, + ) + with patch.object( + self.datasource_composite, + "execute_native_query", + new_callable=AsyncMock, + return_value=[], + ) as mock_exec_native_query: + self.loop.run_until_complete( + self.native_query_resource.dispatch(request, "native_query"), + ) + mock_exec_native_query.assert_awaited_once_with( + "db1", + "select count(*) as value, status as key from orders where customer = %(recordId)s " + "and customer_name like '\\%henry\\%';", + {"recordId": 1}, + ) diff --git a/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py b/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py index cecd5da05..a63508880 100644 --- a/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py +++ b/src/agent_toolkit/tests/resources/services/permissions/test_permission_service.py @@ -38,7 +38,7 @@ class BaseTestPermissionService(TestCase): @classmethod def setUpClass(cls) -> None: cls.loop = asyncio.new_event_loop() - cls.datasource = Datasource() + cls.datasource = Datasource(["database_1"]) Collection.__abstractmethods__ = set() # to instantiate abstract class cls.booking_collection = Collection("Booking", cls.datasource) cls.booking_collection.add_fields( @@ -174,7 +174,18 @@ def mock_forest_http_api(self, scope=None) -> PatchHttpApiDict: "get_rendering_permissions", new_callable=AsyncMock, return_value={ - "collections": {"Booking": {"scope": scope, "segments": []}}, + "collections": { + "Booking": { + "scope": scope, + "segments": ["select id from booking where title is null"], + "liveQuerySegments": [ + { + "query": "select id from booking where title is null", + "connectionName": "database_1", + } + ], + } + }, "stats": [ { "type": "Pie", @@ -191,6 +202,11 @@ def mock_forest_http_api(self, scope=None) -> PatchHttpApiDict: "aggregateFieldName": None, "sourceCollectionName": "Booking", }, + { + "type": "Pie", + "query": "select sum(amount) as value, status as key from app_order group by status", + "connectionName": "database_1", + }, ], "team": { "id": 1, @@ -216,12 +232,11 @@ def test_invalidate_cache_should_delete_corresponding_key(self): self.assertNotIn("forest.users", self.permission_service.cache) def test_dont_call_api_when_something_is_cached(self): - http_patches: PatchHttpApiDict = self.mock_forest_http_api() http_patches: PatchHttpApiDict = self.mock_forest_http_api() http_mocks: MockHttpApiDict = {name: patch.start() for name, patch in http_patches.items()} - response_1 = self.loop.run_until_complete(self.permission_service._get_chart_data(1)) - response_2 = self.loop.run_until_complete(self.permission_service._get_chart_data(1)) + response_1 = self.loop.run_until_complete(self.permission_service._get_rendering_data(1)) + response_2 = self.loop.run_until_complete(self.permission_service._get_rendering_data(1)) self.assertEqual(response_1, response_2) response_1 = self.loop.run_until_complete(self.permission_service.get_user_data(1)) @@ -383,7 +398,7 @@ def test_can_chart_should_retry_on_not_allowed_at_first_try(self): def mock_get_rendering_permissions(rendering_id, options): if mock["call_count"] == 0: - ret = {"stats": {}} + ret = {"stats": {}, "collections": {}, "team": {}} elif mock["call_count"] == 1: ret = http_patches["get_rendering_permissions"].kwargs["return_value"] mock["call_count"] = mock["call_count"] + 1 @@ -416,7 +431,10 @@ def test_can_chart_should_raise_forbidden_error_on_not_allowed_chart(self): ) with patch.object( - ForestHttpApi, "get_rendering_permissions", new_callable=AsyncMock, return_value={"stats": {}} + ForestHttpApi, + "get_rendering_permissions", + new_callable=AsyncMock, + return_value={"stats": {}, "collections": {}, "team": {}}, ): self.assertRaisesRegex( ForbiddenError, @@ -425,6 +443,87 @@ def test_can_chart_should_raise_forbidden_error_on_not_allowed_chart(self): self.permission_service.can_chart(request), ) + def test_can_chart_should_handle_live_query_chart(self): + http_patches: PatchHttpApiDict = self.mock_forest_http_api() + rendering_permission_mock: AsyncMock = http_patches["get_rendering_permissions"].start() + + request = RequestCollection( + method=RequestMethod.POST, + collection=self.booking_collection, + body={ + "connectionName": "database_1", + "contextVariables": {}, + "query": "select sum(amount) as value, status as key from app_order group by status", + "type": "Pie", + }, + query={}, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + + is_allowed = self.loop.run_until_complete(self.permission_service.can_chart(request)) + rendering_permission_mock.assert_awaited_once() + self.assertTrue(is_allowed) + + http_patches["get_rendering_permissions"].stop() + + def test_can_chart_should_raise_forbidden_error_on_not_wrong_live_query_connection(self): + http_patches: PatchHttpApiDict = self.mock_forest_http_api() + http_patches["get_rendering_permissions"].start() + + request = RequestCollection( + method=RequestMethod.POST, + collection=self.booking_collection, + body={ + "connectionName": "wrong_database", + "contextVariables": {}, + "query": "select sum(amount) as value, status as key from app_order group by status", + "type": "Pie", + }, + query={}, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + + self.assertRaisesRegex( + ForbiddenError, + r"🌳🌳🌳You don't have permission to access this chart.", + self.loop.run_until_complete, + self.permission_service.can_chart(request), + ) + + http_patches["get_rendering_permissions"].stop() + + def test_can_chart_should_raise_forbidden_error_on_mismatching_query(self): + http_patches: PatchHttpApiDict = self.mock_forest_http_api() + http_patches["get_rendering_permissions"].start() + + request = RequestCollection( + method=RequestMethod.POST, + collection=self.booking_collection, + body={ + "connectionName": "database_1", + "contextVariables": {}, + "query": "select sum(amount) as value, status as key from app_order group by status ; ", + "type": "Pie", + }, + query={}, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + + self.assertRaisesRegex( + ForbiddenError, + r"🌳🌳🌳You don't have permission to access this chart.", + self.loop.run_until_complete, + self.permission_service.can_chart(request), + ) + + http_patches["get_rendering_permissions"].stop() + class Test04GetScopePermissionService(BaseTestPermissionService): def test_get_scope_should_return_null_when_no_scope_in_permissions(self): @@ -646,3 +745,86 @@ def test_can_smart_action_should_throw_when_action_is_unknown(self): http_patches["get_environment_permissions"].stop() http_patches["get_users"].stop() + + +class Test06CanLiveQuerySegment(BaseTestPermissionService): + def test_can_live_query_segment_should_allow_correct_live_query_segment(self): + http_patches: PatchHttpApiDict = self.mock_forest_http_api() + rendering_permission_mock: AsyncMock = http_patches["get_rendering_permissions"].start() + + request = RequestCollection( + method=RequestMethod.GET, + collection=self.booking_collection, + body=None, + query={ + "fields[Booking]": "id,title", + "connectionName": "database_1", + "segmentName": "no_title", + "segmentQuery": "select id from booking where title is null", + }, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + + is_allowed = self.loop.run_until_complete(self.permission_service.can_live_query_segment(request)) + rendering_permission_mock.assert_awaited_once() + self.assertTrue(is_allowed) + + http_patches["get_rendering_permissions"].stop() + + def test_can_live_query_segment_should_raise_forbidden_when_wrong_connection_name(self): + http_patches: PatchHttpApiDict = self.mock_forest_http_api() + http_patches["get_rendering_permissions"].start() + + request = RequestCollection( + method=RequestMethod.GET, + collection=self.booking_collection, + body=None, + query={ + "fields[Booking]": "id,title", + "connectionName": "database_2", + "segmentName": "no_title", + "segmentQuery": "select id from booking where title is null", + }, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + + self.assertRaisesRegex( + ForbiddenError, + r"🌳🌳🌳You don't have permission to access this segment query.", + self.loop.run_until_complete, + self.permission_service.can_live_query_segment(request), + ) + + http_patches["get_rendering_permissions"].stop() + + def test_can_live_query_segment_should_raise_forbidden_when_mismatching_query(self): + http_patches: PatchHttpApiDict = self.mock_forest_http_api() + http_patches["get_rendering_permissions"].start() + + request = RequestCollection( + method=RequestMethod.GET, + collection=self.booking_collection, + body=None, + query={ + "fields[Booking]": "id,title", + "connectionName": "database_1", + "segmentName": "no_title", + "segmentQuery": "select id from booking where title is null;", + }, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + + self.assertRaisesRegex( + ForbiddenError, + r"🌳🌳🌳You don't have permission to access this segment query.", + self.loop.run_until_complete, + self.permission_service.can_live_query_segment(request), + ) + + http_patches["get_rendering_permissions"].stop() diff --git a/src/agent_toolkit/tests/resources/test_capabilities_resource.py b/src/agent_toolkit/tests/resources/test_capabilities_resource.py index de94e0312..0ba6a52eb 100644 --- a/src/agent_toolkit/tests/resources/test_capabilities_resource.py +++ b/src/agent_toolkit/tests/resources/test_capabilities_resource.py @@ -71,7 +71,7 @@ def setUpClass(cls) -> None: is_production=False, ) # type:ignore - cls.datasource = Datasource() + cls.datasource = Datasource(live_query_connections=["test1", "test2"]) Collection.__abstractmethods__ = set() # type:ignore # to instantiate abstract class cls.book_collection = Collection("Book", cls.datasource) # type:ignore cls.book_collection.add_fields( @@ -127,7 +127,7 @@ def test_dispatch_should_not_dispatch_to_capabilities_when_no_post_request(self) self.assertEqual(response.status, 405) - def test_dispatch_should_dispatch_POST_to_capabilities(self): + def test_dispatch_should_return_correct_collection_and_fields_capabilities(self): request = Request( method=RequestMethod.POST, query={}, @@ -156,3 +156,18 @@ def test_dispatch_should_dispatch_POST_to_capabilities(self): } ], ) + + def test_should_return_correct_datasource_connections_capabilities(self): + request = Request( + method=RequestMethod.POST, + query={}, + body={"collectionNames": ["Book"]}, + headers={}, + client_ip="127.0.0.1", + user=self.mocked_caller, + ) + response: Response = self.loop.run_until_complete(self.capabilities_resource.dispatch(request, "capabilities")) + self.assertEqual(response.status, 200) + response_content = json.loads(response.body) + + self.assertEqual(response_content["nativeQueryConnections"], [{"name": "test1"}, {"name": "test2"}]) diff --git a/src/agent_toolkit/tests/test_agent_toolkit.py b/src/agent_toolkit/tests/test_agent_toolkit.py index 0301ca2ad..2aceffeb6 100644 --- a/src/agent_toolkit/tests/test_agent_toolkit.py +++ b/src/agent_toolkit/tests/test_agent_toolkit.py @@ -18,6 +18,7 @@ @patch("forestadmin.agent_toolkit.agent.CrudRelatedResource") @patch("forestadmin.agent_toolkit.agent.StatsResource") @patch("forestadmin.agent_toolkit.agent.ActionResource") +@patch("forestadmin.agent_toolkit.agent.NativeQueryResource") @patch("forestadmin.agent_toolkit.agent.SchemaEmitter.get_serialized_schema", new_callable=AsyncMock) @patch("forestadmin.agent_toolkit.agent.ForestHttpApi.send_schema", new_callable=AsyncMock) class TestAgent(TestCase): @@ -34,6 +35,7 @@ def test_create( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -81,6 +83,7 @@ def test_property_resources( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -97,10 +100,14 @@ def test_property_resources( mocked_authentication_resource.assert_called_once_with(agent._ip_white_list_service, agent.options) mocked_capabilities_resource.assert_called_once_with( - "fake_datasource", agent._ip_white_list_service, agent.options + agent.customizer.composite_datasource, agent._ip_white_list_service, agent.options ) mocked_crud_resource.assert_called_once_with( - "fake_datasource", agent._permission_service, agent._ip_white_list_service, agent.options + agent.customizer.composite_datasource, + "fake_datasource", + agent._permission_service, + agent._ip_white_list_service, + agent.options, ) mocked_crud_related_resource.assert_called_once_with( "fake_datasource", agent._permission_service, agent._ip_white_list_service, agent.options @@ -111,8 +118,16 @@ def test_property_resources( mocked_action_resource.assert_called_once_with( "fake_datasource", agent._permission_service, agent._ip_white_list_service, agent.options ) + mocked_native_query_resource.assert_called_once_with( + agent.customizer.composite_datasource, + "fake_datasource", + agent._permission_service, + agent._ip_white_list_service, + agent.options, + ) - assert len(resources) == 8 + assert len(resources) == 9 + assert "native_query" in resources assert "capabilities" in resources assert "authentication" in resources assert "crud" in resources @@ -126,6 +141,7 @@ def test_property_meta( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -147,6 +163,7 @@ def test_add_datasource( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -167,6 +184,7 @@ def test_add_chart( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -190,6 +208,7 @@ def test_remove_collections( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -210,6 +229,7 @@ def test_customize_datasource( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -234,6 +254,7 @@ def test_start( mocked_create_json_api_schema, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -279,6 +300,7 @@ def test_start_dont_crash_if_schema_generation_or_sending_fail( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, @@ -315,6 +337,7 @@ def test_use_should_add_a_plugin( self, mocked_schema_emitter__get_serialized_schema, mocked_forest_http_api__send_schema, + mocked_native_query_resource, mocked_action_resource, mocked_stats_resource, mocked_crud_related_resource, diff --git a/src/agent_toolkit/tests/utils/test_sql_query_checker.py b/src/agent_toolkit/tests/utils/test_sql_query_checker.py new file mode 100644 index 000000000..dcfecb359 --- /dev/null +++ b/src/agent_toolkit/tests/utils/test_sql_query_checker.py @@ -0,0 +1,55 @@ +from unittest import TestCase + +from forestadmin.agent_toolkit.utils.sql_query_checker import ( + ChainedSQLQueryException, + EmptySQLQueryException, + NonSelectSQLQueryException, + SqlQueryChecker, +) + + +class TestSqlQueryChecker(TestCase): + def test_normal_sql_query_should_be_ok(self): + self.assertTrue( + SqlQueryChecker.check_query( + """ + Select status, sum(amount) as value + from order + where status != "rejected" + group by status having status != "rejected"; + """ + ) + ) + + def test_should_raise_on_linked_query(self): + self.assertRaisesRegex( + ChainedSQLQueryException, + r"You cannot chain SQL queries\.", + SqlQueryChecker.check_query, + """ + Select status, sum(amount) as value + from order + where status != "rejected" + group by status having status != "rejected"; delete from user_debts; + """, + ) + + def test_should_raise_on_empty_query(self): + self.assertRaisesRegex( + EmptySQLQueryException, + r"You cannot execute an empty SQL query\.", + SqlQueryChecker.check_query, + """ + + """, + ) + + def test_should_raise_on_non_select_query(self): + self.assertRaisesRegex( + NonSelectSQLQueryException, + r"Only SELECT queries are allowed\.", + SqlQueryChecker.check_query, + """ + delete from user_debts; + """, + ) diff --git a/src/datasource_django/forestadmin/datasource_django/datasource.py b/src/datasource_django/forestadmin/datasource_django/datasource.py index 453fcf68b..5a3b12015 100644 --- a/src/datasource_django/forestadmin/datasource_django/datasource.py +++ b/src/datasource_django/forestadmin/datasource_django/datasource.py @@ -1,17 +1,94 @@ +from datetime import date +from typing import Dict, List, Optional, Union + +from asgiref.sync import sync_to_async from django.apps import apps +from django.db import connections from forestadmin.datasource_django.collection import DjangoCollection +from forestadmin.datasource_django.exception import DjangoDatasourceException from forestadmin.datasource_django.interface import BaseDjangoDatasource +from forestadmin.datasource_toolkit.exceptions import NativeQueryException +from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias class DjangoDatasource(BaseDjangoDatasource): - def __init__(self, support_polymorphic_relations: bool = False) -> None: - super().__init__() + def __init__( + self, + support_polymorphic_relations: bool = False, + live_query_connection: Optional[Union[str, Dict[str, str]]] = None, + ) -> None: + """ Create a django datasource. + More information here: + https://docs.forestadmin.com/developer-guide-agents-python/data-sources/provided-data-sources/django + + + Args: + support_polymorphic_relations (bool, optional, default to `False`): Enable introspection over \ + polymorphic relation (AKA GenericForeignKey). Defaults to False. + live_query_connection (Union[str, Dict[str, str]], optional, default to `None`): Set a connectionName to \ + use live queries. If a string is given, this connection will be map to django 'default' database. \ + Otherwise, you must use a dict `{'connectionName': 'DjangoDatabaseName'}`. \ + None doesn't enable this feature. + """ + self._django_live_query_connections: Dict[str, str] = self._handle_live_query_connections_param( + live_query_connection + ) + super().__init__([*self._django_live_query_connections.keys()]) + self.support_polymorphic_relations = support_polymorphic_relations self._create_collections() + def _handle_live_query_connections_param( + self, live_query_connections: Optional[Union[str, Dict[str, str]]] + ) -> Dict[str, str]: + if live_query_connections is None: + return {} + + if isinstance(live_query_connections, str): + ret = {live_query_connections: "default"} + else: + ret = live_query_connections + + for forest_name, db_name in ret.items(): + if db_name not in connections: + raise DjangoDatasourceException( + f"Connection to database '{db_name}' for alias '{forest_name}' is not found in django databases. " + f"Existing connections are {','.join([*connections])}" + ) + return ret + def _create_collections(self): models = apps.get_models(include_auto_created=True) for model in models: if model._meta.proxy is False: collection = DjangoCollection(self, model, self.support_polymorphic_relations) self.add_collection(collection) + + async def execute_native_query( + self, connection_name: str, native_query: str, parameters: Dict[str, str] + ) -> List[RecordsDataAlias]: + if connection_name not in self._django_live_query_connections.keys(): + # This one should never occur while datasource composite works fine + raise NativeQueryException(f"Native query connection '{connection_name}' is not known by DjangoDatasource.") + + def _execute_native_query(): + cursor = connections[self._django_live_query_connections[connection_name]].cursor() # type: ignore + try: + # replace '\%' by '%%' + # %(var)s is already the correct syntax + rows = cursor.execute(native_query.replace("\\%", "%%"), parameters) + + ret = [] + for row in rows: + return_row = {} + for i, field_name in enumerate(rows.description): + value = row[i] + if isinstance(value, date): + value = value.isoformat() + return_row[field_name[0]] = value + ret.append(return_row) + return ret + except Exception as e: + raise NativeQueryException(str(e)) + + return await sync_to_async(_execute_native_query)() diff --git a/src/datasource_django/tests/test_django_datasource.py b/src/datasource_django/tests/test_django_datasource.py index dc64674c9..ae0b98ac0 100644 --- a/src/datasource_django/tests/test_django_datasource.py +++ b/src/datasource_django/tests/test_django_datasource.py @@ -1,8 +1,11 @@ -from unittest import TestCase +import asyncio from unittest.mock import Mock, call, patch +from django.test import SimpleTestCase, TestCase from forestadmin.datasource_django.collection import DjangoCollection from forestadmin.datasource_django.datasource import DjangoDatasource +from forestadmin.datasource_django.exception import DjangoDatasourceException +from forestadmin.datasource_toolkit.exceptions import NativeQueryException mock_collection1 = Mock(DjangoCollection) mock_collection1.name = "first" @@ -66,3 +69,111 @@ def test_django_datasource_should_ignore_proxy_models(self): """ignoring proxy models means no collections added twice or more""" datasource = DjangoDatasource() self.assertEqual(len([c.name for c in datasource.collections if c.name == "auth_user"]), 1) + + +class TestDjangoDatasourceConnectionQueryCreation(SimpleTestCase): + def test_should_not_create_native_query_connection_if_no_params(self): + ds = DjangoDatasource() + self.assertEqual(ds.get_native_query_connections(), []) + + def test_should_create_native_query_connection_to_default_if_string_is_set(self): + ds = DjangoDatasource(live_query_connection="django") + self.assertEqual(ds.get_native_query_connections(), ["django"]) + self.assertEqual(ds._django_live_query_connections["django"], "default") + + def test_should_raise_if_connection_query_target_non_existent_database(self): + self.assertRaisesRegex( + DjangoDatasourceException, + r"Connection to database 'plouf' for alias 'plif' is not found in django databases\. " + r"Existing connections are default,other", + DjangoDatasource, + live_query_connection={"django": "default", "plif": "plouf"}, + ) + + +class TestDjangoDatasourceNativeQueryExecution(TestCase): + fixtures = ["person.json", "book.json", "rating.json", "tag.json"] + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.loop = asyncio.new_event_loop() + cls.dj_datasource = DjangoDatasource(live_query_connection={"django": "default", "other": "other"}) + + def test_should_raise_if_connection_is_not_known_by_datasource(self): + self.assertRaisesRegex( + NativeQueryException, + r"Native query connection 'foo' is not known by DjangoDatasource.", + self.loop.run_until_complete, + self.dj_datasource.execute_native_query("foo", "select * from blabla", {}), + ) + + async def test_should_correctly_execute_query(self): + result = await self.dj_datasource.execute_native_query( + "django", "select * from test_app_person order by person_pk;", {} + ) + self.assertEqual( + result, + [ + { + "person_pk": 1, + "first_name": "Isaac", + "last_name": "Asimov", + "birth_date": "1920-02-01", + "auth_user_id": None, + }, + { + "person_pk": 2, + "first_name": "J.K.", + "last_name": "Rowling", + "birth_date": "1965-07-31", + "auth_user_id": None, + }, + ], + ) + + async def test_should_correctly_execute_query_with_formatting(self): + result = await self.dj_datasource.execute_native_query( + "django", + "select * from test_app_person where first_name = %(first_name)s order by person_pk;", + {"first_name": "Isaac"}, + ) + self.assertEqual( + result, + [ + { + "person_pk": 1, + "first_name": "Isaac", + "last_name": "Asimov", + "birth_date": "1920-02-01", + "auth_user_id": None, + }, + ], + ) + + async def test_should_correctly_execute_query_with_percent(self): + result = await self.dj_datasource.execute_native_query( + "django", + "select * from test_app_person where first_name like 'Is\\%' order by person_pk;", + {}, + ) + self.assertEqual( + result, + [ + { + "person_pk": 1, + "first_name": "Isaac", + "last_name": "Asimov", + "birth_date": "1920-02-01", + "auth_user_id": None, + }, + ], + ) + + def test_should_correctly_raise_exception_during_sql_error(self): + self.assertRaisesRegex( + NativeQueryException, + r"no such table: blabla", + self.loop.run_until_complete, + self.dj_datasource.execute_native_query("django", "select * from blabla", {}), + ) diff --git a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py index 24668e05e..82284b1ef 100644 --- a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py +++ b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/datasource.py @@ -1,15 +1,17 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from forestadmin.datasource_sqlalchemy.collections import SqlAlchemyCollection from forestadmin.datasource_sqlalchemy.exceptions import SqlAlchemyDatasourceException from forestadmin.datasource_sqlalchemy.interfaces import BaseSqlAlchemyDatasource -from sqlalchemy import create_engine +from forestadmin.datasource_toolkit.exceptions import NativeQueryException +from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias +from sqlalchemy import create_engine, text from sqlalchemy.orm import Mapper, sessionmaker class SqlAlchemyDatasource(BaseSqlAlchemyDatasource): - def __init__(self, Base: Any, db_uri: Optional[str] = None) -> None: - super().__init__() + def __init__(self, Base: Any, db_uri: Optional[str] = None, live_query_connection: Optional[str] = None) -> None: + super().__init__([live_query_connection] if live_query_connection is not None else None) self._base = Base self.__is_using_flask_sqlalchemy = hasattr(Base, "Model") @@ -56,6 +58,30 @@ def build_mappers(self) -> Dict[str, Mapper]: mappers[mapper.persist_selectable.name] = mapper return mappers + async def execute_native_query( + self, connection_name: str, native_query: str, parameters: Dict[str, str] + ) -> List[RecordsDataAlias]: + if connection_name != self.get_native_query_connections()[0]: + raise NativeQueryException( + f"The native query connection '{connection_name}' doesn't belongs to this datasource." + ) + try: + session = self.Session() + query = native_query + if isinstance(query, str): + query = native_query + for key in parameters.keys(): + # replace '%(...)s' by ':...' + query = query.replace(f"%({key})s", f":{key}") + # replace '\%' by '%' + query = query.replace("\\%", "%") + + query = text(query) + rows = session.execute(query, parameters) + return [*rows.mappings()] + except Exception as exc: + raise NativeQueryException(str(exc)) + # unused code, can be use full but can be remove # from forestadmin.datasource_toolkit.datasources import DatasourceException # from forestadmin.datasource_toolkit.interfaces.fields import FieldType, ManyToMany, ManyToOne diff --git a/src/datasource_sqlalchemy/tests/fixture/models.py b/src/datasource_sqlalchemy/tests/fixture/models.py index ffb6e102a..9f5bdb56a 100644 --- a/src/datasource_sqlalchemy/tests/fixture/models.py +++ b/src/datasource_sqlalchemy/tests/fixture/models.py @@ -3,11 +3,11 @@ import os from datetime import date, datetime +import sqlalchemy # type: ignore from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String, create_engine, func, types -from sqlalchemy.orm import Session, declarative_base, relationship, validates +from sqlalchemy.orm import Session, relationship, validates -test_db_path = os.path.abspath(os.path.join(__file__, "..", "..", "..", "..", "..", "test_db.sql")) -engine = create_engine(f"sqlite:///{test_db_path}", echo=False) +use_sqlalchemy_2 = sqlalchemy.__version__.split(".")[0] == "2" fixtures_dir = os.path.abspath(os.path.join(__file__, "..")) # to import/export json as fixtures @@ -42,8 +42,24 @@ def __import__(cls, d): return cls(**params) -Base = declarative_base(cls=_Base) -Base.metadata.bind = engine +if use_sqlalchemy_2: + from sqlalchemy.orm import DeclarativeBase + + class Base(DeclarativeBase, _Base): + pass + +else: + from sqlalchemy.orm import declarative_base + + Base = declarative_base(cls=_Base) + + +def get_models_base(db_file_name): + test_db_path = os.path.abspath(os.path.join(__file__, "..", "..", "..", "..", "..", f"{db_file_name}.sql")) + engine = create_engine(f"sqlite:///{test_db_path}", echo=False) + Base.metadata.bind = engine + Base.metadata.file_path = test_db_path + return Base class ORDER_STATUS(str, enum.Enum): @@ -104,7 +120,7 @@ class CustomersAddresses(Base): address_id = Column(Integer, ForeignKey("address.id"), primary_key=True) -def load_fixtures(): +def load_fixtures(base): with open(os.path.join(fixtures_dir, "addresses.json"), "r") as fin: data = json.load(fin) addresses = [Address.__import__(d) for d in data] @@ -121,7 +137,7 @@ def load_fixtures(): data = json.load(fin) orders = [Order.__import__(d) for d in data] - with Session(Base.metadata.bind) as session: + with Session(base.metadata.bind) as session: session.bulk_save_objects(addresses) session.bulk_save_objects(customers) session.bulk_save_objects(customers_addresses) @@ -129,5 +145,5 @@ def load_fixtures(): session.commit() -def create_test_database(): - Base.metadata.create_all(engine) +def create_test_database(base): + base.metadata.create_all(base.metadata.bind) diff --git a/src/datasource_sqlalchemy/tests/test_sqlalchemy_collections.py b/src/datasource_sqlalchemy/tests/test_sqlalchemy_collections.py index b0e27d722..366515f8d 100644 --- a/src/datasource_sqlalchemy/tests/test_sqlalchemy_collections.py +++ b/src/datasource_sqlalchemy/tests/test_sqlalchemy_collections.py @@ -143,10 +143,11 @@ class TestSqlAlchemyCollectionWithModels(TestCase): @classmethod def setUpClass(cls): cls.loop = asyncio.new_event_loop() - if os.path.exists(models.test_db_path): - os.remove(models.test_db_path) - models.create_test_database() - models.load_fixtures() + cls.sql_alchemy_base = models.get_models_base("test_collection_operations") + if os.path.exists(cls.sql_alchemy_base.metadata.file_path): + os.remove(cls.sql_alchemy_base.metadata.file_path) + models.create_test_database(cls.sql_alchemy_base) + models.load_fixtures(cls.sql_alchemy_base) cls.datasource = SqlAlchemyDatasource(models.Base) cls.mocked_caller = User( rendering_id=1, @@ -162,7 +163,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - os.remove(models.test_db_path) + os.remove(cls.sql_alchemy_base.metadata.file_path) cls.loop.close() def test_get_columns(self): @@ -306,7 +307,7 @@ def test_aggregate(self): def test_get_native_driver_should_return_connection(self): with self.datasource.get_collection("order").get_native_driver() as connection: self.assertIsInstance(connection, Session) - self.assertEqual(str(connection.bind.url), f"sqlite:///{models.test_db_path}") + self.assertEqual(str(connection.bind.url), f"sqlite:///{self.sql_alchemy_base.metadata.file_path}") rows = connection.execute(text('select id,amount from "order" where id = 3')).all() self.assertEqual(rows, [(3, 5285)]) @@ -314,7 +315,7 @@ def test_get_native_driver_should_return_connection(self): def test_get_native_driver_should_work_without_declaring_request_as_text(self): with self.datasource.get_collection("order").get_native_driver() as connection: self.assertIsInstance(connection, Session) - self.assertEqual(str(connection.bind.url), f"sqlite:///{models.test_db_path}") + self.assertEqual(str(connection.bind.url), f"sqlite:///{self.sql_alchemy_base.metadata.file_path}") rows = connection.execute('select id,amount from "order" where id = 3').all() self.assertEqual(rows, [(3, 5285)]) diff --git a/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py b/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py index ec5662028..a1cf7982d 100644 --- a/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py +++ b/src/datasource_sqlalchemy/tests/test_sqlalchemy_datasource.py @@ -1,10 +1,12 @@ +import asyncio +import os from unittest import TestCase from unittest.mock import Mock, patch from flask import Flask -from flask_sqlalchemy import SQLAlchemy from forestadmin.datasource_sqlalchemy.datasource import SqlAlchemyDatasource from forestadmin.datasource_sqlalchemy.exceptions import SqlAlchemyDatasourceException +from forestadmin.datasource_toolkit.exceptions import NativeQueryException from sqlalchemy.orm import DeclarativeMeta from .fixture import models @@ -50,6 +52,8 @@ def test_create_datasources_not_search_engine_when_db_uri_is_supply(self): @patch("forestadmin.datasource_sqlalchemy.datasource.sessionmaker") def test_create_datasource_with_flask_sqlalchemy_integration_should_find_engine(self, mocked_sessionmaker): + from flask_sqlalchemy import SQLAlchemy + app = Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///memory" db = SQLAlchemy() @@ -93,3 +97,99 @@ def test_with_models(self): assert len(datasource._collections) == 4 assert datasource.get_collection("address").datasource == datasource + + +class TestSQLAlchemyDatasourceConnectionQueryCreation(TestCase): + def test_should_not_create_native_query_connection_if_no_params(self): + ds = SqlAlchemyDatasource(models.Base) + self.assertEqual(ds.get_native_query_connections(), []) + + def test_should_create_native_query_connection_to_default_if_string_is_set(self): + ds = SqlAlchemyDatasource(models.Base, live_query_connection="sqlalchemy") + self.assertEqual(ds.get_native_query_connections(), ["sqlalchemy"]) + + +class TestSQLAlchemyDatasourceNativeQueryExecution(TestCase): + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.loop = asyncio.new_event_loop() + cls.sql_alchemy_base = models.get_models_base("test_datasource_native_query") + if os.path.exists(cls.sql_alchemy_base.metadata.file_path): + os.remove(cls.sql_alchemy_base.metadata.file_path) + models.create_test_database(cls.sql_alchemy_base) + models.load_fixtures(cls.sql_alchemy_base) + cls.sql_alchemy_datasource = SqlAlchemyDatasource(models.Base, live_query_connection="sqlalchemy") + + @classmethod + def tearDownClass(cls): + os.remove(cls.sql_alchemy_base.metadata.file_path) + cls.loop.close() + + def test_should_raise_if_connection_is_not_known_by_datasource(self): + self.assertRaisesRegex( + NativeQueryException, + r"The native query connection 'foo' doesn't belongs to this datasource.", + self.loop.run_until_complete, + self.sql_alchemy_datasource.execute_native_query("foo", "select * from blabla", {}), + ) + + def test_should_correctly_execute_query(self): + result = self.loop.run_until_complete( + self.sql_alchemy_datasource.execute_native_query( + "sqlalchemy", "select * from customer where id <= 2 order by id;", {} + ) + ) + self.assertEqual( + result, + [ + {"id": 1, "first_name": "David", "last_name": "Myers", "age": 112}, + {"id": 2, "first_name": "Thomas", "last_name": "Odom", "age": 92}, + ], + ) + + def test_should_correctly_execute_query_with_formatting(self): + result = self.loop.run_until_complete( + self.sql_alchemy_datasource.execute_native_query( + "sqlalchemy", + """select * + from customer + where first_name = %(first_name)s + and last_name = %(last_name)s + order by id""", + {"first_name": "David", "last_name": "Myers"}, + ) + ) + self.assertEqual( + result, + [ + {"id": 1, "first_name": "David", "last_name": "Myers", "age": 112}, + ], + ) + + def test_should_correctly_execute_query_with_percent(self): + result = self.loop.run_until_complete( + self.sql_alchemy_datasource.execute_native_query( + "sqlalchemy", + """select * + from customer + where first_name like 'Dav\\%' + order by id""", + {}, + ) + ) + + self.assertEqual( + result, + [ + {"id": 1, "first_name": "David", "last_name": "Myers", "age": 112}, + ], + ) + + def test_should_correctly_raise_exception_during_sql_error(self): + self.assertRaisesRegex( + NativeQueryException, + r"no such table: blabla", + self.loop.run_until_complete, + self.sql_alchemy_datasource.execute_native_query("sqlalchemy", "select * from blabla", {}), + ) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py index 21781b7c6..166ec77f6 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/context/relaxed_wrappers/collection.py @@ -51,6 +51,22 @@ def get_collection(self, name: str) -> "RelaxedCollection": def add_collection(self, collection: "RelaxedCollection") -> None: raise RelaxedDatasourceException("Cannot modify existing datasources") + @property + def schema(self): + return self.datasource.schema + + @property + def name(self): + return self.datasource.name + + def get_native_query_connections(self): + raise NotImplementedError + + async def execute_native_query( + self, connection_name: str, native_query: str, parameters: Dict[str, str] + ) -> List[Dict[str, Any]]: + raise RelaxedDatasourceException("Cannot use this method. Please use 'collection.get_native_driver' instead.") + class RelaxedCollection(Collection): def __init__(self, collection: Collection): diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py new file mode 100644 index 000000000..8ac4931e6 --- /dev/null +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_composite.py @@ -0,0 +1,78 @@ +from typing import Any, Dict, List + +from forestadmin.agent_toolkit.utils.context import User +from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException, NativeQueryException +from forestadmin.datasource_toolkit.interfaces.chart import Chart +from forestadmin.datasource_toolkit.interfaces.models.collections import BoundCollection, DatasourceSchema + + +class CompositeDatasource(Datasource): + def __init__(self) -> None: + super().__init__() + self._datasources: List[Datasource] = [] + + @property + def schema(self) -> DatasourceSchema: + charts = {} + for datasource in self._datasources: + charts.update(datasource.schema["charts"]) + + return {"charts": charts} + + def get_native_query_connections(self) -> List[str]: + native_queries = [] + for datasource in self._datasources: + native_queries.extend(datasource.get_native_query_connections()) + return native_queries + + @property + def collections(self) -> List[BoundCollection]: + ret = [] + for datasource in self._datasources: + ret.extend(datasource.collections) + return ret + + def get_collection(self, name: str) -> Any: + for datasource in self._datasources: + try: + return datasource.get_collection(name) + except Exception: + pass + + collection_names = [c.name for c in self.collections] + collection_names.sort() + raise DatasourceToolkitException( + f"Collection {name} not found. List of available collection: {', '.join(collection_names)}" + ) + + def add_datasource(self, datasource: Datasource): + existing_collection_names = [c.name for c in self.collections] + for collection in datasource.collections: + if collection.name in existing_collection_names: + raise DatasourceToolkitException(f"Collection '{collection.name}' already exists.") + + for connection in datasource.schema["charts"].keys(): + if connection in self.schema["charts"].keys(): + raise DatasourceToolkitException(f"Chart '{connection}' already exists.") + + existing_native_query_connection_names = self.get_native_query_connections() + for connection in datasource.get_native_query_connections(): + if connection in existing_native_query_connection_names: + raise DatasourceToolkitException(f"Native query connection '{connection}' already exists.") + + self._datasources.append(datasource) + + async def render_chart(self, caller: User, name: str) -> Chart: + for datasource in self._datasources: + if name in datasource.schema["charts"]: + return await datasource.render_chart(caller, name) + + raise DatasourceToolkitException(f"Chart {name} is not defined in the datasource.") + + async def execute_native_query(self, connection_name: str, native_query: str, parameters: Dict[str, str]) -> Any: + for datasource in self._datasources: + if connection_name in datasource.get_native_query_connections(): + return await datasource.execute_native_query(connection_name, native_query, parameters) + + raise NativeQueryException(f"Native query connection '{connection_name}' is unknown") diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py index 38deb3e44..2904eb167 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasource_customizer/datasource_customizer.py @@ -1,6 +1,7 @@ from typing import Dict, Optional from forestadmin.datasource_toolkit.datasource_customizer.collection_customizer import CollectionCustomizer +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource from forestadmin.datasource_toolkit.datasource_customizer.types import DataSourceOptions from forestadmin.datasource_toolkit.datasources import Datasource from forestadmin.datasource_toolkit.decorators.chart.types import DataSourceChartDefinition @@ -12,7 +13,7 @@ class DatasourceCustomizer: def __init__(self) -> None: - self.composite_datasource: Datasource = Datasource() + self.composite_datasource: CompositeDatasource = CompositeDatasource() self.stack = DecoratorStack(self.composite_datasource) @property @@ -52,8 +53,7 @@ async def _add_datasource(): rename_decorator.rename_collections(_options.get("rename", {})) datasource = rename_decorator - for collection in datasource.collections: - self.composite_datasource.add_collection(collection) + self.composite_datasource.add_datasource(datasource) self.stack.queue_customization(_add_datasource) return self diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py index 9623aa5f3..4ddc1b44e 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/datasources.py @@ -1,10 +1,11 @@ -from typing import Dict, List +from typing import Any, Dict, List, Optional from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException from forestadmin.datasource_toolkit.interfaces.chart import Chart from forestadmin.datasource_toolkit.interfaces.models.collections import BoundCollection from forestadmin.datasource_toolkit.interfaces.models.collections import Datasource as DatasourceInterface +from forestadmin.datasource_toolkit.interfaces.models.collections import DatasourceSchema class DatasourceException(DatasourceToolkitException): @@ -12,12 +13,19 @@ class DatasourceException(DatasourceToolkitException): class Datasource(DatasourceInterface[BoundCollection]): - def __init__(self) -> None: + def __init__(self, live_query_connections: Optional[List[str]] = None) -> None: self._collections: Dict[str, BoundCollection] = {} + self._live_query_connections = live_query_connections + self._schema: DatasourceSchema = { + "charts": {}, + } + + def get_native_query_connections(self) -> List[str]: + return self._live_query_connections or [] @property - def schema(self): - return {"charts": {}} + def schema(self) -> DatasourceSchema: + return self._schema @property def collections(self) -> List[BoundCollection]: @@ -40,3 +48,9 @@ def add_collection(self, collection: BoundCollection) -> None: async def render_chart(self, caller: User, name: str) -> Chart: raise DatasourceException(f"Chart {name} not exists on this datasource.") + + async def execute_native_query(self, connection_name: str, native_query: str, parameters: Dict[str, str]) -> Any: + # in native_query, there is the following syntax: + # - parameters to inject by 'execute' method are in the format '%(var)s' + # - '%' (in 'like' comparisons) are replaced by '\%' (to avoid conflict with previous rule) + raise NotImplementedError(f"'execute_native_query' is not implemented on {self.__class__.__name__}") diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py index 62450bfa9..64ebddda3 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/datasource_decorator.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Any, Dict, List, Union from forestadmin.agent_toolkit.utils.context import User from forestadmin.datasource_toolkit.collections import Collection @@ -11,7 +11,7 @@ class DatasourceDecorator(Datasource): def __init__( self, child_datasource: Union[Datasource, "DatasourceDecorator"], class_collection_decorator: type ) -> None: - super().__init__() + super().__init__(child_datasource.get_native_query_connections()) self.child_datasource = child_datasource self.class_collection_decorator = class_collection_decorator self._decorators: Dict[Collection, CollectionDecorator] = {} @@ -38,3 +38,8 @@ def get_charts(self): async def render_chart(self, caller: User, name: str) -> Chart: return await self.child_datasource.render_chart(caller, name) + + async def execute_native_query( + self, connection_name: str, native_query: str, parameters: Dict[str, str] + ) -> List[Dict[str, Any]]: + return await self.child_datasource.execute_native_query(connection_name, native_query, parameters) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/exceptions.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/exceptions.py index 462226e21..1c91bf651 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/exceptions.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/exceptions.py @@ -25,6 +25,18 @@ def __init__( super().__init__(message, *args) +class NativeQueryException(BusinessError): + def __init__( + self, + message: str = "", + headers: Optional[Dict[str, Any]] = None, + name: str = "NativeQueryError", + data: Optional[Dict[str, Any]] = {}, + *args: object, + ) -> None: + super().__init__(message, headers, name, data, *args) + + class ValidationError(BusinessError): def __init__( self, diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/chart.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/chart.py index f4170d17a..5d5f1b400 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/chart.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/chart.py @@ -31,5 +31,3 @@ LeaderboardChart, SmartChart, ] - -a: TimeBasedChart = [{"label": "&", "values": {"value"}}] diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py index 818a19cc4..9e750a7fa 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/models/collections.py @@ -3,6 +3,7 @@ from forestadmin.datasource_toolkit.interfaces.actions import Action from forestadmin.datasource_toolkit.interfaces.fields import FieldAlias +from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias from typing_extensions import Self @@ -15,6 +16,10 @@ class CollectionSchema(TypedDict): charts: Dict[str, Callable] +class DatasourceSchema(TypedDict): + charts: Dict[str, Callable] + + class Collection(abc.ABC): @property @abc.abstractmethod @@ -46,6 +51,21 @@ def collections(self) -> List[BoundCollection]: def get_collection(self, name: str) -> BoundCollection: raise NotImplementedError + @abc.abstractmethod + def get_native_query_connections(self) -> List[str]: + raise NotImplementedError + @abc.abstractmethod def add_collection(self, collection: BoundCollection) -> None: raise NotImplementedError + + @property + @abc.abstractmethod + def schema(self) -> DatasourceSchema: + raise NotImplementedError + + @abc.abstractmethod + async def execute_native_query( + self, connection_name: str, native_query: str, parameters: Dict[str, str] + ) -> List[RecordsDataAlias]: + raise NotImplementedError diff --git a/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py b/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py new file mode 100644 index 000000000..dae0e2c3f --- /dev/null +++ b/src/datasource_toolkit/tests/datasource_customizer/test_composite_datasource.py @@ -0,0 +1,226 @@ +import asyncio +import sys +from unittest import TestCase +from unittest.mock import AsyncMock, Mock, PropertyMock, patch + +if sys.version_info >= (3, 9): + import zoneinfo +else: + from backports import zoneinfo + +from forestadmin.agent_toolkit.utils.context import User +from forestadmin.datasource_toolkit.collections import Collection +from forestadmin.datasource_toolkit.datasource_customizer.datasource_composite import CompositeDatasource +from forestadmin.datasource_toolkit.datasources import Datasource +from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException, NativeQueryException +from forestadmin.datasource_toolkit.interfaces.fields import Column, FieldType, PrimitiveType + + +class BaseTestCompositeDatasource(TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.loop = asyncio.new_event_loop() + Collection.__abstractmethods__ = set() # to instantiate abstract class + + cls.mocked_caller = User( + rendering_id=1, + user_id=1, + tags={}, + email="dummy@user.fr", + first_name="dummy", + last_name="user", + team="operational", + timezone=zoneinfo.ZoneInfo("Europe/Paris"), + request={"ip": "127.0.0.1"}, + ) + + def setUp(self) -> None: + self.composite_ds = CompositeDatasource() + self.ds1_charts = {"charts": {"chart1": Mock()}} + self.ds2_charts = {"charts": {"chart2": Mock()}} + + self.ds1_connections = ["db1", "db2"] + self.ds2_connections = ["db3"] + + DS1 = type( + "DS1", + (Datasource,), + { + "schema": PropertyMock(return_value=self.ds1_charts), + "get_native_query_connections": Mock(return_value=self.ds1_connections), + }, + ) + DS2 = type( + "DS2", + (Datasource,), + { + "schema": PropertyMock(return_value=self.ds2_charts), + "get_native_query_connections": Mock(return_value=self.ds2_connections), + }, + ) + + self.datasource_1: Datasource = DS1() + self.collection_person = Collection("Person", self.datasource_1) + self.collection_person.add_fields( + { + "id": Column(column_type=PrimitiveType.NUMBER, is_primary_key=True, type=FieldType.COLUMN), + "first_name": Column(column_type=PrimitiveType.STRING, type=FieldType.COLUMN), + "last_name": Column(column_type=PrimitiveType.STRING, type=FieldType.COLUMN), + } + ) + self.datasource_1.add_collection(self.collection_person) + + self.datasource_2: Datasource = DS2() + self.collection_order = Collection("Order", self.datasource_2) + self.collection_order.add_fields( + { + "id": Column(column_type=PrimitiveType.NUMBER, is_primary_key=True, type=FieldType.COLUMN), + "customer_id": Column(column_type=PrimitiveType.NUMBER, type=FieldType.COLUMN), + "price": Column(column_type=PrimitiveType.NUMBER, type=FieldType.COLUMN), + } + ) + + self.datasource_2.add_collection(self.collection_order) + + +class TestCompositeDatasource(BaseTestCompositeDatasource): + def setUp(self) -> None: + super().setUp() + self.composite_ds.add_datasource(self.datasource_1) + + def test_add_datasource_should_raise_if_multiple_collection_with_same_name(self): + collection_person = Collection("Person", self.datasource_2) + collection_person.add_fields( + { + "id": Column(column_type=PrimitiveType.NUMBER, is_primary_key=True, type=FieldType.COLUMN), + "first_name": Column(column_type=PrimitiveType.STRING, type=FieldType.COLUMN), + "last_name": Column(column_type=PrimitiveType.STRING, type=FieldType.COLUMN), + } + ) + self.datasource_2.add_collection(collection_person) + + self.assertRaisesRegex( + DatasourceToolkitException, + r"Collection 'Person' already exists\.", + self.composite_ds.add_datasource, + self.datasource_2, + ) + + def test_collection_should_return_collection_of_all_datasources(self): + self.composite_ds.add_datasource(self.datasource_2) + + self.assertEqual(len(self.composite_ds.collections), 2) + self.assertIn("Person", [c.name for c in self.composite_ds.collections]) + self.assertIn("Order", [c.name for c in self.composite_ds.collections]) + + def test_get_collection_should_search_in_all_datasources(self): + self.composite_ds.add_datasource(self.datasource_2) + + collection = self.composite_ds.get_collection("Person") + self.assertEqual(collection.name, "Person") + self.assertEqual(collection.datasource, self.datasource_1) + + collection = self.composite_ds.get_collection("Order") + self.assertEqual(collection.name, "Order") + self.assertEqual(collection.datasource, self.datasource_2) + + def test_get_collection_should_list_collection_names_if_collection_not_found(self): + self.composite_ds.add_datasource(self.datasource_2) + + self.assertRaisesRegex( + DatasourceToolkitException, + "Collection Unknown not found. List of available collection: Order, Person", + self.composite_ds.get_collection, + "Unknown", + ) + + +class TestCompositeDatasourceCharts(BaseTestCompositeDatasource): + def setUp(self) -> None: + self.ds1_charts = {"charts": {"chart1": Mock()}} + self.ds2_charts = {"charts": {"chart2": Mock()}} + super().setUp() + + self.composite_ds.add_datasource(self.datasource_1) + + def test_add_datasource_should_raise_if_duplicated_chart(self): + self.ds1_charts["charts"]["chart2"] = Mock() + + self.assertRaisesRegex( + DatasourceToolkitException, + r"Chart 'chart2' already exists.", + self.composite_ds.add_datasource, + self.datasource_2, + ) + + def test_schema_should_contains_all_charts(self): + self.composite_ds.add_datasource(self.datasource_2) + self.assertIn("chart1", self.composite_ds.schema["charts"]) + self.assertIn("chart2", self.composite_ds.schema["charts"]) + + def test_render_chart_should_raise_if_chart_is_unknown(self): + self.composite_ds.add_datasource(self.datasource_2) + self.assertRaisesRegex( + DatasourceToolkitException, + "Chart unknown is not defined in the datasource.", + self.loop.run_until_complete, + self.composite_ds.render_chart(self.mocked_caller, "unknown"), + ) + + def test_render_chart_should_call_render_chart_on_good_datasource(self): + self.composite_ds.add_datasource(self.datasource_2) + + with patch.object(self.datasource_1, "render_chart", new_callable=AsyncMock) as mock_render_chart: + self.loop.run_until_complete(self.composite_ds.render_chart(self.mocked_caller, "chart1")) + mock_render_chart.assert_awaited_with(self.mocked_caller, "chart1") + + with patch.object(self.datasource_2, "render_chart", new_callable=AsyncMock) as mock_render_chart: + self.loop.run_until_complete(self.composite_ds.render_chart(self.mocked_caller, "chart2")) + mock_render_chart.assert_awaited_with(self.mocked_caller, "chart2") + + +class TestCompositeDatasourceNativeQuery(BaseTestCompositeDatasource): + + def setUp(self) -> None: + super().setUp() + self.composite_ds.add_datasource(self.datasource_1) + + def test_add_datasource_should_raise_if_duplicated_live_query_connection(self): + with patch.object(self.datasource_2, "get_native_query_connections", return_value=["db1"]): + self.assertRaisesRegex( + DatasourceToolkitException, + r"Native query connection 'db1' already exists.", + self.composite_ds.add_datasource, + self.datasource_2, + ) + + def test_get_native_query_connection_should_return_all_connections(self): + self.composite_ds.add_datasource(self.datasource_2) + connections = self.composite_ds.get_native_query_connections() + self.assertIn("db1", connections) + self.assertIn("db2", connections) + self.assertIn("db3", connections) + + def test_execute_native_query_should_raise_if_connection_is_unknown(self): + self.composite_ds.add_datasource(self.datasource_2) + + self.assertRaisesRegex( + NativeQueryException, + r"Native query connection 'bla' is unknown", + self.loop.run_until_complete, + self.composite_ds.execute_native_query("bla", "select * from ...", {}), + ) + + def test_execute_native_query_should_call_to_correct_datasource(self): + self.composite_ds.add_datasource(self.datasource_2) + + with patch.object(self.datasource_1, "execute_native_query", new_callable=AsyncMock) as mock_ds1_exec: + self.loop.run_until_complete(self.composite_ds.execute_native_query("db1", "select * from ...", {})) + mock_ds1_exec.assert_any_await("db1", "select * from ...", {}) + + self.loop.run_until_complete(self.composite_ds.execute_native_query("db2", "select * from ...", {})) + mock_ds1_exec.assert_any_await("db2", "select * from ...", {}) + + with patch.object(self.datasource_2, "execute_native_query", new_callable=AsyncMock) as mock_ds2_exec: + self.loop.run_until_complete(self.composite_ds.execute_native_query("db3", "select * from ...", {})) + mock_ds2_exec.assert_any_await("db3", "select * from ...", {}) diff --git a/src/django_agent/forestadmin/django_agent/urls.py b/src/django_agent/forestadmin/django_agent/urls.py index 60e9098e7..da2765286 100644 --- a/src/django_agent/forestadmin/django_agent/urls.py +++ b/src/django_agent/forestadmin/django_agent/urls.py @@ -1,7 +1,7 @@ from django.conf import settings from django.urls import path -from .views import actions, authentication, capabilities, charts, crud, crud_related, index, stats +from .views import actions, authentication, capabilities, charts, crud, crud_related, index, native_query, stats app_name = "django_agent" @@ -16,6 +16,7 @@ # generic path(f"{prefix}forest/", index.index, name="index"), path(f"{prefix}forest/_internal/capabilities", capabilities.capabilities, name="capabilities"), + path(f"{prefix}forest/_internal/native_query", native_query.native_query, name="capabilities"), path(f"{prefix}forest/scope-cache-invalidation", index.scope_cache_invalidation, name="scope_invalidation"), # authentication path(f"{prefix}forest/authentication", authentication.authentication, name="authentication"), diff --git a/src/django_agent/forestadmin/django_agent/views/index.py b/src/django_agent/forestadmin/django_agent/views/index.py index 7e888e703..59cd970b2 100644 --- a/src/django_agent/forestadmin/django_agent/views/index.py +++ b/src/django_agent/forestadmin/django_agent/views/index.py @@ -10,5 +10,9 @@ async def index(request: HttpRequest): @transaction.non_atomic_requests async def scope_cache_invalidation(request: HttpRequest): - DjangoAgentApp.get_agent()._permission_service.invalidate_cache("forest.scopes") + DjangoAgentApp.get_agent()._permission_service.invalidate_cache("forest.rendering") return HttpResponse(status=204) + + +# This is so ugly... But django.views.decorators.csrf.csrf_exempt is not asyncio ready +scope_cache_invalidation.csrf_exempt = True diff --git a/src/django_agent/forestadmin/django_agent/views/native_query.py b/src/django_agent/forestadmin/django_agent/views/native_query.py new file mode 100644 index 000000000..a49f95512 --- /dev/null +++ b/src/django_agent/forestadmin/django_agent/views/native_query.py @@ -0,0 +1,15 @@ +from asgiref.sync import async_to_sync +from django.http import HttpRequest +from forestadmin.django_agent.apps import DjangoAgentApp +from forestadmin.django_agent.utils.converter import convert_request, convert_response + + +@async_to_sync +async def native_query(request: HttpRequest, **kwargs): + resource = (await DjangoAgentApp.get_agent().get_resources())["native_query"] + response = await resource.dispatch(convert_request(request, kwargs), "native_query") + return convert_response(response) + + +# This is so ugly... But django.views.decorators.csrf.csrf_exempt is not asyncio ready +native_query.csrf_exempt = True diff --git a/src/django_agent/tests/test_http_routes.py b/src/django_agent/tests/test_http_routes.py index 5abb2f24a..7f6690d3f 100644 --- a/src/django_agent/tests/test_http_routes.py +++ b/src/django_agent/tests/test_http_routes.py @@ -18,6 +18,7 @@ def setUpClass(cls) -> None: cls.mocked_resources = {} for key in [ "capabilities", + "native_query", "authentication", "crud", "crud_related", @@ -94,12 +95,18 @@ def test_index(self): self.assertEqual(response.content, b"") def test_scope_cache_invalidation(self): - response = self.client.get( - f"/{self.conf_prefix}forest/scope-cache-invalidation", - HTTP_X_FORWARDED_FOR="179.114.131.49", - ) - self.assertEqual(response.status_code, 204) - self.assertEqual(response.content, b"") + with patch.object( + self.django_agent._permission_service, + "invalidate_cache", + spy=self.django_agent._permission_service.invalidate_cache, + ) as spy_invalidate: + response = self.client.get( + f"/{self.conf_prefix}forest/scope-cache-invalidation", + HTTP_X_FORWARDED_FOR="179.114.131.49", + ) + self.assertEqual(response.status_code, 204) + self.assertEqual(response.content, b"") + spy_invalidate.assert_called_once_with("forest.rendering") class TestDjangoAgentAuthenticationRoutes(TestDjangoAgentRoutes): @@ -457,3 +464,45 @@ def test_stat_list(self): self.assertEqual(request_param.method, RequestMethod.POST) self.assertEqual(request_param.query["collection_name"], "customer") self.assertEqual(request_param.body, {"post_attr": "post_value"}) + + +class TestDjangoAgentNativeQueryRoutes(TestDjangoAgentRoutes): + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.native_query_resource = cls.loop.run_until_complete(cls.django_agent.get_resources())["native_query"] + + def test_native_query(self): + response = self.client.post( + f"/{self.conf_prefix}forest/_internal/native_query?timezone=Europe%2FParis", + json.dumps( + { + "connectionName": "django", + "contextVariables": {}, + "query": "select status as key, sum(amount) as value from order group by key", + "type": "Pie", + } + ), + content_type="application/json", + ) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'{"mock": "ok"}') + self.native_query_resource.dispatch.assert_awaited() + call_args = self.native_query_resource.dispatch.await_args[0] + self.assertEqual(call_args[1], "native_query") + self.assertEqual( + call_args[0], + Request( + RequestMethod.POST, + body={ + "connectionName": "django", + "contextVariables": {}, + "query": "select status as key, sum(amount) as value from order group by key", + "type": "Pie", + }, + query={"timezone": "Europe/Paris"}, + headers={"Cookie": "", "Content-Length": "146", "Content-Type": "application/json"}, + client_ip="127.0.0.1", + ), + ) diff --git a/src/flask_agent/forestadmin/flask_agent/agent.py b/src/flask_agent/forestadmin/flask_agent/agent.py index e5c267797..d3fbff658 100644 --- a/src/flask_agent/forestadmin/flask_agent/agent.py +++ b/src/flask_agent/forestadmin/flask_agent/agent.py @@ -16,6 +16,7 @@ from forestadmin.agent_toolkit.resources.base import BaseResource from forestadmin.agent_toolkit.resources.capabilities import LiteralMethod as CapabilitiesLiteralMethod from forestadmin.agent_toolkit.resources.collections.crud import LiteralMethod as CrudLiteralMethod +from forestadmin.agent_toolkit.resources.collections.native_query import LiteralMethod as NativeQueryLiteralMethod from forestadmin.agent_toolkit.resources.security.resources import LiteralMethod as AuthLiteralMethod from forestadmin.agent_toolkit.utils.context import Request from forestadmin.agent_toolkit.utils.forest_schema.type import AgentMeta @@ -116,7 +117,13 @@ async def _get_collection_response( request: FlaskRequest, resource: BaseResource, method: Optional[ - Union[AuthLiteralMethod, CrudLiteralMethod, ActionLiteralMethod, CapabilitiesLiteralMethod] + Union[ + AuthLiteralMethod, + CrudLiteralMethod, + ActionLiteralMethod, + CapabilitiesLiteralMethod, + NativeQueryLiteralMethod, + ] ] = None, detail: bool = False, ) -> FlaskResponse: @@ -133,6 +140,10 @@ async def index() -> FlaskResponse: # type: ignore async def capabilities() -> FlaskResponse: # type: ignore return await _get_collection_response(request, (await agent.get_resources())["capabilities"], "capabilities") + @blueprint.route("/_internal/native_query", methods=["POST"]) + async def native_query() -> FlaskResponse: # type: ignore + return await _get_collection_response(request, (await agent.get_resources())["native_query"], "native_query") + @blueprint.route("/authentication/callback", methods=["GET"]) async def callback() -> FlaskResponse: # type: ignore return await _get_collection_response(request, (await agent.get_resources())["authentication"], "callback") @@ -199,7 +210,7 @@ async def csv_related(**_) -> FlaskResponse: # type: ignore @blueprint.route("/scope-cache-invalidation", methods=["POST"]) async def scope_cache_invalidation(**_) -> FlaskResponse: # type: ignore - agent._permission_service.invalidate_cache("forest.scopes") + agent._permission_service.invalidate_cache("forest.rendering") rsp = FlaskResponse(status=204) return rsp diff --git a/src/flask_agent/tests/test_flask_agent.py b/src/flask_agent/tests/test_flask_agent.py index c119f82b7..ae8cd5048 100644 --- a/src/flask_agent/tests/test_flask_agent.py +++ b/src/flask_agent/tests/test_flask_agent.py @@ -69,6 +69,7 @@ def test_build_blueprint(self, mocked_blueprint, mock_base_agent_resources, mock calls = [ call("", methods=["GET"]), call("/_internal/capabilities", methods=["POST"]), + call("/_internal/native_query", methods=["POST"]), call("/authentication/callback", methods=["GET"]), call("/_actions////hooks/load", methods=["POST"]), call("/_actions////hooks/change", methods=["POST"]), diff --git a/src/flask_agent/tests/test_flask_agent_blueprint.py b/src/flask_agent/tests/test_flask_agent_blueprint.py index 7bbf1360a..f7f235e86 100644 --- a/src/flask_agent/tests/test_flask_agent_blueprint.py +++ b/src/flask_agent/tests/test_flask_agent_blueprint.py @@ -13,6 +13,7 @@ def setUpClass(cls) -> None: cls.loop = asyncio.new_event_loop() cls.mocked_resources = {} for key in [ + "native_query", "capabilities", "authentication", "crud", @@ -76,6 +77,44 @@ def test_capabilities(self): ), ) + def test_native_query(self): + response = self.client.post( + "/forest/_internal/native_query?timezone=Europe%2FParis", + json={ + "connectionName": "django", + "contextVariables": {}, + "query": "select status as key, sum(amount) as value from order group by key", + "type": "Pie", + }, + ) + assert response.status_code == 200 + assert response.json == {"mock": "ok"} + self.mocked_resources["native_query"].dispatch.assert_awaited() + call_args = self.mocked_resources["native_query"].dispatch.await_args.args + self.assertEqual(call_args[1], "native_query") + headers = {**call_args[0].headers} + del headers["User-Agent"] + call_args[0].headers = headers + self.assertEqual( + call_args[0], + Request( + RequestMethod.POST, + body={ + "connectionName": "django", + "contextVariables": {}, + "query": "select status as key, sum(amount) as value from order group by key", + "type": "Pie", + }, + query={"timezone": "Europe/Paris"}, + headers={ + "Host": "localhost", + "Content-Type": "application/json", + "Content-Length": "146", + }, + client_ip="127.0.0.1", + ), + ) + def test_hook_load(self): response = self.client.post("/forest/_actions/customer/1/action_name/hooks/load") assert response.status_code == 200 @@ -240,6 +279,5 @@ def test_datasource_chart(self): def test_invalidate_cache(self): with patch.object(self.agent._permission_service, "invalidate_cache") as mocked_invalidate_cache: response = self.client.post("/forest/scope-cache-invalidation") - mocked_invalidate_cache.assert_called_with("forest.scopes") - assert response.status_code == 204 + mocked_invalidate_cache.assert_called_with("forest.rendering") assert response.status_code == 204