Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,11 @@ def _build_leaf_condition(cls, leaf: ConditionTreeLeaf) -> models.Q:
value = leaf.value
if key == "__isnull":
value = True

if key == "__in" and isinstance(value, list) and None in value:
q_obj = cls.build(ConditionTreeBranch("or", [ConditionTreeLeaf(leaf.field, "equal", v) for v in value]))
return ~q_obj if should_negate else q_obj

if should_negate:
return ~models.Q(**{f"{field}{key}": value})
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ class FilterOperator(BaseFilterOperator):
# operator: (lookup_expr, negate needed)
Operator.EQUAL: ("__exact", False),
Operator.NOT_EQUAL: ("__exact", True),
Operator.BLANK: ("__isnull", False),
Operator.CONTAINS: ("__icontains", False),
Operator.NOT_CONTAINS: ("__icontains", True),
Operator.STARTS_WITH: ("__istartswith", False),
Expand Down
47 changes: 47 additions & 0 deletions src/datasource_django/tests/test_django_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,53 @@ async def test_datetime_and_date_should_be_correctly_serialized(self):
],
)

async def test_in_and_not_in_should_works_contains_none(self):
ret = await self.rating_collection.list(
self.mocked_caller,
PaginatedFilter({"condition_tree": ConditionTreeLeaf("rating_pk", Operator.IN, [1, None])}),
Projection("rating_pk", "rated_at", "rating", "book:author:birth_date"),
)

self.assertEqual(
ret,
[
{
"rating_pk": 1,
"rating": 1,
"rated_at": datetime.datetime(2022, 12, 25, 10, 10, 10, tzinfo=datetime.timezone.utc),
"book": {"author": {"birth_date": datetime.date(1920, 2, 1)}},
},
],
)

ret = await self.rating_collection.list(
self.mocked_caller,
PaginatedFilter(
{
"condition_tree": ConditionTreeBranch(
"and",
[
ConditionTreeLeaf("rating_pk", Operator.IN, [1, 2]),
ConditionTreeLeaf("rating_pk", Operator.NOT_IN, [2, None]),
],
)
}
),
Projection("rating_pk", "rated_at", "rating", "book:author:birth_date"),
)

self.assertEqual(
ret,
[
{
"rating_pk": 1,
"rating": 1,
"rated_at": datetime.datetime(2022, 12, 25, 10, 10, 10, tzinfo=datetime.timezone.utc),
"book": {"author": {"birth_date": datetime.date(1920, 2, 1)}},
},
],
)


