66import contextlib
77import warnings
88from collections .abc import Sequence
9- from typing import Any , Union
9+ from typing import Any , Generic , TypeVar , Union
1010
1111import numpy as np
1212import prettytable
1313
14+ T = TypeVar ("T" )
1415
15- class ClassList (collections .UserList ):
16+
17+ class ClassList (collections .UserList , Generic [T ]):
1618 """List of instances of a particular class.
1719
1820 This class subclasses collections.UserList to construct a list intended to store ONLY instances of a particular
@@ -30,14 +32,14 @@ class ClassList(collections.UserList):
3032
3133 Parameters
3234 ----------
33- init_list : Sequence [object ] or object , optional
35+ init_list : Sequence [T ] or T , optional
3436 An instance, or list of instance(s), of the class to be used in this ClassList.
3537 name_field : str, optional
3638 The field used to define unique objects in the ClassList (default is "name").
3739
3840 """
3941
40- def __init__ (self , init_list : Union [Sequence [object ], object ] = None , name_field : str = "name" ) -> None :
42+ def __init__ (self , init_list : Union [Sequence [T ], T ] = None , name_field : str = "name" ) -> None :
4143 self .name_field = name_field
4244
4345 # Set input as list if necessary
@@ -81,20 +83,20 @@ def __str__(self):
8183 output = str (self .data )
8284 return output
8385
84- def __getitem__ (self , index : Union [int , slice , str , object ]) -> object :
86+ def __getitem__ (self , index : Union [int , slice , str , T ]) -> T :
8587 """Get an item by its index, name, a slice, or the object itself."""
8688 if isinstance (index , (int , slice )):
8789 return self .data [index ]
88- elif isinstance (index , (str , object )):
90+ elif isinstance (index , (str , self . _class_handle )):
8991 return self .data [self .index (index )]
9092 else :
9193 raise IndexError ("ClassLists can only be indexed by integers, slices, name strings, or objects." )
9294
93- def __setitem__ (self , index : int , item : object ) -> None :
95+ def __setitem__ (self , index : int , item : T ) -> None :
9496 """Replace the object at an existing index of the ClassList."""
9597 self ._setitem (index , item )
9698
97- def _setitem (self , index : int , item : object ) -> None :
99+ def _setitem (self , index : int , item : T ) -> None :
98100 """Auxiliary routine of "__setitem__" used to enable wrapping."""
99101 self ._check_classes ([item ])
100102 self ._check_unique_name_fields ([item ])
@@ -108,11 +110,11 @@ def _delitem(self, index: int) -> None:
108110 """Auxiliary routine of "__delitem__" used to enable wrapping."""
109111 del self .data [index ]
110112
111- def __iadd__ (self , other : Sequence [object ]) -> "ClassList" :
113+ def __iadd__ (self , other : Sequence [T ]) -> "ClassList" :
112114 """Define in-place addition using the "+=" operator."""
113115 return self ._iadd (other )
114116
115- def _iadd (self , other : Sequence [object ]) -> "ClassList" :
117+ def _iadd (self , other : Sequence [T ]) -> "ClassList" :
116118 """Auxiliary routine of "__iadd__" used to enable wrapping."""
117119 if other and not (isinstance (other , Sequence ) and not isinstance (other , str )):
118120 other = [other ]
@@ -135,13 +137,13 @@ def __imul__(self, n: int) -> None:
135137 """Define in-place multiplication using the "*=" operator."""
136138 raise TypeError (f"unsupported operand type(s) for *=: '{ self .__class__ .__name__ } ' and '{ n .__class__ .__name__ } '" )
137139
138- def append (self , obj : object = None , ** kwargs ) -> None :
140+ def append (self , obj : T = None , ** kwargs ) -> None :
139141 """Append a new object to the ClassList using either the object itself, or keyword arguments to set attribute
140142 values.
141143
142144 Parameters
143145 ----------
144- obj : object , optional
146+ obj : T , optional
145147 An instance of the class specified by self._class_handle.
146148 **kwargs : dict[str, Any], optional
147149 The input keyword arguments for a new object in the ClassList.
@@ -180,15 +182,15 @@ def append(self, obj: object = None, **kwargs) -> None:
180182 self ._validate_name_field (kwargs )
181183 self .data .append (self ._class_handle (** kwargs ))
182184
183- def insert (self , index : int , obj : object = None , ** kwargs ) -> None :
185+ def insert (self , index : int , obj : T = None , ** kwargs ) -> None :
184186 """Insert a new object into the ClassList at a given index using either the object itself, or keyword arguments
185187 to set attribute values.
186188
187189 Parameters
188190 ----------
189191 index: int
190192 The index at which to insert a new object in the ClassList.
191- obj : object , optional
193+ obj : T , optional
192194 An instance of the class specified by self._class_handle.
193195 **kwargs : dict[str, Any], optional
194196 The input keyword arguments for a new object in the ClassList.
@@ -227,26 +229,26 @@ def insert(self, index: int, obj: object = None, **kwargs) -> None:
227229 self ._validate_name_field (kwargs )
228230 self .data .insert (index , self ._class_handle (** kwargs ))
229231
230- def remove (self , item : Union [object , str ]) -> None :
232+ def remove (self , item : Union [T , str ]) -> None :
231233 """Remove an object from the ClassList using either the object itself or its name_field value."""
232234 item = self ._get_item_from_name_field (item )
233235 self .data .remove (item )
234236
235- def count (self , item : Union [object , str ]) -> int :
237+ def count (self , item : Union [T , str ]) -> int :
236238 """Return the number of times an object appears in the ClassList using either the object itself or its
237239 name_field value.
238240 """
239241 item = self ._get_item_from_name_field (item )
240242 return self .data .count (item )
241243
242- def index (self , item : Union [object , str ], offset : bool = False , * args ) -> int :
244+ def index (self , item : Union [T , str ], offset : bool = False , * args ) -> int :
243245 """Return the index of a particular object in the ClassList using either the object itself or its
244246 name_field value. If offset is specified, add one to the index. This is used to account for one-based indexing.
245247 """
246248 item = self ._get_item_from_name_field (item )
247249 return self .data .index (item , * args ) + int (offset )
248250
249- def extend (self , other : Sequence [object ]) -> None :
251+ def extend (self , other : Sequence [T ]) -> None :
250252 """Extend the ClassList by adding another sequence."""
251253 if other and not (isinstance (other , Sequence ) and not isinstance (other , str )):
252254 other = [other ]
@@ -319,7 +321,7 @@ def _validate_name_field(self, input_args: dict[str, Any]) -> None:
319321 f"which is already specified at index { names .index (name )} of the ClassList" ,
320322 )
321323
322- def _check_unique_name_fields (self , input_list : Sequence [object ]) -> None :
324+ def _check_unique_name_fields (self , input_list : Sequence [T ]) -> None :
323325 """Raise a ValueError if any value of the name_field attribute is used more than once in a list of class
324326 objects.
325327
@@ -376,7 +378,7 @@ def _check_unique_name_fields(self, input_list: Sequence[object]) -> None:
376378 f"{ newline .join (error for error in error_list )} "
377379 )
378380
379- def _check_classes (self , input_list : Sequence [object ]) -> None :
381+ def _check_classes (self , input_list : Sequence [T ]) -> None :
380382 """Raise a ValueError if any object in a list of objects is not of the type specified by self._class_handle.
381383
382384 Parameters
@@ -401,17 +403,17 @@ def _check_classes(self, input_list: Sequence[object]) -> None:
401403 f"In the input list:\n { newline .join (error for error in error_list )} \n "
402404 )
403405
404- def _get_item_from_name_field (self , value : Union [object , str ]) -> Union [object , str ]:
406+ def _get_item_from_name_field (self , value : Union [T , str ]) -> Union [T , str ]:
405407 """Return the object with the given value of the name_field attribute in the ClassList.
406408
407409 Parameters
408410 ----------
409- value : object or str
411+ value : T or str
410412 Either an object in the ClassList, or the value of the name_field attribute of an object in the ClassList.
411413
412414 Returns
413415 -------
414- instance : object or str
416+ instance : T or str
415417 Either the object with the value of the name_field attribute given by value, or the input value if an
416418 object with that value of the name_field attribute cannot be found.
417419
@@ -424,7 +426,7 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object,
424426 return next ((model for model in self .data if getattr (model , self .name_field ).lower () == lower_value ), value )
425427
426428 @staticmethod
427- def _determine_class_handle (input_list : Sequence [object ]):
429+ def _determine_class_handle (input_list : Sequence [T ]):
428430 """When inputting a sequence of object to a ClassList, the _class_handle should be set as the type of the
429431 element which satisfies "issubclass" for all the other elements.
430432
@@ -448,3 +450,50 @@ def _determine_class_handle(input_list: Sequence[object]):
448450 class_handle = type (input_list [0 ])
449451
450452 return class_handle
453+
454+ # Pydantic core schema which allows ClassLists to be validated
455+ # in short: it validates that each ClassList is indeed a ClassList,
456+ # and then validates ClassList.data as though it were a typed list
457+ # e.g. ClassList[str] data is validated like list[str]
458+ @classmethod
459+ def __get_pydantic_core_schema__ (cls , source : Any , handler ):
460+ # import here so that the ClassList can be instantiated and used without Pydantic installed
461+ from pydantic import ValidatorFunctionWrapHandler
462+ from pydantic .types import (
463+ core_schema , # import core_schema through here rather than making pydantic_core a dependency
464+ )
465+ from typing_extensions import get_args , get_origin
466+
467+ # if annotated with a class, get the item type of that class
468+ origin = get_origin (source )
469+ item_tp = Any if origin is None else get_args (source )[0 ]
470+
471+ list_schema = handler .generate_schema (list [item_tp ])
472+
473+ def coerce (v : Any , handler : ValidatorFunctionWrapHandler ) -> ClassList [T ]:
474+ """If a sequence is given, try to coerce it to a ClassList."""
475+ if isinstance (v , Sequence ):
476+ classlist = ClassList ()
477+ if len (v ) > 0 and isinstance (v [0 ], dict ):
478+ # we want to be OK if the type is a model and is passed as a dict;
479+ # pydantic will coerce it or fall over later
480+ classlist ._class_handle = dict
481+ elif item_tp is not Any :
482+ classlist ._class_handle = item_tp
483+ classlist .extend (v )
484+ v = classlist
485+ v = handler (v )
486+ return v
487+
488+ def validate_items (v : ClassList [T ], handler : ValidatorFunctionWrapHandler ) -> ClassList [T ]:
489+ v .data = handler (v .data )
490+ return v
491+
492+ schema = core_schema .chain_schema (
493+ [
494+ core_schema .no_info_wrap_validator_function (coerce , core_schema .is_instance_schema (cls )),
495+ core_schema .no_info_wrap_validator_function (validate_items , list_schema ),
496+ ],
497+ )
498+
499+ return schema
0 commit comments