Skip to content

Commit aff6405

Browse files
authored
Improves ClassList integration with Pydantic and adds JSON conversion for Projects (#79)
* made ClassList generic and added Pydantic validation schema * added coercion for lists * added JSON conversion for Projects * moved pydantic imports into core schema generator * fixed and improved write_script and its test * review fixes * review fixes
1 parent 072f954 commit aff6405

File tree

9 files changed

+649
-217
lines changed

9 files changed

+649
-217
lines changed

RATapi/classlist.py

Lines changed: 73 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import contextlib
77
import warnings
88
from collections.abc import Sequence
9-
from typing import Any, Union
9+
from typing import Any, Generic, TypeVar, Union
1010

1111
import numpy as np
1212
import prettytable
1313

14+
T = TypeVar("T")
1415

15-
class ClassList(collections.UserList):
16+
17+
class ClassList(collections.UserList, Generic[T]):
1618
"""List of instances of a particular class.
1719
1820
This class subclasses collections.UserList to construct a list intended to store ONLY instances of a particular
@@ -30,14 +32,14 @@ class ClassList(collections.UserList):
3032
3133
Parameters
3234
----------
33-
init_list : Sequence [object] or object, optional
35+
init_list : Sequence [T] or T, optional
3436
An instance, or list of instance(s), of the class to be used in this ClassList.
3537
name_field : str, optional
3638
The field used to define unique objects in the ClassList (default is "name").
3739
3840
"""
3941

40-
def __init__(self, init_list: Union[Sequence[object], object] = None, name_field: str = "name") -> None:
42+
def __init__(self, init_list: Union[Sequence[T], T] = None, name_field: str = "name") -> None:
4143
self.name_field = name_field
4244

4345
# Set input as list if necessary
@@ -81,20 +83,20 @@ def __str__(self):
8183
output = str(self.data)
8284
return output
8385

84-
def __getitem__(self, index: Union[int, slice, str, object]) -> object:
86+
def __getitem__(self, index: Union[int, slice, str, T]) -> T:
8587
"""Get an item by its index, name, a slice, or the object itself."""
8688
if isinstance(index, (int, slice)):
8789
return self.data[index]
88-
elif isinstance(index, (str, object)):
90+
elif isinstance(index, (str, self._class_handle)):
8991
return self.data[self.index(index)]
9092
else:
9193
raise IndexError("ClassLists can only be indexed by integers, slices, name strings, or objects.")
9294

93-
def __setitem__(self, index: int, item: object) -> None:
95+
def __setitem__(self, index: int, item: T) -> None:
9496
"""Replace the object at an existing index of the ClassList."""
9597
self._setitem(index, item)
9698

97-
def _setitem(self, index: int, item: object) -> None:
99+
def _setitem(self, index: int, item: T) -> None:
98100
"""Auxiliary routine of "__setitem__" used to enable wrapping."""
99101
self._check_classes([item])
100102
self._check_unique_name_fields([item])
@@ -108,11 +110,11 @@ def _delitem(self, index: int) -> None:
108110
"""Auxiliary routine of "__delitem__" used to enable wrapping."""
109111
del self.data[index]
110112

111-
def __iadd__(self, other: Sequence[object]) -> "ClassList":
113+
def __iadd__(self, other: Sequence[T]) -> "ClassList":
112114
"""Define in-place addition using the "+=" operator."""
113115
return self._iadd(other)
114116

115-
def _iadd(self, other: Sequence[object]) -> "ClassList":
117+
def _iadd(self, other: Sequence[T]) -> "ClassList":
116118
"""Auxiliary routine of "__iadd__" used to enable wrapping."""
117119
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
118120
other = [other]
@@ -135,13 +137,13 @@ def __imul__(self, n: int) -> None:
135137
"""Define in-place multiplication using the "*=" operator."""
136138
raise TypeError(f"unsupported operand type(s) for *=: '{self.__class__.__name__}' and '{n.__class__.__name__}'")
137139

138-
def append(self, obj: object = None, **kwargs) -> None:
140+
def append(self, obj: T = None, **kwargs) -> None:
139141
"""Append a new object to the ClassList using either the object itself, or keyword arguments to set attribute
140142
values.
141143
142144
Parameters
143145
----------
144-
obj : object, optional
146+
obj : T, optional
145147
An instance of the class specified by self._class_handle.
146148
**kwargs : dict[str, Any], optional
147149
The input keyword arguments for a new object in the ClassList.
@@ -180,15 +182,15 @@ def append(self, obj: object = None, **kwargs) -> None:
180182
self._validate_name_field(kwargs)
181183
self.data.append(self._class_handle(**kwargs))
182184

183-
def insert(self, index: int, obj: object = None, **kwargs) -> None:
185+
def insert(self, index: int, obj: T = None, **kwargs) -> None:
184186
"""Insert a new object into the ClassList at a given index using either the object itself, or keyword arguments
185187
to set attribute values.
186188
187189
Parameters
188190
----------
189191
index: int
190192
The index at which to insert a new object in the ClassList.
191-
obj : object, optional
193+
obj : T, optional
192194
An instance of the class specified by self._class_handle.
193195
**kwargs : dict[str, Any], optional
194196
The input keyword arguments for a new object in the ClassList.
@@ -227,26 +229,26 @@ def insert(self, index: int, obj: object = None, **kwargs) -> None:
227229
self._validate_name_field(kwargs)
228230
self.data.insert(index, self._class_handle(**kwargs))
229231

230-
def remove(self, item: Union[object, str]) -> None:
232+
def remove(self, item: Union[T, str]) -> None:
231233
"""Remove an object from the ClassList using either the object itself or its name_field value."""
232234
item = self._get_item_from_name_field(item)
233235
self.data.remove(item)
234236

235-
def count(self, item: Union[object, str]) -> int:
237+
def count(self, item: Union[T, str]) -> int:
236238
"""Return the number of times an object appears in the ClassList using either the object itself or its
237239
name_field value.
238240
"""
239241
item = self._get_item_from_name_field(item)
240242
return self.data.count(item)
241243

242-
def index(self, item: Union[object, str], offset: bool = False, *args) -> int:
244+
def index(self, item: Union[T, str], offset: bool = False, *args) -> int:
243245
"""Return the index of a particular object in the ClassList using either the object itself or its
244246
name_field value. If offset is specified, add one to the index. This is used to account for one-based indexing.
245247
"""
246248
item = self._get_item_from_name_field(item)
247249
return self.data.index(item, *args) + int(offset)
248250

249-
def extend(self, other: Sequence[object]) -> None:
251+
def extend(self, other: Sequence[T]) -> None:
250252
"""Extend the ClassList by adding another sequence."""
251253
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
252254
other = [other]
@@ -319,7 +321,7 @@ def _validate_name_field(self, input_args: dict[str, Any]) -> None:
319321
f"which is already specified at index {names.index(name)} of the ClassList",
320322
)
321323

