diff --git a/src/py/mat3ra/code/entity.py b/src/py/mat3ra/code/entity.py index 5e1817e6..61eb6eab 100644 --- a/src/py/mat3ra/code/entity.py +++ b/src/py/mat3ra/code/entity.py @@ -2,7 +2,8 @@ import jsonschema from mat3ra.utils import object as object_utils -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_snake from typing_extensions import Self from . import BaseUnderscoreJsonPropsHandler @@ -86,6 +87,14 @@ def clone(self: T, extra_context: Optional[Dict[str, Any]] = None, deep=True) -> return self.model_copy(update=extra_context or {}, deep=deep) +class InMemoryEntitySnakeCase(InMemoryEntityPydantic): + model_config = ConfigDict( + arbitrary_types_allowed=True, + alias_generator=to_snake, + populate_by_name=True, + ) + + # TODO: remove in the next PR class InMemoryEntity(BaseUnderscoreJsonPropsHandler): jsonSchema: Optional[Dict] = None diff --git a/tests/py/unit/__init__.py b/tests/py/unit/__init__.py index 6da01dbd..db5a3a21 100644 --- a/tests/py/unit/__init__.py +++ b/tests/py/unit/__init__.py @@ -1,7 +1,8 @@ import json from enum import Enum +from typing import Optional -from mat3ra.code.entity import InMemoryEntityPydantic +from mat3ra.code.entity import InMemoryEntityPydantic, InMemoryEntitySnakeCase from pydantic import BaseModel REFERENCE_OBJECT_VALID = {"key1": "value1", "key2": 1} @@ -68,3 +69,32 @@ class SampleModelWithEnum(BaseModel): class SampleEntityWithEnum(SampleModelWithEnum, InMemoryEntityPydantic): pass + + +class CamelCaseSchema(BaseModel): + applicationName: str + applicationVersion: Optional[str] = None + executableName: Optional[str] = None + + +class SnakeCaseEntity(CamelCaseSchema, InMemoryEntitySnakeCase): + pass + + +SNAKE_CASE_CONFIG = { + "application_name": "espresso", + "application_version": "7.2", + "executable_name": "pw.x", +} + +CAMEL_CASE_CONFIG = { + "applicationName": "espresso", + "applicationVersion": "7.2", + "executableName": "pw.x", +} + +MIXED_CASE_CONFIG = { + "application_name": "espresso", + "applicationVersion": "7.2", + "executable_name": "pw.x", +} diff --git a/tests/py/unit/test_entity.py b/tests/py/unit/test_entity.py index 45d7f8f8..1d9d73f9 100644 --- a/tests/py/unit/test_entity.py +++ b/tests/py/unit/test_entity.py @@ -1,6 +1,11 @@ import json +import pytest + from . import ( + CAMEL_CASE_CONFIG, + CAMEL_CASE_CONFIG as EXPECTED_CAMEL_CASE_OUTPUT, + MIXED_CASE_CONFIG, REFERENCE_OBJECT_DOUBLE_NESTED_VALID, REFERENCE_OBJECT_INVALID, REFERENCE_OBJECT_NESTED_VALID, @@ -10,6 +15,7 @@ REFERENCE_OBJECT_VALID_UPDATED, REFERENCE_OBJECT_VALID_WITH_EXTRA_KEY, REFERENCE_OBJECT_VALID_WITH_MISSING_KEY, + SNAKE_CASE_CONFIG, ExampleClass, ExampleDefaultableClass, ExampleDoubleNestedKeyAsClassInstancesClass, @@ -20,6 +26,7 @@ ExampleSchema, SampleEnum, SampleEntityWithEnum, + SnakeCaseEntity, ) @@ -211,3 +218,30 @@ def test_clone_deep(): cloned_entity_deep.key1 = "adjusted_value" assert entity.key1 == "value1" assert cloned_entity_deep.key1 == "adjusted_value" + + +@pytest.mark.parametrize( + "config,expected_output", + [ + (SNAKE_CASE_CONFIG, EXPECTED_CAMEL_CASE_OUTPUT), + (CAMEL_CASE_CONFIG, EXPECTED_CAMEL_CASE_OUTPUT), + (MIXED_CASE_CONFIG, EXPECTED_CAMEL_CASE_OUTPUT), + ], +) +def test_create_entity_snake_case(config, expected_output): + entity = SnakeCaseEntity(**config) + assert entity.applicationName == expected_output["applicationName"] + assert entity.applicationVersion == expected_output["applicationVersion"] + assert entity.executableName == expected_output["executableName"] + + result_dict = entity.to_dict() + assert result_dict == expected_output + assert "applicationName" in result_dict + assert "application_name" not in result_dict + + result_json = json.loads(entity.to_json()) + assert result_json == expected_output + + entity_from_create = SnakeCaseEntity.create(config) + assert entity_from_create.to_dict() == expected_output +