diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py index e4c1b9ee2..9a824a373 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/resources/collections/stats.py @@ -1,8 +1,7 @@ -from datetime import date, datetime +from datetime import date from typing import Any, Dict, List, Literal, Optional, Union, cast from uuid import uuid1 -import pandas as pd from forestadmin.agent_toolkit.forest_logger import ForestLogger from forestadmin.agent_toolkit.resources.collections.base_collection_resource import BaseCollectionResource from forestadmin.agent_toolkit.resources.collections.decorators import ( @@ -16,20 +15,21 @@ from forestadmin.agent_toolkit.resources.context_variable_injector_mixin import ContextVariableInjectorResourceMixin from forestadmin.agent_toolkit.utils.context import FileResponse, HttpResponseBuilder, Request, RequestMethod, Response from forestadmin.datasource_toolkit.exceptions import ForbiddenError, ForestException -from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation +from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation, DateOperation from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.base import ConditionTree from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import Aggregator, ConditionTreeBranch from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf from forestadmin.datasource_toolkit.interfaces.query.filter.factory import FilterFactory from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter +from forestadmin.datasource_toolkit.utils.date_utils import ( + DATE_OPERATION_STR_FORMAT_FN, + make_formatted_date_range, + parse_date, +) from forestadmin.datasource_toolkit.utils.schema import SchemaUtils class StatsResource(BaseCollectionResource, ContextVariableInjectorResourceMixin): - FREQUENCIES = {"Day": "d", "Week": "W-MON", "Month": "BMS", "Year": "BYS"} - - FORMAT = {"Day": "%d/%m/%Y", "Week": "W%V-%G", "Month": "%b %Y", "Year": "%Y"} - def stats_method(self, type: str): return { "Value": self.value, @@ -135,12 +135,13 @@ async def line(self, request: RequestCollection) -> Response: if key not in request.body: raise ForestException(f"The parameter {key} is not defined") + date_operation = DateOperation(request.body["timeRange"]) current_filter = await self._get_filter(request) aggregation = Aggregation( { "operation": request.body["aggregator"], "field": request.body.get("aggregateFieldName"), - "groups": [{"field": request.body["groupByFieldName"], "operation": request.body["timeRange"]}], + "groups": [{"field": request.body["groupByFieldName"], "operation": date_operation}], } ) rows = await request.collection.aggregate(request.user, current_filter, aggregation) @@ -149,34 +150,23 @@ async def line(self, request: RequestCollection) -> Response: for row in rows: label = row["group"][request.body["groupByFieldName"]] if label is not None: - if isinstance(label, str): - label = datetime.fromisoformat(label).date() - elif isinstance(label, datetime): - label = label.date() - elif isinstance(label, date): - pass - else: - ForestLogger.log( - "warning", - f"The time chart label type must be 'str' or 'date', not {type(label)}. Skipping this record.", - ) + label = parse_date(label) dates.append(label) - values_label[label.strftime(self.FORMAT[request.body["timeRange"]])] = row["value"] + values_label[DATE_OPERATION_STR_FORMAT_FN[date_operation](label)] = row["value"] dates.sort() end = dates[-1] start = dates[0] - data_points: List[Dict[str, Union[date, Dict[str, int]]]] = [] - for dt in pd.date_range( # type: ignore - start=start, end=end, freq=self.FREQUENCIES[request.body["timeRange"]] - ).to_pydatetime(): - label = dt.strftime(self.FORMAT[request.body["timeRange"]]) + data_points: List[Dict[str, Union[date, Dict[str, int], str]]] = [] + + for label in make_formatted_date_range(start, end, date_operation): data_points.append( { "label": label, "values": {"value": values_label.get(label, 0)}, } ) + return self._build_success_response(data_points) @check_method(RequestMethod.POST) diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/utils/forest_schema/action_fields.py b/src/agent_toolkit/forestadmin/agent_toolkit/utils/forest_schema/action_fields.py index bef60b9d6..cb99eced3 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/utils/forest_schema/action_fields.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/utils/forest_schema/action_fields.py @@ -119,7 +119,9 @@ def is_checkbox_group_field( return field is not None and field.get("widget", "") == "CheckboxGroup" @staticmethod - def is_dropdown_field(field: ActionField) -> TypeGuard[ + def is_dropdown_field( + field: ActionField, + ) -> TypeGuard[ Union[ DropdownDynamicSearchFieldConfiguration[str], DropdownDynamicSearchFieldConfiguration[int], @@ -129,7 +131,9 @@ def is_dropdown_field(field: ActionField) -> TypeGuard[ return field is not None and field.get("widget", "") == "Dropdown" @staticmethod - def is_user_dropdown_field(field: ActionField) -> TypeGuard[ + def is_user_dropdown_field( + field: ActionField, + ) -> TypeGuard[ Union[ PlainStringListDynamicFieldUserDropdownFieldConfiguration, PlainStringDynamicFieldUserDropdownFieldConfiguration, diff --git a/src/agent_toolkit/forestadmin/agent_toolkit/utils/forest_schema/filterable.py b/src/agent_toolkit/forestadmin/agent_toolkit/utils/forest_schema/filterable.py index 5cf10fd37..76545f343 100644 --- a/src/agent_toolkit/forestadmin/agent_toolkit/utils/forest_schema/filterable.py +++ b/src/agent_toolkit/forestadmin/agent_toolkit/utils/forest_schema/filterable.py @@ -4,7 +4,6 @@ class FrontendFilterableUtils: - @classmethod def is_filterable(cls, operators: Set[Operator]) -> bool: return operators is not None and len(operators) > 0 diff --git a/src/agent_toolkit/pyproject.toml b/src/agent_toolkit/pyproject.toml index 2c38eb7e3..5134d8dbd 100644 --- a/src/agent_toolkit/pyproject.toml +++ b/src/agent_toolkit/pyproject.toml @@ -23,21 +23,7 @@ pyjwt = "^2" cachetools = "~=5.2" sseclient-py = "^1.5" forestadmin-datasource-toolkit = "1.22.11" -[[tool.poetry.dependencies.pandas]] -version = ">=1.4.0" -python = "<3.13.0" -[[tool.poetry.dependencies.pandas]] -version = ">=2.2.3" -python = ">=3.13.0" - -[[tool.poetry.dependencies.numpy]] -python = ">=3.8.0,<3.12" -version = ">=1.24.0" - -[[tool.poetry.dependencies.numpy]] -python = ">=3.13" -version = ">=1.3.0" [tool.poetry.dependencies."backports.zoneinfo"] version = "~0.2.1" diff --git a/src/agent_toolkit/tests/resources/collections/test_stats_resources.py b/src/agent_toolkit/tests/resources/collections/test_stats_resources.py index d2e131bdb..190cdca3e 100644 --- a/src/agent_toolkit/tests/resources/collections/test_stats_resources.py +++ b/src/agent_toolkit/tests/resources/collections/test_stats_resources.py @@ -686,6 +686,54 @@ def test_line_should_return_chart_with_month_filter(self): {"label": "Feb 2022", "values": {"value": 15}}, ) + def test_line_should_return_chart_with_quarter_filter(self): + request = self.mk_request("Quarter") + with patch.object( + self.book_collection, + "aggregate", + new_callable=AsyncMock, + return_value=[ + {"value": 10, "group": {"date": "2022-03-31 00:00:00"}}, + {"value": 20, "group": {"date": "2022-06-30 00:00:00"}}, + {"value": 30, "group": {"date": "2022-09-30 00:00:00"}}, + {"value": 40, "group": {"date": "2022-12-31 00:00:00"}}, + ], + ): + response = self.loop.run_until_complete(self.stat_resource.line(request)) + + content_body = json.loads(response.body) + self.assertEqual(response.status, 200) + self.assertEqual(content_body["data"]["type"], "stats") + self.assertEqual(len(content_body["data"]["attributes"]["value"]), 4) + self.assertEqual(content_body["data"]["attributes"]["value"][0], {"label": "Q1-2022", "values": {"value": 10}}) + self.assertEqual(content_body["data"]["attributes"]["value"][1], {"label": "Q2-2022", "values": {"value": 20}}) + self.assertEqual(content_body["data"]["attributes"]["value"][2], {"label": "Q3-2022", "values": {"value": 30}}) + self.assertEqual(content_body["data"]["attributes"]["value"][3], {"label": "Q4-2022", "values": {"value": 40}}) + + def test_line_should_return_chart_with_quarter_filter_should_also_work_with_date_as_quarter_start(self): + request = self.mk_request("Quarter") + with patch.object( + self.book_collection, + "aggregate", + new_callable=AsyncMock, + return_value=[ + {"value": 10, "group": {"date": "2022-01-01 00:00:00"}}, + {"value": 20, "group": {"date": "2022-04-01 00:00:00"}}, + {"value": 30, "group": {"date": "2022-07-01 00:00:00"}}, + {"value": 40, "group": {"date": "2022-10-01 00:00:00"}}, + ], + ): + response = self.loop.run_until_complete(self.stat_resource.line(request)) + + content_body = json.loads(response.body) + self.assertEqual(response.status, 200) + self.assertEqual(content_body["data"]["type"], "stats") + self.assertEqual(len(content_body["data"]["attributes"]["value"]), 4) + self.assertEqual(content_body["data"]["attributes"]["value"][0], {"label": "Q1-2022", "values": {"value": 10}}) + self.assertEqual(content_body["data"]["attributes"]["value"][1], {"label": "Q2-2022", "values": {"value": 20}}) + self.assertEqual(content_body["data"]["attributes"]["value"][2], {"label": "Q3-2022", "values": {"value": 30}}) + self.assertEqual(content_body["data"]["attributes"]["value"][3], {"label": "Q4-2022", "values": {"value": 40}}) + def test_line_should_return_chart_with_year_filter(self): request = self.mk_request("Year") with patch.object( diff --git a/src/datasource_django/forestadmin/datasource_django/utils/query_factory.py b/src/datasource_django/forestadmin/datasource_django/utils/query_factory.py index 6a6e3f2bf..0489c1aaa 100644 --- a/src/datasource_django/forestadmin/datasource_django/utils/query_factory.py +++ b/src/datasource_django/forestadmin/datasource_django/utils/query_factory.py @@ -2,6 +2,7 @@ from datetime import date, datetime from typing import Any, Dict, List, Optional, Set, Tuple +import pandas as pd from django.db import models from forestadmin.datasource_django.exception import DjangoDatasourceException from forestadmin.datasource_django.interface import BaseDjangoCollection @@ -311,6 +312,7 @@ class DjangoQueryGroupByHelper: DateOperation.DAY: "__day", DateOperation.WEEK: "__week", DateOperation.MONTH: "__month", + DateOperation.QUARTER: "__quarter", DateOperation.YEAR: "__year", } @@ -331,6 +333,11 @@ def get_operation_suffixes(cls, group: PlainAggregationGroup) -> List[str]: cls.DATE_OPERATION_SUFFIX_MAPPING[DateOperation.YEAR], cls.DATE_OPERATION_SUFFIX_MAPPING[DateOperation.WEEK], ] + if group["operation"] == DateOperation.QUARTER: + return [ + cls.DATE_OPERATION_SUFFIX_MAPPING[DateOperation.YEAR], + cls.DATE_OPERATION_SUFFIX_MAPPING[DateOperation.QUARTER], + ] if group["operation"] == DateOperation.DAY: return [ cls.DATE_OPERATION_SUFFIX_MAPPING[DateOperation.YEAR], @@ -380,5 +387,11 @@ def _make_date_from_record(cls, row: AggregateResult, date_field: str, date_oper row_date = datetime.strptime(str_year_week + "-1", "%Y-W%W-%w") return row_date.date() + if date_operation == DateOperation.QUARTER: + end_of_quarter_date = ( + pd.Timestamp(row[f"{date_field}__year"], (row[f"{date_field}__quarter"] * 3), 1) + pd.offsets.MonthEnd() + ) + return end_of_quarter_date.date() + if date_operation == DateOperation.DAY: return date(row[f"{date_field}__year"], row[f"{date_field}__month"], row[f"{date_field}__day"]) diff --git a/src/datasource_django/pyproject.toml b/src/datasource_django/pyproject.toml index b74d7859e..0091a8c2e 100644 --- a/src/datasource_django/pyproject.toml +++ b/src/datasource_django/pyproject.toml @@ -20,6 +20,13 @@ typing-extensions = "~=4.2" django = ">=3.2,<5.2" forestadmin-datasource-toolkit = "1.22.11" forestadmin-agent-toolkit = "1.22.11" +[[tool.poetry.dependencies.pandas]] +version = ">=1.4.0" +python = "<3.13.0" + +[[tool.poetry.dependencies.pandas]] +version = ">=2.2.3" +python = ">=3.13.0" [tool.pytest.ini_options] DJANGO_SETTINGS_MODULE = "test_project_datasource.settings" diff --git a/src/datasource_django/tests/test_django_collection.py b/src/datasource_django/tests/test_django_collection.py index 26c979494..ed8298843 100644 --- a/src/datasource_django/tests/test_django_collection.py +++ b/src/datasource_django/tests/test_django_collection.py @@ -512,6 +512,21 @@ async def test_should_work_by_year(self): ], ) + async def test_should_work_by_quarter(self): + ret = await self.rating_collection.aggregate( + self.mocked_caller, + Filter({}), + Aggregation( + { + "operation": "Sum", + "field": "rating", + "groups": [{"field": "rated_at", "operation": DateOperation.QUARTER}], + } + ), + ) + self.assertIn({"value": 16, "group": {"rated_at": datetime.date(2023, 3, 31)}}, ret) + self.assertIn({"value": 1, "group": {"rated_at": datetime.date(2022, 12, 31)}}, ret) + async def test_should_work_by_month(self): ret = await self.rating_collection.aggregate( self.mocked_caller, diff --git a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/utils/aggregation.py b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/utils/aggregation.py index be77891af..da1fd7969 100644 --- a/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/utils/aggregation.py +++ b/src/datasource_sqlalchemy/forestadmin/datasource_sqlalchemy/utils/aggregation.py @@ -6,9 +6,9 @@ from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation, Aggregator, DateOperation from forestadmin.datasource_toolkit.interfaces.query.projections import Projection -from sqlalchemy import DATE, cast +from sqlalchemy import DATE, Integer, cast from sqlalchemy import column as SqlAlchemyColumn -from sqlalchemy import func, text +from sqlalchemy import extract, func, text from sqlalchemy.engine import Dialect @@ -82,12 +82,22 @@ def build_group( class DateAggregation: @staticmethod def build_postgres(column: SqlAlchemyColumn, operation: DateOperation) -> SqlAlchemyColumn: + return func.date_trunc(operation.value.lower(), column) @staticmethod - def build_sqllite(column: SqlAlchemyColumn, operation: DateOperation) -> SqlAlchemyColumn: + def build_sqlite(column: SqlAlchemyColumn, operation: DateOperation) -> SqlAlchemyColumn: if operation == DateOperation.WEEK: return func.DATE(column, "weekday 1", "-7 days") + elif operation == DateOperation.QUARTER: + return func.date( + func.strftime("%Y", column) + + "-" + + func.printf("%02d", (func.floor((func.cast(func.strftime("%m", column), Integer) - 1) / 3) + 1) * 3) + + "-01", + "+1 month", + "-1 day", + ) elif operation == DateOperation.YEAR: format = "%Y-01-01" elif operation == DateOperation.MONTH: @@ -107,6 +117,15 @@ def build_mysql(column: SqlAlchemyColumn, operation: DateOperation) -> SqlAlchem format = "%Y-%m-01" elif operation == DateOperation.WEEK: return cast(func.date_sub(column, text(f"INTERVAL(WEEKDAY({column})) DAY")), DATE) + elif operation == DateOperation.QUARTER: + return func.last_day( + func.str_to_date( + func.concat( + func.year(column), "-", func.lpad(func.ceiling(extract("month", column) / 3) * 3, 2, "0"), "-01" + ), + "%Y-%m-%d", + ) + ) elif operation == DateOperation.DAY: format = "%Y-%m-%d" else: @@ -121,6 +140,14 @@ def build_mssql(column: SqlAlchemyColumn, operation: DateOperation) -> SqlAlchem return func.datefromparts(func.extract("year", column), func.extract("month", column), "01") elif operation == DateOperation.WEEK: return cast(func.dateadd(text("day"), -func.extract("dw", column) + 2, column), DATE) + elif operation == DateOperation.QUARTER: + return func.eomonth( + func.datefromparts( + func.extract("YEAR", column), + func.datepart(text("QUARTER"), column) * text("3"), + text("1"), + ) + ) elif operation == DateOperation.DAY: return func.datefromparts( func.extract("year", column), @@ -131,7 +158,7 @@ def build_mssql(column: SqlAlchemyColumn, operation: DateOperation) -> SqlAlchem @classmethod def build(cls, dialect: Dialect, column: SqlAlchemyColumn, operation: DateOperation) -> SqlAlchemyColumn: if dialect.name == "sqlite": - return cls.build_sqllite(column, operation) + return cls.build_sqlite(column, operation) elif dialect.name in ["mysql", "mariadb"]: return cls.build_mysql(column, operation) elif dialect.name == "postgresql": diff --git a/src/datasource_sqlalchemy/tests/test_sqlalchemy_collections.py b/src/datasource_sqlalchemy/tests/test_sqlalchemy_collections.py index 7397a5455..b7040108c 100644 --- a/src/datasource_sqlalchemy/tests/test_sqlalchemy_collections.py +++ b/src/datasource_sqlalchemy/tests/test_sqlalchemy_collections.py @@ -5,8 +5,6 @@ from unittest import TestCase from unittest.mock import Mock, patch -from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import ConditionTreeBranch - if sys.version_info >= (3, 9): import zoneinfo else: @@ -18,6 +16,7 @@ from forestadmin.datasource_sqlalchemy.exceptions import SqlAlchemyCollectionException from forestadmin.datasource_toolkit.interfaces.fields import Operator from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation +from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import ConditionTreeBranch from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter from forestadmin.datasource_toolkit.interfaces.query.projections import Projection @@ -141,7 +140,7 @@ def test__normalize_projection(self, mocked_sqlalchemy_collection_factory, mocke assert projection == Projection("city", "customers:first_name") -class TestSqlAlchemyCollectionWithModels(TestCase): +class BaseTestSqlAlchemyCollectionWithModels(TestCase): @classmethod def setUpClass(cls): cls.loop = asyncio.new_event_loop() @@ -168,6 +167,8 @@ def tearDownClass(cls): os.remove(cls.sql_alchemy_base.metadata.file_path) cls.loop.close() + +class TestSqlAlchemyCollectionWithModels(BaseTestSqlAlchemyCollectionWithModels): def test_get_columns(self): collection = self.datasource.get_collection("order") columns, relationships = collection.get_columns(Projection("amount", "status", "customer:first_name")) @@ -254,6 +255,102 @@ def test_list_filter_relation(self): self.assertEqual(len(results), 1) self.assertEqual(results[0]["customer"]["id"], 1) + def test_list_with_filter_with_aggregator(self): + collection = self.datasource.get_collection("order") + filter_ = PaginatedFilter( + { + "condition_tree": ConditionTreeBranch( + "and", + [ConditionTreeLeaf("id", Operator.LESS_THAN, 6), ConditionTreeLeaf("id", Operator.GREATER_THAN, 1)], + ), + } + ) + + results = self.loop.run_until_complete( + collection.list(self.mocked_caller, filter_, Projection("id", "created_at")) + ) + self.assertEqual(len(results), 4) + + filter_ = PaginatedFilter( + { + "condition_tree": ConditionTreeBranch( + "or", [ConditionTreeLeaf("id", "equal", 6), ConditionTreeLeaf("id", "equal", 1)] + ), + } + ) + results = self.loop.run_until_complete( + collection.list(self.mocked_caller, filter_, Projection("id", "created_at")) + ) + self.assertEqual(len(results), 2) + + def test_list_should_handle_sort(self): + collection = self.datasource.get_collection("order") + filter_ = PaginatedFilter( + { + "condition_tree": ConditionTreeBranch( + "and", + [ConditionTreeLeaf("id", Operator.LESS_THAN, 6), ConditionTreeLeaf("id", Operator.GREATER_THAN, 1)], + ), + "sort": [{"field": "id", "ascending": True}], + } + ) + + results = self.loop.run_until_complete( + collection.list(self.mocked_caller, filter_, Projection("id", "created_at")) + ) + self.assertEqual(len(results), 4) + for i in range(2, 6): + self.assertEqual(results[i - 2]["id"], i) + + filter_ = PaginatedFilter( + { + "condition_tree": ConditionTreeBranch( + "and", + [ConditionTreeLeaf("id", Operator.LESS_THAN, 6), ConditionTreeLeaf("id", Operator.GREATER_THAN, 1)], + ), + "sort": [{"field": "id", "ascending": False}], + } + ) + + results = self.loop.run_until_complete( + collection.list(self.mocked_caller, filter_, Projection("id", "created_at")) + ) + self.assertEqual(len(results), 4) + self.assertEqual(results[0]["id"], 5) + self.assertEqual(results[1]["id"], 4) + self.assertEqual(results[2]["id"], 3) + self.assertEqual(results[3]["id"], 2) + + def test_list_should_handle_multiple_sort(self): + collection = self.datasource.get_collection("order") + filter_ = PaginatedFilter( + { + "condition_tree": ConditionTreeBranch( + "and", + [ConditionTreeLeaf("id", Operator.LESS_THAN, 6), ConditionTreeLeaf("id", Operator.GREATER_THAN, 1)], + ), + "sort": [ + {"field": "status", "ascending": True}, + {"field": "customer:id", "ascending": True}, + {"field": "amount", "ascending": False}, + ], + } + ) + + results = self.loop.run_until_complete( + collection.list(self.mocked_caller, filter_, Projection("id", "customer:id", "status", "amount")) + ) + self.assertEqual(len(results), 4) + self.assertEqual( + results, + [ + {"id": 5, "status": models.ORDER_STATUS.DELIVERED, "amount": 9526, "customer": {"id": 8}}, + {"id": 3, "status": models.ORDER_STATUS.DELIVERED, "amount": 5285, "customer": {"id": 9}}, + {"id": 4, "status": models.ORDER_STATUS.DELIVERED, "amount": 4684, "customer": {"id": 9}}, + {"id": 2, "status": models.ORDER_STATUS.DELIVERED, "amount": 2664, "customer": {"id": 10}}, + ], + ) + def test_create(self): order = { "id": 11, @@ -345,6 +442,164 @@ def test_aggregate(self): assert [*filter(lambda item: item["group"]["customer_id"] == 9, results)][0]["value"] == 4984.5 assert [*filter(lambda item: item["group"]["customer_id"] == 10, results)][0]["value"] == 3408.5 + def test_aggregate_by_date_year(self): + filter_ = PaginatedFilter({"condition_tree": ConditionTreeLeaf("id", Operator.LESS_THAN, 11)}) + collection = self.datasource.get_collection("order") + + results = self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Year"}], + } + ), + ) + ) + + self.assertEqual(len(results), 3) + self.assertIn({"value": 5881.666666666667, "group": {"created_at": "2022-01-01"}}, results) + self.assertIn({"value": 5278.5, "group": {"created_at": "2023-01-01"}}, results) + self.assertIn({"value": 4433.8, "group": {"created_at": "2021-01-01"}}, results) + + def test_aggregate_by_date_quarter(self): + filter_ = PaginatedFilter({"condition_tree": ConditionTreeLeaf("id", Operator.LESS_THAN, 11)}) + collection = self.datasource.get_collection("order") + + results = self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Quarter"}], + } + ), + ) + ) + + self.assertEqual(len(results), 7) + self.assertIn({"value": 9744.0, "group": {"created_at": "2021-09-30"}}, results) + self.assertIn({"value": 7676.0, "group": {"created_at": "2022-09-30"}}, results) + self.assertIn({"value": 5285.0, "group": {"created_at": "2022-06-30"}}, results) + self.assertIn({"value": 5278.5, "group": {"created_at": "2023-03-31"}}, results) + self.assertIn({"value": 4753.5, "group": {"created_at": "2021-03-31"}}, results) + self.assertIn({"value": 4684.0, "group": {"created_at": "2022-12-31"}}, results) + self.assertIn({"value": 1459.0, "group": {"created_at": "2021-06-30"}}, results) + + def test_aggregate_by_date_month(self): + filter_ = PaginatedFilter( + { + "condition_tree": ConditionTreeBranch( + "and", + [ + ConditionTreeLeaf("id", Operator.LESS_THAN, 11), + ConditionTreeLeaf("id", Operator.GREATER_THAN, 4), + ], + ) + } + ) + collection = self.datasource.get_collection("order") + + results = self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Month"}], + } + ), + ) + ) + + self.assertEqual(len(results), 6) + self.assertIn({"value": 9744.0, "group": {"created_at": "2021-07-01"}}, results) + self.assertIn({"value": 9526.0, "group": {"created_at": "2023-02-01"}}, results) + self.assertIn({"value": 7676.0, "group": {"created_at": "2022-08-01"}}, results) + self.assertIn({"value": 5354.0, "group": {"created_at": "2021-01-01"}}, results) + self.assertIn({"value": 4153.0, "group": {"created_at": "2021-03-01"}}, results) + self.assertIn({"value": 254.0, "group": {"created_at": "2021-05-01"}}, results) + + def test_aggregate_by_date_week(self): + filter_ = PaginatedFilter( + { + "condition_tree": ConditionTreeBranch( + "and", + [ + ConditionTreeLeaf("id", Operator.LESS_THAN, 11), + ConditionTreeLeaf("id", Operator.GREATER_THAN, 4), + ], + ) + } + ) + collection = self.datasource.get_collection("order") + + results = self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Week"}], + } + ), + ) + ) + + self.assertEqual(len(results), 6) + self.assertIn({"value": 9744.0, "group": {"created_at": "2021-06-28"}}, results) + self.assertIn({"value": 9526.0, "group": {"created_at": "2023-02-20"}}, results) + self.assertIn({"value": 7676.0, "group": {"created_at": "2022-08-01"}}, results) + self.assertIn({"value": 5354.0, "group": {"created_at": "2021-01-11"}}, results) + self.assertIn({"value": 4153.0, "group": {"created_at": "2021-03-08"}}, results) + self.assertIn({"value": 254.0, "group": {"created_at": "2021-05-24"}}, results) + + def test_aggregate_by_date_day(self): + filter_ = PaginatedFilter( + { + "condition_tree": ConditionTreeBranch( + "and", + [ + ConditionTreeLeaf("id", Operator.LESS_THAN, 11), + ConditionTreeLeaf("id", Operator.GREATER_THAN, 4), + ], + ) + } + ) + collection = self.datasource.get_collection("order") + + results = self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Day"}], + } + ), + ) + ) + + self.assertEqual(len(results), 6) + self.assertIn({"value": 9744.0, "group": {"created_at": "2021-07-05"}}, results) + self.assertIn({"value": 9526.0, "group": {"created_at": "2023-02-27"}}, results) + self.assertIn({"value": 7676.0, "group": {"created_at": "2022-08-07"}}, results) + self.assertIn({"value": 5354.0, "group": {"created_at": "2021-01-13"}}, results) + self.assertIn({"value": 4153.0, "group": {"created_at": "2021-03-13"}}, results) + self.assertIn({"value": 254.0, "group": {"created_at": "2021-05-30"}}, results) + def test_get_native_driver_should_return_connection(self): with self.datasource.get_collection("order").get_native_driver() as connection: self.assertIsInstance(connection, Session) @@ -362,6 +617,781 @@ def test_get_native_driver_should_work_without_declaring_request_as_text(self): self.assertEqual(rows, [(3, 5285)]) +class TestSQLAlchemyOnSQLite(BaseTestSqlAlchemyCollectionWithModels): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.dialect = "sqlite" + + def test_can_aggregate_date_by_year(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Year"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, ' + 'strftime(:strftime_1, "order".created_at) AS created_at__grouped__ \n' + 'FROM "order" GROUP BY strftime(:strftime_1, "order".created_at) ' + "ORDER BY __aggregate__ DESC", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["%Y-01-01"], + ) + + def test_can_aggregate_date_by_quarter(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Quarter"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, date(strftime(:strftime_1, "order".created_at) || ' + ':strftime_2 || printf(:printf_1, (floor((CAST(strftime(:strftime_3, "order".created_at) AS INTEGER) ' + "- :param_1) / CAST(:param_2 AS NUMERIC)) + :floor_1) * :param_3) || :param_4, :date_1, :date_2) " + "AS created_at__grouped__ \n" + 'FROM "order" GROUP BY date(strftime(:strftime_1, "order".created_at) || :strftime_2 || ' + 'printf(:printf_1, (floor((CAST(strftime(:strftime_3, "order".created_at) AS INTEGER) - :param_1) / ' + "CAST(:param_2 AS NUMERIC)) + :floor_1) * :param_3) || :param_4, :date_1, :date_2) " + "ORDER BY __aggregate__ DESC", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["%Y", "-", "%02d", "%m", 1, 3, 1, 3, "-01", "+1 month", "-1 day"], + ) + + def test_can_aggregate_date_by_month(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Month"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, strftime(:strftime_1, "order".created_at) ' + "AS created_at__grouped__ \n" + 'FROM "order" GROUP BY strftime(:strftime_1, "order".created_at) ORDER BY __aggregate__ DESC', + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["%Y-%m-01"], + ) + + def test_can_aggregate_date_by_week(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Week"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, DATE("order".created_at, :DATE_1, :DATE_2) AS ' + 'created_at__grouped__ \nFROM "order" ' + 'GROUP BY DATE("order".created_at, :DATE_1, :DATE_2) ' + "ORDER BY __aggregate__ DESC", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["weekday 1", "-7 days"], + ) + + def test_can_aggregate_date_by_day(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Day"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, strftime(:strftime_1, "order".created_at) ' + "AS created_at__grouped__ \n" + 'FROM "order" GROUP BY strftime(:strftime_1, "order".created_at) ORDER BY __aggregate__ DESC', + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["%Y-%m-%d"], + ) + + +class TestSQLAlchemyOnPostgres(BaseTestSqlAlchemyCollectionWithModels): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.dialect = "postgresql" + + def test_can_aggregate_date_by_year(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Year"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, date_trunc(:date_trunc_1, "order".created_at) ' + "AS created_at__grouped__ \n" + 'FROM "order" ' + 'GROUP BY date_trunc(:date_trunc_1, "order".created_at) ORDER BY __aggregate__ DESC NULLS LAST', + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["year"], + ) + + def test_can_aggregate_date_by_quarter(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Quarter"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, date_trunc(:date_trunc_1, "order".created_at) ' + "AS created_at__grouped__ \n" + 'FROM "order" ' + 'GROUP BY date_trunc(:date_trunc_1, "order".created_at) ORDER BY __aggregate__ DESC NULLS LAST', + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["quarter"], + ) + + def test_can_aggregate_date_by_month(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Month"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, date_trunc(:date_trunc_1, "order".created_at) ' + "AS created_at__grouped__ \n" + 'FROM "order" ' + 'GROUP BY date_trunc(:date_trunc_1, "order".created_at) ORDER BY __aggregate__ DESC NULLS LAST', + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["month"], + ) + + def test_can_aggregate_date_by_week(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Week"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, date_trunc(:date_trunc_1, "order".created_at) AS ' + "created_at__grouped__ \n" + 'FROM "order" ' + 'GROUP BY date_trunc(:date_trunc_1, "order".created_at) ' + "ORDER BY __aggregate__ DESC NULLS LAST", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["week"], + ) + + def test_can_aggregate_date_by_day(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Day"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, date_trunc(:date_trunc_1, "order".created_at) AS ' + "created_at__grouped__ \n" + 'FROM "order" ' + 'GROUP BY date_trunc(:date_trunc_1, "order".created_at) ' + "ORDER BY __aggregate__ DESC NULLS LAST", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["day"], + ) + + +class TestSQLAlchemyOnMySQL(BaseTestSqlAlchemyCollectionWithModels): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.dialect = "mysql" # same as 'mariadb' + + def test_can_aggregate_date_by_year(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Year"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, date_format("order".created_at, :date_format_1) ' + "AS created_at__grouped__ \n" + 'FROM "order" ' + 'GROUP BY date_format("order".created_at, :date_format_1) ' + "ORDER BY __aggregate__ DESC", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["%Y-01-01"], + ) + + def test_can_aggregate_date_by_quarter(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Quarter"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, last_day(str_to_date(concat(year("order".created_at), ' + ':concat_1, lpad(ceiling(EXTRACT(month FROM "order".created_at) / CAST(:param_1 AS NUMERIC)' + ") * :ceiling_1, :lpad_1, :lpad_2), :concat_2), :str_to_date_1)) AS created_at__grouped__ \n" + 'FROM "order" ' + 'GROUP BY last_day(str_to_date(concat(year("order".created_at), :concat_1, ' + 'lpad(ceiling(EXTRACT(month FROM "order".created_at) / CAST(:param_1 AS NUMERIC)' + ") * :ceiling_1, :lpad_1, :lpad_2), :concat_2), :str_to_date_1)) " + "ORDER BY __aggregate__ DESC", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["-", 3, 3, 2, "0", "-01", "%Y-%m-%d"], + ) + + def test_can_aggregate_date_by_month(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Month"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, date_format("order".created_at, :date_format_1) ' + "AS created_at__grouped__ \n" + 'FROM "order" ' + 'GROUP BY date_format("order".created_at, :date_format_1) ' + "ORDER BY __aggregate__ DESC", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["%Y-%m-01"], + ) + + def test_can_aggregate_date_by_week(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Week"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, CAST(date_sub("order".created_at, ' + "INTERVAL(WEEKDAY(order.created_at)) DAY) AS DATE) AS created_at__grouped__ \n" + 'FROM "order" ' + 'GROUP BY CAST(date_sub("order".created_at, INTERVAL(WEEKDAY(order.created_at)) DAY) AS DATE) ' + "ORDER BY __aggregate__ DESC", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + [], + ) + + def test_can_aggregate_date_by_day(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Day"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, date_format("order".created_at, :date_format_1) ' + "AS created_at__grouped__ \n" + 'FROM "order" ' + 'GROUP BY date_format("order".created_at, :date_format_1) ' + "ORDER BY __aggregate__ DESC", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["%Y-%m-%d"], + ) + + +class TestSQLAlchemyOnMSSQL(BaseTestSqlAlchemyCollectionWithModels): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.dialect = "mssql" + + def test_can_aggregate_date_by_year(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Year"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, ' + 'datefromparts(EXTRACT(year FROM "order".created_at), ' + ":datefromparts_1, :datefromparts_2) " + "AS created_at__grouped__ \n" + 'FROM "order" ' + 'GROUP BY datefromparts(EXTRACT(year FROM "order".created_at), :datefromparts_1, :datefromparts_2) ' + "ORDER BY __aggregate__ DESC", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["01", "01"], + ) + + def test_can_aggregate_date_by_quarter(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Quarter"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, ' + 'eomonth(datefromparts(EXTRACT(YEAR FROM "order".created_at), ' + 'datepart(QUARTER, "order".created_at) * 3, 1)) AS created_at__grouped__ \n' + 'FROM "order" ' + 'GROUP BY eomonth(datefromparts(EXTRACT(YEAR FROM "order".created_at), ' + 'datepart(QUARTER, "order".created_at) * 3, 1)) ' + "ORDER BY __aggregate__ DESC", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + [], + ) + + def test_can_aggregate_date_by_month(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Month"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, ' + 'datefromparts(EXTRACT(year FROM "order".created_at), EXTRACT(month FROM "order".created_at), ' + ":datefromparts_1) AS created_at__grouped__ \n" + 'FROM "order" ' + 'GROUP BY datefromparts(EXTRACT(year FROM "order".created_at), EXTRACT(month FROM "order".created_at), ' + ":datefromparts_1) " + "ORDER BY __aggregate__ DESC", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + ["01"], + ) + + def test_can_aggregate_date_by_week(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Week"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, ' + 'CAST(dateadd(day, -EXTRACT(dw FROM "order".created_at) + :param_1, "order".created_at) AS DATE) AS ' + "created_at__grouped__ \n" + 'FROM "order" ' + 'GROUP BY CAST(dateadd(day, -EXTRACT(dw FROM "order".created_at) + :param_1, "order".created_at) ' + "AS DATE) " + "ORDER BY __aggregate__ DESC", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + [2], + ) + + def test_can_aggregate_date_by_day(self): + with patch.object(self.datasource.Session, "begin") as mock_begin: + mock_session = Mock() + mock_session.execute = Mock(return_value=[]) + mock_session.bind.dialect.name = self.dialect + mock_begin.return_value.__enter__.return_value = mock_session + + filter_ = PaginatedFilter({}) + collection = self.datasource.get_collection("order") + self.loop.run_until_complete( + collection.aggregate( + self.mocked_caller, + filter_, + Aggregation( + { + "operation": "Avg", + "field": "amount", + "groups": [{"field": "created_at", "operation": "Day"}], + } + ), + ) + ) + query = mock_session.execute.call_args.args[0] + sql_query = str(query) + self.assertEqual( + sql_query, + 'SELECT avg("order".amount) AS __aggregate__, ' + 'datefromparts(EXTRACT(year FROM "order".created_at), EXTRACT(month FROM "order".created_at), ' + 'EXTRACT(day FROM "order".created_at)) AS created_at__grouped__ \n' + 'FROM "order" ' + 'GROUP BY datefromparts(EXTRACT(year FROM "order".created_at), EXTRACT(month FROM "order".created_at), ' + 'EXTRACT(day FROM "order".created_at)) ' + "ORDER BY __aggregate__ DESC", + ) + self.assertEqual( + [p.value for p in query._get_embedded_bindparams()], + [], + ) + + class testSqlAlchemyCollectionFactory(TestCase): def test_create(self): mocked_collection = Mock() diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/chart/result_builder.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/chart/result_builder.py index 1d3545e28..9fb944d3a 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/chart/result_builder.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/chart/result_builder.py @@ -1,8 +1,6 @@ -import enum from datetime import date, datetime from typing import Dict, List, Optional, TypedDict, Union -import pandas as pd from forestadmin.datasource_toolkit.interfaces.chart import ( DistributionChart, LeaderboardChart, @@ -14,52 +12,16 @@ ValueChart, ) from forestadmin.datasource_toolkit.interfaces.query.aggregation import DateOperation, DateOperationLiteral -from forestadmin.datasource_toolkit.interfaces.query.condition_tree.transforms.time import Frequency - - -class _DateRangeFrequency(enum.Enum): - Day: str = "days" - Week: str = "weeks" - Month: str = "months" - Year: str = "years" - +from forestadmin.datasource_toolkit.utils.date_utils import ( + DATE_OPERATION_STR_FORMAT_FN, + make_formatted_date_range, + parse_date, +) MultipleTimeBasedLines = List[TypedDict("Line", {"label": str, "values": List[Union[int, float, None]]})] -def _parse_date(date_input: Union[str, date, datetime]) -> date: - if isinstance(date_input, str): - return datetime.fromisoformat(date_input).date() - elif isinstance(date_input, datetime): - return date_input.date() - elif isinstance(date_input, date): - return date_input - - -def _make_formatted_date_range( - first: Union[date, datetime], last: Union[date, datetime], frequency: _DateRangeFrequency, format_: str -): - current = first - used = set() - while current <= last: - yield current.strftime(format_) - used.add(current.strftime(format_)) - current = (current + pd.DateOffset(**{frequency.value: 1})).date() - - if last.strftime(format_) not in used: - yield last.strftime(format_) - - class ResultBuilder: - FREQUENCIES = {"Day": Frequency.DAY, "Week": Frequency.WEEK, "Month": Frequency.MONTH, "Year": Frequency.YEAR} - - FORMATS: Dict[DateOperation, str] = { - DateOperation.DAY: "%d/%m/%Y", - DateOperation.WEEK: "W%V-%G", - DateOperation.MONTH: "%b %Y", - DateOperation.YEAR: "%Y", - } - @staticmethod def value(value: Union[int, float], previous_value: Optional[Union[int, float]] = None) -> ValueChart: return ValueChart(countCurrent=value, countPrevious=previous_value) @@ -181,12 +143,12 @@ def _build_time_base_chart_result( """ if len(points) == 0: return [] - points_in_date_time = [{"date": _parse_date(point["date"]), "value": point["value"]} for point in points] - format_ = ResultBuilder.FORMATS[DateOperation(time_range)] + points_in_date_time = [{"date": parse_date(point["date"]), "value": point["value"]} for point in points] + format_fn = DATE_OPERATION_STR_FORMAT_FN[DateOperation(time_range)] formatted = {} for point in points_in_date_time: - label = point["date"].strftime(format_) + label = format_fn(point["date"]) if point["value"] is not None: formatted[label] = formatted.get(label, 0) + point["value"] @@ -194,8 +156,6 @@ def _build_time_base_chart_result( dates = sorted([p["date"] for p in points_in_date_time]) first = dates[0] last = dates[-1] - for label in _make_formatted_date_range( - first, last, _DateRangeFrequency[DateOperation(time_range).value], format_ - ): + for label in make_formatted_date_range(first, last, DateOperation(time_range)): data_points.append({"label": label, "values": {"value": formatted.get(label, 0)}}) return data_points diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/aggregation.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/aggregation.py index 9d6e7d632..80424a279 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/aggregation.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/interfaces/query/aggregation.py @@ -30,12 +30,13 @@ class Aggregator(enum.Enum): class DateOperation(enum.Enum): YEAR = "Year" + QUARTER = "Quarter" MONTH = "Month" WEEK = "Week" DAY = "Day" -DateOperationLiteral = Literal["Year", "Month", "Week", "Day"] +DateOperationLiteral = Literal["Year", "Quarter", "Month", "Week", "Day"] class AggregateResult(TypedDict): diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/utils/date_utils.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/utils/date_utils.py new file mode 100644 index 000000000..a9622df96 --- /dev/null +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/utils/date_utils.py @@ -0,0 +1,49 @@ +from datetime import date, datetime +from typing import Callable, Dict, Iterator, Union + +import pandas as pd +from forestadmin.datasource_toolkit.interfaces.query.aggregation import DateOperation + +DATE_OPERATION_STR_FORMAT_FN: Dict[DateOperation, Callable[[Union[date, datetime]], str]] = { + DateOperation.DAY: lambda d: d.strftime("%d/%m/%Y"), + DateOperation.WEEK: lambda d: d.strftime("W%V-%G"), + DateOperation.MONTH: lambda d: d.strftime("%b %Y"), + DateOperation.YEAR: lambda d: d.strftime("%Y"), + DateOperation.QUARTER: lambda d: f"Q{pd.Timestamp(d).quarter}-{d.year}", +} + +_DATE_OPERATION_OFFSET: Dict[DateOperation, pd.DateOffset] = { + DateOperation.YEAR: pd.DateOffset(years=1), + DateOperation.QUARTER: pd.DateOffset(months=3), + DateOperation.MONTH: pd.DateOffset(months=1), + DateOperation.WEEK: pd.DateOffset(weeks=1), + DateOperation.DAY: pd.DateOffset(days=1), +} + + +def parse_date(date_input: Union[str, date, datetime]) -> date: + if isinstance(date_input, str): + return datetime.fromisoformat(date_input).date() + elif isinstance(date_input, datetime): + return date_input.date() + elif isinstance(date_input, date): + return date_input + + +def make_formatted_date_range( + first: Union[date, datetime], + last: Union[date, datetime], + date_operation: DateOperation, +) -> Iterator[str]: + current = first + used = set() + format_fn = DATE_OPERATION_STR_FORMAT_FN[date_operation] + + while current <= last: + formatted = format_fn(current) + yield formatted + used.add(formatted) + current = (current + _DATE_OPERATION_OFFSET[date_operation]).date() + + if format_fn(last) not in used: + yield format_fn(last) diff --git a/src/datasource_toolkit/tests/decorators/chart/test_chart_result_builder.py b/src/datasource_toolkit/tests/decorators/chart/test_chart_result_builder.py index 43c225121..33b6caf04 100644 --- a/src/datasource_toolkit/tests/decorators/chart/test_chart_result_builder.py +++ b/src/datasource_toolkit/tests/decorators/chart/test_chart_result_builder.py @@ -109,6 +109,23 @@ def test_time_based_should_return_correct_format_week(self): {"label": "W02-1986", "values": {"value": 7}}, ] + def test_time_based_should_return_correct_format_quarter(self): + result = ResultBuilder.time_based( + DateOperation.QUARTER, + { + "2023-01-07": 3, + "2023-01-08": 4, + "2023-07-26": 1, + "2023-12-31": 1, + }, + ) + assert result == [ + {"label": "Q1-2023", "values": {"value": 7}}, + {"label": "Q2-2023", "values": {"value": 0}}, + {"label": "Q3-2023", "values": {"value": 1}}, + {"label": "Q4-2023", "values": {"value": 1}}, + ] + def test_time_based_should_return_correct_format_week_iso_year(self): result = ResultBuilder.time_based( DateOperation.WEEK,