322-
def _check_unique_name_fields(self, input_list: Sequence[object]) -> None:
324+
def _check_unique_name_fields(self, input_list: Sequence[T]) -> None:
323325
"""Raise a ValueError if any value of the name_field attribute is used more than once in a list of class
324326
objects.
325327
@@ -376,7 +378,7 @@ def _check_unique_name_fields(self, input_list: Sequence[object]) -> None:
376378
f"{newline.join(error for error in error_list)}"
377379
)
378380

379-
def _check_classes(self, input_list: Sequence[object]) -> None:
381+
def _check_classes(self, input_list: Sequence[T]) -> None:
380382
"""Raise a ValueError if any object in a list of objects is not of the type specified by self._class_handle.
381383
382384
Parameters
@@ -401,17 +403,17 @@ def _check_classes(self, input_list: Sequence[object]) -> None:
401403
f"In the input list:\n{newline.join(error for error in error_list)}\n"
402404
)
403405

404-
def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object, str]:
406+
def _get_item_from_name_field(self, value: Union[T, str]) -> Union[T, str]:
405407
"""Return the object with the given value of the name_field attribute in the ClassList.
406408
407409
Parameters
408410
----------
409-
value : object or str
411+
value : T or str
410412
Either an object in the ClassList, or the value of the name_field attribute of an object in the ClassList.
411413
412414
Returns
413415
-------
414-
instance : object or str
416+
instance : T or str
415417
Either the object with the value of the name_field attribute given by value, or the input value if an
416418
object with that value of the name_field attribute cannot be found.
417419
@@ -424,7 +426,7 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object,
424426
return next((model for model in self.data if getattr(model, self.name_field).lower() == lower_value), value)
425427

426428
@staticmethod
427-
def _determine_class_handle(input_list: Sequence[object]):
429+
def _determine_class_handle(input_list: Sequence[T]):
428430
"""When inputting a sequence of object to a ClassList, the _class_handle should be set as the type of the
429431
element which satisfies "issubclass" for all the other elements.
430432
@@ -448,3 +450,50 @@ def _determine_class_handle(input_list: Sequence[object]):
448450
class_handle = type(input_list[0])
449451