class TestDjangoCollectionCRUDListPolymorphism(TestDjangoCollectionCRUDList):
def setUp(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def test_build_should_correctly_introspect_uuid(self):
self.assertEqual(
field_schema["filter_operators"],
{
Operator.BLANK,
Operator.EQUAL,
Operator.MISSING,
Operator.NOT_EQUAL,
Expand All @@ -141,7 +140,6 @@ def test_introspected_field_should_respect_django_capabilities(self):
self.assertEqual(
field_schema["filter_operators"],
{
Operator.BLANK,
Operator.EQUAL,
Operator.MISSING,
Operator.NOT_EQUAL,
Expand Down Expand Up @@ -323,7 +321,6 @@ def test_build_should_handle_polymorphic_many_to_one(self):
Operator.NOT_EQUAL,
Operator.ENDS_WITH,
Operator.PRESENT,
Operator.BLANK,
Operator.EQUAL,
Operator.NOT_IN,
Operator.IN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@
projections_to_records,
)
from forestadmin.datasource_sqlalchemy.utils.relationships import Relationships, merge_relationships
from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType, RelationAlias
from forestadmin.datasource_toolkit.interfaces.fields import Column, PrimitiveType, RelationAlias
from forestadmin.datasource_toolkit.interfaces.query.aggregation import AggregateResult, Aggregation
from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.base import ConditionTree
from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.leaf import ConditionTreeLeaf
from forestadmin.datasource_toolkit.interfaces.query.filter.paginated import PaginatedFilter
from forestadmin.datasource_toolkit.interfaces.query.filter.unpaginated import Filter
from forestadmin.datasource_toolkit.interfaces.query.projections import Projection
from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias
from forestadmin.datasource_toolkit.validations.type_getter import TypeGetter
from sqlalchemy import Table
from sqlalchemy import column as SqlAlchemyColumn
from sqlalchemy import text
Expand Down Expand Up @@ -208,7 +207,7 @@ async def update(self, caller: User, filter_: Optional[Filter], patch: RecordsDa

def _cast_condition_tree(self, tree: ConditionTree) -> ConditionTree:
if isinstance(tree, ConditionTreeLeaf):
if TypeGetter.get(tree.value, None) == PrimitiveType.DATE:
if cast(Column, self.schema["fields"][tree.field])["column_type"] == PrimitiveType.DATE:
iso_format = tree.value
if isinstance(iso_format, str):
if iso_format[-1] == "Z":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from forestadmin.datasource_sqlalchemy.utils.relationships import Relationships, merge_relationships
from forestadmin.datasource_sqlalchemy.utils.type_converter import FilterOperator
from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException
from forestadmin.datasource_toolkit.interfaces.fields import Operator
from forestadmin.datasource_toolkit.interfaces.query.aggregation import Aggregation
from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.base import ConditionTree
from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import Aggregator, ConditionTreeBranch
Expand All @@ -17,7 +18,7 @@
from forestadmin.datasource_toolkit.interfaces.records import RecordsDataAlias
from sqlalchemy import and_
from sqlalchemy import column as SqlAlchemyColumn
from sqlalchemy import delete, or_, select, update
from sqlalchemy import delete, not_, or_, select, update
from sqlalchemy.engine import Dialect
from sqlalchemy.sql.elements import BooleanClauseList, UnaryExpression

Expand All @@ -31,6 +32,19 @@ class ConditionTreeFactory:

@classmethod
def _build_leaf_condition(cls, collection: BaseSqlAlchemyCollection, leaf: ConditionTreeLeaf) -> Tuple[Any, Any]:
if leaf.operator in [Operator.IN, Operator.NOT_IN] and isinstance(leaf.value, list) and None in leaf.value:
operator, relationships = cls._build_branch_condition(
collection,
ConditionTreeBranch(
"or",
[ConditionTreeLeaf(leaf.field, Operator.EQUAL, v) for v in leaf.value],
),
)
return (
operator if leaf.operator == Operator.IN else not_(operator),
relationships,
)

projection = leaf.projection
columns, relationships = collection.get_columns(projection)
operator = FilterOperator.get_operator(columns, leaf.operator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from forestadmin.datasource_toolkit.exceptions import DatasourceToolkitException
from forestadmin.datasource_toolkit.interfaces.fields import ColumnAlias, Operator, PrimitiveType
from forestadmin.datasource_toolkit.utils.operators import BaseFilterOperator
from sqlalchemy import ARRAY # type: ignore
from sqlalchemy import column as SqlAlchemyColumn # type: ignore
from sqlalchemy import func, not_, or_ # type: ignore
from sqlalchemy import ARRAY
from sqlalchemy import column as SqlAlchemyColumn
from sqlalchemy import func, not_
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql import UUID

Expand Down Expand Up @@ -68,7 +68,6 @@ class FilterOperator(BaseFilterOperator):
OPERATORS = {
Operator.EQUAL: "_equal_operator",
Operator.NOT_EQUAL: "_not_equal_operator",
Operator.BLANK: "_blank_operator",
Operator.CONTAINS: "_contains_operator",
Operator.NOT_CONTAINS: "_not_contains_operator",
Operator.STARTS_WITH: "_starts_with_operator",
Expand All @@ -93,13 +92,6 @@ def _equal_operator(column: SqlAlchemyColumn):
def _not_equal_operator(column: SqlAlchemyColumn):
return column.__ne__

@staticmethod
def _blank_operator(column: SqlAlchemyColumn):
def wrapped(_: str):
return or_([column.is_(None), column.__eq__("")])

return wrapped

@staticmethod
def _contains_operator(column: SqlAlchemyColumn):
def wrapped(value: str):
Expand Down
30 changes: 30 additions & 0 deletions src/datasource_sqlalchemy/tests/test_sqlalchemy_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from unittest import TestCase
from unittest.mock import Mock, patch

from forestadmin.datasource_toolkit.interfaces.query.condition_tree.nodes.branch import ConditionTreeBranch

if sys.version_info >= (3, 9):
import zoneinfo
else:
Expand Down Expand Up @@ -213,6 +215,34 @@ def test_list_with_filter(self):
)
assert len(results) == 2

def test_list_filter_in_and_not_in_with_null_in_values_should_work(self):
collection = self.datasource.get_collection("order")
filter_ = PaginatedFilter({"condition_tree": ConditionTreeLeaf("id", Operator.IN, [1, None])})

results = self.loop.run_until_complete(
collection.list(self.mocked_caller, filter_, Projection("id", "created_at"))
)

self.assertEqual(len(results), 1)
self.assertEqual(results[0]["id"], 1)

collection = self.datasource.get_collection("order")
filter_ = PaginatedFilter(
{
"condition_tree": ConditionTreeBranch(
"and",
[ConditionTreeLeaf("id", Operator.IN, [1, 2]), ConditionTreeLeaf("id", Operator.NOT_IN, [2, None])],
)
}
)

results = self.loop.run_until_complete(
collection.list(self.mocked_caller, filter_, Projection("id", "created_at"))
)

self.assertEqual(len(results), 1)
self.assertEqual(results[0]["id"], 1)

def test_create(self):
order = {
"id": 11,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class FieldEnum(enum.Enum):
"validations": [],
"filter_operators": {
Operator.NOT_IN,
Operator.BLANK,
Operator.EQUAL,
Operator.PRESENT,
Operator.MISSING,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

class BaseFilterOperator:
COMMON_OPERATORS: Set[Operator] = {
Operator.BLANK,
Operator.EQUAL,
Operator.MISSING,
Operator.NOT_EQUAL,
Expand Down