diff --git a/src/py/mat3ra/code/entity.py b/src/py/mat3ra/code/entity.py index 61eb6eab..116668bf 100644 --- a/src/py/mat3ra/code/entity.py +++ b/src/py/mat3ra/code/entity.py @@ -1,39 +1,15 @@ from typing import Any, Dict, List, Optional, Type, TypeVar -import jsonschema -from mat3ra.utils import object as object_utils from pydantic import BaseModel, ConfigDict from pydantic.alias_generators import to_snake from typing_extensions import Self -from . import BaseUnderscoreJsonPropsHandler from .mixins import DefaultableMixin, HasDescriptionMixin, HasMetadataMixin, NamedMixin T = TypeVar("T", bound="InMemoryEntityPydantic") B = TypeVar("B", bound="BaseModel") -# TODO: remove in the next PR -class ValidationErrorCode: - IN_MEMORY_ENTITY_DATA_INVALID = "IN_MEMORY_ENTITY_DATA_INVALID" - - -# TODO: remove in the next PR -class ErrorDetails: - def __init__(self, error: Optional[Dict[str, Any]], json: Dict[str, Any], schema: Dict): - self.error = error - self.json = json - self.schema = schema - - -# TODO: remove in the next PR -class EntityError(Exception): - def __init__(self, code: ValidationErrorCode, details: Optional[ErrorDetails] = None): - super().__init__(code) - self.code = code - self.details = details - - class InMemoryEntityPydantic(BaseModel): model_config = {"arbitrary_types_allowed": True} @@ -90,82 +66,41 @@ def clone(self: T, extra_context: Optional[Dict[str, Any]] = None, deep=True) -> class InMemoryEntitySnakeCase(InMemoryEntityPydantic): model_config = ConfigDict( arbitrary_types_allowed=True, + # Generate snake_case aliases for all fields (e.g. myField -> my_field) alias_generator=to_snake, + # Allow populating fields using either the original name or the snake_case alias populate_by_name=True, ) + @staticmethod + def _create_property_from_camel_case(camel_name: str): + def getter(self): + return getattr(self, camel_name) -# TODO: remove in the next PR -class InMemoryEntity(BaseUnderscoreJsonPropsHandler): - jsonSchema: Optional[Dict] = None - - @classmethod - def get_cls(cls) -> str: - return cls.__name__ - - @property - def cls(self) -> str: - return self.__class__.__name__ - - def get_cls_name(self) -> str: - return self.__class__.__name__ - - @classmethod - def create(cls, config: Dict[str, Any]) -> Any: - return cls(config) + def setter(self, value: Any): + setattr(self, camel_name, value) - def to_json(self, exclude: List[str] = []) -> Dict[str, Any]: - return self.clean(object_utils.clone_deep(object_utils.omit(self._json, exclude))) + return property(getter, setter) - def clone(self, extra_context: Dict[str, Any] = {}) -> Any: - config = self.to_json() - config.update(extra_context) - # To avoid: - # Argument 1 to "__init__" of "BaseUnderscoreJsonPropsHandler" has incompatible type "Dict[str, Any]"; - # expected "BaseUnderscoreJsonPropsHandler" - return self.__class__(config) + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if not issubclass(cls, BaseModel): + return - @staticmethod - def validate_data(data: Dict[str, Any], clean: bool = False): - if clean: - print("Error: clean is not supported for InMemoryEntity.validateData") - if InMemoryEntity.jsonSchema: - jsonschema.validate(data, InMemoryEntity.jsonSchema) - - def validate(self) -> None: - if self._json: - self.__class__.validate_data(self._json) - - def clean(self, config: Dict[str, Any]) -> Dict[str, Any]: - # Not implemented, consider the below for the implementation - # https://stackoverflow.com/questions/44694835/remove-properties-from-json-object-not-present-in-schema - return config - - def is_valid(self) -> bool: try: - self.validate() - return True - except EntityError: - return False - - # Properties - @property - def id(self) -> str: - return self.prop("_id", "") + model_fields = cls.model_fields + except Exception: + return - @id.setter - def id(self, id: str) -> None: - self.set_prop("_id", id) + for field_name, field_info in model_fields.items(): + if field_name == to_snake(field_name): + continue - @property - def slug(self) -> str: - return self.prop("slug", "") + snake_case_name = to_snake(field_name) + if hasattr(cls, snake_case_name): + continue - def get_as_entity_reference(self, by_id_only: bool = False) -> Dict[str, str]: - if by_id_only: - return {"_id": self.id} - else: - return {"_id": self.id, "slug": self.slug, "cls": self.get_cls_name()} + setattr(cls, snake_case_name, cls._create_property_from_camel_case(field_name)) class HasDescriptionHasMetadataNamedDefaultableInMemoryEntityPydantic( diff --git a/tests/py/unit/__init__.py b/tests/py/unit/__init__.py index db5a3a21..cd12f9e5 100644 --- a/tests/py/unit/__init__.py +++ b/tests/py/unit/__init__.py @@ -98,3 +98,14 @@ class SnakeCaseEntity(CamelCaseSchema, InMemoryEntitySnakeCase): "applicationVersion": "7.2", "executable_name": "pw.x", } + + +class AutoSnakeCaseTestSchema(BaseModel): + contextProviders: list = [] + applicationName: str + applicationVersion: Optional[str] = None + executableName: Optional[str] = None + + +class AutoSnakeCaseTestEntity(AutoSnakeCaseTestSchema, InMemoryEntitySnakeCase): + pass diff --git a/tests/py/unit/test_entity.py b/tests/py/unit/test_entity.py index 1d9d73f9..9d6adab9 100644 --- a/tests/py/unit/test_entity.py +++ b/tests/py/unit/test_entity.py @@ -244,4 +244,3 @@ def test_create_entity_snake_case(config, expected_output): entity_from_create = SnakeCaseEntity.create(config) assert entity_from_create.to_dict() == expected_output - diff --git a/tests/py/unit/test_entity_snake_case.py b/tests/py/unit/test_entity_snake_case.py new file mode 100644 index 00000000..0a3249c2 --- /dev/null +++ b/tests/py/unit/test_entity_snake_case.py @@ -0,0 +1,91 @@ +import pytest +from mat3ra.utils import assertion +from . import AutoSnakeCaseTestEntity + +BASE = { + "applicationName": "camelCasedValue", + "applicationVersion": "camelCasedVersion", + "executableName": "camelCasedExecutable", + "contextProviders": [], +} + +INSTANTIATION = [ + {"application_name": BASE["applicationName"], "application_version": BASE["applicationVersion"], + "executable_name": BASE["executableName"]}, + {"applicationName": BASE["applicationName"], "applicationVersion": BASE["applicationVersion"], + "executableName": BASE["executableName"]}, + {"application_name": BASE["applicationName"], "applicationVersion": BASE["applicationVersion"], + "executable_name": BASE["executableName"]}, +] + +UPDATES = [ + ( + {"application_name": "new_value", "context_providers": ["item_snake"]}, + {"applicationName": "new_value", "contextProviders": ["item_snake"]}, + {"application_name": "new_value", "context_providers": ["item_snake"]}, + ), + ( + {"applicationName": "newValueCamel", "contextProviders": ["itemCamel"]}, + {"applicationName": "newValueCamel", "contextProviders": ["itemCamel"]}, + {"application_name": "newValueCamel", "context_providers": ["itemCamel"]}, + ), + ( + {"application_name": "new_value_snake", "applicationVersion": "newVersionCamel"}, + {"applicationName": "new_value_snake", "applicationVersion": "newVersionCamel"}, + {"application_name": "new_value_snake", "application_version": "newVersionCamel"}, + ), + ( + {"application_name": "new_val", "application_version": "new_version", + "executable_name": "new_exec", "context_providers": ["a", "b"]}, + {"applicationName": "new_val", "applicationVersion": "new_version", + "executableName": "new_exec", "contextProviders": ["a", "b"]}, + {"application_name": "new_val", "application_version": "new_version", + "executable_name": "new_exec", "context_providers": ["a", "b"]}, + ), +] + + +def camel(entity): + return dict( + applicationName=entity.applicationName, + applicationVersion=entity.applicationVersion, + executableName=entity.executableName, + contextProviders=entity.contextProviders, + ) + + +def snake(entity): + return dict( + application_name=entity.application_name, + application_version=entity.application_version, + executable_name=entity.executable_name, + context_providers=entity.context_providers, + ) + + +@pytest.mark.parametrize("cfg", INSTANTIATION) +def test_instantiation(cfg): + entity = AutoSnakeCaseTestEntity(**cfg) + assertion.assert_deep_almost_equal(BASE, camel(entity)) + assertion.assert_deep_almost_equal( + dict(application_name=BASE["applicationName"], + application_version=BASE["applicationVersion"], + executable_name=BASE["executableName"], + context_providers=[]), + snake(entity), + ) + + +@pytest.mark.parametrize("updates, exp_camel, exp_snake", UPDATES) +def test_updates(updates, exp_camel, exp_snake): + entity = AutoSnakeCaseTestEntity(**BASE) + for k, v in updates.items(): + setattr(entity, k, v) + assertion.assert_deep_almost_equal({**BASE, **exp_camel}, camel(entity)) + assertion.assert_deep_almost_equal( + {**snake(AutoSnakeCaseTestEntity(**BASE)), **exp_snake}, + snake(entity), + ) + out = entity.to_dict() + assertion.assert_deep_almost_equal({**BASE, **exp_camel}, out) + assert "application_name" not in out and "context_providers" not in out