450452
return class_handle
453+
454+
# Pydantic core schema which allows ClassLists to be validated
455+
# in short: it validates that each ClassList is indeed a ClassList,
456+
# and then validates ClassList.data as though it were a typed list
457+
# e.g. ClassList[str] data is validated like list[str]
458+
@classmethod
459+
def __get_pydantic_core_schema__(cls, source: Any, handler):
460+
# import here so that the ClassList can be instantiated and used without Pydantic installed
461+
from pydantic import ValidatorFunctionWrapHandler
462+
from pydantic.types import (
463+
core_schema, # import core_schema through here rather than making pydantic_core a dependency
464+
)
465+
from typing_extensions import get_args, get_origin
466+
467+
# if annotated with a class, get the item type of that class
468+
origin = get_origin(source)
469+
item_tp = Any if origin is None else get_args(source)[0]
470+
471+
list_schema = handler.generate_schema(list[item_tp])
472+
473+
def coerce(v: Any, handler: ValidatorFunctionWrapHandler) -> ClassList[T]:
474+
"""If a sequence is given, try to coerce it to a ClassList."""
475+
if isinstance(v, Sequence):
476+
classlist = ClassList()
477+
if len(v) > 0 and isinstance(v[0], dict):
478+
# we want to be OK if the type is a model and is passed as a dict;
479+
# pydantic will coerce it or fall over later
480+
classlist._class_handle = dict
481+
elif item_tp is not Any:
482+
classlist._class_handle = item_tp
483+
classlist.extend(v)
484+
v = classlist
485+
v = handler(v)
486+
return v
487+
488+
def validate_items(v: ClassList[T], handler: ValidatorFunctionWrapHandler) -> ClassList[T]:
489+
v.data = handler(v.data)
490+
return v
491+
492+
schema = core_schema.chain_schema(
493+
[
494+
core_schema.no_info_wrap_validator_function(coerce, core_schema.is_instance_schema(cls)),
495+
core_schema.no_info_wrap_validator_function(validate_items, list_schema),
496+
],
497+
)
498+
499+
return schema

RATapi/models.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""The models module. Contains the pydantic models used by RAT to store project parameters."""
22

33
import pathlib
4-
from typing import Any, Union
4+
from typing import Any
55

66
import numpy as np
77
import prettytable
@@ -91,6 +91,18 @@ class Contrast(RATModel):
9191
resample: bool = False
9292
model: list[str] = []
9393

94+
@model_validator(mode="before")
95+
@classmethod
96+
def domain_ratio_error(cls, data: Any):
97+
"""If the extra input 'domain_ratio' is given, give a more descriptive error."""
98+
99+
if isinstance(data, dict) and data.get("domain_ratio", False):
100+
raise ValueError(
101+
"The Contrast class does not support domain ratios. Use the ContrastWithRatio class instead."
102+
)
103+
104+
return data
105+
94106
def __str__(self):
95107
table = prettytable.PrettyTable()
96108
table.field_names = [key.replace("_", " ") for key in self.__dict__]
@@ -155,7 +167,7 @@ class CustomFile(RATModel):
155167
filename: str = ""
156168
function_name: str = ""
157169
language: Languages = Languages.Python
158-
path: Union[str, pathlib.Path] = ""
170+
path: pathlib.Path = pathlib.Path(".")
159171

160172
def model_post_init(self, __context: Any) -> None:
161173
"""If a "filename" is supplied but the "function_name" field is not set, the "function_name" should be set to
@@ -291,6 +303,16 @@ class Layer(RATModel, populate_by_name=True):
291303
hydration: str = ""
292304
hydrate_with: Hydration = Hydration.BulkOut
293305

306+
@model_validator(mode="before")
307+
@classmethod
308+
def sld_imaginary_error(cls, data: Any):
309+
"""If the extra input 'sld_imaginary' is given, give a more descriptive error."""
310+
311+
if isinstance(data, dict) and data.get("SLD_imaginary", False):
312+
raise ValueError("The Layer class does not support imaginary SLD. Use the AbsorptionLayer class instead.")
313+
314+
return data
315+
294316

295317
class AbsorptionLayer(RATModel, populate_by_name=True):
296318
"""Combines parameters into defined layers including absorption terms."""

0 commit comments

Comments
 (0)