diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..c131bb3c7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: +- repo: local + hooks: + - id: ruff-check + name: ruff-check + entry: uv run ruff check --fix . + language: system + pass_filenames: false + - id: ruff-format + name: ruff-format + entry: uv run ruff format . + language: system + pass_filenames: false + # - id: ty-check + # name: ty-check + # entry: uv run ty check . + # language: system + # pass_filenames: false diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 000000000..eed2f3797 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,6 @@ +{ + "recommendations": [ + // Linting and formatting + "charliermarsh.ruff", + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..eb0d949d7 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,13 @@ +{ + "python.defaultInterpreterPath": "${workspaceFolder}/.venv/bin/python", + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff", + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + } + }, + // Disable flake8 extension completely + "flake8.enabled": false, +} diff --git a/docs/source/conf.py b/docs/source/conf.py index 5b9aa88be..dab3c605e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,14 +13,15 @@ import citrine import os import sys -sys.path.insert(0, os.path.abspath('../../src/citrine')) + +sys.path.insert(0, os.path.abspath("../../src/citrine")) # -- Project information ----------------------------------------------------- -project = 'Citrine Python' -copyright = '2019, Citrine Informatics' -author = 'Citrine Informatics' +project = "Citrine Python" +copyright = "2019, Citrine Informatics" +author = "Citrine Informatics" # The short X.Y version. version = citrine.__version__ @@ -33,10 +34,10 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinxcontrib.apidoc', - 'sphinx.ext.napoleon', - 'sphinx.ext.intersphinx', - 'sphinx_rtd_theme' + "sphinxcontrib.apidoc", + "sphinx.ext.napoleon", + "sphinx.ext.intersphinx", + "sphinx_rtd_theme", ] # Use the sphinxcontrib.apidoc extension to wire in the sphinx-apidoc invocation @@ -44,17 +45,19 @@ # build. # # See: https://github.com/sphinx-contrib/apidoc -apidoc_module_dir = '../../src/citrine' -apidoc_output_dir = 'reference' -apidoc_excluded_paths = ['tests'] +apidoc_module_dir = "../../src/citrine" +apidoc_output_dir = "reference" +apidoc_excluded_paths = ["tests"] apidoc_separate_modules = True # Use intersphinx to link to classes in Sphinx docs for other libraries # See: https://www.sphinx-doc.org/en/master/usage/extensions/intersphinx.html -intersphinx_mapping = {'gemd-python': ('https://citrineinformatics.github.io/gemd-python/', None)} +intersphinx_mapping = { + "gemd-python": ("https://citrineinformatics.github.io/gemd-python/", None) +} # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -71,19 +74,16 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # These paths are either relative to html_static_path or fully qualified paths (eg. https://...) html_css_files = [ - 'css/custom.css', + "css/custom.css", ] -autodoc_member_order = 'groupwise' +autodoc_member_order = "groupwise" autodoc_mock_imports = [] # autodoc_mock_imports allows Spyinx to ignore any external modules listed in the array -html_favicon = '_static/favicon.png' -html_logo = '_static/logo.png' -html_theme_options = { - 'sticky_navigation': False, - 'logo_only': True -} +html_favicon = "_static/favicon.png" +html_logo = "_static/logo.png" +html_theme_options = {"sticky_navigation": False, "logo_only": True} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..f2d2a87db --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src"] + +[project] +name = "citrine" +version = "0.1.0" +description = "Python library for the Citrine Platform" +requires-python = ">=3.12" + +dependencies = [ + "arrow", + "boto3", + "deprecation", + "gemd", + "requests", + "tqdm", + "urllib3", + +] + +[project.optional-dependencies] +test = [ + "factory-boy", + "mock", + "pandas", + "pre-commit", + "pytest", + "requests-mock", + "ruff", + "ty" +] \ No newline at end of file diff --git a/scripts/validate_version_bump.py b/scripts/validate_version_bump.py index 180c8b731..c37d97f4f 100755 --- a/scripts/validate_version_bump.py +++ b/scripts/validate_version_bump.py @@ -9,7 +9,7 @@ def main(): repo_dir = popen("git rev-parse --show-toplevel", mode="r").read().rstrip() - version_path = relpath(f'{repo_dir}/src/citrine/__version__.py', getcwd()) + version_path = relpath(f"{repo_dir}/src/citrine/__version__.py", getcwd()) try: with open(version_path, "r") as fh: @@ -18,10 +18,13 @@ def main(): raise ValueError(f"Couldn't extract version from {version_path}") from e try: - with popen(f"git fetch origin && git show origin/main:src/citrine/__version__.py", mode="r") as fh: + with popen( + "git fetch origin && git show origin/main:src/citrine/__version__.py", + mode="r", + ) as fh: old_version = extract_version(fh) except Exception as e: - raise ValueError(f"Couldn't extract version from main branch") from e + raise ValueError("Couldn't extract version from main branch") from e if new_version.major != old_version.major: number = "major" @@ -50,7 +53,7 @@ def main(): def extract_version(handle: TextIO) -> Version: text = handle.read() if not re.search(r"\S", text): - raise ValueError(f"No content") + raise ValueError("No content") match = re.search(r"""^\s*__version__\s*=\s*(['"])(\S+)\1""", text, re.MULTILINE) if match: return Version(match.group(2)) diff --git a/src/citrine/__init__.py b/src/citrine/__init__.py index effc6d496..f728b81f0 100644 --- a/src/citrine/__init__.py +++ b/src/citrine/__init__.py @@ -5,5 +5,6 @@ https://citrineinformatics.github.io/citrine-python/index.html """ + from citrine.citrine import Citrine # noqa: F401 from .__version__ import __version__ # noqa: F401 diff --git a/src/citrine/_rest/ai_resource_metadata.py b/src/citrine/_rest/ai_resource_metadata.py index 4b3175ed7..ccfe734df 100644 --- a/src/citrine/_rest/ai_resource_metadata.py +++ b/src/citrine/_rest/ai_resource_metadata.py @@ -2,31 +2,40 @@ from citrine._serialization import properties -class AIResourceMetadata(): +class AIResourceMetadata: """Abstract class for representing common metadata for Resources.""" - created_by = properties.Optional(properties.UUID, 'created_by', serializable=False) + created_by = properties.Optional(properties.UUID, "created_by", serializable=False) """:Optional[UUID]: id of the user who created the resource""" - create_time = properties.Optional(properties.Datetime, 'create_time', serializable=False) + create_time = properties.Optional( + properties.Datetime, "create_time", serializable=False + ) """:Optional[datetime]: date and time at which the resource was created""" - updated_by = properties.Optional(properties.UUID, 'updated_by', serializable=False) + updated_by = properties.Optional(properties.UUID, "updated_by", serializable=False) """:Optional[UUID]: id of the user who most recently updated the resource, if it has been updated""" - update_time = properties.Optional(properties.Datetime, 'update_time', serializable=False) + update_time = properties.Optional( + properties.Datetime, "update_time", serializable=False + ) """:Optional[datetime]: date and time at which the resource was most recently updated, if it has been updated""" - archived = properties.Boolean('archived', default=False) + archived = properties.Boolean("archived", default=False) """:bool: whether the resource is archived (hidden but not deleted)""" - archived_by = properties.Optional(properties.UUID, 'archived_by', serializable=False) + archived_by = properties.Optional( + properties.UUID, "archived_by", serializable=False + ) """:Optional[UUID]: id of the user who archived the resource, if it has been archived""" - archive_time = properties.Optional(properties.Datetime, 'archive_time', serializable=False) + archive_time = properties.Optional( + properties.Datetime, "archive_time", serializable=False + ) """:Optional[datetime]: date and time at which the resource was archived, if it has been archived""" - status = properties.Optional(properties.String(), 'status', serializable=False) + status = properties.Optional(properties.String(), "status", serializable=False) """:Optional[str]: short description of the resource's status""" - status_detail = properties.List(properties.Object(StatusDetail), 'status_detail', default=[], - serializable=False) + status_detail = properties.List( + properties.Object(StatusDetail), "status_detail", default=[], serializable=False + ) """:List[StatusDetail]: a list of structured status info, containing the message and level""" diff --git a/src/citrine/_rest/collection.py b/src/citrine/_rest/collection.py index 6f83805ed..b76933de0 100644 --- a/src/citrine/_rest/collection.py +++ b/src/citrine/_rest/collection.py @@ -9,12 +9,12 @@ from citrine.exceptions import ModuleRegistrationFailedException, NonRetryableException from citrine.resources.response import Response -ResourceType = TypeVar('ResourceType', bound=Resource) +ResourceType = TypeVar("ResourceType", bound=Resource) # Python does not support a TypeVar being used as a bound for another TypeVar. # Thus, this will never be particularly type safe on its own. The solution is to # have subclasses override the create method. -CreationType = TypeVar('CreationType', bound='Resource') +CreationType = TypeVar("CreationType", bound="Resource") class Collection(Generic[ResourceType], Pageable): @@ -24,7 +24,7 @@ class Collection(Generic[ResourceType], Pageable): _dataset_agnostic_path_template: str = NotImplemented _individual_key: str = NotImplemented _resource: ResourceType = NotImplemented - _collection_key: str = 'entries' + _collection_key: str = "entries" _paginator: Paginator = Paginator() _api_version: str = "v1" @@ -33,17 +33,27 @@ def _put_resource_ref(self, subpath: str, uid: Union[UUID, str]): ref = ResourceRef(uid) return self.session.put_resource(url, ref.dump(), version=self._api_version) - def _get_path(self, - uid: Optional[Union[UUID, str]] = None, - *, - ignore_dataset: bool = False, - action: Union[str, Sequence[str]] = [], - query_terms: Dict[str, str] = {}, - ) -> str: + def _get_path( + self, + uid: Optional[Union[UUID, str]] = None, + *, + ignore_dataset: bool = False, + action: Union[str, Sequence[str]] = [], + query_terms: Dict[str, str] = {}, + ) -> str: """Construct a url from __base_path__ and, optionally, id and/or action.""" - base = self._dataset_agnostic_path_template if ignore_dataset else self._path_template - return resource_path(path_template=base, uid=uid, action=action, query_terms=query_terms, - **self.__dict__) + base = ( + self._dataset_agnostic_path_template + if ignore_dataset + else self._path_template + ) + return resource_path( + path_template=base, + uid=uid, + action=action, + query_terms=query_terms, + **self.__dict__, + ) @abstractmethod def build(self, data: dict): @@ -52,7 +62,9 @@ def build(self, data: dict): def get(self, uid: Union[UUID, str]) -> ResourceType: """Get a particular element of the collection.""" if uid is None: - raise ValueError("Cannot get when uid=None. Are you using a registered resource?") + raise ValueError( + "Cannot get when uid=None. Are you using a registered resource?" + ) path = self._get_path(uid) data = self.session.get_resource(path, version=self._api_version) data = data[self._individual_key] if self._individual_key else data @@ -62,7 +74,9 @@ def register(self, model: CreationType) -> CreationType: """Create a new element of the collection by registering an existing resource.""" path = self._get_path() try: - data = self.session.post_resource(path, model.dump(), version=self._api_version) + data = self.session.post_resource( + path, model.dump(), version=self._api_version + ) data = data[self._individual_key] if self._individual_key else data return self.build(data) except NonRetryableException as e: @@ -89,14 +103,18 @@ def list(self, *, per_page: int = 100) -> Iterator[ResourceType]: Use list() to force evaluation of all results into an in-memory list. """ - return self._paginator.paginate(page_fetcher=self._fetch_page, - collection_builder=self._build_collection_elements, - per_page=per_page) + return self._paginator.paginate( + page_fetcher=self._fetch_page, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) def update(self, model: CreationType) -> CreationType: """Update a particular element of the collection.""" url = self._get_path(model.uid) - updated = self.session.put_resource(url, model.dump(), version=self._api_version) + updated = self.session.put_resource( + url, model.dump(), version=self._api_version + ) data = updated[self._individual_key] if self._individual_key else updated return self.build(data) @@ -106,8 +124,9 @@ def delete(self, uid: Union[UUID, str]) -> Response: data = self.session.delete_resource(url, version=self._api_version) return Response(body=data) - def _build_collection_elements(self, - collection: Iterable[dict]) -> Iterator[ResourceType]: + def _build_collection_elements( + self, collection: Iterable[dict] + ) -> Iterator[ResourceType]: """ For each element in the collection, build the appropriate resource type. diff --git a/src/citrine/_rest/engine_resource.py b/src/citrine/_rest/engine_resource.py index 9abd9f568..35aed8ef3 100644 --- a/src/citrine/_rest/engine_resource.py +++ b/src/citrine/_rest/engine_resource.py @@ -5,32 +5,39 @@ from citrine._serialization.include_parent_properties import IncludeParentProperties from citrine.resources.status_detail import StatusDetail -Self = TypeVar('Self', bound='Resource') +Self = TypeVar("Self", bound="Resource") class EngineResourceWithoutStatus(Resource[Self]): """Base resource for metadata from stand-alone AI Engine modules.""" - created_by = properties.Optional(properties.UUID, 'metadata.created.user', serializable=False) + created_by = properties.Optional( + properties.UUID, "metadata.created.user", serializable=False + ) """:Optional[UUID]: id of the user who created the resource""" - create_time = properties.Optional(properties.Datetime, 'metadata.created.time', - serializable=False) + create_time = properties.Optional( + properties.Datetime, "metadata.created.time", serializable=False + ) """:Optional[datetime]: date and time at which the resource was created""" - updated_by = properties.Optional(properties.UUID, 'metadata.updated.user', - serializable=False) + updated_by = properties.Optional( + properties.UUID, "metadata.updated.user", serializable=False + ) """:Optional[UUID]: id of the user who most recently updated the resource, if it has been updated""" - update_time = properties.Optional(properties.Datetime, 'metadata.updated.time', - serializable=False) + update_time = properties.Optional( + properties.Datetime, "metadata.updated.time", serializable=False + ) """:Optional[datetime]: date and time at which the resource was most recently updated, if it has been updated""" - archived_by = properties.Optional(properties.UUID, 'metadata.archived.user', - serializable=False) + archived_by = properties.Optional( + properties.UUID, "metadata.archived.user", serializable=False + ) """:Optional[UUID]: id of the user who archived the resource, if it has been archived""" - archive_time = properties.Optional(properties.Datetime, 'metadata.archived.time', - serializable=False) + archive_time = properties.Optional( + properties.Datetime, "metadata.archived.time", serializable=False + ) """:Optional[datetime]: date and time at which the resource was archived, if it has been archived""" @@ -59,10 +66,16 @@ def _post_dump(self, data: dict) -> dict: class EngineResource(EngineResourceWithoutStatus[Self], IncludeParentProperties[Self]): """Base resource for metadata from stand-alone AI Engine modules.""" - status = properties.Optional(properties.String(), 'metadata.status.name', serializable=False) + status = properties.Optional( + properties.String(), "metadata.status.name", serializable=False + ) """:Optional[str]: short description of the resource's status""" - status_detail = properties.List(properties.Object(StatusDetail), 'metadata.status.detail', - default=[], serializable=False) + status_detail = properties.List( + properties.Object(StatusDetail), + "metadata.status.detail", + default=[], + serializable=False, + ) """:List[StatusDetail]: a list of structured status info, containing the message and level""" @classmethod @@ -75,10 +88,14 @@ class VersionedEngineResource(EngineResource[Self], IncludeParentProperties[Self """Base resource for metadata from stand-alone AI Engine modules which support versioning.""" """:Integer: The version number of the resource.""" - version = properties.Optional(properties.Integer, 'metadata.version', serializable=False) + version = properties.Optional( + properties.Integer, "metadata.version", serializable=False + ) """:Boolean: The draft status of the resource.""" - draft = properties.Optional(properties.Boolean, 'metadata.draft', serializable=False) + draft = properties.Optional( + properties.Boolean, "metadata.draft", serializable=False + ) @classmethod def build(cls, data: dict): diff --git a/src/citrine/_rest/pageable.py b/src/citrine/_rest/pageable.py index 129155cf7..fa12bb66c 100644 --- a/src/citrine/_rest/pageable.py +++ b/src/citrine/_rest/pageable.py @@ -2,32 +2,34 @@ from uuid import UUID -class Pageable(): +class Pageable: """Class that allows paging.""" _collection_key: str = NotImplemented _api_version: str = "v1" - def _get_path(self, - uid: Optional[Union[UUID, str]] = None, - *, - ignore_dataset: bool = False, - action: Union[str, Sequence[str]] = [], - query_terms: Dict[str, str] = {}, - ) -> str: + def _get_path( + self, + uid: Optional[Union[UUID, str]] = None, + *, + ignore_dataset: bool = False, + action: Union[str, Sequence[str]] = [], + query_terms: Dict[str, str] = {}, + ) -> str: """Construct a url from __base_path__ and, optionally, id.""" raise NotImplementedError # pragma: no cover - def _fetch_page(self, - path: Optional[str] = None, - fetch_func: Optional[Callable[..., dict]] = None, - page: Optional[int] = None, - per_page: Optional[int] = None, - json_body: Optional[dict] = None, - additional_params: Optional[dict] = None, - *, - version: Optional[str] = None - ) -> Tuple[Iterable[dict], str]: + def _fetch_page( + self, + path: Optional[str] = None, + fetch_func: Optional[Callable[..., dict]] = None, + page: Optional[int] = None, + per_page: Optional[int] = None, + json_body: Optional[dict] = None, + additional_params: Optional[dict] = None, + *, + version: Optional[str] = None, + ) -> Tuple[Iterable[dict], str]: """ Fetch visible elements. This does not handle pagination. @@ -85,7 +87,7 @@ def _fetch_page(self, data = fetch_func(path, params=params, version=version, **json_body) try: - next_uri = data.get('next', "") + next_uri = data.get("next", "") except AttributeError: next_uri = "" @@ -99,10 +101,12 @@ def _fetch_page(self, return collection, next_uri - def _page_params(self, - page: Optional[int], - per_page: Optional[int], - module_type: Optional[str] = None) -> Dict[str, int]: + def _page_params( + self, + page: Optional[int], + per_page: Optional[int], + module_type: Optional[str] = None, + ) -> Dict[str, int]: params = {} if page is not None: params["page"] = page diff --git a/src/citrine/_rest/paginator.py b/src/citrine/_rest/paginator.py index 60dd6cac7..eca25fed2 100644 --- a/src/citrine/_rest/paginator.py +++ b/src/citrine/_rest/paginator.py @@ -1,7 +1,7 @@ from typing import TypeVar, Generic, Callable, Optional, Iterable, Any, Tuple, Iterator from uuid import uuid4 -ResourceType = TypeVar('ResourceType') +ResourceType = TypeVar("ResourceType") class Paginator(Generic[ResourceType]): @@ -12,12 +12,14 @@ class Paginator(Generic[ResourceType]): that will be extracted for comparison purposes (to avoid looping on the same items). """ - def paginate(self, - page_fetcher: Callable[[Optional[int], int], Tuple[Iterable[dict], str]], - collection_builder: Callable[[Iterable[dict]], Iterable[ResourceType]], - per_page: int = 100, - search_params: Optional[dict] = None, - deduplicate: bool = True) -> Iterator[ResourceType]: + def paginate( + self, + page_fetcher: Callable[[Optional[int], int], Tuple[Iterable[dict], str]], + collection_builder: Callable[[Iterable[dict]], Iterable[ResourceType]], + per_page: int = 100, + search_params: Optional[dict] = None, + deduplicate: bool = True, + ) -> Iterator[ResourceType]: """ A generic support class to paginate requests into an iterable of a built object. @@ -53,26 +55,27 @@ def paginate(self, """ # To avoid setting default to {} -> reduce mutation risk, and to make more extensible. Also # making 'search_params' key of outermost dict for keyword expansion by page_fetcher func - search_params = {} if search_params is None else {'search_params': search_params} + search_params = ( + {} if search_params is None else {"search_params": search_params} + ) first_entity = None page_idx = 1 uids = set() while True: - subset_collection, next_uri = page_fetcher(page=page_idx, per_page=per_page, - **search_params) + subset_collection, next_uri = page_fetcher( + page=page_idx, per_page=per_page, **search_params + ) subset = collection_builder(subset_collection) count = 0 for idx, element in enumerate(subset): - # escaping from infinite loops where page/per_page are not # honored and are returning the same results regardless of page: current_entity = self._comparison_fields(element) - if first_entity is not None and \ - first_entity == current_entity: + if first_entity is not None and first_entity == current_entity: # TODO: raise an exception once the APIs that ignore pagination are fixed break @@ -106,4 +109,4 @@ def _comparison_fields(self, entity: ResourceType) -> Any: If the 'uid' here isn't found, default to comparing the entire entity. """ - return getattr(entity, 'uid', entity) + return getattr(entity, "uid", entity) diff --git a/src/citrine/_rest/resource.py b/src/citrine/_rest/resource.py index 6bc63d202..9a4ae9599 100644 --- a/src/citrine/_rest/resource.py +++ b/src/citrine/_rest/resource.py @@ -30,7 +30,7 @@ class ResourceTypeEnum(BaseEnumeration): TABLE_DEFINITION = "TABLE_DEFINITION" -Self = TypeVar('Self', bound='Resource') +Self = TypeVar("Self", bound="Resource") class Resource(Serializable[Self]): @@ -42,13 +42,10 @@ class Resource(Serializable[Self]): def access_control_dict(self) -> dict: """Return an access control entity representation of this resource. Internal use only.""" - return { - "type": self._resource_type.value, - "id": str(self.uid) - } + return {"type": self._resource_type.value, "id": str(self.uid)} -GEMDSelf = TypeVar('GEMDSelf', bound='GEMDResource') +GEMDSelf = TypeVar("GEMDSelf", bound="GEMDResource") class GEMDResource(Resource[GEMDSelf]): @@ -58,8 +55,10 @@ class GEMDResource(Resource[GEMDSelf]): def build(cls, data: dict) -> GEMDSelf: """Convert a raw, nested dictionary into Objects.""" if "context" in data and len(data) == 2: + def _inflate(x): return DictSerializable.class_mapping[x["type"]].build(x) + key = next(k for k in data if k != "context") idx = make_index([_inflate(x) for x in data["context"] + [data[key]]]) lst = [idx[k] for k in idx] @@ -72,8 +71,12 @@ def _inflate(x): return idx[root.to_link()] else: if data.get("type") is not None: - if not issubclass(cls, DictSerializable.class_mapping.get(data.get("type"))): - raise ValueError(f"{cls.__name__} passed a {data.get('type')} dictionary.") + if not issubclass( + cls, DictSerializable.class_mapping.get(data.get("type")) + ): + raise ValueError( + f"{cls.__name__} passed a {data.get('type')} dictionary." + ) return super().build(data) def as_dict(self) -> dict: @@ -91,25 +94,27 @@ def as_dict(self) -> dict: return result -class ResourceRef(Serializable['ResourceRef']): +class ResourceRef(Serializable["ResourceRef"]): """A reference to a resource by UID.""" # json key 'module_uid' is a legacy of when this object was only used for modules - uid = properties.UUID('module_uid') + uid = properties.UUID("module_uid") def __init__(self, uid: Union[UUID, str]): self.uid = uid -class PredictorRef(Serializable['PredictorRef']): +class PredictorRef(Serializable["PredictorRef"]): """A reference to a resource by UID.""" - uid = properties.UUID('predictor_id') + uid = properties.UUID("predictor_id") version = properties.Optional( properties.Union([properties.Integer(), properties.String()]), - 'predictor_version' + "predictor_version", ) - def __init__(self, uid: Union[UUID, str], version: Optional[Union[int, str]] = None): + def __init__( + self, uid: Union[UUID, str], version: Optional[Union[int, str]] = None + ): self.uid = uid self.version = version diff --git a/src/citrine/_serialization/include_parent_properties.py b/src/citrine/_serialization/include_parent_properties.py index 1e9dfcd58..0f4a35f3f 100644 --- a/src/citrine/_serialization/include_parent_properties.py +++ b/src/citrine/_serialization/include_parent_properties.py @@ -2,7 +2,7 @@ from citrine._serialization.serializable import Serializable -Self = TypeVar('Self', bound='Serializable') +Self = TypeVar("Self", bound="Serializable") class IncludeParentProperties(Serializable[Self]): @@ -14,6 +14,7 @@ def build_with_parent(cls, data: dict, base_cls) -> Self: resource = super().build(data) from citrine._serialization import properties + metadata_properties = properties.Object(base_cls).deserialize(data) resource.__dict__.update(metadata_properties.__dict__) diff --git a/src/citrine/_serialization/polymorphic_serializable.py b/src/citrine/_serialization/polymorphic_serializable.py index bd7028266..eb2389f90 100644 --- a/src/citrine/_serialization/polymorphic_serializable.py +++ b/src/citrine/_serialization/polymorphic_serializable.py @@ -4,7 +4,7 @@ from citrine._serialization.serializable import Serializable -SelfType = TypeVar('SelfType', bound='PolymorphicSerializable') +SelfType = TypeVar("SelfType", bound="PolymorphicSerializable") class PolymorphicSerializable(Generic[SelfType]): diff --git a/src/citrine/_serialization/properties.py b/src/citrine/_serialization/properties.py index 95cceef28..650016f97 100644 --- a/src/citrine/_serialization/properties.py +++ b/src/citrine/_serialization/properties.py @@ -1,4 +1,5 @@ """Property objects for typed setting and ser/de.""" + from abc import abstractmethod import typing from datetime import datetime @@ -17,10 +18,10 @@ from citrine._serialization.serializable import Serializable from citrine._serialization.polymorphic_serializable import PolymorphicSerializable -SerializedType = typing.TypeVar('SerializedType') -DeserializedType = typing.TypeVar('DeserializedType') -SerializedInteger = typing.TypeVar('SerializedInteger', int, str) -SerializedFloat = typing.TypeVar('SerializedFloat', float, str) +SerializedType = typing.TypeVar("SerializedType") +DeserializedType = typing.TypeVar("DeserializedType") +SerializedInteger = typing.TypeVar("SerializedInteger", int, str) +SerializedFloat = typing.TypeVar("SerializedFloat", float, str) class Property(typing.Generic[DeserializedType, SerializedType]): @@ -51,20 +52,23 @@ class Property(typing.Generic[DeserializedType, SerializedType]): """ - def __init__(self, - serialization_path: typing.Optional[str] = None, - *, - serializable: bool = True, - deserializable: bool = True, - default: typing.Optional[DeserializedType] = None, - override: bool = False, - use_init: bool = False - ): + def __init__( + self, + serialization_path: typing.Optional[str] = None, + *, + serializable: bool = True, + deserializable: bool = True, + default: typing.Optional[DeserializedType] = None, + override: bool = False, + use_init: bool = False, + ): self.serialization_path = serialization_path if override: self._key: None = None else: - self._key: str = '__' + str(uuid.uuid4()) # Make this object key human-readable + self._key: str = "__" + str( + uuid.uuid4() + ) # Make this object key human-readable self.serializable: bool = serializable self.deserializable: bool = deserializable self.default: typing.Optional[DeserializedType] = default @@ -74,42 +78,48 @@ def __init__(self, @property @abstractmethod - def underlying_types(self) -> typing.Union[DeserializedType, typing.Tuple[DeserializedType]]: + def underlying_types( + self, + ) -> typing.Union[DeserializedType, typing.Tuple[DeserializedType]]: """Return the python types handled by this property.""" @property @abstractmethod - def serialized_types(self) -> typing.Union[SerializedType, typing.Tuple[SerializedType]]: + def serialized_types( + self, + ) -> typing.Union[SerializedType, typing.Tuple[SerializedType]]: """Return the types used to serialize this property.""" def _error_source(self, base_class: type) -> str: """Construct a string of the base class name and the parameter that failed.""" if base_class is not None: - return ' for {}:{}'.format(base_class.__name__, self.serialization_path) + return " for {}:{}".format(base_class.__name__, self.serialization_path) elif self.serialization_path: - return ' for {}'.format(self.serialization_path) + return " for {}".format(self.serialization_path) else: - return '' + return "" - def serialize(self, value: DeserializedType, - base_class: typing.Optional[type] = None) -> SerializedType: + def serialize( + self, value: DeserializedType, base_class: typing.Optional[type] = None + ) -> SerializedType: if not isinstance(value, self.underlying_types): base_name = self._error_source(base_class) raise ValueError( - f'{type(value)} {value} is not one of valid types: ' - f'{self.underlying_types}{base_name}' + f"{type(value)} {value} is not one of valid types: " + f"{self.underlying_types}{base_name}" ) return self._serialize(value) - def deserialize(self, value: SerializedType, - base_class: typing.Optional[type] = None) -> DeserializedType: + def deserialize( + self, value: SerializedType, base_class: typing.Optional[type] = None + ) -> DeserializedType: if not isinstance(value, self.serialized_types): if isinstance(value, self.underlying_types): return value # Don't worry if it was already deserialized base_name = self._error_source(base_class) raise ValueError( - f'{type(value)} {value} is not one of valid types: ' - f'{self.serialized_types}{base_name}' + f"{type(value)} {value} is not one of valid types: " + f"{self.serialized_types}{base_name}" ) return self._deserialize(value) @@ -124,13 +134,14 @@ def _deserialize(self, value: SerializedType) -> DeserializedType: def deserialize_from_dict(self, data: dict) -> DeserializedType: value = data # `serialization_path` is expected to be a sequence of nested dictionary keys - fields = self.serialization_path.split('.') + fields = self.serialization_path.split(".") for field in fields: next_value = value.get(field) if next_value is None: if self.default is None and not self.optional: msg = "Unable to deserialize {} into {}, missing a required field: {}".format( - data, self.underlying_types, field) + data, self.underlying_types, field + ) raise ValueError(msg) # This occurs if a `field` is unexpectedly not present in the data dictionary # or if its value is null. @@ -144,10 +155,10 @@ def deserialize_from_dict(self, data: dict) -> DeserializedType: def serialize_to_dict(self, data: dict, value: DeserializedType) -> dict: if self.serialization_path is None: - raise ValueError('No serialization path set!') + raise ValueError("No serialization path set!") _data = data - fields = self.serialization_path.split('.') + fields = self.serialization_path.split(".") for field in fields[:-1]: _data = _data.setdefault(field, {}) _data[fields[-1]] = self.serialize(value, base_class=None) # Always a dict @@ -192,11 +203,10 @@ def __set__(self, obj, value: typing.Union[SerializedType, DeserializedType]): setattr(obj, self._key, value_to_set) def __str__(self): - return ''.format(self.serialization_path) + return "".format(self.serialization_path) class PropertyCollection(Property[DeserializedType, SerializedType]): - def __set__(self, obj, value: typing.Union[SerializedType, DeserializedType]): """ Property setter for container property types. @@ -240,8 +250,9 @@ def _set_elements(self, value: typing.Union[SerializedType, DeserializedType]): @lru_cache(maxsize=1024) -def _get_key_and_base_class(prop: Property, klass: typing.Any) -> \ - typing.Tuple[typing.Optional[str], typing.Optional[str]]: +def _get_key_and_base_class( + prop: Property, klass: typing.Any +) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]: """ Return the base class and class attribute name for the object and property. @@ -257,7 +268,6 @@ def _get_key_and_base_class(prop: Property, klass: typing.Any) -> \ class Integer(Property[int, SerializedInteger]): - @property def underlying_types(self): return int @@ -268,22 +278,21 @@ def serialized_types(self): def _deserialize(self, value: SerializedInteger) -> int: if isinstance(value, bool): - raise TypeError('value must be a Number, not a boolean.') + raise TypeError("value must be a Number, not a boolean.") else: return int(value) def _serialize(self, value: int) -> SerializedInteger: if isinstance(value, bool): - raise TypeError('Boolean cannot be serialized to integer.') + raise TypeError("Boolean cannot be serialized to integer.") else: return value def __str__(self): - return ''.format(self.serialization_path) + return "".format(self.serialization_path) class Float(Property[float, SerializedFloat]): - @property def underlying_types(self): return float @@ -295,7 +304,7 @@ def serialized_types(self): @classmethod def _deserialize(cls, value: SerializedFloat) -> float: if isinstance(value, bool): - raise TypeError('value must be a Number, not a boolean.') + raise TypeError("value must be a Number, not a boolean.") else: return float(value) @@ -304,11 +313,10 @@ def _serialize(cls, value: float) -> SerializedFloat: return value def __str__(self): - return ''.format(self.serialization_path) + return "".format(self.serialization_path) class Raw(Property[typing.Any, typing.Any]): - @property def underlying_types(self): return object @@ -326,11 +334,10 @@ def _serialize(cls, value: typing.Any) -> typing.Any: return value def __str__(self): - return ''.format(self.serialization_path) + return "".format(self.serialization_path) class String(Property[str, str]): - @property def underlying_types(self): return str @@ -342,18 +349,17 @@ def serialized_types(self): def _deserialize(self, value: str) -> str: value = self.default if value is None else value if value is None: - raise ValueError('Value must not be none!') + raise ValueError("Value must not be none!") return str(value) def _serialize(self, value: str) -> str: return str(value) def __str__(self): - return ''.format(self.serialization_path) + return "".format(self.serialization_path) class Boolean(Property[bool, bool]): - @property def underlying_types(self): return bool @@ -369,11 +375,10 @@ def _serialize(self, value: str) -> bool: return bool(value) def __str__(self): - return ''.format(self.serialization_path) + return "".format(self.serialization_path) class UUID(Property[uuid.UUID, str]): - @property def underlying_types(self): return uuid.UUID @@ -390,7 +395,6 @@ def _serialize(self, value: uuid.UUID) -> str: class Datetime(Property[datetime, int]): - @property def underlying_types(self): return datetime @@ -413,25 +417,28 @@ def _serialize(self, value: datetime) -> int: class List(PropertyCollection[list, list]): - - def __init__(self, - element_type: typing.Union[Property, typing.Type[Property]], - serialization_path: typing.Optional[str] = None, - *, - serializable: bool = True, - deserializable: bool = True, - default: typing.Optional[DeserializedType] = None, - override: bool = False, - use_init: bool = False - ): - super().__init__(serialization_path=serialization_path, - serializable=serializable, - deserializable=deserializable, - default=default, - override=override, - use_init=use_init - ) - self.element_type = element_type if isinstance(element_type, Property) else element_type() + def __init__( + self, + element_type: typing.Union[Property, typing.Type[Property]], + serialization_path: typing.Optional[str] = None, + *, + serializable: bool = True, + deserializable: bool = True, + default: typing.Optional[DeserializedType] = None, + override: bool = False, + use_init: bool = False, + ): + super().__init__( + serialization_path=serialization_path, + serializable=serializable, + deserializable=deserializable, + default=default, + override=override, + use_init=use_init, + ) + self.element_type = ( + element_type if isinstance(element_type, Property) else element_type() + ) @property def underlying_types(self): @@ -467,24 +474,28 @@ def _set_elements(self, value): class Set(PropertyCollection[set, typing.Iterable]): - - def __init__(self, - element_type: typing.Union[Property, typing.Type[Property]], - serialization_path: typing.Optional[str] = None, - *, - serializable: bool = True, - deserializable: bool = True, - default: typing.Optional[DeserializedType] = None, - override: bool = False, - use_init: bool = False - ): - super().__init__(serialization_path=serialization_path, - serializable=serializable, - deserializable=deserializable, - default=default, - override=override, - use_init=use_init) - self.element_type = element_type if isinstance(element_type, Property) else element_type() + def __init__( + self, + element_type: typing.Union[Property, typing.Type[Property]], + serialization_path: typing.Optional[str] = None, + *, + serializable: bool = True, + deserializable: bool = True, + default: typing.Optional[DeserializedType] = None, + override: bool = False, + use_init: bool = False, + ): + super().__init__( + serialization_path=serialization_path, + serializable=serializable, + deserializable=deserializable, + default=default, + override=override, + use_init=use_init, + ) + self.element_type = ( + element_type if isinstance(element_type, Property) else element_type() + ) @property def underlying_types(self): @@ -529,38 +540,58 @@ class Union(Property[typing.Any, typing.Any]): Attempted de/serialization is done in the order in which types are provided in the constructor. """ - def __init__(self, - element_types: typing.Sequence[typing.Union[Property, typing.Type[Property]]], - serialization_path: typing.Optional[str] = None, - *, - serializable: bool = True, - deserializable: bool = True, - default: typing.Optional[DeserializedType] = None, - override: bool = False, - use_init: bool = False - ): - super().__init__(serialization_path=serialization_path, - serializable=serializable, - deserializable=deserializable, - default=default, - override=override, - use_init=use_init) + def __init__( + self, + element_types: typing.Sequence[typing.Union[Property, typing.Type[Property]]], + serialization_path: typing.Optional[str] = None, + *, + serializable: bool = True, + deserializable: bool = True, + default: typing.Optional[DeserializedType] = None, + override: bool = False, + use_init: bool = False, + ): + super().__init__( + serialization_path=serialization_path, + serializable=serializable, + deserializable=deserializable, + default=default, + override=override, + use_init=use_init, + ) if not isinstance(element_types, typing.Iterable): raise ValueError("element types must be iterable: {}".format(element_types)) - self.element_types: typing.List[Property, ...] = \ - [el if isinstance(el, Property) else el() for el in element_types] + self.element_types: typing.List[Property, ...] = [ + el if isinstance(el, Property) else el() for el in element_types + ] @property def underlying_types(self): all_underlying_types = [prop.underlying_types for prop in self.element_types] - return tuple(set(chain(*[typ if isinstance(typ, tuple) - else (typ,) for typ in all_underlying_types]))) + return tuple( + set( + chain( + *[ + typ if isinstance(typ, tuple) else (typ,) + for typ in all_underlying_types + ] + ) + ) + ) @property def serialized_types(self): all_serialized_types = [prop.serialized_types for prop in self.element_types] - return tuple(set(chain(*[typ if isinstance(typ, tuple) - else (typ,) for typ in all_serialized_types]))) + return tuple( + set( + chain( + *[ + typ if isinstance(typ, tuple) else (typ,) + for typ in all_serialized_types + ] + ) + ) + ) def _serialize(self, value: typing.Any) -> typing.Any: for prop in self.element_types: @@ -568,8 +599,10 @@ def _serialize(self, value: typing.Any) -> typing.Any: return prop.serialize(value) except ValueError: pass - raise ValueError("An unexpected error occurred while trying to serialize {} to one " - "of the following types: {}.".format(value, self.serialized_types)) + raise ValueError( + "An unexpected error occurred while trying to serialize {} to one " + "of the following types: {}.".format(value, self.serialized_types) + ) def _deserialize(self, value: typing.Any) -> typing.Any: for prop in self.element_types: @@ -577,33 +610,39 @@ def _deserialize(self, value: typing.Any) -> typing.Any: return prop.deserialize(value) except ValueError: pass - raise ValueError("An unexpected error occurred while trying to deserialize {} to " - "one of the following types: {}.".format(value, self.underlying_types)) + raise ValueError( + "An unexpected error occurred while trying to deserialize {} to " + "one of the following types: {}.".format(value, self.underlying_types) + ) class SpecifiedMixedList(PropertyCollection[list, list]): """A finite list in which the type of each entry is specified.""" - def __init__(self, - element_types: typing.Sequence[typing.Union[Property, typing.Type[Property]]], - serialization_path: typing.Optional[str] = None, - *, - serializable: bool = True, - deserializable: bool = True, - default: typing.Optional[DeserializedType] = None, - override: bool = False, - use_init: bool = False - ): - super().__init__(serialization_path=serialization_path, - serializable=serializable, - deserializable=deserializable, - default=default, - override=override, - use_init=use_init) + def __init__( + self, + element_types: typing.Sequence[typing.Union[Property, typing.Type[Property]]], + serialization_path: typing.Optional[str] = None, + *, + serializable: bool = True, + deserializable: bool = True, + default: typing.Optional[DeserializedType] = None, + override: bool = False, + use_init: bool = False, + ): + super().__init__( + serialization_path=serialization_path, + serializable=serializable, + deserializable=deserializable, + default=default, + override=override, + use_init=use_init, + ) if not isinstance(element_types, list): raise ValueError("element types must be a list: {}".format(element_types)) - self.element_types: typing.List[Property, ...] = \ - [el if isinstance(el, Property) else el() for el in element_types] + self.element_types: typing.List[Property, ...] = [ + el if isinstance(el, Property) else el() for el in element_types + ] @property def underlying_types(self): @@ -615,28 +654,32 @@ def serialized_types(self): def _deserialize(self, value: list) -> tuple: if len(value) > len(self.element_types): - raise ValueError("Cannot deserialize value {}, as it has more elements " - "than expected for list {}".format(value, self.element_types)) + raise ValueError( + "Cannot deserialize value {}, as it has more elements " + "than expected for list {}".format(value, self.element_types) + ) deserialized = [] for element, element_type in zip(value, self.element_types): deserialized.append(element_type.deserialize(element)) # If there are more element types than elements, append default values - for element_type in self.element_types[len(value):]: + for element_type in self.element_types[len(value) :]: deserialized.append(element_type.default) return deserialized def _serialize(self, value: tuple) -> list: if len(value) > len(self.element_types): - raise ValueError("Cannot serialize value {}, as it has more elements " - "than expected for list {}".format(value, self.element_types)) + raise ValueError( + "Cannot serialize value {}, as it has more elements " + "than expected for list {}".format(value, self.element_types) + ) serialized = [] for element, element_type in zip(value, self.element_types): serialized.append(element_type.serialize(element)) # If there are more element types than elements, append serialized default values - for element_type in self.element_types[len(value):]: + for element_type in self.element_types[len(value) :]: serialized.append(element_type.serialize(element_type.default)) return serialized @@ -644,8 +687,10 @@ def _serialize(self, value: tuple) -> list: def _set_elements(self, value): elems = [] if len(value) > len(self.element_types): - raise ValueError("Cannot serialize value {}, as it has more elements " - "than expected for list {}".format(value, self.element_types)) + raise ValueError( + "Cannot serialize value {}, as it has more elements " + "than expected for list {}".format(value, self.element_types) + ) for element, element_type in zip(value, self.element_types): if isinstance(element_type, PropertyCollection): val_to_append = element_type._set_elements(element) @@ -656,30 +701,32 @@ def _set_elements(self, value): elems.append(val_to_append) # If there are more element types than elements, append serialized default values - for element_type in self.element_types[len(value):]: + for element_type in self.element_types[len(value) :]: elems.append(element_type.default) return elems class Enumeration(Property[BaseEnumeration, str]): - - def __init__(self, - klass: typing.Type[typing.Any], - serialization_path: typing.Optional[str] = None, - *, - serializable: bool = True, - deserializable: bool = True, - default: typing.Optional[DeserializedType] = None, - override: bool = False, - use_init: bool = False - ): - super().__init__(serialization_path=serialization_path, - serializable=serializable, - deserializable=deserializable, - default=default, - override=override, - use_init=use_init) + def __init__( + self, + klass: typing.Type[typing.Any], + serialization_path: typing.Optional[str] = None, + *, + serializable: bool = True, + deserializable: bool = True, + default: typing.Optional[DeserializedType] = None, + override: bool = False, + use_init: bool = False, + ): + super().__init__( + serialization_path=serialization_path, + serializable=serializable, + deserializable=deserializable, + default=default, + override=override, + use_init=use_init, + ) self.klass = klass @property @@ -703,33 +750,37 @@ def _fields_map(klass: typing.Type) -> typing.Dict[str, Property]: return { k: v for x in reversed(klass.__mro__) # Classes at the front trump - for k, v in x.__dict__.items() if isinstance(v, Property) + for k, v in x.__dict__.items() + if isinstance(v, Property) } class Object(PropertyCollection[typing.Any, dict]): - - def __init__(self, - klass: typing.Type[typing.Any], - serialization_path: typing.Optional[str] = None, - *, - serializable: bool = True, - deserializable: bool = True, - default: typing.Optional[DeserializedType] = None, - override: bool = False, - use_init: bool = False - ): - super().__init__(serialization_path=serialization_path, - serializable=serializable, - deserializable=deserializable, - default=default, - override=override, - use_init=use_init) + def __init__( + self, + klass: typing.Type[typing.Any], + serialization_path: typing.Optional[str] = None, + *, + serializable: bool = True, + deserializable: bool = True, + default: typing.Optional[DeserializedType] = None, + override: bool = False, + use_init: bool = False, + ): + super().__init__( + serialization_path=serialization_path, + serializable=serializable, + deserializable=deserializable, + default=default, + override=override, + use_init=use_init, + ) self.klass = klass # We need to use __dict__ here because other access methods will invoke __get__ self.fields: typing.Dict[str, Property] = _fields_map(self.klass) - self.polymorphic = "get_type" in self.klass.__dict__ and\ - issubclass(self.klass, PolymorphicSerializable) + self.polymorphic = "get_type" in self.klass.__dict__ and issubclass( + self.klass, PolymorphicSerializable + ) @property def underlying_types(self): @@ -746,8 +797,10 @@ def _deserialize(self, data: dict) -> typing.Any: # Maybe there are no fields because we hit a gemd-python class if issubclass(self.klass, DictSerializable): return DictSerializable.build(data) - raise AttributeError("Tried to deserialize to {!r}, which has no fields and is not an" - " explicitly serializable class".format(self.klass)) + raise AttributeError( + "Tried to deserialize to {!r}, which has no fields and is not an" + " explicitly serializable class".format(self.klass) + ) values = {} init_props = set() @@ -761,12 +814,14 @@ def _deserialize(self, data: dict) -> typing.Any: if len(init_props) > 0: try: - instance = self.klass(**{k: v for k, v in values.items() if k in init_props}) + instance = self.klass( + **{k: v for k, v in values.items() if k in init_props} + ) except TypeError as e: # Check if it's because the signature was wrong sig = signature(self.klass.__init__) for arg, param in sig.parameters.items(): - if arg not in init_props | {'self'}: + if arg not in init_props | {"self"}: if param.default is param.empty: raise AttributeError( f"{self.klass} has at least 1 property marked as `use_init`, " @@ -798,8 +853,10 @@ def _serialize(self, obj: typing.Any) -> dict: try: return obj.dump() except AttributeError: - raise AttributeError("Tried to serialize object {!r} of type {}, which has " - "neither fields not a dump() method.".format(obj, type(obj))) + raise AttributeError( + "Tried to serialize object {!r} of type {}, which has " + "neither fields not a dump() method.".format(obj, type(obj)) + ) for property_name, field in self.fields.items(): if field.serializable: value = getattr(obj, property_name) @@ -807,7 +864,7 @@ def _serialize(self, obj: typing.Any) -> dict: return serialized def __str__(self): - return ''.format(self.klass.__name__, self.serialization_path) + return "".format(self.klass.__name__, self.serialization_path) def _set_elements(self, value): if issubclass(type(value), self.klass): @@ -835,23 +892,25 @@ class LinkOrElse(PropertyCollection[typing.Union[Serializable, LinkByUID], dict] generic Link object. """ - def __init__(self, - klass: typing.Type[typing.Any] = Serializable, - serialization_path: typing.Optional[str] = None, - *, - serializable: bool = True, - deserializable: bool = True, - default: typing.Optional[DeserializedType] = None, - override: bool = False, - use_init: bool = False - ): + def __init__( + self, + klass: typing.Type[typing.Any] = Serializable, + serialization_path: typing.Optional[str] = None, + *, + serializable: bool = True, + deserializable: bool = True, + default: typing.Optional[DeserializedType] = None, + override: bool = False, + use_init: bool = False, + ): super().__init__( serialization_path=serialization_path, serializable=serializable, deserializable=deserializable, default=default, override=override, - use_init=use_init) + use_init=use_init, + ) self.klass = klass @property @@ -869,13 +928,15 @@ def _serialize(self, value: typing.Any) -> dict: return value.dump() def _deserialize(self, value: dict): - if 'type' in value: - target = DictSerializable.class_mapping[value['type']] + if "type" in value: + target = DictSerializable.class_mapping[value["type"]] try: return target.build(value) except TypeError as e: # TODO: Consider migrating this ValueError to a TypeError for 3 - match = re.search(r"__init__.* missing (\d+) required \w+ arguments: (.+)", str(e)) + match = re.search( + r"__init__.* missing (\d+) required \w+ arguments: (.+)", str(e) + ) if match: raise ValueError( f"{match.group(1)} missing required " @@ -883,31 +944,37 @@ def _deserialize(self, value: dict): ) else: raise e - raise Exception("Serializable object that is being pointed to must have a self-contained " - "build() method that does not call deserialize().") + raise Exception( + "Serializable object that is being pointed to must have a self-contained " + "build() method that does not call deserialize()." + ) def _set_elements(self, value): return value -class Optional(PropertyCollection[typing.Optional[typing.Any], typing.Optional[typing.Any]]): - - def __init__(self, - prop: typing.Union[Property, typing.Type[Property]], - serialization_path: typing.Optional[str] = None, - *, - serializable: bool = True, - deserializable: bool = True, - default: typing.Optional[DeserializedType] = None, - override: bool = False, - use_init: bool = False - ): - super().__init__(serialization_path=serialization_path, - serializable=serializable, - deserializable=deserializable, - default=default, - override=override, - use_init=use_init) +class Optional( + PropertyCollection[typing.Optional[typing.Any], typing.Optional[typing.Any]] +): + def __init__( + self, + prop: typing.Union[Property, typing.Type[Property]], + serialization_path: typing.Optional[str] = None, + *, + serializable: bool = True, + deserializable: bool = True, + default: typing.Optional[DeserializedType] = None, + override: bool = False, + use_init: bool = False, + ): + super().__init__( + serialization_path=serialization_path, + serializable=serializable, + deserializable=deserializable, + default=default, + override=override, + use_init=use_init, + ) self.prop = prop if isinstance(prop, Property) else prop() self.optional = True @@ -927,14 +994,18 @@ def serialized_types(self): else: return constituent_types, type(None) - def _deserialize(self, data: typing.Optional[typing.Any]) -> typing.Optional[typing.Any]: + def _deserialize( + self, data: typing.Optional[typing.Any] + ) -> typing.Optional[typing.Any]: return self.prop.deserialize(data) if data is not None else None - def _serialize(self, obj: typing.Optional[typing.Any]) -> typing.Optional[typing.Any]: + def _serialize( + self, obj: typing.Optional[typing.Any] + ) -> typing.Optional[typing.Any]: return self.prop.serialize(obj) if obj is not None else None def __str__(self): - return ''.format(self.prop, self.serialization_path) + return "".format(self.prop, self.serialization_path) def _set_elements(self, value): elem = None @@ -958,27 +1029,32 @@ class Mapping(PropertyCollection[dict, dict]): key value pairs and converts them to a dict. """ - def __init__(self, - keys_type: typing.Union[Property, typing.Type[Property]], - values_type: typing.Union[Property, typing.Type[Property]], - serialization_path: typing.Optional[str] = None, - *, - serializable: bool = True, - deserializable: bool = True, - default: typing.Optional[dict] = None, - override: bool = False, - use_init: bool = False, - ser_as_list_of_pairs: bool = False): - super().__init__(serialization_path=serialization_path, - serializable=serializable, - deserializable=deserializable, - default=default, - override=override, - use_init=use_init - ) + def __init__( + self, + keys_type: typing.Union[Property, typing.Type[Property]], + values_type: typing.Union[Property, typing.Type[Property]], + serialization_path: typing.Optional[str] = None, + *, + serializable: bool = True, + deserializable: bool = True, + default: typing.Optional[dict] = None, + override: bool = False, + use_init: bool = False, + ser_as_list_of_pairs: bool = False, + ): + super().__init__( + serialization_path=serialization_path, + serializable=serializable, + deserializable=deserializable, + default=default, + override=override, + use_init=use_init, + ) self.keys_type = keys_type if isinstance(keys_type, Property) else keys_type() - self.values_type = values_type if isinstance(values_type, Property) else values_type() + self.values_type = ( + values_type if isinstance(values_type, Property) else values_type() + ) self.ser_as_list_of_pairs = ser_as_list_of_pairs @property diff --git a/src/citrine/_serialization/serializable.py b/src/citrine/_serialization/serializable.py index b8cdff16f..4eccc254d 100644 --- a/src/citrine/_serialization/serializable.py +++ b/src/citrine/_serialization/serializable.py @@ -1,7 +1,7 @@ from typing import Generic, TypeVar -Self = TypeVar('Self', bound='Serializable') +Self = TypeVar("Self", bound="Serializable") class Serializable(Generic[Self]): @@ -16,12 +16,14 @@ def _pre_build(cls, data: dict) -> dict: def build(cls, data: dict) -> Self: """Build an instance of this object from given data.""" from citrine._serialization import properties + pre_built = cls._pre_build(data) return properties.Object(cls).deserialize(pre_built) def dump(self) -> dict: """Dump this instance.""" from citrine._serialization import properties + serialized = properties.Object(type(self)).serialize(self) return self._post_dump(serialized) diff --git a/src/citrine/_session.py b/src/citrine/_session.py index c403a574b..bda26521a 100644 --- a/src/citrine/_session.py +++ b/src/citrine/_session.py @@ -21,7 +21,8 @@ NotFound, Unauthorized, UnauthorizedRefreshToken, - WorkflowNotReadyException) + WorkflowNotReadyException, +) # Choose a 5-second buffer so that there's no chance of the access token # expiring during the check for expiration @@ -32,25 +33,29 @@ class Session(requests.Session): """Wrapper around requests.Session that is both refresh-token and schema aware.""" - def __init__(self, - refresh_token: str = None, - *, - scheme: str = None, - host: str = None, - port: Optional[str] = None): + def __init__( + self, + refresh_token: str = None, + *, + scheme: str = None, + host: str = None, + port: Optional[str] = None, + ): super().__init__() if refresh_token is None: - refresh_token = environ.get('CITRINE_API_KEY') + refresh_token = environ.get("CITRINE_API_KEY") if scheme is None: - scheme = 'https' + scheme = "https" if host is None: - host = environ.get('CITRINE_API_HOST') + host = environ.get("CITRINE_API_HOST") if host is None: - raise ValueError("No host passed and environmental " - "variable CITRINE_API_HOST not set.") + raise ValueError( + "No host passed and environmental " + "variable CITRINE_API_HOST not set." + ) self.scheme: str = scheme - self.authority = ':'.join(([host] if host else []) + ([port] if port else [])) + self.authority = ":".join(([host] if host else []) + ([port] if port else [])) self.refresh_token: str = refresh_token self.access_token: Optional[str] = None self.access_token_expiration: datetime = datetime.now(timezone.utc) @@ -59,17 +64,16 @@ def __init__(self, platform.python_implementation(), platform.python_version(), requests.__version__, - citrine.__version__) + citrine.__version__, + ) # Following scheme:[//authority]path[?query][#fragment] (https://en.wikipedia.org/wiki/URL) - self.headers.update({ - "Content-Type": "application/json", - "User-Agent": agent}) + self.headers.update({"Content-Type": "application/json", "User-Agent": agent}) # Default parameters for S3 connectivity. Can be changed by tests. self.s3_endpoint_url = None self.s3_use_ssl = True - self.s3_addressing_style = 'auto' + self.s3_addressing_style = "auto" # Feature flag for enabling the use of Dataset idempotent PUT. Will be removed # in a future release. @@ -78,32 +82,38 @@ def __init__(self, # Custom adapter so we can use custom retry parameters. The default HTTP status # codes for retries are [503, 413, 429]. We're using status_force list to add # additional codes to retry on, focusing on specific CloudFlare 5XX errors. - retries = Retry(total=10, - connect=5, - read=5, - status=5, - backoff_factor=0.25, - status_forcelist=[500, 502, 504, 520, 521, 522, 524, 527]) + retries = Retry( + total=10, + connect=5, + read=5, + status=5, + backoff_factor=0.25, + status_forcelist=[500, 502, 504, 520, 521, 522, 524, 527], + ) adapter = requests.adapters.HTTPAdapter(max_retries=retries) - self.mount('https://', adapter) - self.mount('http://', adapter) + self.mount("https://", adapter) + self.mount("http://", adapter) # Requests has its own set of exceptions that do not inherit from the # built-in exceptions. The built-in ConnectionError handles 4 different # child exceptions: https://docs.python.org/3/library/exceptions.html#ConnectionError - self.retry_errs = (ConnectionError, - requests.exceptions.ConnectionError, - requests.exceptions.ChunkedEncodingError) + self.retry_errs = ( + ConnectionError, + requests.exceptions.ConnectionError, + requests.exceptions.ChunkedEncodingError, + ) self._refresh_access_token() - def _versioned_base_url(self, version: str = 'v1'): - return urlunsplit(( - self.scheme, - self.authority, - format_escaped_url('api/{}/', version), - '', # query string - '' # fragment - )) + def _versioned_base_url(self, version: str = "v1"): + return urlunsplit( + ( + self.scheme, + self.authority, + format_escaped_url("api/{}/", version), + "", # query string + "", # fragment + ) + ) def _is_access_token_expired(self): buffered_expire = self.access_token_expiration - EXPIRATION_BUFFER @@ -111,21 +121,22 @@ def _is_access_token_expired(self): def _refresh_access_token(self) -> None: """Optionally refresh our access token (if the previous one is about to expire).""" - data = {'refresh_token': self.refresh_token} + data = {"refresh_token": self.refresh_token} - response = self._request_with_retry('POST', self._versioned_base_url() + 'tokens/refresh', - json=data) + response = self._request_with_retry( + "POST", self._versioned_base_url() + "tokens/refresh", json=data + ) if response.status_code != 200: raise UnauthorizedRefreshToken() - self.access_token = response.json()['access_token'] + self.access_token = response.json()["access_token"] self.access_token_expiration = datetime.fromtimestamp( jwt.decode( self.access_token, options={"verify_signature": False}, - algorithms=["HS256"] - )['exp'], - timezone.utc + algorithms=["HS256"], + )["exp"], + timezone.utc, ) # Explicitly set an updated 'auth', so as to not rely on implicit cookie handling. @@ -139,76 +150,84 @@ def _request_with_retry(self, method, uri, **kwargs): try: response = self.request(method, uri, **kwargs) except self.retry_errs as e: - logger.warning('{} seen, retrying request'.format(repr(e))) + logger.warning("{} seen, retrying request".format(repr(e))) response = self.request(method, uri, **kwargs) return response - def checked_request(self, method: str, path: str, - version: str = 'v1', **kwargs) -> requests.Response: + def checked_request( + self, method: str, path: str, version: str = "v1", **kwargs + ) -> requests.Response: """Check response status code and throw an exception if relevant.""" - logger.debug('BEGIN request details:') - logger.debug('\tmethod: {}'.format(method)) - logger.debug('\tpath: {}'.format(path)) - logger.debug('\tversion: {}'.format(version)) + logger.debug("BEGIN request details:") + logger.debug("\tmethod: {}".format(method)) + logger.debug("\tpath: {}".format(path)) + logger.debug("\tversion: {}".format(version)) for k, v in kwargs.items(): - logger.debug(f'\t{k}: {v}') + logger.debug(f"\t{k}: {v}") if self._is_access_token_expired(): self._refresh_access_token() - uri = self._versioned_base_url(version) + path.lstrip('/') + uri = self._versioned_base_url(version) + path.lstrip("/") - logger.debug('\turi: {}'.format(uri)) + logger.debug("\turi: {}".format(uri)) for k, v in kwargs.items(): - logger.debug('\t{}: {}'.format(k, v)) - logger.debug('END request details.') + logger.debug("\t{}: {}".format(k, v)) + logger.debug("END request details.") response = self._request_with_retry(method, uri, **kwargs) try: - if response.status_code == 401 and response.json().get("reason") == "invalid-token": + if ( + response.status_code == 401 + and response.json().get("reason") == "invalid-token" + ): self._refresh_access_token() response = self._request_with_retry(method, uri, **kwargs) except AttributeError: # Catch AttributeErrors and log response # The 401 status will be handled further down - logger.error("Failed to decode json from response: {}".format(response.text)) + logger.error( + "Failed to decode json from response: {}".format(response.text) + ) except ValueError: # Ignore ValueErrors thrown by attempting to decode json bodies. This # might occur if we get a 401 response without a JSON body pass if 200 <= response.status_code <= 299: - logger.info('%s %s %s', response.status_code, method, path) + logger.info("%s %s %s", response.status_code, method, path) return response else: stacktrace = self._extract_response_stacktrace(response) if stacktrace is not None: - logger.error('Response arrived with stacktrace:') + logger.error("Response arrived with stacktrace:") logger.error(stacktrace) if response.status_code == 400: - logger.error('%s %s %s', response.status_code, method, path) + logger.error("%s %s %s", response.status_code, method, path) logger.error(response.text) raise BadRequest(path, response) elif response.status_code == 401: - logger.error('%s %s %s', response.status_code, method, path) + logger.error("%s %s %s", response.status_code, method, path) raise Unauthorized(path, response) elif response.status_code == 403: - logger.error('%s %s %s', response.status_code, method, path) + logger.error("%s %s %s", response.status_code, method, path) raise Unauthorized(path, response) elif response.status_code == 404: - logger.error('%s %s %s', response.status_code, method, path) + logger.error("%s %s %s", response.status_code, method, path) raise NotFound(path, response) elif response.status_code == 409: - logger.debug('%s %s %s', response.status_code, method, path) + logger.debug("%s %s %s", response.status_code, method, path) raise Conflict(path, response) elif response.status_code == 425: - logger.debug('%s %s %s', response.status_code, method, path) - msg = 'Cant execute at this time. Try again later. Error: {}'.format(response.text) + logger.debug("%s %s %s", response.status_code, method, path) + msg = "Cant execute at this time. Try again later. Error: {}".format( + response.text + ) raise WorkflowNotReadyException(msg) else: - logger.error('%s %s %s', response.status_code, method, path) + logger.error("%s %s %s", response.status_code, method, path) raise CitrineException(response.text) @staticmethod @@ -216,7 +235,7 @@ def _extract_response_stacktrace(response: Response) -> Optional[str]: try: json_value = response.json() if isinstance(json_value, dict): - return json_value.get('debug_stacktrace') + return json_value.get("debug_stacktrace") except ValueError: pass return None @@ -258,56 +277,63 @@ def _extract_response_json(path, response) -> dict: lacked the required 'application/json' Content-Type in the header.""") except JSONDecodeError as err: - logger.info('Response at path %s with status code %s failed json parsing with' - ' exception %s. Returning empty value.', - path, - response.status_code, - err.msg) + logger.info( + "Response at path %s with status code %s failed json parsing with" + " exception %s. Returning empty value.", + path, + response.status_code, + err.msg, + ) return extracted_response @staticmethod - def cursor_paged_resource(base_method: Callable[..., dict], path: str, - forward: bool = True, per_page: int = 100, - version: str = 'v2', **kwargs) -> Iterator[dict]: + def cursor_paged_resource( + base_method: Callable[..., dict], + path: str, + forward: bool = True, + per_page: int = 100, + version: str = "v2", + **kwargs, + ) -> Iterator[dict]: """ Returns a flat generator of results for an API query. Results are fetched in chunks of size `per_page` and loaded lazily. """ - params = kwargs.get('params', {}) - params['forward'] = forward - params['ascending'] = forward - params['per_page'] = per_page - kwargs['params'] = params + params = kwargs.get("params", {}) + params["forward"] = forward + params["ascending"] = forward + params["per_page"] = per_page + kwargs["params"] = params while True: response_json = base_method(path, version=version, **kwargs) - for obj in response_json['contents']: + for obj in response_json["contents"]: yield obj - cursor = response_json.get('next') + cursor = response_json.get("next") if cursor is None: break - params['cursor'] = cursor + params["cursor"] = cursor def checked_post(self, path: str, json: dict, **kwargs) -> Response: """Execute a POST request to a URL and utilize error filtering on the response.""" - return self.checked_request('POST', path, json=json, **kwargs) + return self.checked_request("POST", path, json=json, **kwargs) def checked_put(self, path: str, json: dict, **kwargs) -> Response: """Execute a PUT request to a URL and utilize error filtering on the response.""" - return self.checked_request('PUT', path, json=json, **kwargs) + return self.checked_request("PUT", path, json=json, **kwargs) def checked_patch(self, path: str, json: dict, **kwargs) -> Response: """Execute a PATCH request to a URL and utilize error filtering on the response.""" - return self.checked_request('PATCH', path, json=json, **kwargs) + return self.checked_request("PATCH", path, json=json, **kwargs) def checked_delete(self, path: str, **kwargs) -> Response: """Execute a DELETE request to a URL and utilize error filtering on the response.""" - return self.checked_request('DELETE', path, **kwargs) + return self.checked_request("DELETE", path, **kwargs) def checked_get(self, path: str, **kwargs) -> Response: """Execute a GET request to a URL and utilize error filtering on the response.""" - return self.checked_request('GET', path, **kwargs) + return self.checked_request("GET", path, **kwargs) class BearerAuth(requests.auth.AuthBase): diff --git a/src/citrine/_utils/batcher.py b/src/citrine/_utils/batcher.py index dcc171807..f92e3a62a 100644 --- a/src/citrine/_utils/batcher.py +++ b/src/citrine/_utils/batcher.py @@ -11,7 +11,9 @@ class Batcher(ABC): """Base class for Data Concepts batching routines.""" @abstractmethod - def batch(self, objects: Iterable[DataConcepts], batch_size) -> List[List[DataConcepts]]: + def batch( + self, objects: Iterable[DataConcepts], batch_size + ) -> List[List[DataConcepts]]: """Collect a list of DataConcepts into batches according to some batching algorithm.""" @staticmethod @@ -28,24 +30,30 @@ def by_dependency(): class BatchByType(Batcher): """Batching by object type.""" - def batch(self, objects: Iterable[DataConcepts], batch_size) -> List[List[DataConcepts]]: + def batch( + self, objects: Iterable[DataConcepts], batch_size + ) -> List[List[DataConcepts]]: """Collect object batches by type, following an order that will satisfy prereqs.""" batches = list() by_type = defaultdict(list) seen = {} for obj in objects: - if obj.to_link() in seen: # Repeat in the iterable; don't add it to the batch + if ( + obj.to_link() in seen + ): # Repeat in the iterable; don't add it to the batch if seen[obj.to_link()] != obj: # verify that it's a replicate raise ValueError(f"Colliding objects for {obj.to_link()}") else: by_type[obj.typ].append(obj) for scope in obj.uids: seen[obj.to_link(scope)] = obj - typ_groups = sorted(list(by_type.values()), key=lambda x: writable_sort_order(x[0])) + typ_groups = sorted( + list(by_type.values()), key=lambda x: writable_sort_order(x[0]) + ) for typ_group in typ_groups: num_batches = len(typ_group) // batch_size for batch_num in range(num_batches + 1): - batch = typ_group[batch_num * batch_size: (batch_num + 1) * batch_size] + batch = typ_group[batch_num * batch_size : (batch_num + 1) * batch_size] batches.append(batch) for i in reversed(range(len(batches) - 1)): if len(batches[i]) + len(batches[i + 1]) <= batch_size: @@ -58,7 +66,9 @@ def batch(self, objects: Iterable[DataConcepts], batch_size) -> List[List[DataCo class BatchByDependency(Batcher): """Batching by clusters where nothing references anything outside the cluster.""" - def batch(self, objects: Iterable[DataConcepts], batch_size) -> List[List[DataConcepts]]: + def batch( + self, objects: Iterable[DataConcepts], batch_size + ) -> List[List[DataConcepts]]: """Collect object batches that are internally consistent for dry_run object tests.""" # Collect shallow dependences, UID references, and type-based clusters depends = dict() @@ -72,20 +82,27 @@ def batch(self, objects: Iterable[DataConcepts], batch_size) -> List[List[DataCo # Deep dependencies w/ objects only, build inverse index # This takes a second loop because we need to build up all derefs first supported_by = defaultdict(list) - type_groups = sorted(list(by_type.values()), key=lambda x: writable_sort_order(x[0])) + type_groups = sorted( + list(by_type.values()), key=lambda x: writable_sort_order(x[0]) + ) for type_group in type_groups: for obj in type_group: # Collect objects of interest that we are supposed to load # Note depends contains both objects & links; obj_set is everything in the call - local_set = {index.get(x, x) for x in depends[obj] if index.get(x, x) in obj_set} + local_set = { + index.get(x, x) for x in depends[obj] if index.get(x, x) in obj_set + } full_set = set(local_set) if len(full_set) > batch_size: - raise ValueError(f"Object {obj.name} has more than {batch_size} dependencies.") + raise ValueError( + f"Object {obj.name} has more than {batch_size} dependencies." + ) for subobj in local_set: full_set.update(depends[subobj]) - depends[obj] = sorted(list(full_set), - key=lambda x: writable_sort_order(x)) + depends[obj] = sorted( + list(full_set), key=lambda x: writable_sort_order(x) + ) for dependant in reversed(depends[obj]): supported_by[dependant].append(obj) diff --git a/src/citrine/_utils/functions.py b/src/citrine/_utils/functions.py index a2c6c3a2a..6257e477a 100644 --- a/src/citrine/_utils/functions.py +++ b/src/citrine/_utils/functions.py @@ -22,20 +22,28 @@ def get_object_id(object_or_id): if isinstance(object_or_id, LinkByUID): if object_or_id.scope == CITRINE_SCOPE: return object_or_id.id - raise ValueError("LinkByUID must be scoped to citrine scope {}, " - "instead is {}".format(CITRINE_SCOPE, object_or_id.scope)) - raise TypeError("{} must be a data concepts object or LinkByUID".format(object_or_id)) + raise ValueError( + "LinkByUID must be scoped to citrine scope {}, instead is {}".format( + CITRINE_SCOPE, object_or_id.scope + ) + ) + raise TypeError( + "{} must be a data concepts object or LinkByUID".format(object_or_id) + ) def validate_type(data_dict: dict, type_name: str) -> dict: """Ensure that dict has field 'type' with given value.""" data_dict_copy = data_dict.copy() - if 'type' in data_dict_copy: - if data_dict_copy['type'] != type_name: + if "type" in data_dict_copy: + if data_dict_copy["type"] != type_name: raise Exception( - "Object type must be {}, but was instead {}.".format(type_name, data_dict['type'])) + "Object type must be {}, but was instead {}.".format( + type_name, data_dict["type"] + ) + ) else: - data_dict_copy['type'] = type_name + data_dict_copy["type"] = type_name return data_dict_copy @@ -69,7 +77,7 @@ def replace_objects_with_links(json: dict) -> dict: def object_to_link(obj: Any) -> Any: """See if an object is a dictionary that can be converted into a Link, and if so, convert.""" if isinstance(obj, dict): - if 'type' in obj and 'uids' in obj and obj['type'] != LinkByUID.typ: + if "type" in obj and "uids" in obj and obj["type"] != LinkByUID.typ: return object_to_link_by_uid(obj) else: return replace_objects_with_links(obj) @@ -81,8 +89,9 @@ def object_to_link(obj: Any) -> Any: def object_to_link_by_uid(json: dict) -> dict: """Convert an object dictionary into a LinkByUID dictionary, if possible.""" from citrine.resources.data_concepts import CITRINE_SCOPE - if 'uids' in json: - uids = json['uids'] + + if "uids" in json: + uids = json["uids"] if not isinstance(uids, dict) or not uids: return json if CITRINE_SCOPE in uids: @@ -111,8 +120,9 @@ def rewrite_s3_links_locally(url: str, s3_endpoint_url: str = None) -> str: if s3_endpoint_url is not None: # Given an explicit endpoint to use instead parsed_s3_endpoint = urlparse(s3_endpoint_url) - return parsed_url._replace(scheme=parsed_s3_endpoint.scheme, - netloc=parsed_s3_endpoint.netloc).geturl() + return parsed_url._replace( + scheme=parsed_s3_endpoint.scheme, netloc=parsed_s3_endpoint.netloc + ).geturl() else: # Else return the URL unmodified return url @@ -131,7 +141,7 @@ def write_file_locally(content, local_path: Union[str, Path]): raise ValueError(f"A filename must be provided in the path ({local_path})") local_path.parent.mkdir(parents=True, exist_ok=True) - local_path.open(mode='wb').write(content) + local_path.open(mode="wb").write(content) class MigratedClassMeta(ABCMeta): @@ -170,23 +180,32 @@ class MyClass(NewMyClass, deprecated_in="1.2.3", removed_in="2.0.0", def __new__(mcs, *args, deprecated_in=None, removed_in=None, **kwargs): # noqa: D102 return super().__new__(mcs, *args, **kwargs) - def __init__(cls, name, bases, *args, deprecated_in=None, removed_in=None, **kwargs): + def __init__( + cls, name, bases, *args, deprecated_in=None, removed_in=None, **kwargs + ): super().__init__(name, bases, *args, **kwargs) if not any(isinstance(b, MigratedClassMeta) for b in bases): # First generation if len(bases) != 1: - raise TypeError(f"Migrated Classes must reference precisely one target. " - f"{bases} found.") + raise TypeError( + f"Migrated Classes must reference precisely one target. " + f"{bases} found." + ) if deprecated_in is None or removed_in is None: - raise TypeError("Migrated Classes must include `deprecated_in` " - "and `removed_in` arguments.") + raise TypeError( + "Migrated Classes must include `deprecated_in` " + "and `removed_in` arguments." + ) cls._deprecation_info[cls] = (bases[0], deprecated_in, removed_in) def _new(*args_, **kwargs_): - warn(f"Importing {name} from {cls.__module__} is deprecated as of " - f"{deprecated_in} and will be removed in {removed_in}. " - f"Please import {bases[0].__name__} from {bases[0].__module__} instead.", - DeprecationWarning, stacklevel=2) + warn( + f"Importing {name} from {cls.__module__} is deprecated as of " + f"{deprecated_in} and will be removed in {removed_in}. " + f"Please import {bases[0].__name__} from {bases[0].__module__} instead.", + DeprecationWarning, + stacklevel=2, + ) return bases[0](*args_[1:], **kwargs_) cls.__new__ = _new @@ -195,18 +214,24 @@ def _new(*args_, **kwargs_): if base in cls._deprecation_info: # Second generation alias, this_deprecated_in, this_removed_in = cls._deprecation_info[base] - warn(f"Importing {base.__name__} from {base.__module__} is deprecated as of " - f"{this_deprecated_in} and will be removed in {this_removed_in}. " - f"Please import {alias.__name__} from {alias.__module__} instead.", - DeprecationWarning, stacklevel=2) + warn( + f"Importing {base.__name__} from {base.__module__} is deprecated as of " + f"{this_deprecated_in} and will be removed in {this_removed_in}. " + f"Please import {alias.__name__} from {alias.__module__} instead.", + DeprecationWarning, + stacklevel=2, + ) def __instancecheck__(cls, instance): - return any(cls.__subclasscheck__(c) - for c in {type(instance), instance.__class__}) + return any( + cls.__subclasscheck__(c) for c in {type(instance), instance.__class__} + ) def __subclasscheck__(cls, subclass): try: - return issubclass(subclass, cls._deprecation_info.get(cls, (type(None), ))[0]) + return issubclass( + subclass, cls._deprecation_info.get(cls, (type(None),))[0] + ) except RecursionError: return False @@ -216,16 +241,15 @@ def generate_shared_meta(target: type): if issubclass(MigratedClassMeta, type(target)): return MigratedClassMeta else: + class _CustomMeta(MigratedClassMeta, type(target)): pass + return _CustomMeta def migrate_deprecated_argument( - new_arg: Optional[Any], - new_arg_name: str, - old_arg: Optional[Any], - old_arg_name: str + new_arg: Optional[Any], new_arg_name: str, old_arg: Optional[Any], old_arg_name: str ) -> Any: """ Facilitates the migration of an argument's name. @@ -253,22 +277,22 @@ def migrate_deprecated_argument( """ if old_arg is not None: - warn(f"\'{old_arg_name}\' is deprecated in favor of \'{new_arg_name}\'", - DeprecationWarning) + warn( + f"'{old_arg_name}' is deprecated in favor of '{new_arg_name}'", + DeprecationWarning, + ) if new_arg is None: return old_arg else: - raise ValueError(f"Cannot specify both \'{new_arg_name}\' and \'{new_arg_name}\'") + raise ValueError( + f"Cannot specify both '{new_arg_name}' and '{new_arg_name}'" + ) elif new_arg is None: - raise ValueError(f"Please specify \'{new_arg_name}\'") + raise ValueError(f"Please specify '{new_arg_name}'") return new_arg -def format_escaped_url( - template: str, - *args, - **kwargs -) -> str: +def format_escaped_url(template: str, *args, **kwargs) -> str: """ Escape arguments with percent encoding and bind them to a template of a URL. @@ -290,21 +314,23 @@ def format_escaped_url( the formatted URL """ - return template.format(*[quote(str(x), safe='') for x in args], - **{k: quote(str(v), safe='') for (k, v) in kwargs.items()} - ) - - -def resource_path(*, - path_template: str, - uid: Optional[Union[UUID, str]] = None, - action: Union[str, Sequence[str]] = [], - query_terms: Dict[str, str] = {}, - **kwargs - ) -> str: + return template.format( + *[quote(str(x), safe="") for x in args], + **{k: quote(str(v), safe="") for (k, v) in kwargs.items()}, + ) + + +def resource_path( + *, + path_template: str, + uid: Optional[Union[UUID, str]] = None, + action: Union[str, Sequence[str]] = [], + query_terms: Dict[str, str] = {}, + **kwargs, +) -> str: """Construct a url from a base path and, optionally, id and/or action.""" base = urlparse(path_template) - path = base.path.split('/') + path = base.path.split("/") if uid is not None: path.append("{uid}") @@ -316,28 +342,37 @@ def resource_path(*, path.extend(["{}"] * len(action)) query = urlencode(query_terms) - new_url = base._replace(path='/'.join(path), query=query).geturl() + new_url = base._replace(path="/".join(path), query=query).geturl() return format_escaped_url(new_url, *action, **kwargs, uid=uid) -def _data_manager_deprecation_checks(session, project_id: UUID, team_id: UUID, obj_type: str): +def _data_manager_deprecation_checks( + session, project_id: UUID, team_id: UUID, obj_type: str +): if team_id is None: if project_id is None: raise TypeError("Missing one required argument: team_id.") - warn(f"{obj_type} now belong to Teams, so the project_id parameter was deprecated in " - "3.4.0, and will be removed in 4.0. Please provide the team_id instead.", - DeprecationWarning) + warn( + f"{obj_type} now belong to Teams, so the project_id parameter was deprecated in " + "3.4.0, and will be removed in 4.0. Please provide the team_id instead.", + DeprecationWarning, + ) # avoiding a circular import from citrine.resources.project import Project - team_id = Project.get_team_id_from_project_id(session=session, project_id=project_id) + + team_id = Project.get_team_id_from_project_id( + session=session, project_id=project_id + ) return team_id def _pad_positional_args(args, n): if len(args) > 0: - warn("Positional arguments are deprecated and will be removed in v4.0. Please use keyword " - "arguments instead.", - DeprecationWarning) - return args + (None, ) * (n - len(args)) + warn( + "Positional arguments are deprecated and will be removed in v4.0. Please use keyword " + "arguments instead.", + DeprecationWarning, + ) + return args + (None,) * (n - len(args)) diff --git a/src/citrine/citrine.py b/src/citrine/citrine.py index e89d61d57..20331a3f5 100644 --- a/src/citrine/citrine.py +++ b/src/citrine/citrine.py @@ -26,28 +26,30 @@ class Citrine: """ - def __init__(self, - api_key: str = None, - *, - scheme: str = None, - host: str = None, - port: Optional[str] = None): + def __init__( + self, + api_key: str = None, + *, + scheme: str = None, + host: str = None, + port: Optional[str] = None, + ): if api_key is None: - api_key = environ.get('CITRINE_API_KEY') + api_key = environ.get("CITRINE_API_KEY") if scheme is None: - scheme = 'https' + scheme = "https" if host is None: - host = environ.get('CITRINE_API_HOST') + host = environ.get("CITRINE_API_HOST") if host is None: - raise ValueError("No host passed and environmental " - "variable CITRINE_API_HOST not set.") + raise ValueError( + "No host passed and environmental " + "variable CITRINE_API_HOST not set." + ) - self.session: Session = Session(refresh_token=api_key, - scheme=scheme, - host=host, - port=port - ) + self.session: Session = Session( + refresh_token=api_key, scheme=scheme, host=host, port=port + ) @property def projects(self) -> ProjectCollection: diff --git a/src/citrine/exceptions.py b/src/citrine/exceptions.py index a528cce80..84f969bcb 100644 --- a/src/citrine/exceptions.py +++ b/src/citrine/exceptions.py @@ -1,4 +1,5 @@ """Citrine-specific exceptions.""" + from types import SimpleNamespace from typing import Optional, List from urllib.parse import urlencode @@ -54,11 +55,13 @@ def __init__(self, path: str, response: Optional[Response] = None): resp_json = response.json() if isinstance(resp_json, dict): from citrine.resources.api_error import ApiError + self.api_error = ApiError.build(resp_json) validation_error_msgs = [ "{} ({})".format(f.failure_message, f.failure_id) - for f in self.api_error.validation_errors] + for f in self.api_error.validation_errors + ] if self.api_error.message is not None: self.detailed_error_info.append(self.api_error.message) @@ -120,8 +123,12 @@ def build(*, message: str, method: str, path: str, params: dict = {}): status_code=404, request=SimpleNamespace(method=method.upper()), reason="Not Found", - json=lambda self: {"code": 404, "message": message, "validation_errors": []} - ) + json=lambda self: { + "code": 404, + "message": message, + "validation_errors": [], + }, + ), ) @@ -173,5 +180,6 @@ class ModuleRegistrationFailedException(NonRetryableException): def __init__(self, moduleType: str, exc: Exception): err = 'The "{0}" failed to register. {1}: {2}'.format( - moduleType, exc.__class__.__name__, str(exc)) + moduleType, exc.__class__.__name__, str(exc) + ) super().__init__(err) diff --git a/src/citrine/gemd_queries/criteria.py b/src/citrine/gemd_queries/criteria.py index 8c32dffb3..bf19b55e9 100644 --- a/src/citrine/gemd_queries/criteria.py +++ b/src/citrine/gemd_queries/criteria.py @@ -1,4 +1,5 @@ """Definitions for GemdQuery objects, and their sub-objects.""" + from typing import List, Type from gemd.enumeration.base_enumeration import BaseEnumeration @@ -8,12 +9,19 @@ from citrine._serialization import properties from citrine.gemd_queries.filter import PropertyFilterType -__all__ = ['MaterialClassification', 'TextSearchType', 'TagFilterType', - 'AndOperator', 'OrOperator', - 'PropertiesCriteria', 'NameCriteria', - 'MaterialRunClassificationCriteria', 'MaterialTemplatesCriteria', - 'TagsCriteria', 'ConnectivityClassCriteria' - ] +__all__ = [ + "MaterialClassification", + "TextSearchType", + "TagFilterType", + "AndOperator", + "OrOperator", + "PropertiesCriteria", + "NameCriteria", + "MaterialRunClassificationCriteria", + "MaterialTemplatesCriteria", + "TagsCriteria", + "ConnectivityClassCriteria", +] class MaterialClassification(BaseEnumeration): @@ -48,14 +56,19 @@ class Criteria(PolymorphicSerializable): def get_type(cls, data) -> Type[Serializable]: """Return the subtype.""" classes: List[Type[Criteria]] = [ - AndOperator, OrOperator, - PropertiesCriteria, NameCriteria, MaterialRunClassificationCriteria, - MaterialTemplatesCriteria, TagsCriteria, ConnectivityClassCriteria + AndOperator, + OrOperator, + PropertiesCriteria, + NameCriteria, + MaterialRunClassificationCriteria, + MaterialTemplatesCriteria, + TagsCriteria, + ConnectivityClassCriteria, ] - return {klass.typ: klass for klass in classes}[data['type']] + return {klass.typ: klass for klass in classes}[data["type"]] -class AndOperator(Serializable['AndOperator'], Criteria): +class AndOperator(Serializable["AndOperator"], Criteria): """ Combine multiple criteria, requiring EACH to be true for a match. @@ -67,10 +80,10 @@ class AndOperator(Serializable['AndOperator'], Criteria): """ criteria = properties.List(properties.Object(Criteria), "criteria") - typ = properties.String('type', default="and_operator", deserializable=False) + typ = properties.String("type", default="and_operator", deserializable=False) -class OrOperator(Serializable['OrOperator'], Criteria): +class OrOperator(Serializable["OrOperator"], Criteria): """ Combine multiple criteria, requiring ANY to be true for a match. @@ -82,10 +95,10 @@ class OrOperator(Serializable['OrOperator'], Criteria): """ criteria = properties.List(properties.Object(Criteria), "criteria") - typ = properties.String('type', default="or_operator", deserializable=False) + typ = properties.String("type", default="or_operator", deserializable=False) -class PropertiesCriteria(Serializable['PropertiesCriteria'], Criteria): +class PropertiesCriteria(Serializable["PropertiesCriteria"], Criteria): """ Look for materials with a particular Property and optionally Value types & ranges. @@ -98,14 +111,16 @@ class PropertiesCriteria(Serializable['PropertiesCriteria'], Criteria): """ - property_templates_filter = properties.Set(properties.UUID, "property_templates_filter") + property_templates_filter = properties.Set( + properties.UUID, "property_templates_filter" + ) value_type_filter = properties.Optional( properties.Object(PropertyFilterType), "value_type_filter" ) - typ = properties.String('type', default="properties_criteria", deserializable=False) + typ = properties.String("type", default="properties_criteria", deserializable=False) -class NameCriteria(Serializable['NameCriteria'], Criteria): +class NameCriteria(Serializable["NameCriteria"], Criteria): """ Look for materials with particular names. @@ -118,14 +133,13 @@ class NameCriteria(Serializable['NameCriteria'], Criteria): """ - name = properties.String('name') - search_type = properties.Enumeration(TextSearchType, 'search_type') - typ = properties.String('type', default="name_criteria", deserializable=False) + name = properties.String("name") + search_type = properties.Enumeration(TextSearchType, "search_type") + typ = properties.String("type", default="name_criteria", deserializable=False) class MaterialRunClassificationCriteria( - Serializable['MaterialRunClassificationCriteria'], - Criteria + Serializable["MaterialRunClassificationCriteria"], Criteria ): """ Look for materials with particular classification, defined by MaterialClassification. @@ -138,16 +152,14 @@ class MaterialRunClassificationCriteria( """ classifications = properties.Set( - properties.Enumeration(MaterialClassification), 'classifications' + properties.Enumeration(MaterialClassification), "classifications" ) typ = properties.String( - 'type', - default="material_run_classification_criteria", - deserializable=False + "type", default="material_run_classification_criteria", deserializable=False ) -class MaterialTemplatesCriteria(Serializable['MaterialTemplatesCriteria'], Criteria): +class MaterialTemplatesCriteria(Serializable["MaterialTemplatesCriteria"], Criteria): """ Look for materials with particular Material Templates and tags. @@ -163,14 +175,15 @@ class MaterialTemplatesCriteria(Serializable['MaterialTemplatesCriteria'], Crite """ material_templates_identifiers = properties.Set( - properties.UUID, - "material_templates_identifiers" + properties.UUID, "material_templates_identifiers" + ) + tag_filters = properties.Set(properties.String, "tag_filters") + typ = properties.String( + "type", default="material_template_criteria", deserializable=False ) - tag_filters = properties.Set(properties.String, 'tag_filters') - typ = properties.String('type', default="material_template_criteria", deserializable=False) -class TagsCriteria(Serializable['TagsCriteria'], Criteria): +class TagsCriteria(Serializable["TagsCriteria"], Criteria): """ Look for materials with particular tags. @@ -186,12 +199,12 @@ class TagsCriteria(Serializable['TagsCriteria'], Criteria): """ - tags = properties.Set(properties.String, 'tags') - filter_type = properties.Enumeration(TagFilterType, 'filter_type') - typ = properties.String('type', default="tags_criteria", deserializable=False) + tags = properties.Set(properties.String, "tags") + filter_type = properties.Enumeration(TagFilterType, "filter_type") + typ = properties.String("type", default="tags_criteria", deserializable=False) -class ConnectivityClassCriteria(Serializable['ConnectivityClassCriteria'], Criteria): +class ConnectivityClassCriteria(Serializable["ConnectivityClassCriteria"], Criteria): """ Look for materials with particular connectivity classes. @@ -204,6 +217,8 @@ class ConnectivityClassCriteria(Serializable['ConnectivityClassCriteria'], Crite """ - is_consumed = properties.Optional(properties.Boolean, 'is_consumed') - is_produced = properties.Optional(properties.Boolean, 'is_produced') - typ = properties.String('type', default="connectivity_class_criteria", deserializable=False) + is_consumed = properties.Optional(properties.Boolean, "is_consumed") + is_produced = properties.Optional(properties.Boolean, "is_produced") + typ = properties.String( + "type", default="connectivity_class_criteria", deserializable=False + ) diff --git a/src/citrine/gemd_queries/filter.py b/src/citrine/gemd_queries/filter.py index 0b8a555ca..0be4b8fd8 100644 --- a/src/citrine/gemd_queries/filter.py +++ b/src/citrine/gemd_queries/filter.py @@ -1,11 +1,12 @@ """Definitions for GemdQuery objects, and their sub-objects.""" + from typing import List, Type from citrine._serialization.serializable import Serializable from citrine._serialization.polymorphic_serializable import PolymorphicSerializable from citrine._serialization import properties -__all__ = ['AllRealFilter', 'AllIntegerFilter', 'NominalCategoricalFilter'] +__all__ = ["AllRealFilter", "AllIntegerFilter", "NominalCategoricalFilter"] class PropertyFilterType(PolymorphicSerializable): @@ -16,12 +17,13 @@ def get_type(cls, data) -> Type[Serializable]: """Return the subtype.""" classes: List[Type[PropertyFilterType]] = [ NominalCategoricalFilter, - AllRealFilter, AllIntegerFilter + AllRealFilter, + AllIntegerFilter, ] - return {klass.typ: klass for klass in classes}[data['type']] + return {klass.typ: klass for klass in classes}[data["type"]] -class AllRealFilter(Serializable['AllRealFilter'], PropertyFilterType): +class AllRealFilter(Serializable["AllRealFilter"], PropertyFilterType): """ Filter for any real value that fits certain constraints. @@ -36,13 +38,13 @@ class AllRealFilter(Serializable['AllRealFilter'], PropertyFilterType): """ - lower = properties.Float('lower') - upper = properties.Float('upper') - unit = properties.String('unit') - typ = properties.String('type', default="all_real_filter", deserializable=False) + lower = properties.Float("lower") + upper = properties.Float("upper") + unit = properties.String("unit") + typ = properties.String("type", default="all_real_filter", deserializable=False) -class AllIntegerFilter(Serializable['AllIntegerFilter'], PropertyFilterType): +class AllIntegerFilter(Serializable["AllIntegerFilter"], PropertyFilterType): """ Filter for any integer value that fits certain constraints. @@ -57,13 +59,15 @@ class AllIntegerFilter(Serializable['AllIntegerFilter'], PropertyFilterType): """ - lower = properties.Float('lower') - upper = properties.Float('upper') - inclusive = properties.Optional(properties.Boolean, 'inclusive', default=True) - typ = properties.String('type', default="all_integer_filter", deserializable=False) + lower = properties.Float("lower") + upper = properties.Float("upper") + inclusive = properties.Optional(properties.Boolean, "inclusive", default=True) + typ = properties.String("type", default="all_integer_filter", deserializable=False) -class NominalCategoricalFilter(Serializable['NominalCategoricalFilter'], PropertyFilterType): +class NominalCategoricalFilter( + Serializable["NominalCategoricalFilter"], PropertyFilterType +): """ Filter based upon a fixed list of Categorical Values. @@ -74,5 +78,7 @@ class NominalCategoricalFilter(Serializable['NominalCategoricalFilter'], Propert """ - categories = properties.Set(properties.String, 'categories') - typ = properties.String('type', default="nominal_categorical_filter", deserializable=False) + categories = properties.Set(properties.String, "categories") + typ = properties.String( + "type", default="nominal_categorical_filter", deserializable=False + ) diff --git a/src/citrine/gemd_queries/gemd_query.py b/src/citrine/gemd_queries/gemd_query.py index 28b3deaff..25e31dbbb 100644 --- a/src/citrine/gemd_queries/gemd_query.py +++ b/src/citrine/gemd_queries/gemd_query.py @@ -1,4 +1,5 @@ """Definitions for GemdQuery objects, and their sub-objects.""" + from gemd.enumeration.base_enumeration import BaseEnumeration from citrine._serialization.serializable import Serializable @@ -27,7 +28,7 @@ class GemdObjectType(BaseEnumeration): MEASUREMENT_SPEC_TYPE = "measurement_spec", "MEASUREMENT_SPEC_TYPE" -class GemdQuery(Serializable['GemdQuery']): +class GemdQuery(Serializable["GemdQuery"]): """ This describes what data objects to fetch (or graph of data objects). @@ -48,16 +49,16 @@ class GemdQuery(Serializable['GemdQuery']): datasets = properties.Set(properties.UUID, "datasets", default=set()) object_types = properties.Set( properties.Enumeration(GemdObjectType), - 'object_types', - default={x for x in GemdObjectType} + "object_types", + default={x for x in GemdObjectType}, ) - schema_version = properties.Integer('schema_version', default=1) + schema_version = properties.Integer("schema_version", default=1) @classmethod def _pre_build(cls, data: dict) -> dict: """Run data modification before building.""" - version = data.get('schema_version') - if data.get('schema_version') != 1: + version = data.get("schema_version") + if data.get("schema_version") != 1: raise ValueError( f"This version of the library only supports schema_version 1, not '{version}'" ) diff --git a/src/citrine/gemtables/columns.py b/src/citrine/gemtables/columns.py index 8c1612059..d97a687f7 100644 --- a/src/citrine/gemtables/columns.py +++ b/src/citrine/gemtables/columns.py @@ -1,4 +1,5 @@ """Column definitions for GEM Tables.""" + from typing import Type, Optional, List, Union from gemd.enumeration.base_enumeration import BaseEnumeration @@ -46,11 +47,14 @@ def _make_data_source(variable_rep: Union[str, Variable]) -> str: elif isinstance(variable_rep, Variable): return variable_rep.name else: - raise TypeError("Columns can only be linked by str or Variable." - "Instead got {}.".format(variable_rep)) + raise TypeError( + "Columns can only be linked by str or Variable.Instead got {}.".format( + variable_rep + ) + ) -class Column(PolymorphicSerializable['Column']): +class Column(PolymorphicSerializable["Column"]): """A column in the GEM Table, defined as some operation on a variable. Abstract type that returns the proper type given a serialized dict. @@ -69,11 +73,18 @@ def get_type(cls, data) -> Type[Serializable]: raise ValueError("Can only get types from dicts with a 'type' key") types: List[Type[Serializable]] = [ IdentityColumn, - MeanColumn, StdDevColumn, QuantileColumn, OriginalUnitsColumn, - MostLikelyCategoryColumn, MostLikelyProbabilityColumn, - FlatCompositionColumn, ComponentQuantityColumn, - NthBiggestComponentNameColumn, NthBiggestComponentQuantityColumn, - MolecularStructureColumn, ConcatColumn + MeanColumn, + StdDevColumn, + QuantileColumn, + OriginalUnitsColumn, + MostLikelyCategoryColumn, + MostLikelyProbabilityColumn, + FlatCompositionColumn, + ComponentQuantityColumn, + NthBiggestComponentNameColumn, + NthBiggestComponentQuantityColumn, + MolecularStructureColumn, + ConcatColumn, ] res = next((x for x in types if x.typ == data["type"]), None) if res is None: @@ -81,7 +92,7 @@ def get_type(cls, data) -> Type[Serializable]: return res -class MeanColumn(Serializable['MeanColumn'], Column): +class MeanColumn(Serializable["MeanColumn"], Column): """Column containing the mean of a real-valued variable. Parameters @@ -97,13 +108,13 @@ class MeanColumn(Serializable['MeanColumn'], Column): """ - data_source = properties.String('data_source') + data_source = properties.String("data_source") target_units = properties.Optional(properties.String, "target_units") - typ = properties.String('type', default="mean_column", deserializable=False) + typ = properties.String("type", default="mean_column", deserializable=False) - def __init__(self, *, - data_source: Union[str, Variable], - target_units: Optional[str] = None): + def __init__( + self, *, data_source: Union[str, Variable], target_units: Optional[str] = None + ): self.data_source = _make_data_source(data_source) self.target_units = target_units @@ -124,13 +135,13 @@ class StdDevColumn(Serializable["StdDevColumn"], Column): """ - data_source = properties.String('data_source') + data_source = properties.String("data_source") target_units = properties.Optional(properties.String, "target_units") - typ = properties.String('type', default="std_dev_column", deserializable=False) + typ = properties.String("type", default="std_dev_column", deserializable=False) - def __init__(self, *, - data_source: Union[str, Variable], - target_units: Optional[str] = None): + def __init__( + self, *, data_source: Union[str, Variable], target_units: Optional[str] = None + ): self.data_source = _make_data_source(data_source) self.target_units = target_units @@ -167,15 +178,18 @@ class QuantileColumn(Serializable["QuantileColumn"], Column): """ - data_source = properties.String('data_source') + data_source = properties.String("data_source") quantile = properties.Float("quantile") target_units = properties.Optional(properties.String, "target_units") - typ = properties.String('type', default="quantile_column", deserializable=False) - - def __init__(self, *, - data_source: Union[str, Variable], - quantile: float, - target_units: Optional[str] = None): + typ = properties.String("type", default="quantile_column", deserializable=False) + + def __init__( + self, + *, + data_source: Union[str, Variable], + quantile: float, + target_units: Optional[str] = None, + ): self.data_source = _make_data_source(data_source) self.quantile = quantile self.target_units = target_units @@ -191,8 +205,10 @@ class OriginalUnitsColumn(Serializable["OriginalUnitsColumn"], Column): """ - data_source = properties.String('data_source') - typ = properties.String('type', default="original_units_column", deserializable=False) + data_source = properties.String("data_source") + typ = properties.String( + "type", default="original_units_column", deserializable=False + ) def __init__(self, *, data_source: Union[str, Variable]): self.data_source = _make_data_source(data_source) @@ -208,8 +224,10 @@ class MostLikelyCategoryColumn(Serializable["MostLikelyCategoryColumn"], Column) """ - data_source = properties.String('data_source') - typ = properties.String('type', default="most_likely_category_column", deserializable=False) + data_source = properties.String("data_source") + typ = properties.String( + "type", default="most_likely_category_column", deserializable=False + ) def __init__(self, *, data_source: Union[str, Variable]): self.data_source = _make_data_source(data_source) @@ -225,8 +243,10 @@ class MostLikelyProbabilityColumn(Serializable["MostLikelyProbabilityColumn"], C """ - data_source = properties.String('data_source') - typ = properties.String('type', default="most_likely_probability_column", deserializable=False) + data_source = properties.String("data_source") + typ = properties.String( + "type", default="most_likely_probability_column", deserializable=False + ) def __init__(self, *, data_source: Union[str, Variable]): self.data_source = _make_data_source(data_source) @@ -248,13 +268,15 @@ class FlatCompositionColumn(Serializable["FlatCompositionColumn"], Column): """ - data_source = properties.String('data_source') - sort_order = properties.Enumeration(CompositionSortOrder, 'sort_order') - typ = properties.String('type', default="flat_composition_column", deserializable=False) + data_source = properties.String("data_source") + sort_order = properties.Enumeration(CompositionSortOrder, "sort_order") + typ = properties.String( + "type", default="flat_composition_column", deserializable=False + ) - def __init__(self, *, - data_source: Union[str, Variable], - sort_order: CompositionSortOrder): + def __init__( + self, *, data_source: Union[str, Variable], sort_order: CompositionSortOrder + ): self.data_source = _make_data_source(data_source) self.sort_order = sort_order @@ -275,21 +297,28 @@ class ComponentQuantityColumn(Serializable["ComponentQuantityColumn"], Column): """ - data_source = properties.String('data_source') + data_source = properties.String("data_source") component_name = properties.String("component_name") normalize = properties.Boolean("normalize") - typ = properties.String('type', default="component_quantity_column", deserializable=False) - - def __init__(self, *, - data_source: Union[str, Variable], - component_name: str, - normalize: bool = False): + typ = properties.String( + "type", default="component_quantity_column", deserializable=False + ) + + def __init__( + self, + *, + data_source: Union[str, Variable], + component_name: str, + normalize: bool = False, + ): self.data_source = _make_data_source(data_source) self.component_name = component_name self.normalize = normalize -class NthBiggestComponentNameColumn(Serializable["NthBiggestComponentNameColumn"], Column): +class NthBiggestComponentNameColumn( + Serializable["NthBiggestComponentNameColumn"], Column +): """Name of the Nth biggest component. If there are fewer than N components in the composition, then this column will be empty. @@ -303,18 +332,20 @@ class NthBiggestComponentNameColumn(Serializable["NthBiggestComponentNameColumn" """ - data_source = properties.String('data_source') + data_source = properties.String("data_source") n = properties.Integer("n") - typ = properties.String('type', default="biggest_component_name_column", deserializable=False) + typ = properties.String( + "type", default="biggest_component_name_column", deserializable=False + ) - def __init__(self, *, - data_source: Union[str, Variable], - n: int): + def __init__(self, *, data_source: Union[str, Variable], n: int): self.data_source = _make_data_source(data_source) self.n = n -class NthBiggestComponentQuantityColumn(Serializable["NthBiggestComponentQuantityColumn"], Column): +class NthBiggestComponentQuantityColumn( + Serializable["NthBiggestComponentQuantityColumn"], Column +): """Quantity of the Nth biggest component. If there are fewer than N components in the composition, then this column will be empty. @@ -330,22 +361,22 @@ class NthBiggestComponentQuantityColumn(Serializable["NthBiggestComponentQuantit """ - data_source = properties.String('data_source') + data_source = properties.String("data_source") n = properties.Integer("n") normalize = properties.Boolean("normalize") - typ = properties.String('type', - default="biggest_component_quantity_column", deserializable=False) + typ = properties.String( + "type", default="biggest_component_quantity_column", deserializable=False + ) - def __init__(self, *, - data_source: Union[str, Variable], - n: int, - normalize: bool = False): + def __init__( + self, *, data_source: Union[str, Variable], n: int, normalize: bool = False + ): self.data_source = _make_data_source(data_source) self.n = n self.normalize = normalize -class IdentityColumn(Serializable['IdentityColumn'], Column): +class IdentityColumn(Serializable["IdentityColumn"], Column): """Column containing the value of a string-valued variable. Parameters @@ -355,14 +386,14 @@ class IdentityColumn(Serializable['IdentityColumn'], Column): """ - data_source = properties.String('data_source') - typ = properties.String('type', default="identity_column", deserializable=False) + data_source = properties.String("data_source") + typ = properties.String("type", default="identity_column", deserializable=False) def __init__(self, *, data_source: Union[str, Variable]): self.data_source = _make_data_source(data_source) -class MolecularStructureColumn(Serializable['MolecularStructureColumn'], Column): +class MolecularStructureColumn(Serializable["MolecularStructureColumn"], Column): """Column containing a representation of a molecular structure. Parameters @@ -374,16 +405,20 @@ class MolecularStructureColumn(Serializable['MolecularStructureColumn'], Column) """ - data_source = properties.String('data_source') - format = properties.Enumeration(ChemicalDisplayFormat, 'format') - typ = properties.String('type', default="molecular_structure_column", deserializable=False) + data_source = properties.String("data_source") + format = properties.Enumeration(ChemicalDisplayFormat, "format") + typ = properties.String( + "type", default="molecular_structure_column", deserializable=False + ) - def __init__(self, *, data_source: Union[str, Variable], format: ChemicalDisplayFormat): + def __init__( + self, *, data_source: Union[str, Variable], format: ChemicalDisplayFormat + ): self.data_source = _make_data_source(data_source) self.format = format -class ConcatColumn(Serializable['ConcatColumn'], Column): +class ConcatColumn(Serializable["ConcatColumn"], Column): """Column that concatenates multiple values produced by a list- or set-valued variable. The input subcolumn need not exist elsewhere in the table config, and its parameters have @@ -399,9 +434,9 @@ class ConcatColumn(Serializable['ConcatColumn'], Column): """ - data_source = properties.String('data_source') - subcolumn = properties.Object(Column, 'subcolumn') - typ = properties.String('type', default="concat_column", deserializable=False) + data_source = properties.String("data_source") + subcolumn = properties.Object(Column, "subcolumn") + typ = properties.String("type", default="concat_column", deserializable=False) def __init__(self, *, data_source: Union[str, Variable], subcolumn: Column): self.data_source = _make_data_source(data_source) diff --git a/src/citrine/gemtables/rows.py b/src/citrine/gemtables/rows.py index b74374bd2..a41fee51c 100644 --- a/src/citrine/gemtables/rows.py +++ b/src/citrine/gemtables/rows.py @@ -1,4 +1,5 @@ """Row definitions for GEM Tables.""" + from typing import Type, List, Set, Union from uuid import UUID @@ -12,7 +13,7 @@ from citrine.resources.data_concepts import _make_link_by_uid -class Row(PolymorphicSerializable['Row']): +class Row(PolymorphicSerializable["Row"]): """A rule for defining rows in a GEM Table. Abstract type that returns the proper type given a serialized dict. @@ -29,16 +30,14 @@ def get_type(cls, data) -> Type[Serializable]: """Return the subtype.""" if "type" not in data: raise ValueError("Can only get types from dicts with a 'type' key") - types: List[Type[Serializable]] = [ - MaterialRunByTemplate - ] + types: List[Type[Serializable]] = [MaterialRunByTemplate] res = next((x for x in types if x.typ == data["type"]), None) if res is None: raise ValueError("Unrecognized type: {}".format(data["type"])) return res -class MaterialRunByTemplate(Serializable['MaterialRunByTemplate'], Row): +class MaterialRunByTemplate(Serializable["MaterialRunByTemplate"], Row): """Rows corresponding to MaterialRuns, marked by their template. Parameters @@ -52,14 +51,13 @@ class MaterialRunByTemplate(Serializable['MaterialRunByTemplate'], Row): """ templates = properties.List(properties.Object(LinkByUID), "templates") - typ = properties.String('type', default="material_run_by_template", deserializable=False) + typ = properties.String( + "type", default="material_run_by_template", deserializable=False + ) tags = properties.Optional(properties.Set(properties.String), "tags") template_type = Union[UUID, str, LinkByUID, MaterialTemplate] - def __init__(self, *, - templates: List[template_type], - tags: Set[str] = None): - + def __init__(self, *, templates: List[template_type], tags: Set[str] = None): self.templates = [_make_link_by_uid(x) for x in templates] self.tags = tags diff --git a/src/citrine/gemtables/variables.py b/src/citrine/gemtables/variables.py index 646434326..f02b39492 100644 --- a/src/citrine/gemtables/variables.py +++ b/src/citrine/gemtables/variables.py @@ -1,4 +1,5 @@ """Variable definitions for GEM Tables.""" + from typing import Type, Optional, List, Union, Tuple from uuid import UUID @@ -53,7 +54,7 @@ class DataObjectTypeSelector(BaseEnumeration): ANY = "any" -class Variable(PolymorphicSerializable['Variable']): +class Variable(PolymorphicSerializable["Variable"]): """A variable that can be assigned values present in material histories. Abstract type that returns the proper type given a serialized dict. @@ -71,13 +72,24 @@ def get_type(cls, data) -> Type[Serializable]: if "type" not in data: raise ValueError("Can only get types from dicts with a 'type' key") types: List[Type[Serializable]] = [ - TerminalMaterialInfo, AttributeByTemplate, AttributeByTemplateAfterProcessTemplate, - AttributeByTemplateAndObjectTemplate, LocalAttribute, LocalAttributeAndObject, - IngredientIdentifierByProcessTemplateAndName, IngredientLabelByProcessAndName, - IngredientLabelsSetByProcessAndName, IngredientQuantityByProcessAndName, - TerminalMaterialIdentifier, AttributeInOutput, - IngredientIdentifierInOutput, IngredientLabelsSetInOutput, IngredientQuantityInOutput, - LocalIngredientIdentifier, LocalIngredientLabelsSet, LocalIngredientQuantity, + TerminalMaterialInfo, + AttributeByTemplate, + AttributeByTemplateAfterProcessTemplate, + AttributeByTemplateAndObjectTemplate, + LocalAttribute, + LocalAttributeAndObject, + IngredientIdentifierByProcessTemplateAndName, + IngredientLabelByProcessAndName, + IngredientLabelsSetByProcessAndName, + IngredientQuantityByProcessAndName, + TerminalMaterialIdentifier, + AttributeInOutput, + IngredientIdentifierInOutput, + IngredientLabelsSetInOutput, + IngredientQuantityInOutput, + LocalIngredientIdentifier, + LocalIngredientLabelsSet, + LocalIngredientQuantity, XOR, ] res = next((x for x in types if x.typ == data["type"]), None) @@ -87,7 +99,7 @@ def get_type(cls, data) -> Type[Serializable]: return res -class TerminalMaterialInfo(Serializable['TerminalMaterialInfo'], Variable): +class TerminalMaterialInfo(Serializable["TerminalMaterialInfo"], Variable): """Metadata from the terminal material of the material history. Parameters @@ -102,21 +114,18 @@ class TerminalMaterialInfo(Serializable['TerminalMaterialInfo'], Variable): """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - field = properties.String('field') - typ = properties.String('type', default="root_info", deserializable=False) + name = properties.String("name") + headers = properties.List(properties.String, "headers") + field = properties.String("field") + typ = properties.String("type", default="root_info", deserializable=False) - def __init__(self, - name: str, *, - headers: List[str], - field: str): + def __init__(self, name: str, *, headers: List[str], field: str): self.name = name self.headers = headers self.field = field -class AttributeByTemplate(Serializable['AttributeByTemplate'], Variable): +class AttributeByTemplate(Serializable["AttributeByTemplate"], Variable): """Attribute marked by an attribute template. Parameters @@ -137,38 +146,48 @@ class AttributeByTemplate(Serializable['AttributeByTemplate'], Variable): """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - template = properties.Object(LinkByUID, 'template') + name = properties.String("name") + headers = properties.List(properties.String, "headers") + template = properties.Object(LinkByUID, "template") attribute_constraints = properties.Optional( properties.List( properties.SpecifiedMixedList( [properties.Object(LinkByUID), properties.Object(BaseBounds)] ) - ), 'attribute_constraints') + ), + "attribute_constraints", + ) type_selector = properties.Enumeration(DataObjectTypeSelector, "type_selector") - typ = properties.String('type', default="attribute_by_template", deserializable=False) + typ = properties.String( + "type", default="attribute_by_template", deserializable=False + ) attribute_type = Union[UUID, str, LinkByUID, AttributeTemplate] constraint_type = Tuple[attribute_type, BaseBounds] - def __init__(self, - name: str, - *, - headers: List[str], - template: attribute_type, - attribute_constraints: Optional[List[constraint_type]] = None, - type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN): + def __init__( + self, + name: str, + *, + headers: List[str], + template: attribute_type, + attribute_constraints: Optional[List[constraint_type]] = None, + type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, + ): self.name = name self.headers = headers self.template = _make_link_by_uid(template) - self.attribute_constraints = None if attribute_constraints is None \ + self.attribute_constraints = ( + None + if attribute_constraints is None else [(_make_link_by_uid(x[0]), x[1]) for x in attribute_constraints] + ) self.type_selector = type_selector class AttributeByTemplateAfterProcessTemplate( - Serializable['AttributeByTemplateAfterProcessTemplate'], Variable): + Serializable["AttributeByTemplateAfterProcessTemplate"], Variable +): """Attribute of an object marked by an attribute template and a parent process template. Parameters @@ -191,42 +210,52 @@ class AttributeByTemplateAfterProcessTemplate( """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - attribute_template = properties.Object(LinkByUID, 'attribute_template') - process_template = properties.Object(LinkByUID, 'process_template') + name = properties.String("name") + headers = properties.List(properties.String, "headers") + attribute_template = properties.Object(LinkByUID, "attribute_template") + process_template = properties.Object(LinkByUID, "process_template") attribute_constraints = properties.Optional( properties.List( properties.SpecifiedMixedList( [properties.Object(LinkByUID), properties.Object(BaseBounds)] ) - ), 'attribute_constraints') + ), + "attribute_constraints", + ) type_selector = properties.Enumeration(DataObjectTypeSelector, "type_selector") - typ = properties.String('type', default="attribute_after_process", deserializable=False) + typ = properties.String( + "type", default="attribute_after_process", deserializable=False + ) attribute_type = Union[UUID, str, LinkByUID, AttributeTemplate] process_type = Union[UUID, str, LinkByUID, ProcessTemplate] constraint_type = Tuple[attribute_type, BaseBounds] - def __init__(self, - name: str, - *, - headers: List[str], - attribute_template: attribute_type, - process_template: process_type, - attribute_constraints: Optional[List[constraint_type]] = None, - type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN): + def __init__( + self, + name: str, + *, + headers: List[str], + attribute_template: attribute_type, + process_template: process_type, + attribute_constraints: Optional[List[constraint_type]] = None, + type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, + ): self.name = name self.headers = headers self.attribute_template = _make_link_by_uid(attribute_template) self.process_template = _make_link_by_uid(process_template) - self.attribute_constraints = None if attribute_constraints is None \ + self.attribute_constraints = ( + None + if attribute_constraints is None else [(_make_link_by_uid(x[0]), x[1]) for x in attribute_constraints] + ) self.type_selector = type_selector class AttributeByTemplateAndObjectTemplate( - Serializable['AttributeByTemplateAndObjectTemplate'], Variable): + Serializable["AttributeByTemplateAndObjectTemplate"], Variable +): """Attribute marked by an attribute template and an object template. For example, one property may be measured by two different measurement techniques. In this @@ -254,41 +283,48 @@ class AttributeByTemplateAndObjectTemplate( """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - attribute_template = properties.Object(LinkByUID, 'attribute_template') - object_template = properties.Object(LinkByUID, 'object_template') + name = properties.String("name") + headers = properties.List(properties.String, "headers") + attribute_template = properties.Object(LinkByUID, "attribute_template") + object_template = properties.Object(LinkByUID, "object_template") attribute_constraints = properties.Optional( properties.List( properties.SpecifiedMixedList( [properties.Object(LinkByUID), properties.Object(BaseBounds)] ) - ), 'attribute_constraints') + ), + "attribute_constraints", + ) type_selector = properties.Enumeration(DataObjectTypeSelector, "type_selector") - typ = properties.String('type', default="attribute_by_object", deserializable=False) + typ = properties.String("type", default="attribute_by_object", deserializable=False) attribute_type = Union[UUID, str, LinkByUID, AttributeTemplate] object_type = Union[UUID, str, LinkByUID, BaseTemplate] constraint_type = Tuple[attribute_type, BaseBounds] - def __init__(self, - name: str, - *, - headers: List[str], - attribute_template: attribute_type, - object_template: object_type, - attribute_constraints: Optional[List[constraint_type]] = None, - type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN): + def __init__( + self, + name: str, + *, + headers: List[str], + attribute_template: attribute_type, + object_template: object_type, + attribute_constraints: Optional[List[constraint_type]] = None, + type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, + ): self.name = name self.headers = headers self.attribute_template = _make_link_by_uid(attribute_template) self.object_template = _make_link_by_uid(object_template) - self.attribute_constraints = None if attribute_constraints is None \ + self.attribute_constraints = ( + None + if attribute_constraints is None else [(_make_link_by_uid(x[0]), x[1]) for x in attribute_constraints] + ) self.type_selector = type_selector -class LocalAttribute(Serializable['LocalAttribute'], Variable): +class LocalAttribute(Serializable["LocalAttribute"], Variable): """[ALPHA] Attribute marked by an attribute template for the root of a material history tree. Parameters @@ -309,37 +345,44 @@ class LocalAttribute(Serializable['LocalAttribute'], Variable): """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - template = properties.Object(LinkByUID, 'template') + name = properties.String("name") + headers = properties.List(properties.String, "headers") + template = properties.Object(LinkByUID, "template") attribute_constraints = properties.Optional( properties.List( properties.SpecifiedMixedList( [properties.Object(LinkByUID), properties.Object(BaseBounds)] ) - ), 'attribute_constraints') + ), + "attribute_constraints", + ) type_selector = properties.Enumeration(DataObjectTypeSelector, "type_selector") - typ = properties.String('type', default="local_attribute", deserializable=False) + typ = properties.String("type", default="local_attribute", deserializable=False) attribute_type = Union[UUID, str, LinkByUID, AttributeTemplate] constraint_type = Tuple[attribute_type, BaseBounds] - def __init__(self, - name: str, - *, - headers: List[str], - template: attribute_type, - attribute_constraints: Optional[List[constraint_type]] = None, - type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN): + def __init__( + self, + name: str, + *, + headers: List[str], + template: attribute_type, + attribute_constraints: Optional[List[constraint_type]] = None, + type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, + ): self.name = name self.headers = headers self.template = _make_link_by_uid(template) - self.attribute_constraints = None if attribute_constraints is None \ + self.attribute_constraints = ( + None + if attribute_constraints is None else [(_make_link_by_uid(x[0]), x[1]) for x in attribute_constraints] + ) self.type_selector = type_selector -class LocalAttributeAndObject(Serializable['LocalAttributeAndObject'], Variable): +class LocalAttributeAndObject(Serializable["LocalAttributeAndObject"], Variable): """[ALPHA] Attribute marked by an attribute template for the root of a material history tree. Parameters @@ -362,42 +405,52 @@ class LocalAttributeAndObject(Serializable['LocalAttributeAndObject'], Variable) """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - template = properties.Object(LinkByUID, 'template') - object_template = properties.Object(LinkByUID, 'object_template') + name = properties.String("name") + headers = properties.List(properties.String, "headers") + template = properties.Object(LinkByUID, "template") + object_template = properties.Object(LinkByUID, "object_template") attribute_constraints = properties.Optional( properties.List( properties.SpecifiedMixedList( [properties.Object(LinkByUID), properties.Object(BaseBounds)] ) - ), 'attribute_constraints') + ), + "attribute_constraints", + ) type_selector = properties.Enumeration(DataObjectTypeSelector, "type_selector") - typ = properties.String('type', default="local_attribute_and_object", deserializable=False) + typ = properties.String( + "type", default="local_attribute_and_object", deserializable=False + ) attribute_type = Union[UUID, str, LinkByUID, AttributeTemplate] object_type = Union[UUID, str, LinkByUID, BaseTemplate] constraint_type = Tuple[attribute_type, BaseBounds] - def __init__(self, - name: str, - *, - headers: List[str], - template: attribute_type, - object_template: object_type, - attribute_constraints: Optional[List[constraint_type]] = None, - type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN): + def __init__( + self, + name: str, + *, + headers: List[str], + template: attribute_type, + object_template: object_type, + attribute_constraints: Optional[List[constraint_type]] = None, + type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, + ): self.name = name self.headers = headers self.template = _make_link_by_uid(template) self.object_template = _make_link_by_uid(object_template) - self.attribute_constraints = None if attribute_constraints is None \ + self.attribute_constraints = ( + None + if attribute_constraints is None else [(_make_link_by_uid(x[0]), x[1]) for x in attribute_constraints] + ) self.type_selector = type_selector class IngredientIdentifierByProcessTemplateAndName( - Serializable['IngredientIdentifierByProcessAndName'], Variable): + Serializable["IngredientIdentifierByProcessAndName"], Variable +): """Ingredient identifier associated with a process template and a name. Parameters @@ -417,24 +470,28 @@ class IngredientIdentifierByProcessTemplateAndName( """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - process_template = properties.Object(LinkByUID, 'process_template') - ingredient_name = properties.String('ingredient_name') - scope = properties.String('scope') + name = properties.String("name") + headers = properties.List(properties.String, "headers") + process_template = properties.Object(LinkByUID, "process_template") + ingredient_name = properties.String("ingredient_name") + scope = properties.String("scope") type_selector = properties.Enumeration(DataObjectTypeSelector, "type_selector") - typ = properties.String('type', default="ing_id_by_process_and_name", deserializable=False) + typ = properties.String( + "type", default="ing_id_by_process_and_name", deserializable=False + ) process_type = Union[UUID, str, LinkByUID, ProcessTemplate] - def __init__(self, - name: str, - *, - headers: List[str], - process_template: process_type, - ingredient_name: str, - scope: str, # Note that the default is set server side - type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN): + def __init__( + self, + name: str, + *, + headers: List[str], + process_template: process_type, + ingredient_name: str, + scope: str, # Note that the default is set server side + type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, + ): self.name = name self.headers = headers self.process_template = _make_link_by_uid(process_template) @@ -443,7 +500,9 @@ def __init__(self, self.type_selector = type_selector -class IngredientLabelByProcessAndName(Serializable['IngredientLabelByProcessAndName'], Variable): +class IngredientLabelByProcessAndName( + Serializable["IngredientLabelByProcessAndName"], Variable +): """A boolean variable indicating whether a given label is applied. Matches by process template, ingredient name, and the label string to check. @@ -469,24 +528,28 @@ class IngredientLabelByProcessAndName(Serializable['IngredientLabelByProcessAndN """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - process_template = properties.Object(LinkByUID, 'process_template') - ingredient_name = properties.String('ingredient_name') - label = properties.String('label') + name = properties.String("name") + headers = properties.List(properties.String, "headers") + process_template = properties.Object(LinkByUID, "process_template") + ingredient_name = properties.String("ingredient_name") + label = properties.String("label") type_selector = properties.Enumeration(DataObjectTypeSelector, "type_selector") - typ = properties.String('type', default="ing_label_by_process_and_name", deserializable=False) + typ = properties.String( + "type", default="ing_label_by_process_and_name", deserializable=False + ) process_type = Union[UUID, str, LinkByUID, ProcessTemplate] - def __init__(self, - name: str, - *, - headers: List[str], - process_template: process_type, - ingredient_name: str, - label: str, - type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN): + def __init__( + self, + name: str, + *, + headers: List[str], + process_template: process_type, + ingredient_name: str, + label: str, + type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, + ): self.name = name self.headers = headers self.process_template = _make_link_by_uid(process_template) @@ -496,8 +559,8 @@ def __init__(self, class IngredientLabelsSetByProcessAndName( - Serializable['IngredientLabelsSetByProcessAndName'], - Variable): + Serializable["IngredientLabelsSetByProcessAndName"], Variable +): """The set of labels on an ingredient when used in a process. For example, the ingredient "ethanol" might be labeled "solvent", "alcohol" and "VOC". @@ -516,22 +579,24 @@ class IngredientLabelsSetByProcessAndName( """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - process_template = properties.Object(LinkByUID, 'process_template') - ingredient_name = properties.String('ingredient_name') - typ = properties.String('type', - default="ing_labels_set_by_process_and_name", - deserializable=False) + name = properties.String("name") + headers = properties.List(properties.String, "headers") + process_template = properties.Object(LinkByUID, "process_template") + ingredient_name = properties.String("ingredient_name") + typ = properties.String( + "type", default="ing_labels_set_by_process_and_name", deserializable=False + ) process_type = Union[UUID, str, LinkByUID, ProcessTemplate] - def __init__(self, - name: str, - *, - headers: List[str], - process_template: process_type, - ingredient_name: str): + def __init__( + self, + name: str, + *, + headers: List[str], + process_template: process_type, + ingredient_name: str, + ): self.name = name self.headers = headers self.process_template = _make_link_by_uid(process_template) @@ -539,7 +604,8 @@ def __init__(self, class IngredientQuantityByProcessAndName( - Serializable['IngredientQuantityByProcessAndName'], Variable): + Serializable["IngredientQuantityByProcessAndName"], Variable +): """The quantity of an ingredient associated with a process template and a name. Parameters @@ -565,27 +631,32 @@ class IngredientQuantityByProcessAndName( """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - process_template = properties.Object(LinkByUID, 'process_template') - ingredient_name = properties.String('ingredient_name') - quantity_dimension = properties.Enumeration(IngredientQuantityDimension, 'quantity_dimension') + name = properties.String("name") + headers = properties.List(properties.String, "headers") + process_template = properties.Object(LinkByUID, "process_template") + ingredient_name = properties.String("ingredient_name") + quantity_dimension = properties.Enumeration( + IngredientQuantityDimension, "quantity_dimension" + ) type_selector = properties.Enumeration(DataObjectTypeSelector, "type_selector") - typ = properties.String('type', default="ing_quantity_by_process_and_name", - deserializable=False) + typ = properties.String( + "type", default="ing_quantity_by_process_and_name", deserializable=False + ) unit = properties.Optional(properties.String, "unit") process_type = Union[UUID, str, LinkByUID, ProcessTemplate] - def __init__(self, - name: str, - *, - headers: List[str], - process_template: process_type, - ingredient_name: str, - quantity_dimension: IngredientQuantityDimension, - type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, - unit: Optional[str] = None): + def __init__( + self, + name: str, + *, + headers: List[str], + process_template: process_type, + ingredient_name: str, + quantity_dimension: IngredientQuantityDimension, + type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, + unit: Optional[str] = None, + ): self.name = name self.headers = headers self.process_template = _make_link_by_uid(process_template) @@ -593,19 +664,22 @@ def __init__(self, self.type_selector = type_selector # Cast to make sure the string is valid - self.quantity_dimension = IngredientQuantityDimension.from_str(quantity_dimension, - exception=True) + self.quantity_dimension = IngredientQuantityDimension.from_str( + quantity_dimension, exception=True + ) if quantity_dimension == IngredientQuantityDimension.ABSOLUTE: if unit is None: - raise ValueError("Absolute Quantity variables require that 'unit' is set") + raise ValueError( + "Absolute Quantity variables require that 'unit' is set" + ) else: if unit is not None and unit != "": raise ValueError("Fractional variables cannot take a 'unit'") self.unit = unit -class TerminalMaterialIdentifier(Serializable['TerminalMaterialIdentifier'], Variable): +class TerminalMaterialIdentifier(Serializable["TerminalMaterialIdentifier"], Variable): """A unique identifier of the terminal material of the material history, by scope. Parameters @@ -619,22 +693,18 @@ class TerminalMaterialIdentifier(Serializable['TerminalMaterialIdentifier'], Var """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - scope = properties.String('scope') - typ = properties.String('type', default="root_id", deserializable=False) + name = properties.String("name") + headers = properties.List(properties.String, "headers") + scope = properties.String("scope") + typ = properties.String("type", default="root_id", deserializable=False) - def __init__(self, - name: str, - *, - headers: List[str], - scope: str = CITRINE_SCOPE): + def __init__(self, name: str, *, headers: List[str], scope: str = CITRINE_SCOPE): self.name = name self.headers = headers self.scope = scope -class AttributeInOutput(Serializable['AttributeInOutput'], Variable): +class AttributeInOutput(Serializable["AttributeInOutput"], Variable): """Attribute marked by an attribute template in the trunk of the history tree. The search for an attribute that marks the given attribute template starts at the terminal @@ -677,41 +747,52 @@ class AttributeInOutput(Serializable['AttributeInOutput'], Variable): """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - attribute_template = properties.Object(LinkByUID, 'attribute_template') - process_templates = properties.List(properties.Object(LinkByUID), 'process_templates') + name = properties.String("name") + headers = properties.List(properties.String, "headers") + attribute_template = properties.Object(LinkByUID, "attribute_template") + process_templates = properties.List( + properties.Object(LinkByUID), "process_templates" + ) attribute_constraints = properties.Optional( properties.List( properties.SpecifiedMixedList( [properties.Object(LinkByUID), properties.Object(BaseBounds)] ) - ), 'attribute_constraints') + ), + "attribute_constraints", + ) type_selector = properties.Enumeration(DataObjectTypeSelector, "type_selector") - typ = properties.String('type', default="attribute_in_trunk", deserializable=False) + typ = properties.String("type", default="attribute_in_trunk", deserializable=False) attribute_type = Union[UUID, str, LinkByUID, AttributeTemplate] process_type = Union[UUID, str, LinkByUID, ProcessTemplate] constraint_type = Tuple[attribute_type, BaseBounds] - def __init__(self, - name: str, - *, - headers: List[str], - attribute_template: attribute_type, - process_templates: List[process_type], - attribute_constraints: Optional[List[constraint_type]] = None, - type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN): + def __init__( + self, + name: str, + *, + headers: List[str], + attribute_template: attribute_type, + process_templates: List[process_type], + attribute_constraints: Optional[List[constraint_type]] = None, + type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, + ): self.name = name self.headers = headers self.attribute_template = _make_link_by_uid(attribute_template) self.process_templates = [_make_link_by_uid(x) for x in process_templates] - self.attribute_constraints = None if attribute_constraints is None \ + self.attribute_constraints = ( + None + if attribute_constraints is None else [(_make_link_by_uid(x[0]), x[1]) for x in attribute_constraints] + ) self.type_selector = type_selector -class IngredientIdentifierInOutput(Serializable['IngredientIdentifierInOutput'], Variable): +class IngredientIdentifierInOutput( + Serializable["IngredientIdentifierInOutput"], Variable +): """Ingredient identifier in the trunk of a material history tree. The search for an ingredient starts at the terminal of the material history tree and @@ -757,23 +838,28 @@ class IngredientIdentifierInOutput(Serializable['IngredientIdentifierInOutput'], """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - ingredient_name = properties.String('ingredient_name') - process_templates = properties.List(properties.Object(LinkByUID), 'process_templates') - scope = properties.String('scope') + name = properties.String("name") + headers = properties.List(properties.String, "headers") + ingredient_name = properties.String("ingredient_name") + process_templates = properties.List( + properties.Object(LinkByUID), "process_templates" + ) + scope = properties.String("scope") type_selector = properties.Enumeration(DataObjectTypeSelector, "type_selector") - typ = properties.String('type', default="ing_id_in_output", deserializable=False) + typ = properties.String("type", default="ing_id_in_output", deserializable=False) process_type = Union[UUID, str, LinkByUID, ProcessTemplate] - def __init__(self, - name: str, *, - headers: List[str], - ingredient_name: str, - process_templates: List[process_type], - scope: str = CITRINE_SCOPE, - type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN): + def __init__( + self, + name: str, + *, + headers: List[str], + ingredient_name: str, + process_templates: List[process_type], + scope: str = CITRINE_SCOPE, + type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, + ): self.name = name self.headers = headers self.ingredient_name = ingredient_name @@ -782,7 +868,9 @@ def __init__(self, self.type_selector = type_selector -class IngredientLabelsSetInOutput(Serializable['IngredientLabelsSetInOutput'], Variable): +class IngredientLabelsSetInOutput( + Serializable["IngredientLabelsSetInOutput"], Variable +): """The set of labels on an ingredient in the trunk of a material history tree. The search for an ingredient starts at the terminal of the material history tree and proceeds @@ -825,26 +913,33 @@ class IngredientLabelsSetInOutput(Serializable['IngredientLabelsSetInOutput'], V """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - process_templates = properties.List(properties.Object(LinkByUID), 'process_templates') - ingredient_name = properties.String('ingredient_name') - typ = properties.String('type', default="ing_label_set_in_output", deserializable=False) + name = properties.String("name") + headers = properties.List(properties.String, "headers") + process_templates = properties.List( + properties.Object(LinkByUID), "process_templates" + ) + ingredient_name = properties.String("ingredient_name") + typ = properties.String( + "type", default="ing_label_set_in_output", deserializable=False + ) process_type = Union[UUID, str, LinkByUID, ProcessTemplate] - def __init__(self, - name: str, *, - headers: List[str], - process_templates: List[process_type], - ingredient_name: str): + def __init__( + self, + name: str, + *, + headers: List[str], + process_templates: List[process_type], + ingredient_name: str, + ): self.name = name self.headers = headers self.process_templates = [_make_link_by_uid(x) for x in process_templates] self.ingredient_name = ingredient_name -class IngredientQuantityInOutput(Serializable['IngredientQuantityInOutput'], Variable): +class IngredientQuantityInOutput(Serializable["IngredientQuantityInOutput"], Variable): """Ingredient quantity in the trunk of a material history tree. The search for an ingredient starts at the terminal of the material history tree and proceeds @@ -898,25 +993,34 @@ class IngredientQuantityInOutput(Serializable['IngredientQuantityInOutput'], Var """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - ingredient_name = properties.String('ingredient_name') - quantity_dimension = properties.Enumeration(IngredientQuantityDimension, 'quantity_dimension') - process_templates = properties.List(properties.Object(LinkByUID), 'process_templates') + name = properties.String("name") + headers = properties.List(properties.String, "headers") + ingredient_name = properties.String("ingredient_name") + quantity_dimension = properties.Enumeration( + IngredientQuantityDimension, "quantity_dimension" + ) + process_templates = properties.List( + properties.Object(LinkByUID), "process_templates" + ) type_selector = properties.Enumeration(DataObjectTypeSelector, "type_selector") unit = properties.Optional(properties.String, "unit") - typ = properties.String('type', default="ing_quantity_in_output", deserializable=False) + typ = properties.String( + "type", default="ing_quantity_in_output", deserializable=False + ) process_type = Union[UUID, str, LinkByUID, ProcessTemplate] - def __init__(self, - name: str, *, - headers: List[str], - ingredient_name: str, - quantity_dimension: IngredientQuantityDimension, - process_templates: List[process_type], - type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, - unit: Optional[str] = None): + def __init__( + self, + name: str, + *, + headers: List[str], + ingredient_name: str, + quantity_dimension: IngredientQuantityDimension, + process_templates: List[process_type], + type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, + unit: Optional[str] = None, + ): self.name = name self.headers = headers self.ingredient_name = ingredient_name @@ -924,19 +1028,22 @@ def __init__(self, self.type_selector = type_selector # Cast to make sure the string is valid - self.quantity_dimension = IngredientQuantityDimension.from_str(quantity_dimension, - exception=True) + self.quantity_dimension = IngredientQuantityDimension.from_str( + quantity_dimension, exception=True + ) if quantity_dimension == IngredientQuantityDimension.ABSOLUTE: if unit is None: - raise ValueError("Absolute Quantity variables require that 'unit' is set") + raise ValueError( + "Absolute Quantity variables require that 'unit' is set" + ) else: if unit is not None and unit != "": raise ValueError("Fractional variables cannot take a 'unit'") self.unit = unit -class LocalIngredientIdentifier(Serializable['LocalIngredientIdentifier'], Variable): +class LocalIngredientIdentifier(Serializable["LocalIngredientIdentifier"], Variable): """Ingredient identifier for the root process of a material history tree. Get ingredient identifier by name. Stop traversal when encountering any ingredient. @@ -961,21 +1068,24 @@ class LocalIngredientIdentifier(Serializable['LocalIngredientIdentifier'], Varia """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - ingredient_name = properties.String('ingredient_name') - scope = properties.String('scope') + name = properties.String("name") + headers = properties.List(properties.String, "headers") + ingredient_name = properties.String("ingredient_name") + scope = properties.String("scope") type_selector = properties.Enumeration(DataObjectTypeSelector, "type_selector") - typ = properties.String('type', default="local_ing_id", deserializable=False) + typ = properties.String("type", default="local_ing_id", deserializable=False) process_type = Union[UUID, str, LinkByUID, ProcessTemplate] - def __init__(self, - name: str, *, - headers: List[str], - ingredient_name: str, - scope: str = CITRINE_SCOPE, - type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN): + def __init__( + self, + name: str, + *, + headers: List[str], + ingredient_name: str, + scope: str = CITRINE_SCOPE, + type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, + ): self.name = name self.headers = headers self.ingredient_name = ingredient_name @@ -983,7 +1093,7 @@ def __init__(self, self.type_selector = type_selector -class LocalIngredientLabelsSet(Serializable['LocalIngredientLabelsSet'], Variable): +class LocalIngredientLabelsSet(Serializable["LocalIngredientLabelsSet"], Variable): """The set of labels on an ingredient for the root process of a material history tree. Define a variable contains the set of labels that is present on the ingredient @@ -1004,23 +1114,20 @@ class LocalIngredientLabelsSet(Serializable['LocalIngredientLabelsSet'], Variabl """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - ingredient_name = properties.String('ingredient_name') - typ = properties.String('type', default="local_ing_label_set", deserializable=False) + name = properties.String("name") + headers = properties.List(properties.String, "headers") + ingredient_name = properties.String("ingredient_name") + typ = properties.String("type", default="local_ing_label_set", deserializable=False) process_type = Union[UUID, str, LinkByUID, ProcessTemplate] - def __init__(self, - name: str, *, - headers: List[str], - ingredient_name: str): + def __init__(self, name: str, *, headers: List[str], ingredient_name: str): self.name = name self.headers = headers self.ingredient_name = ingredient_name -class LocalIngredientQuantity(Serializable['LocalIngredientQuantity'], Variable): +class LocalIngredientQuantity(Serializable["LocalIngredientQuantity"], Variable): """The quantity of an ingredient for the root process of a material history tree. Get ingredient quantity by name. Stop traversal when encountering any ingredient. @@ -1050,42 +1157,50 @@ class LocalIngredientQuantity(Serializable['LocalIngredientQuantity'], Variable) """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - ingredient_name = properties.String('ingredient_name') - quantity_dimension = properties.Enumeration(IngredientQuantityDimension, 'quantity_dimension') + name = properties.String("name") + headers = properties.List(properties.String, "headers") + ingredient_name = properties.String("ingredient_name") + quantity_dimension = properties.Enumeration( + IngredientQuantityDimension, "quantity_dimension" + ) type_selector = properties.Enumeration(DataObjectTypeSelector, "type_selector") unit = properties.Optional(properties.String, "unit") - typ = properties.String('type', default="local_ing_quantity", deserializable=False) + typ = properties.String("type", default="local_ing_quantity", deserializable=False) process_type = Union[UUID, str, LinkByUID, ProcessTemplate] - def __init__(self, - name: str, *, - headers: List[str], - ingredient_name: str, - quantity_dimension: IngredientQuantityDimension, - type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, - unit: Optional[str] = None): + def __init__( + self, + name: str, + *, + headers: List[str], + ingredient_name: str, + quantity_dimension: IngredientQuantityDimension, + type_selector: DataObjectTypeSelector = DataObjectTypeSelector.PREFER_RUN, + unit: Optional[str] = None, + ): self.name = name self.headers = headers self.ingredient_name = ingredient_name self.type_selector = type_selector # Cast to make sure the string is valid - self.quantity_dimension = IngredientQuantityDimension.from_str(quantity_dimension, - exception=True) + self.quantity_dimension = IngredientQuantityDimension.from_str( + quantity_dimension, exception=True + ) if quantity_dimension == IngredientQuantityDimension.ABSOLUTE: if unit is None: - raise ValueError("Absolute Quantity variables require that 'unit' is set") + raise ValueError( + "Absolute Quantity variables require that 'unit' is set" + ) else: if unit is not None and unit != "": raise ValueError("Fractional variables cannot take a 'unit'") self.unit = unit -class XOR(Serializable['XOR'], Variable): +class XOR(Serializable["XOR"], Variable): """Logical exclusive OR for GEM table variables. This variable combines the results of 2 or more variables into a single variable according to @@ -1114,10 +1229,10 @@ class XOR(Serializable['XOR'], Variable): """ - name = properties.String('name') - headers = properties.List(properties.String, 'headers') - variables = properties.List(properties.Object(Variable), 'variables') - typ = properties.String('type', default="xor", deserializable=False) + name = properties.String("name") + headers = properties.List(properties.String, "headers") + variables = properties.List(properties.Object(Variable), "variables") + typ = properties.String("type", default="xor", deserializable=False) def __init__(self, name, *, headers, variables): self.name = name diff --git a/src/citrine/informatics/catalyst/assistant.py b/src/citrine/informatics/catalyst/assistant.py index d9bf5939b..c0863ec93 100644 --- a/src/citrine/informatics/catalyst/assistant.py +++ b/src/citrine/informatics/catalyst/assistant.py @@ -13,14 +13,20 @@ class AssistantRequest(Serializable["AssistantRequest"]): question = properties.String("question") predictor = properties.Object(GraphPredictor, "config") temperature = properties.Optional(properties.Float, "temperature", default=0.0) - language_model = properties.Optional(properties.Enumeration(LanguageModelChoice), - "language_model", default=LanguageModelChoice.GPT_4) - - def __init__(self, *, - question: str, - predictor: GraphPredictor, - temperature: Optional[float] = 0.0, - language_model: Optional[LanguageModelChoice] = LanguageModelChoice.GPT_4): + language_model = properties.Optional( + properties.Enumeration(LanguageModelChoice), + "language_model", + default=LanguageModelChoice.GPT_4, + ) + + def __init__( + self, + *, + question: str, + predictor: GraphPredictor, + temperature: Optional[float] = 0.0, + language_model: Optional[LanguageModelChoice] = LanguageModelChoice.GPT_4, + ): self.question = question self.predictor = predictor self.temperature = temperature @@ -36,32 +42,36 @@ class AssistantResponse(PolymorphicSerializable["AssistantResponse"]): """The parent type for all Model Assistant responses.""" @classmethod - def get_type(cls, data) -> Type['AssistantResponse']: + def get_type(cls, data) -> Type["AssistantResponse"]: """Return the subtype.""" type_dict = { "message": AssistantResponseMessage, "modified-config": AssistantResponseConfig, "unsupported": AssistantResponseUnsupported, "input-error": AssistantResponseInputErrors, - "exec-error": AssistantResponseExecError + "exec-error": AssistantResponseExecError, } - typ = type_dict.get(data['type']) + typ = type_dict.get(data["type"]) if typ is not None: return typ else: raise ValueError( - f'{data["type"]} is not a valid assistant response type. ' - f'Must be in {type_dict.keys()}.' + f"{data['type']} is not a valid assistant response type. " + f"Must be in {type_dict.keys()}." ) -class AssistantResponseMessage(Serializable["AssistantResponseMessage"], AssistantResponse): +class AssistantResponseMessage( + Serializable["AssistantResponseMessage"], AssistantResponse +): """A successful model assistant invocation, whose response is only text.""" message = properties.String("data.message") -class AssistantResponseConfig(Serializable["AssistantResponseConfig"], AssistantResponse): +class AssistantResponseConfig( + Serializable["AssistantResponseConfig"], AssistantResponse +): """A successful model assistant invocation, whose response includes a modified predictor.""" predictor = properties.Object(GraphPredictor, "data.config") @@ -72,8 +82,9 @@ def _pre_build(cls, data): return data -class AssistantResponseUnsupported(Serializable["AssistantResponseUnsupported"], - AssistantResponse): +class AssistantResponseUnsupported( + Serializable["AssistantResponseUnsupported"], AssistantResponse +): """A successful model assistant invocation, but for an unsupported query. This will cover any user query which the model assistant could not map to a functionality it @@ -83,7 +94,9 @@ class AssistantResponseUnsupported(Serializable["AssistantResponseUnsupported"], message = properties.String("data.message") -class AssistantResponseInputError(Serializable["AssistantResponseInputError"], AssistantResponse): +class AssistantResponseInputError( + Serializable["AssistantResponseInputError"], AssistantResponse +): """A single input failure. Contains the error message, and the field it applies to. @@ -93,18 +106,23 @@ class AssistantResponseInputError(Serializable["AssistantResponseInputError"], A error = properties.String("error") -class AssistantResponseInputErrors(Serializable["AssistantResponseInputErrors"], - AssistantResponse): +class AssistantResponseInputErrors( + Serializable["AssistantResponseInputErrors"], AssistantResponse +): """A failed model assistant invocation, due to malformed input. This should only happen if there's some field omitted by the client, or one of its values is outside acceptable ranges. """ - errors = properties.List(properties.Object(AssistantResponseInputError), "data.errors") + errors = properties.List( + properties.Object(AssistantResponseInputError), "data.errors" + ) -class AssistantResponseExecError(Serializable["AssistantResponseExecError"], AssistantResponse): +class AssistantResponseExecError( + Serializable["AssistantResponseExecError"], AssistantResponse +): """A failed model assistant invocation, due to some internal issue. This most likely indicates the assistant got some unexpected output when asking its query. It diff --git a/src/citrine/informatics/catalyst/insights.py b/src/citrine/informatics/catalyst/insights.py index a0dc8a827..185631e1f 100644 --- a/src/citrine/informatics/catalyst/insights.py +++ b/src/citrine/informatics/catalyst/insights.py @@ -16,14 +16,18 @@ class InsightsRequest(Serializable["InsightsRequest"]): default=LanguageModelChoice.GPT_35_TURBO, ) n_documents = properties.Optional(properties.Integer, "n_documents", default=5) - response_size = properties.Optional(properties.Integer, "response_size", default=100) + response_size = properties.Optional( + properties.Integer, "response_size", default=100 + ) def __init__( self, *, question: str, temperature: Optional[float] = 0.0, - language_model: Optional[LanguageModelChoice] = LanguageModelChoice.GPT_35_TURBO, + language_model: Optional[ + LanguageModelChoice + ] = LanguageModelChoice.GPT_35_TURBO, n_documents: Optional[int] = 5, response_size: Optional[int] = 100, ): diff --git a/src/citrine/informatics/constraints/categorical_constraint.py b/src/citrine/informatics/constraints/categorical_constraint.py index 3b381da6e..e950165e4 100644 --- a/src/citrine/informatics/constraints/categorical_constraint.py +++ b/src/citrine/informatics/constraints/categorical_constraint.py @@ -4,10 +4,12 @@ from citrine._serialization.serializable import Serializable from citrine.informatics.constraints.constraint import Constraint -__all__ = ['AcceptableCategoriesConstraint'] +__all__ = ["AcceptableCategoriesConstraint"] -class AcceptableCategoriesConstraint(Serializable['AcceptableCategoriesConstraint'], Constraint): +class AcceptableCategoriesConstraint( + Serializable["AcceptableCategoriesConstraint"], Constraint +): """ A constraint on a categorical material attribute to be in a set of acceptable values. @@ -20,16 +22,13 @@ class AcceptableCategoriesConstraint(Serializable['AcceptableCategoriesConstrain """ - descriptor_key = properties.String('descriptor_key') - acceptable_categories = properties.List(properties.String(), 'acceptable_classes') - typ = properties.String('type', default='AcceptableCategoriesConstraint') + descriptor_key = properties.String("descriptor_key") + acceptable_categories = properties.List(properties.String(), "acceptable_classes") + typ = properties.String("type", default="AcceptableCategoriesConstraint") - def __init__(self, - *, - descriptor_key: str, - acceptable_categories: List[str]): + def __init__(self, *, descriptor_key: str, acceptable_categories: List[str]): self.descriptor_key = descriptor_key self.acceptable_categories = acceptable_categories def __str__(self): - return ''.format(self.descriptor_key) + return "".format(self.descriptor_key) diff --git a/src/citrine/informatics/constraints/constraint.py b/src/citrine/informatics/constraints/constraint.py index 057978555..644d0f564 100644 --- a/src/citrine/informatics/constraints/constraint.py +++ b/src/citrine/informatics/constraints/constraint.py @@ -2,10 +2,10 @@ from citrine._serialization.polymorphic_serializable import PolymorphicSerializable -__all__ = ['Constraint'] +__all__ = ["Constraint"] -class Constraint(PolymorphicSerializable['Constraint']): +class Constraint(PolymorphicSerializable["Constraint"]): """A Citrine Constraint places restrictions on a design space. Abstract type that returns the proper type given a serialized dict. @@ -24,13 +24,14 @@ def get_type(cls, data): from .integer_range_constraint import IntegerRangeConstraint from .categorical_constraint import AcceptableCategoriesConstraint from .ingredient_ratio_constraint import IngredientRatioConstraint + return { - 'Categorical': AcceptableCategoriesConstraint, # Kept for backwards compatibility. - 'AcceptableCategoriesConstraint': AcceptableCategoriesConstraint, - 'IngredientCountConstraint': IngredientCountConstraint, - 'IngredientFractionConstraint': IngredientFractionConstraint, - 'LabelFractionConstraint': LabelFractionConstraint, - 'ScalarRange': ScalarRangeConstraint, - 'IntegerRange': IntegerRangeConstraint, - 'IngredientRatio': IngredientRatioConstraint, - }[data['type']] + "Categorical": AcceptableCategoriesConstraint, # Kept for backwards compatibility. + "AcceptableCategoriesConstraint": AcceptableCategoriesConstraint, + "IngredientCountConstraint": IngredientCountConstraint, + "IngredientFractionConstraint": IngredientFractionConstraint, + "LabelFractionConstraint": LabelFractionConstraint, + "ScalarRange": ScalarRangeConstraint, + "IntegerRange": IntegerRangeConstraint, + "IngredientRatio": IngredientRatioConstraint, + }[data["type"]] diff --git a/src/citrine/informatics/constraints/ingredient_count_constraint.py b/src/citrine/informatics/constraints/ingredient_count_constraint.py index 7bdc0ef74..9b6dbfd66 100644 --- a/src/citrine/informatics/constraints/ingredient_count_constraint.py +++ b/src/citrine/informatics/constraints/ingredient_count_constraint.py @@ -5,10 +5,10 @@ from citrine.informatics.constraints.constraint import Constraint from citrine.informatics.descriptors import FormulationDescriptor -__all__ = ['IngredientCountConstraint'] +__all__ = ["IngredientCountConstraint"] -class IngredientCountConstraint(Serializable['IngredientCountConstraint'], Constraint): +class IngredientCountConstraint(Serializable["IngredientCountConstraint"], Constraint): """Represents a constraint on the total number of ingredients in a formulation. Parameters @@ -26,17 +26,22 @@ class IngredientCountConstraint(Serializable['IngredientCountConstraint'], Const """ - formulation_descriptor = properties.Object(FormulationDescriptor, 'formulation_descriptor') - min = properties.Integer('min') - max = properties.Integer('max') - label = properties.Optional(properties.String, 'label') - typ = properties.String('type', default='IngredientCountConstraint') - - def __init__(self, *, - formulation_descriptor: FormulationDescriptor, - min: int, - max: int, - label: Optional[str] = None): + formulation_descriptor = properties.Object( + FormulationDescriptor, "formulation_descriptor" + ) + min = properties.Integer("min") + max = properties.Integer("max") + label = properties.Optional(properties.String, "label") + typ = properties.String("type", default="IngredientCountConstraint") + + def __init__( + self, + *, + formulation_descriptor: FormulationDescriptor, + min: int, + max: int, + label: Optional[str] = None, + ): self.formulation_descriptor: FormulationDescriptor = formulation_descriptor self.min: int = min self.max: int = max diff --git a/src/citrine/informatics/constraints/ingredient_fraction_constraint.py b/src/citrine/informatics/constraints/ingredient_fraction_constraint.py index d289b36ec..b95e349bd 100644 --- a/src/citrine/informatics/constraints/ingredient_fraction_constraint.py +++ b/src/citrine/informatics/constraints/ingredient_fraction_constraint.py @@ -3,10 +3,12 @@ from citrine.informatics.constraints.constraint import Constraint from citrine.informatics.descriptors import FormulationDescriptor -__all__ = ['IngredientFractionConstraint'] +__all__ = ["IngredientFractionConstraint"] -class IngredientFractionConstraint(Serializable['IngredientFractionConstraint'], Constraint): +class IngredientFractionConstraint( + Serializable["IngredientFractionConstraint"], Constraint +): """Represents a constraint on an ingredient fraction in a formulation. Parameters @@ -27,19 +29,24 @@ class IngredientFractionConstraint(Serializable['IngredientFractionConstraint'], """ - formulation_descriptor = properties.Object(FormulationDescriptor, 'formulation_descriptor') - ingredient = properties.String('ingredient') - min = properties.Optional(properties.Float, 'min') - max = properties.Optional(properties.Float, 'max') - is_required = properties.Boolean('is_required') - typ = properties.String('type', default='IngredientFractionConstraint') - - def __init__(self, *, - formulation_descriptor: FormulationDescriptor, - ingredient: str, - min: float, - max: float, - is_required: bool = True): + formulation_descriptor = properties.Object( + FormulationDescriptor, "formulation_descriptor" + ) + ingredient = properties.String("ingredient") + min = properties.Optional(properties.Float, "min") + max = properties.Optional(properties.Float, "max") + is_required = properties.Boolean("is_required") + typ = properties.String("type", default="IngredientFractionConstraint") + + def __init__( + self, + *, + formulation_descriptor: FormulationDescriptor, + ingredient: str, + min: float, + max: float, + is_required: bool = True, + ): self.formulation_descriptor: FormulationDescriptor = formulation_descriptor self.ingredient: str = ingredient self.min: float = min diff --git a/src/citrine/informatics/constraints/ingredient_ratio_constraint.py b/src/citrine/informatics/constraints/ingredient_ratio_constraint.py index f85364ce1..ccceda411 100644 --- a/src/citrine/informatics/constraints/ingredient_ratio_constraint.py +++ b/src/citrine/informatics/constraints/ingredient_ratio_constraint.py @@ -6,10 +6,10 @@ from citrine.informatics.constraints.constraint import Constraint from citrine.informatics.descriptors import FormulationDescriptor -__all__ = ['IngredientRatioConstraint'] +__all__ = ["IngredientRatioConstraint"] -class IngredientRatioConstraint(Serializable['IngredientRatioConstraint'], Constraint): +class IngredientRatioConstraint(Serializable["IngredientRatioConstraint"], Constraint): """A formulation constraint operating on the ratio of quantities of ingredients and a basis. Example: "6 to 7 parts ingredient A per 100 parts ingredient B" becomes @@ -37,37 +37,47 @@ class IngredientRatioConstraint(Serializable['IngredientRatioConstraint'], Const """ - formulation_descriptor = properties.Object(FormulationDescriptor, 'formulation_descriptor') - min = properties.Float('min') - max = properties.Float('max') + formulation_descriptor = properties.Object( + FormulationDescriptor, "formulation_descriptor" + ) + min = properties.Float("min") + max = properties.Float("max") # The backend provides ingredients and labels as dictionaries, but presently only allows one # between them. To clarify customer interaction, we only allow a single one of each to be set. # Since our serde library doesn't allow extracting from a dict with unknown keys, we do it by # hiding the dictionaries and exposing properties. _ingredients = properties.Mapping( - properties.String, properties.Float, 'ingredients', default={}) - _labels = properties.Mapping(properties.String, properties.Float, 'labels', default={}) + properties.String, properties.Float, "ingredients", default={} + ) + _labels = properties.Mapping( + properties.String, properties.Float, "labels", default={} + ) # The backend provides basis ingredients and basis labels as a dictionary from the key to a # multiplier. However, for ingredient ratio constraints, the multiplier in the denominator # should always be one, so we can't allow users to enter it. We need to use properties for this # behavior. It also allows us to display deprecation warnings for the coming type change. _basis_ingredients = properties.Mapping( - properties.String, properties.Float, 'basis_ingredients', default={}) + properties.String, properties.Float, "basis_ingredients", default={} + ) _basis_labels = properties.Mapping( - properties.String, properties.Float, 'basis_labels', default={}) - - typ = properties.String('type', default='IngredientRatio') - - def __init__(self, *, - formulation_descriptor: FormulationDescriptor, - min: float, - max: float, - ingredient: Optional[Tuple[str, float]] = None, - label: Optional[Tuple[str, float]] = None, - basis_ingredients: Set[str] = set(), - basis_labels: Set[str] = set()): + properties.String, properties.Float, "basis_labels", default={} + ) + + typ = properties.String("type", default="IngredientRatio") + + def __init__( + self, + *, + formulation_descriptor: FormulationDescriptor, + min: float, + max: float, + ingredient: Optional[Tuple[str, float]] = None, + label: Optional[Tuple[str, float]] = None, + basis_ingredients: Set[str] = set(), + basis_labels: Set[str] = set(), + ): self.formulation_descriptor = formulation_descriptor self.min = min self.max = max @@ -109,16 +119,22 @@ def basis_ingredients(self, value: Set[str]): @property def basis_ingredient_names(self) -> Set[str]: """Retrieve the names of all ingredients in the denominator of the ratio.""" - warnings.warn("basis_ingredient_names is deprecated as of 3.0.0 and will be dropped in " - "4.0. Please use basis_ingredients instead.", DeprecationWarning) + warnings.warn( + "basis_ingredient_names is deprecated as of 3.0.0 and will be dropped in " + "4.0. Please use basis_ingredients instead.", + DeprecationWarning, + ) return self.basis_ingredients # This is for symmetry; it's not strictly necessary. @basis_ingredient_names.setter def basis_ingredient_names(self, value: Set[str]): """Set the names of all ingredients in the denominator of the ratio.""" - warnings.warn("basis_ingredient_names is deprecated as of 3.0.0 and will be dropped in " - "4.0. Please use basis_ingredients instead.", DeprecationWarning) + warnings.warn( + "basis_ingredient_names is deprecated as of 3.0.0 and will be dropped in " + "4.0. Please use basis_ingredients instead.", + DeprecationWarning, + ) self.basis_ingredients = value @property @@ -134,16 +150,22 @@ def basis_labels(self, value: Set[str]): @property def basis_label_names(self) -> Set[str]: """Retrieve the names of all labels in the denominator of the ratio.""" - warnings.warn("basis_label_names is deprecated as of 3.0.0 and will be dropped in 4.0. " - "Please use basis_labels instead.", DeprecationWarning) + warnings.warn( + "basis_label_names is deprecated as of 3.0.0 and will be dropped in 4.0. " + "Please use basis_labels instead.", + DeprecationWarning, + ) return self.basis_labels # This is for symmetry; it's not strictly necessary. @basis_label_names.setter def basis_label_names(self, value: Set[str]): """Set the names of all labels in the denominator of the ratio.""" - warnings.warn("basis_label_names is deprecated as of 3.0.0 and will be dropped in 4.0. " - "Please use basis_labels instead.", DeprecationWarning) + warnings.warn( + "basis_label_names is deprecated as of 3.0.0 and will be dropped in 4.0. " + "Please use basis_labels instead.", + DeprecationWarning, + ) self.basis_labels = value def _numerator_read(self, num_dict): diff --git a/src/citrine/informatics/constraints/integer_range_constraint.py b/src/citrine/informatics/constraints/integer_range_constraint.py index a3ba9d699..fbb5d3bd9 100644 --- a/src/citrine/informatics/constraints/integer_range_constraint.py +++ b/src/citrine/informatics/constraints/integer_range_constraint.py @@ -4,10 +4,10 @@ from citrine._serialization.serializable import Serializable from citrine.informatics.constraints.constraint import Constraint -__all__ = ['IntegerRangeConstraint'] +__all__ = ["IntegerRangeConstraint"] -class IntegerRangeConstraint(Serializable['IntegerRangeConstraint'], Constraint): +class IntegerRangeConstraint(Serializable["IntegerRangeConstraint"], Constraint): """[ALPHA] Represents an inequality constraint on an integer-valued material attribute. Warning: IntegerRangeConstraints are not fully supported by the Citrine Platform web interface @@ -28,18 +28,21 @@ class IntegerRangeConstraint(Serializable['IntegerRangeConstraint'], Constraint) """ - descriptor_key = properties.String('descriptor_key') - lower_bound = properties.Optional(properties.Float, 'min') - upper_bound = properties.Optional(properties.Float, 'max') - typ = properties.String('type', default='IntegerRange') - - def __init__(self, *, - descriptor_key: str, - lower_bound: Optional[int] = None, - upper_bound: Optional[int] = None): + descriptor_key = properties.String("descriptor_key") + lower_bound = properties.Optional(properties.Float, "min") + upper_bound = properties.Optional(properties.Float, "max") + typ = properties.String("type", default="IntegerRange") + + def __init__( + self, + *, + descriptor_key: str, + lower_bound: Optional[int] = None, + upper_bound: Optional[int] = None, + ): self.descriptor_key = descriptor_key self.lower_bound = lower_bound self.upper_bound = upper_bound def __str__(self): - return ''.format(self.descriptor_key) + return "".format(self.descriptor_key) diff --git a/src/citrine/informatics/constraints/label_fraction_constraint.py b/src/citrine/informatics/constraints/label_fraction_constraint.py index 816d1ffd0..ef2fdb09a 100644 --- a/src/citrine/informatics/constraints/label_fraction_constraint.py +++ b/src/citrine/informatics/constraints/label_fraction_constraint.py @@ -3,10 +3,10 @@ from citrine.informatics.constraints.constraint import Constraint from citrine.informatics.descriptors import FormulationDescriptor -__all__ = ['LabelFractionConstraint'] +__all__ = ["LabelFractionConstraint"] -class LabelFractionConstraint(Serializable['LabelFractionConstraint'], Constraint): +class LabelFractionConstraint(Serializable["LabelFractionConstraint"], Constraint): """Represents a constraint on the total amount of ingredients with a given label. Parameters @@ -27,19 +27,24 @@ class LabelFractionConstraint(Serializable['LabelFractionConstraint'], Constrain """ - formulation_descriptor = properties.Object(FormulationDescriptor, 'formulation_descriptor') - label = properties.String('label') - min = properties.Optional(properties.Float, 'min') - max = properties.Optional(properties.Float, 'max') - is_required = properties.Boolean('is_required') - typ = properties.String('type', default='LabelFractionConstraint') - - def __init__(self, *, - formulation_descriptor: FormulationDescriptor, - label: str, - min: float, - max: float, - is_required: bool = True): + formulation_descriptor = properties.Object( + FormulationDescriptor, "formulation_descriptor" + ) + label = properties.String("label") + min = properties.Optional(properties.Float, "min") + max = properties.Optional(properties.Float, "max") + is_required = properties.Boolean("is_required") + typ = properties.String("type", default="LabelFractionConstraint") + + def __init__( + self, + *, + formulation_descriptor: FormulationDescriptor, + label: str, + min: float, + max: float, + is_required: bool = True, + ): self.formulation_descriptor: FormulationDescriptor = formulation_descriptor self.label: str = label self.min: float = min diff --git a/src/citrine/informatics/constraints/scalar_range_constraint.py b/src/citrine/informatics/constraints/scalar_range_constraint.py index 48793406b..c761b1709 100644 --- a/src/citrine/informatics/constraints/scalar_range_constraint.py +++ b/src/citrine/informatics/constraints/scalar_range_constraint.py @@ -4,10 +4,10 @@ from citrine._serialization.serializable import Serializable from citrine.informatics.constraints.constraint import Constraint -__all__ = ['ScalarRangeConstraint'] +__all__ = ["ScalarRangeConstraint"] -class ScalarRangeConstraint(Serializable['ScalarRangeConstraint'], Constraint): +class ScalarRangeConstraint(Serializable["ScalarRangeConstraint"], Constraint): """Represents an inequality constraint on a real-valued material attribute. Parameters @@ -25,19 +25,22 @@ class ScalarRangeConstraint(Serializable['ScalarRangeConstraint'], Constraint): """ - descriptor_key = properties.String('descriptor_key') - lower_bound = properties.Optional(properties.Float, 'min') - upper_bound = properties.Optional(properties.Float, 'max') - lower_inclusive = properties.Boolean('min_inclusive') - upper_inclusive = properties.Boolean('max_inclusive') - typ = properties.String('type', default='ScalarRange') - - def __init__(self, *, - descriptor_key: str, - lower_bound: Optional[float] = None, - upper_bound: Optional[float] = None, - lower_inclusive: Optional[bool] = None, - upper_inclusive: Optional[bool] = None): + descriptor_key = properties.String("descriptor_key") + lower_bound = properties.Optional(properties.Float, "min") + upper_bound = properties.Optional(properties.Float, "max") + lower_inclusive = properties.Boolean("min_inclusive") + upper_inclusive = properties.Boolean("max_inclusive") + typ = properties.String("type", default="ScalarRange") + + def __init__( + self, + *, + descriptor_key: str, + lower_bound: Optional[float] = None, + upper_bound: Optional[float] = None, + lower_inclusive: Optional[bool] = None, + upper_inclusive: Optional[bool] = None, + ): self.descriptor_key = descriptor_key self.lower_bound = lower_bound @@ -54,4 +57,4 @@ def __init__(self, *, self.upper_inclusive = upper_inclusive def __str__(self): - return ''.format(self.descriptor_key) + return "".format(self.descriptor_key) diff --git a/src/citrine/informatics/data_sources.py b/src/citrine/informatics/data_sources.py index 77bed62d6..0c172ba82 100644 --- a/src/citrine/informatics/data_sources.py +++ b/src/citrine/informatics/data_sources.py @@ -1,4 +1,5 @@ """Tools for working with Descriptors.""" + from abc import abstractmethod from typing import Type, List, Mapping, Optional, Union from uuid import UUID @@ -12,15 +13,15 @@ from citrine.resources.gemtables import GemTable __all__ = [ - 'DataSource', - 'CSVDataSource', - 'GemTableDataSource', - 'ExperimentDataSourceRef', - 'SnapshotDataSource', + "DataSource", + "CSVDataSource", + "GemTableDataSource", + "ExperimentDataSourceRef", + "SnapshotDataSource", ] -class DataSource(PolymorphicSerializable['DataSource']): +class DataSource(PolymorphicSerializable["DataSource"]): """A source of data for the AI engine. Data source provides a polymorphic interface for specifying different kinds of data as the @@ -36,7 +37,12 @@ def __eq__(self, other): @classmethod def _subclass_list(self) -> List[Type[Serializable]]: - return [CSVDataSource, GemTableDataSource, ExperimentDataSourceRef, SnapshotDataSource] + return [ + CSVDataSource, + GemTableDataSource, + ExperimentDataSourceRef, + SnapshotDataSource, + ] @classmethod def get_type(cls, data) -> Type[Serializable]: @@ -57,7 +63,9 @@ def _data_source_type(self) -> str: def from_data_source_id(cls, data_source_id: str) -> "DataSource": """Build a DataSource from a datasource_id.""" terms = data_source_id.split("::") - res = next((x for x in cls._subclass_list() if x._data_source_type == terms[0]), None) + res = next( + (x for x in cls._subclass_list() if x._data_source_type == terms[0]), None + ) if res is None: raise ValueError(f"Unrecognized type: {terms[0]}") return res._data_source_id_builder(*terms[1:]) @@ -72,7 +80,7 @@ def to_data_source_id(self) -> str: """Generate the data_source_id for this DataSource.""" -class CSVDataSource(Serializable['CSVDataSource'], DataSource): +class CSVDataSource(Serializable["CSVDataSource"], DataSource): """A data source based on a CSV file stored on the data platform. Parameters @@ -89,22 +97,27 @@ class CSVDataSource(Serializable['CSVDataSource'], DataSource): """ - typ = properties.String('type', default='csv_data_source', deserializable=False) + typ = properties.String("type", default="csv_data_source", deserializable=False) file_link = properties.Object(FileLink, "file_link") column_definitions = properties.Mapping( - properties.String, properties.Object(Descriptor), "column_definitions") + properties.String, properties.Object(Descriptor), "column_definitions" + ) identifiers = properties.Optional(properties.List(properties.String), "identifiers") _data_source_type = "csv" - def __init__(self, - *, - file_link: FileLink, - column_definitions: Mapping[str, Descriptor], - identifiers: Optional[List[str]] = None): - warn("CSVDataSource is deprecated as of 3.28.0 and will be removed in 4.0.0. Please use " - "another type of data source, such as GemTableDataSource.", - category=DeprecationWarning) + def __init__( + self, + *, + file_link: FileLink, + column_definitions: Mapping[str, Descriptor], + identifiers: Optional[List[str]] = None, + ): + warn( + "CSVDataSource is deprecated as of 3.28.0 and will be removed in 4.0.0. Please use " + "another type of data source, such as GemTableDataSource.", + category=DeprecationWarning, + ) self.file_link = file_link self.column_definitions = column_definitions self.identifiers = identifiers @@ -112,20 +125,23 @@ def __init__(self, @classmethod def _data_source_id_builder(cls, *args) -> DataSource: # TODO Figure out how to populate the column definitions - warn("A CSVDataSource was derived from a data_source_id " - "but is missing its column_definitions and identities", - UserWarning) + warn( + "A CSVDataSource was derived from a data_source_id " + "but is missing its column_definitions and identities", + UserWarning, + ) return CSVDataSource( - file_link=FileLink(url=args[0], filename=args[1]), - column_definitions={} + file_link=FileLink(url=args[0], filename=args[1]), column_definitions={} ) def to_data_source_id(self) -> str: """Generate the data_source_id for this DataSource.""" - return f"{self._data_source_type}::{self.file_link.url}::{self.file_link.filename}" + return ( + f"{self._data_source_type}::{self.file_link.url}::{self.file_link.filename}" + ) -class GemTableDataSource(Serializable['GemTableDataSource'], DataSource): +class GemTableDataSource(Serializable["GemTableDataSource"], DataSource): """A data source based on a GEM Table hosted on the data platform. Parameters @@ -138,16 +154,15 @@ class GemTableDataSource(Serializable['GemTableDataSource'], DataSource): """ - typ = properties.String('type', default='hosted_table_data_source', deserializable=False) + typ = properties.String( + "type", default="hosted_table_data_source", deserializable=False + ) table_id = properties.UUID("table_id") table_version = properties.Integer("table_version") _data_source_type = "gemd" - def __init__(self, - *, - table_id: UUID, - table_version: Union[int, str]): + def __init__(self, *, table_id: UUID, table_version: Union[int, str]): self.table_id: UUID = table_id self.table_version: Union[int, str] = table_version @@ -172,7 +187,7 @@ def from_gemtable(cls, table: GemTable) -> "GemTableDataSource": return GemTableDataSource(table_id=table.uid, table_version=table.version) -class ExperimentDataSourceRef(Serializable['ExperimentDataSourceRef'], DataSource): +class ExperimentDataSourceRef(Serializable["ExperimentDataSourceRef"], DataSource): """A reference to a data source based on an experiment result hosted on the data platform. Parameters @@ -182,7 +197,9 @@ class ExperimentDataSourceRef(Serializable['ExperimentDataSourceRef'], DataSourc """ - typ = properties.String('type', default='experiments_data_source', deserializable=False) + typ = properties.String( + "type", default="experiments_data_source", deserializable=False + ) datasource_id = properties.UUID("datasource_id") _data_source_type = "experiments" @@ -199,7 +216,7 @@ def to_data_source_id(self) -> str: return f"{self._data_source_type}::{self.datasource_id}" -class SnapshotDataSource(Serializable['SnapshotDataSource'], DataSource): +class SnapshotDataSource(Serializable["SnapshotDataSource"], DataSource): """A reference to a data source based on a Snapshot on the data platform. Parameters @@ -209,7 +226,9 @@ class SnapshotDataSource(Serializable['SnapshotDataSource'], DataSource): """ - typ = properties.String('type', default='snapshot_data_source', deserializable=False) + typ = properties.String( + "type", default="snapshot_data_source", deserializable=False + ) snapshot_id = properties.UUID("snapshot_id") _data_source_type = "snapshot" diff --git a/src/citrine/informatics/descriptors.py b/src/citrine/informatics/descriptors.py index 692dfe91e..6d28b5bd6 100644 --- a/src/citrine/informatics/descriptors.py +++ b/src/citrine/informatics/descriptors.py @@ -1,4 +1,5 @@ """Tools for working with Descriptors.""" + from typing import Type, Set, Union from gemd.enumeration.base_enumeration import BaseEnumeration @@ -7,14 +8,16 @@ from citrine._serialization.polymorphic_serializable import PolymorphicSerializable from citrine._serialization import properties -__all__ = ['Descriptor', - 'RealDescriptor', - 'IntegerDescriptor', - 'ChemicalFormulaDescriptor', - 'MolecularStructureDescriptor', - 'CategoricalDescriptor', - 'FormulationDescriptor', - 'FormulationKey'] +__all__ = [ + "Descriptor", + "RealDescriptor", + "IntegerDescriptor", + "ChemicalFormulaDescriptor", + "MolecularStructureDescriptor", + "CategoricalDescriptor", + "FormulationDescriptor", + "FormulationKey", +] class FormulationKey(BaseEnumeration): @@ -29,13 +32,13 @@ class FormulationKey(BaseEnumeration): FLAT = "Flat Formulation" -class Descriptor(PolymorphicSerializable['Descriptor']): +class Descriptor(PolymorphicSerializable["Descriptor"]): """A Descriptor describes the range of values that a quantity can take on. Abstract type that returns the proper type given a serialized dict. """ - key = properties.String('descriptor_key') + key = properties.String("descriptor_key") @classmethod def get_type(cls, data) -> Type[Serializable]: @@ -68,14 +71,17 @@ def _equals(self, other, attrs): [self.__getattribute__(key) for key in attrs] try: - return all([ - self.__getattribute__(key) == other.__getattribute__(key) for key in attrs - ]) + return all( + [ + self.__getattribute__(key) == other.__getattribute__(key) + for key in attrs + ] + ) except AttributeError: return False -class RealDescriptor(Serializable['RealDescriptor'], Descriptor): +class RealDescriptor(Serializable["RealDescriptor"], Descriptor): """A descriptor to hold real-valued numbers. Parameters @@ -91,20 +97,17 @@ class RealDescriptor(Serializable['RealDescriptor'], Descriptor): """ - lower_bound = properties.Float('lower_bound') - upper_bound = properties.Float('upper_bound') - units = properties.String('units', default='') - typ = properties.String('type', default='Real', deserializable=False) + lower_bound = properties.Float("lower_bound") + upper_bound = properties.Float("upper_bound") + units = properties.String("units", default="") + typ = properties.String("type", default="Real", deserializable=False) def __eq__(self, other): - return self._equals(other, ["key", "lower_bound", "upper_bound", "units", "typ"]) - - def __init__(self, - key: str, - *, - lower_bound: float, - upper_bound: float, - units: str): + return self._equals( + other, ["key", "lower_bound", "upper_bound", "units", "typ"] + ) + + def __init__(self, key: str, *, lower_bound: float, upper_bound: float, units: str): self.key: str = key self.lower_bound: float = lower_bound self.upper_bound: float = upper_bound @@ -115,10 +118,11 @@ def __str__(self): def __repr__(self): return "RealDescriptor({}, {}, {}, {})".format( - self.key, self.lower_bound, self.upper_bound, self.units) + self.key, self.lower_bound, self.upper_bound, self.units + ) -class IntegerDescriptor(Serializable['IntegerDescriptor'], Descriptor): +class IntegerDescriptor(Serializable["IntegerDescriptor"], Descriptor): """[ALPHA] A descriptor to hold integer-valued numbers. Warning: IntegerDescriptors are not fully supported by the Citrine Platform web interface @@ -135,9 +139,9 @@ class IntegerDescriptor(Serializable['IntegerDescriptor'], Descriptor): """ - lower_bound = properties.Integer('lower_bound') - upper_bound = properties.Integer('upper_bound') - typ = properties.String('type', default='Integer', deserializable=False) + lower_bound = properties.Integer("lower_bound") + upper_bound = properties.Integer("upper_bound") + typ = properties.String("type", default="Integer", deserializable=False) def __eq__(self, other): return self._equals(other, ["key", "lower_bound", "upper_bound", "typ"]) @@ -151,10 +155,12 @@ def __str__(self): return "".format(self.key) def __repr__(self): - return "IntegerDescriptor({}, {}, {})".format(self.key, self.lower_bound, self.upper_bound) + return "IntegerDescriptor({}, {}, {})".format( + self.key, self.lower_bound, self.upper_bound + ) -class ChemicalFormulaDescriptor(Serializable['ChemicalFormulaDescriptor'], Descriptor): +class ChemicalFormulaDescriptor(Serializable["ChemicalFormulaDescriptor"], Descriptor): """Captures domain-specific context about a stoichiometric chemical formula. Parameters @@ -164,7 +170,7 @@ class ChemicalFormulaDescriptor(Serializable['ChemicalFormulaDescriptor'], Descr """ - typ = properties.String('type', default='Inorganic', deserializable=False) + typ = properties.String("type", default="Inorganic", deserializable=False) def __eq__(self, other): return self._equals(other, ["key", "typ"]) @@ -179,7 +185,9 @@ def __repr__(self): return "ChemicalFormulaDescriptor(key={})".format(self.key) -class MolecularStructureDescriptor(Serializable['MolecularStructureDescriptor'], Descriptor): +class MolecularStructureDescriptor( + Serializable["MolecularStructureDescriptor"], Descriptor +): """ Material descriptor for an organic molecule. @@ -192,7 +200,7 @@ class MolecularStructureDescriptor(Serializable['MolecularStructureDescriptor'], """ - typ = properties.String('type', default='Organic', deserializable=False) + typ = properties.String("type", default="Organic", deserializable=False) def __eq__(self, other): return self._equals(other, ["key", "typ"]) @@ -207,7 +215,7 @@ def __repr__(self): return "MolecularStructureDescriptor(key={})".format(self.key) -class CategoricalDescriptor(Serializable['CategoricalDescriptor'], Descriptor): +class CategoricalDescriptor(Serializable["CategoricalDescriptor"], Descriptor): """A descriptor to hold categorical variables. An exhaustive list of categorical values may be supplied. @@ -221,8 +229,8 @@ class CategoricalDescriptor(Serializable['CategoricalDescriptor'], Descriptor): """ - typ = properties.String('type', default='Categorical', deserializable=False) - categories = properties.Set(properties.String, 'descriptor_values') + typ = properties.String("type", default="Categorical", deserializable=False) + categories = properties.Set(properties.String, "descriptor_values") def __eq__(self, other): return self._equals(other, ["key", "categories", "typ"]) @@ -238,10 +246,12 @@ def __str__(self): return "".format(self.key) def __repr__(self): - return "CategoricalDescriptor(key={}, categories={})".format(self.key, self.categories) + return "CategoricalDescriptor(key={}, categories={})".format( + self.key, self.categories + ) -class FormulationDescriptor(Serializable['FormulationDescriptor'], Descriptor): +class FormulationDescriptor(Serializable["FormulationDescriptor"], Descriptor): """A descriptor to hold formulations. Parameters @@ -254,7 +264,7 @@ class FormulationDescriptor(Serializable['FormulationDescriptor'], Descriptor): """ typ = properties.String( - 'type', default=FormulationKey.HIERARCHICAL.value, deserializable=False + "type", default=FormulationKey.HIERARCHICAL.value, deserializable=False ) def __init__(self, key: Union[FormulationKey, str]): diff --git a/src/citrine/informatics/design_candidate.py b/src/citrine/informatics/design_candidate.py index 94fbcd275..2561ed04b 100644 --- a/src/citrine/informatics/design_candidate.py +++ b/src/citrine/informatics/design_candidate.py @@ -6,26 +6,26 @@ __all__ = [ - 'DesignCandidate', - 'HierarchicalDesignCandidate', - 'DesignMaterial', - 'HierarchicalDesignMaterial', - 'SampleSearchSpaceResultCandidate', - 'DesignVariable', - 'MeanAndStd', - 'TopCategories', - 'Mixture', - 'ChemicalFormula', - 'MolecularStructure', + "DesignCandidate", + "HierarchicalDesignCandidate", + "DesignMaterial", + "HierarchicalDesignMaterial", + "SampleSearchSpaceResultCandidate", + "DesignVariable", + "MeanAndStd", + "TopCategories", + "Mixture", + "ChemicalFormula", + "MolecularStructure", ] class DesignCandidateComment(Serializable["DesignCandidateComment"]): - message = properties.String('message') + message = properties.String("message") """:str: the text of the comment""" - created_by = properties.UUID('created.user') + created_by = properties.UUID("created.user") """:UUID: id of the user who created the comment""" - create_time = properties.Datetime('created.time') + create_time = properties.Datetime("created.time") """:datetime: date and time at which the comment was created""" @@ -47,7 +47,7 @@ def get_type(cls, data) -> Type[Serializable]: "C": TopCategories, "M": Mixture, "F": ChemicalFormula, - "S": MolecularStructure + "S": MolecularStructure, }[data["type"]] @@ -57,11 +57,11 @@ class MeanAndStd(Serializable["MeanAndStd"], DesignVariable): This does not imply that the distribution is Normal. """ - mean = properties.Float('m') + mean = properties.Float("m") """:float: mean of the continuous distribution""" - std = properties.Float('s') + std = properties.Float("s") """:float: standard deviation of the continuous distribution""" - typ = properties.String('type', default='R', deserializable=False) + typ = properties.String("type", default="R", deserializable=False) """:str: polymorphic type code""" def __init__(self, *, mean: float, std: float): @@ -77,9 +77,9 @@ class TopCategories(Serializable["CategoriesAndProbabilities"], DesignVariable): may have non-zero probabilities. """ - probabilities = properties.Mapping(properties.String, properties.Float, 'cp') + probabilities = properties.Mapping(properties.String, properties.Float, "cp") """:Dict[str, float]: mapping from category names to their probabilities""" - typ = properties.String('type', default='C', deserializable=False) + typ = properties.String("type", default="C", deserializable=False) """:str: polymorphic type code""" def __init__(self, *, probabilities: dict): @@ -94,11 +94,13 @@ class Mixture(Serializable["Mixture"], DesignVariable): truncation (but there may be rounding). """ - quantities = properties.Mapping(properties.String, properties.Float, 'q') + quantities = properties.Mapping(properties.String, properties.Float, "q") """:Dict[str, float]: mapping from ingredient identifiers to their quantities""" - labels = properties.Mapping(properties.String, properties.Set(properties.String), 'l') + labels = properties.Mapping( + properties.String, properties.Set(properties.String), "l" + ) """:Dict[str, Set[str]]: mapping from label identifiers to their associated ingredients""" - typ = properties.String('type', default='M', deserializable=False) + typ = properties.String("type", default="M", deserializable=False) """:str: polymorphic type code""" def __init__(self, *, quantities: dict, labels: Optional[dict] = None): @@ -110,9 +112,9 @@ def __init__(self, *, quantities: dict, labels: Optional[dict] = None): class ChemicalFormula(Serializable["ChemicalFormula"], DesignVariable): """Chemical formula as a string.""" - formula = properties.String('f') + formula = properties.String("f") """:str: chemical formula""" - typ = properties.String('type', default='F', deserializable=False) + typ = properties.String("type", default="F", deserializable=False) """:str: polymorphic type code""" def __init__(self, *, formula: str): @@ -123,9 +125,9 @@ def __init__(self, *, formula: str): class MolecularStructure(Serializable["MolecularStructure"], DesignVariable): """SMILES string representation of a molecular structure.""" - smiles = properties.String('s') + smiles = properties.String("s") """:str: SMILES string""" - typ = properties.String('type', default='S', deserializable=False) + typ = properties.String("type", default="S", deserializable=False) """:str: polymorphic type code""" def __init__(self, *, smiles: str): @@ -136,15 +138,21 @@ def __init__(self, *, smiles: str): class DesignMaterial(Serializable["DesignMaterial"]): """Description of the material that was designed, as a set of DesignVariables.""" - material_id = properties.UUID('identifiers.id') + material_id = properties.UUID("identifiers.id") """:UUID: unique internal Citrine id of the material""" - identifiers = properties.List(properties.String, 'identifiers.external', default=[]) + identifiers = properties.List(properties.String, "identifiers.external", default=[]) """:List[str]: globally unique identifiers assigned to the material""" - process_template = properties.Optional(properties.UUID, 'identifiers.process_template') + process_template = properties.Optional( + properties.UUID, "identifiers.process_template" + ) """:Optional[UUID]: GEMD process template that describes the process to create this material""" - material_template = properties.Optional(properties.UUID, 'identifiers.material_template') + material_template = properties.Optional( + properties.UUID, "identifiers.material_template" + ) """:Optional[UUID]: GEMD material template that describes this material""" - values = properties.Mapping(properties.String, properties.Object(DesignVariable), 'vars') + values = properties.Mapping( + properties.String, properties.Object(DesignVariable), "vars" + ) """:Dict[str, DesignVariable]: mapping from descriptor keys to the value for this material""" def __init__(self, *, values: dict): @@ -161,11 +169,13 @@ class HierarchicalDesignMaterial(Serializable["HierarchicalDesignMaterial"]): that associates each material (by Citrine ID) with the ingredients that comprise it. """ - root = properties.Object(DesignMaterial, 'terminal') + root = properties.Object(DesignMaterial, "terminal") """:DesignMaterial: root material containing features and predicted properties""" - sub_materials = properties.List(properties.Object(DesignMaterial), 'sub_materials') + sub_materials = properties.List(properties.Object(DesignMaterial), "sub_materials") """:List[DesignMaterial]: all other materials appearing in the history of the root""" - mixtures = properties.Mapping(properties.UUID, properties.Object(Mixture), 'mixtures') + mixtures = properties.Mapping( + properties.UUID, properties.Object(Mixture), "mixtures" + ) """:Dict[UUID, Mixture]: mapping from Citrine ID to components the material is composed of""" @@ -175,26 +185,28 @@ class DesignCandidate(Serializable["DesignCandidate"]): This class represents the candidate computed by a design execution. """ - uid = properties.UUID('id') + uid = properties.UUID("id") """:UUID: unique external Citrine id of the material""" - material_id = properties.UUID('material_id') + material_id = properties.UUID("material_id") """:UUID: unique internal Citrine id of the material""" - identifiers = properties.List(properties.String(), 'identifiers') + identifiers = properties.List(properties.String(), "identifiers") """:List[str]: globally unique identifiers assigned to the material""" - primary_score = properties.Float('primary_score') + primary_score = properties.Float("primary_score") """:float: numerical score describing how well the candidate satisfies the objectives and constraints (higher is better)""" - material = properties.Object(DesignMaterial, 'material') + material = properties.Object(DesignMaterial, "material") """:DesignMaterial: the material returned by the design workflow""" - name = properties.String('name') + name = properties.String("name") """:str: the name of the candidate""" - hidden = properties.Boolean('hidden') + hidden = properties.Boolean("hidden") """:str: whether the candidate is marked hidden""" - pinned_by = properties.Optional(properties.UUID, 'pinned.user') + pinned_by = properties.Optional(properties.UUID, "pinned.user") """:Optional[UUID]: id of the user who pinned the candidate, if it's been pinned""" - pinned_time = properties.Optional(properties.Datetime, 'pinned.time') + pinned_time = properties.Optional(properties.Datetime, "pinned.time") """:Optional[datetime]: date and time at which the candidate was pinned, if it's been pinned""" - comments = properties.List(properties.Object(DesignCandidateComment), 'comments', default=[]) + comments = properties.List( + properties.Object(DesignCandidateComment), "comments", default=[] + ) """:list[DesignCandidateComment]: the list of comments on the candidate, with metadata.""" @@ -204,9 +216,9 @@ class HierarchicalDesignCandidate(Serializable["HierarchicalDesignCandidate"]): This class represents the candidate computed by a design execution. """ - uid = properties.UUID('id') + uid = properties.UUID("id") """:UUID: unique external Citrine ID of the material""" - primary_score = properties.Float('primary_score') + primary_score = properties.Float("primary_score") """:float: numerical score describing how well the candidate satisfies the objectives and constraints (higher is better)""" rank = properties.Integer("rank") @@ -215,15 +227,17 @@ class HierarchicalDesignCandidate(Serializable["HierarchicalDesignCandidate"]): """:HierarchicalDesignMaterial: the material returned by the design workflow""" -class SampleSearchSpaceResultCandidate(Serializable["SampleSearchSpaceResultCandidate"]): +class SampleSearchSpaceResultCandidate( + Serializable["SampleSearchSpaceResultCandidate"] +): """A hierarchical candidate material generated by the Citrine Platform. This class represents the candidate computed by a design execution. """ - uid = properties.UUID('id') + uid = properties.UUID("id") """:UUID: unique external Citrine ID of the material""" - execution_uid = properties.UUID('id') + execution_uid = properties.UUID("id") """:UUID: unique external Citrine ID of the execution""" material = properties.Object(HierarchicalDesignMaterial, "material") """:HierarchicalDesignMaterial: the material returned by the design workflow""" diff --git a/src/citrine/informatics/design_spaces/data_source_design_space.py b/src/citrine/informatics/design_spaces/data_source_design_space.py index dc8ef4409..86b838a37 100644 --- a/src/citrine/informatics/design_spaces/data_source_design_space.py +++ b/src/citrine/informatics/design_spaces/data_source_design_space.py @@ -3,10 +3,10 @@ from citrine.informatics.data_sources import DataSource from citrine.informatics.design_spaces.design_space import DesignSpace -__all__ = ['DataSourceDesignSpace'] +__all__ = ["DataSourceDesignSpace"] -class DataSourceDesignSpace(EngineResource['DataSourceDesignSpace'], DesignSpace): +class DataSourceDesignSpace(EngineResource["DataSourceDesignSpace"], DesignSpace): """An enumeration of candidates stored in a data source. Parameters @@ -20,19 +20,16 @@ class DataSourceDesignSpace(EngineResource['DataSourceDesignSpace'], DesignSpace """ - data_source = properties.Object(DataSource, 'data.instance.data_source') + data_source = properties.Object(DataSource, "data.instance.data_source") - typ = properties.String('data.instance.type', default='DataSourceDesignSpace', - deserializable=False) + typ = properties.String( + "data.instance.type", default="DataSourceDesignSpace", deserializable=False + ) - def __init__(self, - name: str, - *, - description: str, - data_source: DataSource): + def __init__(self, name: str, *, description: str, data_source: DataSource): self.name: str = name self.description: str = description self.data_source: DataSource = data_source def __str__(self): - return ''.format(self.name) + return "".format(self.name) diff --git a/src/citrine/informatics/design_spaces/design_space.py b/src/citrine/informatics/design_spaces/design_space.py index 295d4094e..edffe5d18 100644 --- a/src/citrine/informatics/design_spaces/design_space.py +++ b/src/citrine/informatics/design_spaces/design_space.py @@ -1,4 +1,5 @@ """Tools for working with design spaces.""" + from typing import Optional, Type from uuid import UUID @@ -7,31 +8,34 @@ from citrine._serialization.polymorphic_serializable import PolymorphicSerializable from citrine._serialization.serializable import Serializable from citrine._session import Session -from citrine.resources.sample_design_space_execution import \ - SampleDesignSpaceExecutionCollection +from citrine.resources.sample_design_space_execution import ( + SampleDesignSpaceExecutionCollection, +) -__all__ = ['DesignSpace'] +__all__ = ["DesignSpace"] -class DesignSpace(PolymorphicSerializable['DesignSpace'], AsynchronousObject): +class DesignSpace(PolymorphicSerializable["DesignSpace"], AsynchronousObject): """A Citrine Design Space describes the set of materials that can be made. Abstract type that returns the proper type given a serialized dict. """ - uid = properties.Optional(properties.UUID, 'id', serializable=False) + uid = properties.Optional(properties.UUID, "id", serializable=False) """:Optional[UUID]: Citrine Platform unique identifier""" - name = properties.String('data.name') - description = properties.Optional(properties.String(), 'data.description') + name = properties.String("data.name") + description = properties.Optional(properties.String(), "data.description") - locked_by = properties.Optional(properties.UUID, 'metadata.locked.user', - serializable=False) + locked_by = properties.Optional( + properties.UUID, "metadata.locked.user", serializable=False + ) """:Optional[UUID]: id of the user whose action cause the design space to be locked, if it is locked""" - lock_time = properties.Optional(properties.Datetime, 'metadata.locked.time', - serializable=False) + lock_time = properties.Optional( + properties.Datetime, "metadata.locked.time", serializable=False + ) """:Optional[datetime]: date and time at which the resource was locked, if it is locked""" @@ -45,7 +49,7 @@ def wrap_instance(subspace_data: dict) -> dict: "data": { "name": subspace_data.get("name", ""), "description": subspace_data.get("description", ""), - "instance": subspace_data + "instance": subspace_data, } } @@ -66,13 +70,13 @@ def get_type(cls, data) -> Type[Serializable]: from .hierarchical_design_space import HierarchicalDesignSpace return { - 'Univariate': ProductDesignSpace, - 'ProductDesignSpace': ProductDesignSpace, - 'EnumeratedDesignSpace': EnumeratedDesignSpace, - 'FormulationDesignSpace': FormulationDesignSpace, - 'DataSourceDesignSpace': DataSourceDesignSpace, - 'HierarchicalDesignSpace': HierarchicalDesignSpace - }[data['data']['instance']['type']] + "Univariate": ProductDesignSpace, + "ProductDesignSpace": ProductDesignSpace, + "EnumeratedDesignSpace": EnumeratedDesignSpace, + "FormulationDesignSpace": FormulationDesignSpace, + "DataSourceDesignSpace": DataSourceDesignSpace, + "HierarchicalDesignSpace": HierarchicalDesignSpace, + }[data["data"]["instance"]["type"]] @property def is_locked(self) -> bool: diff --git a/src/citrine/informatics/design_spaces/design_space_settings.py b/src/citrine/informatics/design_spaces/design_space_settings.py index 3c28ef0bd..b1f29eb6f 100644 --- a/src/citrine/informatics/design_spaces/design_space_settings.py +++ b/src/citrine/informatics/design_spaces/design_space_settings.py @@ -17,8 +17,8 @@ class DefaultDesignSpaceMode(BaseEnumeration): * HIERARCHICAL results in a hierarchical design space resembling the shape of training data """ - ATTRIBUTE = 'ATTRIBUTE' - HIERARCHICAL = 'HIERARCHICAL' + ATTRIBUTE = "ATTRIBUTE" + HIERARCHICAL = "HIERARCHICAL" class DesignSpaceSettings(Resource["DesignSpaceSettings"]): @@ -27,10 +27,12 @@ class DesignSpaceSettings(Resource["DesignSpaceSettings"]): predictor_id = properties.UUID("predictor_id") predictor_version = properties.Optional( properties.Union([properties.Integer(), properties.String()]), - 'predictor_version' + "predictor_version", ) mode = properties.Optional(properties.Enumeration(DefaultDesignSpaceMode), "mode") - exclude_intermediates = properties.Optional(properties.Boolean(), "exclude_intermediates") + exclude_intermediates = properties.Optional( + properties.Boolean(), "exclude_intermediates" + ) include_ingredient_fraction_constraints = properties.Optional( properties.Boolean(), "include_ingredient_fraction_constraints" ) @@ -44,21 +46,25 @@ class DesignSpaceSettings(Resource["DesignSpaceSettings"]): properties.Boolean(), "include_parameter_constraints" ) - def __init__(self, - *, - predictor_id: Union[UUID, str], - predictor_version: Optional[Union[int, str]] = None, - mode: Optional[DefaultDesignSpaceMode] = None, - exclude_intermediates: Optional[bool] = None, - include_ingredient_fraction_constraints: Optional[bool] = None, - include_label_fraction_constraints: Optional[bool] = None, - include_label_count_constraints: Optional[bool] = None, - include_parameter_constraints: Optional[bool] = None): + def __init__( + self, + *, + predictor_id: Union[UUID, str], + predictor_version: Optional[Union[int, str]] = None, + mode: Optional[DefaultDesignSpaceMode] = None, + exclude_intermediates: Optional[bool] = None, + include_ingredient_fraction_constraints: Optional[bool] = None, + include_label_fraction_constraints: Optional[bool] = None, + include_label_count_constraints: Optional[bool] = None, + include_parameter_constraints: Optional[bool] = None, + ): self.predictor_id = predictor_id self.predictor_version = predictor_version self.mode = mode self.exclude_intermediates = exclude_intermediates - self.include_ingredient_fraction_constraints = include_ingredient_fraction_constraints + self.include_ingredient_fraction_constraints = ( + include_ingredient_fraction_constraints + ) self.include_label_fraction_constraints = include_label_fraction_constraints self.include_label_count_constraints = include_label_count_constraints self.include_parameter_constraints = include_parameter_constraints diff --git a/src/citrine/informatics/design_spaces/enumerated_design_space.py b/src/citrine/informatics/design_spaces/enumerated_design_space.py index 5b5115cfa..e4712a929 100644 --- a/src/citrine/informatics/design_spaces/enumerated_design_space.py +++ b/src/citrine/informatics/design_spaces/enumerated_design_space.py @@ -6,10 +6,10 @@ from citrine.informatics.descriptors import Descriptor from citrine.informatics.design_spaces.design_space import DesignSpace -__all__ = ['EnumeratedDesignSpace'] +__all__ = ["EnumeratedDesignSpace"] -class EnumeratedDesignSpace(EngineResource['EnumeratedDesignSpace'], DesignSpace): +class EnumeratedDesignSpace(EngineResource["EnumeratedDesignSpace"], DesignSpace): """An explicit enumeration of candidate materials to score. Enumerated design spaces are intended to capture small spaces with fewer than @@ -29,29 +29,38 @@ class EnumeratedDesignSpace(EngineResource['EnumeratedDesignSpace'], DesignSpace """ - descriptors = properties.List(properties.Object(Descriptor), 'data.instance.descriptors') - _data = properties.List(properties.Mapping(properties.String, - properties.Union([properties.String(), - properties.Integer(), - properties.Float()])), - 'data.instance.data') + descriptors = properties.List( + properties.Object(Descriptor), "data.instance.descriptors" + ) + _data = properties.List( + properties.Mapping( + properties.String, + properties.Union( + [properties.String(), properties.Integer(), properties.Float()] + ), + ), + "data.instance.data", + ) - typ = properties.String('data.instance.type', default='EnumeratedDesignSpace', - deserializable=False) + typ = properties.String( + "data.instance.type", default="EnumeratedDesignSpace", deserializable=False + ) - def __init__(self, - name: str, - *, - description: str, - descriptors: List[Descriptor], - data: List[Mapping[str, Union[int, float, str]]]): + def __init__( + self, + name: str, + *, + description: str, + descriptors: List[Descriptor], + data: List[Mapping[str, Union[int, float, str]]], + ): self.name: str = name self.description: str = description self.descriptors: List[Descriptor] = descriptors self.data: List[Mapping[str, Union[int, float, str]]] = data def __str__(self): - return ''.format(self.name) + return "".format(self.name) @property def data(self) -> List[Mapping[str, Union[int, float, str]]]: @@ -63,7 +72,9 @@ def data(self, value: List[Mapping[str, Union[int, float, str]]]): for item in value: for el in item.values(): if isinstance(el, (int, float)): - warn("Providing numeric data values is deprecated as of 3.4.7, and will be " - "dropped in 4.0.0. Please use strings instead.", - DeprecationWarning) + warn( + "Providing numeric data values is deprecated as of 3.4.7, and will be " + "dropped in 4.0.0. Please use strings instead.", + DeprecationWarning, + ) self._data = value diff --git a/src/citrine/informatics/design_spaces/formulation_design_space.py b/src/citrine/informatics/design_spaces/formulation_design_space.py index a77e65b2d..85c20914f 100644 --- a/src/citrine/informatics/design_spaces/formulation_design_space.py +++ b/src/citrine/informatics/design_spaces/formulation_design_space.py @@ -6,10 +6,10 @@ from citrine.informatics.descriptors import FormulationDescriptor from citrine.informatics.design_spaces.design_space import DesignSpace -__all__ = ['FormulationDesignSpace'] +__all__ = ["FormulationDesignSpace"] -class FormulationDesignSpace(EngineResource['FormulationDesignSpace'], DesignSpace): +class FormulationDesignSpace(EngineResource["FormulationDesignSpace"], DesignSpace): """Design space composed of mixtures of ingredients. Parameters @@ -37,32 +37,33 @@ class FormulationDesignSpace(EngineResource['FormulationDesignSpace'], DesignSpa """ formulation_descriptor = properties.Object( - FormulationDescriptor, - 'data.instance.formulation_descriptor' + FormulationDescriptor, "data.instance.formulation_descriptor" ) - ingredients = properties.Set(properties.String, 'data.instance.ingredients') - labels = properties.Optional(properties.Mapping( - properties.String, - properties.Set(properties.String) - ), 'data.instance.labels') - constraints = properties.Set(properties.Object(Constraint), 'data.instance.constraints') - resolution = properties.Float('data.instance.resolution') + ingredients = properties.Set(properties.String, "data.instance.ingredients") + labels = properties.Optional( + properties.Mapping(properties.String, properties.Set(properties.String)), + "data.instance.labels", + ) + constraints = properties.Set( + properties.Object(Constraint), "data.instance.constraints" + ) + resolution = properties.Float("data.instance.resolution") typ = properties.String( - 'data.instance.type', - default='FormulationDesignSpace', - deserializable=False + "data.instance.type", default="FormulationDesignSpace", deserializable=False ) - def __init__(self, - name: str, - *, - description: str, - formulation_descriptor: FormulationDescriptor, - ingredients: Set[str], - constraints: Set[Constraint], - labels: Optional[Mapping[str, Set[str]]] = None, - resolution: float = 0.0001): + def __init__( + self, + name: str, + *, + description: str, + formulation_descriptor: FormulationDescriptor, + ingredients: Set[str], + constraints: Set[Constraint], + labels: Optional[Mapping[str, Set[str]]] = None, + resolution: float = 0.0001, + ): self.name: str = name self.description: str = description self.formulation_descriptor: FormulationDescriptor = formulation_descriptor @@ -72,4 +73,4 @@ def __init__(self, self.resolution: float = resolution def __str__(self): - return ''.format(self.name) + return "".format(self.name) diff --git a/src/citrine/informatics/design_spaces/hierarchical_design_space.py b/src/citrine/informatics/design_spaces/hierarchical_design_space.py index 205441820..122c8bc72 100644 --- a/src/citrine/informatics/design_spaces/hierarchical_design_space.py +++ b/src/citrine/informatics/design_spaces/hierarchical_design_space.py @@ -10,11 +10,7 @@ from citrine.informatics.design_spaces.design_space import DesignSpace from citrine.informatics.design_spaces.design_space_settings import DesignSpaceSettings -__all__ = [ - "TemplateLink", - "MaterialNodeDefinition", - "HierarchicalDesignSpace" -] +__all__ = ["TemplateLink", "MaterialNodeDefinition", "HierarchicalDesignSpace"] class TemplateLink(Serializable["TemplateLink"]): @@ -35,16 +31,20 @@ class TemplateLink(Serializable["TemplateLink"]): material_template = properties.UUID("material_template") process_template = properties.UUID("process_template") - material_template_name = properties.Optional(properties.String, "material_template_name") - process_template_name = properties.Optional(properties.String, "process_template_name") + material_template_name = properties.Optional( + properties.String, "material_template_name" + ) + process_template_name = properties.Optional( + properties.String, "process_template_name" + ) def __init__( - self, - *, - material_template: UUID, - process_template: UUID, - material_template_name: Optional[str] = None, - process_template_name: Optional[str] = None + self, + *, + material_template: UUID, + process_template: UUID, + material_template_name: Optional[str] = None, + process_template_name: Optional[str] = None, ): self.material_template: UUID = material_template self.process_template: UUID = process_template @@ -87,19 +87,21 @@ class MaterialNodeDefinition(Serializable["MaterialNodeDefinition"]): display_name = properties.Optional(properties.String, "display_name") def __init__( - self, - *, - name: str, - scope: Optional[str] = None, - attributes: Optional[List[Dimension]] = None, - formulation_subspace: Optional[FormulationDesignSpace] = None, - template_link: Optional[TemplateLink] = None, - display_name: Optional[str] = None + self, + *, + name: str, + scope: Optional[str] = None, + attributes: Optional[List[Dimension]] = None, + formulation_subspace: Optional[FormulationDesignSpace] = None, + template_link: Optional[TemplateLink] = None, + display_name: Optional[str] = None, ): self.name = name self.scope: Optional[str] = scope self.attributes = attributes or list() - self.formulation_subspace: Optional[FormulationDesignSpace] = formulation_subspace + self.formulation_subspace: Optional[FormulationDesignSpace] = ( + formulation_subspace + ) self.template_link: Optional[TemplateLink] = template_link self.display_name: Optional[str] = display_name @@ -151,7 +153,9 @@ class HierarchicalDesignSpace(EngineResource["HierarchicalDesignSpace"], DesignS """ - _settings = properties.Optional(properties.Object(DesignSpaceSettings), "metadata.settings") + _settings = properties.Optional( + properties.Object(DesignSpaceSettings), "metadata.settings" + ) root = properties.Object(MaterialNodeDefinition, "data.instance.root") subspaces = properties.List( @@ -165,13 +169,13 @@ class HierarchicalDesignSpace(EngineResource["HierarchicalDesignSpace"], DesignS ) def __init__( - self, - name: str, - *, - description: str, - root: MaterialNodeDefinition, - subspaces: Optional[List[MaterialNodeDefinition]] = None, - data_sources: Optional[List[DataSource]] = None + self, + name: str, + *, + description: str, + root: MaterialNodeDefinition, + subspaces: Optional[List[MaterialNodeDefinition]] = None, + data_sources: Optional[List[DataSource]] = None, ): self.name: str = name self.description: str = description @@ -189,8 +193,7 @@ def _post_dump(self, data: dict) -> dict: data["instance"]["root"] = self.__unwrap_node(root_node) data["instance"]["subspaces"] = [ - self.__unwrap_node(sub_node) - for sub_node in data['instance']['subspaces'] + self.__unwrap_node(sub_node) for sub_node in data["instance"]["subspaces"] ] return data @@ -200,26 +203,29 @@ def _pre_build(cls, data: dict) -> dict: data["data"]["instance"]["root"] = cls.__wrap_node(root_node) data["data"]["instance"]["subspaces"] = [ - cls.__wrap_node(sub_node) for sub_node in data['data']['instance']['subspaces'] + cls.__wrap_node(sub_node) + for sub_node in data["data"]["instance"]["subspaces"] ] return data @staticmethod def __wrap_node(node_data: dict) -> dict: - formulation_subspace = node_data.pop('formulation', None) + formulation_subspace = node_data.pop("formulation", None) if formulation_subspace: - node_data['formulation'] = DesignSpace.wrap_instance(formulation_subspace) + node_data["formulation"] = DesignSpace.wrap_instance(formulation_subspace) return node_data @staticmethod def __unwrap_node(node_data: dict) -> dict: - formulation_subspace = node_data.pop('formulation', None) + formulation_subspace = node_data.pop("formulation", None) if formulation_subspace: - node_data['formulation'] = formulation_subspace['data']['instance'] - node_data['formulation']['name'] = formulation_subspace['data']['name'] - node_data['formulation']['description'] = formulation_subspace['data']['description'] + node_data["formulation"] = formulation_subspace["data"]["instance"] + node_data["formulation"]["name"] = formulation_subspace["data"]["name"] + node_data["formulation"]["description"] = formulation_subspace["data"][ + "description" + ] return node_data def __repr__(self): - return f'' + return f"" diff --git a/src/citrine/informatics/design_spaces/product_design_space.py b/src/citrine/informatics/design_spaces/product_design_space.py index d52f6a640..b5bcfb1a3 100644 --- a/src/citrine/informatics/design_spaces/product_design_space.py +++ b/src/citrine/informatics/design_spaces/product_design_space.py @@ -7,10 +7,10 @@ from citrine.informatics.design_spaces.design_space_settings import DesignSpaceSettings from citrine.informatics.dimensions import Dimension -__all__ = ['ProductDesignSpace'] +__all__ = ["ProductDesignSpace"] -class ProductDesignSpace(EngineResource['ProductDesignSpace'], DesignSpace): +class ProductDesignSpace(EngineResource["ProductDesignSpace"], DesignSpace): """A Cartesian product of design spaces. Factors can be other design spaces and/or univariate dimensions. @@ -29,23 +29,29 @@ class ProductDesignSpace(EngineResource['ProductDesignSpace'], DesignSpace): """ - _settings = properties.Optional(properties.Object(DesignSpaceSettings), "metadata.settings") + _settings = properties.Optional( + properties.Object(DesignSpaceSettings), "metadata.settings" + ) - subspaces = properties.List(properties.Object(DesignSpace), 'data.instance.subspaces', - default=[]) + subspaces = properties.List( + properties.Object(DesignSpace), "data.instance.subspaces", default=[] + ) dimensions = properties.Optional( - properties.List(properties.Object(Dimension)), 'data.instance.dimensions' + properties.List(properties.Object(Dimension)), "data.instance.dimensions" ) - typ = properties.String('data.instance.type', default='ProductDesignSpace', - deserializable=False) + typ = properties.String( + "data.instance.type", default="ProductDesignSpace", deserializable=False + ) - def __init__(self, - name: str, - *, - description: str, - subspaces: Optional[List[Union[UUID, DesignSpace]]] = None, - dimensions: Optional[List[Dimension]] = None): + def __init__( + self, + name: str, + *, + description: str, + subspaces: Optional[List[Union[UUID, DesignSpace]]] = None, + dimensions: Optional[List[Dimension]] = None, + ): self.name: str = name self.description: str = description self.subspaces: List[Union[UUID, DesignSpace]] = subspaces or [] @@ -57,18 +63,20 @@ def _post_dump(self, data: dict) -> dict: if self._settings: data["settings"] = self._settings.dump() - for i, subspace in enumerate(data['instance']['subspaces']): + for i, subspace in enumerate(data["instance"]["subspaces"]): if isinstance(subspace, dict): # embedded design spaces are not modules, so only serialize their config - data['instance']['subspaces'][i] = subspace['instance'] + data["instance"]["subspaces"][i] = subspace["instance"] return data @classmethod def _pre_build(cls, data: dict) -> dict: - for i, subspace_data in enumerate(data['data']['instance']['subspaces']): + for i, subspace_data in enumerate(data["data"]["instance"]["subspaces"]): if isinstance(subspace_data, dict): - data['data']['instance']['subspaces'][i] = DesignSpace.wrap_instance(subspace_data) + data["data"]["instance"]["subspaces"][i] = DesignSpace.wrap_instance( + subspace_data + ) return data def __str__(self): - return ''.format(self.name) + return "".format(self.name) diff --git a/src/citrine/informatics/design_spaces/sample_design_space.py b/src/citrine/informatics/design_spaces/sample_design_space.py index 64c22a9a9..e7832bbac 100644 --- a/src/citrine/informatics/design_spaces/sample_design_space.py +++ b/src/citrine/informatics/design_spaces/sample_design_space.py @@ -2,7 +2,7 @@ from citrine._serialization.serializable import Serializable -class SampleDesignSpaceInput(Serializable['SampleDesignSpaceInput']): +class SampleDesignSpaceInput(Serializable["SampleDesignSpaceInput"]): """A Citrine Sample Design Space Execution Input. Parameters diff --git a/src/citrine/informatics/dimensions.py b/src/citrine/informatics/dimensions.py index db73e50aa..6737ac1a4 100644 --- a/src/citrine/informatics/dimensions.py +++ b/src/citrine/informatics/dimensions.py @@ -1,15 +1,25 @@ """Tools for working with Dimensions.""" + from typing import Optional, Type, List from citrine._serialization import properties from citrine._serialization.polymorphic_serializable import PolymorphicSerializable from citrine._serialization.serializable import Serializable -from citrine.informatics.descriptors import Descriptor, RealDescriptor, IntegerDescriptor +from citrine.informatics.descriptors import ( + Descriptor, + RealDescriptor, + IntegerDescriptor, +) -__all__ = ['Dimension', 'ContinuousDimension', 'IntegerDimension', 'EnumeratedDimension'] +__all__ = [ + "Dimension", + "ContinuousDimension", + "IntegerDimension", + "EnumeratedDimension", +] -class Dimension(PolymorphicSerializable['Dimension']): +class Dimension(PolymorphicSerializable["Dimension"]): """A Dimension describes the values that a quantity can take in the context of a design space. Abstract type that returns the proper type given a serialized dict. @@ -20,13 +30,13 @@ class Dimension(PolymorphicSerializable['Dimension']): def get_type(cls, data) -> Type[Serializable]: """Return the subtype.""" return { - 'ContinuousDimension': ContinuousDimension, - 'IntegerDimension': IntegerDimension, - 'EnumeratedDimension': EnumeratedDimension - }[data['type']] + "ContinuousDimension": ContinuousDimension, + "IntegerDimension": IntegerDimension, + "EnumeratedDimension": EnumeratedDimension, + }[data["type"]] -class ContinuousDimension(Serializable['ContinuousDimension'], Dimension): +class ContinuousDimension(Serializable["ContinuousDimension"], Dimension): """A continuous, real-valued dimension. Parameters @@ -40,21 +50,28 @@ class ContinuousDimension(Serializable['ContinuousDimension'], Dimension): """ - descriptor = properties.Object(RealDescriptor, 'descriptor') - lower_bound = properties.Float('lower_bound') - upper_bound = properties.Float('upper_bound') - typ = properties.String('type', default='ContinuousDimension', deserializable=False) - - def __init__(self, - descriptor: RealDescriptor, *, - lower_bound: Optional[float] = None, - upper_bound: Optional[float] = None): + descriptor = properties.Object(RealDescriptor, "descriptor") + lower_bound = properties.Float("lower_bound") + upper_bound = properties.Float("upper_bound") + typ = properties.String("type", default="ContinuousDimension", deserializable=False) + + def __init__( + self, + descriptor: RealDescriptor, + *, + lower_bound: Optional[float] = None, + upper_bound: Optional[float] = None, + ): self.descriptor: RealDescriptor = descriptor - self.lower_bound = lower_bound if lower_bound is not None else descriptor.lower_bound - self.upper_bound = upper_bound if upper_bound is not None else descriptor.upper_bound + self.lower_bound = ( + lower_bound if lower_bound is not None else descriptor.lower_bound + ) + self.upper_bound = ( + upper_bound if upper_bound is not None else descriptor.upper_bound + ) -class IntegerDimension(Serializable['IntegerDimension'], Dimension): +class IntegerDimension(Serializable["IntegerDimension"], Dimension): """An integer-valued dimension with inclusive lower and upper bounds. Parameters @@ -68,21 +85,28 @@ class IntegerDimension(Serializable['IntegerDimension'], Dimension): """ - descriptor = properties.Object(IntegerDescriptor, 'descriptor') - lower_bound = properties.Integer('lower_bound') - upper_bound = properties.Integer('upper_bound') - typ = properties.String('type', default='IntegerDimension', deserializable=False) - - def __init__(self, - descriptor: IntegerDescriptor, *, - lower_bound: Optional[int] = None, - upper_bound: Optional[int] = None): + descriptor = properties.Object(IntegerDescriptor, "descriptor") + lower_bound = properties.Integer("lower_bound") + upper_bound = properties.Integer("upper_bound") + typ = properties.String("type", default="IntegerDimension", deserializable=False) + + def __init__( + self, + descriptor: IntegerDescriptor, + *, + lower_bound: Optional[int] = None, + upper_bound: Optional[int] = None, + ): self.descriptor: IntegerDescriptor = descriptor - self.lower_bound = lower_bound if lower_bound is not None else descriptor.lower_bound - self.upper_bound = upper_bound if upper_bound is not None else descriptor.upper_bound + self.lower_bound = ( + lower_bound if lower_bound is not None else descriptor.lower_bound + ) + self.upper_bound = ( + upper_bound if upper_bound is not None else descriptor.upper_bound + ) -class EnumeratedDimension(Serializable['EnumeratedDimension'], Dimension): +class EnumeratedDimension(Serializable["EnumeratedDimension"], Dimension): """A finite, enumerated dimension. Parameters @@ -94,12 +118,10 @@ class EnumeratedDimension(Serializable['EnumeratedDimension'], Dimension): """ - descriptor = properties.Object(Descriptor, 'descriptor') - values = properties.List(properties.String(), 'list') - typ = properties.String('type', default='EnumeratedDimension', deserializable=False) + descriptor = properties.Object(Descriptor, "descriptor") + values = properties.List(properties.String(), "list") + typ = properties.String("type", default="EnumeratedDimension", deserializable=False) - def __init__(self, - descriptor: Descriptor, *, - values: List[str]): + def __init__(self, descriptor: Descriptor, *, values: List[str]): self.descriptor: Descriptor = descriptor self.values: List[str] = values diff --git a/src/citrine/informatics/executions/design_execution.py b/src/citrine/informatics/executions/design_execution.py index 0511cd06f..89262bbc0 100644 --- a/src/citrine/informatics/executions/design_execution.py +++ b/src/citrine/informatics/executions/design_execution.py @@ -6,7 +6,10 @@ from citrine._serialization import properties from citrine._utils.functions import format_escaped_url from citrine.informatics.descriptors import Descriptor -from citrine.informatics.design_candidate import DesignCandidate, HierarchicalDesignCandidate +from citrine.informatics.design_candidate import ( + DesignCandidate, + HierarchicalDesignCandidate, +) from citrine.informatics.predict_request import PredictRequest from citrine.informatics.scores import Score from citrine.informatics.executions.execution import Execution @@ -21,61 +24,75 @@ class DesignExecution(Resource["DesignExecution"], Execution): """ _paginator: Paginator = Paginator() - _collection_key = 'response' - workflow_id = properties.UUID('workflow_id', serializable=False) + _collection_key = "response" + workflow_id = properties.UUID("workflow_id", serializable=False) """:UUID: Unique identifier of the workflow that was executed""" version_number = properties.Integer("version_number", serializable=False) """:int: Integer identifier that increases each time the workflow is executed. The first execution has version_number = 1.""" - score = properties.Object(Score, 'score') + score = properties.Object(Score, "score") """:Score: score by which this execution was evaluated""" - descriptors = properties.List(properties.Object(Descriptor), 'descriptors') + descriptors = properties.List(properties.Object(Descriptor), "descriptors") """:List[Descriptor]: all of the descriptors in the candidates generated by this execution""" def _path(self): return format_escaped_url( - '/projects/{project_id}/design-workflows/{workflow_id}/executions/{execution_id}', + "/projects/{project_id}/design-workflows/{workflow_id}/executions/{execution_id}", project_id=self.project_id, workflow_id=self.workflow_id, - execution_id=self.uid + execution_id=self.uid, ) @classmethod - def _build_candidates(cls, subset_collection: Iterable[dict]) -> Iterable[DesignCandidate]: + def _build_candidates( + cls, subset_collection: Iterable[dict] + ) -> Iterable[DesignCandidate]: for candidate in subset_collection: yield DesignCandidate.build(candidate) def candidates(self, *, per_page: int = 100) -> Iterable[DesignCandidate]: """Fetch the Design Candidates for the particular execution, paginated.""" - path = self._path() + '/candidates' + path = self._path() + "/candidates" - fetcher = partial(self._fetch_page, path=path, fetch_func=self._session.get_resource) + fetcher = partial( + self._fetch_page, path=path, fetch_func=self._session.get_resource + ) - return self._paginator.paginate(page_fetcher=fetcher, - collection_builder=self._build_candidates, - per_page=per_page) + return self._paginator.paginate( + page_fetcher=fetcher, + collection_builder=self._build_candidates, + per_page=per_page, + ) @classmethod def _build_hierarchical_candidates( - cls, subset_collection: Iterable[dict]) -> Iterable[HierarchicalDesignCandidate]: + cls, subset_collection: Iterable[dict] + ) -> Iterable[HierarchicalDesignCandidate]: for candidate in subset_collection: yield HierarchicalDesignCandidate.build(candidate) - def hierarchical_candidates(self, *, per_page: int = 100) -> Iterable[DesignCandidate]: + def hierarchical_candidates( + self, *, per_page: int = 100 + ) -> Iterable[DesignCandidate]: """Fetch the Design Candidates for the particular execution, paginated.""" - path = self._path() + '/candidate-histories' + path = self._path() + "/candidate-histories" - fetcher = partial(self._fetch_page, path=path, fetch_func=self._session.get_resource) + fetcher = partial( + self._fetch_page, path=path, fetch_func=self._session.get_resource + ) - return self._paginator.paginate(page_fetcher=fetcher, - collection_builder=self._build_hierarchical_candidates, - per_page=per_page) + return self._paginator.paginate( + page_fetcher=fetcher, + collection_builder=self._build_hierarchical_candidates, + per_page=per_page, + ) - def predict(self, - predict_request: PredictRequest) -> DesignCandidate: + def predict(self, predict_request: PredictRequest) -> DesignCandidate: """Invoke a prediction on a design candidate.""" - path = self._path() + '/predict' + path = self._path() + "/predict" - res = self._session.post_resource(path, predict_request.dump(), version=self._api_version) + res = self._session.post_resource( + path, predict_request.dump(), version=self._api_version + ) return DesignCandidate.build(res) diff --git a/src/citrine/informatics/executions/execution.py b/src/citrine/informatics/executions/execution.py index ea873f9eb..5eeb86d99 100644 --- a/src/citrine/informatics/executions/execution.py +++ b/src/citrine/informatics/executions/execution.py @@ -17,37 +17,44 @@ class Execution(Pageable, AsynchronousObject, ABC): """ _paginator: Paginator = Paginator() - _collection_key = 'response' + _collection_key = "response" _in_progress_statuses = ["INPROGRESS"] _succeeded_statuses = ["SUCCEEDED"] _failed_statuses = ["FAILED"] _session: Optional[Session] = None project_id: Optional[UUID] = None - uid: UUID = properties.UUID('id', serializable=False) + uid: UUID = properties.UUID("id", serializable=False) """:UUID: Unique identifier of the execution""" - status = properties.Optional(properties.String(), 'status', serializable=False) + status = properties.Optional(properties.String(), "status", serializable=False) """:Optional[str]: short description of the execution's status""" status_description = properties.Optional( - properties.String(), 'status_description', serializable=False) + properties.String(), "status_description", serializable=False + ) """:Optional[str]: more detailed description of the execution's status""" status_detail = properties.List( - properties.Object(StatusDetail), 'status_detail', default=[], serializable=False + properties.Object(StatusDetail), "status_detail", default=[], serializable=False ) """:List[StatusDetail]: a list of structured status info, containing the message and level""" - created_by = properties.Optional(properties.UUID, 'created_by', serializable=False) + created_by = properties.Optional(properties.UUID, "created_by", serializable=False) """:Optional[UUID]: id of the user who created the resource""" - updated_by = properties.Optional(properties.UUID, 'updated_by', serializable=False) + updated_by = properties.Optional(properties.UUID, "updated_by", serializable=False) """:Optional[UUID]: id of the user who most recently updated the resource, if it has been updated""" - create_time = properties.Optional(properties.Datetime, 'create_time', serializable=False) + create_time = properties.Optional( + properties.Datetime, "create_time", serializable=False + ) """:Optional[datetime]: date and time at which the resource was created""" - update_time = properties.Optional(properties.Datetime, 'update_time', serializable=False) + update_time = properties.Optional( + properties.Datetime, "update_time", serializable=False + ) """:Optional[datetime]: date and time at which the resource was most recently updated, if it has been updated""" def __str__(self): - return f'<{self.__class__.__name__} {str(self.uid)!r}>' + return f"<{self.__class__.__name__} {str(self.uid)!r}>" def _path(self): - raise NotImplementedError("Subclasses must implement the _path method") # pragma: no cover + raise NotImplementedError( + "Subclasses must implement the _path method" + ) # pragma: no cover diff --git a/src/citrine/informatics/executions/generative_design_execution.py b/src/citrine/informatics/executions/generative_design_execution.py index 694bea7a0..22aa4416c 100644 --- a/src/citrine/informatics/executions/generative_design_execution.py +++ b/src/citrine/informatics/executions/generative_design_execution.py @@ -18,7 +18,7 @@ class GenerativeDesignExecution(Resource["GenerativeDesignExecution"], Execution def _path(self): return format_escaped_url( - '/projects/{project_id}/generative-design/executions/', + "/projects/{project_id}/generative-design/executions/", project_id=self.project_id, ) @@ -31,11 +31,15 @@ def _build_results( def results(self, *, per_page: int = 100) -> Iterable[GenerativeDesignResult]: """Fetch the Generative Design Results for the particular execution, paginated.""" - path = self._path() + f'{self.uid}/results' - fetcher = partial(self._fetch_page, path=path, fetch_func=self._session.get_resource) - return self._paginator.paginate(page_fetcher=fetcher, - collection_builder=self._build_results, - per_page=per_page) + path = self._path() + f"{self.uid}/results" + fetcher = partial( + self._fetch_page, path=path, fetch_func=self._session.get_resource + ) + return self._paginator.paginate( + page_fetcher=fetcher, + collection_builder=self._build_results, + per_page=per_page, + ) def result( self, @@ -43,7 +47,7 @@ def result( result_id: UUID, ) -> GenerativeDesignResult: """Fetch a Generative Design Result for the particular UID.""" - path = self._path() + f'{self.uid}/results/{result_id}' + path = self._path() + f"{self.uid}/results/{result_id}" data = self._session.get_resource(path, version=self._api_version) result = GenerativeDesignResult.build(data) return result diff --git a/src/citrine/informatics/executions/predictor_evaluation.py b/src/citrine/informatics/executions/predictor_evaluation.py index f74dcd113..d1be0636b 100644 --- a/src/citrine/informatics/executions/predictor_evaluation.py +++ b/src/citrine/informatics/executions/predictor_evaluation.py @@ -14,7 +14,7 @@ from citrine.resources.status_detail import StatusDetail -class PredictorEvaluatorsResponse(Serializable['EvaluatorsPayload']): +class PredictorEvaluatorsResponse(Serializable["EvaluatorsPayload"]): """Container object for a default predictor evaluator response.""" evaluators = properties.List(properties.Object(PredictorEvaluator), "evaluators") @@ -23,39 +23,50 @@ def __init__(self, evaluators: List[PredictorEvaluator]): self.evaluators = evaluators -class PredictorEvaluationRequest(Serializable['EvaluatorsPayload']): +class PredictorEvaluationRequest(Serializable["EvaluatorsPayload"]): """Container object for a predictor evaluation request.""" predictor = properties.Object(PredictorRef, "predictor") evaluators = properties.List(properties.Object(PredictorEvaluator), "evaluators") - def __init__(self, - *, - evaluators: List[PredictorEvaluator], - predictor_id: Union[UUID, str], - predictor_version: Optional[Union[int, str]] = None): + def __init__( + self, + *, + evaluators: List[PredictorEvaluator], + predictor_id: Union[UUID, str], + predictor_version: Optional[Union[int, str]] = None, + ): self.evaluators = evaluators self.predictor = PredictorRef(predictor_id, predictor_version) -class PredictorEvaluation(EngineResourceWithoutStatus['PredictorEvaluation'], AsynchronousObject): +class PredictorEvaluation( + EngineResourceWithoutStatus["PredictorEvaluation"], AsynchronousObject +): """The evaluation of a predictor's performance.""" - uid: UUID = properties.UUID('id', serializable=False) + uid: UUID = properties.UUID("id", serializable=False) """:UUID: Unique identifier of the evaluation""" - evaluators = properties.List(properties.Object(PredictorEvaluator), "data.evaluators", - serializable=False) + evaluators = properties.List( + properties.Object(PredictorEvaluator), "data.evaluators", serializable=False + ) """:List{PredictorEvaluator]:the predictor evaluators that were executed. These are used when calling the ``results()`` method.""" - predictor_id = properties.UUID('metadata.predictor_id', serializable=False) + predictor_id = properties.UUID("metadata.predictor_id", serializable=False) """:UUID:""" - predictor_version = properties.Integer('metadata.predictor_version', serializable=False) - status = properties.String('metadata.status.major', serializable=False) + predictor_version = properties.Integer( + "metadata.predictor_version", serializable=False + ) + status = properties.String("metadata.status.major", serializable=False) """:str: short description of the evaluation's status""" - status_description = properties.String('metadata.status.minor', serializable=False) + status_description = properties.String("metadata.status.minor", serializable=False) """:str: more detailed description of the evaluation's status""" - status_detail = properties.List(properties.Object(StatusDetail), 'metadata.status.detail', - default=[], serializable=False) + status_detail = properties.List( + properties.Object(StatusDetail), + "metadata.status.detail", + default=[], + serializable=False, + ) """:List[StatusDetail]: a list of structured status info, containing the message and level""" project_id: Optional[UUID] = None @@ -66,9 +77,9 @@ class PredictorEvaluation(EngineResourceWithoutStatus['PredictorEvaluation'], As def _path(self): return format_escaped_url( - '/projects/{project_id}/predictor-evaluations/{evaluation_id}', + "/projects/{project_id}/predictor-evaluations/{evaluation_id}", project_id=str(self.project_id), - evaluation_id=str(self.uid) + evaluation_id=str(self.uid), ) @lru_cache() diff --git a/src/citrine/informatics/executions/predictor_evaluation_execution.py b/src/citrine/informatics/executions/predictor_evaluation_execution.py index 42a2cdb4e..e981b61b9 100644 --- a/src/citrine/informatics/executions/predictor_evaluation_execution.py +++ b/src/citrine/informatics/executions/predictor_evaluation_execution.py @@ -7,7 +7,7 @@ from citrine._rest.resource import Resource -class PredictorEvaluationExecution(Resource['PredictorEvaluationExecution'], Execution): +class PredictorEvaluationExecution(Resource["PredictorEvaluationExecution"], Execution): """[DEPRECATED] The execution of a PredictorEvaluationWorkflow. Possible statuses are INPROGRESS, SUCCEEDED, and FAILED. @@ -15,19 +15,21 @@ class PredictorEvaluationExecution(Resource['PredictorEvaluationExecution'], Exe """ - evaluator_names = properties.List(properties.String, "evaluator_names", serializable=False) + evaluator_names = properties.List( + properties.String, "evaluator_names", serializable=False + ) """:List[str]: names of the predictor evaluators that were executed. These are used when calling the ``results()`` method.""" - workflow_id = properties.UUID('workflow_id', serializable=False) + workflow_id = properties.UUID("workflow_id", serializable=False) """:UUID: Unique identifier of the workflow that was executed""" - predictor_id = properties.UUID('predictor_id', serializable=False) - predictor_version = properties.Integer('predictor_version', serializable=False) + predictor_id = properties.UUID("predictor_id", serializable=False) + predictor_version = properties.Integer("predictor_version", serializable=False) def _path(self): return format_escaped_url( - '/projects/{project_id}/predictor-evaluation-executions/{execution_id}', + "/projects/{project_id}/predictor-evaluation-executions/{execution_id}", project_id=self.project_id, - execution_id=self.uid + execution_id=self.uid, ) @lru_cache() diff --git a/src/citrine/informatics/executions/sample_design_space_execution.py b/src/citrine/informatics/executions/sample_design_space_execution.py index 6db97626b..6c9efa61f 100644 --- a/src/citrine/informatics/executions/sample_design_space_execution.py +++ b/src/citrine/informatics/executions/sample_design_space_execution.py @@ -8,7 +8,7 @@ from citrine._utils.functions import format_escaped_url -class SampleDesignSpaceExecution(Resource['SampleDesignSpaceExecution'], Execution): +class SampleDesignSpaceExecution(Resource["SampleDesignSpaceExecution"], Execution): """The execution of a Sample Design Space task. Possible statuses are INPROGRESS, SUCCEEDED, and FAILED. @@ -16,12 +16,12 @@ class SampleDesignSpaceExecution(Resource['SampleDesignSpaceExecution'], Executi """ - _api_version = 'v3' + _api_version = "v3" design_space_id: Optional[UUID] = None def _path(self): return format_escaped_url( - '/projects/{project_id}/design-spaces/{design_space_id}/sample/', + "/projects/{project_id}/design-spaces/{design_space_id}/sample/", project_id=self.project_id, design_space_id=self.design_space_id, ) @@ -32,9 +32,9 @@ def _pre_build(cls, data: dict) -> dict: # Flatten the status object in order to match other workflow objects. return { **data, - "status_description": data['status']['minor'], - "status_detail": data['status']['detail'], - "status": data['status']["major"] + "status_description": data["status"]["minor"], + "status_detail": data["status"]["detail"], + "status": data["status"]["major"], } @classmethod @@ -51,11 +51,15 @@ def results( per_page: int = 100, ) -> Iterable[SampleSearchSpaceResultCandidate]: """Fetch the Sample Design Space Results for the particular execution, paginated.""" - path = self._path() + f'{self.uid}/results' - fetcher = partial(self._fetch_page, path=path, fetch_func=self._session.get_resource) - return self._paginator.paginate(page_fetcher=fetcher, - collection_builder=self._build_results, - per_page=per_page) + path = self._path() + f"{self.uid}/results" + fetcher = partial( + self._fetch_page, path=path, fetch_func=self._session.get_resource + ) + return self._paginator.paginate( + page_fetcher=fetcher, + collection_builder=self._build_results, + per_page=per_page, + ) def result( self, @@ -63,7 +67,7 @@ def result( result_id: UUID, ) -> SampleSearchSpaceResultCandidate: """Fetch a Sample Design Space Result for the particular UID.""" - path = self._path() + f'{self.uid}/results/{result_id}' + path = self._path() + f"{self.uid}/results/{result_id}" data = self._session.get_resource(path, version=self._api_version) result = SampleSearchSpaceResultCandidate.build(data) return result diff --git a/src/citrine/informatics/experiment_values.py b/src/citrine/informatics/experiment_values.py index d73a62f0c..fa529cd89 100644 --- a/src/citrine/informatics/experiment_values.py +++ b/src/citrine/informatics/experiment_values.py @@ -5,17 +5,17 @@ from citrine._serialization import properties __all__ = [ - 'ExperimentValue', - 'RealExperimentValue', - 'IntegerExperimentValue', - 'CategoricalExperimentValue', - 'MixtureExperimentValue', - 'ChemicalFormulaExperimentValue', - 'MolecularStructureExperimentValue' + "ExperimentValue", + "RealExperimentValue", + "IntegerExperimentValue", + "CategoricalExperimentValue", + "MixtureExperimentValue", + "ChemicalFormulaExperimentValue", + "MolecularStructureExperimentValue", ] -class ExperimentValue(PolymorphicSerializable['ExperimentValue']): +class ExperimentValue(PolymorphicSerializable["ExperimentValue"]): """An container for experiment values. Abstract type that returns the proper type given a serialized dict. @@ -61,70 +61,77 @@ def _equals(self, other, attrs) -> bool: [self.__getattribute__(key) for key in attrs] try: - return all([ - self.__getattribute__(key) == other.__getattribute__(key) for key in attrs - ]) + return all( + [ + self.__getattribute__(key) == other.__getattribute__(key) + for key in attrs + ] + ) except AttributeError: return False -class RealExperimentValue(Serializable['RealExperimentValue'], ExperimentValue): +class RealExperimentValue(Serializable["RealExperimentValue"], ExperimentValue): """A floating point experiment result.""" - value = properties.Float('value') - typ = properties.String('type', default='RealValue', deserializable=False) + value = properties.Float("value") + typ = properties.String("type", default="RealValue", deserializable=False) def __init__(self, value: float): self.value = value -class IntegerExperimentValue(Serializable['IntegerExperimentValue'], ExperimentValue): +class IntegerExperimentValue(Serializable["IntegerExperimentValue"], ExperimentValue): """An integer value experiment result.""" - value = properties.Integer('value') - typ = properties.String('type', default='IntegerValue', deserializable=False) + value = properties.Integer("value") + typ = properties.String("type", default="IntegerValue", deserializable=False) def __init__(self, value: int): self.value = value -class CategoricalExperimentValue(Serializable['CategoricalExperimentValue'], ExperimentValue): +class CategoricalExperimentValue( + Serializable["CategoricalExperimentValue"], ExperimentValue +): """An experiment result with a categorical value.""" - value = properties.String('value') - typ = properties.String('type', default='CategoricalValue', deserializable=False) + value = properties.String("value") + typ = properties.String("type", default="CategoricalValue", deserializable=False) def __init__(self, value: str): self.value = value -class MixtureExperimentValue(Serializable['MixtureExperimentValue'], ExperimentValue): +class MixtureExperimentValue(Serializable["MixtureExperimentValue"], ExperimentValue): """An experiment result mapping ingredients and labels to real values.""" - value = properties.Mapping(properties.String, properties.Float, 'value') - typ = properties.String('type', default='MixtureValue', deserializable=False) + value = properties.Mapping(properties.String, properties.Float, "value") + typ = properties.String("type", default="MixtureValue", deserializable=False) def __init__(self, value: Dict[str, float]): self.value = value -class ChemicalFormulaExperimentValue(Serializable['ChemicalFormulaExperimentValue'], - ExperimentValue): +class ChemicalFormulaExperimentValue( + Serializable["ChemicalFormulaExperimentValue"], ExperimentValue +): """Experiment value for a chemical formula.""" - value = properties.String('value') - typ = properties.String('type', default='InorganicValue', deserializable=False) + value = properties.String("value") + typ = properties.String("type", default="InorganicValue", deserializable=False) def __init__(self, value: str): self.value = value -class MolecularStructureExperimentValue(Serializable['MolecularStructureExperimentValue'], - ExperimentValue): +class MolecularStructureExperimentValue( + Serializable["MolecularStructureExperimentValue"], ExperimentValue +): """Experiment value for a molecular structure.""" - value = properties.String('value') - typ = properties.String('type', default='OrganicValue', deserializable=False) + value = properties.String("value") + typ = properties.String("type", default="OrganicValue", deserializable=False) def __init__(self, value: str): self.value = value diff --git a/src/citrine/informatics/feature_effects.py b/src/citrine/informatics/feature_effects.py index d0cd04d6e..eab41995a 100644 --- a/src/citrine/informatics/feature_effects.py +++ b/src/citrine/informatics/feature_effects.py @@ -8,16 +8,17 @@ class ShapleyMaterial(Resource): """The feature effect of a material.""" - material_id = properties.UUID('material_id', serializable=False) - value = properties.Float('value', serializable=False) + material_id = properties.UUID("material_id", serializable=False) + value = properties.Float("value", serializable=False) class ShapleyFeature(Resource): """All feature effects for this feature by material.""" - feature = properties.String('feature', serializable=False) - materials = properties.List(properties.Object(ShapleyMaterial), 'materials', - serializable=False) + feature = properties.String("feature", serializable=False) + materials = properties.List( + properties.Object(ShapleyMaterial), "materials", serializable=False + ) @property def material_dict(self) -> Dict[UUID, float]: @@ -28,8 +29,10 @@ def material_dict(self) -> Dict[UUID, float]: class ShapleyOutput(Resource): """All feature effects for this output by feature.""" - output = properties.String('output', serializable=False) - features = properties.List(properties.Object(ShapleyFeature), 'features', serializable=False) + output = properties.String("output", serializable=False) + features = properties.List( + properties.Object(ShapleyFeature), "features", serializable=False + ) @property def feature_dict(self) -> Dict[str, Dict[UUID, float]]: @@ -40,14 +43,20 @@ def feature_dict(self) -> Dict[str, Dict[UUID, float]]: class FeatureEffects(Resource): """Captures information about the feature effects associated with a predictor.""" - predictor_id = properties.UUID('metadata.predictor_id', serializable=False) - predictor_version = properties.Integer('metadata.predictor_version', serializable=False) - status = properties.String('metadata.status', serializable=False) - failure_reason = properties.Optional(properties.String(), 'metadata.failure_reason', - serializable=False) - - outputs = properties.Optional(properties.List(properties.Object(ShapleyOutput)), 'resultobj', - serializable=False) + predictor_id = properties.UUID("metadata.predictor_id", serializable=False) + predictor_version = properties.Integer( + "metadata.predictor_version", serializable=False + ) + status = properties.String("metadata.status", serializable=False) + failure_reason = properties.Optional( + properties.String(), "metadata.failure_reason", serializable=False + ) + + outputs = properties.Optional( + properties.List(properties.Object(ShapleyOutput)), + "resultobj", + serializable=False, + ) @classmethod def _pre_build(cls, data: dict) -> Dict: @@ -62,11 +71,10 @@ def _pre_build(cls, data: dict) -> Dict: features = [] for feature, values in feature_dict.items(): items = zip(material_ids, values) - materials = [{"material_id": mid, "value": value} for mid, value in items] - features.append({ - "feature": feature, - "materials": materials - }) + materials = [ + {"material_id": mid, "value": value} for mid, value in items + ] + features.append({"feature": feature, "materials": materials}) outputs.append({"output": output, "features": features}) diff --git a/src/citrine/informatics/generative_design.py b/src/citrine/informatics/generative_design.py index eb0432eb1..31e39c7b7 100644 --- a/src/citrine/informatics/generative_design.py +++ b/src/citrine/informatics/generative_design.py @@ -75,8 +75,8 @@ def _pre_build(cls, data: dict) -> dict: data.update(result) return data - uid = properties.UUID('id') - execution_id = properties.UUID('execution_id') + uid = properties.UUID("id") + execution_id = properties.UUID("execution_id") seed = properties.String("seed") """The seed used to generate the molecule.""" @@ -91,7 +91,7 @@ def __init__(self): pass # pragma: no cover -class GenerativeDesignInput(Serializable['GenerativeDesignInput']): +class GenerativeDesignInput(Serializable["GenerativeDesignInput"]): """A Citrine Generative Design Execution Input. Parameters @@ -117,20 +117,22 @@ class GenerativeDesignInput(Serializable['GenerativeDesignInput']): """ - seeds = properties.List(properties.String(), 'seeds') + seeds = properties.List(properties.String(), "seeds") fingerprint_type = properties.Enumeration(FingerprintType, "fingerprint_type") min_fingerprint_similarity = properties.Float("min_fingerprint_similarity") mutation_per_seed = properties.Integer("mutation_per_seed") structure_exclusions = properties.List( - properties.Enumeration(StructureExclusion), - "structure_exclusions" + properties.Enumeration(StructureExclusion), "structure_exclusions" ) min_substructure_counts = properties.Mapping( - properties.String(), properties.Integer(), "min_substructure_counts", + properties.String(), + properties.Integer(), + "min_substructure_counts", ) def __init__( - self, *, + self, + *, seeds: List[str], fingerprint_type: FingerprintType, min_fingerprint_similarity: float, diff --git a/src/citrine/informatics/objectives.py b/src/citrine/informatics/objectives.py index 55231bfe4..428ff920d 100644 --- a/src/citrine/informatics/objectives.py +++ b/src/citrine/informatics/objectives.py @@ -1,13 +1,14 @@ """Tools for working with Objectives.""" + from citrine._serialization import properties from citrine._serialization.serializable import Serializable from citrine._serialization.polymorphic_serializable import PolymorphicSerializable -__all__ = ['Objective', 'ScalarMaxObjective', 'ScalarMinObjective'] +__all__ = ["Objective", "ScalarMaxObjective", "ScalarMinObjective"] -class Objective(PolymorphicSerializable['Objective']): +class Objective(PolymorphicSerializable["Objective"]): """ An Objective describes a goal for a score associated with a single descriptor. @@ -20,13 +21,12 @@ class Objective(PolymorphicSerializable['Objective']): @classmethod def get_type(cls, data): """Return the subtype.""" - return { - 'ScalarMax': ScalarMaxObjective, - 'ScalarMin': ScalarMinObjective - }[data['type']] + return {"ScalarMax": ScalarMaxObjective, "ScalarMin": ScalarMinObjective}[ + data["type"] + ] -class ScalarMaxObjective(Serializable['ScalarMaxObjective'], Objective): +class ScalarMaxObjective(Serializable["ScalarMaxObjective"], Objective): """ Simple single-response maximization objective with optional bounds. @@ -37,17 +37,17 @@ class ScalarMaxObjective(Serializable['ScalarMaxObjective'], Objective): """ - descriptor_key = properties.String('descriptor_key') - typ = properties.String('type', default='ScalarMax') + descriptor_key = properties.String("descriptor_key") + typ = properties.String("type", default="ScalarMax") def __init__(self, descriptor_key: str): self.descriptor_key = descriptor_key def __str__(self): - return ''.format(self.descriptor_key) + return "".format(self.descriptor_key) -class ScalarMinObjective(Serializable['ScalarMinObjective'], Objective): +class ScalarMinObjective(Serializable["ScalarMinObjective"], Objective): """ Simple single-response minimization objective with optional bounds. @@ -58,11 +58,11 @@ class ScalarMinObjective(Serializable['ScalarMinObjective'], Objective): """ - descriptor_key = properties.String('descriptor_key') - typ = properties.String('type', default='ScalarMin') + descriptor_key = properties.String("descriptor_key") + typ = properties.String("type", default="ScalarMin") def __init__(self, descriptor_key: str): self.descriptor_key = descriptor_key def __str__(self): - return ''.format(self.descriptor_key) + return "".format(self.descriptor_key) diff --git a/src/citrine/informatics/predict_request.py b/src/citrine/informatics/predict_request.py index b2c2adab0..bbc9f32af 100644 --- a/src/citrine/informatics/predict_request.py +++ b/src/citrine/informatics/predict_request.py @@ -12,18 +12,21 @@ class PredictRequest(Serializable["PredictRequest"]): This class represents the candidate computed by a design execution. """ - material_id = properties.UUID('material_id') - identifiers = properties.List(properties.String(), 'identifiers') - material = properties.Object(DesignMaterial, 'material') - created_from_id = properties.UUID('created_from_id') - random_seed = properties.Optional(properties.Integer, 'random_seed') + material_id = properties.UUID("material_id") + identifiers = properties.List(properties.String(), "identifiers") + material = properties.Object(DesignMaterial, "material") + created_from_id = properties.UUID("created_from_id") + random_seed = properties.Optional(properties.Integer, "random_seed") - def __init__(self, material_id: UUID, - identifiers: List[str], - material: DesignMaterial, - created_from_id: UUID, - *, - random_seed: Optional[int] = None): + def __init__( + self, + material_id: UUID, + identifiers: List[str], + material: DesignMaterial, + created_from_id: UUID, + *, + random_seed: Optional[int] = None, + ): self.material_id = material_id self.identifiers = identifiers self.material = material diff --git a/src/citrine/informatics/predictor_evaluation_metrics.py b/src/citrine/informatics/predictor_evaluation_metrics.py index 1d1a37a66..c9c1c6d55 100644 --- a/src/citrine/informatics/predictor_evaluation_metrics.py +++ b/src/citrine/informatics/predictor_evaluation_metrics.py @@ -6,15 +6,17 @@ from citrine._serialization.polymorphic_serializable import PolymorphicSerializable from citrine._serialization.serializable import Serializable -__all__ = ['PredictorEvaluationMetric', - 'RMSE', - 'NDME', - 'RSquared', - 'StandardRMSE', - 'PVA', - 'F1', - 'AreaUnderROC', - 'CoverageProbability'] +__all__ = [ + "PredictorEvaluationMetric", + "RMSE", + "NDME", + "RSquared", + "StandardRMSE", + "PVA", + "F1", + "AreaUnderROC", + "CoverageProbability", +] logger = getLogger(__name__) @@ -148,7 +150,9 @@ def __str__(self): return "Area Under the ROC" -class CoverageProbability(Serializable["CoverageProbability"], PredictorEvaluationMetric): +class CoverageProbability( + Serializable["CoverageProbability"], PredictorEvaluationMetric +): """Percentage of observations that fall within a given confidence interval. The coverage level can be specified to 3 digits, e.g., 0.123, but not 0.1234. @@ -173,9 +177,8 @@ def __init__(self, *, coverage_level: Union[str, float] = "0.683"): raise ValueError( "Invalid coverage level string '{requested_level}'. " "Coverage level must represent a floating point number between " - "0 and 1 (non-inclusive).".format( - requested_level=coverage_level - )) + "0 and 1 (non-inclusive).".format(requested_level=coverage_level) + ) elif isinstance(coverage_level, float): raw_float = coverage_level else: @@ -189,9 +192,9 @@ def __init__(self, *, coverage_level: Union[str, float] = "0.683"): "Coverage level can only be specified to 3 decimal places." "Requested level '{requested_level}' will be rounded " "to {rounded_level}.".format( - requested_level=coverage_level, - rounded_level=_level_float - )) + requested_level=coverage_level, rounded_level=_level_float + ) + ) self._level_str = "{:5.3f}".format(_level_float) diff --git a/src/citrine/informatics/predictor_evaluation_result.py b/src/citrine/informatics/predictor_evaluation_result.py index b71176a1f..b74c23f40 100644 --- a/src/citrine/informatics/predictor_evaluation_result.py +++ b/src/citrine/informatics/predictor_evaluation_result.py @@ -4,18 +4,23 @@ from citrine._serialization.polymorphic_serializable import PolymorphicSerializable from citrine._serialization.serializable import Serializable from citrine.informatics.predictor_evaluation_metrics import PredictorEvaluationMetric -from citrine.informatics.predictor_evaluator import PredictorEvaluator, HoldoutSetEvaluator, \ - CrossValidationEvaluator - -__all__ = ['MetricValue', - 'RealMetricValue', - 'PredictedVsActualRealPoint', - 'PredictedVsActualCategoricalPoint', - 'RealPredictedVsActual', - 'CategoricalPredictedVsActual', - 'ResponseMetrics', - 'PredictorEvaluationResult', - 'CrossValidationResult'] +from citrine.informatics.predictor_evaluator import ( + PredictorEvaluator, + HoldoutSetEvaluator, + CrossValidationEvaluator, +) + +__all__ = [ + "MetricValue", + "RealMetricValue", + "PredictedVsActualRealPoint", + "PredictedVsActualCategoricalPoint", + "RealPredictedVsActual", + "CategoricalPredictedVsActual", + "ResponseMetrics", + "PredictorEvaluationResult", + "CrossValidationResult", +] class MetricValue(PolymorphicSerializable["MetricValue"]): @@ -31,7 +36,7 @@ def get_type(cls, data) -> Type[Serializable]: return { "RealMetricValue": RealMetricValue, "RealPredictedVsActual": RealPredictedVsActual, - "CategoricalPredictedVsActual": CategoricalPredictedVsActual + "CategoricalPredictedVsActual": CategoricalPredictedVsActual, }[data["type"]] @@ -42,11 +47,13 @@ class RealMetricValue(Serializable["RealMetricValue"], MetricValue): """:float: Mean value""" standard_error = properties.Optional(properties.Float(), "standard_error") """:Optional[float]: Standard error of the mean""" - typ = properties.String('type', default='RealMetricValue', deserializable=False) + typ = properties.String("type", default="RealMetricValue", deserializable=False) def __eq__(self, other): if isinstance(other, RealMetricValue): - return self.mean == other.mean and self.standard_error == other.standard_error + return ( + self.mean == other.mean and self.standard_error == other.standard_error + ) else: return False @@ -71,7 +78,9 @@ def __init__(self): pass # pragma: no cover -class PredictedVsActualCategoricalPoint(Serializable["PredictedVsActualCategoricalPoint"]): +class PredictedVsActualCategoricalPoint( + Serializable["PredictedVsActualCategoricalPoint"] +): """Predicted vs. actual data for a single categorical data point.""" uuid = properties.UUID("uuid") @@ -93,14 +102,20 @@ def __init__(self): pass # pragma: no cover -class CategoricalPredictedVsActual(Serializable["CategoricalPredictedVsActual"], MetricValue): +class CategoricalPredictedVsActual( + Serializable["CategoricalPredictedVsActual"], MetricValue +): """List of predicted vs. actual data points for a categorical value.""" - value = properties.List(properties.Object(PredictedVsActualCategoricalPoint), "value") + value = properties.List( + properties.Object(PredictedVsActualCategoricalPoint), "value" + ) """:List[PredictedVsActualCategoricalPoint]: List of predicted vs. actual data computed during a predictor evaluation. This is a flattened list that contains data for all trials and folds.""" - typ = properties.String('type', default='CategoricalPredictedVsActual', deserializable=False) + typ = properties.String( + "type", default="CategoricalPredictedVsActual", deserializable=False + ) def __iter__(self): return iter(self.value) @@ -116,7 +131,9 @@ class RealPredictedVsActual(Serializable["RealPredictedVsActual"], MetricValue): """:List[PredictedVsActualRealPoint]: List of predicted vs. actual data computed during a predictor evaluation. This is a flattened list that contains data for all trials and folds.""" - typ = properties.String('type', default='RealPredictedVsActual', deserializable=False) + typ = properties.String( + "type", default="RealPredictedVsActual", deserializable=False + ) def __iter__(self): return iter(self.value) @@ -133,7 +150,9 @@ class ResponseMetrics(Serializable["ResponseMetrics"]): """ - metrics = properties.Mapping(properties.String, properties.Object(MetricValue), "metrics") + metrics = properties.Mapping( + properties.String, properties.Object(MetricValue), "metrics" + ) """:Dict[str, MetricValue]: Metrics computed for a single response, keyed by the metric's ``__repr__``.""" @@ -166,7 +185,7 @@ def get_type(cls, data) -> Type[Serializable]: """Return the subtype.""" return { "CrossValidationResult": CrossValidationResult, - "HoldoutSetResult": HoldoutSetResult + "HoldoutSetResult": HoldoutSetResult, }[data["type"]] @property @@ -185,7 +204,9 @@ def metrics(self) -> Set[PredictorEvaluationMetric]: raise NotImplementedError # pragma: no cover -class CrossValidationResult(Serializable["CrossValidationResult"], PredictorEvaluationResult): +class CrossValidationResult( + Serializable["CrossValidationResult"], PredictorEvaluationResult +): """Result of performing a cross-validation evaluation on a predictor. Results for a cross-validated response can be accessed via ``cvResult['response_name']``, @@ -197,9 +218,12 @@ class CrossValidationResult(Serializable["CrossValidationResult"], PredictorEval """ _evaluator = properties.Object(CrossValidationEvaluator, "evaluator") - _response_results = properties.Mapping(properties.String, properties.Object(ResponseMetrics), - "response_results") - typ = properties.String('type', default='CrossValidationResult', deserializable=False) + _response_results = properties.Mapping( + properties.String, properties.Object(ResponseMetrics), "response_results" + ) + typ = properties.String( + "type", default="CrossValidationResult", deserializable=False + ) def __getitem__(self, item): return self._response_results[item] @@ -235,9 +259,10 @@ class HoldoutSetResult(Serializable["HoldoutSetResult"], PredictorEvaluationResu """ _evaluator = properties.Object(HoldoutSetEvaluator, "evaluator") - _response_results = properties.Mapping(properties.String, properties.Object(ResponseMetrics), - "response_results") - typ = properties.String('type', default='HoldoutSetResult', deserializable=False) + _response_results = properties.Mapping( + properties.String, properties.Object(ResponseMetrics), "response_results" + ) + typ = properties.String("type", default="HoldoutSetResult", deserializable=False) def __getitem__(self, item): return self._response_results[item] diff --git a/src/citrine/informatics/predictor_evaluator.py b/src/citrine/informatics/predictor_evaluator.py index 499f6407b..c02e66511 100644 --- a/src/citrine/informatics/predictor_evaluator.py +++ b/src/citrine/informatics/predictor_evaluator.py @@ -6,10 +6,7 @@ from citrine.informatics.predictor_evaluation_metrics import PredictorEvaluationMetric from citrine.informatics.data_sources import DataSource -__all__ = ['PredictorEvaluator', - 'CrossValidationEvaluator', - 'HoldoutSetEvaluator' - ] +__all__ = ["PredictorEvaluator", "CrossValidationEvaluator", "HoldoutSetEvaluator"] class PredictorEvaluator(PolymorphicSerializable["PredictorEvaluator"]): @@ -20,7 +17,7 @@ def get_type(cls, data) -> Type[Serializable]: """Return the subtype.""" return { "CrossValidationEvaluator": CrossValidationEvaluator, - "HoldoutSetEvaluator": HoldoutSetEvaluator + "HoldoutSetEvaluator": HoldoutSetEvaluator, }[data["type"]] def __eq__(self, other): @@ -28,13 +25,15 @@ def __eq__(self, other): self_dict = self.dump() other_dict = other.dump() - self_dict['responses'] = set(self_dict.get('responses', [])) - self_dict['metrics'] = frozenset( - frozenset((k, v) for k, v in dct.items()) for dct in self_dict.get('metrics', []) + self_dict["responses"] = set(self_dict.get("responses", [])) + self_dict["metrics"] = frozenset( + frozenset((k, v) for k, v in dct.items()) + for dct in self_dict.get("metrics", []) ) - other_dict['responses'] = set(other_dict.get('responses', [])) - other_dict['metrics'] = frozenset( - frozenset((k, v) for k, v in dct.items()) for dct in other_dict.get('metrics', []) + other_dict["responses"] = set(other_dict.get("responses", [])) + other_dict["metrics"] = frozenset( + frozenset((k, v) for k, v in dct.items()) + for dct in other_dict.get("metrics", []) ) return self_dict == other_dict @@ -63,7 +62,9 @@ def name(self) -> str: raise NotImplementedError # pragma: no cover -class CrossValidationEvaluator(Serializable["CrossValidationEvaluator"], PredictorEvaluator): +class CrossValidationEvaluator( + Serializable["CrossValidationEvaluator"], PredictorEvaluator +): """Evaluate a predictor via cross validation. Performs cross-validation on requested predictor responses and computes the requested metrics @@ -105,21 +106,27 @@ class CrossValidationEvaluator(Serializable["CrossValidationEvaluator"], Predict _responses = properties.Set(properties.String, "responses") n_folds = properties.Integer("n_folds") n_trials = properties.Integer("n_trials") - _metrics = properties.Optional(properties.Set(properties.Object(PredictorEvaluationMetric)), - "metrics") - ignore_when_grouping = properties.Optional(properties.Set(properties.String), - "ignore_when_grouping") - typ = properties.String("type", default="CrossValidationEvaluator", deserializable=False) - - def __init__(self, - name: str, - *, - description: str = "", - responses: Set[str], - n_folds: int = 5, - n_trials: int = 3, - metrics: Optional[Set[PredictorEvaluationMetric]] = None, - ignore_when_grouping: Optional[Set[str]] = None): + _metrics = properties.Optional( + properties.Set(properties.Object(PredictorEvaluationMetric)), "metrics" + ) + ignore_when_grouping = properties.Optional( + properties.Set(properties.String), "ignore_when_grouping" + ) + typ = properties.String( + "type", default="CrossValidationEvaluator", deserializable=False + ) + + def __init__( + self, + name: str, + *, + description: str = "", + responses: Set[str], + n_folds: int = 5, + n_trials: int = 3, + metrics: Optional[Set[PredictorEvaluationMetric]] = None, + ignore_when_grouping: Optional[Set[str]] = None, + ): self.name: str = name self.description: str = description self._responses: Set[str] = responses @@ -163,16 +170,20 @@ class HoldoutSetEvaluator(Serializable["HoldoutSetEvaluator"], PredictorEvaluato description = properties.String("description") _responses = properties.Set(properties.String, "responses") data_source = properties.Object(DataSource, "data_source") - _metrics = properties.Optional(properties.Set(properties.Object(PredictorEvaluationMetric)), - "metrics") + _metrics = properties.Optional( + properties.Set(properties.Object(PredictorEvaluationMetric)), "metrics" + ) typ = properties.String("type", default="HoldoutSetEvaluator", deserializable=False) - def __init__(self, - name: str, *, - description: str = "", - responses: Set[str], - data_source: DataSource, - metrics: Optional[Set[PredictorEvaluationMetric]] = None): + def __init__( + self, + name: str, + *, + description: str = "", + responses: Set[str], + data_source: DataSource, + metrics: Optional[Set[PredictorEvaluationMetric]] = None, + ): self.name: str = name self.description: str = description self._responses: Set[str] = responses diff --git a/src/citrine/informatics/predictors/attribute_accumulation_predictor.py b/src/citrine/informatics/predictors/attribute_accumulation_predictor.py index 963da002d..16d4e88de 100644 --- a/src/citrine/informatics/predictors/attribute_accumulation_predictor.py +++ b/src/citrine/informatics/predictors/attribute_accumulation_predictor.py @@ -6,28 +6,34 @@ from citrine.informatics.predictors import PredictorNode -class AttributeAccumulationPredictor(Resource["AttributeAccumulationPredictor"], PredictorNode): +class AttributeAccumulationPredictor( + Resource["AttributeAccumulationPredictor"], PredictorNode +): """A node to support propagating attributes through training. You should never have to add this node yourself: the backend should be able to automatically create it when necessary. """ - attributes = _properties.List(_properties.Object(Descriptor), 'attributes') - sequential = _properties.Boolean('sequential') - - typ = _properties.String('type', default='AttributeAccumulation', deserializable=False) - - def __init__(self, - name: str, - *, - description: str, - attributes: List[Descriptor], - sequential: bool): + attributes = _properties.List(_properties.Object(Descriptor), "attributes") + sequential = _properties.Boolean("sequential") + + typ = _properties.String( + "type", default="AttributeAccumulation", deserializable=False + ) + + def __init__( + self, + name: str, + *, + description: str, + attributes: List[Descriptor], + sequential: bool, + ): self.name = name self.description = description self.attributes = attributes self.sequential = sequential def __str__(self): - return ''.format(self.name) + return "".format(self.name) diff --git a/src/citrine/informatics/predictors/auto_ml_predictor.py b/src/citrine/informatics/predictors/auto_ml_predictor.py index 54874fcd9..40bf1ba8c 100644 --- a/src/citrine/informatics/predictors/auto_ml_predictor.py +++ b/src/citrine/informatics/predictors/auto_ml_predictor.py @@ -9,7 +9,7 @@ from citrine.informatics.descriptors import Descriptor from citrine.informatics.predictors import PredictorNode -__all__ = ['AutoMLPredictor', 'AutoMLEstimator'] +__all__ = ["AutoMLPredictor", "AutoMLEstimator"] class AutoMLEstimator(BaseEnumeration): @@ -62,50 +62,58 @@ class AutoMLPredictor(Resource["AutoMLPredictor"], PredictorNode): """ - inputs = _properties.List(_properties.Object(Descriptor), 'inputs') - outputs = _properties.List(_properties.Object(Descriptor), 'outputs') + inputs = _properties.List(_properties.Object(Descriptor), "inputs") + outputs = _properties.List(_properties.Object(Descriptor), "outputs") estimators = _properties.Set( _properties.Enumeration(AutoMLEstimator), - 'estimators', - default={AutoMLEstimator.RANDOM_FOREST} + "estimators", + default={AutoMLEstimator.RANDOM_FOREST}, ) _training_data = _properties.List( - _properties.Object(DataSource), - 'training_data', - default=[] + _properties.Object(DataSource), "training_data", default=[] ) - typ = _properties.String('type', default='AutoML', deserializable=False) - - def __init__(self, - name: str, - *, - description: str, - outputs: List[Descriptor], - inputs: List[Descriptor], - estimators: Optional[Set[AutoMLEstimator]] = None, - training_data: Optional[List[DataSource]] = None): + typ = _properties.String("type", default="AutoML", deserializable=False) + + def __init__( + self, + name: str, + *, + description: str, + outputs: List[Descriptor], + inputs: List[Descriptor], + estimators: Optional[Set[AutoMLEstimator]] = None, + training_data: Optional[List[DataSource]] = None, + ): self.name: str = name self.description: str = description self.inputs: List[Descriptor] = inputs - self.estimators: Set[AutoMLEstimator] = estimators or {AutoMLEstimator.RANDOM_FOREST} + self.estimators: Set[AutoMLEstimator] = estimators or { + AutoMLEstimator.RANDOM_FOREST + } self.outputs = outputs # self.training_data: List[DataSource] = training_data or [] if training_data: self.training_data: List[DataSource] = training_data @property - @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", - details="Training data must be accessed through the top-level GraphPredictor.'") + @deprecated( + deprecated_in="3.5.0", + removed_in="4.0.0", + details="Training data must be accessed through the top-level GraphPredictor.'", + ) def training_data(self): """[DEPRECATED] Retrieve training data associated with this node.""" return self._training_data @training_data.setter - @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", - details="Training data should only be added to the top-level GraphPredictor.'") + @deprecated( + deprecated_in="3.5.0", + removed_in="4.0.0", + details="Training data should only be added to the top-level GraphPredictor.'", + ) def training_data(self, value): self._training_data = value def __str__(self): - return ''.format(self.name) + return "".format(self.name) diff --git a/src/citrine/informatics/predictors/chemical_formula_featurizer.py b/src/citrine/informatics/predictors/chemical_formula_featurizer.py index 45213ab17..56f9192c6 100644 --- a/src/citrine/informatics/predictors/chemical_formula_featurizer.py +++ b/src/citrine/informatics/predictors/chemical_formula_featurizer.py @@ -6,7 +6,7 @@ from citrine.informatics.descriptors import ChemicalFormulaDescriptor from citrine.informatics.predictors import PredictorNode -__all__ = ['ChemicalFormulaFeaturizer'] +__all__ = ["ChemicalFormulaFeaturizer"] class ChemicalFormulaFeaturizer(Resource["ChemicalFormulaFeaturizer"], PredictorNode): @@ -131,21 +131,25 @@ class ChemicalFormulaFeaturizer(Resource["ChemicalFormulaFeaturizer"], Predictor """ - input_descriptor = properties.Object(ChemicalFormulaDescriptor, 'input') - features = properties.List(properties.String, 'features') - excludes = properties.List(properties.String, 'excludes', default=[]) - _powers = properties.List(properties.Float, 'powers') - - typ = properties.String('type', default='ChemicalFormulaFeaturizer', deserializable=False) - - def __init__(self, - name: str, - *, - description: str, - input_descriptor: ChemicalFormulaDescriptor, - features: Optional[List[str]] = None, - excludes: Optional[List[str]] = None, - powers: Optional[List[int]] = None): + input_descriptor = properties.Object(ChemicalFormulaDescriptor, "input") + features = properties.List(properties.String, "features") + excludes = properties.List(properties.String, "excludes", default=[]) + _powers = properties.List(properties.Float, "powers") + + typ = properties.String( + "type", default="ChemicalFormulaFeaturizer", deserializable=False + ) + + def __init__( + self, + name: str, + *, + description: str, + input_descriptor: ChemicalFormulaDescriptor, + features: Optional[List[str]] = None, + excludes: Optional[List[str]] = None, + powers: Optional[List[int]] = None, + ): self.name = name self.description = description self.input_descriptor = input_descriptor @@ -156,8 +160,10 @@ def __init__(self, @property def powers(self) -> List[int]: """The list of powers when computing generalized weighted means of element properties.""" - warn("The type of 'powers' will change to a list of floats in v4.0.0. To retrieve them as " - "floats now, use 'powers_as_float'.") + warn( + "The type of 'powers' will change to a list of floats in v4.0.0. To retrieve them as " + "floats now, use 'powers_as_float'." + ) truncated = [int(p) for p in self._powers] if truncated != self._powers: diffs = [f"{x} => {y}" for x, y in zip(self._powers, truncated) if x != y] @@ -171,9 +177,11 @@ def powers(self, value: List[Union[int, float]]): @property def powers_as_float(self) -> List[float]: """Powers when computing generalized weighted means of element properties.""" - warn("'powers_as_float' will be deprecated in v4.0.0 for 'powers', and removed in v5.0.0", - PendingDeprecationWarning) + warn( + "'powers_as_float' will be deprecated in v4.0.0 for 'powers', and removed in v5.0.0", + PendingDeprecationWarning, + ) return self._powers def __str__(self): - return ''.format(self.name) + return "".format(self.name) diff --git a/src/citrine/informatics/predictors/expression_predictor.py b/src/citrine/informatics/predictors/expression_predictor.py index 47f5e36a9..2e126eab3 100644 --- a/src/citrine/informatics/predictors/expression_predictor.py +++ b/src/citrine/informatics/predictors/expression_predictor.py @@ -5,7 +5,7 @@ from citrine.informatics.descriptors import RealDescriptor from citrine.informatics.predictors import PredictorNode -__all__ = ['ExpressionPredictor'] +__all__ = ["ExpressionPredictor"] class ExpressionPredictor(Resource["ExpressionPredictor"], PredictorNode): @@ -30,21 +30,23 @@ class ExpressionPredictor(Resource["ExpressionPredictor"], PredictorNode): """ - expression = _properties.String('expression') - output = _properties.Object(RealDescriptor, 'output') + expression = _properties.String("expression") + output = _properties.Object(RealDescriptor, "output") aliases = _properties.Mapping( - _properties.String, _properties.Object(RealDescriptor), 'aliases' + _properties.String, _properties.Object(RealDescriptor), "aliases" ) - typ = _properties.String('type', default='AnalyticExpression', deserializable=False) - - def __init__(self, - name: str, - *, - description: str, - expression: str, - output: RealDescriptor, - aliases: Mapping[str, RealDescriptor]): + typ = _properties.String("type", default="AnalyticExpression", deserializable=False) + + def __init__( + self, + name: str, + *, + description: str, + expression: str, + output: RealDescriptor, + aliases: Mapping[str, RealDescriptor], + ): self.name: str = name self.description: str = description self.expression: str = expression @@ -52,4 +54,4 @@ def __init__(self, self.aliases: Mapping[str, RealDescriptor] = aliases def __str__(self): - return ''.format(self.name) + return "".format(self.name) diff --git a/src/citrine/informatics/predictors/graph_predictor.py b/src/citrine/informatics/predictors/graph_predictor.py index 324a2fbb0..3bf33f968 100644 --- a/src/citrine/informatics/predictors/graph_predictor.py +++ b/src/citrine/informatics/predictors/graph_predictor.py @@ -14,10 +14,12 @@ from citrine.informatics.reports import Report from citrine.resources.report import ReportResource -__all__ = ['GraphPredictor'] +__all__ = ["GraphPredictor"] -class GraphPredictor(VersionedEngineResource['GraphPredictor'], AsynchronousObject, Predictor): +class GraphPredictor( + VersionedEngineResource["GraphPredictor"], AsynchronousObject, Predictor +): """A predictor interface that stitches individual predictor nodes together. The GraphPredictor is the only predictor that can be registered on the Citrine Platform @@ -42,23 +44,25 @@ class GraphPredictor(VersionedEngineResource['GraphPredictor'], AsynchronousObje """ - uid = properties.Optional(properties.UUID, 'id', serializable=False) + uid = properties.Optional(properties.UUID, "id", serializable=False) """:Optional[UUID]: Citrine Platform unique identifier""" - name = properties.String('data.name') - description = properties.Optional(properties.String(), 'data.description') - predictors = properties.List(properties.Object(PredictorNode), 'data.instance.predictors') + name = properties.String("data.name") + description = properties.Optional(properties.String(), "data.description") + predictors = properties.List( + properties.Object(PredictorNode), "data.instance.predictors" + ) # the default seems to be defined in instances, not the class itself # this is tested in test_graph_default_training_data training_data = properties.List( - properties.Object(DataSource), 'data.instance.training_data', default=[] + properties.Object(DataSource), "data.instance.training_data", default=[] ) version = properties.Optional( properties.Union([properties.Integer(), properties.String()]), - 'metadata.version', - serializable=False + "metadata.version", + serializable=False, ) _api_version = "v3" @@ -69,26 +73,28 @@ class GraphPredictor(VersionedEngineResource['GraphPredictor'], AsynchronousObje _succeeded_statuses = ["READY"] _failed_statuses = ["INVALID", "ERROR"] - def __init__(self, - name: str, - *, - description: str, - predictors: List[PredictorNode], - training_data: Optional[List[DataSource]] = None): + def __init__( + self, + name: str, + *, + description: str, + predictors: List[PredictorNode], + training_data: Optional[List[DataSource]] = None, + ): self.name: str = name self.description: str = description self.training_data: List[DataSource] = training_data or [] self.predictors: List[PredictorNode] = predictors def __str__(self): - return ''.format(self.name) + return "".format(self.name) def _path(self): return format_escaped_url( - '/projects/{project_id}/predictors/{predictor_id}/versions/{version}', + "/projects/{project_id}/predictors/{predictor_id}/versions/{version}", project_id=self._project_id, predictor_id=str(self.uid), - version=self.version + version=self.version, ) @staticmethod @@ -101,29 +107,37 @@ def wrap_instance(predictor_data: dict) -> dict: "data": { "name": predictor_data.get("name", ""), "description": predictor_data.get("description", ""), - "instance": predictor_data + "instance": predictor_data, } } @property def report(self) -> Report: """Fetch the predictor report.""" - if self.uid is None or self._session is None or self._project_id is None \ - or getattr(self, "version", None) is None: + if ( + self.uid is None + or self._session is None + or self._project_id is None + or getattr(self, "version", None) is None + ): msg = "Cannot get the report for a predictor that wasn't read from the platform" raise ValueError(msg) report_resource = ReportResource(self._project_id, self._session) - return report_resource.get(predictor_id=self.uid, predictor_version=self.version) + return report_resource.get( + predictor_id=self.uid, predictor_version=self.version + ) @property def feature_effects(self) -> FeatureEffects: """Retrieve the feature effects for all outputs in the predictor's training data..""" - path = self._path() + '/shapley/query' + path = self._path() + "/shapley/query" response = self._session.post_resource(path, {}, version=self._api_version) return FeatureEffects.build(response) def predict(self, predict_request: SinglePredictRequest) -> SinglePrediction: """Make a one-off prediction with this predictor.""" - path = self._path() + '/predict' - res = self._session.post_resource(path, predict_request.dump(), version=self._api_version) + path = self._path() + "/predict" + res = self._session.post_resource( + path, predict_request.dump(), version=self._api_version + ) return SinglePrediction.build(res) diff --git a/src/citrine/informatics/predictors/ingredient_fractions_predictor.py b/src/citrine/informatics/predictors/ingredient_fractions_predictor.py index 449486950..0ff5f854a 100644 --- a/src/citrine/informatics/predictors/ingredient_fractions_predictor.py +++ b/src/citrine/informatics/predictors/ingredient_fractions_predictor.py @@ -5,10 +5,12 @@ from citrine.informatics.descriptors import FormulationDescriptor from citrine.informatics.predictors import PredictorNode -__all__ = ['IngredientFractionsPredictor'] +__all__ = ["IngredientFractionsPredictor"] -class IngredientFractionsPredictor(Resource["IngredientFractionsPredictor"], PredictorNode): +class IngredientFractionsPredictor( + Resource["IngredientFractionsPredictor"], PredictorNode +): """A predictor interface that computes ingredient fractions. Parameters @@ -25,21 +27,25 @@ class IngredientFractionsPredictor(Resource["IngredientFractionsPredictor"], Pre """ - input_descriptor = _properties.Object(FormulationDescriptor, 'input') - ingredients = _properties.Set(_properties.String, 'ingredients') - - typ = _properties.String('type', default='IngredientFractions', deserializable=False) - - def __init__(self, - name: str, - *, - description: str, - input_descriptor: FormulationDescriptor, - ingredients: Set[str]): + input_descriptor = _properties.Object(FormulationDescriptor, "input") + ingredients = _properties.Set(_properties.String, "ingredients") + + typ = _properties.String( + "type", default="IngredientFractions", deserializable=False + ) + + def __init__( + self, + name: str, + *, + description: str, + input_descriptor: FormulationDescriptor, + ingredients: Set[str], + ): self.name: str = name self.description: str = description self.input_descriptor: FormulationDescriptor = input_descriptor self.ingredients: Set[str] = ingredients def __str__(self): - return ''.format(self.name) + return "".format(self.name) diff --git a/src/citrine/informatics/predictors/ingredients_to_formulation_predictor.py b/src/citrine/informatics/predictors/ingredients_to_formulation_predictor.py index e068ee3b7..090f689fd 100644 --- a/src/citrine/informatics/predictors/ingredients_to_formulation_predictor.py +++ b/src/citrine/informatics/predictors/ingredients_to_formulation_predictor.py @@ -5,7 +5,7 @@ from citrine.informatics.descriptors import FormulationDescriptor, RealDescriptor from citrine.informatics.predictors import PredictorNode -__all__ = ['IngredientsToFormulationPredictor'] +__all__ = ["IngredientsToFormulationPredictor"] class IngredientsToFormulationPredictor( @@ -29,25 +29,31 @@ class IngredientsToFormulationPredictor( """ id_to_quantity = properties.Mapping( - properties.String, properties.Object(RealDescriptor), 'id_to_quantity' + properties.String, properties.Object(RealDescriptor), "id_to_quantity" + ) + labels = properties.Mapping( + properties.String, properties.Set(properties.String), "labels" ) - labels = properties.Mapping(properties.String, properties.Set(properties.String), 'labels') - typ = properties.String('type', default='IngredientsToSimpleMixture', deserializable=False) + typ = properties.String( + "type", default="IngredientsToSimpleMixture", deserializable=False + ) - def __init__(self, - name: str, - *, - description: str, - id_to_quantity: Mapping[str, RealDescriptor], - labels: Mapping[str, Set[str]]): + def __init__( + self, + name: str, + *, + description: str, + id_to_quantity: Mapping[str, RealDescriptor], + labels: Mapping[str, Set[str]], + ): self.name: str = name self.description: str = description self.id_to_quantity: Mapping[str, RealDescriptor] = id_to_quantity self.labels: Mapping[str, Set[str]] = labels def __str__(self): - return ''.format(self.name) + return "".format(self.name) @property def output(self) -> FormulationDescriptor: diff --git a/src/citrine/informatics/predictors/label_fractions_predictor.py b/src/citrine/informatics/predictors/label_fractions_predictor.py index 2bdf3ff0a..7d8a27925 100644 --- a/src/citrine/informatics/predictors/label_fractions_predictor.py +++ b/src/citrine/informatics/predictors/label_fractions_predictor.py @@ -5,7 +5,7 @@ from citrine.informatics.descriptors import FormulationDescriptor from citrine.informatics.predictors import PredictorNode -__all__ = ['LabelFractionsPredictor'] +__all__ = ["LabelFractionsPredictor"] class LabelFractionsPredictor(Resource["LabelFractionsPredictor"], PredictorNode): @@ -24,21 +24,23 @@ class LabelFractionsPredictor(Resource["LabelFractionsPredictor"], PredictorNode """ - input_descriptor = _properties.Object(FormulationDescriptor, 'input') - labels = _properties.Set(_properties.String, 'labels') + input_descriptor = _properties.Object(FormulationDescriptor, "input") + labels = _properties.Set(_properties.String, "labels") - typ = _properties.String('type', default='LabelFractions', deserializable=False) + typ = _properties.String("type", default="LabelFractions", deserializable=False) - def __init__(self, - name: str, - *, - description: str, - input_descriptor: FormulationDescriptor, - labels: Set[str]): + def __init__( + self, + name: str, + *, + description: str, + input_descriptor: FormulationDescriptor, + labels: Set[str], + ): self.name: str = name self.description: str = description self.input_descriptor: FormulationDescriptor = input_descriptor self.labels: Set[str] = labels def __str__(self): - return ''.format(self.name) + return "".format(self.name) diff --git a/src/citrine/informatics/predictors/mean_property_predictor.py b/src/citrine/informatics/predictors/mean_property_predictor.py index c71bb60df..91b7808ad 100644 --- a/src/citrine/informatics/predictors/mean_property_predictor.py +++ b/src/citrine/informatics/predictors/mean_property_predictor.py @@ -6,11 +6,13 @@ from citrine._serialization import properties as _properties from citrine.informatics.data_sources import DataSource from citrine.informatics.descriptors import ( - CategoricalDescriptor, FormulationDescriptor, RealDescriptor + CategoricalDescriptor, + FormulationDescriptor, + RealDescriptor, ) from citrine.informatics.predictors import PredictorNode -__all__ = ['MeanPropertyPredictor'] +__all__ = ["MeanPropertyPredictor"] class MeanPropertyPredictor(Resource["MeanPropertyPredictor"], PredictorNode): @@ -64,40 +66,45 @@ class MeanPropertyPredictor(Resource["MeanPropertyPredictor"], PredictorNode): """ - input_descriptor = _properties.Object(FormulationDescriptor, 'input') + input_descriptor = _properties.Object(FormulationDescriptor, "input") properties = _properties.List( _properties.Union( - [_properties.Object(RealDescriptor), _properties.Object(CategoricalDescriptor)] + [ + _properties.Object(RealDescriptor), + _properties.Object(CategoricalDescriptor), + ] ), - 'properties' + "properties", ) - p = _properties.Float('p') - impute_properties = _properties.Boolean('impute_properties') - label = _properties.Optional(_properties.String, 'label') + p = _properties.Float("p") + impute_properties = _properties.Boolean("impute_properties") + label = _properties.Optional(_properties.String, "label") default_properties = _properties.Optional( _properties.Mapping( _properties.String, - _properties.Union([_properties.String, _properties.Float]) + _properties.Union([_properties.String, _properties.Float]), ), - 'default_properties' + "default_properties", ) _training_data = _properties.List( - _properties.Object(DataSource), 'training_data', default=[] + _properties.Object(DataSource), "training_data", default=[] ) - typ = _properties.String('type', default='MeanProperty', deserializable=False) + typ = _properties.String("type", default="MeanProperty", deserializable=False) - def __init__(self, - name: str, - *, - description: str, - input_descriptor: FormulationDescriptor, - properties: List[Union[RealDescriptor, CategoricalDescriptor]], - p: float, - impute_properties: bool, - label: Optional[str] = None, - default_properties: Optional[Mapping[str, Union[str, float]]] = None, - training_data: Optional[List[DataSource]] = None): + def __init__( + self, + name: str, + *, + description: str, + input_descriptor: FormulationDescriptor, + properties: List[Union[RealDescriptor, CategoricalDescriptor]], + p: float, + impute_properties: bool, + label: Optional[str] = None, + default_properties: Optional[Mapping[str, Union[str, float]]] = None, + training_data: Optional[List[DataSource]] = None, + ): self.name: str = name self.description: str = description self.input_descriptor: FormulationDescriptor = input_descriptor @@ -105,23 +112,31 @@ def __init__(self, self.p: float = p self.impute_properties: bool = impute_properties self.label: Optional[str] = label - self.default_properties: Optional[Mapping[str, Union[str, float]]] = default_properties + self.default_properties: Optional[Mapping[str, Union[str, float]]] = ( + default_properties + ) # self.training_data: List[DataSource] = training_data or [] if training_data: self.training_data: List[DataSource] = training_data def __str__(self): - return ''.format(self.name) + return "".format(self.name) @property - @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", - details="Training data must be accessed through the top-level GraphPredictor.'") + @deprecated( + deprecated_in="3.5.0", + removed_in="4.0.0", + details="Training data must be accessed through the top-level GraphPredictor.'", + ) def training_data(self): """[DEPRECATED] Retrieve training data associated with this node.""" return self._training_data @training_data.setter - @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", - details="Training data should only be added to the top-level GraphPredictor.'") + @deprecated( + deprecated_in="3.5.0", + removed_in="4.0.0", + details="Training data should only be added to the top-level GraphPredictor.'", + ) def training_data(self, value): self._training_data = value diff --git a/src/citrine/informatics/predictors/molecular_structure_featurizer.py b/src/citrine/informatics/predictors/molecular_structure_featurizer.py index 3b16b50eb..3e461c5ce 100644 --- a/src/citrine/informatics/predictors/molecular_structure_featurizer.py +++ b/src/citrine/informatics/predictors/molecular_structure_featurizer.py @@ -8,10 +8,12 @@ from citrine.informatics.descriptors import MolecularStructureDescriptor from citrine.informatics.predictors import PredictorNode -__all__ = ['MolecularStructureFeaturizer'] +__all__ = ["MolecularStructureFeaturizer"] -class MolecularStructureFeaturizer(Resource["MolecularStructureFeaturizer"], PredictorNode): +class MolecularStructureFeaturizer( + Resource["MolecularStructureFeaturizer"], PredictorNode +): """ A featurizer for molecular structures, powered by CDK. @@ -78,19 +80,21 @@ class MolecularStructureFeaturizer(Resource["MolecularStructureFeaturizer"], Pre """ - input_descriptor = _properties.Object(MolecularStructureDescriptor, 'descriptor') - features = _properties.List(_properties.String, 'features') - excludes = _properties.List(_properties.String, 'excludes') + input_descriptor = _properties.Object(MolecularStructureDescriptor, "descriptor") + features = _properties.List(_properties.String, "features") + excludes = _properties.List(_properties.String, "excludes") - typ = _properties.String('type', default='MoleculeFeaturizer', deserializable=False) + typ = _properties.String("type", default="MoleculeFeaturizer", deserializable=False) - def __init__(self, - name: str, - *, - description: str, - input_descriptor: MolecularStructureDescriptor, - features: Optional[List[str]] = None, - excludes: Optional[List[str]] = None): + def __init__( + self, + name: str, + *, + description: str, + input_descriptor: MolecularStructureDescriptor, + features: Optional[List[str]] = None, + excludes: Optional[List[str]] = None, + ): self.name: str = name self.description: str = description self.input_descriptor = input_descriptor @@ -98,4 +102,4 @@ def __init__(self, self.excludes = excludes if excludes is not None else [] def __str__(self): - return ''.format(self.name) + return "".format(self.name) diff --git a/src/citrine/informatics/predictors/node.py b/src/citrine/informatics/predictors/node.py index d20b6cdef..54bb7391f 100644 --- a/src/citrine/informatics/predictors/node.py +++ b/src/citrine/informatics/predictors/node.py @@ -17,18 +17,21 @@ class PredictorNode(PolymorphicSerializable["PredictorNode"], Predictor): description = properties.Optional(properties.String(), "description") @classmethod - def get_type(cls, data) -> Type['PredictorNode']: + def get_type(cls, data) -> Type["PredictorNode"]: """Return the subtype.""" from .auto_ml_predictor import AutoMLPredictor from .attribute_accumulation_predictor import AttributeAccumulationPredictor from .chemical_formula_featurizer import ChemicalFormulaFeaturizer from .expression_predictor import ExpressionPredictor from .ingredient_fractions_predictor import IngredientFractionsPredictor - from .ingredients_to_formulation_predictor import IngredientsToFormulationPredictor + from .ingredients_to_formulation_predictor import ( + IngredientsToFormulationPredictor, + ) from .label_fractions_predictor import LabelFractionsPredictor from .mean_property_predictor import MeanPropertyPredictor from .molecular_structure_featurizer import MolecularStructureFeaturizer from .simple_mixture_predictor import SimpleMixturePredictor + type_dict = { "AnalyticExpression": ExpressionPredictor, "AttributeAccumulation": AttributeAccumulationPredictor, @@ -41,11 +44,12 @@ def get_type(cls, data) -> Type['PredictorNode']: "MoleculeFeaturizer": MolecularStructureFeaturizer, "SimpleMixture": SimpleMixturePredictor, } - typ = type_dict.get(data['type']) + typ = type_dict.get(data["type"]) if typ is not None: return typ else: raise ValueError( - '{} is not a valid predictor node type. ' - 'Must be in {}.'.format(data['type'], type_dict.keys()) + "{} is not a valid predictor node type. Must be in {}.".format( + data["type"], type_dict.keys() + ) ) diff --git a/src/citrine/informatics/predictors/predictor.py b/src/citrine/informatics/predictors/predictor.py index 24090dcce..31f85f05a 100644 --- a/src/citrine/informatics/predictors/predictor.py +++ b/src/citrine/informatics/predictors/predictor.py @@ -1,4 +1,4 @@ -__all__ = ['Predictor'] +__all__ = ["Predictor"] class Predictor: diff --git a/src/citrine/informatics/predictors/simple_mixture_predictor.py b/src/citrine/informatics/predictors/simple_mixture_predictor.py index da2b1fe4b..7138fc094 100644 --- a/src/citrine/informatics/predictors/simple_mixture_predictor.py +++ b/src/citrine/informatics/predictors/simple_mixture_predictor.py @@ -8,7 +8,7 @@ from citrine.informatics.descriptors import FormulationDescriptor from citrine.informatics.predictors import PredictorNode -__all__ = ['SimpleMixturePredictor'] +__all__ = ["SimpleMixturePredictor"] class SimpleMixturePredictor(Resource["SimpleMixturePredictor"], PredictorNode): @@ -30,22 +30,26 @@ class SimpleMixturePredictor(Resource["SimpleMixturePredictor"], PredictorNode): """ - _training_data = properties.List(properties.Object(DataSource), 'training_data', default=[]) + _training_data = properties.List( + properties.Object(DataSource), "training_data", default=[] + ) - typ = properties.String('type', default='SimpleMixture', deserializable=False) + typ = properties.String("type", default="SimpleMixture", deserializable=False) - def __init__(self, - name: str, - *, - description: str, - training_data: Optional[List[DataSource]] = None): + def __init__( + self, + name: str, + *, + description: str, + training_data: Optional[List[DataSource]] = None, + ): self.name: str = name self.description: str = description if training_data: self.training_data: List[DataSource] = training_data def __str__(self): - return ''.format(self.name) + return "".format(self.name) @property def input_descriptor(self) -> FormulationDescriptor: @@ -58,14 +62,20 @@ def output_descriptor(self) -> FormulationDescriptor: return FormulationDescriptor.flat() @property - @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", - details="Training data must be accessed through the top-level GraphPredictor.'") + @deprecated( + deprecated_in="3.5.0", + removed_in="4.0.0", + details="Training data must be accessed through the top-level GraphPredictor.'", + ) def training_data(self): """[DEPRECATED] Retrieve training data associated with this node.""" return self._training_data @training_data.setter - @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", - details="Training data should only be added to the top-level GraphPredictor.'") + @deprecated( + deprecated_in="3.5.0", + removed_in="4.0.0", + details="Training data should only be added to the top-level GraphPredictor.'", + ) def training_data(self, value): self._training_data = value diff --git a/src/citrine/informatics/predictors/single_predict_request.py b/src/citrine/informatics/predictors/single_predict_request.py index c97f076c3..51af7aedd 100644 --- a/src/citrine/informatics/predictors/single_predict_request.py +++ b/src/citrine/informatics/predictors/single_predict_request.py @@ -5,7 +5,7 @@ from citrine._serialization.serializable import Serializable from citrine.informatics.design_candidate import DesignMaterial -__all__ = ['SinglePredictRequest'] +__all__ = ["SinglePredictRequest"] class SinglePredictRequest(Serializable["SinglePredictRequest"]): @@ -14,16 +14,19 @@ class SinglePredictRequest(Serializable["SinglePredictRequest"]): This class represents a request to make a prediction against a predictor. """ - material_id = properties.UUID('material_id') - identifiers = properties.List(properties.String(), 'identifiers') - material = properties.Object(DesignMaterial, 'material') - random_seed = properties.Optional(properties.Integer, 'random_seed') + material_id = properties.UUID("material_id") + identifiers = properties.List(properties.String(), "identifiers") + material = properties.Object(DesignMaterial, "material") + random_seed = properties.Optional(properties.Integer, "random_seed") - def __init__(self, material_id: UUID, - identifiers: List[str], - material: DesignMaterial, - *, - random_seed: Optional[int] = None): + def __init__( + self, + material_id: UUID, + identifiers: List[str], + material: DesignMaterial, + *, + random_seed: Optional[int] = None, + ): self.material_id = material_id self.identifiers = identifiers self.material = material diff --git a/src/citrine/informatics/predictors/single_prediction.py b/src/citrine/informatics/predictors/single_prediction.py index 9c44e09a1..a33e344d3 100644 --- a/src/citrine/informatics/predictors/single_prediction.py +++ b/src/citrine/informatics/predictors/single_prediction.py @@ -5,7 +5,7 @@ from citrine._serialization.serializable import Serializable from citrine.informatics.design_candidate import DesignMaterial -__all__ = ['SinglePrediction'] +__all__ = ["SinglePrediction"] class SinglePrediction(Serializable["SinglePrediction"]): @@ -14,13 +14,13 @@ class SinglePrediction(Serializable["SinglePrediction"]): This class represents the result of a prediction made using a predictor. """ - material_id = properties.UUID('material_id') - identifiers = properties.List(properties.String(), 'identifiers') - material = properties.Object(DesignMaterial, 'material') + material_id = properties.UUID("material_id") + identifiers = properties.List(properties.String(), "identifiers") + material = properties.Object(DesignMaterial, "material") - def __init__(self, material_id: UUID, - identifiers: List[str], - material: DesignMaterial): + def __init__( + self, material_id: UUID, identifiers: List[str], material: DesignMaterial + ): self.material_id = material_id self.identifiers = identifiers self.material = material diff --git a/src/citrine/informatics/reports.py b/src/citrine/informatics/reports.py index 143b4b0f5..5b7ce5cba 100644 --- a/src/citrine/informatics/reports.py +++ b/src/citrine/informatics/reports.py @@ -1,4 +1,5 @@ """Tools for working with reports.""" + from typing import Type, Dict, TypeVar, Iterable, Any, Set from abc import abstractmethod from itertools import groupby @@ -11,12 +12,12 @@ from citrine.informatics.descriptors import Descriptor from citrine.informatics.predictor_evaluation_result import ResponseMetrics -SelfType = TypeVar('SelfType', bound='Report') +SelfType = TypeVar("SelfType", bound="Report") logger = getLogger(__name__) -class Report(PolymorphicSerializable['Report'], AsynchronousObject): +class Report(PolymorphicSerializable["Report"], AsynchronousObject): """A Citrine Report contains information related to a module. Abstract type that returns the proper type given a serialized dict. @@ -50,17 +51,22 @@ class FeatureImportanceReport(Serializable["FeatureImportanceReport"]): should not be user-instantiated. """ - output_key = properties.String('response_key') + output_key = properties.String("response_key") """:str: output descriptor key for which these feature importances are applicable""" - importances = properties.Mapping(keys_type=properties.String, values_type=properties.Float, - serialization_path='importances') + importances = properties.Mapping( + keys_type=properties.String, + values_type=properties.Float, + serialization_path="importances", + ) """:dict[str, float]: map from feature name to its importance""" def __init__(self): pass # pragma: no cover def __str__(self): - return "".format(self.output_key) # pragma: no cover + return "".format( + self.output_key + ) # pragma: no cover class ModelEvaluationResult(Serializable["ModelEvaluationResult"]): @@ -70,18 +76,16 @@ class ModelEvaluationResult(Serializable["ModelEvaluationResult"]): and should not be user-instantiated. """ - model_settings = properties.Raw('model_settings') + model_settings = properties.Raw("model_settings") _response_results = properties.Mapping( - properties.String, - properties.Object(ResponseMetrics), - "response_results" + properties.String, properties.Object(ResponseMetrics), "response_results" ) def __init__(self): pass # pragma: no cover def __str__(self): - return '' # pragma: no cover + return "" # pragma: no cover def __getitem__(self, item): return self._response_results[item] @@ -102,51 +106,52 @@ class ModelSelectionReport(Serializable["ModelSelectionReport"]): should not be user-instantiated. """ - n_folds = properties.Integer('n_folds') + n_folds = properties.Integer("n_folds") evaluation_results = properties.List( - properties.Object(ModelEvaluationResult), - "evaluation_results" + properties.Object(ModelEvaluationResult), "evaluation_results" ) def __init__(self): pass # pragma: no cover def __str__(self): - return '' # pragma: no cover + return "" # pragma: no cover -class ModelSummary(Serializable['ModelSummary']): +class ModelSummary(Serializable["ModelSummary"]): """Summary of information about a single model in a predictor. ModelSummary objects are constructed from saved models and should not be user-instantiated. """ - name = properties.String('name') + name = properties.String("name") """:str: the name of the model""" - type_ = properties.String('type') + type_ = properties.String("type") """:str: the type of the model (e.g., "ML Model", "Featurizer", etc.)""" inputs = properties.List( - properties.Union([properties.Object(Descriptor), properties.String()]), - 'inputs' + properties.Union([properties.Object(Descriptor), properties.String()]), "inputs" ) """:List[Descriptor]: list of input descriptors""" outputs = properties.List( properties.Union([properties.Object(Descriptor), properties.String()]), - 'outputs' + "outputs", ) """:List[Descriptor]: list of output descriptors""" - model_settings = properties.Raw('model_settings') + model_settings = properties.Raw("model_settings") """:dict: model settings, as a dictionary (keys depend on the model type)""" feature_importances = properties.List( - properties.Object(FeatureImportanceReport), 'feature_importances') + properties.Object(FeatureImportanceReport), "feature_importances" + ) """:List[FeatureImportanceReport]: feature importance reports for each output""" selection_summary = properties.Optional( properties.Object(ModelSelectionReport), "selection_summary" ) """:Optional[ModelSelectionReport]: optional results of AutoML model selection""" - predictor_name = properties.String('predictor_configuration_name', default='') + predictor_name = properties.String("predictor_configuration_name", default="") """:str: the name of the predictor that created this model""" - predictor_uid = properties.Optional(properties.UUID(), 'predictor_configuration_uid') + predictor_uid = properties.Optional( + properties.UUID(), "predictor_configuration_uid" + ) """:Optional[UUID]: the unique Citrine id of the predictor that created this model""" training_data_count = properties.Optional(properties.Integer, "training_data_count") """:int: Number of rows in the training data for the model, if applicable.""" @@ -155,10 +160,10 @@ def __init__(self): pass # pragma: no cover def __str__(self): - return ''.format(self.name) # pragma: no cover + return "".format(self.name) # pragma: no cover -class PredictorReport(Serializable['PredictorReport'], Report): +class PredictorReport(Serializable["PredictorReport"], Report): """The performance metrics corresponding to a predictor. PredictorReport objects are constructed from saved models and should not be user-instantiated. @@ -168,13 +173,17 @@ class PredictorReport(Serializable['PredictorReport'], Report): _succeeded_statuses = ["OK"] _failed_statuses = ["ERROR"] - uid = properties.Optional(properties.UUID, 'id', serializable=False) + uid = properties.Optional(properties.UUID, "id", serializable=False) """:UUID: Unique Citrine id of the predictor report""" - status = properties.String('status') + status = properties.String("status") """:str: The status of the report. Possible statuses are PENDING, ERROR, and OK.""" - descriptors = properties.List(properties.Object(Descriptor), 'report.descriptors', default=[]) + descriptors = properties.List( + properties.Object(Descriptor), "report.descriptors", default=[] + ) """:List[Descriptor]: All descriptors that appear in the predictor""" - model_summaries = properties.List(properties.Object(ModelSummary), 'report.models', default=[]) + model_summaries = properties.List( + properties.Object(ModelSummary), "report.models", default=[] + ) """:List[ModelSummary]: Summaries of all models in the predictor""" def __init__(self): @@ -185,30 +194,40 @@ def post_build(self): self._fill_out_descriptors() for _, summary in enumerate(self.model_summaries): # Collapse settings on final trained model - summary.model_settings = self._collapse_model_settings(summary.model_settings) + summary.model_settings = self._collapse_model_settings( + summary.model_settings + ) if summary.selection_summary is not None: # Collapse settings on any child model evaluation results for result in summary.selection_summary.evaluation_results: - result.model_settings = self._collapse_model_settings(result.model_settings) + result.model_settings = self._collapse_model_settings( + result.model_settings + ) def _fill_out_descriptors(self): """Replace descriptor keys in `model_summaries` with full Descriptor objects.""" descriptor_map = dict() - for key, group in groupby(sorted(self.descriptors, key=lambda d: d.key), lambda d: d.key): + for key, group in groupby( + sorted(self.descriptors, key=lambda d: d.key), lambda d: d.key + ): descriptor_map[key] = self._get_sole_descriptor(group) for i, model in enumerate(self.model_summaries): for j, input_key in enumerate(model.inputs): try: model.inputs[j] = descriptor_map[input_key] except KeyError: - raise RuntimeError("Model {} contains input \'{}\', but no descriptor found " - "with that key".format(model.name, input_key)) + raise RuntimeError( + "Model {} contains input '{}', but no descriptor found " + "with that key".format(model.name, input_key) + ) for j, output_key in enumerate(model.outputs): try: model.outputs[j] = descriptor_map[output_key] except KeyError: - raise RuntimeError("Model {} contains output \'{}\', but no descriptor found " - "with that key".format(model.name, output_key)) + raise RuntimeError( + "Model {} contains output '{}', but no descriptor found " + "with that key".format(model.name, output_key) + ) @staticmethod def _get_sole_descriptor(it: Iterable): @@ -226,9 +245,12 @@ def _get_sole_descriptor(it: Iterable): as_list = list(it) if len(as_list) > 1: serialized_descriptors = [d.dump() for d in as_list] - logger.warning("Warning: found multiple descriptors with the key \'{}\', arbitrarily " - "selecting the first one. The descriptors are: {}" - .format(as_list[0].key, serialized_descriptors)) + logger.warning( + "Warning: found multiple descriptors with the key '{}', arbitrarily " + "selecting the first one. The descriptors are: {}".format( + as_list[0].key, serialized_descriptors + ) + ) return as_list[0] @staticmethod @@ -240,14 +262,15 @@ def _collapse_model_settings(raw_settings: Dict[str, Any]): top-level dictionary with keys given by "name" and values given by "value." """ + def _recurse_model_settings(settings: Dict[str, str], list_or_dict): """Recursively traverse the model settings, adding name-value pairs to dictionary.""" if isinstance(list_or_dict, list): for setting in list_or_dict: _recurse_model_settings(settings, setting) elif isinstance(list_or_dict, dict): - settings[list_or_dict['name']] = list_or_dict['value'] - _recurse_model_settings(settings, list_or_dict['children']) + settings[list_or_dict["name"]] = list_or_dict["value"] + _recurse_model_settings(settings, list_or_dict["children"]) collapsed = dict() _recurse_model_settings(collapsed, raw_settings) diff --git a/src/citrine/informatics/scores.py b/src/citrine/informatics/scores.py index ebb452c53..47d1b6ac0 100644 --- a/src/citrine/informatics/scores.py +++ b/src/citrine/informatics/scores.py @@ -1,4 +1,5 @@ """Tools for working with Scores.""" + from typing import List, Optional from citrine._serialization import properties @@ -7,30 +8,26 @@ from citrine.informatics.constraints import Constraint from citrine.informatics.objectives import Objective -__all__ = ['Score', 'LIScore', 'EIScore', 'EVScore'] +__all__ = ["Score", "LIScore", "EIScore", "EVScore"] -class Score(PolymorphicSerializable['Score']): +class Score(PolymorphicSerializable["Score"]): """A Score is used to rank materials according to objectives and constraints. Abstract type that returns the proper type given a serialized dict. """ - _name = properties.String('name') - _description = properties.String('description') + _name = properties.String("name") + _description = properties.String("description") @classmethod def get_type(cls, data): """Return the subtype.""" - return { - 'MLI': LIScore, - 'MEI': EIScore, - 'MEV': EVScore - }[data['type']] + return {"MLI": LIScore, "MEI": EIScore, "MEV": EVScore}[data["type"]] -class LIScore(Serializable['LIScore'], Score): +class LIScore(Serializable["LIScore"], Score): """Evaluates the likelihood of scoring better than some baselines for given objectives. Parameters @@ -46,15 +43,18 @@ class LIScore(Serializable['LIScore'], Score): """ - baselines = properties.List(properties.Float, 'baselines') - objectives = properties.List(properties.Object(Objective), 'objectives') - constraints = properties.List(properties.Object(Constraint), 'constraints') - typ = properties.String('type', default='MLI') - - def __init__(self, *, - objectives: List[Objective], - baselines: List[float], - constraints: Optional[List[Constraint]] = None): + baselines = properties.List(properties.Float, "baselines") + objectives = properties.List(properties.Object(Objective), "objectives") + constraints = properties.List(properties.Object(Constraint), "constraints") + typ = properties.String("type", default="MLI") + + def __init__( + self, + *, + objectives: List[Objective], + baselines: List[float], + constraints: Optional[List[Constraint]] = None, + ): self.objectives: List[Objective] = objectives self.baselines: List[float] = baselines self.constraints: List[Constraint] = constraints or [] @@ -62,10 +62,10 @@ def __init__(self, *, self._description = "" def __str__(self): - return '' + return "" -class EIScore(Serializable['EIScore'], Score): +class EIScore(Serializable["EIScore"], Score): """ Evaluates the expected magnitude of improvement beyond baselines for a given objective. @@ -81,15 +81,18 @@ class EIScore(Serializable['EIScore'], Score): """ - baselines = properties.List(properties.Float, 'baselines') - objectives = properties.List(properties.Object(Objective), 'objectives') - constraints = properties.List(properties.Object(Constraint), 'constraints') - typ = properties.String('type', default='MEI') - - def __init__(self, *, - objectives: List[Objective], - baselines: List[float], - constraints: Optional[List[Constraint]] = None): + baselines = properties.List(properties.Float, "baselines") + objectives = properties.List(properties.Object(Objective), "objectives") + constraints = properties.List(properties.Object(Constraint), "constraints") + typ = properties.String("type", default="MEI") + + def __init__( + self, + *, + objectives: List[Objective], + baselines: List[float], + constraints: Optional[List[Constraint]] = None, + ): self.objectives: List[Objective] = objectives self.baselines: List[float] = baselines self.constraints: List[Constraint] = constraints or [] @@ -97,10 +100,10 @@ def __init__(self, *, self._description = "" def __str__(self): - return '' + return "" -class EVScore(Serializable['EVScore'], Score): +class EVScore(Serializable["EVScore"], Score): """ Evaluates the expected value for given objectives. @@ -116,17 +119,20 @@ class EVScore(Serializable['EVScore'], Score): """ - objectives = properties.List(properties.Object(Objective), 'objectives') - constraints = properties.List(properties.Object(Constraint), 'constraints') - typ = properties.String('type', default='MEV') + objectives = properties.List(properties.Object(Objective), "objectives") + constraints = properties.List(properties.Object(Constraint), "constraints") + typ = properties.String("type", default="MEV") - def __init__(self, *, - objectives: List[Objective], - constraints: Optional[List[Constraint]] = None): + def __init__( + self, + *, + objectives: List[Objective], + constraints: Optional[List[Constraint]] = None, + ): self.objectives: List[Objective] = objectives self.constraints: List[Constraint] = constraints or [] self._name = "Expected Value" self._description = "" def __str__(self): - return '' + return "" diff --git a/src/citrine/informatics/workflows/analysis_workflow.py b/src/citrine/informatics/workflows/analysis_workflow.py index 08dad4b63..6e9634dc1 100644 --- a/src/citrine/informatics/workflows/analysis_workflow.py +++ b/src/citrine/informatics/workflows/analysis_workflow.py @@ -8,36 +8,43 @@ from citrine.gemd_queries.gemd_query import GemdQuery -class LatestBuild(Resource['LatestBuild']): +class LatestBuild(Resource["LatestBuild"]): """Info on the latest analysis workflow build.""" - status = properties.Optional(properties.String, 'status', serializable=False) - failures = properties.List(properties.String, 'failure_reason', default=[], serializable=False) - query = properties.Optional(properties.Object(GemdQuery), 'query', serializable=False) + status = properties.Optional(properties.String, "status", serializable=False) + failures = properties.List( + properties.String, "failure_reason", default=[], serializable=False + ) + query = properties.Optional( + properties.Object(GemdQuery), "query", serializable=False + ) -class AnalysisWorkflow(EngineResourceWithoutStatus['AnalysisWorkflow'], Workflow): +class AnalysisWorkflow(EngineResourceWithoutStatus["AnalysisWorkflow"], Workflow): """An analysis workflow stored on the platform. Note that plots are not fully supported. They're captured as raw JSON in order to facilitate cloning an existing workflow, but no facilities are provided to validate them in the client. """ - uid = properties.UUID('id', serializable=False) - name = properties.String('data.name') - description = properties.String('data.description') - snapshot_id = properties.Optional(properties.UUID, 'data.snapshot_id') - _plots = properties.List(properties.Raw, 'data.plots', default=[]) - - latest_build = properties.Optional(properties.Object(LatestBuild), 'metadata.latest_build', - serializable=False) - - def __init__(self, - *, - name: str, - description: str, - snapshot_id: Optional[Union[UUID, str]] = None, - plots: List[dict] = []): + uid = properties.UUID("id", serializable=False) + name = properties.String("data.name") + description = properties.String("data.description") + snapshot_id = properties.Optional(properties.UUID, "data.snapshot_id") + _plots = properties.List(properties.Raw, "data.plots", default=[]) + + latest_build = properties.Optional( + properties.Object(LatestBuild), "metadata.latest_build", serializable=False + ) + + def __init__( + self, + *, + name: str, + description: str, + snapshot_id: Optional[Union[UUID, str]] = None, + plots: List[dict] = [], + ): self.name = name self.description = description self.snapshot_id = snapshot_id @@ -49,11 +56,13 @@ def status(self) -> str: return self.latest_build.status def _post_dump(self, data: dict) -> dict: - data["data"]["plots"] = [plot["data"] for plot in data["data"].get("plots") or []] + data["data"]["plots"] = [ + plot["data"] for plot in data["data"].get("plots") or [] + ] return super()._post_dump(data) -class AnalysisWorkflowUpdatePayload(Resource['AnalysisWorkflowUpdatePayload']): +class AnalysisWorkflowUpdatePayload(Resource["AnalysisWorkflowUpdatePayload"]): """An object capturing the analysis workflow upload payload. Making this a separate payload makes it explicit that you can only update name and description. @@ -61,15 +70,17 @@ class AnalysisWorkflowUpdatePayload(Resource['AnalysisWorkflowUpdatePayload']): changing the other. """ - uid = properties.UUID('id', serializable=False) - name = properties.Optional(properties.String, 'name') - description = properties.Optional(properties.String, 'description') - - def __init__(self, - uid: Union[UUID, str], - *, - name: Optional[str] = None, - description: Optional[str] = None): + uid = properties.UUID("id", serializable=False) + name = properties.Optional(properties.String, "name") + description = properties.Optional(properties.String, "description") + + def __init__( + self, + uid: Union[UUID, str], + *, + name: Optional[str] = None, + description: Optional[str] = None, + ): self.uid = uid self.name = name self.description = description diff --git a/src/citrine/informatics/workflows/design_workflow.py b/src/citrine/informatics/workflows/design_workflow.py index a63848fe4..aea5f5a98 100644 --- a/src/citrine/informatics/workflows/design_workflow.py +++ b/src/citrine/informatics/workflows/design_workflow.py @@ -8,10 +8,10 @@ from citrine.resources.design_execution import DesignExecutionCollection from citrine._rest.ai_resource_metadata import AIResourceMetadata -__all__ = ['DesignWorkflow'] +__all__ = ["DesignWorkflow"] -class DesignWorkflow(Resource['DesignWorkflow'], Workflow, AIResourceMetadata): +class DesignWorkflow(Resource["DesignWorkflow"], Workflow, AIResourceMetadata): """Object that generates scored materials that may approach higher values of the score. Parameters @@ -29,31 +29,39 @@ class DesignWorkflow(Resource['DesignWorkflow'], Workflow, AIResourceMetadata): """ - design_space_id = properties.Optional(properties.UUID, 'design_space_id') - predictor_id = properties.Optional(properties.UUID, 'predictor_id') + design_space_id = properties.Optional(properties.UUID, "design_space_id") + predictor_id = properties.Optional(properties.UUID, "predictor_id") predictor_version = properties.Optional( - properties.Union([properties.Integer, properties.String]), 'predictor_version') - branch_root_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_root_id') + properties.Union([properties.Integer, properties.String]), "predictor_version" + ) + branch_root_id: Optional[UUID] = properties.Optional( + properties.UUID, "branch_root_id" + ) """:Optional[UUID]: Root ID of the branch that contains this workflow.""" - branch_version: Optional[int] = properties.Optional(properties.Integer, 'branch_version') + branch_version: Optional[int] = properties.Optional( + properties.Integer, "branch_version" + ) """:Optional[int]: Version number of the branch that contains this workflow.""" data_source = properties.Optional(properties.Object(DataSource), "data_source") - status_description = properties.String('status_description', serializable=False) + status_description = properties.String("status_description", serializable=False) """:str: more detailed description of the workflow's status""" - typ = properties.String('type', default='DesignWorkflow', deserializable=False) - - _branch_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_id', - serializable=False) - - def __init__(self, - name: str, - *, - design_space_id: Optional[UUID] = None, - predictor_id: Optional[UUID] = None, - predictor_version: Optional[Union[int, str]] = None, - data_source: Optional[DataSource] = None, - description: Optional[str] = None): + typ = properties.String("type", default="DesignWorkflow", deserializable=False) + + _branch_id: Optional[UUID] = properties.Optional( + properties.UUID, "branch_id", serializable=False + ) + + def __init__( + self, + name: str, + *, + design_space_id: Optional[UUID] = None, + predictor_id: Optional[UUID] = None, + predictor_version: Optional[Union[int, str]] = None, + data_source: Optional[DataSource] = None, + description: Optional[str] = None, + ): self.name = name self.design_space_id = design_space_id self.predictor_id = predictor_id @@ -62,7 +70,7 @@ def __init__(self, self.description = description def __str__(self): - return ''.format(self.name) + return "".format(self.name) @classmethod def _pre_build(cls, data: dict) -> dict: @@ -84,10 +92,13 @@ def _post_dump(self, data: dict) -> dict: @property def design_executions(self) -> DesignExecutionCollection: """Return a resource representing all visible executions of this workflow.""" - if getattr(self, 'project_id', None) is None: - raise AttributeError('Cannot initialize execution without project reference!') + if getattr(self, "project_id", None) is None: + raise AttributeError( + "Cannot initialize execution without project reference!" + ) return DesignExecutionCollection( - project_id=self.project_id, session=self._session, workflow_id=self.uid) + project_id=self.project_id, session=self._session, workflow_id=self.uid + ) @property def data_source_id(self) -> Optional[str]: diff --git a/src/citrine/informatics/workflows/predictor_evaluation_workflow.py b/src/citrine/informatics/workflows/predictor_evaluation_workflow.py index c6af228e4..87477a3ff 100644 --- a/src/citrine/informatics/workflows/predictor_evaluation_workflow.py +++ b/src/citrine/informatics/workflows/predictor_evaluation_workflow.py @@ -4,14 +4,17 @@ from citrine._serialization import properties from citrine.informatics.predictor_evaluator import PredictorEvaluator from citrine.informatics.workflows.workflow import Workflow -from citrine.resources.predictor_evaluation_execution import PredictorEvaluationExecutionCollection +from citrine.resources.predictor_evaluation_execution import ( + PredictorEvaluationExecutionCollection, +) from citrine._rest.ai_resource_metadata import AIResourceMetadata -__all__ = ['PredictorEvaluationWorkflow'] +__all__ = ["PredictorEvaluationWorkflow"] -class PredictorEvaluationWorkflow(Resource['PredictorEvaluationWorkflow'], - Workflow, AIResourceMetadata): +class PredictorEvaluationWorkflow( + Resource["PredictorEvaluationWorkflow"], Workflow, AIResourceMetadata +): """[DEPRECATED] A workflow that evaluates a predictor. Parameters @@ -27,28 +30,31 @@ class PredictorEvaluationWorkflow(Resource['PredictorEvaluationWorkflow'], evaluators = properties.List(properties.Object(PredictorEvaluator), "evaluators") - status_description = properties.String('status_description', serializable=False) + status_description = properties.String("status_description", serializable=False) """:str: more detailed description of the workflow's status""" - typ = properties.String('type', default='PredictorEvaluationWorkflow', deserializable=False) + typ = properties.String( + "type", default="PredictorEvaluationWorkflow", deserializable=False + ) _resource_type = ResourceTypeEnum.MODULE - def __init__(self, - name: str, - *, - description: str = "", - evaluators: List[PredictorEvaluator]): + def __init__( + self, name: str, *, description: str = "", evaluators: List[PredictorEvaluator] + ): self.name: str = name self.description: str = description self.evaluators: List[PredictorEvaluator] = evaluators def __str__(self): - return ''.format(self.name) + return "".format(self.name) @property def executions(self) -> PredictorEvaluationExecutionCollection: """Return a resource representing all visible executions of this workflow.""" - if getattr(self, 'project_id', None) is None: - raise AttributeError('Cannot initialize execution without project reference!') + if getattr(self, "project_id", None) is None: + raise AttributeError( + "Cannot initialize execution without project reference!" + ) return PredictorEvaluationExecutionCollection( - project_id=self.project_id, session=self._session, workflow_id=self.uid) + project_id=self.project_id, session=self._session, workflow_id=self.uid + ) diff --git a/src/citrine/informatics/workflows/workflow.py b/src/citrine/informatics/workflows/workflow.py index 7ee23da63..3aab0fe63 100644 --- a/src/citrine/informatics/workflows/workflow.py +++ b/src/citrine/informatics/workflows/workflow.py @@ -1,4 +1,5 @@ """Tools for working with workflow resources.""" + from typing import Optional from uuid import UUID @@ -7,7 +8,7 @@ from citrine._serialization import properties -__all__ = ['Workflow'] +__all__ = ["Workflow"] class Workflow(AsynchronousObject): @@ -27,7 +28,7 @@ class Workflow(AsynchronousObject): project_id: Optional[UUID] = None """:Optional[UUID]: Unique ID of the project that contains this workflow.""" - name = properties.String('name') - description = properties.Optional(properties.String, 'description') - uid = properties.Optional(properties.UUID, 'id', serializable=False) + name = properties.String("name") + description = properties.Optional(properties.String, "description") + uid = properties.Optional(properties.UUID, "id", serializable=False) """:Optional[UUID]: Citrine Platform unique identifier""" diff --git a/src/citrine/jobs/job.py b/src/citrine/jobs/job.py index 397a99971..2ae8c5c93 100644 --- a/src/citrine/jobs/job.py +++ b/src/citrine/jobs/job.py @@ -15,7 +15,7 @@ logger = getLogger(__name__) -class JobSubmissionResponse(Resource['JobSubmissionResponse']): +class JobSubmissionResponse(Resource["JobSubmissionResponse"]): """A response to a submit-job request for the job submission framework. This is returned as a successful response from the remote service. @@ -35,7 +35,7 @@ class JobStatus(BaseEnumeration): FAILURE = "Failure" -class TaskNode(Resource['TaskNode']): +class TaskNode(Resource["TaskNode"]): """Individual task status. The TaskNode describes a component of an overall job. @@ -64,12 +64,12 @@ def status(self, value: Union[JobStatus, str]) -> None: if JobStatus.from_str(value, exception=False) is None: warn( f"{value} is not a recognized JobStatus; this will become an error as of v4.0.0.", - DeprecationWarning + DeprecationWarning, ) self._status = value -class JobStatusResponse(Resource['JobStatusResponse']): +class JobStatusResponse(Resource["JobStatusResponse"]): """A response to a job status check. The JobStatusResponse summarizes the status for the entire job. @@ -81,7 +81,7 @@ class JobStatusResponse(Resource['JobStatusResponse']): """:str: The status of the job. One of "Running", "Success", or "Failure".""" tasks = properties.List(Object(TaskNode), "tasks") """:List[TaskNode]: all of the constituent task required to complete this job""" - output = properties.Optional(properties.Mapping(String, String), 'output') + output = properties.Optional(properties.Mapping(String, String), "output") """:Optional[dict[str, str]]: job output properties and results""" @property @@ -95,28 +95,33 @@ def status(self) -> Union[JobStatus, str]: @status.setter def status(self, value: Union[JobStatus, str]) -> None: if resolved := JobStatus.from_str(value, exception=False): - if resolved not in [JobStatus.RUNNING, JobStatus.SUCCESS, JobStatus.FAILURE]: + if resolved not in [ + JobStatus.RUNNING, + JobStatus.SUCCESS, + JobStatus.FAILURE, + ]: warn( f"{value} is not a valid JobStatus for a JobStatusResponse; " f"this will become an error as of v4.0.0.", - DeprecationWarning + DeprecationWarning, ) else: warn( f"{value} is not a recognized JobStatus; this will become an error as of v4.0.0.", - DeprecationWarning + DeprecationWarning, ) self._status = value -def _poll_for_job_completion(session: Session, - job: Union[JobSubmissionResponse, UUID, str], - *, - team_id: Union[UUID, str], - timeout: float = 2 * 60, - polling_delay: float = 2.0, - raise_errors: bool = True, - ) -> JobStatusResponse: +def _poll_for_job_completion( + session: Session, + job: Union[JobSubmissionResponse, UUID, str], + *, + team_id: Union[UUID, str], + timeout: float = 2 * 60, + polling_delay: float = 2.0, + raise_errors: bool = True, +) -> JobStatusResponse: """ Polls for job completion given a timeout. @@ -147,8 +152,8 @@ def _poll_for_job_completion(session: Session, job_id = job.job_id else: job_id = job # pragma: no cover - path = format_escaped_url('teams/{}/execution/job-status', team_id) - params = {'job_id': job_id} + path = format_escaped_url("teams/{}/execution/job-status", team_id) + params = {"job_id": job_id} start_time = time() while True: response = session.get_resource(path=path, params=params) @@ -157,27 +162,32 @@ def _poll_for_job_completion(session: Session, break elif time() - start_time < timeout: logger.info( - f'Job still in progress, polling status again in {polling_delay:.2f} seconds.' + f"Job still in progress, polling status again in {polling_delay:.2f} seconds." ) sleep(polling_delay) else: - logger.error(f'Job exceeded user timeout of {timeout} seconds. ' - f'Note job on server is unaffected by this timeout.') - logger.debug('Last status: {}'.format(status.dump())) - raise PollingTimeoutError('Job {} timed out.'.format(job_id)) + logger.error( + f"Job exceeded user timeout of {timeout} seconds. " + f"Note job on server is unaffected by this timeout." + ) + logger.debug("Last status: {}".format(status.dump())) + raise PollingTimeoutError("Job {} timed out.".format(job_id)) if status.status == JobStatus.FAILURE: - logger.debug(f'Job terminated with Failure status: {status.dump()}') + logger.debug(f"Job terminated with Failure status: {status.dump()}") if raise_errors: failure_reasons = [] for task in status.tasks: if task.status == JobStatus.FAILURE: - logger.error(f'Task {task.id} failed with reason "{task.failure_reason}"') + logger.error( + f'Task {task.id} failed with reason "{task.failure_reason}"' + ) failure_reasons.append(task.failure_reason) raise JobFailureError( - message=f'Job {job_id} terminated with Failure status. ' - f'Failure reasons: {failure_reasons}', + message=f"Job {job_id} terminated with Failure status. " + f"Failure reasons: {failure_reasons}", job_id=job_id, - failure_reasons=failure_reasons) + failure_reasons=failure_reasons, + ) return status diff --git a/src/citrine/jobs/waiting.py b/src/citrine/jobs/waiting.py index 62207947a..69c194dda 100644 --- a/src/citrine/jobs/waiting.py +++ b/src/citrine/jobs/waiting.py @@ -5,8 +5,12 @@ from citrine._rest.collection import Collection from citrine._rest.asynchronous_object import AsynchronousObject from citrine.informatics.executions.design_execution import DesignExecution -from citrine.informatics.executions.generative_design_execution import GenerativeDesignExecution -from citrine.informatics.executions.sample_design_space_execution import SampleDesignSpaceExecution +from citrine.informatics.executions.generative_design_execution import ( + GenerativeDesignExecution, +) +from citrine.informatics.executions.sample_design_space_execution import ( + SampleDesignSpaceExecution, +) from citrine.informatics.executions import PredictorEvaluationExecution @@ -32,7 +36,7 @@ def wait_for_asynchronous_object( collection: Collection[AsynchronousObject], print_status_info: bool = False, timeout: float = 1800.0, - interval: float = 3.0 + interval: float = 3.0, ) -> AsynchronousObject: """ Wait until an asynchronous object has finished. @@ -77,11 +81,12 @@ def is_finished(): raise ConditionTimeoutError( "Timeout of {timeout_length} seconds " "reached, but task {uid} is still in progress".format( - timeout_length=timeout, uid=resource.uid) + timeout_length=timeout, uid=resource.uid + ) ) current_resource = collection.get(resource.uid) - if print_status_info and hasattr(current_resource, 'status_detail'): + if print_status_info and hasattr(current_resource, "status_detail"): print("\nStatus info:") pprint([detail.msg for detail in current_resource.status_detail]) return current_resource @@ -122,9 +127,13 @@ def wait_while_validating( If fails to validate within timeout """ - return wait_for_asynchronous_object(resource=module, collection=collection, - print_status_info=print_status_info, timeout=timeout, - interval=interval) + return wait_for_asynchronous_object( + resource=module, + collection=collection, + print_status_info=print_status_info, + timeout=timeout, + interval=interval, + ) def wait_while_executing( @@ -133,22 +142,22 @@ def wait_while_executing( Collection[PredictorEvaluationExecution], Collection[DesignExecution], Collection[GenerativeDesignExecution], - Collection[SampleDesignSpaceExecution] + Collection[SampleDesignSpaceExecution], ], execution: Union[ PredictorEvaluationExecution, DesignExecution, GenerativeDesignExecution, - SampleDesignSpaceExecution + SampleDesignSpaceExecution, ], print_status_info: bool = False, timeout: float = 1800.0, interval: float = 3.0, ) -> Union[ - PredictorEvaluationExecution, - DesignExecution, - GenerativeDesignExecution, - SampleDesignSpaceExecution, + PredictorEvaluationExecution, + DesignExecution, + GenerativeDesignExecution, + SampleDesignSpaceExecution, ]: """ Wait until execution is finished. @@ -178,6 +187,10 @@ def wait_while_executing( If fails to finish execution within timeout """ - return wait_for_asynchronous_object(resource=execution, collection=collection, - print_status_info=print_status_info, timeout=timeout, - interval=interval) + return wait_for_asynchronous_object( + resource=execution, + collection=collection, + print_status_info=print_status_info, + timeout=timeout, + interval=interval, + ) diff --git a/src/citrine/resources/_default_labels.py b/src/citrine/resources/_default_labels.py index f4ce943ea..b6ef59fad 100644 --- a/src/citrine/resources/_default_labels.py +++ b/src/citrine/resources/_default_labels.py @@ -2,7 +2,7 @@ from citrine.resources.data_concepts import CITRINE_TAG_PREFIX -_CITRINE_DEFAULT_LABEL_PREFIX = f'{CITRINE_TAG_PREFIX}::mat_label' +_CITRINE_DEFAULT_LABEL_PREFIX = f"{CITRINE_TAG_PREFIX}::mat_label" def _inject_default_label_tags( diff --git a/src/citrine/resources/analysis_workflow.py b/src/citrine/resources/analysis_workflow.py index ec9e8c794..c33326124 100644 --- a/src/citrine/resources/analysis_workflow.py +++ b/src/citrine/resources/analysis_workflow.py @@ -2,8 +2,10 @@ from typing import Iterator, Optional, Union from uuid import UUID -from citrine.informatics.workflows.analysis_workflow import AnalysisWorkflow, \ - AnalysisWorkflowUpdatePayload +from citrine.informatics.workflows.analysis_workflow import ( + AnalysisWorkflow, + AnalysisWorkflowUpdatePayload, +) from citrine._rest.collection import Collection from citrine._session import Session @@ -18,11 +20,11 @@ class AnalysisWorkflowCollection(Collection[AnalysisWorkflow]): """ - _api_version = 'v1' - _path_template = '/teams/{team_id}/analysis-workflows' + _api_version = "v1" + _path_template = "/teams/{team_id}/analysis-workflows" _individual_key = None _resource = AnalysisWorkflow - _collection_key = 'response' + _collection_key = "response" def __init__(self, session: Session, *, team_id: UUID): self.session = session @@ -59,11 +61,15 @@ def list(self, *, per_page: int = 20) -> Iterator[AnalysisWorkflow]: """List acttive analysis workflows.""" return self._list_with_params(per_page=per_page, filter="archived eq 'false'") - def _list_with_params(self, *, per_page: int, **kwargs) -> Iterator[AnalysisWorkflow]: + def _list_with_params( + self, *, per_page: int, **kwargs + ) -> Iterator[AnalysisWorkflow]: page_fetcher = functools.partial(self._fetch_page, additional_params=kwargs) - return self._paginator.paginate(page_fetcher=page_fetcher, - collection_builder=self._build_collection_elements, - per_page=per_page) + return self._paginator.paginate( + page_fetcher=page_fetcher, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) def archive(self, uid: Union[UUID, str]) -> AnalysisWorkflow: """Archive an analysis workflow, hiding it from default listings.""" @@ -77,13 +83,17 @@ def restore(self, uid: Union[UUID, str]) -> AnalysisWorkflow: entity = self.session.put_resource(url, {}, version=self._api_version) return self.build(entity) - def update(self, - uid: Union[UUID, str], - *, - name: Optional[str] = None, - description: Optional[str] = None) -> AnalysisWorkflow: + def update( + self, + uid: Union[UUID, str], + *, + name: Optional[str] = None, + description: Optional[str] = None, + ) -> AnalysisWorkflow: """Update the name and/or description of the analysis workflow.""" - aw_update = AnalysisWorkflowUpdatePayload(uid=uid, name=name, description=description) + aw_update = AnalysisWorkflowUpdatePayload( + uid=uid, name=name, description=description + ) return super().update(aw_update) def rebuild(self, uid: Union[UUID, str]) -> AnalysisWorkflow: @@ -94,4 +104,6 @@ def rebuild(self, uid: Union[UUID, str]) -> AnalysisWorkflow: def delete(self, uid: Union[UUID, str]): """Analysis workflows cannot be deleted at this time.""" - raise NotImplementedError("Deleting Analysis Workflows is not currently supported.") + raise NotImplementedError( + "Deleting Analysis Workflows is not currently supported." + ) diff --git a/src/citrine/resources/api_error.py b/src/citrine/resources/api_error.py index cfbfc2d4c..f531471d7 100644 --- a/src/citrine/resources/api_error.py +++ b/src/citrine/resources/api_error.py @@ -16,7 +16,9 @@ class ApiError(Serializable["ApiError"]): code = properties.Optional(properties.Integer(), "code") message = properties.Optional(properties.String(), "message") - validation_errors = properties.List(properties.Object(ValidationError), "validation_errors") + validation_errors = properties.List( + properties.Object(ValidationError), "validation_errors" + ) def has_failure(self, failure_id: str) -> bool: """Checks if this error contains a ValidationError with specified failure ID.""" diff --git a/src/citrine/resources/attribute_templates.py b/src/citrine/resources/attribute_templates.py index f4a7aee8d..2bff54c1f 100644 --- a/src/citrine/resources/attribute_templates.py +++ b/src/citrine/resources/attribute_templates.py @@ -1,10 +1,13 @@ """Top-level class for all attribute template objects and collections thereof.""" + from abc import ABC from typing import TypeVar from citrine._serialization.properties import Optional as PropertyOptional from citrine._serialization.properties import String, Object -from gemd.entity.template.attribute_template import AttributeTemplate as GEMDAttributeTemplate +from gemd.entity.template.attribute_template import ( + AttributeTemplate as GEMDAttributeTemplate, +) from gemd.entity.bounds.base_bounds import BaseBounds from citrine.resources.templates import Template, TemplateCollection @@ -16,12 +19,14 @@ class AttributeTemplate(Template, GEMDAttributeTemplate, ABC): AttributeTemplate must be extended along with `Resource` """ - name = String('name') - description = PropertyOptional(String(), 'description') - bounds = Object(BaseBounds, 'bounds', override=True) + name = String("name") + description = PropertyOptional(String(), "description") + bounds = Object(BaseBounds, "bounds", override=True) -AttributeTemplateResourceType = TypeVar("AttributeTemplateResourceType", bound="AttributeTemplate") +AttributeTemplateResourceType = TypeVar( + "AttributeTemplateResourceType", bound="AttributeTemplate" +) class AttributeTemplateCollection(TemplateCollection[AttributeTemplateResourceType]): diff --git a/src/citrine/resources/audit_info.py b/src/citrine/resources/audit_info.py index 554b25ec5..7994a2957 100644 --- a/src/citrine/resources/audit_info.py +++ b/src/citrine/resources/audit_info.py @@ -6,13 +6,13 @@ class AuditInfo(Serializable, DictSerializable, typ="audit_info"): """Model that holds audit metadata. AuditInfo objects should not be created by the user.""" - created_by = properties.Optional(properties.UUID, 'created_by') + created_by = properties.Optional(properties.UUID, "created_by") """:Optional[UUID]: ID of the user who created the object""" - created_at = properties.Optional(properties.Datetime, 'created_at') + created_at = properties.Optional(properties.Datetime, "created_at") """:Optional[datetime]: Time, in ms since epoch, at which the object was created""" - updated_by = properties.Optional(properties.UUID, 'updated_by') + updated_by = properties.Optional(properties.UUID, "updated_by") """:Optional[UUID]: ID of the user who most recently updated the object""" - updated_at = properties.Optional(properties.Datetime, 'updated_at') + updated_at = properties.Optional(properties.Datetime, "updated_at") """:Optional[datetime]: Time, in ms since epoch, at which the object was most recently updated""" @@ -20,18 +20,20 @@ def __init__(self): pass # pragma: no cover def __repr__(self): - return 'Created by: {!r}\nCreated at: {!r}\nUpdated by: {!r}\nUpdated at: {!r}'.format( + return "Created by: {!r}\nCreated at: {!r}\nUpdated by: {!r}\nUpdated at: {!r}".format( self.created_by, self.created_at, self.updated_by, self.updated_at ) def __str__(self): - create_str = 'Created by user {} at time {}'.format( - self.created_by, self.created_at) + create_str = "Created by user {} at time {}".format( + self.created_by, self.created_at + ) if self.updated_by is not None or self.updated_at is not None: - update_str = '\nUpdated by user {} at time {}'.format( - self.updated_by, self.updated_at) + update_str = "\nUpdated by user {} at time {}".format( + self.updated_by, self.updated_at + ) else: - update_str = '' + update_str = "" return create_str + update_str def __eq__(self, other): diff --git a/src/citrine/resources/branch.py b/src/citrine/resources/branch.py index 4d9099304..f83c06ba4 100644 --- a/src/citrine/resources/branch.py +++ b/src/citrine/resources/branch.py @@ -7,61 +7,73 @@ from citrine._serialization import properties from citrine._session import Session from citrine.exceptions import NotFound -from citrine.resources.data_version_update import BranchDataUpdate, NextBranchVersionRequest +from citrine.resources.data_version_update import ( + BranchDataUpdate, + NextBranchVersionRequest, +) from citrine.resources.design_workflow import DesignWorkflowCollection -from citrine.resources.experiment_datasource import (ExperimentDataSourceCollection, - ExperimentDataSource) +from citrine.resources.experiment_datasource import ( + ExperimentDataSourceCollection, + ExperimentDataSource, +) LATEST_VER = "latest" # Refers to the most recently created branch version. -class Branch(Resource['Branch']): +class Branch(Resource["Branch"]): """ A project branch. A branch is a container for design workflows. """ - name = properties.String('data.name') - uid = properties.Optional(properties.UUID(), 'id') - archived = properties.Boolean('metadata.archived', serializable=False) - created_at = properties.Optional(properties.Datetime(), 'metadata.created.time', - serializable=False) - updated_at = properties.Optional(properties.Datetime(), 'metadata.updated.time', - serializable=False) + name = properties.String("data.name") + uid = properties.Optional(properties.UUID(), "id") + archived = properties.Boolean("metadata.archived", serializable=False) + created_at = properties.Optional( + properties.Datetime(), "metadata.created.time", serializable=False + ) + updated_at = properties.Optional( + properties.Datetime(), "metadata.updated.time", serializable=False + ) # added in v2 - root_id = properties.UUID('metadata.root_id', serializable=False) - version = properties.Integer('metadata.version', serializable=False) + root_id = properties.UUID("metadata.root_id", serializable=False) + version = properties.Integer("metadata.version", serializable=False) project_id: Optional[UUID] = None - def __init__(self, - name: str, - *, - session: Optional[Session] = None): + def __init__(self, name: str, *, session: Optional[Session] = None): self.name: str = name self.session: Session = session def __str__(self): - return f'' + return f"" @property def design_workflows(self) -> DesignWorkflowCollection: """Return a resource representing all workflows contained within this branch.""" - if getattr(self, 'project_id', None) is None: - raise AttributeError('Cannot initialize workflow without project reference!') - return DesignWorkflowCollection(project_id=self.project_id, - session=self.session, - branch_root_id=self.root_id, - branch_version=self.version) + if getattr(self, "project_id", None) is None: + raise AttributeError( + "Cannot initialize workflow without project reference!" + ) + return DesignWorkflowCollection( + project_id=self.project_id, + session=self.session, + branch_root_id=self.root_id, + branch_version=self.version, + ) @property def experiment_datasource(self) -> Optional[ExperimentDataSource]: """Return this branch's experiment data source, or None if one doesn't exist.""" - if getattr(self, 'project_id', None) is None: - raise AttributeError('Cannot retrieve datasource without project reference!') - erds = ExperimentDataSourceCollection(project_id=self.project_id, session=self.session) + if getattr(self, "project_id", None) is None: + raise AttributeError( + "Cannot retrieve datasource without project reference!" + ) + erds = ExperimentDataSourceCollection( + project_id=self.project_id, session=self.session + ) branch_erds_iter = erds.list(branch_version_id=self.uid, version=LATEST_VER) return next(branch_erds_iter, None) @@ -74,11 +86,11 @@ def _post_dump(self, data: dict) -> dict: class BranchCollection(Collection[Branch]): """A collection of Branches.""" - _path_template = '/projects/{project_id}/branches' + _path_template = "/projects/{project_id}/branches" _individual_key = None - _collection_key = 'response' + _collection_key = "response" _resource = Branch - _api_version = 'v2' + _api_version = "v2" def __init__(self, project_id: UUID, session: Session): self.project_id: UUID = project_id @@ -104,10 +116,12 @@ def build(self, data: dict) -> Branch: branch.project_id = self.project_id return branch - def get(self, - *, - root_id: Union[UUID, str], - version: Optional[Union[int, str]] = LATEST_VER) -> Branch: + def get( + self, + *, + root_id: Union[UUID, str], + version: Optional[Union[int, str]] = LATEST_VER, + ) -> Branch: """ Retrieve a branch by its root ID and, optionally, its version number. @@ -138,7 +152,7 @@ def get(self, message=f"Branch root '{root_id}', version {version} not found", method="GET", path=self._get_path(), - params=params + params=params, ) def get_by_version_id(self, *, version_id: Union[UUID, str]) -> Branch: @@ -220,14 +234,18 @@ def list_all(self, *, per_page: int = 20) -> Iterator[Branch]: def _list_with_params(self, *, per_page, **kwargs): fetcher = functools.partial(self._fetch_page, additional_params=kwargs) - return self._paginator.paginate(page_fetcher=fetcher, - collection_builder=self._build_collection_elements, - per_page=per_page) - - def archive(self, - *, - root_id: Union[UUID, str], - version: Optional[Union[int, str]] = LATEST_VER): + return self._paginator.paginate( + page_fetcher=fetcher, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) + + def archive( + self, + *, + root_id: Union[UUID, str], + version: Optional[Union[int, str]] = LATEST_VER, + ): """ Archive a branch. @@ -249,10 +267,12 @@ def archive(self, data = self.session.put_resource(url, {}, version=self._api_version) return self.build(data) - def restore(self, - *, - root_id: Union[UUID, str], - version: Optional[Union[int, str]] = LATEST_VER): + def restore( + self, + *, + root_id: Union[UUID, str], + version: Optional[Union[int, str]] = LATEST_VER, + ): """ Restore an archived branch. @@ -274,12 +294,14 @@ def restore(self, data = self.session.put_resource(url, {}, version=self._api_version) return self.build(data) - def update_data(self, - *, - root_id: Union[UUID, str], - version: Optional[Union[int, str]] = LATEST_VER, - use_existing: bool = True, - retrain_models: bool = False) -> Optional[Branch]: + def update_data( + self, + *, + root_id: Union[UUID, str], + version: Optional[Union[int, str]] = LATEST_VER, + use_existing: bool = True, + retrain_models: bool = False, + ) -> Optional[Branch]: """ Automatically advance the branch to the next version. @@ -319,17 +341,22 @@ def update_data(self, if use_existing: use_predictors = version_updates.predictors - branch_instructions = NextBranchVersionRequest(data_updates=version_updates.data_updates, - use_predictors=use_predictors) - branch = self.next_version(root_id=root_id, - branch_instructions=branch_instructions, - retrain_models=retrain_models) + branch_instructions = NextBranchVersionRequest( + data_updates=version_updates.data_updates, use_predictors=use_predictors + ) + branch = self.next_version( + root_id=root_id, + branch_instructions=branch_instructions, + retrain_models=retrain_models, + ) return branch - def data_updates(self, - *, - root_id: Union[UUID, str], - version: Optional[Union[int, str]] = LATEST_VER) -> BranchDataUpdate: + def data_updates( + self, + *, + root_id: Union[UUID, str], + version: Optional[Union[int, str]] = LATEST_VER, + ) -> BranchDataUpdate: """ Get data updates for a branch. @@ -356,11 +383,13 @@ def data_updates(self, data = self.session.get_resource(path, version=self._api_version) return BranchDataUpdate.build(data) - def next_version(self, - root_id: Union[UUID, str], - *, - branch_instructions: NextBranchVersionRequest, - retrain_models: bool = True): + def next_version( + self, + root_id: Union[UUID, str], + *, + branch_instructions: NextBranchVersionRequest, + retrain_models: bool = True, + ): """ Move a branch to the next version. @@ -388,9 +417,10 @@ def next_version(self, """ path = self._get_path(action="next-version-predictor") - data = self.session.post_resource(path, branch_instructions.dump(), - version=self._api_version, - params={ - 'root': str(root_id), - 'retrain_models': retrain_models}) + data = self.session.post_resource( + path, + branch_instructions.dump(), + version=self._api_version, + params={"root": str(root_id), "retrain_models": retrain_models}, + ) return self.build(data) diff --git a/src/citrine/resources/catalyst.py b/src/citrine/resources/catalyst.py index 9307e75f5..3d03d947c 100644 --- a/src/citrine/resources/catalyst.py +++ b/src/citrine/resources/catalyst.py @@ -9,8 +9,8 @@ class CatalystResource: """Encapsulates th ability to invoke Catalyst.""" - _path_template: str = '/catalyst' - _api_version = 'v1' + _path_template: str = "/catalyst" + _api_version = "v1" def __init__(self, session: Session): self.session: Session = session diff --git a/src/citrine/resources/condition_template.py b/src/citrine/resources/condition_template.py index af98f44bd..b12d2780b 100644 --- a/src/citrine/resources/condition_template.py +++ b/src/citrine/resources/condition_template.py @@ -1,17 +1,23 @@ """Resources that represent condition templates.""" + from typing import List, Dict, Optional, Type from citrine._rest.resource import GEMDResource -from citrine.resources.attribute_templates import AttributeTemplate, AttributeTemplateCollection +from citrine.resources.attribute_templates import ( + AttributeTemplate, + AttributeTemplateCollection, +) from gemd.entity.bounds.base_bounds import BaseBounds -from gemd.entity.template.condition_template import ConditionTemplate as GEMDConditionTemplate +from gemd.entity.template.condition_template import ( + ConditionTemplate as GEMDConditionTemplate, +) class ConditionTemplate( - GEMDResource['ConditionTemplate'], + GEMDResource["ConditionTemplate"], AttributeTemplate, GEMDConditionTemplate, - typ=GEMDConditionTemplate.typ + typ=GEMDConditionTemplate.typ, ): """ A condition template. @@ -37,29 +43,36 @@ class ConditionTemplate( _response_key = GEMDConditionTemplate.typ # 'condition_template' - def __init__(self, - name: str, - *, - bounds: BaseBounds, - uids: Optional[Dict[str, str]] = None, - description: Optional[str] = None, - tags: Optional[List[str]] = None - ): + def __init__( + self, + name: str, + *, + bounds: BaseBounds, + uids: Optional[Dict[str, str]] = None, + description: Optional[str] = None, + tags: Optional[List[str]] = None, + ): if uids is None: uids = dict() super(AttributeTemplate, self).__init__() - GEMDConditionTemplate.__init__(self, name=name, bounds=bounds, tags=tags, - uids=uids, description=description) + GEMDConditionTemplate.__init__( + self, + name=name, + bounds=bounds, + tags=tags, + uids=uids, + description=description, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class ConditionTemplateCollection(AttributeTemplateCollection[ConditionTemplate]): """A collection of condition templates.""" - _individual_key = 'condition_template' - _collection_key = 'condition_templates' + _individual_key = "condition_template" + _collection_key = "condition_templates" _resource = ConditionTemplate @classmethod diff --git a/src/citrine/resources/data_concepts.py b/src/citrine/resources/data_concepts.py index 344db7ddb..cce8f927d 100644 --- a/src/citrine/resources/data_concepts.py +++ b/src/citrine/resources/data_concepts.py @@ -1,4 +1,5 @@ """Top-level class for all data concepts objects and collections thereof.""" + import re from abc import abstractmethod, ABC from typing import TypeVar, Type, List, Union, Optional, Iterator, Iterable @@ -18,15 +19,20 @@ from citrine._serialization.properties import UUID as PropertyUUID from citrine._serialization.serializable import Serializable from citrine._session import Session -from citrine._utils.functions import _data_manager_deprecation_checks, format_escaped_url, \ - _pad_positional_args, replace_objects_with_links, scrub_none +from citrine._utils.functions import ( + _data_manager_deprecation_checks, + format_escaped_url, + _pad_positional_args, + replace_objects_with_links, + scrub_none, +) from citrine.exceptions import BadRequest from citrine.jobs.job import _poll_for_job_completion from citrine.resources.audit_info import AuditInfo from citrine.resources.response import Response -CITRINE_SCOPE = 'id' -CITRINE_TAG_PREFIX = 'citr_auto' +CITRINE_SCOPE = "id" +CITRINE_TAG_PREFIX = "citr_auto" class DataConceptsMeta(DictSerializableMeta): @@ -34,17 +40,16 @@ class DataConceptsMeta(DictSerializableMeta): def __init__(cls, *args, **kwargs): super().__init__(*args, **kwargs) - resolved = next((b.typ for b in cls.__bases__ if getattr(b, "typ", None) is not None), - None) + resolved = next( + (b.typ for b in cls.__bases__ if getattr(b, "typ", None) is not None), None + ) if resolved is not None: cls._typ_stash = resolved cls.typ = String("type") class DataConcepts( - PolymorphicSerializable['DataConcepts'], - BaseEntity, - metaclass=DataConceptsMeta + PolymorphicSerializable["DataConcepts"], BaseEntity, metaclass=DataConceptsMeta ): """ An abstract data concepts object. @@ -59,8 +64,10 @@ class DataConcepts( """ """Properties inherited from GEMD Base Entitiy.""" - uids = PropertyOptional(Mapping(String('scope'), String('id')), 'uids', override=True) - tags = PropertyOptional(PropertyList(String()), 'tags', override=True) + uids = PropertyOptional( + Mapping(String("scope"), String("id")), "uids", override=True + ) + tags = PropertyOptional(PropertyList(String()), "tags", override=True) _type_key = "type" """str: key used to determine type of serialized object.""" @@ -122,7 +129,7 @@ def get_type(cls, data) -> Type[Serializable]: """ if isinstance(data, DictSerializable): data = data.as_dict() - return DictSerializable.class_mapping[data['type']] + return DictSerializable.class_mapping[data["type"]] @classmethod def get_collection_type(cls, data) -> "Type[DataConceptsCollection]": @@ -149,7 +156,7 @@ def get_collection_type(cls, data) -> "Type[DataConceptsCollection]": DataConcepts._make_collection_dict() if isinstance(data, DictSerializable): data = data.as_dict() - return DataConcepts.collection_dict[data['type']] + return DataConcepts.collection_dict[data["type"]] @staticmethod def _make_collection_dict(): @@ -168,18 +175,30 @@ def _make_collection_dict(): from citrine.resources.material_run import MaterialRunCollection from citrine.resources.measurement_run import MeasurementRunCollection from citrine.resources.process_run import ProcessRunCollection + _collection_list = [ - ConditionTemplateCollection, ParameterTemplateCollection, PropertyTemplateCollection, - MaterialTemplateCollection, MeasurementTemplateCollection, ProcessTemplateCollection, - IngredientSpecCollection, MaterialSpecCollection, MeasurementSpecCollection, - ProcessSpecCollection, IngredientRunCollection, MaterialRunCollection, - MeasurementRunCollection, ProcessRunCollection + ConditionTemplateCollection, + ParameterTemplateCollection, + PropertyTemplateCollection, + MaterialTemplateCollection, + MeasurementTemplateCollection, + ProcessTemplateCollection, + IngredientSpecCollection, + MaterialSpecCollection, + MeasurementSpecCollection, + ProcessSpecCollection, + IngredientRunCollection, + MaterialRunCollection, + MeasurementRunCollection, + ProcessRunCollection, ] for collection in _collection_list: DataConcepts.collection_dict[collection._individual_key] = collection -def _make_link_by_uid(gemd_object_rep: Union[str, UUID, BaseEntity, LinkByUID]) -> LinkByUID: +def _make_link_by_uid( + gemd_object_rep: Union[str, UUID, BaseEntity, LinkByUID], +) -> LinkByUID: if isinstance(gemd_object_rep, BaseEntity): return gemd_object_rep.to_link(CITRINE_SCOPE, allow_fallback=True) elif isinstance(gemd_object_rep, LinkByUID): @@ -189,11 +208,13 @@ def _make_link_by_uid(gemd_object_rep: Union[str, UUID, BaseEntity, LinkByUID]) scope = CITRINE_SCOPE return LinkByUID(scope, uid) else: - raise TypeError("Link can only be created from a GEMD object, LinkByUID, str, or UUID." - "Instead got {}.".format(gemd_object_rep)) + raise TypeError( + "Link can only be created from a GEMD object, LinkByUID, str, or UUID." + "Instead got {}.".format(gemd_object_rep) + ) -ResourceType = TypeVar('ResourceType', bound='DataConcepts') +ResourceType = TypeVar("ResourceType", bound="DataConcepts") class DataConceptsCollection(Collection[ResourceType], ABC): @@ -214,12 +235,14 @@ class DataConceptsCollection(Collection[ResourceType], ABC): """ - def __init__(self, - *args, - session: Session = None, - dataset_id: Optional[UUID] = None, - team_id: UUID = None, - project_id: Optional[UUID] = None): + def __init__( + self, + *args, + session: Session = None, + dataset_id: Optional[UUID] = None, + team_id: UUID = None, + project_id: Optional[UUID] = None, + ): # Handle positional arguments for backward compatibility args = _pad_positional_args(args, 3) self.project_id = project_id or args[0] @@ -232,7 +255,8 @@ def __init__(self, session=self.session, project_id=self.project_id, team_id=team_id, - obj_type="GEMD Objects") + obj_type="GEMD Objects", + ) @classmethod @abstractmethod @@ -242,15 +266,17 @@ def get_type(cls) -> Type[Serializable]: @property def _path_template(self): collection_key = self._collection_key.replace("_", "-") - return f'teams/{self.team_id}/datasets/{self.dataset_id}/{collection_key}' + return f"teams/{self.team_id}/datasets/{self.dataset_id}/{collection_key}" # After Data Manager deprecation, both can use the `teams/...` path. @property def _dataset_agnostic_path_template(self): if self.project_id is None: - return f'teams/{self.team_id}/{self._collection_key.replace("_", "-")}' + return f"teams/{self.team_id}/{self._collection_key.replace('_', '-')}" else: - return f'projects/{self.project_id}/{self._collection_key.replace("_", "-")}' + return ( + f"projects/{self.project_id}/{self._collection_key.replace('_', '-')}" + ) def build(self, data: dict) -> ResourceType: """ @@ -271,9 +297,9 @@ def build(self, data: dict) -> ResourceType: """ return self.get_type().build(data) - def list(self, *, - per_page: Optional[int] = 100, - forward: bool = True) -> Iterator[ResourceType]: + def list( + self, *, per_page: Optional[int] = 100, forward: bool = True + ) -> Iterator[ResourceType]: """ Get all visible elements of the collection. @@ -298,13 +324,14 @@ def list(self, *, """ params = {} if self.dataset_id is not None: - params['dataset_id'] = str(self.dataset_id) + params["dataset_id"] = str(self.dataset_id) raw_objects = self.session.cursor_paged_resource( self.session.get_resource, self._get_path(ignore_dataset=True), forward=forward, per_page=per_page, - params=params) + params=params, + ) return (self.build(raw) for raw in raw_objects) def register(self, model: ResourceType, *, dry_run=False): @@ -337,9 +364,11 @@ def register(self, model: ResourceType, *, dry_run=False): """ if self.dataset_id is None: - raise RuntimeError("Must specify a dataset in order to register a data model object.") + raise RuntimeError( + "Must specify a dataset in order to register a data model object." + ) path = self._get_path() - params = {'dry_run': dry_run} + params = {"dry_run": dry_run} temp_scope = str(uuid4()) scope = temp_scope if dry_run else CITRINE_SCOPE @@ -349,34 +378,46 @@ def register(self, model: ResourceType, *, dry_run=False): data = self.session.post_resource(path, dumped_data, params=params) registered = self.build(data) - recursive_foreach(model, lambda x: x.uids.pop(temp_scope, None)) # Strip temp uids + recursive_foreach( + model, lambda x: x.uids.pop(temp_scope, None) + ) # Strip temp uids if not dry_run: # Platform may add a CITRINE_SCOPE uid and citr_auto tags; update locals model.uids.update({k: v for k, v in registered.uids.items()}) if registered.tags is not None: if model.tags is None: # This is somehow hit by nextgen-devkit tests model.tags = list() # pragma: no cover - model.tags.extend([tag for tag in registered.tags - if re.match(f"^{CITRINE_TAG_PREFIX}::", tag)]) + model.tags.extend( + [ + tag + for tag in registered.tags + if re.match(f"^{CITRINE_TAG_PREFIX}::", tag) + ] + ) else: # Remove of the tags/uids the platform spuriously added # this might leave objects with just the temp ids, which we want to strip later if CITRINE_SCOPE not in model.uids: registered.uids.pop(CITRINE_SCOPE, None) if registered.tags is not None: - todo = [tag for tag in registered.tags - if re.match(f"^{CITRINE_TAG_PREFIX}::", tag)] + todo = [ + tag + for tag in registered.tags + if re.match(f"^{CITRINE_TAG_PREFIX}::", tag) + ] for tag in todo: # Covering this block would require dark art if tag not in model.tags: registered.tags.remove(tag) return registered - def register_all(self, - models: Iterable[ResourceType], - *, - dry_run: bool = False, - status_bar: bool = False, - include_nested: bool = False) -> List[ResourceType]: + def register_all( + self, + models: Iterable[ResourceType], + *, + dry_run: bool = False, + status_bar: bool = False, + include_nested: bool = False, + ) -> List[ResourceType]: """ Register multiple GEMD objects to each of their appropriate collections. @@ -419,14 +460,15 @@ def register_all(self, """ # avoiding a circular import from citrine.resources.gemd_resource import GEMDResourceCollection - gemd_collection = GEMDResourceCollection(team_id=self.team_id, - dataset_id=self.dataset_id, - session=self.session) + + gemd_collection = GEMDResourceCollection( + team_id=self.team_id, dataset_id=self.dataset_id, session=self.session + ) return gemd_collection.register_all( models, dry_run=dry_run, status_bar=status_bar, - include_nested=include_nested + include_nested=include_nested, ) def update(self, model: ResourceType) -> ResourceType: @@ -444,15 +486,20 @@ def update(self, model: ResourceType) -> ResourceType: return self.register(model, dry_run=False) except BadRequest: # If register() cannot be used because an asynchronous check is required - return self.async_update(model, dry_run=False, - wait_for_response=True, return_model=True) - - def async_update(self, model: ResourceType, *, - dry_run: bool = False, - wait_for_response: bool = True, - timeout: float = 2 * 60, - polling_delay: float = 1.0, - return_model: bool = False) -> Optional[Union[UUID, ResourceType]]: + return self.async_update( + model, dry_run=False, wait_for_response=True, return_model=True + ) + + def async_update( + self, + model: ResourceType, + *, + dry_run: bool = False, + wait_for_response: bool = True, + timeout: float = 2 * 60, + polling_delay: float = 1.0, + return_model: bool = False, + ) -> Optional[Union[UUID, ResourceType]]: """ Update a particular element of the collection with data validation. @@ -499,22 +546,29 @@ def async_update(self, model: ResourceType, *, temp_scope = str(uuid4()) GEMDJson(scope=temp_scope).dumps(model) # This apparent no-op populates uids dumped_data = replace_objects_with_links(scrub_none(model.dump())) - recursive_foreach(model, lambda x: x.uids.pop(temp_scope, None)) # Strip temp uids + recursive_foreach( + model, lambda x: x.uids.pop(temp_scope, None) + ) # Strip temp uids scope = CITRINE_SCOPE - id = dumped_data['uids'][scope] + id = dumped_data["uids"][scope] if self.dataset_id is None: - raise RuntimeError("Must specify a dataset in order to update " - "a data model object with data validation.") + raise RuntimeError( + "Must specify a dataset in order to update " + "a data model object with data validation." + ) url = self._get_path(action=[scope, id, "async"]) - response_json = self.session.put_resource(url, dumped_data, params={'dry_run': dry_run}) + response_json = self.session.put_resource( + url, dumped_data, params={"dry_run": dry_run} + ) job_id = response_json["job_id"] if wait_for_response: - self.poll_async_update_job(job_id=job_id, timeout=timeout, - polling_delay=polling_delay) + self.poll_async_update_job( + job_id=job_id, timeout=timeout, polling_delay=polling_delay + ) # That worked, return nothing or return the object if return_model: @@ -525,8 +579,9 @@ def async_update(self, model: ResourceType, *, # TODO: use JobSubmissionResponse here instead return job_id - def poll_async_update_job(self, job_id: UUID, *, timeout: float = 2 * 60, - polling_delay: float = 1.0) -> None: + def poll_async_update_job( + self, job_id: UUID, *, timeout: float = 2 * 60, polling_delay: float = 1.0 + ) -> None: """ Poll for the result of the async_update call. @@ -559,8 +614,10 @@ def poll_async_update_job(self, job_id: UUID, *, timeout: float = 2 * 60, _poll_for_job_completion( session=self.session, team_id=self.team_id, - job=job_id, timeout=timeout, - polling_delay=polling_delay) + job=job_id, + timeout=timeout, + polling_delay=polling_delay, + ) # That worked, nothing returned in this case return None @@ -581,12 +638,20 @@ def get(self, uid: Union[UUID, str, LinkByUID, BaseEntity]) -> ResourceType: """ link = _make_link_by_uid(uid) - path = self._get_path(ignore_dataset=self.dataset_id is None, action=[link.scope, link.id]) + path = self._get_path( + ignore_dataset=self.dataset_id is None, action=[link.scope, link.id] + ) data = self.session.get_resource(path) return self.build(data) - def list_by_name(self, name: str, *, exact: bool = False, - forward: bool = True, per_page: int = 100) -> Iterator[ResourceType]: + def list_by_name( + self, + name: str, + *, + exact: bool = False, + forward: bool = True, + per_page: int = 100, + ) -> Iterator[ResourceType]: """ Get all objects with specified name in this dataset. @@ -612,14 +677,15 @@ def list_by_name(self, name: str, *, exact: bool = False, """ if self.dataset_id is None: raise RuntimeError("Must specify a dataset to filter by name.") - params = {'dataset_id': str(self.dataset_id), 'name': name, 'exact': exact} + params = {"dataset_id": str(self.dataset_id), "name": name, "exact": exact} raw_objects = self.session.cursor_paged_resource( self.session.get_resource, # "Ignoring" dataset because it is in the query params (and required) self._get_path(ignore_dataset=True, action="filter-by-name"), forward=forward, per_page=per_page, - params=params) + params=params, + ) return (self.build(raw) for raw in raw_objects) def list_by_tag(self, tag: str, *, per_page: int = 100) -> Iterator[ResourceType]: @@ -648,17 +714,20 @@ def list_by_tag(self, tag: str, *, per_page: int = 100) -> Iterator[ResourceType Every object in this collection. """ - params = {'tags': [tag]} + params = {"tags": [tag]} if self.dataset_id is not None: - params['dataset_id'] = str(self.dataset_id) + params["dataset_id"] = str(self.dataset_id) raw_objects = self.session.cursor_paged_resource( self.session.get_resource, self._get_path(ignore_dataset=True), per_page=per_page, - params=params) + params=params, + ) return (self.build(raw) for raw in raw_objects) - def delete(self, uid: Union[UUID, str, LinkByUID, BaseEntity], *, dry_run: bool = False): + def delete( + self, uid: Union[UUID, str, LinkByUID, BaseEntity], *, dry_run: bool = False + ): """ Delete an element of the collection by its id. @@ -673,12 +742,17 @@ def delete(self, uid: Union[UUID, str, LinkByUID, BaseEntity], *, dry_run: bool """ link = _make_link_by_uid(uid) path = self._get_path(action=[link.scope, link.id]) - params = {'dry_run': dry_run} + params = {"dry_run": dry_run} self.session.delete_resource(path, params=params) return Response(status_code=200) # delete succeeded - def _get_relation(self, relation: str, uid: Union[UUID, str, LinkByUID, BaseEntity], - forward: bool = True, per_page: int = 100) -> Iterator[ResourceType]: + def _get_relation( + self, + relation: str, + uid: Union[UUID, str, LinkByUID, BaseEntity], + forward: bool = True, + per_page: int = 100, + ) -> Iterator[ResourceType]: """ Generic method for searching this collection by relation to another object. @@ -705,19 +779,21 @@ def _get_relation(self, relation: str, uid: Union[UUID, str, LinkByUID, BaseEnti """ params = {} if self.dataset_id is not None: - params['dataset_id'] = str(self.dataset_id) + params["dataset_id"] = str(self.dataset_id) link = _make_link_by_uid(uid) raw_objects = self.session.cursor_paged_resource( self.session.get_resource, - format_escaped_url('teams/{}/{}/{}/{}/{}', - self.team_id, - relation, - link.scope, - link.id, - self._collection_key.replace('_', '-') - ), + format_escaped_url( + "teams/{}/{}/{}/{}/{}", + self.team_id, + relation, + link.scope, + link.id, + self._collection_key.replace("_", "-"), + ), forward=forward, per_page=per_page, params=params, - version='v1') + version="v1", + ) return (self.build(raw) for raw in raw_objects) diff --git a/src/citrine/resources/data_objects.py b/src/citrine/resources/data_objects.py index c0b0839e2..a916d0b3a 100644 --- a/src/citrine/resources/data_objects.py +++ b/src/citrine/resources/data_objects.py @@ -1,4 +1,5 @@ """Top-level class for all data object (i.e., spec and run) objects and collections thereof.""" + from abc import ABC from typing import Dict, Union, Optional, Iterator, List, TypeVar from uuid import uuid4 @@ -6,7 +7,11 @@ from gemd.json import GEMDJson from gemd.util import recursive_foreach -from citrine._utils.functions import get_object_id, replace_objects_with_links, scrub_none +from citrine._utils.functions import ( + get_object_id, + replace_objects_with_links, + scrub_none, +) from citrine._serialization.properties import List as PropertyList from citrine._serialization.properties import String, Object from citrine._serialization.properties import Optional as PropertyOptional @@ -29,8 +34,10 @@ class DataObject(DataConcepts, BaseObject, ABC): DataObject must be extended along with `Resource` """ - notes = PropertyOptional(String(), 'notes') - file_links = PropertyOptional(PropertyList(Object(FileLink)), 'file_links', override=True) + notes = PropertyOptional(String(), "notes") + file_links = PropertyOptional( + PropertyList(Object(FileLink)), "file_links", override=True + ) DataObjectResourceType = TypeVar("DataObjectResourceType", bound="DataObject") @@ -40,9 +47,12 @@ class DataObjectCollection(DataConceptsCollection[DataObjectResourceType], ABC): """A collection of one kind of data object object.""" def list_by_attribute_bounds( - self, - attribute_bounds: Dict[Union[AttributeTemplate, LinkByUID], BaseBounds], *, - forward: bool = True, per_page: int = 100) -> Iterator[DataObject]: + self, + attribute_bounds: Dict[Union[AttributeTemplate, LinkByUID], BaseBounds], + *, + forward: bool = True, + per_page: int = 100, + ) -> Iterator[DataObject]: """ Get all objects in the collection with attributes within certain bounds. @@ -80,7 +90,7 @@ def list_by_attribute_bounds( body = self._get_attribute_bounds_search_body(attribute_bounds) params = {} if self.dataset_id is not None: - params['dataset_id'] = str(self.dataset_id) + params["dataset_id"] = str(self.dataset_id) raw_objects = self.session.cursor_paged_resource( self.session.post_resource, # "Ignoring" dataset because it is in the query params (and required) @@ -88,30 +98,36 @@ def list_by_attribute_bounds( json=body, forward=forward, per_page=per_page, - params=params) + params=params, + ) return (self.build(raw) for raw in raw_objects) @staticmethod def _get_attribute_bounds_search_body(attribute_bounds): if not isinstance(attribute_bounds, dict): - raise TypeError('attribute_bounds must be a dict mapping template to bounds; ' - 'got {}'.format(attribute_bounds)) + raise TypeError( + "attribute_bounds must be a dict mapping template to bounds; " + "got {}".format(attribute_bounds) + ) if len(attribute_bounds) != 1: - raise NotImplementedError('Currently, only searches with exactly one template ' - 'to bounds mapping are supported; got {}' - .format(attribute_bounds)) + raise NotImplementedError( + "Currently, only searches with exactly one template " + "to bounds mapping are supported; got {}".format(attribute_bounds) + ) return { - 'attribute_bounds': { + "attribute_bounds": { get_object_id(templ): bounds.as_dict() for templ, bounds in attribute_bounds.items() } } - def validate_templates(self, *, - model: DataObjectResourceType, - object_template: Optional[ObjectTemplateResourceType] = None, - ingredient_process_template: Optional[ProcessTemplate] = None)\ - -> List[ValidationError]: + def validate_templates( + self, + *, + model: DataObjectResourceType, + object_template: Optional[ObjectTemplateResourceType] = None, + ingredient_process_template: Optional[ProcessTemplate] = None, + ) -> List[ValidationError]: """ Validate a data object against its templates. @@ -129,15 +145,19 @@ def validate_templates(self, *, temp_scope = str(uuid4()) GEMDJson(scope=temp_scope).dumps(model) # This apparent no-op populates uids dumped_data = replace_objects_with_links(scrub_none(model.dump())) - recursive_foreach(model, lambda x: x.uids.pop(temp_scope, None)) # Strip temp uids + recursive_foreach( + model, lambda x: x.uids.pop(temp_scope, None) + ) # Strip temp uids request_data = {"dataObject": dumped_data} if object_template is not None: - request_data["objectTemplate"] = \ - replace_objects_with_links(scrub_none(object_template.dump())) + request_data["objectTemplate"] = replace_objects_with_links( + scrub_none(object_template.dump()) + ) if ingredient_process_template is not None: - request_data["ingredientProcessTemplate"] = \ - replace_objects_with_links(scrub_none(ingredient_process_template.dump())) + request_data["ingredientProcessTemplate"] = replace_objects_with_links( + scrub_none(ingredient_process_template.dump()) + ) try: self.session.put_resource(path, request_data) return [] diff --git a/src/citrine/resources/data_version_update.py b/src/citrine/resources/data_version_update.py index 2f3876b92..802d39837 100644 --- a/src/citrine/resources/data_version_update.py +++ b/src/citrine/resources/data_version_update.py @@ -1,4 +1,5 @@ """Record to hold branch data version update information.""" + from typing import List from citrine._rest.resource import PredictorRef, Resource @@ -6,37 +7,33 @@ from citrine._serialization.serializable import Serializable -class DataVersionUpdate(Serializable['DataVersionUpdate']): +class DataVersionUpdate(Serializable["DataVersionUpdate"]): """Container for data updates.""" - current = properties.String('current') - latest = properties.String('latest') + current = properties.String("current") + latest = properties.String("latest") - def __init__(self, - *, - current: str, - latest: str): + def __init__(self, *, current: str, latest: str): self.current = current self.latest = latest - typ = properties.String('type', default='DataVersionUpdate') + typ = properties.String("type", default="DataVersionUpdate") -class BranchDataUpdate(Resource['BranchDataUpdate']): +class BranchDataUpdate(Resource["BranchDataUpdate"]): """Branch data updates with predictors using the versions indicated.""" data_updates = properties.List(properties.Object(DataVersionUpdate), "data_updates") predictors = properties.List(properties.Object(PredictorRef), "predictors") - def __init__(self, - *, - data_updates: List[DataVersionUpdate], - predictors: List[PredictorRef]): + def __init__( + self, *, data_updates: List[DataVersionUpdate], predictors: List[PredictorRef] + ): self.data_updates = data_updates self.predictors = predictors -class NextBranchVersionRequest(Resource['NextBranchVersionRequest']): +class NextBranchVersionRequest(Resource["NextBranchVersionRequest"]): """ Instructions for how the next version of a branch should handle its predictors. @@ -49,9 +46,11 @@ class NextBranchVersionRequest(Resource['NextBranchVersionRequest']): data_updates = properties.List(properties.Object(DataVersionUpdate), "data_updates") use_predictors = properties.List(properties.Object(PredictorRef), "use_predictors") - def __init__(self, - *, - data_updates: List[DataVersionUpdate], - use_predictors: List[PredictorRef]): + def __init__( + self, + *, + data_updates: List[DataVersionUpdate], + use_predictors: List[PredictorRef], + ): self.data_updates = data_updates self.use_predictors = use_predictors diff --git a/src/citrine/resources/dataset.py b/src/citrine/resources/dataset.py index 9d4100b69..7d2e2a0af 100644 --- a/src/citrine/resources/dataset.py +++ b/src/citrine/resources/dataset.py @@ -1,4 +1,5 @@ """Resources that represent both individual and collections of datasets.""" + from typing import List, Optional, Union, Tuple, Iterator, Iterable from uuid import UUID @@ -34,7 +35,7 @@ from citrine.resources.property_template import PropertyTemplateCollection -class Dataset(Resource['Dataset']): +class Dataset(Resource["Dataset"]): """ A collection of data objects. @@ -55,42 +56,51 @@ class Dataset(Resource['Dataset']): """ - _response_key = 'dataset' + _response_key = "dataset" _resource_type = ResourceTypeEnum.DATASET - uid = properties.Optional(properties.UUID(), 'id') + uid = properties.Optional(properties.UUID(), "id") """UUID: Unique uuid4 identifier of this dataset.""" - name = properties.String('name') - unique_name = properties.Optional(properties.String(), 'unique_name') - summary = properties.Optional(properties.String, 'summary') - description = properties.Optional(properties.String, 'description') - deleted = properties.Optional(properties.Boolean(), 'deleted') + name = properties.String("name") + unique_name = properties.Optional(properties.String(), "unique_name") + summary = properties.Optional(properties.String, "summary") + description = properties.Optional(properties.String, "description") + deleted = properties.Optional(properties.Boolean(), "deleted") """bool: Flag indicating whether or not this dataset has been deleted.""" - created_by = properties.Optional(properties.UUID(), 'created_by') + created_by = properties.Optional(properties.UUID(), "created_by") """UUID: ID of the user who created the dataset.""" - updated_by = properties.Optional(properties.UUID(), 'updated_by') + updated_by = properties.Optional(properties.UUID(), "updated_by") """UUID: ID of the user who last updated the dataset.""" - deleted_by = properties.Optional(properties.UUID(), 'deleted_by') + deleted_by = properties.Optional(properties.UUID(), "deleted_by") """UUID: ID of the user who deleted the dataset, if it is deleted.""" - create_time = properties.Optional(properties.Datetime(), 'create_time') + create_time = properties.Optional(properties.Datetime(), "create_time") """int: Time the dataset was created, in seconds since epoch.""" - update_time = properties.Optional(properties.Datetime(), 'update_time') + update_time = properties.Optional(properties.Datetime(), "update_time") """int: Time the dataset was most recently updated, in seconds since epoch.""" - delete_time = properties.Optional(properties.Datetime(), 'delete_time') + delete_time = properties.Optional(properties.Datetime(), "delete_time") """int: Time the dataset was deleted, in seconds since epoch, if it is deleted.""" - public = properties.Optional(properties.Boolean(), 'public') + public = properties.Optional(properties.Boolean(), "public") """bool: Flag indicating whether the dataset is publicly readable.""" - project_id = properties.Optional(properties.UUID(), 'project_id', - serializable=False, deserializable=False) + project_id = properties.Optional( + properties.UUID(), "project_id", serializable=False, deserializable=False + ) """project_id will be needed here until deprecation is complete. This class property will be removed post deprecation""" - team_id = properties.Optional(properties.UUID(), 'team_id', - serializable=False, deserializable=False) - session = properties.Optional(properties.Object(Session), 'session', - serializable=False, deserializable=False) - - def __init__(self, name: str, *, summary: Optional[str] = None, - description: Optional[str] = None, unique_name: Optional[str] = None): + team_id = properties.Optional( + properties.UUID(), "team_id", serializable=False, deserializable=False + ) + session = properties.Optional( + properties.Object(Session), "session", serializable=False, deserializable=False + ) + + def __init__( + self, + name: str, + *, + summary: Optional[str] = None, + description: Optional[str] = None, + unique_name: Optional[str] = None, + ): self.name: str = name self.summary: Optional[str] = summary self.description: Optional[str] = description @@ -112,120 +122,190 @@ def __init__(self, name: str, *, summary: Optional[str] = None, self.session = None def __str__(self): - return ''.format(self.name) + return "".format(self.name) @property def property_templates(self) -> PropertyTemplateCollection: """Return a resource representing all property templates in this dataset.""" - return PropertyTemplateCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return PropertyTemplateCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def condition_templates(self) -> ConditionTemplateCollection: """Return a resource representing all condition templates in this dataset.""" - return ConditionTemplateCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return ConditionTemplateCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def parameter_templates(self) -> ParameterTemplateCollection: """Return a resource representing all parameter templates in this dataset.""" - return ParameterTemplateCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return ParameterTemplateCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def material_templates(self) -> MaterialTemplateCollection: """Return a resource representing all material templates in this dataset.""" - return MaterialTemplateCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return MaterialTemplateCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def measurement_templates(self) -> MeasurementTemplateCollection: """Return a resource representing all measurement templates in this dataset.""" - return MeasurementTemplateCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return MeasurementTemplateCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def process_templates(self) -> ProcessTemplateCollection: """Return a resource representing all process templates in this dataset.""" - return ProcessTemplateCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return ProcessTemplateCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def process_runs(self) -> ProcessRunCollection: """Return a resource representing all process runs in this dataset.""" - return ProcessRunCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return ProcessRunCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def measurement_runs(self) -> MeasurementRunCollection: """Return a resource representing all measurement runs in this dataset.""" - return MeasurementRunCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return MeasurementRunCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def material_runs(self) -> MaterialRunCollection: """Return a resource representing all material runs in this dataset.""" - return MaterialRunCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return MaterialRunCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def ingredient_runs(self) -> IngredientRunCollection: """Return a resource representing all ingredient runs in this dataset.""" - return IngredientRunCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return IngredientRunCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def process_specs(self) -> ProcessSpecCollection: """Return a resource representing all process specs in this dataset.""" - return ProcessSpecCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return ProcessSpecCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def measurement_specs(self) -> MeasurementSpecCollection: """Return a resource representing all measurement specs in this dataset.""" - return MeasurementSpecCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return MeasurementSpecCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def material_specs(self) -> MaterialSpecCollection: """Return a resource representing all material specs in this dataset.""" - return MaterialSpecCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return MaterialSpecCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def ingredient_specs(self) -> IngredientSpecCollection: """Return a resource representing all ingredient specs in this dataset.""" - return IngredientSpecCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return IngredientSpecCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def gemd(self) -> GEMDResourceCollection: """Return a resource representing all GEMD objects/templates in this dataset.""" - return GEMDResourceCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return GEMDResourceCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def files(self) -> FileCollection: """Return a resource representing all files in the dataset.""" - return FileCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return FileCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) @property def ingestions(self) -> IngestionCollection: """Return a resource representing all files in the dataset.""" - return IngestionCollection(team_id=self.team_id, dataset_id=self.uid, - session=self.session, project_id=self.project_id) + return IngestionCollection( + team_id=self.team_id, + dataset_id=self.uid, + session=self.session, + project_id=self.project_id, + ) def register(self, model: DataConcepts, *, dry_run=False) -> DataConcepts: """Register a data model object to the appropriate collection.""" return self.gemd._collection_for(model).register(model, dry_run=dry_run) - def register_all(self, - models: Iterable[DataConcepts], - *, - dry_run: bool = False, - status_bar: bool = False, - include_nested: bool = False) -> List[DataConcepts]: + def register_all( + self, + models: Iterable[DataConcepts], + *, + dry_run: bool = False, + status_bar: bool = False, + include_nested: bool = False, + ) -> List[DataConcepts]: """ Register multiple GEMD objects to each of their appropriate collections. @@ -266,7 +346,7 @@ def register_all(self, models, dry_run=dry_run, status_bar=status_bar, - include_nested=include_nested + include_nested=include_nested, ) def update(self, model: DataConcepts) -> DataConcepts: @@ -293,12 +373,12 @@ def delete(self, uid: Union[UUID, str, LinkByUID, DataConcepts], *, dry_run=Fals return collection.delete(uid=uid, dry_run=dry_run) def delete_contents( - self, - *, - prompt_to_confirm: bool = True, - remove_templates: bool = True, - timeout: float = 2 * 60, - polling_delay: float = 1.0 + self, + *, + prompt_to_confirm: bool = True, + remove_templates: bool = True, + timeout: float = 2 * 60, + polling_delay: float = 1.0, ): """ Delete all the GEMD objects from within a single Dataset. @@ -327,17 +407,21 @@ def delete_contents( deleted. """ - path = format_escaped_url('teams/{team_id}/datasets/{dataset_uid}/contents', - dataset_uid=self.uid, - team_id=self.team_id) + path = format_escaped_url( + "teams/{team_id}/datasets/{dataset_uid}/contents", + dataset_uid=self.uid, + team_id=self.team_id, + ) while prompt_to_confirm: - print(f"Confirm you want to delete the contents of " - f"Dataset {self.name} {self.uid} [Y/N]") + print( + f"Confirm you want to delete the contents of " + f"Dataset {self.name} {self.uid} [Y/N]" + ) user_response = input() - if user_response.lower() in {'y', 'yes'}: + if user_response.lower() in {"y", "yes"}: break # return to main flow - elif user_response.lower() in {'n', 'no'}: + elif user_response.lower() in {"n", "no"}: raise RuntimeError("delete_contents was invoked unintentionally") else: print(f'"{user_response}" is not a valid response') @@ -346,18 +430,20 @@ def delete_contents( response = self.session.delete_resource(path, params=params) job_id = response["job_id"] - return _poll_for_async_batch_delete_result(team_id=self.team_id, - session=self.session, - job_id=job_id, - timeout=timeout, - polling_delay=polling_delay) + return _poll_for_async_batch_delete_result( + team_id=self.team_id, + session=self.session, + job_id=job_id, + timeout=timeout, + polling_delay=polling_delay, + ) def gemd_batch_delete( - self, - id_list: List[Union[LinkByUID, UUID, str, BaseEntity]], - *, - timeout: float = 2 * 60, - polling_delay: float = 1.0 + self, + id_list: List[Union[LinkByUID, UUID, str, BaseEntity]], + *, + timeout: float = 2 * 60, + polling_delay: float = 1.0, ) -> List[Tuple[LinkByUID, ApiError]]: """ Remove a set of GEMD objects. @@ -401,9 +487,9 @@ def gemd_batch_delete( deleted. """ - return self.gemd.batch_delete(id_list=id_list, - timeout=timeout, - polling_delay=polling_delay) + return self.gemd.batch_delete( + id_list=id_list, timeout=timeout, polling_delay=polling_delay + ) class DatasetCollection(Collection[Dataset]): @@ -423,11 +509,13 @@ class DatasetCollection(Collection[Dataset]): _collection_key = None _resource = Dataset - def __init__(self, - *args, - session: Session = None, - team_id: UUID = None, - project_id: Optional[UUID] = None): + def __init__( + self, + *args, + session: Session = None, + team_id: UUID = None, + project_id: Optional[UUID] = None, + ): # Handle positional arguments for backward compatibility args = _pad_positional_args(args, 2) self.project_id = project_id or args[0] @@ -439,16 +527,17 @@ def __init__(self, session=self.session, project_id=self.project_id, team_id=team_id, - obj_type="Datasets") + obj_type="Datasets", + ) # After the Data Manager deprecation # this can be a Class Variable using the `teams/...` endpoint @property def _path_template(self): if self.project_id is None: - return f'teams/{self.team_id}/datasets' + return f"teams/{self.team_id}/datasets" else: - return f'projects/{self.project_id}/datasets' + return f"projects/{self.project_id}/datasets" def build(self, data: dict) -> Dataset: """ @@ -504,7 +593,6 @@ def register(self, model: Dataset) -> Dataset: # Leverage the create-or-update endpoint if we've got a unique name data = self.session.put_resource(path, scrub_none(dumped_dataset)) else: - if model.uid is None: # POST to create a new one if a UID is not assigned data = self.session.post_resource(path, scrub_none(dumped_dataset)) @@ -512,7 +600,8 @@ def register(self, model: Dataset) -> Dataset: else: # Otherwise PUT to update it data = self.session.put_resource( - self._get_path(model.uid), scrub_none(dumped_dataset)) + self._get_path(model.uid), scrub_none(dumped_dataset) + ) full_model = self.build(data) full_model.team_id = self.team_id @@ -550,6 +639,8 @@ def get_by_unique_name(self, unique_name: str) -> Dataset: if len(data) == 1: return self.build(data[0]) elif len(data) > 1: - raise RuntimeError("Received multiple results when requesting a unique dataset") + raise RuntimeError( + "Received multiple results when requesting a unique dataset" + ) else: raise NotFound(path) diff --git a/src/citrine/resources/delete.py b/src/citrine/resources/delete.py index dddb7b8d2..f9df688e8 100644 --- a/src/citrine/resources/delete.py +++ b/src/citrine/resources/delete.py @@ -13,12 +13,12 @@ def _async_gemd_batch_delete( - id_list: List[Union[LinkByUID, UUID, str, BaseEntity]], - team_id: UUID, - session: Session, - dataset_id: Optional[UUID] = None, - timeout: float = 2 * 60, - polling_delay: float = 1.0 + id_list: List[Union[LinkByUID, UUID, str, BaseEntity]], + team_id: UUID, + session: Session, + dataset_id: Optional[UUID] = None, + timeout: float = 2 * 60, + polling_delay: float = 1.0, ) -> List[Tuple[LinkByUID, ApiError]]: """ Shared implementation of Async GEMD Batch deletion. @@ -65,15 +65,16 @@ def _async_gemd_batch_delete( scoped_uids = [] for uid in id_list: # And now normalize to id/scope pairs link_by_uid = _make_link_by_uid(uid) - scoped_uids.append({'scope': link_by_uid.scope, 'id': link_by_uid.id}) + scoped_uids.append({"scope": link_by_uid.scope, "id": link_by_uid.id}) - body = {'ids': scoped_uids} + body = {"ids": scoped_uids} if dataset_id is not None: - body.update({'dataset_id': str(dataset_id)}) + body.update({"dataset_id": str(dataset_id)}) if team_id is not None: - path = format_escaped_url('/teams/{team_id}/gemd/async-batch-delete', - team_id=team_id) + path = format_escaped_url( + "/teams/{team_id}/gemd/async-batch-delete", team_id=team_id + ) else: raise TypeError("Missing one required argument: team_id") response = session.post_resource(path, body) @@ -85,15 +86,12 @@ def _async_gemd_batch_delete( session=session, job_id=job_id, timeout=timeout, - polling_delay=polling_delay) + polling_delay=polling_delay, + ) def _poll_for_async_batch_delete_result( - team_id: UUID, - session: Session, - job_id: str, - timeout: float, - polling_delay: float + team_id: UUID, session: Session, job_id: str, timeout: float, polling_delay: float ) -> List[Tuple[LinkByUID, ApiError]]: """ Poll for the result of an asynchronous batch delete (or a deletion of dataset contents). @@ -130,7 +128,10 @@ def _poll_for_async_batch_delete_result( team_id=team_id, job=job_id, timeout=timeout, - polling_delay=polling_delay) + polling_delay=polling_delay, + ) - return [(LinkByUID(f['id']['scope'], f['id']['id']), ApiError.build(f['cause'])) - for f in json.loads(response.output.get('failures', '[]'))] + return [ + (LinkByUID(f["id"]["scope"], f["id"]["id"]), ApiError.build(f["cause"])) + for f in json.loads(response.output.get("failures", "[]")) + ] diff --git a/src/citrine/resources/descriptors.py b/src/citrine/resources/descriptors.py index d34199ca1..41f1abc65 100644 --- a/src/citrine/resources/descriptors.py +++ b/src/citrine/resources/descriptors.py @@ -16,8 +16,12 @@ def __init__(self, project_id: UUID, session: Session): self.project_id = project_id self.session: Session = session - def from_predictor_responses(self, *, predictor: Union[GraphPredictor, PredictorNode], - inputs: List[Descriptor]) -> List[Descriptor]: + def from_predictor_responses( + self, + *, + predictor: Union[GraphPredictor, PredictorNode], + inputs: List[Descriptor], + ) -> List[Descriptor]: """ Get responses for a predictor, given an input space. @@ -42,14 +46,12 @@ def from_predictor_responses(self, *, predictor: Union[GraphPredictor, Predictor predictor_data = predictor.dump() response = self.session.post_resource( - path=format_escaped_url('/projects/{}/material-descriptors/predictor-responses', - self.project_id), - json={ - 'predictor': predictor_data, - 'inputs': [i.dump() for i in inputs] - } + path=format_escaped_url( + "/projects/{}/material-descriptors/predictor-responses", self.project_id + ), + json={"predictor": predictor_data, "inputs": [i.dump() for i in inputs]}, ) - return [Descriptor.build(r) for r in response['responses']] + return [Descriptor.build(r) for r in response["responses"]] def from_data_source(self, *, data_source: DataSource) -> List[Descriptor]: """ @@ -67,10 +69,9 @@ def from_data_source(self, *, data_source: DataSource) -> List[Descriptor]: """ response = self.session.post_resource( - path=format_escaped_url('/projects/{}/material-descriptors/from-data-source', - self.project_id), - json={ - 'data_source': data_source.dump() - } + path=format_escaped_url( + "/projects/{}/material-descriptors/from-data-source", self.project_id + ), + json={"data_source": data_source.dump()}, ) - return [Descriptor.build(r) for r in response['descriptors']] + return [Descriptor.build(r) for r in response["descriptors"]] diff --git a/src/citrine/resources/design_execution.py b/src/citrine/resources/design_execution.py index 7bf87fa77..433ea3eab 100644 --- a/src/citrine/resources/design_execution.py +++ b/src/citrine/resources/design_execution.py @@ -1,4 +1,5 @@ """Resources that represent both individual and collections of design workflow executions.""" + from typing import Optional, Union, Iterator from uuid import UUID @@ -12,15 +13,14 @@ class DesignExecutionCollection(Collection["DesignExecution"]): """A collection of DesignExecutions.""" - _path_template = '/projects/{project_id}/design-workflows/{workflow_id}/executions' # noqa + _path_template = "/projects/{project_id}/design-workflows/{workflow_id}/executions" # noqa _individual_key = None - _collection_key = 'response' + _collection_key = "response" _resource = executions.DesignExecution - def __init__(self, - project_id: UUID, - session: Session, - workflow_id: Optional[UUID] = None): + def __init__( + self, project_id: UUID, session: Session, workflow_id: Optional[UUID] = None + ): self.project_id: UUID = project_id self.session: Session = session self.workflow_id: UUID = workflow_id @@ -35,7 +35,7 @@ def build(self, data: dict) -> executions.DesignExecution: def trigger(self, execution_input: Score, *, max_candidates: Optional[int] = None): """Trigger a Design Workflow execution given a score and a maximum number of candidates.""" path = self._get_path() - json = {'score': execution_input.dump(), "max_candidates": max_candidates} + json = {"score": execution_input.dump(), "max_candidates": max_candidates} data = self.session.post_resource(path, json) return self.build(data) @@ -56,8 +56,7 @@ def archive(self, uid: Union[UUID, str]): Unique identifier of the execution to archive """ - raise NotImplementedError( - "Design Executions cannot be archived") + raise NotImplementedError("Design Executions cannot be archived") def restore(self, uid: UUID): """Restore an archived Design Workflow execution. @@ -68,8 +67,7 @@ def restore(self, uid: UUID): Unique identifier of the execution to restore """ - raise NotImplementedError( - "Design Executions cannot be restored") + raise NotImplementedError("Design Executions cannot be restored") def list(self, *, per_page: int = 100) -> Iterator[executions.DesignExecution]: """ @@ -91,11 +89,12 @@ def list(self, *, per_page: int = 100) -> Iterator[executions.DesignExecution]: Resources in this collection. """ - return self._paginator.paginate(page_fetcher=self._fetch_page, - collection_builder=self._build_collection_elements, - per_page=per_page) + return self._paginator.paginate( + page_fetcher=self._fetch_page, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) def delete(self, uid: Union[UUID, str]) -> Response: """Design Workflow Executions cannot be deleted or archived.""" - raise NotImplementedError( - "Design Executions cannot be deleted") + raise NotImplementedError("Design Executions cannot be deleted") diff --git a/src/citrine/resources/design_space.py b/src/citrine/resources/design_space.py index fb1aa5f69..b849a2937 100644 --- a/src/citrine/resources/design_space.py +++ b/src/citrine/resources/design_space.py @@ -1,4 +1,5 @@ """Resources that represent collections of design spaces.""" + import warnings from functools import partial from typing import Iterable, Iterator, Optional, TypeVar, Union @@ -6,13 +7,19 @@ from citrine._utils.functions import format_escaped_url -from citrine.informatics.design_spaces import DataSourceDesignSpace, DefaultDesignSpaceMode, \ - DesignSpace, DesignSpaceSettings, EnumeratedDesignSpace, FormulationDesignSpace, \ - HierarchicalDesignSpace +from citrine.informatics.design_spaces import ( + DataSourceDesignSpace, + DefaultDesignSpaceMode, + DesignSpace, + DesignSpaceSettings, + EnumeratedDesignSpace, + FormulationDesignSpace, + HierarchicalDesignSpace, +) from citrine._rest.collection import Collection from citrine._session import Session -CreationType = TypeVar('CreationType', bound=DesignSpace) +CreationType = TypeVar("CreationType", bound=DesignSpace) class DesignSpaceCollection(Collection[DesignSpace]): @@ -25,11 +32,11 @@ class DesignSpaceCollection(Collection[DesignSpace]): """ - _api_version = 'v3' - _path_template = '/projects/{project_id}/design-spaces' + _api_version = "v3" + _path_template = "/projects/{project_id}/design-spaces" _individual_key = None _resource = DesignSpace - _collection_key = 'response' + _collection_key = "response" _enumerated_cell_limit = 128 * 2000 def __init__(self, project_id: UUID, session: Session): @@ -55,24 +62,32 @@ def _verify_write_request(self, design_space: DesignSpace): warnings as appropriate. """ if isinstance(design_space, EnumeratedDesignSpace): - warnings.warn("As of 3.27.0, EnumeratedDesignSpace is deprecated in favor of a " - "ProductDesignSpace containing a DataSourceDesignSpace subspace. " - "Support for EnumeratedDesignSpace will be dropped in 4.0.", - DeprecationWarning) + warnings.warn( + "As of 3.27.0, EnumeratedDesignSpace is deprecated in favor of a " + "ProductDesignSpace containing a DataSourceDesignSpace subspace. " + "Support for EnumeratedDesignSpace will be dropped in 4.0.", + DeprecationWarning, + ) width = len(design_space.descriptors) length = len(design_space.data) if width * length > self._enumerated_cell_limit: - msg = "EnumeratedDesignSpace only supports up to {} descriptor-values, " \ - "but {} were given. Please reduce the number of descriptors or candidates " \ - "in this EnumeratedDesignSpace" - raise ValueError(msg.format(self._enumerated_cell_limit, width * length)) + msg = ( + "EnumeratedDesignSpace only supports up to {} descriptor-values, " + "but {} were given. Please reduce the number of descriptors or candidates " + "in this EnumeratedDesignSpace" + ) + raise ValueError( + msg.format(self._enumerated_cell_limit, width * length) + ) elif isinstance(design_space, (DataSourceDesignSpace, FormulationDesignSpace)): typ = type(design_space).__name__ - warnings.warn(f"As of 3.27.0, saving a top-level {typ} is deprecated. Support " - "will be removed in 4.0. Wrap it in a ProductDesignSpace instead: " - f"ProductDesignSpace('name', 'description', subspaces=[{typ}(...)])", - DeprecationWarning) + warnings.warn( + f"As of 3.27.0, saving a top-level {typ} is deprecated. Support " + "will be removed in 4.0. Wrap it in a ProductDesignSpace instead: " + f"ProductDesignSpace('name', 'description', subspaces=[{typ}(...)])", + DeprecationWarning, + ) def _verify_read_request(self, design_space: DesignSpace): """Perform read-time validations of the design space. @@ -81,17 +96,21 @@ def _verify_read_request(self, design_space: DesignSpace): appropriate. """ if isinstance(design_space, EnumeratedDesignSpace): - warnings.warn("As of 3.27.0, EnumeratedDesignSpace is deprecated in favor of a " - "ProductDesignSpace containing a DataSourceDesignSpace subspace. " - "Support for EnumeratedDesignSpace will be dropped in 4.0.", - DeprecationWarning) + warnings.warn( + "As of 3.27.0, EnumeratedDesignSpace is deprecated in favor of a " + "ProductDesignSpace containing a DataSourceDesignSpace subspace. " + "Support for EnumeratedDesignSpace will be dropped in 4.0.", + DeprecationWarning, + ) elif isinstance(design_space, (DataSourceDesignSpace, FormulationDesignSpace)): typ = type(design_space).__name__ - warnings.warn(f"As of 3.27.0, top-level {typ}s are deprecated. Any that remain when " - "SDK 4.0 are released will be wrapped in a ProductDesignSpace. You " - "can wrap it yourself to get rid of this warning now: " - f"ProductDesignSpace('name', 'description', subspaces=[{typ}(...)])", - DeprecationWarning) + warnings.warn( + f"As of 3.27.0, top-level {typ}s are deprecated. Any that remain when " + "SDK 4.0 are released will be wrapped in a ProductDesignSpace. You " + "can wrap it yourself to get rid of this warning now: " + f"ProductDesignSpace('name', 'description', subspaces=[{typ}(...)])", + DeprecationWarning, + ) def register(self, design_space: DesignSpace) -> DesignSpace: """Create a new design space.""" @@ -157,7 +176,9 @@ def get(self, uid: Union[UUID, str]) -> DesignSpace: self._verify_read_request(design_space) return design_space - def _build_collection_elements(self, collection: Iterable[dict]) -> Iterator[DesignSpace]: + def _build_collection_elements( + self, collection: Iterable[dict] + ) -> Iterator[DesignSpace]: """ For each element in the collection, build the appropriate resource type. @@ -182,9 +203,11 @@ def _list_base(self, *, per_page: int = 100, archived: Optional[bool] = None): filters["archived"] = archived fetcher = partial(self._fetch_page, additional_params=filters, version="v4") - return self._paginator.paginate(page_fetcher=fetcher, - collection_builder=self._build_collection_elements, - per_page=per_page) + return self._paginator.paginate( + page_fetcher=fetcher, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) def list_all(self, *, per_page: int = 20) -> Iterable[DesignSpace]: """List all design spaces.""" @@ -198,15 +221,17 @@ def list_archived(self, *, per_page: int = 20) -> Iterable[DesignSpace]: """List archived design spaces.""" return self._list_base(per_page=per_page, archived=True) - def create_default(self, - *, - predictor_id: Union[UUID, str], - predictor_version: Optional[Union[int, str]] = None, - mode: DefaultDesignSpaceMode = DefaultDesignSpaceMode.ATTRIBUTE, - include_ingredient_fraction_constraints: bool = False, - include_label_fraction_constraints: bool = False, - include_label_count_constraints: bool = False, - include_parameter_constraints: bool = False) -> DesignSpace: + def create_default( + self, + *, + predictor_id: Union[UUID, str], + predictor_version: Optional[Union[int, str]] = None, + mode: DefaultDesignSpaceMode = DefaultDesignSpaceMode.ATTRIBUTE, + include_ingredient_fraction_constraints: bool = False, + include_label_fraction_constraints: bool = False, + include_label_count_constraints: bool = False, + include_parameter_constraints: bool = False, + ) -> DesignSpace: """Create a default design space for a predictor. This method will return an unregistered design space for all inputs @@ -254,7 +279,7 @@ def create_default(self, Default design space """ - path = f'projects/{self.project_id}/design-spaces/default' + path = f"projects/{self.project_id}/design-spaces/default" settings = DesignSpaceSettings( predictor_id=predictor_id, predictor_version=predictor_version, @@ -262,20 +287,22 @@ def create_default(self, include_ingredient_fraction_constraints=include_ingredient_fraction_constraints, include_label_fraction_constraints=include_label_fraction_constraints, include_label_count_constraints=include_label_count_constraints, - include_parameter_constraints=include_parameter_constraints + include_parameter_constraints=include_parameter_constraints, ) - data = self.session.post_resource(path, json=settings.dump(), version=self._api_version) + data = self.session.post_resource( + path, json=settings.dump(), version=self._api_version + ) ds = self.build(DesignSpace.wrap_instance(data["instance"])) ds._settings = settings return ds def convert_to_hierarchical( - self, - uid: Union[UUID, str], - *, - predictor_id: Union[UUID, str], - predictor_version: Optional[Union[int, str]] = None + self, + uid: Union[UUID, str], + *, + predictor_id: Union[UUID, str], + predictor_version: Optional[Union[int, str]] = None, ) -> HierarchicalDesignSpace: """Convert an existing ProductDesignSpace into an equivalent HierarchicalDesignSpace. @@ -303,7 +330,7 @@ def convert_to_hierarchical( path = format_escaped_url( "projects/{project_id}/design-spaces/{design_space_id}/convert-hierarchical", project_id=self.project_id, - design_space_id=uid + design_space_id=uid, ) payload = { "predictor_id": str(predictor_id), @@ -311,7 +338,9 @@ def convert_to_hierarchical( if predictor_version: payload["predictor_version"] = predictor_version data = self.session.post_resource(path, json=payload, version=self._api_version) - return HierarchicalDesignSpace.build(DesignSpace.wrap_instance(data["instance"])) + return HierarchicalDesignSpace.build( + DesignSpace.wrap_instance(data["instance"]) + ) def delete(self, uid: Union[UUID, str]): """Design Spaces cannot be deleted at this time.""" diff --git a/src/citrine/resources/design_workflow.py b/src/citrine/resources/design_workflow.py index 4fffa021a..4f65e3c64 100644 --- a/src/citrine/resources/design_workflow.py +++ b/src/citrine/resources/design_workflow.py @@ -12,18 +12,20 @@ class DesignWorkflowCollection(Collection[DesignWorkflow]): """A collection of DesignWorkflows.""" - _path_template = '/projects/{project_id}/design-workflows' + _path_template = "/projects/{project_id}/design-workflows" _individual_key = None - _collection_key = 'response' + _collection_key = "response" _resource = DesignWorkflow _api_version = "v2" - def __init__(self, - project_id: UUID, - session: Session, - *, - branch_root_id: Optional[UUID] = None, - branch_version: Optional[int] = None): + def __init__( + self, + project_id: UUID, + session: Session, + *, + branch_root_id: Optional[UUID] = None, + branch_version: Optional[int] = None, + ): self.project_id: UUID = project_id self.session: Session = session @@ -52,9 +54,11 @@ def register(self, model: DesignWorkflow) -> DesignWorkflow: if self.branch_root_id is None or self.branch_version is None: # There are a number of contexts in which hitting design workflow endpoints without # a branch ID is valid, so only this particular usage is disallowed. - msg = ('A design workflow must be created with a branch. Please use ' - 'branch.design_workflows.register() instead of ' - 'project.design_workflows.register().') + msg = ( + "A design workflow must be created with a branch. Please use " + "branch.design_workflows.register() instead of " + "project.design_workflows.register()." + ) raise RuntimeError(msg) else: # branch_root_id and branch_version are in the body of design workflow endpoints, so @@ -107,20 +111,28 @@ def update(self, model: DesignWorkflow) -> DesignWorkflow: """ if self.branch_root_id is not None or self.branch_version is not None: - if self.branch_root_id != model.branch_root_id or \ - self.branch_version != model.branch_version: - raise ValueError('To move a design workflow to another branch, please use ' - 'Project.design_workflows.update') + if ( + self.branch_root_id != model.branch_root_id + or self.branch_version != model.branch_version + ): + raise ValueError( + "To move a design workflow to another branch, please use " + "Project.design_workflows.update" + ) if model.branch_root_id is None or model.branch_version is None: - raise ValueError('Cannot update a design workflow unless its branch_root_id and ' - 'branch_version are set.') + raise ValueError( + "Cannot update a design workflow unless its branch_root_id and " + "branch_version are set." + ) # If executions have already been done, warn about future behavior change executions = model.design_executions.list() if next(executions, None) is not None: - raise RuntimeError("Cannot update a design workflow after candidate generation, " - "please register a new design workflow instead") + raise RuntimeError( + "Cannot update a design workflow after candidate generation, " + "please register a new design workflow instead" + ) return super().update(model) @@ -151,29 +163,37 @@ def restore(self, uid: Union[UUID, str]): def delete(self, uid: Union[UUID, str]) -> Response: """Design Workflows cannot be deleted; they can be archived instead.""" raise NotImplementedError( - "Design Workflows cannot be deleted; they can be archived instead.") + "Design Workflows cannot be deleted; they can be archived instead." + ) def list_archived(self, *, per_page: int = 500) -> Iterable[DesignWorkflow]: """List archived Design Workflows.""" - fetcher = partial(self._fetch_page, additional_params={"filter": "archived eq 'true'"}) - return self._paginator.paginate(page_fetcher=fetcher, - collection_builder=self._build_collection_elements, - per_page=per_page) - - def _fetch_page(self, - path: Optional[str] = None, - fetch_func: Optional[Callable[..., dict]] = None, - page: Optional[int] = None, - per_page: Optional[int] = None, - json_body: Optional[dict] = None, - additional_params: Optional[dict] = None, - ) -> Tuple[Iterable[dict], str]: + fetcher = partial( + self._fetch_page, additional_params={"filter": "archived eq 'true'"} + ) + return self._paginator.paginate( + page_fetcher=fetcher, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) + + def _fetch_page( + self, + path: Optional[str] = None, + fetch_func: Optional[Callable[..., dict]] = None, + page: Optional[int] = None, + per_page: Optional[int] = None, + json_body: Optional[dict] = None, + additional_params: Optional[dict] = None, + ) -> Tuple[Iterable[dict], str]: params = additional_params or {} params["branch_root_id"] = self.branch_root_id params["branch_version"] = self.branch_version - return super()._fetch_page(path=path, - fetch_func=fetch_func, - page=page, - per_page=per_page, - json_body=json_body, - additional_params=params) + return super()._fetch_page( + path=path, + fetch_func=fetch_func, + page=page, + per_page=per_page, + json_body=json_body, + additional_params=params, + ) diff --git a/src/citrine/resources/experiment_datasource.py b/src/citrine/resources/experiment_datasource.py index ab4a9e412..a513c54a5 100644 --- a/src/citrine/resources/experiment_datasource.py +++ b/src/citrine/resources/experiment_datasource.py @@ -12,24 +12,27 @@ from citrine.informatics.experiment_values import ExperimentValue -class CandidateExperimentSnapshot(Serializable['CandidateExperimentSnapshot']): +class CandidateExperimentSnapshot(Serializable["CandidateExperimentSnapshot"]): """The contents of a candidate experiment within an experiment data source.""" - uid = properties.UUID('experiment_id', serializable=False) + uid = properties.UUID("experiment_id", serializable=False) """:UUID: unique Citrine id of this experiment""" - candidate_id = properties.UUID('candidate_id', serializable=False) + candidate_id = properties.UUID("candidate_id", serializable=False) """:UUID: unique Citrine id of the candidate associated with this experiment""" - workflow_id = properties.UUID('workflow_id', serializable=False) + workflow_id = properties.UUID("workflow_id", serializable=False) """:UUID: unique Citrine id of the design workflow which produced the associated candidate""" - name = properties.String('name', serializable=False) + name = properties.String("name", serializable=False) """:str: name of the experiment""" - description = properties.Optional(properties.String, 'description', serializable=False) + description = properties.Optional( + properties.String, "description", serializable=False + ) """:Optional[str]: description of the experiment""" - updated_time = properties.Datetime('updated_time', serializable=False) + updated_time = properties.Datetime("updated_time", serializable=False) """:datetime: date and time at which the experiment was updated""" - overrides = properties.Mapping(properties.String, properties.Object(ExperimentValue), - 'overrides') + overrides = properties.Mapping( + properties.String, properties.Object(ExperimentValue), "overrides" + ) """:dict[str, ExperimentValue]: dictionary of candidate material variable overrides""" def __init__(self, *args, **kwargs): @@ -37,25 +40,30 @@ def __init__(self, *args, **kwargs): pass # pragma: no cover def _overrides_json(self) -> Dict[str, str]: - return {name: json.dumps(expt_value.value) for name, expt_value in self.overrides.items()} + return { + name: json.dumps(expt_value.value) + for name, expt_value in self.overrides.items() + } -class ExperimentDataSource(Serializable['ExperimentDataSource']): +class ExperimentDataSource(Serializable["ExperimentDataSource"]): """An experiment data source.""" - uid = properties.UUID('id', serializable=False) + uid = properties.UUID("id", serializable=False) """:UUID: unique Citrine id of this experiment data source""" - experiments = properties.List(properties.Object(CandidateExperimentSnapshot), - 'data.experiments', - serializable=False) + experiments = properties.List( + properties.Object(CandidateExperimentSnapshot), + "data.experiments", + serializable=False, + ) """:list[CandidateExperimentSnapshot]: list of experiment data in this data source""" - branch_root_id = properties.UUID('metadata.branch_root_id', serializable=False) + branch_root_id = properties.UUID("metadata.branch_root_id", serializable=False) """:UUID: unique Citrine id of the branch root this data source is associated with""" - version = properties.Integer('metadata.version', serializable=False) + version = properties.Integer("metadata.version", serializable=False) """:int: version of this data source""" - created_by = properties.UUID('metadata.created.user', serializable=False) + created_by = properties.UUID("metadata.created.user", serializable=False) """:UUID: id of the user who created this data source""" - create_time = properties.Datetime('metadata.created.time', serializable=False) + create_time = properties.Datetime("metadata.created.time", serializable=False) """:datetime: date and time at which this data source was created""" def __init__(self, *args, **kwargs): @@ -88,10 +96,10 @@ def read(self) -> str: class ExperimentDataSourceCollection(Collection[ExperimentDataSource]): """Represents the collection of all experiment data sources associated with a project.""" - _path_template = 'projects/{project_id}/candidate-experiment-datasources' + _path_template = "projects/{project_id}/candidate-experiment-datasources" _individual_key = None _resource = ExperimentDataSource - _collection_key = 'response' + _collection_key = "response" def __init__(self, project_id: UUID, session: Session): self.project_id = project_id @@ -104,10 +112,13 @@ def build(self, data: dict) -> ExperimentDataSource: result._session = self.session return result - def list(self, *, - per_page: int = 100, - branch_version_id: Optional[Union[UUID, str]] = None, - version: Optional[Union[int, str]] = None) -> Iterator[ExperimentDataSource]: + def list( + self, + *, + per_page: int = 100, + branch_version_id: Optional[Union[UUID, str]] = None, + version: Optional[Union[int, str]] = None, + ) -> Iterator[ExperimentDataSource]: """Paginate over the experiment data sources. Parameters @@ -134,9 +145,11 @@ def list(self, *, params["version"] = version fetcher = partial(self._fetch_page, additional_params=params) - return self._paginator.paginate(page_fetcher=fetcher, - collection_builder=self._build_collection_elements, - per_page=per_page) + return self._paginator.paginate( + page_fetcher=fetcher, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) def read(self, datasource: Union[ExperimentDataSource, UUID, str]): """Reads the provided experiment data source into a CSV. diff --git a/src/citrine/resources/file_link.py b/src/citrine/resources/file_link.py index e4e558d99..6d6fdf8d1 100644 --- a/src/citrine/resources/file_link.py +++ b/src/citrine/resources/file_link.py @@ -1,4 +1,5 @@ """A collection of FileLink objects.""" + import mimetypes import os from pathlib import Path @@ -14,8 +15,12 @@ from citrine._serialization import properties from citrine._serialization.serializable import Serializable from citrine._session import Session -from citrine._utils.functions import _data_manager_deprecation_checks, _pad_positional_args, \ - rewrite_s3_links_locally, write_file_locally +from citrine._utils.functions import ( + _data_manager_deprecation_checks, + _pad_positional_args, + rewrite_s3_links_locally, + write_file_locally, +) from citrine.resources.response import Response from gemd.entity.dict_serializable import DictSerializableMeta from gemd.entity.bounds.base_bounds import BaseBounds @@ -51,31 +56,32 @@ class _Uploader: """Holds the many parameters that are generated and used during file upload.""" def __init__(self): - self.bucket = '' - self.object_key = '' - self.upload_id = '' - self.region_name = '' - self.aws_access_key_id = '' - self.aws_secret_access_key = '' - self.aws_session_token = '' - self.s3_version = '' + self.bucket = "" + self.object_key = "" + self.upload_id = "" + self.region_name = "" + self.aws_access_key_id = "" + self.aws_secret_access_key = "" + self.aws_session_token = "" + self.s3_version = "" self.s3_endpoint_url = None self.s3_use_ssl = True - self.s3_addressing_style = 'auto' + self.s3_addressing_style = "auto" class CsvColumnInfo(Serializable): """The info for a CSV Column, contains the name, recommended and exact bounds.""" - name = properties.String('name') + name = properties.String("name") """:str: name of the column""" - bounds = properties.Object(BaseBounds, 'bounds') + bounds = properties.Object(BaseBounds, "bounds") """:BaseBounds: recommended bounds of the column (might include some padding)""" - exact_range_bounds = properties.Object(BaseBounds, 'exact_range_bounds') + exact_range_bounds = properties.Object(BaseBounds, "exact_range_bounds") """:BaseBounds: exact bounds of the column""" - def __init__(self, name: str, bounds: BaseBounds, - exact_range_bounds: BaseBounds): # pragma: no cover + def __init__( + self, name: str, bounds: BaseBounds, exact_range_bounds: BaseBounds + ): # pragma: no cover self.name = name self.bounds = bounds self.exact_range_bounds = exact_range_bounds @@ -86,7 +92,7 @@ class FileLinkMeta(DictSerializableMeta): def __init__(cls, *args, **kwargs): super().__init__(*args, **kwargs) - cls.typ = properties.String('type', default="file_link", deserializable=False) + cls.typ = properties.String("type", default="file_link", deserializable=False) def _get_ids_from_url(url: str) -> Tuple[Optional[UUID], Optional[UUID]]: @@ -95,11 +101,15 @@ def _get_ids_from_url(url: str) -> Tuple[Optional[UUID], Optional[UUID]]: if len(parsed.query) > 0 or len(parsed.fragment) > 0: # Illegal modifiers return None, None - split_path = urlparse(url).path.split('/') - if len(split_path) >= 4 and split_path[-4] == 'files' and split_path[-2] == 'versions': + split_path = urlparse(url).path.split("/") + if ( + len(split_path) >= 4 + and split_path[-4] == "files" + and split_path[-2] == "versions" + ): file_id = split_path[-3] version_id = split_path[-1] - elif len(split_path) >= 2 and split_path[-2] == 'files': + elif len(split_path) >= 2 and split_path[-2] == "files": file_id = split_path[-1] version_id = None else: @@ -116,10 +126,7 @@ def _get_ids_from_url(url: str) -> Tuple[Optional[UUID], Optional[UUID]]: class FileLink( - GEMDResource['FileLink'], - GEMDFileLink, - metaclass=FileLinkMeta, - typ=GEMDFileLink.typ + GEMDResource["FileLink"], GEMDFileLink, metaclass=FileLinkMeta, typ=GEMDFileLink.typ ): """ Resource that stores the name and url of an external file. @@ -136,23 +143,29 @@ class FileLink( # NOTE: skipping the "metadata" field since it appears to be unused # NOTE: skipping the "versioned_url" field since it is redundant # NOTE: skipping the "unversioned_url" field since it is redundant - filename = properties.String('filename') - url = properties.String('url') - uid = properties.Optional(properties.UUID, 'id', serializable=False) + filename = properties.String("filename") + url = properties.String("url") + uid = properties.Optional(properties.UUID, "id", serializable=False) """UUID: Unique uuid4 identifier of this file; consistent across versions.""" - version = properties.Optional(properties.UUID, 'version', serializable=False) + version = properties.Optional(properties.UUID, "version", serializable=False) """UUID: Unique uuid4 identifier of this version of this file.""" - created_time = properties.Optional(properties.Datetime, 'created_time', serializable=False) + created_time = properties.Optional( + properties.Datetime, "created_time", serializable=False + ) """datetime: Time the file was created on platform.""" - created_by = properties.Optional(properties.UUID, 'created_by', serializable=False) + created_by = properties.Optional(properties.UUID, "created_by", serializable=False) """UUID: Unique uuid4 identifier of this User who loaded this file.""" - mime_type = properties.Optional(properties.String, 'mime_type', serializable=False) + mime_type = properties.Optional(properties.String, "mime_type", serializable=False) """str: Encoded string representing the type of the file (IETF RFC 2045).""" - size = properties.Optional(properties.Integer, 'size', serializable=False) + size = properties.Optional(properties.Integer, "size", serializable=False) """int: Size in bytes of the file.""" - description = properties.Optional(properties.String, 'description', serializable=False) + description = properties.Optional( + properties.String, "description", serializable=False + ) """str: A human-readable description of the file.""" - version_number = properties.Optional(properties.Integer, 'version_number', serializable=False) + version_number = properties.Optional( + properties.Integer, "version_number", serializable=False + ) """int: How many times this file has been uploaded; files are the "same" if they share a filename and dataset.""" @@ -179,23 +192,23 @@ def name(self): @classmethod def _pre_build(cls, data: dict) -> dict: """Run data modification before building.""" - if 'url' in data and 'id' not in data: - uid, version = _get_ids_from_url(data['url']) + if "url" in data and "id" not in data: + uid, version = _get_ids_from_url(data["url"]) if uid is not None: - data['id'] = str(uid) + data["id"] = str(uid) if version is not None: - data['version'] = str(version) + data["version"] = str(version) return data def __str__(self): - return f'' + return f"" class FileCollection(Collection[FileLink]): """Represents the collection of all file links associated with a dataset.""" - _path_template = 'teams/{team_id}/datasets/{dataset_id}/files' - _collection_key = 'files' + _path_template = "teams/{team_id}/datasets/{dataset_id}/files" + _collection_key = "files" _resource = FileLink def __init__( @@ -204,7 +217,7 @@ def __init__( session: Session = None, dataset_id: UUID = None, team_id: UUID = None, - project_id: Optional[UUID] = None + project_id: Optional[UUID] = None, ): args = _pad_positional_args(args, 3) self.project_id = project_id or args[0] @@ -219,40 +232,46 @@ def __init__( session=self.session, project_id=self.project_id, team_id=team_id, - obj_type="File Links") - - def _get_path(self, - uid: Optional[Union[UUID, str]] = None, - *, - ignore_dataset: Optional[bool] = False, - version: Union[str, UUID] = None, - action: Union[str, Sequence[str]] = [], - query_terms: Dict[str, str] = {},) -> str: + obj_type="File Links", + ) + + def _get_path( + self, + uid: Optional[Union[UUID, str]] = None, + *, + ignore_dataset: Optional[bool] = False, + version: Union[str, UUID] = None, + action: Union[str, Sequence[str]] = [], + query_terms: Dict[str, str] = {}, + ) -> str: """Build the path for taking an action with a particular file version.""" if version is not None: - action = ['versions', version] + ([action] if isinstance(action, str) else action) + action = ["versions", version] + ( + [action] if isinstance(action, str) else action + ) return super()._get_path(uid=uid, ignore_dataset=ignore_dataset, action=action) - def _get_path_from_file_link(self, file_link: FileLink, - *, - action: str = None) -> str: + def _get_path_from_file_link( + self, file_link: FileLink, *, action: str = None + ) -> str: """Build the platform path for taking an action with a particular file link.""" if not self._is_on_platform_url(file_link.url) or file_link.uid is None: raise ValueError("FileLink did not contain a Citrine platform file URL.") - return self._get_path(uid=file_link.uid, version=file_link.version, action=action) + return self._get_path( + uid=file_link.uid, version=file_link.version, action=action + ) def build(self, data: dict) -> FileLink: """Build an instance of FileLink.""" # Use this chance to construct a URL from platform metadata - if 'url' not in data: - data['url'] = self._get_path(uid=data['id'], version=data['version']) + if "url" not in data: + data["url"] = self._get_path(uid=data["id"], version=data["version"]) return FileLink.build(data) - def get(self, - uid: Union[UUID, str], - *, - version: Optional[Union[UUID, str, int]] = None) -> FileLink: + def get( + self, uid: Union[UUID, str], *, version: Optional[Union[UUID, str, int]] = None + ) -> FileLink: """ Retrieve an on-platform FileLink from its filename or file uuid. @@ -272,11 +291,15 @@ def get(self, """ if not isinstance(uid, (str, UUID)): - raise TypeError(f"File Link can only be resolved from str or UUID." - f"Instead got {type(uid)} {uid}.") + raise TypeError( + f"File Link can only be resolved from str or UUID." + f"Instead got {type(uid)} {uid}." + ) if version is not None and not isinstance(version, (str, UUID, int)): - raise TypeError(f"Version can only be resolved from str, int or UUID." - f"Instead got {type(uid)} {uid}.") + raise TypeError( + f"Version can only be resolved from str, int or UUID." + f"Instead got {type(uid)} {uid}." + ) if isinstance(uid, str): try: # Check if the uid string is actually a UUID @@ -298,18 +321,20 @@ def get(self, if isinstance(uid, str): # Assume it's the filename on platform; if version is None or isinstance(version, int): - file = self._search_by_file_name(dset_id=self.dataset_id, - file_name=uid, - file_version_number=version) + file = self._search_by_file_name( + dset_id=self.dataset_id, file_name=uid, file_version_number=version + ) else: # We did our type checks earlier; version is a UUID file = self._search_by_file_version_id(file_version_id=version) else: # We did our type checks earlier; uid is a UUID if isinstance(version, UUID): file = self._search_by_file_version_id(file_version_id=version) else: # We did our type checks earlier; version is an int or None - file = self._search_by_dataset_file_id(dataset_file_id=uid, - dset_id=self.dataset_id, - file_version_number=version) + file = self._search_by_dataset_file_id( + dataset_file_id=uid, + dset_id=self.dataset_id, + file_version_number=version, + ) return file @@ -370,12 +395,8 @@ def _make_upload_request(self, file_path: Path, dest_name: str): file_size = file_path.stat().st_size assert isinstance(file_size, int) upload_json = { - 'files': [ - { - 'file_name': dest_name, - 'mime_type': mime_type, - 'size': file_size - } + "files": [ + {"file_name": dest_name, "mime_type": mime_type, "size": file_size} ] } # POST request creates space in S3 for the file and returns AWS-related information @@ -385,29 +406,34 @@ def _make_upload_request(self, file_path: Path, dest_name: str): # Extract all relevant information from the upload request try: - - uploader.region_name = upload_request['s3_region'] - uploader.aws_access_key_id = upload_request['temporary_credentials']['access_key_id'] - uploader.aws_secret_access_key = \ - upload_request['temporary_credentials']['secret_access_key'] - uploader.aws_session_token = upload_request['temporary_credentials']['session_token'] - uploader.bucket = upload_request['s3_bucket'] - uploader.object_key = upload_request['uploads'][0]['s3_key'] - uploader.upload_id = upload_request['uploads'][0]['upload_id'] + uploader.region_name = upload_request["s3_region"] + uploader.aws_access_key_id = upload_request["temporary_credentials"][ + "access_key_id" + ] + uploader.aws_secret_access_key = upload_request["temporary_credentials"][ + "secret_access_key" + ] + uploader.aws_session_token = upload_request["temporary_credentials"][ + "session_token" + ] + uploader.bucket = upload_request["s3_bucket"] + uploader.object_key = upload_request["uploads"][0]["s3_key"] + uploader.upload_id = upload_request["uploads"][0]["upload_id"] uploader.s3_endpoint_url = self.session.s3_endpoint_url uploader.s3_use_ssl = self.session.s3_use_ssl uploader.s3_addressing_style = self.session.s3_addressing_style except KeyError: - raise RuntimeError("Upload initiation response is missing some fields: " - "{}".format(upload_request)) + raise RuntimeError( + "Upload initiation response is missing some fields: {}".format( + upload_request + ) + ) return uploader - def _search_by_file_name(self, - file_name: str, - dset_id: UUID, - file_version_number: Optional[int] = None - ) -> Optional[FileLink]: + def _search_by_file_name( + self, file_name: str, dset_id: UUID, file_version_number: Optional[int] = None + ) -> Optional[FileLink]: """ Make a request to the backend to search a file by name. @@ -432,22 +458,19 @@ def _search_by_file_name(self, path = self._get_path(action="search") search_json = { - 'fileSearchFilter': - { - 'type': SearchFileFilterTypeEnum.NAME_SEARCH.value, - 'datasetId': str(dset_id), - 'fileName': file_name, - 'fileVersionNumber': file_version_number - } + "fileSearchFilter": { + "type": SearchFileFilterTypeEnum.NAME_SEARCH.value, + "datasetId": str(dset_id), + "fileName": file_name, + "fileVersionNumber": file_version_number, + } } data = self.session.post_resource(path=path, json=search_json) - return self.build(data['files'][0]) + return self.build(data["files"][0]) - def _search_by_file_version_id(self, - file_version_id: UUID - ) -> Optional[FileLink]: + def _search_by_file_version_id(self, file_version_id: UUID) -> Optional[FileLink]: """ Make a request to the backend to search a file by file version id. @@ -465,21 +488,22 @@ def _search_by_file_version_id(self, path = self._get_path(action="search") search_json = { - 'fileSearchFilter': { - 'type': SearchFileFilterTypeEnum.VERSION_ID_SEARCH.value, - 'fileVersionUuid': str(file_version_id) + "fileSearchFilter": { + "type": SearchFileFilterTypeEnum.VERSION_ID_SEARCH.value, + "fileVersionUuid": str(file_version_id), } } data = self.session.post_resource(path=path, json=search_json) - return self.build(data['files'][0]) + return self.build(data["files"][0]) - def _search_by_dataset_file_id(self, - dataset_file_id: UUID, - dset_id: UUID, - file_version_number: Optional[int] = None - ) -> Optional[FileLink]: + def _search_by_dataset_file_id( + self, + dataset_file_id: UUID, + dset_id: UUID, + file_version_number: Optional[int] = None, + ) -> Optional[FileLink]: """ Make a request to the backend to search a file by dataset file id. @@ -504,17 +528,17 @@ def _search_by_dataset_file_id(self, path = self._get_path(action="search") search_json = { - 'fileSearchFilter': { - 'type': SearchFileFilterTypeEnum.DATASET_FILE_ID_SEARCH.value, - 'datasetId': str(dset_id), - 'datasetFileId': str(dataset_file_id), - 'fileVersionNumber': file_version_number + "fileSearchFilter": { + "type": SearchFileFilterTypeEnum.DATASET_FILE_ID_SEARCH.value, + "datasetId": str(dset_id), + "datasetFileId": str(dataset_file_id), + "fileVersionNumber": file_version_number, } } data = self.session.post_resource(path=path, json=search_json) - return self.build(data['files'][0]) + return self.build(data["files"][0]) @staticmethod def _mime_type(file_path: Path): @@ -542,20 +566,22 @@ def _upload_file(file_path: Path, uploader: _Uploader): """ additional_s3_opts = { - 'use_ssl': uploader.s3_use_ssl, - 'config': Config(s3={'addressing_style': uploader.s3_addressing_style}) + "use_ssl": uploader.s3_use_ssl, + "config": Config(s3={"addressing_style": uploader.s3_addressing_style}), } if uploader.s3_endpoint_url is not None: - additional_s3_opts['endpoint_url'] = uploader.s3_endpoint_url - - s3_client = boto3_client('s3', - region_name=uploader.region_name, - aws_access_key_id=uploader.aws_access_key_id, - aws_secret_access_key=uploader.aws_secret_access_key, - aws_session_token=uploader.aws_session_token, - **additional_s3_opts) - with file_path.open(mode='rb') as f: + additional_s3_opts["endpoint_url"] = uploader.s3_endpoint_url + + s3_client = boto3_client( + "s3", + region_name=uploader.region_name, + aws_access_key_id=uploader.aws_access_key_id, + aws_secret_access_key=uploader.aws_secret_access_key, + aws_session_token=uploader.aws_session_token, + **additional_s3_opts, + ) + with file_path.open(mode="rb") as f: try: # NOTE: This is only using the simple PUT logic, not the more sophisticated # multipart upload approach that is also available (providing parallel @@ -564,11 +590,14 @@ def _upload_file(file_path: Path, uploader: _Uploader): Bucket=uploader.bucket, Key=uploader.object_key, Body=f, - Metadata={"X-Citrine-Upload-Id": uploader.upload_id}) + Metadata={"X-Citrine-Upload-Id": uploader.upload_id}, + ) except ClientError as e: - raise RuntimeError(f"Upload of file {file_path} failed with the following " - f"exception: {e}") - uploader.s3_version = upload_response['VersionId'] + raise RuntimeError( + f"Upload of file {file_path} failed with the following " + f"exception: {e}" + ) + uploader.s3_version = upload_response["VersionId"] return uploader def _complete_upload(self, dest_name: str, uploader: _Uploader): @@ -589,19 +618,25 @@ def _complete_upload(self, dest_name: str, uploader: _Uploader): """ url = self._get_path(action=["uploads", uploader.upload_id, "complete"]) - complete_response = self.session.put_resource(path=url, - json={'s3_version': uploader.s3_version}) + complete_response = self.session.put_resource( + path=url, json={"s3_version": uploader.s3_version} + ) try: - file_id = complete_response['file_info']['file_id'] - version_id = complete_response['file_info']['version'] + file_id = complete_response["file_info"]["file_id"] + version_id = complete_response["file_info"]["version"] except KeyError: - raise RuntimeError("Upload completion response is missing some " - "fields: {}".format(complete_response)) + raise RuntimeError( + "Upload completion response is missing some fields: {}".format( + complete_response + ) + ) return self.build({"filename": dest_name, "id": file_id, "version": version_id}) - def download(self, *, file_link: Union[str, UUID, FileLink], local_path: Union[str, Path]): + def download( + self, *, file_link: Union[str, UUID, FileLink], local_path: Union[str, Path] + ): """ Download the file associated with a given FileLink to the local computer. @@ -649,36 +684,43 @@ def read(self, *, file_link: Union[str, UUID, FileLink]) -> bytes: if self._is_local_url(file_link.url): # Read the local file parsed_url = urlparse(file_link.url) - if parsed_url.netloc not in {'', '.', 'localhost'}: - raise ValueError("Non-local UNCs (e.g., Windows network paths) are not supported.") + if parsed_url.netloc not in {"", ".", "localhost"}: + raise ValueError( + "Non-local UNCs (e.g., Windows network paths) are not supported." + ) # Space should have been encoded as %20, but just in case it was a + - path = Path(url2pathname(parsed_url.path.replace('+', '%20'))) + path = Path(url2pathname(parsed_url.path.replace("+", "%20"))) return path.read_bytes() if self._is_external_url(file_link.url): # Pull it from where ever it lives final_url = file_link.url else: # The "/content-link" route returns a pre-signed url to download the file. - content_link = self._get_path_from_file_link(file_link, action='content-link') + content_link = self._get_path_from_file_link( + file_link, action="content-link" + ) content_link_response = self.session.get_resource(content_link) - pre_signed_url = content_link_response['pre_signed_read_link'] - final_url = rewrite_s3_links_locally(pre_signed_url, self.session.s3_endpoint_url) + pre_signed_url = content_link_response["pre_signed_read_link"] + final_url = rewrite_s3_links_locally( + pre_signed_url, self.session.s3_endpoint_url + ) download_response = requests.get(final_url) return download_response.content - def ingest(self, - files: Iterable[Union[FileLink, Path, str]], - *, - upload: bool = False, - raise_errors: bool = True, - build_table: bool = False, - delete_dataset_contents: bool = False, - delete_templates: bool = True, - timeout: float = None, - polling_delay: Optional[float] = None, - project: Optional[Union["Project", UUID, str]] = None, # noqa: F821 - ) -> "IngestionStatus": # noqa: F821 + def ingest( + self, + files: Iterable[Union[FileLink, Path, str]], + *, + upload: bool = False, + raise_errors: bool = True, + build_table: bool = False, + delete_dataset_contents: bool = False, + delete_templates: bool = True, + timeout: float = None, + polling_delay: Optional[float] = None, + project: Optional[Union["Project", UUID, str]] = None, # noqa: F821 + ) -> "IngestionStatus": # noqa: F821 """ [ALPHA] Ingest a set of CSVs and/or Excel Workbooks formatted per the gemd-ingest protocol. @@ -733,7 +775,7 @@ def ingest(self, warn( "Building a table with an implicit project is deprecated " "and will be removed in v4. Please pass a project explicitly.", - DeprecationWarning + DeprecationWarning, ) project = self.project_id @@ -750,7 +792,9 @@ def resolve_with_local(candidate: Union[FileLink, Path, str]) -> FileLink: elif isinstance(candidate, Path): raise TypeError("Path objects are only valid when upload=True.") - return self._resolve_file_link(candidate) # It was a FileLink or unresolvable string + return self._resolve_file_link( + candidate + ) # It was a FileLink or unresolvable string resolved = [] for file in files: @@ -766,17 +810,20 @@ def resolve_with_local(candidate: Union[FileLink, Path, str]) -> FileLink: for file_link in offplatform: path = Path(downloads) / file_link.filename self.download(file_link=file_link, local_path=path) - onplatform.append(self.upload(file_path=path, dest_name=file_link.filename)) + onplatform.append( + self.upload(file_path=path, dest_name=file_link.filename) + ) elif len(offplatform) > 0: - raise ValueError(f"All files must be on-platform to load them. " - f"The following are not: {offplatform}") + raise ValueError( + f"All files must be on-platform to load them. " + f"The following are not: {offplatform}" + ) - ingestion_collection = IngestionCollection(team_id=self.team_id, - dataset_id=self.dataset_id, - session=self.session) + ingestion_collection = IngestionCollection( + team_id=self.team_id, dataset_id=self.dataset_id, session=self.session + ) ingestion = ingestion_collection.build_from_file_links( - file_links=onplatform, - raise_errors=raise_errors + file_links=onplatform, raise_errors=raise_errors ) return ingestion.build_objects( build_table=build_table, @@ -784,7 +831,7 @@ def resolve_with_local(candidate: Union[FileLink, Path, str]) -> FileLink: delete_dataset_contents=delete_dataset_contents, delete_templates=delete_templates, timeout=timeout, - polling_delay=polling_delay + polling_delay=polling_delay, ) def delete(self, file_link: FileLink): @@ -798,23 +845,31 @@ def delete(self, file_link: FileLink): """ if not self._is_on_platform_url(file_link.url): - raise ValueError(f"Only platform resources can be deleted; passed URL {file_link.url}") + raise ValueError( + f"Only platform resources can be deleted; passed URL {file_link.url}" + ) file_id = _get_ids_from_url(file_link.url)[0] if file_id is None: - raise ValueError(f"URL was malformed for local resources; passed URL {file_link.url}") + raise ValueError( + f"URL was malformed for local resources; passed URL {file_link.url}" + ) data = self.session.delete_resource(self._get_path(file_id)) return Response(body=data) def _resolve_file_link(self, identifier: Union[str, UUID, FileLink]) -> FileLink: """Generate the FileLink object referenced by the passed argument.""" if isinstance(identifier, GEMDFileLink): - if not isinstance(identifier, FileLink): # Up-convert type with existing info + if not isinstance( + identifier, FileLink + ): # Up-convert type with existing info update = FileLink(filename=identifier.filename, url=identifier.url) if self._is_on_platform_url(update.url): if update.uid is None: - raise ValueError(f"URL was malformed for platform resources; " - f"passed URL {update.url}") + raise ValueError( + f"URL was malformed for platform resources; " + f"passed URL {update.url}" + ) else: # Validate that it's a real record update = self.get(uid=update.uid, version=update.version) if update.filename != identifier.filename: @@ -833,13 +888,15 @@ def _resolve_file_link(self, identifier: Union[str, UUID, FileLink]) -> FileLink else: # We got a file UID (and possibly a version UID) from a URL return self.get(uid=file_id, version=version_id) else: # Assume it's an absolute URL - filename = urlparse(identifier).path.split('/')[-1] + filename = urlparse(identifier).path.split("/")[-1] return FileLink(filename=filename, url=identifier) elif isinstance(identifier, UUID): # File UID return self.get(uid=identifier) else: - raise TypeError(f"File Link can only be resolved from str, or UUID." - f"Instead got {type(identifier)} {identifier}.") + raise TypeError( + f"File Link can only be resolved from str, or UUID." + f"Instead got {type(identifier)} {identifier}." + ) def _is_external_url(self, url: str): """Check if the URL is absolute and not associated with this platform instance.""" diff --git a/src/citrine/resources/gemd_resource.py b/src/citrine/resources/gemd_resource.py index e72b89650..5a7045743 100644 --- a/src/citrine/resources/gemd_resource.py +++ b/src/citrine/resources/gemd_resource.py @@ -1,21 +1,35 @@ """Collection class for generic GEMD objects and templates.""" + import re from typing import Type, Union, List, Tuple, Iterable, Optional from uuid import UUID, uuid4 from gemd.entity.base_entity import BaseEntity from gemd.entity.link_by_uid import LinkByUID -from gemd.util import recursive_flatmap, recursive_foreach, set_uuids, \ - make_index, substitute_objects +from gemd.util import ( + recursive_flatmap, + recursive_foreach, + set_uuids, + make_index, + substitute_objects, +) from tqdm.auto import tqdm from citrine.resources.api_error import ApiError -from citrine.resources.data_concepts import DataConcepts, DataConceptsCollection, \ - CITRINE_SCOPE, CITRINE_TAG_PREFIX +from citrine.resources.data_concepts import ( + DataConcepts, + DataConceptsCollection, + CITRINE_SCOPE, + CITRINE_TAG_PREFIX, +) from citrine.resources.delete import _async_gemd_batch_delete from citrine._session import Session from citrine._utils.batcher import Batcher -from citrine._utils.functions import _pad_positional_args, replace_objects_with_links, scrub_none +from citrine._utils.functions import ( + _pad_positional_args, + replace_objects_with_links, + scrub_none, +) BATCH_SIZE = 50 @@ -24,7 +38,7 @@ class GEMDResourceCollection(DataConceptsCollection[DataConcepts]): """A collection of any kind of GEMD objects/templates.""" - _collection_key = 'storables' + _collection_key = "storables" def __init__( self, @@ -32,13 +46,15 @@ def __init__( dataset_id: UUID = None, session: Session = None, team_id: UUID = None, - project_id: Optional[UUID] = None + project_id: Optional[UUID] = None, ): - super().__init__(*args, - team_id=team_id, - dataset_id=dataset_id, - session=session, - project_id=project_id) + super().__init__( + *args, + team_id=team_id, + dataset_id=dataset_id, + session=session, + project_id=project_id, + ) args = _pad_positional_args(args, 3) self.project_id = project_id or args[0] self.dataset_id = dataset_id or args[1] @@ -52,7 +68,9 @@ def get_type(cls) -> Type[DataConcepts]: def _collection_for(self, model): collection = DataConcepts.get_collection_type(model) - return collection(team_id=self.team_id, dataset_id=self.dataset_id, session=self.session) + return collection( + team_id=self.team_id, dataset_id=self.dataset_id, session=self.session + ) def build(self, data: dict) -> DataConcepts: """ @@ -73,12 +91,14 @@ def build(self, data: dict) -> DataConcepts: """ return super().build(data) - def register_all(self, - models: Iterable[DataConcepts], - *, - dry_run=False, - status_bar=False, - include_nested=False) -> List[DataConcepts]: + def register_all( + self, + models: Iterable[DataConcepts], + *, + dry_run=False, + status_bar=False, + include_nested=False, + ) -> List[DataConcepts]: """ Register multiple GEMD objects to each of their appropriate collections. @@ -120,9 +140,11 @@ def register_all(self, """ if self.dataset_id is None: - raise RuntimeError("Must specify a dataset in order to register a data model object.") + raise RuntimeError( + "Must specify a dataset in order to register a data model object." + ) path = self._get_path() - params = {'dry_run': dry_run} + params = {"dry_run": dry_run} if include_nested: models = recursive_flatmap(models, lambda o: [o], unidirectional=False) @@ -145,13 +167,13 @@ def register_all(self, iterator = batcher.batch(models, BATCH_SIZE) for batch in iterator: - objects = [replace_objects_with_links(scrub_none(model.dump())) for model in batch] + objects = [ + replace_objects_with_links(scrub_none(model.dump())) for model in batch + ] response_data = self.session.put_resource( - path + '/batch', - json={'objects': objects}, - params=params + path + "/batch", json={"objects": objects}, params=params ) - registered = [self.build(obj) for obj in response_data['objects']] + registered = [self.build(obj) for obj in response_data["objects"]] result_index.update(make_index(registered)) substitute_objects(registered, result_index, inplace=True) @@ -169,10 +191,15 @@ def register_all(self, result = result_index[obj.to_link()] if CITRINE_SCOPE not in obj.uids: citr_id = result.uids.pop(CITRINE_SCOPE, None) - result_index.pop(LinkByUID(scope=CITRINE_SCOPE, id=citr_id), None) + result_index.pop( + LinkByUID(scope=CITRINE_SCOPE, id=citr_id), None + ) if result.tags is not None: - todo = [tag for tag in result.tags - if re.match(f"^{CITRINE_TAG_PREFIX}::", tag)] + todo = [ + tag + for tag in result.tags + if re.match(f"^{CITRINE_TAG_PREFIX}::", tag) + ] for tag in todo: # Covering this block would require dark art if tag not in obj.tags: result.tags.remove(tag) @@ -180,16 +207,17 @@ def register_all(self, resources.extend(registered) if dry_run: # No-op if not dry-run - recursive_foreach(list(models) + list(resources), - lambda x: x.uids.pop(temp_scope, None)) # Strip temp uids + recursive_foreach( + list(models) + list(resources), lambda x: x.uids.pop(temp_scope, None) + ) # Strip temp uids return resources def batch_delete( - self, - id_list: List[Union[LinkByUID, UUID, str, BaseEntity]], - *, - timeout: float = 2 * 60, - polling_delay: float = 1.0 + self, + id_list: List[Union[LinkByUID, UUID, str, BaseEntity]], + *, + timeout: float = 2 * 60, + polling_delay: float = 1.0, ) -> List[Tuple[LinkByUID, ApiError]]: """ Remove a set of GEMD objects. @@ -231,4 +259,5 @@ def batch_delete( session=self.session, dataset_id=self.dataset_id, timeout=timeout, - polling_delay=polling_delay) + polling_delay=polling_delay, + ) diff --git a/src/citrine/resources/gemtables.py b/src/citrine/resources/gemtables.py index 71fb0def5..74650ccac 100644 --- a/src/citrine/resources/gemtables.py +++ b/src/citrine/resources/gemtables.py @@ -10,15 +10,19 @@ from citrine._serialization import properties from citrine._serialization.properties import UUID from citrine._session import Session -from citrine._utils.functions import format_escaped_url, _pad_positional_args, \ - rewrite_s3_links_locally, write_file_locally +from citrine._utils.functions import ( + format_escaped_url, + _pad_positional_args, + rewrite_s3_links_locally, + write_file_locally, +) from citrine.jobs.job import JobSubmissionResponse, _poll_for_job_completion from citrine.resources.table_config import TableConfig, TableConfigCollection logger = getLogger(__name__) -class GemTable(Resource['Table']): +class GemTable(Resource["Table"]): """A 2-dimensional projection of data. GEM Tables are the basic unit used to flatten and manipulate data objects. @@ -27,21 +31,25 @@ class GemTable(Resource['Table']): can be used to 'flatten' data objects into useful projections. """ - _response_key = 'table' + _response_key = "table" _resource_type = ResourceTypeEnum.TABLE - uid = properties.UUID('id') + uid = properties.UUID("id") """:UUID: unique Citrine id of this GEM Table""" - version = properties.Integer('version') + version = properties.Integer("version") """:int: Version number of the GEM Table. The first table built from a given config is version 1.""" - download_url = properties.String('signed_download_url') + download_url = properties.String("signed_download_url") """:str: URL pointing to the location of the GEM Table's contents. This is an expiring download link and is not unique.""" - _config = properties.Optional(properties.Object(TableConfig), "config", serializable=False) + _config = properties.Optional( + properties.Object(TableConfig), "config", serializable=False + ) _name = properties.Optional(properties.String, "name", serializable=False) - _description = properties.Optional(properties.String, "description", serializable=False) + _description = properties.Optional( + properties.String, "description", serializable=False + ) def __init__(self): self._team_id = None @@ -49,15 +57,17 @@ def __init__(self): self._session = None def __str__(self): - return ''.format(self.uid, self.version) + return "".format(self.uid, self.version) @property def config(self) -> TableConfig: """Configuration used to build the table.""" if self._config is None: - config_collection = TableConfigCollection(team_id=self._team_id, - project_id=self._project_id, - session=self._session) + config_collection = TableConfigCollection( + team_id=self._team_id, + project_id=self._project_id, + session=self._session, + ) self._config = config_collection.get_for_table(self) return self._config @@ -91,16 +101,18 @@ def _comparison_fields(self, entity: GemTable) -> Any: class GemTableCollection(Collection[GemTable]): """Represents the collection of all tables associated with a project.""" - _path_template = 'projects/{project_id}/display-tables' - _collection_key: str = 'tables' + _path_template = "projects/{project_id}/display-tables" + _collection_key: str = "tables" _paginator: Paginator = GemTableVersionPaginator() _resource = GemTable - def __init__(self, - *args, - team_id: UUID = None, - project_id: UUID = None, - session: Session = None): + def __init__( + self, + *args, + team_id: UUID = None, + project_id: UUID = None, + session: Session = None, + ): args = _pad_positional_args(args, 2) self.project_id = project_id or args[0] self.session: Session = session or args[1] @@ -123,10 +135,7 @@ def get(self, uid: Union[UUID, str], *, version: Optional[int] = None) -> GemTab newest_table = max(tables, key=lambda x: x.version or 0) return newest_table - def list_versions(self, - uid: UUID, - *, - per_page: int = 100) -> Iterable[GemTable]: + def list_versions(self, uid: UUID, *, per_page: int = 100) -> Iterable[GemTable]: """ List the versions of a table given a specific Table UID. @@ -137,11 +146,14 @@ def list_versions(self, :param per_page: The number of items to fetch per-page. :return: An iterable of the versions of the Tables (as Table objects). """ - def _fetch_versions(page: Optional[int], - per_page: int) -> Tuple[Iterable[dict], str]: - data = self.session.get_resource(self._get_path(uid), - params=self._page_params(page, per_page)) - return data[self._collection_key], data.get('next', "") + + def _fetch_versions( + page: Optional[int], per_page: int + ) -> Tuple[Iterable[dict], str]: + data = self.session.get_resource( + self._get_path(uid), params=self._page_params(page, per_page) + ) + return data[self._collection_key], data.get("next", "") def _build_versions(collection: Iterable[dict]) -> Iterable[GemTable]: for item in collection: @@ -149,12 +161,15 @@ def _build_versions(collection: Iterable[dict]) -> Iterable[GemTable]: return self._paginator.paginate( # Don't deduplicate on uid since uids are shared between versions - _fetch_versions, _build_versions, per_page, deduplicate=False) + _fetch_versions, + _build_versions, + per_page, + deduplicate=False, + ) - def list_by_config(self, - table_config_uid: UUID, - *, - per_page: int = 100) -> Iterable[GemTable]: + def list_by_config( + self, table_config_uid: UUID, *, per_page: int = 100 + ) -> Iterable[GemTable]: """ List the versions of a table associated with a given Table Config UID. @@ -165,18 +180,20 @@ def list_by_config(self, :param per_page: The number of items to fetch per-page. :return: An iterable of the versions of the Tables (as Table objects). """ - def _fetch_versions(page: Optional[int], - per_page: int) -> Tuple[Iterable[dict], str]: - path_params = {'table_config_uid_str': str(table_config_uid)} + + def _fetch_versions( + page: Optional[int], per_page: int + ) -> Tuple[Iterable[dict], str]: + path_params = {"table_config_uid_str": str(table_config_uid)} path_params.update(self.__dict__) path = format_escaped_url( - 'projects/{project_id}/table-configs/{table_config_uid_str}/gem-tables', - **path_params + "projects/{project_id}/table-configs/{table_config_uid_str}/gem-tables", + **path_params, ) data = self.session.get_resource( - path, - params=self._page_params(page, per_page)) - return data[self._collection_key], data.get('next', "") + path, params=self._page_params(page, per_page) + ) + return data[self._collection_key], data.get("next", "") def _build_versions(collection: Iterable[dict]) -> Iterable[GemTable]: for item in collection: @@ -184,10 +201,15 @@ def _build_versions(collection: Iterable[dict]) -> Iterable[GemTable]: return self._paginator.paginate( # Don't deduplicate on uid since uids are shared between versions - _fetch_versions, _build_versions, per_page, deduplicate=False) + _fetch_versions, + _build_versions, + per_page, + deduplicate=False, + ) - def initiate_build(self, config: Union[TableConfig, str, UUID], *, - version: Union[str, UUID] = None) -> JobSubmissionResponse: + def initiate_build( + self, config: Union[TableConfig, str, UUID], *, version: Union[str, UUID] = None + ) -> JobSubmissionResponse: """ Initiates tables build with provided config. @@ -210,20 +232,28 @@ def initiate_build(self, config: Union[TableConfig, str, UUID], *, """ if isinstance(config, TableConfig): if version is not None: - logger.warning('Ignoring version %s since config object was provided.', version) + logger.warning( + "Ignoring version %s since config object was provided.", version + ) if config.version_number is None: - raise ValueError('Cannot build table from config which has no version. ' - 'Try registering the config before building.') + raise ValueError( + "Cannot build table from config which has no version. " + "Try registering the config before building." + ) if config.uid is None: - raise ValueError('Cannot build table from config which has no uid. ' - 'Try registering the config before building.') + raise ValueError( + "Cannot build table from config which has no uid. " + "Try registering the config before building." + ) uid = config.uid version = config.version_number else: if version is None: - raise ValueError('Version must be specified when building by config uid.') + raise ValueError( + "Version must be specified when building by config uid." + ) uid = config - logger.info(f'Submitting table build for config {uid} version {version}...') + logger.info(f"Submitting table build for config {uid} version {version}...") path = format_escaped_url( "teams/{}/projects/{}/table-configs/{}/versions/{}/build", self.team_id, @@ -234,13 +264,14 @@ def initiate_build(self, config: Union[TableConfig, str, UUID], *, response = self.session.post_resource(path=path, json={}) submission = JobSubmissionResponse.build(response) logger.info( - f'Table build job submitted from config {uid} ' - f'version {version} with job ID {submission.job_id}' + f"Table build job submitted from config {uid} " + f"version {version} with job ID {submission.job_id}" ) return submission - def get_by_build_job(self, job: Union[JobSubmissionResponse, UUID], *, - timeout: float = 15 * 60) -> GemTable: + def get_by_build_job( + self, job: Union[JobSubmissionResponse, UUID], *, timeout: float = 15 * 60 + ) -> GemTable: """ Gets table by build job, waiting for it to complete if necessary. @@ -260,29 +291,33 @@ def get_by_build_job(self, job: Union[JobSubmissionResponse, UUID], *, """ status = _poll_for_job_completion( - session=self.session, - team_id=self.team_id, - job=job, - timeout=timeout) - - table_id = status.output['display_table_id'] - table_version = status.output['display_table_version'] - warning_blob = status.output.get('table_warnings') + session=self.session, team_id=self.team_id, job=job, timeout=timeout + ) + + table_id = status.output["display_table_id"] + table_version = status.output["display_table_version"] + warning_blob = status.output.get("table_warnings") warnings = json.loads(warning_blob) if warning_blob is not None else [] if warnings: - warn_lines = ['Table build completed with warnings:'] + warn_lines = ["Table build completed with warnings:"] for warning in warnings: - limited_results = warning.get('limited_results', []) + limited_results = warning.get("limited_results", []) warn_lines.extend(limited_results) - total_count = warning.get('total_count', 0) + total_count = warning.get("total_count", 0) if total_count > len(limited_results): - warn_lines.append(f'and {total_count - len(limited_results)} more similar.') - logger.warning('\n\t'.join(warn_lines)) + warn_lines.append( + f"and {total_count - len(limited_results)} more similar." + ) + logger.warning("\n\t".join(warn_lines)) return self.get(table_id, version=table_version) - def build_from_config(self, config: Union[TableConfig, str, UUID], *, - version: Union[str, int] = None, - timeout: float = 15 * 60) -> GemTable: + def build_from_config( + self, + config: Union[TableConfig, str, UUID], + *, + version: Union[str, int] = None, + timeout: float = 15 * 60, + ) -> GemTable: """ Builds table from table config, waiting for build job to complete. @@ -316,7 +351,7 @@ def build(self, data: dict) -> GemTable: def register(self, model: GemTable) -> GemTable: """Tables cannot be created at this time.""" - raise RuntimeError('Creating Tables is not supported at this time.') + raise RuntimeError("Creating Tables is not supported at this time.") def update(self, model: GemTable) -> GemTable: """Tables cannot be updated.""" @@ -353,7 +388,9 @@ def _read_raw(self, table: table_type) -> requests.Response: table = self.get(uid=table) data_location = table.download_url - data_location = rewrite_s3_links_locally(data_location, self.session.s3_endpoint_url) + data_location = rewrite_s3_links_locally( + data_location, self.session.s3_endpoint_url + ) return requests.get(data_location) def read_to_memory(self, table: table_type) -> str: diff --git a/src/citrine/resources/generative_design_execution.py b/src/citrine/resources/generative_design_execution.py index 8ec82a481..1c957e93c 100644 --- a/src/citrine/resources/generative_design_execution.py +++ b/src/citrine/resources/generative_design_execution.py @@ -1,10 +1,13 @@ """Resources that represent both individual and collections of design workflow executions.""" + from typing import Union, Iterator from uuid import UUID from citrine._rest.collection import Collection from citrine._session import Session -from citrine.informatics.executions.generative_design_execution import GenerativeDesignExecution +from citrine.informatics.executions.generative_design_execution import ( + GenerativeDesignExecution, +) from citrine.informatics.generative_design import GenerativeDesignInput from citrine.resources.response import Response @@ -12,9 +15,9 @@ class GenerativeDesignExecutionCollection(Collection["GenerativeDesignExecution"]): """A collection of GenerativeDesignExecutions.""" - _path_template = '/projects/{project_id}/generative-design/executions' + _path_template = "/projects/{project_id}/generative-design/executions" _individual_key = None - _collection_key = 'response' + _collection_key = "response" _resource = GenerativeDesignExecution def __init__(self, project_id: UUID, session: Session): @@ -65,12 +68,12 @@ def list(self, *, per_page: int = 10) -> Iterator[GenerativeDesignExecution]: Resources in this collection. """ - return self._paginator.paginate(page_fetcher=self._fetch_page, - collection_builder=self._build_collection_elements, - per_page=per_page) + return self._paginator.paginate( + page_fetcher=self._fetch_page, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) def delete(self, uid: Union[UUID, str]) -> Response: """Generative Design Executions cannot be deleted or archived.""" - raise NotImplementedError( - "Generative Design Executions cannot be deleted" - ) + raise NotImplementedError("Generative Design Executions cannot be deleted") diff --git a/src/citrine/resources/ingestion.py b/src/citrine/resources/ingestion.py index 916426db1..4449d87b2 100644 --- a/src/citrine/resources/ingestion.py +++ b/src/citrine/resources/ingestion.py @@ -9,9 +9,16 @@ from citrine._rest.resource import Resource from citrine._serialization import properties from citrine._session import Session -from citrine._utils.functions import _data_manager_deprecation_checks, _pad_positional_args +from citrine._utils.functions import ( + _data_manager_deprecation_checks, + _pad_positional_args, +) from citrine.exceptions import CitrineException, BadRequest -from citrine.jobs.job import JobSubmissionResponse, JobFailureError, _poll_for_job_completion +from citrine.jobs.job import ( + JobSubmissionResponse, + JobFailureError, + _poll_for_job_completion, +) from citrine.resources.api_error import ApiError, ValidationError from citrine.resources.file_link import FileLink @@ -65,29 +72,36 @@ class IngestionErrorLevel(BaseEnumeration): INFO = "info" -class IngestionErrorTrace(Resource['IngestionErrorTrace']): +class IngestionErrorTrace(Resource["IngestionErrorTrace"]): """[ALPHA] Detailed information about an ingestion issue.""" family = properties.Enumeration(IngestionErrorFamily, "family") error_type = properties.Enumeration(IngestionErrorType, "error_type") level = properties.Enumeration(IngestionErrorLevel, "level") msg = properties.String("msg") - dataset_file_id = properties.Optional(properties.UUID(), "dataset_file_id", default=None) - file_version_uuid = properties.Optional(properties.UUID(), "file_version_uuid", default=None) + dataset_file_id = properties.Optional( + properties.UUID(), "dataset_file_id", default=None + ) + file_version_uuid = properties.Optional( + properties.UUID(), "file_version_uuid", default=None + ) row_number = properties.Optional(properties.Integer(), "row_number", default=None) - column_number = properties.Optional(properties.Integer(), "column_number", default=None) - - def __init__(self, - msg, - level=IngestionErrorLevel.ERROR, - *, - family=IngestionErrorFamily.UNKNOWN, - error_type=IngestionErrorType.UNKNOWN_ERROR, - dataset_file_id=dataset_file_id.default, - file_version_uuid=file_version_uuid.default, - row_number=row_number.default, - column_number=column_number.default, - ): + column_number = properties.Optional( + properties.Integer(), "column_number", default=None + ) + + def __init__( + self, + msg, + level=IngestionErrorLevel.ERROR, + *, + family=IngestionErrorFamily.UNKNOWN, + error_type=IngestionErrorType.UNKNOWN_ERROR, + dataset_file_id=dataset_file_id.default, + file_version_uuid=file_version_uuid.default, + row_number=row_number.default, + column_number=column_number.default, + ): self.msg = msg self.level = level self.family = family @@ -109,25 +123,29 @@ def __str__(self): return f"{self!r}: {self.msg}" def __repr__(self): - coords = ", ".join([x for x in (self.column_number, self.row_number) if x is not None]) + coords = ", ".join( + [x for x in (self.column_number, self.row_number) if x is not None] + ) return f"<{self.level}: {self.error_type}{f' {coords}' if len(coords) else ''}>" class IngestionException(CitrineException): """[ALPHA] An exception that contains details of a failed ingestion.""" - uid = properties.Optional(properties.UUID(), 'ingestion_id', default=None) + uid = properties.Optional(properties.UUID(), "ingestion_id", default=None) """Optional[UUID]""" status = properties.Enumeration(IngestionStatusType, "status") errors = properties.List(properties.Object(IngestionErrorTrace), "errors") """List[IngestionErrorTrace]""" - def __init__(self, - *, - uid: Optional[UUID] = uid.default, - errors: Iterable[IngestionErrorTrace]): + def __init__( + self, + *, + uid: Optional[UUID] = uid.default, + errors: Iterable[IngestionErrorTrace], + ): errors_ = list(errors) - message = '; '.join(str(e) for e in errors_) + message = "; ".join(str(e) for e in errors_) super().__init__(message) self.uid = uid self.errors = errors_ @@ -141,27 +159,33 @@ def from_status(cls, source: "IngestionStatus") -> "IngestionException": def from_api_error(cls, source: ApiError) -> "IngestionException": """[ALPHA] Build an IngestionException from an ApiError.""" if len(source.validation_errors) > 0: - return cls(errors=[IngestionErrorTrace.from_validation_error(x) - for x in source.validation_errors]) + return cls( + errors=[ + IngestionErrorTrace.from_validation_error(x) + for x in source.validation_errors + ] + ) else: return cls(errors=[IngestionErrorTrace(msg=source.message)]) -class IngestionStatus(Resource['IngestionStatus']): +class IngestionStatus(Resource["IngestionStatus"]): """[ALPHA] An object that represents the outcome of an ingestion event.""" - uid = properties.Optional(properties.UUID(), 'ingestion_id', default=None) + uid = properties.Optional(properties.UUID(), "ingestion_id", default=None) """UUID""" status = properties.Enumeration(IngestionStatusType, "status") """IngestionStatusType""" errors = properties.List(properties.Object(IngestionErrorTrace), "errors") """List[IngestionErrorTrace]""" - def __init__(self, - *, - uid: Optional[UUID] = uid.default, - status: IngestionStatusType = IngestionStatusType.INGESTION_CREATED, - errors: Iterable[IngestionErrorTrace]): + def __init__( + self, + *, + uid: Optional[UUID] = uid.default, + status: IngestionStatusType = IngestionStatusType.INGESTION_CREATED, + errors: Iterable[IngestionErrorTrace], + ): self.uid = uid self.status = status self.errors = list(errors) @@ -177,7 +201,7 @@ def from_exception(cls, exception: IngestionException) -> "IngestionStatus": return cls(uid=exception.uid, errors=exception.errors) -class Ingestion(Resource['Ingestion']): +class Ingestion(Resource["Ingestion"]): """ [ALPHA] A job that uploads new information to the platform. @@ -187,37 +211,46 @@ class Ingestion(Resource['Ingestion']): """ - uid = properties.UUID('ingestion_id') + uid = properties.UUID("ingestion_id") """UUID: Unique uuid4 identifier of this ingestion.""" - team_id = properties.Optional(properties.UUID, 'team_id', default=None) - _project_id = properties.Optional(properties.UUID, 'project_id', default=None) - dataset_id = properties.UUID('dataset_id') - session = properties.Object(Session, 'session', serializable=False) - raise_errors = properties.Optional(properties.Boolean(), 'raise_errors', default=True) + team_id = properties.Optional(properties.UUID, "team_id", default=None) + _project_id = properties.Optional(properties.UUID, "project_id", default=None) + dataset_id = properties.UUID("dataset_id") + session = properties.Object(Session, "session", serializable=False) + raise_errors = properties.Optional( + properties.Boolean(), "raise_errors", default=True + ) @property - @deprecated(deprecated_in='3.11.0', removed_in='4.0.0', - details="The project_id attribute is deprecated since " - "dataset access is now controlled through teams.") + @deprecated( + deprecated_in="3.11.0", + removed_in="4.0.0", + details="The project_id attribute is deprecated since " + "dataset access is now controlled through teams.", + ) def project_id(self) -> Optional[UUID]: """[DEPRECATED] The project ID associated with this ingest.""" return self._project_id @project_id.setter - @deprecated(deprecated_in='3.9.0', removed_in='4.0.0', - details="Use the project argument instead of setting the project_id attribute.") + @deprecated( + deprecated_in="3.9.0", + removed_in="4.0.0", + details="Use the project argument instead of setting the project_id attribute.", + ) def project_id(self, value: Optional[UUID]): self._project_id = value - def build_objects(self, - *, - build_table: bool = False, - project: Optional[Union["Project", UUID, str]] = None, # noqa: F821 - delete_dataset_contents: bool = False, - delete_templates: bool = True, - timeout: float = None, - polling_delay: Optional[float] = None - ) -> IngestionStatus: + def build_objects( + self, + *, + build_table: bool = False, + project: Optional[Union["Project", UUID, str]] = None, # noqa: F821 + delete_dataset_contents: bool = False, + delete_templates: bool = True, + timeout: float = None, + polling_delay: Optional[float] = None, + ) -> IngestionStatus: """ [ALPHA] Perform a complete ingestion operation, from start to finish. @@ -249,29 +282,35 @@ def build_objects(self, """ try: - job = self.build_objects_async(build_table=build_table, - project=project, - delete_dataset_contents=delete_dataset_contents, - delete_templates=delete_templates) + job = self.build_objects_async( + build_table=build_table, + project=project, + delete_dataset_contents=delete_dataset_contents, + delete_templates=delete_templates, + ) except IngestionException as e: if self.raise_errors: raise e else: return IngestionStatus.from_exception(e) - status = self.poll_for_job_completion(job, timeout=timeout, polling_delay=polling_delay) + status = self.poll_for_job_completion( + job, timeout=timeout, polling_delay=polling_delay + ) if self.raise_errors and not status.success: raise IngestionException.from_status(status) return status - def build_objects_async(self, - *, - build_table: bool = False, - project: Optional[Union["Project", UUID, str]] = None, # noqa: F821 - delete_dataset_contents: bool = False, - delete_templates: bool = True) -> JobSubmissionResponse: + def build_objects_async( + self, + *, + build_table: bool = False, + project: Optional[Union["Project", UUID, str]] = None, # noqa: F821 + delete_dataset_contents: bool = False, + delete_templates: bool = True, + ) -> JobSubmissionResponse: """ [ALPHA] Begin an async ingestion operation. @@ -294,9 +333,10 @@ def build_objects_async(self, """ from citrine.resources.project import Project - collection = IngestionCollection(team_id=self.team_id, - dataset_id=self.dataset_id, - session=self.session) + + collection = IngestionCollection( + team_id=self.team_id, dataset_id=self.dataset_id, session=self.session + ) path = collection._get_path(uid=self.uid, action="gemd-objects-async") # Project resolution logic @@ -309,7 +349,7 @@ def build_objects_async(self, warn( "Building a table with an implicit project is deprecated " "and will be removed in v4. Please pass a project explicitly.", - DeprecationWarning + DeprecationWarning, ) project_id = self._project_id elif isinstance(project, Project): @@ -335,12 +375,13 @@ def build_objects_async(self, else: raise e - def poll_for_job_completion(self, - job: JobSubmissionResponse, - *, - timeout: Optional[float] = None, - polling_delay: Optional[float] = None - ) -> IngestionStatus: + def poll_for_job_completion( + self, + job: JobSubmissionResponse, + *, + timeout: Optional[float] = None, + polling_delay: Optional[float] = None, + ) -> IngestionStatus: """ [ALPHA] Repeatedly ask server if a job associated with this ingestion has completed. @@ -373,15 +414,18 @@ def poll_for_job_completion(self, team_id=self.team_id, job=job, raise_errors=False, # JobFailureError doesn't contain the error - **kwargs + **kwargs, ) - if build_job_status.output is not None and "table_build_job_id" in build_job_status.output: + if ( + build_job_status.output is not None + and "table_build_job_id" in build_job_status.output + ): _poll_for_job_completion( session=self.session, team_id=self.team_id, job=build_job_status.output["table_build_job_id"], raise_errors=False, # JobFailureError doesn't contain the error - **kwargs + **kwargs, ) return self.status() @@ -395,9 +439,9 @@ def status(self) -> IngestionStatus: The result of the ingestion attempt """ - collection = IngestionCollection(team_id=self.team_id, - dataset_id=self.dataset_id, - session=self.session) + collection = IngestionCollection( + team_id=self.team_id, dataset_id=self.dataset_id, session=self.session + ) path = collection._get_path(uid=self.uid, action="status") return IngestionStatus.build(self.session.get_resource(path=path)) @@ -409,42 +453,46 @@ def __init__(self, errors: Iterable[IngestionErrorTrace]): self.errors = list(errors) self.raise_errors = False - def build_objects(self, - *, - build_table: bool = False, - project: Optional[Union["Project", UUID, str]] = None, # noqa: F821 - delete_dataset_contents: bool = False, - delete_templates: bool = True, - timeout: float = None, - polling_delay: Optional[float] = None - ) -> IngestionStatus: + def build_objects( + self, + *, + build_table: bool = False, + project: Optional[Union["Project", UUID, str]] = None, # noqa: F821 + delete_dataset_contents: bool = False, + delete_templates: bool = True, + timeout: float = None, + polling_delay: Optional[float] = None, + ) -> IngestionStatus: """[ALPHA] Satisfy the required interface for a failed ingestion.""" return self.status() - def build_objects_async(self, - *, - build_table: bool = False, - project: Optional[Union["Project", UUID, str]] = None, # noqa: F821 - delete_dataset_contents: bool = False, - delete_templates: bool = True) -> JobSubmissionResponse: + def build_objects_async( + self, + *, + build_table: bool = False, + project: Optional[Union["Project", UUID, str]] = None, # noqa: F821 + delete_dataset_contents: bool = False, + delete_templates: bool = True, + ) -> JobSubmissionResponse: """[ALPHA] Satisfy the required interface for a failed ingestion.""" raise JobFailureError( message=f"Errors: {[e.msg for e in self.errors]}", - job_id=UUID('0' * 32), # Nil UUID - failure_reasons=[e.msg for e in self.errors] + job_id=UUID("0" * 32), # Nil UUID + failure_reasons=[e.msg for e in self.errors], ) - def poll_for_job_completion(self, - job: JobSubmissionResponse, - *, - timeout: Optional[float] = None, - polling_delay: Optional[float] = None - ) -> IngestionStatus: + def poll_for_job_completion( + self, + job: JobSubmissionResponse, + *, + timeout: Optional[float] = None, + polling_delay: Optional[float] = None, + ) -> IngestionStatus: """[ALPHA] Satisfy the required interface for a failed ingestion.""" raise JobFailureError( message=f"Errors: {[e.msg for e in self.errors]}", - job_id=UUID('0' * 32), # Nil UUID - failure_reasons=[e.msg for e in self.errors] + job_id=UUID("0" * 32), # Nil UUID + failure_reasons=[e.msg for e in self.errors], ) def status(self) -> IngestionStatus: @@ -460,14 +508,16 @@ def status(self) -> IngestionStatus: if self.raise_errors: raise JobFailureError( message=f"Ingestion creation failed: {self.errors}", - job_id=UUID('0' * 32), # Nil UUID - failure_reasons=[str(x) for x in self.errors] + job_id=UUID("0" * 32), # Nil UUID + failure_reasons=[str(x) for x in self.errors], ) else: - return IngestionStatus.build({ - "status": IngestionStatusType.INGESTION_CREATED, - "errors": self.errors, - }) + return IngestionStatus.build( + { + "status": IngestionStatusType.INGESTION_CREATED, + "errors": self.errors, + } + ) class IngestionCollection(Collection[Ingestion]): @@ -493,7 +543,7 @@ def __init__( session: Session = None, team_id: UUID = None, dataset_id: UUID = None, - project_id: Optional[UUID] = None + project_id: Optional[UUID] = None, ): args = _pad_positional_args(args, 3) self.project_id = project_id or args[0] @@ -508,21 +558,21 @@ def __init__( session=self.session, project_id=self.project_id, team_id=team_id, - obj_type="Ingestions") + obj_type="Ingestions", + ) # After the Data Manager deprecation, # this can be a Class Variable using the `teams/...` endpoint @property def _path_template(self): if self.project_id is None: - return f'teams/{self.team_id}/ingestions' + return f"teams/{self.team_id}/ingestions" else: - return f'projects/{self.project_id}/ingestions' + return f"projects/{self.project_id}/ingestions" - def build_from_file_links(self, - file_links: TypingCollection[FileLink], - *, - raise_errors: bool = True) -> Ingestion: + def build_from_file_links( + self, file_links: TypingCollection[FileLink], *, raise_errors: bool = True + ) -> Ingestion: """ [ALPHA] Create an on-platform ingestion event based on the passed FileLink objects. @@ -539,7 +589,9 @@ def build_from_file_links(self, raise ValueError("No files passed.") invalid_links = [f for f in file_links if f.uid is None] if len(invalid_links) != 0: - raise ValueError(f"{len(invalid_links)} File Links have no on-platform UID.") + raise ValueError( + f"{len(invalid_links)} File Links have no on-platform UID." + ) req = { "dataset_id": str(self.dataset_id), @@ -547,7 +599,7 @@ def build_from_file_links(self, "files": [ {"dataset_file_id": str(f.uid), "file_version_uuid": str(f.version)} for f in file_links - ] + ], } try: @@ -555,8 +607,10 @@ def build_from_file_links(self, except BadRequest as e: if e.api_error is not None: if e.api_error.validation_errors: - errors = [IngestionErrorTrace.from_validation_error(error) - for error in e.api_error.validation_errors] + errors = [ + IngestionErrorTrace.from_validation_error(error) + for error in e.api_error.validation_errors + ] else: errors = [IngestionErrorTrace(msg=e.api_error.message)] if raise_errors: @@ -565,10 +619,7 @@ def build_from_file_links(self, return FailedIngestion(errors=errors) else: raise e - return self.build({ - **response, - "raise_errors": raise_errors - }) + return self.build({**response, "raise_errors": raise_errors}) def build(self, data: dict) -> Ingestion: """Build an instance of an Ingestion.""" diff --git a/src/citrine/resources/ingredient_run.py b/src/citrine/resources/ingredient_run.py index 340ef69db..c783ee3a1 100644 --- a/src/citrine/resources/ingredient_run.py +++ b/src/citrine/resources/ingredient_run.py @@ -1,4 +1,5 @@ """Resources that represent ingredient run data objects.""" + from typing import List, Dict, Optional, Type, Iterator, Union from uuid import UUID @@ -17,10 +18,10 @@ class IngredientRun( - GEMDResource['IngredientRun'], + GEMDResource["IngredientRun"], ObjectRun, GEMDIngredientRun, - typ=GEMDIngredientRun.typ + typ=GEMDIngredientRun.typ, ): """ An ingredient run. @@ -64,56 +65,69 @@ class IngredientRun( _response_key = GEMDIngredientRun.typ # 'ingredient_run' - material = PropertyOptional(LinkOrElse(GEMDMaterialRun), 'material', override=True) - process = PropertyOptional(LinkOrElse(GEMDProcessRun), - 'process', - override=True, - use_init=True, - ) - mass_fraction = PropertyOptional(Object(ContinuousValue), 'mass_fraction') - volume_fraction = PropertyOptional(Object(ContinuousValue), 'volume_fraction') - number_fraction = PropertyOptional(Object(ContinuousValue), 'number_fraction') - absolute_quantity = PropertyOptional(Object(ContinuousValue), 'absolute_quantity') - spec = PropertyOptional(LinkOrElse(GEMDIngredientSpec), 'spec', override=True, use_init=True) + material = PropertyOptional(LinkOrElse(GEMDMaterialRun), "material", override=True) + process = PropertyOptional( + LinkOrElse(GEMDProcessRun), + "process", + override=True, + use_init=True, + ) + mass_fraction = PropertyOptional(Object(ContinuousValue), "mass_fraction") + volume_fraction = PropertyOptional(Object(ContinuousValue), "volume_fraction") + number_fraction = PropertyOptional(Object(ContinuousValue), "number_fraction") + absolute_quantity = PropertyOptional(Object(ContinuousValue), "absolute_quantity") + spec = PropertyOptional( + LinkOrElse(GEMDIngredientSpec), "spec", override=True, use_init=True + ) """ Intentionally private because they have some unusual dynamics """ - _name = PropertyOptional(String(), 'name') - _labels = PropertyOptional(PropertyList(String()), 'labels') - - def __init__(self, - *, - uids: Optional[Dict[str, str]] = None, - tags: Optional[List[str]] = None, - notes: Optional[str] = None, - material: Optional[GEMDMaterialRun] = None, - process: Optional[GEMDProcessRun] = None, - mass_fraction: Optional[ContinuousValue] = None, - volume_fraction: Optional[ContinuousValue] = None, - number_fraction: Optional[ContinuousValue] = None, - absolute_quantity: Optional[ContinuousValue] = None, - spec: Optional[GEMDIngredientSpec] = None, - file_links: Optional[List[FileLink]] = None): + _name = PropertyOptional(String(), "name") + _labels = PropertyOptional(PropertyList(String()), "labels") + + def __init__( + self, + *, + uids: Optional[Dict[str, str]] = None, + tags: Optional[List[str]] = None, + notes: Optional[str] = None, + material: Optional[GEMDMaterialRun] = None, + process: Optional[GEMDProcessRun] = None, + mass_fraction: Optional[ContinuousValue] = None, + volume_fraction: Optional[ContinuousValue] = None, + number_fraction: Optional[ContinuousValue] = None, + absolute_quantity: Optional[ContinuousValue] = None, + spec: Optional[GEMDIngredientSpec] = None, + file_links: Optional[List[FileLink]] = None, + ): if uids is None: uids = dict() super(ObjectRun, self).__init__() - GEMDIngredientRun.__init__(self, uids=uids, tags=tags, notes=notes, - material=material, process=process, - mass_fraction=mass_fraction, volume_fraction=volume_fraction, - number_fraction=number_fraction, - absolute_quantity=absolute_quantity, - spec=spec, file_links=file_links) + GEMDIngredientRun.__init__( + self, + uids=uids, + tags=tags, + notes=notes, + material=material, + process=process, + mass_fraction=mass_fraction, + volume_fraction=volume_fraction, + number_fraction=number_fraction, + absolute_quantity=absolute_quantity, + spec=spec, + file_links=file_links, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class IngredientRunCollection(ObjectRunCollection[IngredientRun]): """Represents the collection of all ingredient runs associated with a dataset.""" - _individual_key = 'ingredient_run' - _collection_key = 'ingredient_runs' + _individual_key = "ingredient_run" + _collection_key = "ingredient_runs" _resource = IngredientRun @classmethod @@ -121,9 +135,9 @@ def get_type(cls) -> Type[IngredientRun]: """Return the resource type in the collection.""" return IngredientRun - def list_by_spec(self, - uid: Union[UUID, str, LinkByUID, GEMDIngredientSpec] - ) -> Iterator[IngredientRun]: + def list_by_spec( + self, uid: Union[UUID, str, LinkByUID, GEMDIngredientSpec] + ) -> Iterator[IngredientRun]: """ Get the ingredient runs using the specified ingredient spec. @@ -138,11 +152,11 @@ def list_by_spec(self, The ingredient runs using the specified ingredient spec. """ - return self._get_relation(relation='ingredient-specs', uid=uid) + return self._get_relation(relation="ingredient-specs", uid=uid) - def list_by_process(self, - uid: Union[UUID, str, LinkByUID, GEMDProcessRun] - ) -> Iterator[IngredientRun]: + def list_by_process( + self, uid: Union[UUID, str, LinkByUID, GEMDProcessRun] + ) -> Iterator[IngredientRun]: """ Get ingredients to a process. @@ -157,11 +171,11 @@ def list_by_process(self, The ingredients to the specified process. """ - return self._get_relation(relation='process-runs', uid=uid) + return self._get_relation(relation="process-runs", uid=uid) - def list_by_material(self, - uid: Union[UUID, str, LinkByUID, GEMDMaterialRun] - ) -> Iterator[IngredientRun]: + def list_by_material( + self, uid: Union[UUID, str, LinkByUID, GEMDMaterialRun] + ) -> Iterator[IngredientRun]: """ Get ingredients using the specified material. @@ -176,4 +190,4 @@ def list_by_material(self, The ingredients using the specified material """ - return self._get_relation(relation='material-runs', uid=uid) + return self._get_relation(relation="material-runs", uid=uid) diff --git a/src/citrine/resources/ingredient_spec.py b/src/citrine/resources/ingredient_spec.py index 9c9e31f20..ede333483 100644 --- a/src/citrine/resources/ingredient_spec.py +++ b/src/citrine/resources/ingredient_spec.py @@ -1,4 +1,5 @@ """Resources that represent ingredient spec data objects.""" + from typing import List, Dict, Optional, Type, Iterator, Union from uuid import UUID @@ -16,10 +17,10 @@ class IngredientSpec( - GEMDResource['IngredientSpec'], + GEMDResource["IngredientSpec"], ObjectSpec, GEMDIngredientSpec, - typ=GEMDIngredientSpec.typ + typ=GEMDIngredientSpec.typ, ): """ An ingredient specification. @@ -65,54 +66,72 @@ class IngredientSpec( _response_key = GEMDIngredientSpec.typ # 'ingredient_spec' - material = PropertyOptional(LinkOrElse(GEMDMaterialSpec), 'material', override=True) - process = PropertyOptional(LinkOrElse(GEMDProcessSpec), - 'process', - override=True, - use_init=True) - mass_fraction = PropertyOptional(Object(ContinuousValue), 'mass_fraction', override=True) - volume_fraction = PropertyOptional(Object(ContinuousValue), 'volume_fraction', override=True) - number_fraction = PropertyOptional(Object(ContinuousValue), 'number_fraction', override=True) - absolute_quantity = PropertyOptional(Object(ContinuousValue), - 'absolute_quantity', - override=True) - name = String('name', override=True, use_init=True) - labels = PropertyOptional(PropertyList(String()), 'labels', override=True, use_init=True) - - def __init__(self, - name: str, - *, - uids: Optional[Dict[str, str]] = None, - tags: Optional[List[str]] = None, - notes: Optional[str] = None, - material: Optional[GEMDMaterialSpec] = None, - process: Optional[GEMDProcessSpec] = None, - mass_fraction: Optional[ContinuousValue] = None, - volume_fraction: Optional[ContinuousValue] = None, - number_fraction: Optional[ContinuousValue] = None, - absolute_quantity: Optional[ContinuousValue] = None, - labels: Optional[List[str]] = None, - file_links: Optional[List[FileLink]] = None): + material = PropertyOptional(LinkOrElse(GEMDMaterialSpec), "material", override=True) + process = PropertyOptional( + LinkOrElse(GEMDProcessSpec), "process", override=True, use_init=True + ) + mass_fraction = PropertyOptional( + Object(ContinuousValue), "mass_fraction", override=True + ) + volume_fraction = PropertyOptional( + Object(ContinuousValue), "volume_fraction", override=True + ) + number_fraction = PropertyOptional( + Object(ContinuousValue), "number_fraction", override=True + ) + absolute_quantity = PropertyOptional( + Object(ContinuousValue), "absolute_quantity", override=True + ) + name = String("name", override=True, use_init=True) + labels = PropertyOptional( + PropertyList(String()), "labels", override=True, use_init=True + ) + + def __init__( + self, + name: str, + *, + uids: Optional[Dict[str, str]] = None, + tags: Optional[List[str]] = None, + notes: Optional[str] = None, + material: Optional[GEMDMaterialSpec] = None, + process: Optional[GEMDProcessSpec] = None, + mass_fraction: Optional[ContinuousValue] = None, + volume_fraction: Optional[ContinuousValue] = None, + number_fraction: Optional[ContinuousValue] = None, + absolute_quantity: Optional[ContinuousValue] = None, + labels: Optional[List[str]] = None, + file_links: Optional[List[FileLink]] = None, + ): if uids is None: uids = dict() super(ObjectSpec, self).__init__() - GEMDIngredientSpec.__init__(self, uids=uids, tags=tags, notes=notes, - material=material, process=process, - mass_fraction=mass_fraction, volume_fraction=volume_fraction, - number_fraction=number_fraction, - absolute_quantity=absolute_quantity, labels=labels, - name=name, file_links=file_links) + GEMDIngredientSpec.__init__( + self, + uids=uids, + tags=tags, + notes=notes, + material=material, + process=process, + mass_fraction=mass_fraction, + volume_fraction=volume_fraction, + number_fraction=number_fraction, + absolute_quantity=absolute_quantity, + labels=labels, + name=name, + file_links=file_links, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class IngredientSpecCollection(ObjectSpecCollection[IngredientSpec]): """Represents the collection of all ingredient specs associated with a dataset.""" - _individual_key = 'ingredient_spec' - _collection_key = 'ingredient_specs' + _individual_key = "ingredient_spec" + _collection_key = "ingredient_specs" _resource = IngredientSpec @classmethod @@ -120,9 +139,9 @@ def get_type(cls) -> Type[IngredientSpec]: """Return the resource type in the collection.""" return IngredientSpec - def list_by_process(self, - uid: Union[UUID, str, LinkByUID, GEMDProcessSpec] - ) -> Iterator[IngredientSpec]: + def list_by_process( + self, uid: Union[UUID, str, LinkByUID, GEMDProcessSpec] + ) -> Iterator[IngredientSpec]: """ Get ingredients to a process. @@ -137,11 +156,11 @@ def list_by_process(self, The ingredients to the specified process. """ - return self._get_relation(relation='process-specs', uid=uid) + return self._get_relation(relation="process-specs", uid=uid) - def list_by_material(self, - uid: Union[UUID, str, LinkByUID, GEMDMaterialSpec] - ) -> Iterator[IngredientSpec]: + def list_by_material( + self, uid: Union[UUID, str, LinkByUID, GEMDMaterialSpec] + ) -> Iterator[IngredientSpec]: """ Get ingredients using the specified material. @@ -156,4 +175,4 @@ def list_by_material(self, The ingredients using the specified material """ - return self._get_relation(relation='material-specs', uid=uid) + return self._get_relation(relation="material-specs", uid=uid) diff --git a/src/citrine/resources/material_run.py b/src/citrine/resources/material_run.py index fb2a62cc0..c5d8c24e2 100644 --- a/src/citrine/resources/material_run.py +++ b/src/citrine/resources/material_run.py @@ -1,4 +1,5 @@ """Resources that represent material run data objects.""" + from typing import List, Dict, Optional, Type, Iterator, Union from uuid import UUID @@ -14,15 +15,14 @@ from gemd.entity.link_by_uid import LinkByUID from gemd.entity.object.material_run import MaterialRun as GEMDMaterialRun from gemd.entity.object.material_spec import MaterialSpec as GEMDMaterialSpec -from gemd.entity.template.material_template import MaterialTemplate as GEMDMaterialTemplate +from gemd.entity.template.material_template import ( + MaterialTemplate as GEMDMaterialTemplate, +) from gemd.entity.object.process_run import ProcessRun as GEMDProcessRun class MaterialRun( - GEMDResource['MaterialRun'], - ObjectRun, - GEMDMaterialRun, - typ=GEMDMaterialRun.typ + GEMDResource["MaterialRun"], ObjectRun, GEMDMaterialRun, typ=GEMDMaterialRun.typ ): """ A material run. @@ -63,46 +63,59 @@ class MaterialRun( _response_key = GEMDMaterialRun.typ # 'material_run' - name = String('name', override=True, use_init=True) - process = PropertyOptional(LinkOrElse(GEMDProcessRun), - 'process', - override=True, - use_init=True,) - sample_type = PropertyOptional(String, 'sample_type', override=True) - spec = PropertyOptional(LinkOrElse(GEMDMaterialSpec), - 'spec', - override=True, - use_init=True,) - - def __init__(self, - name: str, - *, - uids: Optional[Dict[str, str]] = None, - tags: Optional[List[str]] = None, - notes: Optional[str] = None, - process: Optional[GEMDProcessRun] = None, - sample_type: Optional[str] = "unknown", - spec: Optional[GEMDMaterialSpec] = None, - file_links: Optional[List[FileLink]] = None, - default_labels: Optional[List[str]] = None): + name = String("name", override=True, use_init=True) + process = PropertyOptional( + LinkOrElse(GEMDProcessRun), + "process", + override=True, + use_init=True, + ) + sample_type = PropertyOptional(String, "sample_type", override=True) + spec = PropertyOptional( + LinkOrElse(GEMDMaterialSpec), + "spec", + override=True, + use_init=True, + ) + + def __init__( + self, + name: str, + *, + uids: Optional[Dict[str, str]] = None, + tags: Optional[List[str]] = None, + notes: Optional[str] = None, + process: Optional[GEMDProcessRun] = None, + sample_type: Optional[str] = "unknown", + spec: Optional[GEMDMaterialSpec] = None, + file_links: Optional[List[FileLink]] = None, + default_labels: Optional[List[str]] = None, + ): if uids is None: uids = dict() all_tags = _inject_default_label_tags(tags, default_labels) super(ObjectRun, self).__init__() - GEMDMaterialRun.__init__(self, name=name, uids=uids, - tags=all_tags, process=process, - sample_type=sample_type, spec=spec, - file_links=file_links, notes=notes) + GEMDMaterialRun.__init__( + self, + name=name, + uids=uids, + tags=all_tags, + process=process, + sample_type=sample_type, + spec=spec, + file_links=file_links, + notes=notes, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class MaterialRunCollection(ObjectRunCollection[MaterialRun]): """Represents the collection of all material runs associated with a dataset.""" - _individual_key = 'material_run' - _collection_key = 'material_runs' + _individual_key = "material_run" + _collection_key = "material_runs" _resource = MaterialRun @classmethod @@ -134,18 +147,14 @@ def get_history(self, id: Union[str, UUID, LinkByUID, MaterialRun]) -> MaterialR link = _make_link_by_uid(id) path = format_escaped_url( "teams/{}/gemd/query/material-histories?filter_nonroot_materials=true", - self.team_id) + self.team_id, + ) query = { "criteria": [ { "datasets": [str(self.dataset_id)], "type": "terminal_material_run_identifiers_criteria", - "terminal_material_ids": [ - { - "scope": link.scope, - "id": link.id - } - ] + "terminal_material_ids": [{"scope": link.scope, "id": link.id}], } ] } @@ -161,9 +170,9 @@ def get_history(self, id: Union[str, UUID, LinkByUID, MaterialRun]) -> MaterialR else: return None - def get_by_process(self, - uid: Union[UUID, str, LinkByUID, GEMDProcessRun] - ) -> Optional[MaterialRun]: + def get_by_process( + self, uid: Union[UUID, str, LinkByUID, GEMDProcessRun] + ) -> Optional[MaterialRun]: """ Get output material of a process. @@ -179,13 +188,12 @@ def get_by_process(self, """ return next( - self._get_relation(relation='process-runs', uid=uid, per_page=1), - None + self._get_relation(relation="process-runs", uid=uid, per_page=1), None ) - def list_by_spec(self, - uid: Union[UUID, str, LinkByUID, GEMDMaterialSpec] - ) -> Iterator[MaterialRun]: + def list_by_spec( + self, uid: Union[UUID, str, LinkByUID, GEMDMaterialSpec] + ) -> Iterator[MaterialRun]: """ Get the material runs using the specified material spec. @@ -200,11 +208,11 @@ def list_by_spec(self, The material runs using the specified material spec. """ - return self._get_relation('material-specs', uid=uid) + return self._get_relation("material-specs", uid=uid) - def list_by_template(self, - uid: Union[UUID, str, LinkByUID, GEMDMaterialTemplate] - ) -> Iterator[MaterialRun]: + def list_by_template( + self, uid: Union[UUID, str, LinkByUID, GEMDMaterialTemplate] + ) -> Iterator[MaterialRun]: """ Get the material runs using the specified material template. @@ -220,10 +228,9 @@ def list_by_template(self, """ spec_collection = MaterialSpecCollection( - team_id=self.team_id, - dataset_id=self.dataset_id, - session=self.session + team_id=self.team_id, dataset_id=self.dataset_id, session=self.session ) specs = spec_collection.list_by_template(uid=_make_link_by_uid(uid)) - return (run for runs in (self.list_by_spec(spec) for spec in specs) - for run in runs) + return ( + run for runs in (self.list_by_spec(spec) for spec in specs) for run in runs + ) diff --git a/src/citrine/resources/material_spec.py b/src/citrine/resources/material_spec.py index 9a8654cf9..9732d70d3 100644 --- a/src/citrine/resources/material_spec.py +++ b/src/citrine/resources/material_spec.py @@ -1,4 +1,5 @@ """Resources that represent material spec data objects.""" + from typing import List, Dict, Optional, Type, Iterator, Union from uuid import UUID @@ -13,14 +14,13 @@ from gemd.entity.link_by_uid import LinkByUID from gemd.entity.object.material_spec import MaterialSpec as GEMDMaterialSpec from gemd.entity.object.process_spec import ProcessSpec as GEMDProcessSpec -from gemd.entity.template.material_template import MaterialTemplate as GEMDMaterialTemplate +from gemd.entity.template.material_template import ( + MaterialTemplate as GEMDMaterialTemplate, +) class MaterialSpec( - GEMDResource['MaterialSpec'], - ObjectSpec, - GEMDMaterialSpec, - typ=GEMDMaterialSpec.typ + GEMDResource["MaterialSpec"], ObjectSpec, GEMDMaterialSpec, typ=GEMDMaterialSpec.typ ): """ A material specification. @@ -58,48 +58,61 @@ class MaterialSpec( _response_key = GEMDMaterialSpec.typ # 'material_spec' - name = String('name', override=True, use_init=True) - process = PropertyOptional(LinkOrElse(GEMDProcessSpec), - 'process', - override=True, - use_init=True, - ) - properties = PropertyOptional(PropertyList(Object(PropertyAndConditions)), - 'properties', - override=True) - template = PropertyOptional(LinkOrElse(GEMDMaterialTemplate), - 'template', - override=True, - use_init=True,) - - def __init__(self, - name: str, - *, - uids: Optional[Dict[str, str]] = None, - tags: Optional[List[str]] = None, - notes: Optional[str] = None, - process: Optional[GEMDProcessSpec] = None, - properties: Optional[List[PropertyAndConditions]] = None, - template: Optional[GEMDMaterialTemplate] = None, - file_links: Optional[List[FileLink]] = None, - default_labels: Optional[List[str]] = None): + name = String("name", override=True, use_init=True) + process = PropertyOptional( + LinkOrElse(GEMDProcessSpec), + "process", + override=True, + use_init=True, + ) + properties = PropertyOptional( + PropertyList(Object(PropertyAndConditions)), "properties", override=True + ) + template = PropertyOptional( + LinkOrElse(GEMDMaterialTemplate), + "template", + override=True, + use_init=True, + ) + + def __init__( + self, + name: str, + *, + uids: Optional[Dict[str, str]] = None, + tags: Optional[List[str]] = None, + notes: Optional[str] = None, + process: Optional[GEMDProcessSpec] = None, + properties: Optional[List[PropertyAndConditions]] = None, + template: Optional[GEMDMaterialTemplate] = None, + file_links: Optional[List[FileLink]] = None, + default_labels: Optional[List[str]] = None, + ): if uids is None: uids = dict() all_tags = _inject_default_label_tags(tags, default_labels) super(ObjectSpec, self).__init__() - GEMDMaterialSpec.__init__(self, name=name, uids=uids, - tags=all_tags, process=process, properties=properties, - template=template, file_links=file_links, notes=notes) + GEMDMaterialSpec.__init__( + self, + name=name, + uids=uids, + tags=all_tags, + process=process, + properties=properties, + template=template, + file_links=file_links, + notes=notes, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class MaterialSpecCollection(ObjectSpecCollection[MaterialSpec]): """Represents the collection of all material specs associated with a dataset.""" - _individual_key = 'material_spec' - _collection_key = 'material_specs' + _individual_key = "material_spec" + _collection_key = "material_specs" _resource = MaterialSpec @classmethod @@ -107,9 +120,9 @@ def get_type(cls) -> Type[MaterialSpec]: """Return the resource type in the collection.""" return MaterialSpec - def list_by_template(self, - uid: Union[UUID, str, LinkByUID, GEMDMaterialTemplate] - ) -> Iterator[MaterialSpec]: + def list_by_template( + self, uid: Union[UUID, str, LinkByUID, GEMDMaterialTemplate] + ) -> Iterator[MaterialSpec]: """ Get the material specs using the specified material template. @@ -124,11 +137,11 @@ def list_by_template(self, The material specs using the specified material template. """ - return self._get_relation('material-templates', uid=uid) + return self._get_relation("material-templates", uid=uid) - def get_by_process(self, - uid: Union[UUID, str, LinkByUID, GEMDProcessSpec] - ) -> Optional[MaterialSpec]: + def get_by_process( + self, uid: Union[UUID, str, LinkByUID, GEMDProcessSpec] + ) -> Optional[MaterialSpec]: """ Get output material of a process. @@ -144,10 +157,5 @@ def get_by_process(self, """ return next( - self._get_relation( - relation='process-specs', - uid=uid, - per_page=1 - ), - None + self._get_relation(relation="process-specs", uid=uid, per_page=1), None ) diff --git a/src/citrine/resources/material_template.py b/src/citrine/resources/material_template.py index ea90bff58..524d25aa6 100644 --- a/src/citrine/resources/material_template.py +++ b/src/citrine/resources/material_template.py @@ -1,4 +1,5 @@ """Resources that represent material templates.""" + from typing import List, Dict, Optional, Union, Sequence, Type from citrine._rest.resource import GEMDResource @@ -10,15 +11,19 @@ from citrine.resources.property_template import PropertyTemplate from gemd.entity.bounds.base_bounds import BaseBounds from gemd.entity.link_by_uid import LinkByUID -from gemd.entity.template.material_template import MaterialTemplate as GEMDMaterialTemplate -from gemd.entity.template.property_template import PropertyTemplate as GEMDPropertyTemplate +from gemd.entity.template.material_template import ( + MaterialTemplate as GEMDMaterialTemplate, +) +from gemd.entity.template.property_template import ( + PropertyTemplate as GEMDPropertyTemplate, +) class MaterialTemplate( - GEMDResource['MaterialTemplate'], + GEMDResource["MaterialTemplate"], ObjectTemplate, GEMDMaterialTemplate, - typ=GEMDMaterialTemplate.typ + typ=GEMDMaterialTemplate.typ, ): """ A material template. @@ -53,42 +58,63 @@ class MaterialTemplate( properties = PropertyOptional( PropertyList( - PropertyUnion([LinkOrElse(GEMDPropertyTemplate), - SpecifiedMixedList([LinkOrElse(GEMDPropertyTemplate), - PropertyOptional(Object(BaseBounds))])] - ) - ), 'properties', override=True) + PropertyUnion( + [ + LinkOrElse(GEMDPropertyTemplate), + SpecifiedMixedList( + [ + LinkOrElse(GEMDPropertyTemplate), + PropertyOptional(Object(BaseBounds)), + ] + ), + ] + ) + ), + "properties", + override=True, + ) - def __init__(self, - name: str, - *, - uids: Optional[Dict[str, str]] = None, - properties: Optional[Sequence[Union[PropertyTemplate, - LinkByUID, - Sequence[Union[PropertyTemplate, LinkByUID, - Optional[BaseBounds]]] - ]]] = None, - description: Optional[str] = None, - tags: Optional[List[str]] = None): + def __init__( + self, + name: str, + *, + uids: Optional[Dict[str, str]] = None, + properties: Optional[ + Sequence[ + Union[ + PropertyTemplate, + LinkByUID, + Sequence[Union[PropertyTemplate, LinkByUID, Optional[BaseBounds]]], + ] + ] + ] = None, + description: Optional[str] = None, + tags: Optional[List[str]] = None, + ): # properties is a list, each element of which is a PropertyTemplate OR is a list with # 2 entries: [PropertyTemplate, BaseBounds]. Python typing is not expressive enough, so # the typing above is more general. if uids is None: uids = dict() super(ObjectTemplate, self).__init__() - GEMDMaterialTemplate.__init__(self, name=name, properties=properties, - uids=uids, tags=tags, - description=description) + GEMDMaterialTemplate.__init__( + self, + name=name, + properties=properties, + uids=uids, + tags=tags, + description=description, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class MaterialTemplateCollection(ObjectTemplateCollection[MaterialTemplate]): """A collection of material templates.""" - _individual_key = 'material_template' - _collection_key = 'material_templates' + _individual_key = "material_template" + _collection_key = "material_templates" _resource = MaterialTemplate @classmethod diff --git a/src/citrine/resources/measurement_run.py b/src/citrine/resources/measurement_run.py index 02440dc9c..85e51efa8 100644 --- a/src/citrine/resources/measurement_run.py +++ b/src/citrine/resources/measurement_run.py @@ -1,4 +1,5 @@ """Resources that represent measurement run data objects.""" + from typing import List, Dict, Optional, Type, Iterator, Union from uuid import UUID @@ -19,10 +20,10 @@ class MeasurementRun( - GEMDResource['MeasurementRun'], + GEMDResource["MeasurementRun"], ObjectRun, GEMDMeasurementRun, - typ=GEMDMeasurementRun.typ + typ=GEMDMeasurementRun.typ, ): """ A measurement run. @@ -62,49 +63,72 @@ class MeasurementRun( _response_key = GEMDMeasurementRun.typ # 'measurement_run' - name = String('name', override=True, use_init=True) - conditions = PropertyOptional(PropertyList(Object(Condition)), 'conditions', override=True) - parameters = PropertyOptional(PropertyList(Object(Parameter)), 'parameters', override=True) - properties = PropertyOptional(PropertyList(Object(Property)), 'properties', override=True) - spec = PropertyOptional(LinkOrElse(GEMDMeasurementSpec), 'spec', override=True, use_init=True,) - material = PropertyOptional(LinkOrElse(GEMDMaterialRun), - "material", - override=True, - use_init=True, - ) + name = String("name", override=True, use_init=True) + conditions = PropertyOptional( + PropertyList(Object(Condition)), "conditions", override=True + ) + parameters = PropertyOptional( + PropertyList(Object(Parameter)), "parameters", override=True + ) + properties = PropertyOptional( + PropertyList(Object(Property)), "properties", override=True + ) + spec = PropertyOptional( + LinkOrElse(GEMDMeasurementSpec), + "spec", + override=True, + use_init=True, + ) + material = PropertyOptional( + LinkOrElse(GEMDMaterialRun), + "material", + override=True, + use_init=True, + ) source = PropertyOptional(Object(PerformedSource), "source", override=True) - def __init__(self, - name: str, - *, - uids: Optional[Dict[str, str]] = None, - tags: Optional[List[str]] = None, - notes: Optional[str] = None, - conditions: Optional[List[Condition]] = None, - properties: Optional[List[Property]] = None, - parameters: Optional[List[Parameter]] = None, - spec: Optional[GEMDMeasurementSpec] = None, - material: Optional[GEMDMaterialRun] = None, - file_links: Optional[List[FileLink]] = None, - source: Optional[PerformedSource] = None): + def __init__( + self, + name: str, + *, + uids: Optional[Dict[str, str]] = None, + tags: Optional[List[str]] = None, + notes: Optional[str] = None, + conditions: Optional[List[Condition]] = None, + properties: Optional[List[Property]] = None, + parameters: Optional[List[Parameter]] = None, + spec: Optional[GEMDMeasurementSpec] = None, + material: Optional[GEMDMaterialRun] = None, + file_links: Optional[List[FileLink]] = None, + source: Optional[PerformedSource] = None, + ): if uids is None: uids = dict() super(ObjectRun, self).__init__() - GEMDMeasurementRun.__init__(self, name=name, uids=uids, - material=material, - tags=tags, conditions=conditions, properties=properties, - parameters=parameters, spec=spec, - file_links=file_links, notes=notes, source=source) + GEMDMeasurementRun.__init__( + self, + name=name, + uids=uids, + material=material, + tags=tags, + conditions=conditions, + properties=properties, + parameters=parameters, + spec=spec, + file_links=file_links, + notes=notes, + source=source, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class MeasurementRunCollection(ObjectRunCollection[MeasurementRun]): """Represents the collection of all measurement runs associated with a dataset.""" - _individual_key = 'measurement_run' - _collection_key = 'measurement_runs' + _individual_key = "measurement_run" + _collection_key = "measurement_runs" _resource = MeasurementRun @classmethod @@ -112,9 +136,9 @@ def get_type(cls) -> Type[MeasurementRun]: """Return the resource type in the collection.""" return MeasurementRun - def list_by_spec(self, - uid: Union[UUID, str, LinkByUID, GEMDMeasurementSpec] - ) -> Iterator[MeasurementRun]: + def list_by_spec( + self, uid: Union[UUID, str, LinkByUID, GEMDMeasurementSpec] + ) -> Iterator[MeasurementRun]: """ Get the measurement runs using the specified measurement spec. @@ -129,11 +153,11 @@ def list_by_spec(self, The measurement runs using the specified measurement spec. """ - return self._get_relation('measurement-specs', uid=uid) + return self._get_relation("measurement-specs", uid=uid) - def list_by_material(self, - uid: Union[UUID, str, LinkByUID, GEMDMaterialRun] - ) -> Iterator[MeasurementRun]: + def list_by_material( + self, uid: Union[UUID, str, LinkByUID, GEMDMaterialRun] + ) -> Iterator[MeasurementRun]: """ Get measurements of the specified material. @@ -148,4 +172,4 @@ def list_by_material(self, The measurements of the specified material """ - return self._get_relation(relation='material-runs', uid=uid) + return self._get_relation(relation="material-runs", uid=uid) diff --git a/src/citrine/resources/measurement_spec.py b/src/citrine/resources/measurement_spec.py index 484026243..5d46d9cf8 100644 --- a/src/citrine/resources/measurement_spec.py +++ b/src/citrine/resources/measurement_spec.py @@ -1,4 +1,5 @@ """Resources that represent measurement spec data objects.""" + from typing import List, Dict, Optional, Type, Union, Iterator from uuid import UUID @@ -12,15 +13,16 @@ from gemd.entity.file_link import FileLink from gemd.entity.link_by_uid import LinkByUID from gemd.entity.object.measurement_spec import MeasurementSpec as GEMDMeasurementSpec -from gemd.entity.template.measurement_template import \ - MeasurementTemplate as GEMDMeasurementTemplate +from gemd.entity.template.measurement_template import ( + MeasurementTemplate as GEMDMeasurementTemplate, +) class MeasurementSpec( - GEMDResource['MeasurementSpec'], + GEMDResource["MeasurementSpec"], ObjectSpec, GEMDMeasurementSpec, - typ=GEMDMeasurementSpec.typ + typ=GEMDMeasurementSpec.typ, ): """ A measurement specification. @@ -53,41 +55,56 @@ class MeasurementSpec( _response_key = GEMDMeasurementSpec.typ # 'measurement_spec' - name = String('name', override=True, use_init=True) - conditions = PropertyOptional(PropertyList(Object(Condition)), 'conditions', override=True) - parameters = PropertyOptional(PropertyList(Object(Parameter)), 'parameters', override=True) - template = PropertyOptional(LinkOrElse(GEMDMeasurementTemplate), - 'template', - override=True, - use_init=True, - ) - - def __init__(self, - name: str, - *, - uids: Optional[Dict[str, str]] = None, - tags: Optional[List[str]] = None, - notes: Optional[str] = None, - conditions: Optional[List[Condition]] = None, - parameters: Optional[List[Parameter]] = None, - template: Optional[GEMDMeasurementTemplate] = None, - file_links: Optional[List[FileLink]] = None): + name = String("name", override=True, use_init=True) + conditions = PropertyOptional( + PropertyList(Object(Condition)), "conditions", override=True + ) + parameters = PropertyOptional( + PropertyList(Object(Parameter)), "parameters", override=True + ) + template = PropertyOptional( + LinkOrElse(GEMDMeasurementTemplate), + "template", + override=True, + use_init=True, + ) + + def __init__( + self, + name: str, + *, + uids: Optional[Dict[str, str]] = None, + tags: Optional[List[str]] = None, + notes: Optional[str] = None, + conditions: Optional[List[Condition]] = None, + parameters: Optional[List[Parameter]] = None, + template: Optional[GEMDMeasurementTemplate] = None, + file_links: Optional[List[FileLink]] = None, + ): if uids is None: uids = dict() super(ObjectSpec, self).__init__() - GEMDMeasurementSpec.__init__(self, name=name, uids=uids, - tags=tags, conditions=conditions, parameters=parameters, - template=template, file_links=file_links, notes=notes) + GEMDMeasurementSpec.__init__( + self, + name=name, + uids=uids, + tags=tags, + conditions=conditions, + parameters=parameters, + template=template, + file_links=file_links, + notes=notes, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class MeasurementSpecCollection(ObjectSpecCollection[MeasurementSpec]): """Represents the collection of all measurement specs associated with a dataset.""" - _individual_key = 'measurement_spec' - _collection_key = 'measurement_specs' + _individual_key = "measurement_spec" + _collection_key = "measurement_specs" _resource = MeasurementSpec @classmethod @@ -95,9 +112,9 @@ def get_type(cls) -> Type[MeasurementSpec]: """Return the resource type in the collection.""" return MeasurementSpec - def list_by_template(self, - uid: Union[UUID, str, LinkByUID, GEMDMeasurementTemplate] - ) -> Iterator[MeasurementSpec]: + def list_by_template( + self, uid: Union[UUID, str, LinkByUID, GEMDMeasurementTemplate] + ) -> Iterator[MeasurementSpec]: """ Get the measurement specs using the specified measurement template. @@ -113,4 +130,4 @@ def list_by_template(self, The measurement specs using the specified measurement template. """ - return self._get_relation('measurement-templates', uid=uid) + return self._get_relation("measurement-templates", uid=uid) diff --git a/src/citrine/resources/measurement_template.py b/src/citrine/resources/measurement_template.py index e6825f3cd..a913429a5 100644 --- a/src/citrine/resources/measurement_template.py +++ b/src/citrine/resources/measurement_template.py @@ -1,4 +1,5 @@ """Resources that represent measurement templates.""" + from typing import List, Dict, Optional, Union, Sequence, Type from citrine._rest.resource import GEMDResource @@ -12,18 +13,25 @@ from citrine.resources.property_template import PropertyTemplate from gemd.entity.bounds.base_bounds import BaseBounds from gemd.entity.link_by_uid import LinkByUID -from gemd.entity.template.measurement_template \ - import MeasurementTemplate as GEMDMeasurementTemplate -from gemd.entity.template.condition_template import ConditionTemplate as GEMDConditionTemplate -from gemd.entity.template.parameter_template import ParameterTemplate as GEMDParameterTemplate -from gemd.entity.template.property_template import PropertyTemplate as GEMDPropertyTemplate +from gemd.entity.template.measurement_template import ( + MeasurementTemplate as GEMDMeasurementTemplate, +) +from gemd.entity.template.condition_template import ( + ConditionTemplate as GEMDConditionTemplate, +) +from gemd.entity.template.parameter_template import ( + ParameterTemplate as GEMDParameterTemplate, +) +from gemd.entity.template.property_template import ( + PropertyTemplate as GEMDPropertyTemplate, +) class MeasurementTemplate( - GEMDResource['MeasurementTemplate'], + GEMDResource["MeasurementTemplate"], ObjectTemplate, GEMDMeasurementTemplate, - typ=GEMDMeasurementTemplate.typ + typ=GEMDMeasurementTemplate.typ, ): """ A measurement template. @@ -68,72 +76,114 @@ class MeasurementTemplate( properties = PropertyOptional( PropertyList( - PropertyUnion([LinkOrElse(GEMDPropertyTemplate), - SpecifiedMixedList([LinkOrElse(GEMDPropertyTemplate), - PropertyOptional(Object(BaseBounds))])] - ) + PropertyUnion( + [ + LinkOrElse(GEMDPropertyTemplate), + SpecifiedMixedList( + [ + LinkOrElse(GEMDPropertyTemplate), + PropertyOptional(Object(BaseBounds)), + ] + ), + ] + ) ), - 'properties', - override=True + "properties", + override=True, ) conditions = PropertyOptional( PropertyList( - PropertyUnion([LinkOrElse(GEMDConditionTemplate), - SpecifiedMixedList([LinkOrElse(GEMDConditionTemplate), - PropertyOptional(Object(BaseBounds))])] - ) + PropertyUnion( + [ + LinkOrElse(GEMDConditionTemplate), + SpecifiedMixedList( + [ + LinkOrElse(GEMDConditionTemplate), + PropertyOptional(Object(BaseBounds)), + ] + ), + ] + ) ), - 'conditions', - override=True + "conditions", + override=True, ) parameters = PropertyOptional( PropertyList( - PropertyUnion([LinkOrElse(GEMDParameterTemplate), - SpecifiedMixedList([LinkOrElse(GEMDParameterTemplate), - PropertyOptional(Object(BaseBounds))])] - ) + PropertyUnion( + [ + LinkOrElse(GEMDParameterTemplate), + SpecifiedMixedList( + [ + LinkOrElse(GEMDParameterTemplate), + PropertyOptional(Object(BaseBounds)), + ] + ), + ] + ) ), - 'parameters', - override=True + "parameters", + override=True, ) - def __init__(self, - name: str, - *, - uids: Optional[Dict[str, str]] = None, - properties: Optional[Sequence[Union[PropertyTemplate, - LinkByUID, - Sequence[Union[PropertyTemplate, LinkByUID, - Optional[BaseBounds]]] - ]]] = None, - conditions: Optional[Sequence[Union[ConditionTemplate, - LinkByUID, - Sequence[Union[ConditionTemplate, LinkByUID, - Optional[BaseBounds]]] - ]]] = None, - parameters: Optional[Sequence[Union[ParameterTemplate, - LinkByUID, - Sequence[Union[ParameterTemplate, LinkByUID, - Optional[BaseBounds]]] - ]]] = None, - description: Optional[str] = None, - tags: Optional[List[str]] = None): + def __init__( + self, + name: str, + *, + uids: Optional[Dict[str, str]] = None, + properties: Optional[ + Sequence[ + Union[ + PropertyTemplate, + LinkByUID, + Sequence[Union[PropertyTemplate, LinkByUID, Optional[BaseBounds]]], + ] + ] + ] = None, + conditions: Optional[ + Sequence[ + Union[ + ConditionTemplate, + LinkByUID, + Sequence[Union[ConditionTemplate, LinkByUID, Optional[BaseBounds]]], + ] + ] + ] = None, + parameters: Optional[ + Sequence[ + Union[ + ParameterTemplate, + LinkByUID, + Sequence[Union[ParameterTemplate, LinkByUID, Optional[BaseBounds]]], + ] + ] + ] = None, + description: Optional[str] = None, + tags: Optional[List[str]] = None, + ): if uids is None: uids = dict() super(ObjectTemplate, self).__init__() - GEMDMeasurementTemplate.__init__(self, name=name, properties=properties, - conditions=conditions, parameters=parameters, tags=tags, - uids=uids, description=description) + GEMDMeasurementTemplate.__init__( + self, + name=name, + properties=properties, + conditions=conditions, + parameters=parameters, + tags=tags, + uids=uids, + description=description, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class MeasurementTemplateCollection(ObjectTemplateCollection[MeasurementTemplate]): """A collection of measurement templates.""" - _individual_key = 'measurement_template' - _collection_key = 'measurement_templates' + _individual_key = "measurement_template" + _collection_key = "measurement_templates" _resource = MeasurementTemplate @classmethod diff --git a/src/citrine/resources/object_runs.py b/src/citrine/resources/object_runs.py index 79c90b84d..f728475e5 100644 --- a/src/citrine/resources/object_runs.py +++ b/src/citrine/resources/object_runs.py @@ -1,4 +1,5 @@ """Top-level class for all object run objects and collections thereof.""" + from abc import ABC from typing import TypeVar diff --git a/src/citrine/resources/object_specs.py b/src/citrine/resources/object_specs.py index 893492355..a01ad9b82 100644 --- a/src/citrine/resources/object_specs.py +++ b/src/citrine/resources/object_specs.py @@ -1,4 +1,5 @@ """Top-level class for all object spec objects and collections thereof.""" + from abc import ABC from typing import TypeVar diff --git a/src/citrine/resources/object_templates.py b/src/citrine/resources/object_templates.py index 00836c9a2..fed3fa530 100644 --- a/src/citrine/resources/object_templates.py +++ b/src/citrine/resources/object_templates.py @@ -1,4 +1,5 @@ """Top-level class for all object template objects and collections thereof.""" + from abc import ABC from typing import TypeVar @@ -15,11 +16,13 @@ class ObjectTemplate(Template, GEMDTemplate, ABC): ObjectTemplate must be extended along with `Resource` """ - name = String('name') - description = PropertyOptional(String(), 'description') + name = String("name") + description = PropertyOptional(String(), "description") -ObjectTemplateResourceType = TypeVar("ObjectTemplateResourceType", bound="ObjectTemplate") +ObjectTemplateResourceType = TypeVar( + "ObjectTemplateResourceType", bound="ObjectTemplate" +) class ObjectTemplateCollection(TemplateCollection[ObjectTemplateResourceType], ABC): diff --git a/src/citrine/resources/parameter_template.py b/src/citrine/resources/parameter_template.py index 3227307fe..ccf6e48b2 100644 --- a/src/citrine/resources/parameter_template.py +++ b/src/citrine/resources/parameter_template.py @@ -1,17 +1,23 @@ """Resources that represent parameter templates.""" + from typing import List, Dict, Optional, Type from citrine._rest.resource import GEMDResource -from citrine.resources.attribute_templates import AttributeTemplate, AttributeTemplateCollection +from citrine.resources.attribute_templates import ( + AttributeTemplate, + AttributeTemplateCollection, +) from gemd.entity.bounds.base_bounds import BaseBounds -from gemd.entity.template.parameter_template import ParameterTemplate as GEMDParameterTemplate +from gemd.entity.template.parameter_template import ( + ParameterTemplate as GEMDParameterTemplate, +) class ParameterTemplate( - GEMDResource['ParameterTemplate'], + GEMDResource["ParameterTemplate"], AttributeTemplate, GEMDParameterTemplate, - typ=GEMDParameterTemplate.typ + typ=GEMDParameterTemplate.typ, ): """ A parameter template. @@ -37,28 +43,36 @@ class ParameterTemplate( _response_key = GEMDParameterTemplate.typ # 'parameter_template' - def __init__(self, - name: str, - *, - bounds: BaseBounds, - uids: Optional[Dict[str, str]] = None, - description: Optional[str] = None, - tags: Optional[List[str]] = None): + def __init__( + self, + name: str, + *, + bounds: BaseBounds, + uids: Optional[Dict[str, str]] = None, + description: Optional[str] = None, + tags: Optional[List[str]] = None, + ): if uids is None: uids = dict() super(AttributeTemplate, self).__init__() - GEMDParameterTemplate.__init__(self, name=name, bounds=bounds, tags=tags, - uids=uids, description=description) + GEMDParameterTemplate.__init__( + self, + name=name, + bounds=bounds, + tags=tags, + uids=uids, + description=description, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class ParameterTemplateCollection(AttributeTemplateCollection[ParameterTemplate]): """A collection of parameter templates.""" - _individual_key = 'parameter_template' - _collection_key = 'parameter_templates' + _individual_key = "parameter_template" + _collection_key = "parameter_templates" _resource = ParameterTemplate @classmethod diff --git a/src/citrine/resources/predictor.py b/src/citrine/resources/predictor.py index 08b69aaf0..5c1071d5e 100644 --- a/src/citrine/resources/predictor.py +++ b/src/citrine/resources/predictor.py @@ -1,4 +1,5 @@ """Resources that represent collections of predictors.""" + from functools import partial from typing import Any, Iterable, Optional, Union, List from uuid import UUID @@ -24,17 +25,23 @@ class AsyncDefaultPredictor(Resource["AsyncDefaultPredictor"]): """Return type for async default predictor generation and retrieval.""" - uid = properties.UUID('id', serializable=False) + uid = properties.UUID("id", serializable=False) """:UUID: Citrine Platform unique identifier for this task.""" - predictor = properties.Optional(properties.Object(GraphPredictor), 'data', serializable=False) + predictor = properties.Optional( + properties.Object(GraphPredictor), "data", serializable=False + ) """:Optional[GraphPredictor]:""" - status = properties.String('metadata.status', serializable=False) + status = properties.String("metadata.status", serializable=False) """:str: short description of the resource's status""" - status_detail = properties.List(properties.Object(StatusDetail), 'metadata.status_detail', - default=[], serializable=False) + status_detail = properties.List( + properties.Object(StatusDetail), + "metadata.status_detail", + default=[], + serializable=False, + ) """:List[StatusDetail]: a list of structured status info, containing the message and level""" @classmethod @@ -53,9 +60,9 @@ class AutoConfigureMode(BaseEnumeration): * INFER auto-detects the GEM table and predictor type """ - PLAIN = 'PLAIN' - FORMULATION = 'FORMULATION' - INFER = 'INFER' + PLAIN = "PLAIN" + FORMULATION = "FORMULATION" + INFER = "INFER" class _PredictorVersionPaginator(Paginator): @@ -70,11 +77,11 @@ def paginate(self, *args, **kwargs) -> Iterable[GraphPredictor]: class _PredictorVersionCollection(Collection[GraphPredictor]): - _api_version = 'v3' - _path_template = '/projects/{project_id}/predictors/{uid}/versions' + _api_version = "v3" + _path_template = "/projects/{project_id}/predictors/{uid}/versions" _individual_key = None _resource = GraphPredictor - _collection_key = 'response' + _collection_key = "response" _paginator: Paginator = _PredictorVersionPaginator() _SPECIAL_VERSIONS = [LATEST_VER, MOST_RECENT_VER] @@ -83,17 +90,22 @@ def __init__(self, project_id: UUID, session: Session): self.project_id = project_id self.session: Session = session - def _construct_path(self, - uid: Union[UUID, str], - version: Optional[Union[int, str]] = None, - action: str = None) -> str: + def _construct_path( + self, + uid: Union[UUID, str], + version: Optional[Union[int, str]] = None, + action: str = None, + ) -> str: path = self._path_template.format(project_id=self.project_id, uid=str(uid)) if version is not None: version_str = str(version) - if version_str not in self._SPECIAL_VERSIONS \ - and (not version_str.isdecimal() or int(version_str) <= 0): - raise ValueError("A predictor version must either be a positive integer, " - f"\"{LATEST_VER}\", or \"{MOST_RECENT_VER}\".") + if version_str not in self._SPECIAL_VERSIONS and ( + not version_str.isdecimal() or int(version_str) <= 0 + ): + raise ValueError( + "A predictor version must either be a positive integer, " + f'"{LATEST_VER}", or "{MOST_RECENT_VER}".' + ) path += f"/{version_str}" path += f"/{action}" if action else "" @@ -102,7 +114,7 @@ def _construct_path(self, def _page_fetcher(self, *, uid: Union[UUID, str], **additional_params): fetcher_params = { "path": self._construct_path(uid), - "additional_params": additional_params + "additional_params": additional_params, } return partial(self._fetch_page, **fetcher_params) @@ -113,90 +125,87 @@ def build(self, data: dict) -> GraphPredictor: predictor._project_id = self.project_id return predictor - def get(self, - uid: Union[UUID, str], - *, - version: Union[int, str] = MOST_RECENT_VER) -> GraphPredictor: + def get( + self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER + ) -> GraphPredictor: path = self._construct_path(uid, version) entity = self.session.get_resource(path, version=self._api_version) return self.build(entity) def get_featurized_training_data( - self, - uid: Union[UUID, str], - *, - version: Union[int, str] = MOST_RECENT_VER + self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER ) -> List[HierarchicalDesignMaterial]: version_path = self._construct_path(uid, version) full_path = f"{version_path}/featurized-training-data" payload = self.session.get_resource(full_path, version=self._api_version) return [HierarchicalDesignMaterial.build(x) for x in payload] - def list(self, - uid: Union[UUID, str], - *, - per_page: int = 100) -> Iterable[GraphPredictor]: + def list( + self, uid: Union[UUID, str], *, per_page: int = 100 + ) -> Iterable[GraphPredictor]: """List non-archived versions of the given predictor.""" page_fetcher = self._page_fetcher(uid=uid) - return self._paginator.paginate(page_fetcher=page_fetcher, - collection_builder=self._build_collection_elements, - per_page=per_page) - - def list_archived(self, - uid: Union[UUID, str], - *, - per_page: int = 20) -> Iterable[GraphPredictor]: + return self._paginator.paginate( + page_fetcher=page_fetcher, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) + + def list_archived( + self, uid: Union[UUID, str], *, per_page: int = 20 + ) -> Iterable[GraphPredictor]: """List archived versions of the given predictor.""" page_fetcher = self._page_fetcher(uid=uid, filter="archived eq 'true'") - return self._paginator.paginate(page_fetcher=page_fetcher, - collection_builder=self._build_collection_elements, - per_page=per_page) - - def archive(self, - uid: Union[UUID, str], - *, - version: Union[int, str] = MOST_RECENT_VER) -> GraphPredictor: + return self._paginator.paginate( + page_fetcher=page_fetcher, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) + + def archive( + self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER + ) -> GraphPredictor: url = self._construct_path(uid, version, "archive") entity = self.session.put_resource(url, {}, version=self._api_version) return self.build(entity) - def restore(self, - uid: Union[UUID, str], - *, - version: Union[int, str] = MOST_RECENT_VER) -> GraphPredictor: + def restore( + self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER + ) -> GraphPredictor: url = self._construct_path(uid, version, "restore") entity = self.session.put_resource(url, {}, version=self._api_version) return self.build(entity) - def is_stale(self, - uid: Union[UUID, str], - *, - version: Union[int, str] = MOST_RECENT_VER) -> bool: + def is_stale( + self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER + ) -> bool: path = self._construct_path(uid, version, "is-stale") response = self.session.get_resource(path, version=self._api_version) return response["is_stale"] - def retrain_stale(self, - uid: Union[UUID, str], - *, - version: Union[int, str] = MOST_RECENT_VER) -> GraphPredictor: + def retrain_stale( + self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER + ) -> GraphPredictor: path = self._construct_path(uid, version, "retrain-stale") entity = self.session.put_resource(path, {}, version=self._api_version) return self.build(entity) - def rename(self, - uid: Union[UUID, str], - *, - version: Union[int, str], - name: Optional[str] = None, - description: Optional[str] = None - ) -> GraphPredictor: + def rename( + self, + uid: Union[UUID, str], + *, + version: Union[int, str], + name: Optional[str] = None, + description: Optional[str] = None, + ) -> GraphPredictor: path = self._construct_path(uid, version, "rename") json = {"name": name, "description": description} entity = self.session.put_resource(path, json, version=self._api_version) return self.build(entity) - def delete(self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER): + def delete( + self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER + ): """Predictor versions cannot be deleted at this time.""" msg = "Predictor versions cannot be deleted. Use 'archive_version' instead." raise NotImplementedError(msg) @@ -212,11 +221,11 @@ class PredictorCollection(Collection[GraphPredictor]): """ - _api_version = 'v3' - _path_template = '/projects/{project_id}/predictors' + _api_version = "v3" + _path_template = "/projects/{project_id}/predictors" _individual_key = None _resource = GraphPredictor - _collection_key = 'response' + _collection_key = "response" def __init__(self, project_id: UUID, session: Session): self.project_id = project_id @@ -230,23 +239,21 @@ def build(self, data: dict) -> GraphPredictor: predictor._project_id = self.project_id return predictor - def get(self, - uid: Union[UUID, str], - *, - version: Union[int, str] = MOST_RECENT_VER) -> GraphPredictor: + def get( + self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER + ) -> GraphPredictor: """Get a predictor by ID and (optionally) version. If version is omitted, the most recent version will be retrieved. """ if uid is None: - raise ValueError("Cannot get when uid=None. Are you using a registered resource?") + raise ValueError( + "Cannot get when uid=None. Are you using a registered resource?" + ) return self._versions_collection.get(uid=uid, version=version) def get_featurized_training_data( - self, - uid: Union[UUID, str], - *, - version: Union[int, str] = MOST_RECENT_VER + self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER ) -> List[HierarchicalDesignMaterial]: """Retrieve a list of featurized materials for a trained predictor. @@ -266,9 +273,13 @@ def get_featurized_training_data( A list of featurized materials, formatted as design materials """ - return self._versions_collection.get_featurized_training_data(uid=uid, version=version) + return self._versions_collection.get_featurized_training_data( + uid=uid, version=version + ) - def register(self, predictor: GraphPredictor, *, train: bool = True) -> GraphPredictor: + def register( + self, predictor: GraphPredictor, *, train: bool = True + ) -> GraphPredictor: """Register and optionally train a Predictor. This predctor will be version 1, and its `draft` flag will be `True`. If train is True and @@ -282,7 +293,9 @@ def register(self, predictor: GraphPredictor, *, train: bool = True) -> GraphPre else: return self.train(created_predictor.uid) - def update(self, predictor: GraphPredictor, *, train: bool = True) -> GraphPredictor: + def update( + self, predictor: GraphPredictor, *, train: bool = True + ) -> GraphPredictor: """Update and optionally train a Predictor. If the predictor is a draft, this will overwrite its contents. If it's not a draft, a new @@ -307,23 +320,19 @@ def train(self, uid: Union[UUID, str]) -> GraphPredictor: """ path = self._get_path(uid, action="train") params = {"create_version": True} - entity = self.session.put_resource(path, {}, params=params, version=self._api_version) + entity = self.session.put_resource( + path, {}, params=params, version=self._api_version + ) return self.build(entity) def archive_version( - self, - uid: Union[UUID, str], - *, - version: Union[int, str] + self, uid: Union[UUID, str], *, version: Union[int, str] ) -> GraphPredictor: """Archive a predictor version.""" return self._versions_collection.archive(uid, version=version) def restore_version( - self, - uid: Union[UUID, str], - *, - version: Union[int, str] + self, uid: Union[UUID, str], *, version: Union[int, str] ) -> GraphPredictor: """Restore a predictor version.""" return self._versions_collection.restore(uid, version=version) @@ -355,29 +364,35 @@ def root_is_archived(self, uid: Union[UUID, str]) -> bool: Unique identifier of the predictor to check. """ uid = str(uid) - return any(uid == str(archived_pred.uid) for archived_pred in self.list_archived()) + return any( + uid == str(archived_pred.uid) for archived_pred in self.list_archived() + ) def archive(self, uid: Union[UUID, str]): """[UNSUPPORTED] Use archive_root or archive_version instead.""" - raise NotImplementedError("The archive() method is no longer supported. You most likely " - "want archive_root(), or possibly archive_version().") + raise NotImplementedError( + "The archive() method is no longer supported. You most likely " + "want archive_root(), or possibly archive_version()." + ) def restore(self, uid: Union[UUID, str]): """[UNSUPPORTED] Use restore_root or restore_version instead.""" - raise NotImplementedError("The restore() method is no longer supported. You most likely " - "want restore_root(), or possibly restore_version().") + raise NotImplementedError( + "The restore() method is no longer supported. You most likely " + "want restore_root(), or possibly restore_version()." + ) def _list_base(self, *, per_page: int = 100, archived: Optional[bool] = None): filters = {} if archived is not None: filters["archived"] = archived - fetcher = partial(self._fetch_page, - additional_params=filters, - version="v4") - return self._paginator.paginate(page_fetcher=fetcher, - collection_builder=self._build_collection_elements, - per_page=per_page) + fetcher = partial(self._fetch_page, additional_params=filters, version="v4") + return self._paginator.paginate( + page_fetcher=fetcher, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) def list_all(self, *, per_page: int = 20) -> Iterable[GraphPredictor]: """List the most recent version of all predictors.""" @@ -391,17 +406,15 @@ def list_archived(self, *, per_page: int = 20) -> Iterable[GraphPredictor]: """List the most recent version of all archived predictors.""" return self._list_base(per_page=per_page, archived=True) - def list_versions(self, - uid: Union[UUID, str] = None, - *, - per_page: int = 100) -> Iterable[GraphPredictor]: + def list_versions( + self, uid: Union[UUID, str] = None, *, per_page: int = 100 + ) -> Iterable[GraphPredictor]: """List all non-archived versions of the given Predictor.""" return self._versions_collection.list(uid, per_page=per_page) - def list_archived_versions(self, - uid: Union[UUID, str] = None, - *, - per_page: int = 20) -> Iterable[GraphPredictor]: + def list_archived_versions( + self, uid: Union[UUID, str] = None, *, per_page: int = 20 + ) -> Iterable[GraphPredictor]: """List all archived versions of the given Predictor.""" return self._versions_collection.list_archived(uid, per_page=per_page) @@ -435,11 +448,13 @@ def check_for_update(self, uid: Union[UUID, str]) -> Optional[GraphPredictor]: else: return None - def create_default(self, - *, - training_data: DataSource, - pattern: Union[str, AutoConfigureMode] = AutoConfigureMode.INFER, - prefer_valid: bool = True) -> GraphPredictor: + def create_default( + self, + *, + training_data: DataSource, + pattern: Union[str, AutoConfigureMode] = AutoConfigureMode.INFER, + prefer_valid: bool = True, + ) -> GraphPredictor: """Create a default predictor for some training data. This method will return an unregistered predictor generated by inspecting the @@ -478,16 +493,20 @@ def create_default(self, Automatically configured predictor for the training data """ - payload = PredictorCollection._create_default_payload(training_data, pattern, prefer_valid) + payload = PredictorCollection._create_default_payload( + training_data, pattern, prefer_valid + ) path = self._get_path(action="default") data = self.session.post_resource(path, json=payload, version=self._api_version) return self.build(GraphPredictor.wrap_instance(data["instance"])) - def create_default_async(self, - *, - training_data: DataSource, - pattern: Union[str, AutoConfigureMode] = AutoConfigureMode.INFER, - prefer_valid: bool = True) -> AsyncDefaultPredictor: + def create_default_async( + self, + *, + training_data: DataSource, + pattern: Union[str, AutoConfigureMode] = AutoConfigureMode.INFER, + prefer_valid: bool = True, + ) -> AsyncDefaultPredictor: """Similar to PredictorCollection.create_default, except asynchronous. This begins a long-running task to generate the predictor. The returned object contains an @@ -519,20 +538,27 @@ def create_default_async(self, Information on the long-running default predictor generation task. """ - payload = PredictorCollection._create_default_payload(training_data, pattern, prefer_valid) + payload = PredictorCollection._create_default_payload( + training_data, pattern, prefer_valid + ) path = self._get_path(action="default-async") data = self.session.post_resource(path, json=payload, version=self._api_version) return AsyncDefaultPredictor.build(data) @staticmethod - def _create_default_payload(training_data: DataSource, - pattern: Union[str, AutoConfigureMode] = AutoConfigureMode.INFER, - prefer_valid: bool = True) -> dict: + def _create_default_payload( + training_data: DataSource, + pattern: Union[str, AutoConfigureMode] = AutoConfigureMode.INFER, + prefer_valid: bool = True, + ) -> dict: # Continue handling string pattern inputs pattern = AutoConfigureMode.from_str(pattern, exception=True) - return {"data_source": training_data.dump(), "pattern": pattern, - "prefer_valid": prefer_valid} + return { + "data_source": training_data.dump(), + "pattern": pattern, + "prefer_valid": prefer_valid, + } def get_default_async(self, *, task_id: Union[UUID, str]) -> AsyncDefaultPredictor: """Get the current async default predictor generation result. @@ -554,7 +580,9 @@ def is_stale(self, uid: Union[UUID, str], *, version: Union[int, str]) -> bool: """ return self._versions_collection.is_stale(uid, version=version) - def retrain_stale(self, uid: Union[UUID, str], *, version: Union[int, str]) -> GraphPredictor: + def retrain_stale( + self, uid: Union[UUID, str], *, version: Union[int, str] + ) -> GraphPredictor: """Begins retraining a stale predictor. This can only be used on a stale predictor, which is when it's in the READY state, but the @@ -563,12 +591,14 @@ def retrain_stale(self, uid: Union[UUID, str], *, version: Union[int, str]) -> G """ return self._versions_collection.retrain_stale(uid, version=version) - def rename(self, - uid: Union[UUID, str], - *, - version: Union[int, str], - name: Optional[str] = None, - description: Optional[str] = None) -> GraphPredictor: + def rename( + self, + uid: Union[UUID, str], + *, + version: Union[int, str], + name: Optional[str] = None, + description: Optional[str] = None, + ) -> GraphPredictor: """Rename an existing predictor. Both the name and description can be changed. This does not trigger retraining. diff --git a/src/citrine/resources/predictor_evaluation.py b/src/citrine/resources/predictor_evaluation.py index 20156bd49..c31ff3911 100644 --- a/src/citrine/resources/predictor_evaluation.py +++ b/src/citrine/resources/predictor_evaluation.py @@ -2,8 +2,11 @@ from typing import Iterable, Iterator, List, Optional, Union from uuid import UUID -from citrine.informatics.executions.predictor_evaluation import PredictorEvaluation, \ - PredictorEvaluationRequest, PredictorEvaluatorsResponse +from citrine.informatics.executions.predictor_evaluation import ( + PredictorEvaluation, + PredictorEvaluationRequest, + PredictorEvaluatorsResponse, +) from citrine.informatics.predictor_evaluator import PredictorEvaluator from citrine.informatics.predictors import GraphPredictor from citrine.resources.predictor import LATEST_VER as LATEST_PRED_VER @@ -22,11 +25,11 @@ class PredictorEvaluationCollection(Collection[PredictorEvaluation]): """ - _api_version = 'v1' - _path_template = '/projects/{project_id}/predictor-evaluations' + _api_version = "v1" + _path_template = "/projects/{project_id}/predictor-evaluations" _individual_key = None _resource = PredictorEvaluation - _collection_key = 'response' + _collection_key = "response" def __init__(self, project_id: UUID, session: Session): self.project_id = project_id @@ -39,13 +42,14 @@ def build(self, data: dict) -> PredictorEvaluation: evaluation.project_id = self.project_id return evaluation - def _list_base(self, - *, - per_page: int = 100, - predictor_id: Optional[UUID] = None, - predictor_version: Optional[Union[int, str]] = None, - archived: Optional[bool] = None - ) -> Iterator[PredictorEvaluation]: + def _list_base( + self, + *, + per_page: int = 100, + predictor_id: Optional[UUID] = None, + predictor_version: Optional[Union[int, str]] = None, + archived: Optional[bool] = None, + ) -> Iterator[PredictorEvaluation]: params = {"archived": archived} if predictor_id is not None: params["predictor_id"] = str(predictor_id) @@ -53,44 +57,55 @@ def _list_base(self, params["predictor_version"] = predictor_version fetcher = partial(self._fetch_page, additional_params=params) - return self._paginator.paginate(page_fetcher=fetcher, - collection_builder=self._build_collection_elements, - per_page=per_page) - - def list_all(self, - *, - per_page: int = 100, - predictor_id: Optional[UUID] = None, - predictor_version: Optional[Union[int, str]] = None - ) -> Iterable[PredictorEvaluation]: + return self._paginator.paginate( + page_fetcher=fetcher, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) + + def list_all( + self, + *, + per_page: int = 100, + predictor_id: Optional[UUID] = None, + predictor_version: Optional[Union[int, str]] = None, + ) -> Iterable[PredictorEvaluation]: """List all predictor evaluations.""" - return self._list_base(per_page=per_page, - predictor_id=predictor_id, - predictor_version=predictor_version) - - def list(self, - *, - per_page: int = 100, - predictor_id: Optional[UUID] = None, - predictor_version: Optional[Union[int, str]] = None - ) -> Iterable[PredictorEvaluation]: + return self._list_base( + per_page=per_page, + predictor_id=predictor_id, + predictor_version=predictor_version, + ) + + def list( + self, + *, + per_page: int = 100, + predictor_id: Optional[UUID] = None, + predictor_version: Optional[Union[int, str]] = None, + ) -> Iterable[PredictorEvaluation]: """List non-archived predictor evaluations.""" - return self._list_base(per_page=per_page, - predictor_id=predictor_id, - predictor_version=predictor_version, - archived=False) - - def list_archived(self, - *, - per_page: int = 100, - predictor_id: Optional[UUID] = None, - predictor_version: Optional[Union[int, str]] = None - ) -> Iterable[PredictorEvaluation]: + return self._list_base( + per_page=per_page, + predictor_id=predictor_id, + predictor_version=predictor_version, + archived=False, + ) + + def list_archived( + self, + *, + per_page: int = 100, + predictor_id: Optional[UUID] = None, + predictor_version: Optional[Union[int, str]] = None, + ) -> Iterable[PredictorEvaluation]: """List archived predictor evaluations.""" - return self._list_base(per_page=per_page, - predictor_id=predictor_id, - predictor_version=predictor_version, - archived=True) + return self._list_base( + per_page=per_page, + predictor_id=predictor_id, + predictor_version=predictor_version, + archived=True, + ) def archive(self, uid: Union[UUID, str]): """Archive an evaluation.""" @@ -112,14 +127,17 @@ def default_from_config(self, config: GraphPredictor) -> List[PredictorEvaluator """ path = self._get_path(action="default-from-config") payload = config.dump()["instance"] - result = self.session.post_resource(path, json=payload, version=self._api_version) + result = self.session.post_resource( + path, json=payload, version=self._api_version + ) return PredictorEvaluatorsResponse.build(result).evaluators - def default(self, - *, - predictor_id: Union[UUID, str], - predictor_version: Union[int, str] = LATEST_PRED_VER - ) -> List[PredictorEvaluator]: + def default( + self, + *, + predictor_id: Union[UUID, str], + predictor_version: Union[int, str] = LATEST_PRED_VER, + ) -> List[PredictorEvaluator]: """Retrieve the default evaluators for a stored predictor. The current default evaluators perform 5-fold, 3-trial cross-validation on all valid @@ -149,14 +167,18 @@ def default(self, """ # noqa: E501,W505 path = self._get_path(action="default") payload = PredictorRef(uid=predictor_id, version=predictor_version).dump() - result = self.session.post_resource(path, json=payload, version=self._api_version) + result = self.session.post_resource( + path, json=payload, version=self._api_version + ) return PredictorEvaluatorsResponse.build(result).evaluators - def trigger(self, - *, - predictor_id: Union[UUID, str], - predictor_version: Union[int, str] = LATEST_PRED_VER, - evaluators: List[PredictorEvaluator]) -> PredictorEvaluation: + def trigger( + self, + *, + predictor_id: Union[UUID, str], + predictor_version: Union[int, str] = LATEST_PRED_VER, + evaluators: List[PredictorEvaluator], + ) -> PredictorEvaluation: """Evaluate a predictor using the provided evaluators. Parameters @@ -174,17 +196,20 @@ def trigger(self, """ path = self._get_path("trigger") - payload = PredictorEvaluationRequest(evaluators=evaluators, - predictor_id=predictor_id, - predictor_version=predictor_version).dump() + payload = PredictorEvaluationRequest( + evaluators=evaluators, + predictor_id=predictor_id, + predictor_version=predictor_version, + ).dump() result = self.session.post_resource(path, payload, version=self._api_version) return self.build(result) - def trigger_default(self, - *, - predictor_id: Union[UUID, str], - predictor_version: Union[int, str] = LATEST_PRED_VER - ) -> PredictorEvaluation: + def trigger_default( + self, + *, + predictor_id: Union[UUID, str], + predictor_version: Union[int, str] = LATEST_PRED_VER, + ) -> PredictorEvaluation: """Evaluate a predictor using the default evaluators. See :func:`~citrine.resources.PredictorCollection.default` for details on the evaluators. @@ -203,7 +228,9 @@ def trigger_default(self, """ # noqa: E501,W505 path = self._get_path("trigger-default") payload = PredictorRef(uid=predictor_id, version=predictor_version).dump() - result = self.session.post_resource(path, json=payload, version=self._api_version) + result = self.session.post_resource( + path, json=payload, version=self._api_version + ) return self.build(result) def register(self, model: PredictorEvaluation) -> PredictorEvaluation: diff --git a/src/citrine/resources/predictor_evaluation_execution.py b/src/citrine/resources/predictor_evaluation_execution.py index 50b392708..8dd4f5855 100644 --- a/src/citrine/resources/predictor_evaluation_execution.py +++ b/src/citrine/resources/predictor_evaluation_execution.py @@ -1,4 +1,5 @@ """Resources that represent both individual and collections of predictor evaluation executions.""" + from deprecation import deprecated from functools import partial from typing import Optional, Union, Iterator @@ -12,39 +13,52 @@ from citrine.resources.response import Response -class PredictorEvaluationExecutionCollection(Collection["PredictorEvaluationExecution"]): +class PredictorEvaluationExecutionCollection( + Collection["PredictorEvaluationExecution"] +): """A collection of PredictorEvaluationExecutions.""" - _path_template = '/projects/{project_id}/predictor-evaluation-executions' # noqa + _path_template = "/projects/{project_id}/predictor-evaluation-executions" # noqa _individual_key = None - _collection_key = 'response' + _collection_key = "response" _resource = predictor_evaluation_execution.PredictorEvaluationExecution - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Predictor evaluation workflows are being eliminated in favor of directly" - "evaluating predictors. Please use Project.predictor_evaluations instead.") - def __init__(self, - project_id: UUID, - session: Session, - workflow_id: Optional[UUID] = None): + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Predictor evaluation workflows are being eliminated in favor of directly" + "evaluating predictors. Please use Project.predictor_evaluations instead.", + ) + def __init__( + self, project_id: UUID, session: Session, workflow_id: Optional[UUID] = None + ): self.project_id: UUID = project_id self.session: Session = session self.workflow_id: Optional[UUID] = workflow_id - def build(self, data: dict) -> predictor_evaluation_execution.PredictorEvaluationExecution: + def build( + self, data: dict + ) -> predictor_evaluation_execution.PredictorEvaluationExecution: """Build an individual PredictorEvaluationExecution.""" - execution = predictor_evaluation_execution.PredictorEvaluationExecution.build(data) + execution = predictor_evaluation_execution.PredictorEvaluationExecution.build( + data + ) execution._session = self.session execution.project_id = self.project_id return execution - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Please use PredictorEvaluationCollection.trigger instead.") - def trigger(self, - predictor_id: UUID, - *, - predictor_version: Optional[Union[int, str]] = None, - random_state: Optional[int] = None): + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Please use PredictorEvaluationCollection.trigger instead.", + ) + def trigger( + self, + predictor_id: UUID, + *, + predictor_version: Optional[Union[int, str]] = None, + random_state: Optional[int] = None, + ): """Trigger a predictor evaluation execution against a predictor. Parameters @@ -58,14 +72,16 @@ def trigger(self, """ if self.workflow_id is None: - msg = "Cannot trigger a predictor evaluation execution without knowing the " \ - "predictor evaluation workflow. Use workflow.executions.trigger instead of " \ - "project.predictor_evaluation_executions.trigger" + msg = ( + "Cannot trigger a predictor evaluation execution without knowing the " + "predictor evaluation workflow. Use workflow.executions.trigger instead of " + "project.predictor_evaluation_executions.trigger" + ) raise RuntimeError(msg) path = format_escaped_url( - '/projects/{project_id}/predictor-evaluation-workflows/{workflow_id}/executions', + "/projects/{project_id}/predictor-evaluation-workflows/{workflow_id}/executions", project_id=self.project_id, - workflow_id=self.workflow_id + workflow_id=self.workflow_id, ) params = dict() @@ -73,26 +89,29 @@ def trigger(self, params["random_state"] = random_state payload = PredictorRef(predictor_id, predictor_version).dump() - data = self.session.post_resource(path, payload, params=params, version='v2') + data = self.session.post_resource(path, payload, params=params, version="v2") return self.build(data) @deprecated(deprecated_in="3.23.0", removed_in="4.0.0") - def register(self, - model: predictor_evaluation_execution.PredictorEvaluationExecution - ) -> predictor_evaluation_execution.PredictorEvaluationExecution: + def register( + self, model: predictor_evaluation_execution.PredictorEvaluationExecution + ) -> predictor_evaluation_execution.PredictorEvaluationExecution: """Cannot register an execution.""" raise NotImplementedError("Cannot register a PredictorEvaluationExecution.") @deprecated(deprecated_in="3.23.0", removed_in="4.0.0") - def update(self, - model: predictor_evaluation_execution.PredictorEvaluationExecution - ) -> predictor_evaluation_execution.PredictorEvaluationExecution: + def update( + self, model: predictor_evaluation_execution.PredictorEvaluationExecution + ) -> predictor_evaluation_execution.PredictorEvaluationExecution: """Cannot update an execution.""" raise NotImplementedError("Cannot update a PredictorEvaluationExecution.") - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Please use PredictorEvaluation.archive") + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Please use PredictorEvaluation.archive", + ) def archive(self, uid: Union[UUID, str]): """Archive a predictor evaluation execution. @@ -102,10 +121,13 @@ def archive(self, uid: Union[UUID, str]): Unique identifier of the execution to archive """ - self._put_resource_ref('archive', uid) + self._put_resource_ref("archive", uid) - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Please use PredictorEvaluation.restore") + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Please use PredictorEvaluation.restore", + ) def restore(self, uid: Union[UUID, str]): """Restore an archived predictor evaluation execution. @@ -115,16 +137,20 @@ def restore(self, uid: Union[UUID, str]): Unique identifier of the execution to restore """ - self._put_resource_ref('restore', uid) - - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Please use PredictorEvaluation.list") - def list(self, - *, - per_page: int = 100, - predictor_id: Optional[UUID] = None, - predictor_version: Optional[Union[int, str]] = None - ) -> Iterator[predictor_evaluation_execution.PredictorEvaluationExecution]: + self._put_resource_ref("restore", uid) + + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Please use PredictorEvaluation.list", + ) + def list( + self, + *, + per_page: int = 100, + predictor_id: Optional[UUID] = None, + predictor_version: Optional[Union[int, str]] = None, + ) -> Iterator[predictor_evaluation_execution.PredictorEvaluationExecution]: """ Paginate over the elements of the collection. @@ -154,19 +180,26 @@ def list(self, params["workflow_id"] = str(self.workflow_id) fetcher = partial(self._fetch_page, additional_params=params) - return self._paginator.paginate(page_fetcher=fetcher, - collection_builder=self._build_collection_elements, - per_page=per_page) + return self._paginator.paginate( + page_fetcher=fetcher, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) @deprecated(deprecated_in="3.23.0", removed_in="4.0.0") def delete(self, uid: Union[UUID, str]) -> Response: """Predictor Evaluation Executions cannot be deleted; they can be archived instead.""" raise NotImplementedError( - "Predictor Evaluation Executions cannot be deleted; they can be archived instead.") + "Predictor Evaluation Executions cannot be deleted; they can be archived instead." + ) - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Please use PredictorEvaluation.get") - def get(self, - uid: Union[UUID, str]) -> predictor_evaluation_execution.PredictorEvaluationExecution: + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Please use PredictorEvaluation.get", + ) + def get( + self, uid: Union[UUID, str] + ) -> predictor_evaluation_execution.PredictorEvaluationExecution: """Get a particular element of the collection.""" return super().get(uid) diff --git a/src/citrine/resources/predictor_evaluation_workflow.py b/src/citrine/resources/predictor_evaluation_workflow.py index e41e20ebc..c1ba5cb8f 100644 --- a/src/citrine/resources/predictor_evaluation_workflow.py +++ b/src/citrine/resources/predictor_evaluation_workflow.py @@ -1,4 +1,5 @@ """Resources that represent both individual and collections of workflow executions.""" + from deprecation import deprecated from typing import Iterator, Optional, Union from uuid import UUID @@ -12,14 +13,17 @@ class PredictorEvaluationWorkflowCollection(Collection[PredictorEvaluationWorkflow]): """A collection of PredictorEvaluationWorkflows.""" - _path_template = '/projects/{project_id}/predictor-evaluation-workflows' + _path_template = "/projects/{project_id}/predictor-evaluation-workflows" _individual_key = None - _collection_key = 'response' + _collection_key = "response" _resource = PredictorEvaluationWorkflow - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Predictor evaluation workflows are being eliminated in favor of directly" - "evaluating predictors. Please use Project.predictor_evaluations instead.") + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Predictor evaluation workflows are being eliminated in favor of directly" + "evaluating predictors. Please use Project.predictor_evaluations instead.", + ) def __init__(self, project_id: UUID, session: Session): self.project_id: UUID = project_id self.session: Session = session @@ -31,8 +35,11 @@ def build(self, data: dict) -> PredictorEvaluationWorkflow: workflow.project_id = self.project_id return workflow - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Please use PredictorEvaluations instead, which doesn't store workflows.") + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Please use PredictorEvaluations instead, which doesn't store workflows.", + ) def archive(self, uid: Union[UUID, str]): """Archive a predictor evaluation workflow. @@ -42,10 +49,13 @@ def archive(self, uid: Union[UUID, str]): Unique identifier of the workflow to archive """ - return self._put_resource_ref('archive', uid) + return self._put_resource_ref("archive", uid) - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Please use PredictorEvaluations instead, which doesn't store workflows.") + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Please use PredictorEvaluations instead, which doesn't store workflows.", + ) def restore(self, uid: Union[UUID, str] = None): """Restore an archived predictor evaluation workflow. @@ -55,22 +65,24 @@ def restore(self, uid: Union[UUID, str] = None): Unique identifier of the workflow to restore """ - return self._put_resource_ref('restore', uid) + return self._put_resource_ref("restore", uid) @deprecated(deprecated_in="3.23.0", removed_in="4.0.0") def delete(self, uid: Union[UUID, str]) -> Response: """Predictor Evaluation Workflows cannot be deleted; they can be archived instead.""" raise NotImplementedError( - "Predictor Evaluation Workflows cannot be deleted; they can be archived instead.") - - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Please use PredictorEvaluations.trigger_default instead. It doesn't store" - " a workflow, but it triggers an evaluation with the default evaluators.") - def create_default(self, - *, - predictor_id: UUID, - predictor_version: Optional[Union[int, str]] = None) \ - -> PredictorEvaluationWorkflow: + "Predictor Evaluation Workflows cannot be deleted; they can be archived instead." + ) + + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Please use PredictorEvaluations.trigger_default instead. It doesn't store" + " a workflow, but it triggers an evaluation with the default evaluators.", + ) + def create_default( + self, *, predictor_id: UUID, predictor_version: Optional[Union[int, str]] = None + ) -> PredictorEvaluationWorkflow: """Create a default predictor evaluation workflow for a predictor and execute it. The current default predictor evaluation workflow performs 5-fold, 1-trial cross-validation @@ -101,33 +113,47 @@ def create_default(self, Default workflow """ # noqa: E501,W505 - url = self._get_path('default') - payload = {'predictor_id': str(predictor_id)} + url = self._get_path("default") + payload = {"predictor_id": str(predictor_id)} if predictor_version: - payload['predictor_version'] = predictor_version + payload["predictor_version"] = predictor_version data = self.session.post_resource(url, payload) return self.build(data) - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Please use PredictorEvaluations instead, which doesn't store workflows.") - def register(self, model: PredictorEvaluationWorkflow) -> PredictorEvaluationWorkflow: + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Please use PredictorEvaluations instead, which doesn't store workflows.", + ) + def register( + self, model: PredictorEvaluationWorkflow + ) -> PredictorEvaluationWorkflow: """Create a new element of the collection by registering an existing resource.""" return super().register(model) - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Please use PredictorEvaluations instead, which doesn't store workflows.") + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Please use PredictorEvaluations instead, which doesn't store workflows.", + ) def list(self, *, per_page: int = 100) -> Iterator[PredictorEvaluationWorkflow]: """Paginate over the elements of the collection.""" return super().list(per_page=per_page) - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Please use PredictorEvaluations instead, which doesn't store workflows.") + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Please use PredictorEvaluations instead, which doesn't store workflows.", + ) def update(self, model: PredictorEvaluationWorkflow) -> PredictorEvaluationWorkflow: """Update a particular element of the collection.""" return super().update(model) - @deprecated(deprecated_in="3.23.0", removed_in="4.0.0", - details="Please use PredictorEvaluations instead, which doesn't store workflows.") + @deprecated( + deprecated_in="3.23.0", + removed_in="4.0.0", + details="Please use PredictorEvaluations instead, which doesn't store workflows.", + ) def get(self, uid: Union[UUID, str]) -> PredictorEvaluationWorkflow: """Get a particular element of the collection.""" return super().get(uid) diff --git a/src/citrine/resources/process_run.py b/src/citrine/resources/process_run.py index 7a960ff56..8f457fa80 100644 --- a/src/citrine/resources/process_run.py +++ b/src/citrine/resources/process_run.py @@ -1,4 +1,5 @@ """Resources that represent process run data objects.""" + from typing import List, Dict, Optional, Type, Union, Iterator from uuid import UUID @@ -16,7 +17,9 @@ from gemd.entity.source.performed_source import PerformedSource -class ProcessRun(GEMDResource['ProcessRun'], ObjectRun, GEMDProcessRun, typ=GEMDProcessRun.typ): +class ProcessRun( + GEMDResource["ProcessRun"], ObjectRun, GEMDProcessRun, typ=GEMDProcessRun.typ +): """ A process run. @@ -51,39 +54,59 @@ class ProcessRun(GEMDResource['ProcessRun'], ObjectRun, GEMDProcessRun, typ=GEMD _response_key = GEMDProcessRun.typ # 'process_run' - name = String('name', override=True, use_init=True) - conditions = PropertyOptional(PropertyList(Object(Condition)), 'conditions', override=True) - parameters = PropertyOptional(PropertyList(Object(Parameter)), 'parameters', override=True) - spec = PropertyOptional(LinkOrElse(GEMDProcessSpec), 'spec', override=True, use_init=True,) + name = String("name", override=True, use_init=True) + conditions = PropertyOptional( + PropertyList(Object(Condition)), "conditions", override=True + ) + parameters = PropertyOptional( + PropertyList(Object(Parameter)), "parameters", override=True + ) + spec = PropertyOptional( + LinkOrElse(GEMDProcessSpec), + "spec", + override=True, + use_init=True, + ) source = PropertyOptional(Object(PerformedSource), "source", override=True) - def __init__(self, - name: str, - *, - uids: Optional[Dict[str, str]] = None, - tags: Optional[List[str]] = None, - notes: Optional[str] = None, - conditions: Optional[List[Condition]] = None, - parameters: Optional[List[Parameter]] = None, - spec: Optional[GEMDProcessSpec] = None, - file_links: Optional[List[FileLink]] = None, - source: Optional[PerformedSource] = None): + def __init__( + self, + name: str, + *, + uids: Optional[Dict[str, str]] = None, + tags: Optional[List[str]] = None, + notes: Optional[str] = None, + conditions: Optional[List[Condition]] = None, + parameters: Optional[List[Parameter]] = None, + spec: Optional[GEMDProcessSpec] = None, + file_links: Optional[List[FileLink]] = None, + source: Optional[PerformedSource] = None, + ): if uids is None: uids = dict() super(ObjectRun, self).__init__() - GEMDProcessRun.__init__(self, name=name, uids=uids, - tags=tags, conditions=conditions, parameters=parameters, - spec=spec, file_links=file_links, notes=notes, source=source) + GEMDProcessRun.__init__( + self, + name=name, + uids=uids, + tags=tags, + conditions=conditions, + parameters=parameters, + spec=spec, + file_links=file_links, + notes=notes, + source=source, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class ProcessRunCollection(ObjectRunCollection[ProcessRun]): """Represents the collection of all process runs associated with a dataset.""" - _individual_key = 'process_run' - _collection_key = 'process_runs' + _individual_key = "process_run" + _collection_key = "process_runs" _resource = ProcessRun @classmethod @@ -91,9 +114,9 @@ def get_type(cls) -> Type[ProcessRun]: """Return the resource type in the collection.""" return ProcessRun - def list_by_spec(self, - uid: Union[UUID, str, LinkByUID, GEMDProcessSpec] - ) -> Iterator[ProcessRun]: + def list_by_spec( + self, uid: Union[UUID, str, LinkByUID, GEMDProcessSpec] + ) -> Iterator[ProcessRun]: """ Get the process runs using the specified process spec. @@ -108,4 +131,4 @@ def list_by_spec(self, The process runs using the specified process spec. """ - return self._get_relation('process-specs', uid=uid) + return self._get_relation("process-specs", uid=uid) diff --git a/src/citrine/resources/process_spec.py b/src/citrine/resources/process_spec.py index 46f4de458..fb86fbae4 100644 --- a/src/citrine/resources/process_spec.py +++ b/src/citrine/resources/process_spec.py @@ -1,4 +1,5 @@ """Resources that represent process spec objects.""" + from typing import Optional, Dict, List, Type, Union, Iterator from uuid import UUID @@ -16,10 +17,7 @@ class ProcessSpec( - GEMDResource['ProcessSpec'], - ObjectSpec, - GEMDProcessSpec, - typ=GEMDProcessSpec.typ + GEMDResource["ProcessSpec"], ObjectSpec, GEMDProcessSpec, typ=GEMDProcessSpec.typ ): """ A process specification. @@ -53,41 +51,56 @@ class ProcessSpec( _response_key = GEMDProcessSpec.typ # 'process_spec' - name = String('name', override=True, use_init=True) - conditions = PropertyOptional(PropertyList(Object(Condition)), 'conditions', override=True) - parameters = PropertyOptional(PropertyList(Object(Parameter)), 'parameters', override=True) - template = PropertyOptional(LinkOrElse(GEMDProcessTemplate), - 'template', override=True, - use_init=True, - ) - - def __init__(self, - name: str, - *, - uids: Optional[Dict[str, str]] = None, - tags: Optional[List[str]] = None, - notes: Optional[str] = None, - conditions: Optional[List[Condition]] = None, - parameters: Optional[List[Parameter]] = None, - template: Optional[GEMDProcessTemplate] = None, - file_links: Optional[List[FileLink]] = None - ): + name = String("name", override=True, use_init=True) + conditions = PropertyOptional( + PropertyList(Object(Condition)), "conditions", override=True + ) + parameters = PropertyOptional( + PropertyList(Object(Parameter)), "parameters", override=True + ) + template = PropertyOptional( + LinkOrElse(GEMDProcessTemplate), + "template", + override=True, + use_init=True, + ) + + def __init__( + self, + name: str, + *, + uids: Optional[Dict[str, str]] = None, + tags: Optional[List[str]] = None, + notes: Optional[str] = None, + conditions: Optional[List[Condition]] = None, + parameters: Optional[List[Parameter]] = None, + template: Optional[GEMDProcessTemplate] = None, + file_links: Optional[List[FileLink]] = None, + ): if uids is None: uids = dict() super(ObjectSpec, self).__init__() - GEMDProcessSpec.__init__(self, name=name, uids=uids, - tags=tags, conditions=conditions, parameters=parameters, - template=template, file_links=file_links, notes=notes) + GEMDProcessSpec.__init__( + self, + name=name, + uids=uids, + tags=tags, + conditions=conditions, + parameters=parameters, + template=template, + file_links=file_links, + notes=notes, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class ProcessSpecCollection(ObjectSpecCollection[ProcessSpec]): """Represents the collection of all process specs associated with a dataset.""" - _individual_key = 'process_spec' - _collection_key = 'process_specs' + _individual_key = "process_spec" + _collection_key = "process_specs" _resource = ProcessSpec @classmethod @@ -95,9 +108,9 @@ def get_type(cls) -> Type[ProcessSpec]: """Return the resource type in the collection.""" return ProcessSpec - def list_by_template(self, - uid: Union[UUID, str, LinkByUID, GEMDProcessTemplate] - ) -> Iterator[ProcessSpec]: + def list_by_template( + self, uid: Union[UUID, str, LinkByUID, GEMDProcessTemplate] + ) -> Iterator[ProcessSpec]: """ Get the process specs using the specified process template. @@ -112,4 +125,4 @@ def list_by_template(self, The process specs using the specified process template """ - return self._get_relation('process-templates', uid=uid) + return self._get_relation("process-templates", uid=uid) diff --git a/src/citrine/resources/process_template.py b/src/citrine/resources/process_template.py index a53b692f3..8f2ccf97c 100644 --- a/src/citrine/resources/process_template.py +++ b/src/citrine/resources/process_template.py @@ -1,27 +1,36 @@ """Resources that represent process templates.""" + from typing import List, Dict, Optional, Union, Sequence, Type from citrine._rest.resource import GEMDResource from citrine._serialization.properties import List as PropertyList from citrine._serialization.properties import Optional as PropertyOptional from citrine._serialization.properties import Union as PropertyUnion -from citrine._serialization.properties import String, Object, SpecifiedMixedList, \ - LinkOrElse +from citrine._serialization.properties import ( + String, + Object, + SpecifiedMixedList, + LinkOrElse, +) from citrine.resources.condition_template import ConditionTemplate from citrine.resources.object_templates import ObjectTemplate, ObjectTemplateCollection from citrine.resources.parameter_template import ParameterTemplate from gemd.entity.bounds.base_bounds import BaseBounds from gemd.entity.link_by_uid import LinkByUID from gemd.entity.template.process_template import ProcessTemplate as GEMDProcessTemplate -from gemd.entity.template.condition_template import ConditionTemplate as GEMDConditionTemplate -from gemd.entity.template.parameter_template import ParameterTemplate as GEMDParameterTemplate +from gemd.entity.template.condition_template import ( + ConditionTemplate as GEMDConditionTemplate, +) +from gemd.entity.template.parameter_template import ( + ParameterTemplate as GEMDParameterTemplate, +) class ProcessTemplate( - GEMDResource['ProcessTemplate'], + GEMDResource["ProcessTemplate"], ObjectTemplate, GEMDProcessTemplate, - typ=GEMDProcessTemplate.typ + typ=GEMDProcessTemplate.typ, ): """ A process template. @@ -61,62 +70,97 @@ class ProcessTemplate( conditions = PropertyOptional( PropertyList( - PropertyUnion([LinkOrElse(GEMDConditionTemplate), - SpecifiedMixedList([LinkOrElse(GEMDConditionTemplate), - PropertyOptional(Object(BaseBounds))])] - ) + PropertyUnion( + [ + LinkOrElse(GEMDConditionTemplate), + SpecifiedMixedList( + [ + LinkOrElse(GEMDConditionTemplate), + PropertyOptional(Object(BaseBounds)), + ] + ), + ] + ) ), - 'conditions', - override=True + "conditions", + override=True, ) parameters = PropertyOptional( PropertyList( - PropertyUnion([LinkOrElse(GEMDParameterTemplate), - SpecifiedMixedList([LinkOrElse(GEMDParameterTemplate), - PropertyOptional(Object(BaseBounds))])] - ) + PropertyUnion( + [ + LinkOrElse(GEMDParameterTemplate), + SpecifiedMixedList( + [ + LinkOrElse(GEMDParameterTemplate), + PropertyOptional(Object(BaseBounds)), + ] + ), + ] + ) ), - 'parameters', - override=True + "parameters", + override=True, + ) + allowed_labels = PropertyOptional( + PropertyList(String()), "allowed_labels", override=True + ) + allowed_names = PropertyOptional( + PropertyList(String()), "allowed_names", override=True ) - allowed_labels = PropertyOptional(PropertyList(String()), 'allowed_labels', override=True) - allowed_names = PropertyOptional(PropertyList(String()), 'allowed_names', override=True) - def __init__(self, - name: str, - *, - uids: Optional[Dict[str, str]] = None, - conditions: Optional[Sequence[Union[ConditionTemplate, - LinkByUID, - Sequence[Union[ConditionTemplate, LinkByUID, - Optional[BaseBounds]]] - ]]] = None, - parameters: Optional[Sequence[Union[ParameterTemplate, - LinkByUID, - Sequence[Union[ParameterTemplate, LinkByUID, - Optional[BaseBounds]]] - ]]] = None, - allowed_labels: Optional[List[str]] = None, - allowed_names: Optional[List[str]] = None, - description: Optional[str] = None, - tags: Optional[List[str]] = None): + def __init__( + self, + name: str, + *, + uids: Optional[Dict[str, str]] = None, + conditions: Optional[ + Sequence[ + Union[ + ConditionTemplate, + LinkByUID, + Sequence[Union[ConditionTemplate, LinkByUID, Optional[BaseBounds]]], + ] + ] + ] = None, + parameters: Optional[ + Sequence[ + Union[ + ParameterTemplate, + LinkByUID, + Sequence[Union[ParameterTemplate, LinkByUID, Optional[BaseBounds]]], + ] + ] + ] = None, + allowed_labels: Optional[List[str]] = None, + allowed_names: Optional[List[str]] = None, + description: Optional[str] = None, + tags: Optional[List[str]] = None, + ): if uids is None: uids = dict() super(ObjectTemplate, self).__init__() - GEMDProcessTemplate.__init__(self, name=name, uids=uids, - conditions=conditions, parameters=parameters, tags=tags, - description=description, allowed_labels=allowed_labels, - allowed_names=allowed_names) + GEMDProcessTemplate.__init__( + self, + name=name, + uids=uids, + conditions=conditions, + parameters=parameters, + tags=tags, + description=description, + allowed_labels=allowed_labels, + allowed_names=allowed_names, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class ProcessTemplateCollection(ObjectTemplateCollection[ProcessTemplate]): """A collection of process templates.""" - _individual_key = 'process_template' - _collection_key = 'process_templates' + _individual_key = "process_template" + _collection_key = "process_templates" _resource = ProcessTemplate @classmethod diff --git a/src/citrine/resources/project.py b/src/citrine/resources/project.py index 4e9777f1f..5c66a98ea 100644 --- a/src/citrine/resources/project.py +++ b/src/citrine/resources/project.py @@ -1,4 +1,5 @@ """Resources that represent both individual and collections of projects.""" + from deprecation import deprecated from functools import partial from typing import Optional, Dict, List, Union, Iterable, Tuple, Iterator @@ -37,19 +38,22 @@ from citrine.resources.process_spec import ProcessSpecCollection from citrine.resources.process_template import ProcessTemplateCollection from citrine.resources.predictor import PredictorCollection -from citrine.resources.predictor_evaluation_execution import \ - PredictorEvaluationExecutionCollection -from citrine.resources.predictor_evaluation_workflow import \ - PredictorEvaluationWorkflowCollection +from citrine.resources.predictor_evaluation_execution import ( + PredictorEvaluationExecutionCollection, +) +from citrine.resources.predictor_evaluation_workflow import ( + PredictorEvaluationWorkflowCollection, +) from citrine.resources.predictor_evaluation import PredictorEvaluationCollection -from citrine.resources.generative_design_execution import \ - GenerativeDesignExecutionCollection +from citrine.resources.generative_design_execution import ( + GenerativeDesignExecutionCollection, +) from citrine.resources.project_member import ProjectMember from citrine.resources.response import Response from citrine.resources.table_config import TableConfigCollection -class Project(Resource['Project']): +class Project(Resource["Project"]): """ A Citrine Project. @@ -67,27 +71,29 @@ class Project(Resource['Project']): """ - _response_key = 'project' + _response_key = "project" _resource_type = ResourceTypeEnum.PROJECT - name = properties.String('name') - description = properties.Optional(properties.String(), 'description') - uid = properties.Optional(properties.UUID(), 'id') + name = properties.String("name") + description = properties.Optional(properties.String(), "description") + uid = properties.Optional(properties.UUID(), "id") """UUID: Unique uuid4 identifier of this project.""" - status = properties.Optional(properties.String(), 'status') + status = properties.Optional(properties.String(), "status") """str: Status of the project.""" - created_at = properties.Optional(properties.Datetime(), 'created_at') + created_at = properties.Optional(properties.Datetime(), "created_at") """int: Time the project was created, in seconds since epoch.""" - archived = properties.Optional(properties.Boolean, 'archived') + archived = properties.Optional(properties.Boolean, "archived") """bool: Whether the project is archived.""" _team_id = properties.Optional(properties.UUID, "team.id", serializable=False) - def __init__(self, - name: str, - *, - description: Optional[str] = None, - session: Optional[Session] = None, - team_id: Optional[UUID] = None): + def __init__( + self, + name: str, + *, + description: Optional[str] = None, + session: Optional[Session] = None, + team_id: Optional[UUID] = None, + ): self.name: str = name self.description: Optional[str] = description self.session: Session = session @@ -97,18 +103,17 @@ def _post_dump(self, data: dict) -> dict: return {key: value for key, value in data.items() if value is not None} def __str__(self): - return ''.format(self.name) + return "".format(self.name) def _path(self): - return format_escaped_url('/projects/{project_id}', project_id=self.uid) + return format_escaped_url("/projects/{project_id}", project_id=self.uid) @property def team_id(self): """Returns the Team's id-scoped UUID.""" if self._team_id is None: self._team_id = self.get_team_id_from_project_id( - session=self.session, - project_id=self.uid + session=self.session, project_id=self.uid ) return self._team_id @@ -119,8 +124,8 @@ def team_id(self, value: Optional[UUID]): @classmethod def get_team_id_from_project_id(cls, session: Session, project_id: UUID): """Returns the UUID of the Team that owns the project with the provided project_id.""" - response = session.get_resource(path=f'projects/{project_id}', version="v3") - return response['project']['team']['id'] + response = session.get_resource(path=f"projects/{project_id}", version="v3") + return response["project"]["team"]["id"] @property def branches(self) -> BranchCollection: @@ -145,12 +150,16 @@ def descriptors(self) -> DescriptorMethods: @property def predictor_evaluation_workflows(self) -> PredictorEvaluationWorkflowCollection: """Return a collection representing all visible predictor evaluation workflows.""" - return PredictorEvaluationWorkflowCollection(project_id=self.uid, session=self.session) + return PredictorEvaluationWorkflowCollection( + project_id=self.uid, session=self.session + ) @property def predictor_evaluation_executions(self) -> PredictorEvaluationExecutionCollection: """Return a collection representing all visible predictor evaluation executions.""" - return PredictorEvaluationExecutionCollection(project_id=self.uid, session=self.session) + return PredictorEvaluationExecutionCollection( + project_id=self.uid, session=self.session + ) @property def predictor_evaluations(self) -> PredictorEvaluationCollection: @@ -165,176 +174,260 @@ def design_workflows(self) -> DesignWorkflowCollection: @property def generative_design_executions(self) -> GenerativeDesignExecutionCollection: """Return a collection representing all visible generative design executions.""" - return GenerativeDesignExecutionCollection(project_id=self.uid, session=self.session) + return GenerativeDesignExecutionCollection( + project_id=self.uid, session=self.session + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.datasets' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.datasets' instead.'", + ) def datasets(self) -> DatasetCollection: """Return a resource representing all visible datasets.""" - return DatasetCollection(team_id=self.team_id, project_id=self.uid, session=self.session) + return DatasetCollection( + team_id=self.team_id, project_id=self.uid, session=self.session + ) @property def tables(self) -> GemTableCollection: """Return a resource representing all visible Tables.""" - return GemTableCollection(team_id=self.team_id, project_id=self.uid, session=self.session) + return GemTableCollection( + team_id=self.team_id, project_id=self.uid, session=self.session + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.property_templates' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.property_templates' instead.'", + ) def property_templates(self) -> PropertyTemplateCollection: """Return a resource representing all property templates in this dataset.""" - return PropertyTemplateCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return PropertyTemplateCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.condition_templates' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.condition_templates' instead.'", + ) def condition_templates(self) -> ConditionTemplateCollection: """Return a resource representing all condition templates in this dataset.""" - return ConditionTemplateCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return ConditionTemplateCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.parameter_templates' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.parameter_templates' instead.'", + ) def parameter_templates(self) -> ParameterTemplateCollection: """Return a resource representing all parameter templates in this dataset.""" - return ParameterTemplateCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return ParameterTemplateCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.material_templates' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.material_templates' instead.'", + ) def material_templates(self) -> MaterialTemplateCollection: """Return a resource representing all material templates in this dataset.""" - return MaterialTemplateCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return MaterialTemplateCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.measurement_templates' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.measurement_templates' instead.'", + ) def measurement_templates(self) -> MeasurementTemplateCollection: """Return a resource representing all measurement templates in this dataset.""" - return MeasurementTemplateCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return MeasurementTemplateCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.process_templates' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.process_templates' instead.'", + ) def process_templates(self) -> ProcessTemplateCollection: """Return a resource representing all process templates in this dataset.""" - return ProcessTemplateCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return ProcessTemplateCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.process_runs' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.process_runs' instead.'", + ) def process_runs(self) -> ProcessRunCollection: """Return a resource representing all process runs in this dataset.""" - return ProcessRunCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return ProcessRunCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.measurement_runs' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.measurement_runs' instead.'", + ) def measurement_runs(self) -> MeasurementRunCollection: """Return a resource representing all measurement runs in this dataset.""" - return MeasurementRunCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return MeasurementRunCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.material_runs' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.material_runs' instead.'", + ) def material_runs(self) -> MaterialRunCollection: """Return a resource representing all material runs in this dataset.""" - return MaterialRunCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return MaterialRunCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.ingredient_runs' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.ingredient_runs' instead.'", + ) def ingredient_runs(self) -> IngredientRunCollection: """Return a resource representing all ingredient runs in this dataset.""" - return IngredientRunCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return IngredientRunCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.process_specs' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.process_specs' instead.'", + ) def process_specs(self) -> ProcessSpecCollection: """Return a resource representing all process specs in this dataset.""" - return ProcessSpecCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return ProcessSpecCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.measurement_specs' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.measurement_specs' instead.'", + ) def measurement_specs(self) -> MeasurementSpecCollection: """Return a resource representing all measurement specs in this dataset.""" - return MeasurementSpecCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return MeasurementSpecCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.material_specs' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.material_specs' instead.'", + ) def material_specs(self) -> MaterialSpecCollection: """Return a resource representing all material specs in this dataset.""" - return MaterialSpecCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return MaterialSpecCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.ingredient_specs' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.ingredient_specs' instead.'", + ) def ingredient_specs(self) -> IngredientSpecCollection: """Return a resource representing all ingredient specs in this dataset.""" - return IngredientSpecCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return IngredientSpecCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.gemd' instead.'") + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.gemd' instead.'", + ) def gemd(self) -> GEMDResourceCollection: """Return a resource representing all GEMD objects/templates in this dataset.""" - return GEMDResourceCollection(project_id=self.uid, - dataset_id=None, - session=self.session, - team_id=self.team_id) + return GEMDResourceCollection( + project_id=self.uid, + dataset_id=None, + session=self.session, + team_id=self.team_id, + ) @property def table_configs(self) -> TableConfigCollection: """Return a resource representing all Table Configs in the project.""" - return TableConfigCollection(team_id=self.team_id, - project_id=self.uid, - session=self.session) + return TableConfigCollection( + team_id=self.team_id, project_id=self.uid, session=self.session + ) def publish(self, *, resource: Resource): """ @@ -358,15 +451,18 @@ def publish(self, *, resource: Resource): resource_access = resource.access_control_dict() resource_type = resource_access["type"] if resource_type == ResourceTypeEnum.DATASET: - warn("Newly created datasets belong to a team, making this is unncessary. If it was " - "created before 3.4.0, publish will work as before. Calling publish on datasets " - "will be disabled in 4.0.0, at which time all datasets will be automatically " - "published.", - DeprecationWarning) + warn( + "Newly created datasets belong to a team, making this is unncessary. If it was " + "created before 3.4.0, publish will work as before. Calling publish on datasets " + "will be disabled in 4.0.0, at which time all datasets will be automatically " + "published.", + DeprecationWarning, + ) self.session.checked_post( f"{self._path()}/published-resources/{resource_type}/batch-publish", - version='v3', - json={'ids': [resource_access["id"]]}) + version="v3", + json={"ids": [resource_access["id"]]}, + ) return True def un_publish(self, *, resource: Resource): @@ -387,15 +483,18 @@ def un_publish(self, *, resource: Resource): resource_access = resource.access_control_dict() resource_type = resource_access["type"] if resource_type == ResourceTypeEnum.DATASET: - warn("Newly created datasets belong to a team, making un_publish a no-op. If it was " - "created before 3.4.0, un_publish will work as before. Calling un_publish on " - "datasets will be disabled in 4.0.0, at which time all datasets will be " - "automatically published.", - DeprecationWarning) + warn( + "Newly created datasets belong to a team, making un_publish a no-op. If it was " + "created before 3.4.0, un_publish will work as before. Calling un_publish on " + "datasets will be disabled in 4.0.0, at which time all datasets will be " + "automatically published.", + DeprecationWarning, + ) self.session.checked_post( f"{self._path()}/published-resources/{resource_type}/batch-un-publish", - version='v3', - json={'ids': [resource_access["id"]]}) + version="v3", + json={"ids": [resource_access["id"]]}, + ) return True def pull_in_resource(self, *, resource: Resource): @@ -416,16 +515,19 @@ def pull_in_resource(self, *, resource: Resource): resource_access = resource.access_control_dict() resource_type = resource_access["type"] if resource_type == ResourceTypeEnum.DATASET: - warn("Newly created datasets belong to a team, making pull_in_resource a no-op. If it " - "was created before 3.4.0, pull_in_resource will work as before. Calling " - "pull_in_resource on datasets will be disabled in 4.0.0, at which time all " - "datasets will be automatically published.", - DeprecationWarning) - base_url = f'/teams/{self.team_id}{self._path()}' + warn( + "Newly created datasets belong to a team, making pull_in_resource a no-op. If it " + "was created before 3.4.0, pull_in_resource will work as before. Calling " + "pull_in_resource on datasets will be disabled in 4.0.0, at which time all " + "datasets will be automatically published.", + DeprecationWarning, + ) + base_url = f"/teams/{self.team_id}{self._path()}" self.session.checked_post( - f'{base_url}/outside-resources/{resource_type}/batch-pull-in', - version='v3', - json={'ids': [resource_access["id"]]}) + f"{base_url}/outside-resources/{resource_type}/batch-pull-in", + version="v3", + json={"ids": [resource_access["id"]]}, + ) return True def owned_dataset_ids(self) -> List[str]: @@ -441,13 +543,13 @@ def owned_dataset_ids(self) -> List[str]: warn( "Datasets are no be longer owned by Projects. To find the Datasets owned by your " "Team, use Team.owned_dataset_ids().", - DeprecationWarning + DeprecationWarning, ) query_params = {"userId": "", "domain": self._path(), "action": "WRITE"} - response = self.session.get_resource("/DATASET/authorized-ids", - params=query_params, - version="v3") - return response['ids'] + response = self.session.get_resource( + "/DATASET/authorized-ids", params=query_params, version="v3" + ) + return response["ids"] def list_members(self) -> Union[List[ProjectMember], List["TeamMember"]]: # noqa: F821 """ @@ -467,13 +569,18 @@ def list_members(self) -> Union[List[ProjectMember], List["TeamMember"]]: # noq parent_team = team_collection.get(self.team_id) return parent_team.list_members() - @deprecated(deprecated_in="3.4.0", removed_in="4.0.0", - details="Please use 'Team.gemd_batch_delete' instead.'") - def gemd_batch_delete(self, - id_list: List[Union[LinkByUID, UUID, str, BaseEntity]], - *, - timeout: float = 2 * 60, - polling_delay: float = 1.0) -> List[Tuple[LinkByUID, ApiError]]: + @deprecated( + deprecated_in="3.4.0", + removed_in="4.0.0", + details="Please use 'Team.gemd_batch_delete' instead.'", + ) + def gemd_batch_delete( + self, + id_list: List[Union[LinkByUID, UUID, str, BaseEntity]], + *, + timeout: float = 2 * 60, + polling_delay: float = 1.0, + ) -> List[Tuple[LinkByUID, ApiError]]: """ Remove a set of GEMD objects. @@ -507,12 +614,14 @@ def gemd_batch_delete(self, deleted. """ - return _async_gemd_batch_delete(id_list=id_list, - team_id=self.team_id, - session=self.session, - dataset_id=None, - timeout=timeout, - polling_delay=polling_delay) + return _async_gemd_batch_delete( + id_list=id_list, + team_id=self.team_id, + session=self.session, + dataset_id=None, + timeout=timeout, + polling_delay=polling_delay, + ) class ProjectCollection(Collection[Project]): @@ -529,13 +638,14 @@ class ProjectCollection(Collection[Project]): @property def _path_template(self): if self.team_id is None: - return '/projects' + return "/projects" else: - return '/teams/{team_id}/projects' - _individual_key = 'project' - _collection_key = 'projects' + return "/teams/{team_id}/projects" + + _individual_key = "project" + _collection_key = "projects" _resource = Project - _api_version = 'v3' + _api_version = "v3" def __init__(self, session: Session, *, team_id: Optional[UUID] = None): self.session = session @@ -596,8 +706,10 @@ def register(self, name: str, *, description: Optional[str] = None) -> Project: """ if self.team_id is None: - raise NotImplementedError("Cannot register a project without a team ID. " - "Use team.projects.register.") + raise NotImplementedError( + "Cannot register a project without a team ID. " + "Use team.projects.register." + ) project = Project(name, description=description) return super().register(project) @@ -607,10 +719,14 @@ def _list_base(self, *, per_page: int = 1000, archived: Optional[bool] = None): if archived is not None: filters["archived"] = str(archived).lower() - fetcher = partial(self._fetch_page, additional_params=filters, version=self._api_version) - return self._paginator.paginate(page_fetcher=fetcher, - collection_builder=self._build_collection_elements, - per_page=per_page) + fetcher = partial( + self._fetch_page, additional_params=filters, version=self._api_version + ) + return self._paginator.paginate( + page_fetcher=fetcher, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) def list(self, *, per_page: int = 1000) -> Iterator[Project]: """ @@ -711,24 +827,25 @@ def search_all(self, search_params: Optional[Dict]) -> Iterable[Dict]: """ collections = [] - query_params = {'userId': ""} + query_params = {"userId": ""} - json = {} if search_params is None else {'search_params': search_params} + json = {} if search_params is None else {"search_params": search_params} - data = self.session.post_resource(self._get_path(action="search"), - params=query_params, - json=json, - version=self._api_version) + data = self.session.post_resource( + self._get_path(action="search"), + params=query_params, + json=json, + version=self._api_version, + ) if self._collection_key is not None: collections = data[self._collection_key] return collections - def search(self, - *, - search_params: Optional[dict] = None, - per_page: int = 1000) -> Iterable[Project]: + def search( + self, *, search_params: Optional[dict] = None, per_page: int = 1000 + ) -> Iterable[Project]: """ Search for projects matching the desired name or description. @@ -790,7 +907,9 @@ def archive(self, uid: Union[UUID, str]) -> Response: # Only the team-agnostic project archive is implemented if self.team_id is None: path = self._get_path(uid, action="archive") - return self.session.post_resource(path, version=self._api_version, json=None) + return self.session.post_resource( + path, version=self._api_version, json=None + ) else: return ProjectCollection(session=self.session).archive(uid) @@ -799,7 +918,9 @@ def restore(self, uid: Union[UUID, str]) -> Response: # Only the team-agnostic project restore is implemented if self.team_id is None: path = self._get_path(uid, action="restore") - return self.session.post_resource(path, version=self._api_version, json=None) + return self.session.post_resource( + path, version=self._api_version, json=None + ) else: return ProjectCollection(session=self.session).restore(uid) diff --git a/src/citrine/resources/project_member.py b/src/citrine/resources/project_member.py index 06b6e0066..166cacc21 100644 --- a/src/citrine/resources/project_member.py +++ b/src/citrine/resources/project_member.py @@ -5,17 +5,20 @@ class ProjectMember: """A Member of a Project.""" - def __init__(self, - *, - user: User, - project: 'Project', # noqa: F821 - role: ROLES): + def __init__( + self, + *, + user: User, + project: "Project", # noqa: F821 + role: ROLES, + ): self.user: User = user # To avoid circular dependency, use forward-reference for type definition # https://www.python.org/dev/peps/pep-0484/#forward-references - self.project: 'Project' = project # noqa: F821 + self.project: "Project" = project # noqa: F821 self.role: ROLES = role def __str__(self): - return ''\ - .format(self.user.screen_name, self.role, self.project.name) + return "".format( + self.user.screen_name, self.role, self.project.name + ) diff --git a/src/citrine/resources/property_template.py b/src/citrine/resources/property_template.py index 1391669c2..14540e935 100644 --- a/src/citrine/resources/property_template.py +++ b/src/citrine/resources/property_template.py @@ -1,17 +1,23 @@ """Resources that represent property templates.""" + from typing import List, Dict, Optional, Type from citrine._rest.resource import GEMDResource -from citrine.resources.attribute_templates import AttributeTemplate, AttributeTemplateCollection +from citrine.resources.attribute_templates import ( + AttributeTemplate, + AttributeTemplateCollection, +) from gemd.entity.bounds.base_bounds import BaseBounds -from gemd.entity.template.property_template import PropertyTemplate as GEMDPropertyTemplate +from gemd.entity.template.property_template import ( + PropertyTemplate as GEMDPropertyTemplate, +) class PropertyTemplate( - GEMDResource['PropertyTemplate'], + GEMDResource["PropertyTemplate"], AttributeTemplate, GEMDPropertyTemplate, - typ=GEMDPropertyTemplate.typ + typ=GEMDPropertyTemplate.typ, ): """ A property template. @@ -37,28 +43,36 @@ class PropertyTemplate( _response_key = GEMDPropertyTemplate.typ # 'property_template' - def __init__(self, - name: str, - *, - bounds: BaseBounds, - uids: Optional[Dict[str, str]] = None, - description: Optional[str] = None, - tags: Optional[List[str]] = None): + def __init__( + self, + name: str, + *, + bounds: BaseBounds, + uids: Optional[Dict[str, str]] = None, + description: Optional[str] = None, + tags: Optional[List[str]] = None, + ): if uids is None: uids = dict() super(AttributeTemplate, self).__init__() - GEMDPropertyTemplate.__init__(self, name=name, bounds=bounds, tags=tags, - uids=uids, description=description) + GEMDPropertyTemplate.__init__( + self, + name=name, + bounds=bounds, + tags=tags, + uids=uids, + description=description, + ) def __str__(self): - return ''.format(self.name) + return "".format(self.name) class PropertyTemplateCollection(AttributeTemplateCollection[PropertyTemplate]): """A collection of property templates.""" - _individual_key = 'property_template' - _collection_key = 'property_templates' + _individual_key = "property_template" + _collection_key = "property_templates" _resource = PropertyTemplate @classmethod diff --git a/src/citrine/resources/report.py b/src/citrine/resources/report.py index 7a0a7efa8..6dffe214c 100644 --- a/src/citrine/resources/report.py +++ b/src/citrine/resources/report.py @@ -1,4 +1,5 @@ """A resource that represents a single module report.""" + from typing import Optional, Union from uuid import UUID @@ -8,7 +9,7 @@ from citrine.informatics.reports import Report -class ReportResource(Resource['ReportResource']): +class ReportResource(Resource["ReportResource"]): """Defines a resource for fetching reports from a module. Parameters @@ -18,24 +19,30 @@ class ReportResource(Resource['ReportResource']): """ - _path_template = '/projects/{project_id}/predictors/{predictor_id}/versions/{version}/report' - _api_version = 'v3' + _path_template = ( + "/projects/{project_id}/predictors/{predictor_id}/versions/{version}/report" + ) + _api_version = "v3" def __init__(self, project_id: UUID, session: Session): self.project_id = project_id self.session = session - def get(self, - *, - predictor_id: Union[UUID, str], - predictor_version: Optional[Union[int, str]] = None) -> Report: + def get( + self, + *, + predictor_id: Union[UUID, str], + predictor_version: Optional[Union[int, str]] = None, + ) -> Report: """Gets a single report keyed on the predictor ID and (optionally) version.""" version = predictor_version or "most_recent" - url_path = format_escaped_url(self._path_template, - project_id=self.project_id, - predictor_id=str(predictor_id), - version=version) + url_path = format_escaped_url( + self._path_template, + project_id=self.project_id, + predictor_id=str(predictor_id), + version=version, + ) data = self.session.get_resource(url_path, version=self._api_version) report = Report.build(data) diff --git a/src/citrine/resources/response.py b/src/citrine/resources/response.py index b0bb25ea5..8aeaac6cf 100644 --- a/src/citrine/resources/response.py +++ b/src/citrine/resources/response.py @@ -22,7 +22,7 @@ def _get_body_string(self): return "No body available" def __repr__(self): - return f'Response({self._get_status_string()!r}, {self._get_body_string()!r})' + return f"Response({self._get_status_string()!r}, {self._get_body_string()!r})" def __str__(self): - return f'' + return f"" diff --git a/src/citrine/resources/sample_design_space_execution.py b/src/citrine/resources/sample_design_space_execution.py index 9c5d0f953..3f31436b7 100644 --- a/src/citrine/resources/sample_design_space_execution.py +++ b/src/citrine/resources/sample_design_space_execution.py @@ -1,10 +1,13 @@ """Resources that represent both individual and collections of sample design space executions.""" + from typing import Union, Iterator from uuid import UUID from citrine._rest.collection import Collection from citrine._session import Session -from citrine.informatics.executions.sample_design_space_execution import SampleDesignSpaceExecution +from citrine.informatics.executions.sample_design_space_execution import ( + SampleDesignSpaceExecution, +) from citrine.informatics.design_spaces.sample_design_space import SampleDesignSpaceInput from citrine.resources.response import Response @@ -12,10 +15,10 @@ class SampleDesignSpaceExecutionCollection(Collection["SampleDesignSpaceExecution"]): """A collection of SampleDesignSpaceExecutions.""" - _api_version = 'v3' - _path_template = '/projects/{project_id}/design-spaces/{design_space_id}/sample' + _api_version = "v3" + _path_template = "/projects/{project_id}/design-spaces/{design_space_id}/sample" _individual_key = None - _collection_key = 'response' + _collection_key = "response" _resource = SampleDesignSpaceExecution def __init__(self, project_id: UUID, design_space_id: UUID, session: Session): @@ -48,9 +51,11 @@ def update(self, model: SampleDesignSpaceExecution) -> SampleDesignSpaceExecutio """Cannot update an execution.""" raise NotImplementedError("Cannot update a SampleDesignSpaceExecution.") - def list(self, *, - per_page: int = 10, - ) -> Iterator[SampleDesignSpaceExecution]: + def list( + self, + *, + per_page: int = 10, + ) -> Iterator[SampleDesignSpaceExecution]: """ Paginate over the elements of the collection. @@ -70,12 +75,12 @@ def list(self, *, Resources in this collection. """ - return self._paginator.paginate(page_fetcher=self._fetch_page, - collection_builder=self._build_collection_elements, - per_page=per_page) + return self._paginator.paginate( + page_fetcher=self._fetch_page, + collection_builder=self._build_collection_elements, + per_page=per_page, + ) def delete(self, uid: Union[UUID, str]) -> Response: """Sample Design Space Executions cannot be deleted or archived.""" - raise NotImplementedError( - "Sample Design Space Executions cannot be deleted" - ) + raise NotImplementedError("Sample Design Space Executions cannot be deleted") diff --git a/src/citrine/resources/status_detail.py b/src/citrine/resources/status_detail.py index c5a1c4bed..fc7b368b4 100644 --- a/src/citrine/resources/status_detail.py +++ b/src/citrine/resources/status_detail.py @@ -6,7 +6,7 @@ from gemd.enumeration.base_enumeration import BaseEnumeration -StatusDetailType = TypeVar('StatusDetailType', bound='StatusDetail') +StatusDetailType = TypeVar("StatusDetailType", bound="StatusDetail") class StatusLevelEnum(BaseEnumeration): diff --git a/src/citrine/resources/table_config.py b/src/citrine/resources/table_config.py index 1421ef82e..498339a7e 100644 --- a/src/citrine/resources/table_config.py +++ b/src/citrine/resources/table_config.py @@ -17,17 +17,28 @@ from citrine.resources.data_concepts import CITRINE_SCOPE, _make_link_by_uid from citrine.resources.process_template import ProcessTemplate from citrine.gemd_queries.gemd_query import GemdQuery -from citrine.gemtables.columns import Column, MeanColumn, IdentityColumn, OriginalUnitsColumn, \ - ConcatColumn +from citrine.gemtables.columns import ( + Column, + MeanColumn, + IdentityColumn, + OriginalUnitsColumn, + ConcatColumn, +) from citrine.gemtables.rows import Row from citrine.gemtables.variables import ( - Variable, IngredientIdentifierByProcessTemplateAndName, IngredientQuantityByProcessAndName, - IngredientQuantityDimension, IngredientIdentifierInOutput, IngredientQuantityInOutput, - IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput + Variable, + IngredientIdentifierByProcessTemplateAndName, + IngredientQuantityByProcessAndName, + IngredientQuantityDimension, + IngredientIdentifierInOutput, + IngredientQuantityInOutput, + IngredientLabelsSetByProcessAndName, + IngredientLabelsSetInOutput, ) from typing import TYPE_CHECKING -if TYPE_CHECKING: # pragma: no cover + +if TYPE_CHECKING: # pragma: no cover from citrine.resources.project import Project from citrine.resources.team import Team @@ -100,12 +111,12 @@ def _get_dups(lst: List) -> List: # Hmmn, this looks like a potentially costly operation?! return [x for x in lst if lst.count(x) > 1] - config_uid = properties.Optional(properties.UUID(), 'definition_id') + config_uid = properties.Optional(properties.UUID(), "definition_id") """:Optional[UUID]: Unique ID of the table config, independent of its version.""" - version_number = properties.Optional(properties.Integer, 'version_number') + version_number = properties.Optional(properties.Integer, "version_number") """:Optional[int]: The version of the table config, starting from 1. It increases every time the table config is updated.""" - version_uid = properties.Optional(properties.UUID(), 'id') + version_uid = properties.Optional(properties.UUID(), "id") """:Optional[UUID]: Unique ID that specifies one version of one table config.""" name = properties.String("name") @@ -119,15 +130,18 @@ def _get_dups(lst: List) -> List: properties.Enumeration(TableFromGemdQueryAlgorithm), "generation_algorithm" ) - def __init__(self, name: str, - *, - description: str, - datasets: List[UUID], - variables: List[Variable], - rows: List[Row], - columns: List[Column], - gemd_query: GemdQuery = None, - generation_algorithm: Optional[TableFromGemdQueryAlgorithm] = None): + def __init__( + self, + name: str, + *, + description: str, + datasets: List[UUID], + variables: List[Variable], + rows: List[Row], + columns: List[Column], + gemd_query: GemdQuery = None, + generation_algorithm: Optional[TableFromGemdQueryAlgorithm] = None, + ): self.name = name self.description = description self.datasets = datasets @@ -143,18 +157,26 @@ def __init__(self, name: str, names = [x.name for x in variables] dup_names = self._get_dups(names) if len(dup_names) > 0: - raise ValueError("Multiple variables defined these names," - " which much be unique: {}".format(dup_names)) + raise ValueError( + "Multiple variables defined these names," + " which much be unique: {}".format(dup_names) + ) headers = [x.headers for x in variables] dup_headers = self._get_dups(headers) if len(dup_headers) > 0: - raise ValueError("Multiple variables defined these headers," - " which much be unique: {}".format(dup_headers)) + raise ValueError( + "Multiple variables defined these headers," + " which much be unique: {}".format(dup_headers) + ) - missing_variables = [x.data_source for x in columns if x.data_source not in names] + missing_variables = [ + x.data_source for x in columns if x.data_source not in names + ] if len(missing_variables) > 0: - raise ValueError("The data_source of the columns must match one of the variable names," - " but {} were missing".format(missing_variables)) + raise ValueError( + "The data_source of the columns must match one of the variable names," + " but {} were missing".format(missing_variables) + ) @property def uid(self) -> UUID: @@ -166,12 +188,14 @@ def uid(self, new_uid: Union[str, UUID]) -> None: """Set the unique ID of the table config, independent of its version.""" self.config_uid = new_uid - def add_columns(self, *, - variable: Variable, - columns: List[Column], - name: Optional[str] = None, - description: Optional[str] = None - ) -> 'TableConfig': + def add_columns( + self, + *, + variable: Variable, + columns: List[Column], + name: Optional[str] = None, + description: Optional[str] = None, + ) -> "TableConfig": """Add a variable and one or more columns to this TableConfig (out-of-place). This method checks that the variable name is not already in use and that the columns @@ -191,12 +215,17 @@ def add_columns(self, *, """ if variable.name in [x.name for x in self.variables]: - raise ValueError("The variable name {} is already used".format(variable.name)) + raise ValueError( + "The variable name {} is already used".format(variable.name) + ) mismatched_data_source = [x for x in columns if x.data_source != variable.name] if len(mismatched_data_source): - raise ValueError("Column.data_source must be {} but found {}" - .format(variable.name, mismatched_data_source)) + raise ValueError( + "Column.data_source must be {} but found {}".format( + variable.name, mismatched_data_source + ) + ) new_config = TableConfig( name=name or self.name, @@ -204,21 +233,23 @@ def add_columns(self, *, datasets=copy(self.datasets), rows=copy(self.rows), variables=copy(self.variables) + [variable], - columns=copy(self.columns) + columns + columns=copy(self.columns) + columns, ) new_config.version_number = copy(self.version_number) new_config.config_uid = copy(self.config_uid) new_config.version_uid = copy(self.version_uid) return new_config - def add_all_ingredients(self, *, - process_template: Union[LinkByUID, ProcessTemplate, str, UUID], - project: 'Project' = None, - team: 'Team' = None, - quantity_dimension: IngredientQuantityDimension, - scope: str = CITRINE_SCOPE, - unit: Optional[str] = None - ): + def add_all_ingredients( + self, + *, + process_template: Union[LinkByUID, ProcessTemplate, str, UUID], + project: "Project" = None, + team: "Team" = None, + quantity_dimension: IngredientQuantityDimension, + scope: str = CITRINE_SCOPE, + unit: Optional[str] = None, + ): """Add variables and columns for all of the possible ingredients in a process. For each allowed ingredient name in the process template there is a column for the id of @@ -240,9 +271,11 @@ def add_all_ingredients(self, *, """ if project is not None: - warn("Adding ingredients to a table config through a project is deprecated as of " - "3.4.0, and will be removed in 4.0.0. Please use a team instead.", - DeprecationWarning) + warn( + "Adding ingredients to a table config through a project is deprecated as of " + "3.4.0, and will be removed in 4.0.0. Please use a team instead.", + DeprecationWarning, + ) principal = project elif team is not None: principal = team @@ -253,38 +286,47 @@ def add_all_ingredients(self, *, IngredientQuantityDimension.ABSOLUTE: "absolute quantity", IngredientQuantityDimension.MASS: "mass fraction", IngredientQuantityDimension.VOLUME: "volume fraction", - IngredientQuantityDimension.NUMBER: "number fraction" + IngredientQuantityDimension.NUMBER: "number fraction", } link = _make_link_by_uid(process_template) process: ProcessTemplate = principal.process_templates.get(uid=link) if not process.allowed_names: raise RuntimeError( - "Cannot add ingredients for process template \'{}\' because it has no defined " - "ingredients (allowed_names is not defined).".format(process.name)) + "Cannot add ingredients for process template '{}' because it has no defined " + "ingredients (allowed_names is not defined).".format(process.name) + ) new_variables = [] new_columns = [] for name in process.allowed_names: identifier_variable = IngredientIdentifierByProcessTemplateAndName( - name='_'.join([process.name, name, str(hash(link.id + name + scope))]), + name="_".join([process.name, name, str(hash(link.id + name + scope))]), headers=[process.name, name, scope], process_template=link, ingredient_name=name, - scope=scope + scope=scope, ) quantity_variable = IngredientQuantityByProcessAndName( - name='_'.join([process.name, name, str(hash( - link.id + name + dimension_display[quantity_dimension]))]), + name="_".join( + [ + process.name, + name, + str( + hash(link.id + name + dimension_display[quantity_dimension]) + ), + ] + ), headers=[process.name, name, dimension_display[quantity_dimension]], process_template=link, ingredient_name=name, quantity_dimension=quantity_dimension, - unit=unit + unit=unit, ) label_variable = IngredientLabelsSetByProcessAndName( - name='_'.join([process.name, name, str(hash( - link.id + name + 'Labels'))]), - headers=[process.name, name, 'Labels'], + name="_".join( + [process.name, name, str(hash(link.id + name + "Labels"))] + ), + headers=[process.name, name, "Labels"], process_template=link, ingredient_name=name, ) @@ -295,13 +337,15 @@ def add_all_ingredients(self, *, new_variables.append(quantity_variable) new_columns.append(MeanColumn(data_source=quantity_variable.name)) if quantity_dimension == IngredientQuantityDimension.ABSOLUTE: - new_columns.append(OriginalUnitsColumn(data_source=quantity_variable.name)) + new_columns.append( + OriginalUnitsColumn(data_source=quantity_variable.name) + ) if label_variable.name not in [var.name for var in self.variables]: new_variables.append(label_variable) new_columns.append( ConcatColumn( data_source=label_variable.name, - subcolumn=IdentityColumn(data_source=label_variable.name) + subcolumn=IdentityColumn(data_source=label_variable.name), ) ) @@ -311,21 +355,23 @@ def add_all_ingredients(self, *, datasets=copy(self.datasets), rows=copy(self.rows), variables=copy(self.variables) + new_variables, - columns=copy(self.columns) + new_columns + columns=copy(self.columns) + new_columns, ) new_config.version_number = copy(self.version_number) new_config.config_uid = copy(self.config_uid) new_config.version_uid = copy(self.version_uid) return new_config - def add_all_ingredients_in_output(self, *, - process_templates: List[LinkByUID], - project: 'Project' = None, - team: 'Team' = None, - quantity_dimension: IngredientQuantityDimension, - scope: str = CITRINE_SCOPE, - unit: Optional[str] = None - ): + def add_all_ingredients_in_output( + self, + *, + process_templates: List[LinkByUID], + project: "Project" = None, + team: "Team" = None, + quantity_dimension: IngredientQuantityDimension, + scope: str = CITRINE_SCOPE, + unit: Optional[str] = None, + ): """Add variables and columns for all possible ingredients in a list of processes. For each allowed ingredient name in the union of all passed process templates there is a @@ -350,9 +396,11 @@ def add_all_ingredients_in_output(self, *, """ if project is not None: - warn("Adding ingredients to a table config through a project is deprecated as of " - "3.4.0, and will be removed in 4.0.0. Please use a team instead.", - DeprecationWarning) + warn( + "Adding ingredients to a table config through a project is deprecated as of " + "3.4.0, and will be removed in 4.0.0. Please use a team instead.", + DeprecationWarning, + ) principal = project elif team is not None: principal = team @@ -363,40 +411,46 @@ def add_all_ingredients_in_output(self, *, IngredientQuantityDimension.ABSOLUTE: "absolute quantity", IngredientQuantityDimension.MASS: "mass fraction", IngredientQuantityDimension.VOLUME: "volume fraction", - IngredientQuantityDimension.NUMBER: "number fraction" + IngredientQuantityDimension.NUMBER: "number fraction", } union_allowed_names = [] for process_template_link in process_templates: - process: ProcessTemplate = principal.process_templates.get(process_template_link) + process: ProcessTemplate = principal.process_templates.get( + process_template_link + ) if not process.allowed_names: raise RuntimeError( f"Cannot add ingredients for process template '{process.name}' " "because it has no defined ingredients (allowed_names is not defined)" ) else: - union_allowed_names = list(set(union_allowed_names) | set(process.allowed_names)) + union_allowed_names = list( + set(union_allowed_names) | set(process.allowed_names) + ) new_variables = [] new_columns = [] for name in union_allowed_names: identifier_variable = IngredientIdentifierInOutput( - name='_'.join([name, str(hash(name + scope))]), + name="_".join([name, str(hash(name + scope))]), headers=[name, scope], process_templates=process_templates, ingredient_name=name, - scope=scope + scope=scope, ) quantity_variable = IngredientQuantityInOutput( - name='_'.join([name, str(hash(name + dimension_display[quantity_dimension]))]), + name="_".join( + [name, str(hash(name + dimension_display[quantity_dimension]))] + ), headers=[name, dimension_display[quantity_dimension]], process_templates=process_templates, ingredient_name=name, quantity_dimension=quantity_dimension, - unit=unit + unit=unit, ) label_variable = IngredientLabelsSetInOutput( - name='_'.join([name, str(hash(name + 'Labels'))]), - headers=[name, 'Labels'], + name="_".join([name, str(hash(name + "Labels"))]), + headers=[name, "Labels"], process_templates=process_templates, ingredient_name=name, ) @@ -407,13 +461,15 @@ def add_all_ingredients_in_output(self, *, new_variables.append(quantity_variable) new_columns.append(MeanColumn(data_source=quantity_variable.name)) if quantity_dimension == IngredientQuantityDimension.ABSOLUTE: - new_columns.append(OriginalUnitsColumn(data_source=quantity_variable.name)) + new_columns.append( + OriginalUnitsColumn(data_source=quantity_variable.name) + ) if label_variable.name not in [var.name for var in self.variables]: new_variables.append(label_variable) new_columns.append( ConcatColumn( data_source=label_variable.name, - subcolumn=IdentityColumn(data_source=label_variable.name) + subcolumn=IdentityColumn(data_source=label_variable.name), ) ) @@ -423,7 +479,7 @@ def add_all_ingredients_in_output(self, *, datasets=copy(self.datasets), rows=copy(self.rows), variables=copy(self.variables) + new_variables, - columns=copy(self.columns) + new_columns + columns=copy(self.columns) + new_columns, ) new_config.version_number = copy(self.version_number) new_config.config_uid = copy(self.config_uid) @@ -435,15 +491,17 @@ class TableConfigCollection(Collection[TableConfig]): """Represents the collection of all Table Configs associated with a project.""" # FIXME (DML): use newly named properties when they're available - _path_template = 'projects/{project_id}/ara-definitions' - _collection_key = 'definitions' + _path_template = "projects/{project_id}/ara-definitions" + _collection_key = "definitions" _resource = TableConfig # NOTE: This isn't actually an 'individual key' - both parts (version and # definition) are necessary _individual_key = None - def __init__(self, *args, team_id: UUID, project_id: UUID = None, session: Session = None): + def __init__( + self, *args, team_id: UUID, project_id: UUID = None, session: Session = None + ): args = _pad_positional_args(args, 2) self.project_id = project_id or args[0] self.session: Session = session or args[1] @@ -460,16 +518,20 @@ def get(self, uid: Union[UUID, str], *, version: Optional[int] = None): """ if uid is None: - raise ValueError("Cannot get when uid=None. Are you using a registered resource?") + raise ValueError( + "Cannot get when uid=None. Are you using a registered resource?" + ) if version is not None: path = self._get_path(uid, action=["versions", version]) data = self.session.get_resource(path) else: path = self._get_path(uid) data = self.session.get_resource(path) - version_numbers = [version_data['version_number'] for version_data in data['versions']] + version_numbers = [ + version_data["version_number"] for version_data in data["versions"] + ] index = version_numbers.index(max(version_numbers)) - data['version'] = data['versions'][index] + data["version"] = data["versions"][index] return self.build(data) def get_for_table(self, table: "GemTable") -> TableConfig: # noqa: F821 @@ -489,29 +551,33 @@ def get_for_table(self, table: "GemTable") -> TableConfig: # noqa: F821 """ # the route to fetch the config is built off the display table route tree path = format_escaped_url( - 'projects/{}/display-tables/{}/versions/{}/definition', - self.project_id, table.uid, table.version) + "projects/{}/display-tables/{}/versions/{}/definition", + self.project_id, + table.uid, + table.version, + ) data = self.session.get_resource(path) return self.build(data) def build(self, data: dict) -> TableConfig: """Build an individual Table Config from a dictionary.""" - version_data = data['version'] - table_config = TableConfig.build(version_data['ara_definition']) - table_config.version_number = version_data['version_number'] - table_config.version_uid = version_data['id'] - table_config.config_uid = data['definition']['id'] + version_data = data["version"] + table_config = TableConfig.build(version_data["ara_definition"]) + table_config.version_number = version_data["version_number"] + table_config.version_uid = version_data["id"] + table_config.config_uid = data["definition"]["id"] table_config.team_id = self.team_id table_config.project_id = self.project_id table_config.session = self.session return table_config def default_for_material( - self, *, - material: Union[MaterialRun, LinkByUID, str, UUID], - name: str, - description: str = None, - algorithm: Optional[TableBuildAlgorithm] = None + self, + *, + material: Union[MaterialRun, LinkByUID, str, UUID], + name: str, + description: str = None, + algorithm: Optional[TableBuildAlgorithm] = None, ) -> Tuple[TableConfig, List[Tuple[Variable, Column]]]: """ Build best-guess default table config for provided terminal material's history. @@ -547,33 +613,33 @@ def default_for_material( """ link = _make_link_by_uid(material) params = { - 'id': link.id, - 'scope': link.scope, - 'name': name, + "id": link.id, + "scope": link.scope, + "name": name, } if description is not None: - params['description'] = description + params["description"] = description if algorithm is not None: if isinstance(algorithm, TableBuildAlgorithm): - params['algorithm'] = algorithm.value + params["algorithm"] = algorithm.value else: # Not per spec, but be forgiving - params['algorithm'] = str(algorithm) + params["algorithm"] = str(algorithm) data = self.session.get_resource( - format_escaped_url('teams/{}/table-configs/default', self.team_id), + format_escaped_url("teams/{}/table-configs/default", self.team_id), params=params, ) - config = TableConfig.build(data['config']) - ambiguous = [(Variable.build(v), Column.build(c)) for v, c in data['ambiguous']] + config = TableConfig.build(data["config"]) + ambiguous = [(Variable.build(v), Column.build(c)) for v, c in data["ambiguous"]] return config, ambiguous def from_query( - self, - gemd_query: GemdQuery, - *, - name: str = None, - description: str = None, - algorithm: Optional[TableFromGemdQueryAlgorithm] = None, - register_config: bool = False + self, + gemd_query: GemdQuery, + *, + name: str = None, + description: str = None, + algorithm: Optional[TableFromGemdQueryAlgorithm] = None, + register_config: bool = False, ) -> Tuple[TableConfig, List[Tuple[Variable, Column]]]: """ Build a TableConfig based on the results of a database query. @@ -600,35 +666,33 @@ def from_query( """ if name is None: - collection = DatasetCollection( - session=self.session, - team_id=self.team_id + collection = DatasetCollection(session=self.session, team_id=self.team_id) + name = ( + f"Automatic Table for Dataset: " + f"{', '.join([collection.get(x).name for x in gemd_query.datasets])}" ) - name = (f"Automatic Table for Dataset: " - f"{', '.join([collection.get(x).name for x in gemd_query.datasets])}") params = {"name": name} if description is not None: - params['description'] = description + params["description"] = description if algorithm is not None: - params['algorithm'] = algorithm + params["algorithm"] = algorithm data = self.session.post_resource( - format_escaped_url('teams/{}/table-configs/from-query', self.team_id), + format_escaped_url("teams/{}/table-configs/from-query", self.team_id), params=params, - json=gemd_query.dump() + json=gemd_query.dump(), ) - config = TableConfig.build(data['config']) - ambiguous = [(Variable.build(v), Column.build(c)) for v, c in data['ambiguous']] + config = TableConfig.build(data["config"]) + ambiguous = [(Variable.build(v), Column.build(c)) for v, c in data["ambiguous"]] if register_config: return self.register(config), ambiguous else: return config, ambiguous - def preview(self, *, - table_config: TableConfig, - preview_materials: List[LinkByUID] = None - ) -> dict: + def preview( + self, *, table_config: TableConfig, preview_materials: List[LinkByUID] = None + ) -> dict: """Preview a Table Config on an explicit set of terminal materials. Parameters @@ -639,13 +703,10 @@ def preview(self, *, List of links to the material runs to use as terminal materials in the preview """ - path = format_escaped_url( - "teams/{}/ara-definitions/preview", - self.team_id - ) + path = format_escaped_url("teams/{}/ara-definitions/preview", self.team_id) body = { "definition": table_config.dump(), - "rows": [x.as_dict() for x in preview_materials] + "rows": [x.as_dict() for x in preview_materials], } return self.session.post_resource(path, body) @@ -678,7 +739,9 @@ def register(self, table_config: TableConfig) -> TableConfig: # 1) The validation requirements are the same for updating and registering an # TableConfig # 2) This prevents users from accidentally registering duplicate Table Configs - data = self.session.put_resource(self._get_path(table_config.config_uid), body) + data = self.session.put_resource( + self._get_path(table_config.config_uid), body + ) data = data[self._individual_key] if self._individual_key else data return self.build(data) @@ -695,10 +758,12 @@ def update(self, table_config: TableConfig) -> TableConfig: :return: The updated Table Config with updated metadata """ if table_config.config_uid is None: - raise ValueError("Cannot update Table Config without a config_uid." - " Please either use register() to initially register this" - " Table Config or retrieve the registered details before calling" - " update()") + raise ValueError( + "Cannot update Table Config without a config_uid." + " Please either use register() to initially register this" + " Table Config or retrieve the registered details before calling" + " update()" + ) return self.register(table_config) def delete(self, uid: Union[UUID, str]): diff --git a/src/citrine/resources/team.py b/src/citrine/resources/team.py index 275babc59..28528089a 100644 --- a/src/citrine/resources/team.py +++ b/src/citrine/resources/team.py @@ -1,4 +1,5 @@ """Resources that represent both individual and collections of teams.""" + from typing import List, Optional, Tuple, Union from uuid import UUID @@ -42,18 +43,21 @@ class TeamMember: """A Member of a Team.""" - def __init__(self, - *, - user: User, - team: 'Team', # noqa: F821 - actions: List[TEAM_ACTIONS]): + def __init__( + self, + *, + user: User, + team: "Team", # noqa: F821 + actions: List[TEAM_ACTIONS], + ): self.user = user - self.team: 'Team' = team # noqa: F821 + self.team: "Team" = team # noqa: F821 self.actions: List[TEAM_ACTIONS] = actions def __str__(self): - return '' \ - .format(self.user.screen_name, self.actions, self.team.name) + return "".format( + self.user.screen_name, self.actions, self.team.name + ) class TeamResourceIDs: @@ -75,22 +79,23 @@ class TeamResourceIDs: _api_version = "v3" - def __init__(self, - session: Session, - team_id: Union[str, UUID], - resource_type: str) -> None: + def __init__( + self, session: Session, team_id: Union[str, UUID], resource_type: str + ) -> None: self.session = session self.team_id = team_id self.resource_type = resource_type def _path(self) -> str: - return format_escaped_url(f'/teams/{self.team_id}') + return format_escaped_url(f"/teams/{self.team_id}") def _list_ids(self, action: str) -> List[str]: query_params = {"domain": self._path(), "action": action} - return self.session.get_resource(f"/{self.resource_type}/authorized-ids", - params=query_params, - version=self._api_version)['ids'] + return self.session.get_resource( + f"/{self.resource_type}/authorized-ids", + params=query_params, + version=self._api_version, + )["ids"] def list_readable(self): """ @@ -129,7 +134,7 @@ def list_shareable(self): return self._list_ids(action=SHARE) -class Team(Resource['Team']): +class Team(Resource["Team"]): """ A Citrine Team. @@ -146,33 +151,31 @@ class Team(Resource['Team']): """ - _response_key = 'team' + _response_key = "team" _resource_type = ResourceTypeEnum.TEAM _api_version = "v3" - name = properties.String('name') + name = properties.String("name") """str: Name of the Team""" - description = properties.Optional(properties.String(), 'description') + description = properties.Optional(properties.String(), "description") """str: Description of the Team""" - uid = properties.Optional(properties.UUID(), 'id') + uid = properties.Optional(properties.UUID(), "id") """UUID: Unique uuid4 identifier of this team.""" - created_at = properties.Optional(properties.Datetime(), 'created_at') + created_at = properties.Optional(properties.Datetime(), "created_at") """int: Time the team was created, in seconds since epoch.""" - def __init__(self, - name: str, - *, - description: str = "", - session: Optional[Session] = None): + def __init__( + self, name: str, *, description: str = "", session: Optional[Session] = None + ): self.name: str = name self.description: str = description self.session: Session = session def __str__(self): - return ''.format(self.name) + return "".format(self.name) def _path(self): - return format_escaped_url('/teams/{team_id}', team_id=self.uid) + return format_escaped_url("/teams/{team_id}", team_id=self.uid) def list_members(self) -> List[TeamMember]: """ @@ -186,9 +189,14 @@ def list_members(self) -> List[TeamMember]: The members of the current team """ - response = self.session.get_resource(self._path() + "/users", version=self._api_version) + response = self.session.get_resource( + self._path() + "/users", version=self._api_version + ) members = response["users"] - return [TeamMember(user=User.build(m), team=self, actions=m["actions"]) for m in members] + return [ + TeamMember(user=User.build(m), team=self, actions=m["actions"]) + for m in members + ] def get_member(self, user_id: Union[str, UUID, User]) -> TeamMember: """ @@ -209,7 +217,7 @@ def get_member(self, user_id: Union[str, UUID, User]) -> TeamMember: """ if isinstance(user_id, User): user_id = user_id.uid - path = self._path() + format_escaped_url('/users/{user_id}', user_id=user_id) + path = self._path() + format_escaped_url("/users/{user_id}", user_id=user_id) member = self.session.get_resource(path=path, version=self._api_version)["user"] return TeamMember(user=User.build(member), team=self, actions=member["actions"]) @@ -245,14 +253,19 @@ def remove_user(self, user_id: Union[str, UUID, User]) -> bool: """ if isinstance(user_id, User): user_id = user_id.uid - self.session.checked_post(self._path() + "/users/batch-remove", - json={"ids": [str(user_id)]}, version=self._api_version) + self.session.checked_post( + self._path() + "/users/batch-remove", + json={"ids": [str(user_id)]}, + version=self._api_version, + ) return True # note: only get here if checked_post doesn't raise error - def add_user(self, - user_id: Union[str, UUID, User], - *, - actions: Optional[List[TEAM_ACTIONS]] = None) -> bool: + def add_user( + self, + user_id: Union[str, UUID, User], + *, + actions: Optional[List[TEAM_ACTIONS]] = None, + ) -> bool: """ Add a User to a Team. @@ -283,10 +296,9 @@ def add_user(self, actions = [READ] return self.update_user_action(user_id, actions=actions) - def update_user_action(self, - user_id: Union[str, UUID, User], - *, - actions: List[TEAM_ACTIONS]) -> bool: + def update_user_action( + self, user_id: Union[str, UUID, User], *, actions: List[TEAM_ACTIONS] + ) -> bool: """ Overwrites a User's action permissions in the Team. @@ -308,14 +320,16 @@ def update_user_action(self, """ if isinstance(user_id, User): user_id = user_id.uid - self.session.checked_put(self._path() + "/users", version=self._api_version, - json={'id': str(user_id), "actions": actions}) + self.session.checked_put( + self._path() + "/users", + version=self._api_version, + json={"id": str(user_id), "actions": actions}, + ) return True - def share(self, - *, - resource: Resource, - target_team_id: Union[str, UUID, "Team"]) -> bool: + def share( + self, *, resource: Resource, target_team_id: Union[str, UUID, "Team"] + ) -> bool: """ Share a resource with another team. @@ -340,13 +354,16 @@ def share(self, payload = { "resource_type": resource_access["type"], "resource_id": resource_access["id"], - "target_team_id": str(target_team_id) + "target_team_id": str(target_team_id), } - self.session.checked_post(self._path() + "/shared-resources", - version=self._api_version, json=payload) + self.session.checked_post( + self._path() + "/shared-resources", version=self._api_version, json=payload + ) return True - def un_share(self, *, resource: Resource, target_team_id: Union[str, UUID, "Team"]) -> bool: + def un_share( + self, *, resource: Resource, target_team_id: Union[str, UUID, "Team"] + ) -> bool: """ Revoke the share of a particular resource to a secondary team. @@ -372,7 +389,7 @@ def un_share(self, *, resource: Resource, target_team_id: Union[str, UUID, "Team self.session.checked_delete( self._path() + f"/shared-resources/{resource_type}/{resource_id}", version=self._api_version, - json={"target_team_id": str(target_team_id)} + json={"target_team_id": str(target_team_id)}, ) return True @@ -387,10 +404,10 @@ def owned_dataset_ids(self) -> List[str]: """ query_params = {"userId": "", "domain": self._path(), "action": "WRITE"} - response = self.session.get_resource("/DATASET/authorized-ids", - params=query_params, - version="v3") - return response['ids'] + response = self.session.get_resource( + "/DATASET/authorized-ids", params=query_params, version="v3" + ) + return response["ids"] @property def projects(self) -> ProjectCollection: @@ -405,9 +422,11 @@ def analyses(self) -> AnalysisWorkflowCollection: @property def dataset_ids(self) -> TeamResourceIDs: """Return a TeamResourceIDs instance for listing published dataset IDs.""" - return TeamResourceIDs(session=self.session, - team_id=self.uid, - resource_type=ResourceTypeEnum.DATASET.value) + return TeamResourceIDs( + session=self.session, + team_id=self.uid, + resource_type=ResourceTypeEnum.DATASET.value, + ) @property def datasets(self) -> DatasetCollection: @@ -417,106 +436,142 @@ def datasets(self) -> DatasetCollection: @property def module_ids(self) -> TeamResourceIDs: """Return a TeamResourceIDs instance for listing published module IDs.""" - return TeamResourceIDs(session=self.session, - team_id=self.uid, - resource_type=ResourceTypeEnum.MODULE.value) + return TeamResourceIDs( + session=self.session, + team_id=self.uid, + resource_type=ResourceTypeEnum.MODULE.value, + ) @property def table_ids(self) -> TeamResourceIDs: """Return a TeamResourceIDs instance for listing published table IDs.""" - return TeamResourceIDs(session=self.session, - team_id=self.uid, - resource_type=ResourceTypeEnum.TABLE.value) + return TeamResourceIDs( + session=self.session, + team_id=self.uid, + resource_type=ResourceTypeEnum.TABLE.value, + ) @property def table_definition_ids(self) -> TeamResourceIDs: """Return a TeamResourceIDs instance for listing published table definition IDs.""" - return TeamResourceIDs(session=self.session, - team_id=self.uid, - resource_type=ResourceTypeEnum.TABLE_DEFINITION.value) + return TeamResourceIDs( + session=self.session, + team_id=self.uid, + resource_type=ResourceTypeEnum.TABLE_DEFINITION.value, + ) @property def property_templates(self) -> PropertyTemplateCollection: """Return a resource representing all property templates in this dataset.""" - return PropertyTemplateCollection(team_id=self.uid, dataset_id=None, session=self.session) + return PropertyTemplateCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def condition_templates(self) -> ConditionTemplateCollection: """Return a resource representing all condition templates in this dataset.""" - return ConditionTemplateCollection(team_id=self.uid, dataset_id=None, session=self.session) + return ConditionTemplateCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def parameter_templates(self) -> ParameterTemplateCollection: """Return a resource representing all parameter templates in this dataset.""" - return ParameterTemplateCollection(team_id=self.uid, dataset_id=None, session=self.session) + return ParameterTemplateCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def material_templates(self) -> MaterialTemplateCollection: """Return a resource representing all material templates in this dataset.""" - return MaterialTemplateCollection(team_id=self.uid, dataset_id=None, session=self.session) + return MaterialTemplateCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def measurement_templates(self) -> MeasurementTemplateCollection: """Return a resource representing all measurement templates in this dataset.""" - return MeasurementTemplateCollection(team_id=self.uid, - dataset_id=None, - session=self.session) + return MeasurementTemplateCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def process_templates(self) -> ProcessTemplateCollection: """Return a resource representing all process templates in this dataset.""" - return ProcessTemplateCollection(team_id=self.uid, dataset_id=None, session=self.session) + return ProcessTemplateCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def process_runs(self) -> ProcessRunCollection: """Return a resource representing all process runs in this dataset.""" - return ProcessRunCollection(team_id=self.uid, dataset_id=None, session=self.session) + return ProcessRunCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def measurement_runs(self) -> MeasurementRunCollection: """Return a resource representing all measurement runs in this dataset.""" - return MeasurementRunCollection(team_id=self.uid, dataset_id=None, session=self.session) + return MeasurementRunCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def material_runs(self) -> MaterialRunCollection: """Return a resource representing all material runs in this dataset.""" - return MaterialRunCollection(team_id=self.uid, dataset_id=None, session=self.session) + return MaterialRunCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def ingredient_runs(self) -> IngredientRunCollection: """Return a resource representing all ingredient runs in this dataset.""" - return IngredientRunCollection(team_id=self.uid, dataset_id=None, session=self.session) + return IngredientRunCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def process_specs(self) -> ProcessSpecCollection: """Return a resource representing all process specs in this dataset.""" - return ProcessSpecCollection(team_id=self.uid, dataset_id=None, session=self.session) + return ProcessSpecCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def measurement_specs(self) -> MeasurementSpecCollection: """Return a resource representing all measurement specs in this dataset.""" - return MeasurementSpecCollection(team_id=self.uid, dataset_id=None, session=self.session) + return MeasurementSpecCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def material_specs(self) -> MaterialSpecCollection: """Return a resource representing all material specs in this dataset.""" - return MaterialSpecCollection(team_id=self.uid, dataset_id=None, session=self.session) + return MaterialSpecCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def ingredient_specs(self) -> IngredientSpecCollection: """Return a resource representing all ingredient specs in this dataset.""" - return IngredientSpecCollection(team_id=self.uid, dataset_id=None, session=self.session) + return IngredientSpecCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) @property def gemd(self) -> GEMDResourceCollection: """Return a resource representing all GEMD objects/templates in this dataset.""" - return GEMDResourceCollection(team_id=self.uid, dataset_id=None, session=self.session) + return GEMDResourceCollection( + team_id=self.uid, dataset_id=None, session=self.session + ) - def gemd_batch_delete(self, - id_list: List[Union[LinkByUID, UUID, str, BaseEntity]], - *, - timeout: float = 2 * 60, - polling_delay: float = 1.0) -> List[Tuple[LinkByUID, ApiError]]: + def gemd_batch_delete( + self, + id_list: List[Union[LinkByUID, UUID, str, BaseEntity]], + *, + timeout: float = 2 * 60, + polling_delay: float = 1.0, + ) -> List[Tuple[LinkByUID, ApiError]]: """ Remove a set of GEMD objects. @@ -550,12 +605,14 @@ def gemd_batch_delete(self, deleted. """ - return _async_gemd_batch_delete(id_list=id_list, - team_id=self.uid, - session=self.session, - dataset_id=None, - timeout=timeout, - polling_delay=polling_delay) + return _async_gemd_batch_delete( + id_list=id_list, + team_id=self.uid, + session=self.session, + dataset_id=None, + timeout=timeout, + polling_delay=polling_delay, + ) class TeamCollection(AdminCollection[Team]): @@ -569,9 +626,9 @@ class TeamCollection(AdminCollection[Team]): """ - _path_template = '/teams' - _individual_key = 'team' - _collection_key = 'teams' + _path_template = "/teams" + _individual_key = "team" + _collection_key = "teams" _resource = Team _api_version = "v3" @@ -609,7 +666,9 @@ def update(self, team: Team) -> Team: """ url = self._get_path(team.uid) - updated = self.session.patch_resource(url, team.dump(), version=self._api_version) + updated = self.session.patch_resource( + url, team.dump(), version=self._api_version + ) data = updated[self._individual_key] return self.build(data) diff --git a/src/citrine/resources/templates.py b/src/citrine/resources/templates.py index 4faf81223..4eaafc1ec 100644 --- a/src/citrine/resources/templates.py +++ b/src/citrine/resources/templates.py @@ -1,4 +1,5 @@ """Top-level class for all template objects and collections thereof.""" + from abc import ABC from typing import TypeVar diff --git a/src/citrine/resources/user.py b/src/citrine/resources/user.py index 819736c33..8a0147259 100644 --- a/src/citrine/resources/user.py +++ b/src/citrine/resources/user.py @@ -1,4 +1,5 @@ """Resources that represent both individual and collections of users.""" + from typing import Optional from citrine._rest.admin_collection import AdminCollection @@ -7,7 +8,7 @@ from citrine._session import Session -class User(Resource['User']): +class User(Resource["User"]): """ A Citrine User. @@ -29,18 +30,15 @@ class User(Resource['User']): _resource_type = ResourceTypeEnum.USER _session: Optional[Session] = None - uid = properties.Optional(properties.UUID, 'id') - screen_name = properties.String('screen_name') - position = properties.Optional(properties.String(), 'position') - email = properties.String('email') - is_admin = properties.Boolean('is_admin') - - def __init__(self, - *, - screen_name: str, - email: str, - position: Optional[str], - is_admin: bool): + uid = properties.Optional(properties.UUID, "id") + screen_name = properties.String("screen_name") + position = properties.Optional(properties.String(), "position") + email = properties.String("email") + is_admin = properties.Boolean("is_admin") + + def __init__( + self, *, screen_name: str, email: str, position: Optional[str], is_admin: bool + ): self.email: str = email self.position: Optional[str] = position self.screen_name: str = screen_name @@ -52,7 +50,7 @@ def is_internal(self) -> bool: return self.email.split("@")[-1] == "citrine.io" def __str__(self): - return ''.format(self.screen_name) + return "".format(self.screen_name) def get(self): """Retrieve a specific user from the database.""" @@ -62,9 +60,9 @@ def get(self): class UserCollection(AdminCollection[User]): """Represents the collection of all users.""" - _path_template = '/users' - _collection_key = 'users' - _individual_key = 'user' + _path_template = "/users" + _collection_key = "users" + _individual_key = "user" _resource = User def __init__(self, session: Session): @@ -72,7 +70,7 @@ def __init__(self, session: Session): def me(self): """Get information about the current user.""" - data = self.session.get_resource(self._path_template + '/me') + data = self.session.get_resource(self._path_template + "/me") return self.build(data) def build(self, data): @@ -94,15 +92,15 @@ def build(self, data): user._session = self.session return user - def register(self, - *, - screen_name: str, - email: str, - position: str, - is_admin: bool) -> User: + def register( + self, *, screen_name: str, email: str, position: str, is_admin: bool + ) -> User: """Register a User.""" - return super().register(User( - screen_name=screen_name, - email=email, - position=position, - is_admin=is_admin)) + return super().register( + User( + screen_name=screen_name, + email=email, + position=position, + is_admin=is_admin, + ) + ) diff --git a/src/citrine/seeding/find_or_create.py b/src/citrine/seeding/find_or_create.py index b8ff71a69..97e4713e3 100644 --- a/src/citrine/seeding/find_or_create.py +++ b/src/citrine/seeding/find_or_create.py @@ -10,7 +10,7 @@ from citrine._rest.collection import CreationType, Collection logger = getLogger(__name__) -T = TypeVar('T') +T = TypeVar("T") def find_collection(*, collection: Collection[T], name: str) -> Optional[T]: @@ -24,33 +24,33 @@ def find_collection(*, collection: Collection[T], name: str) -> Optional[T]: # try to use search if it is available # call list() to collapse the iterator, otherwise the NotFound # won't show up until collection_list is used - collection_list = list(collection.search(search_params={ - "name": { - "value": name, - "search_method": "EXACT" - } - })) + collection_list = list( + collection.search( + search_params={"name": {"value": name, "search_method": "EXACT"}} + ) + ) except (NotFound, NotImplementedError): # Search must not be available yet or any more collection_list = collection.list() else: collection_list = collection.list() - matching_resources = [resource for resource in collection_list if resource.name == name] + matching_resources = [ + resource for resource in collection_list if resource.name == name + ] if len(matching_resources) > 1: raise ValueError("Found multiple collections with name '{}'".format(name)) if len(matching_resources) == 1: result = matching_resources.pop() - logger.info('Found existing: {}'.format(result)) + logger.info("Found existing: {}".format(result)) return result else: return None -def get_by_name_or_create(*, - collection: Collection[T], - name: str, - default_provider: Callable[..., T]) -> T: +def get_by_name_or_create( + *, collection: Collection[T], name: str, default_provider: Callable[..., T] +) -> T: """ Tries to find a collection by its name (returns first hit). @@ -60,7 +60,9 @@ def get_by_name_or_create(*, if found: return found else: - logger.info('Failed to find resource with name {}, creating one instead.'.format(name)) + logger.info( + "Failed to find resource with name {}, creating one instead.".format(name) + ) return default_provider() @@ -77,34 +79,39 @@ def get_by_name_or_raise_error(*, collection: Collection[T], name: str) -> T: raise ValueError("Did not find resource with the given name: {}".format(name)) -def find_or_create_project(*, - project_collection: ProjectCollection, - project_name: str, - raise_error: bool = False) -> Project: +def find_or_create_project( + *, + project_collection: ProjectCollection, + project_name: str, + raise_error: bool = False, +) -> Project: """ Tries to find a project by name (returns first hit). If not found, creates a new project with the given name """ if project_collection.team_id is None: - raise NotImplementedError("Collection must have a team ID, such as when retrieved with " - "find_or_create_team.") + raise NotImplementedError( + "Collection must have a team ID, such as when retrieved with " + "find_or_create_team." + ) if raise_error: - project = get_by_name_or_raise_error(collection=project_collection, name=project_name) + project = get_by_name_or_raise_error( + collection=project_collection, name=project_name + ) else: project = get_by_name_or_create( collection=project_collection, name=project_name, - default_provider=lambda: project_collection.register(project_name) + default_provider=lambda: project_collection.register(project_name), ) return project -def find_or_create_team(*, - team_collection: TeamCollection, - team_name: str, - raise_error: bool = False) -> Team: +def find_or_create_team( + *, team_collection: TeamCollection, team_name: str, raise_error: bool = False +) -> Team: """ Tries to find a team by name (returns first hit). @@ -116,36 +123,40 @@ def find_or_create_team(*, team = get_by_name_or_create( collection=team_collection, name=team_name, - default_provider=lambda: team_collection.register(team_name) + default_provider=lambda: team_collection.register(team_name), ) return team -def find_or_create_dataset(*, - dataset_collection: DatasetCollection, - dataset_name: str, - raise_error: bool = False) -> Dataset: +def find_or_create_dataset( + *, + dataset_collection: DatasetCollection, + dataset_name: str, + raise_error: bool = False, +) -> Dataset: """ Tries to find a dataset by name (returns first hit). If not found, creates a new dataset with the given name """ if raise_error: - dataset = get_by_name_or_raise_error(collection=dataset_collection, name=dataset_name) + dataset = get_by_name_or_raise_error( + collection=dataset_collection, name=dataset_name + ) else: dataset = get_by_name_or_create( collection=dataset_collection, name=dataset_name, default_provider=lambda: dataset_collection.register( Dataset(dataset_name, summary="seed summ.", description="seed desc.") - ) + ), ) return dataset -def create_or_update(*, - collection: Collection[CreationType], - resource: CreationType) -> CreationType: +def create_or_update( + *, collection: Collection[CreationType], resource: CreationType +) -> CreationType: """ Update a resource of a given name belonging to a collection. diff --git a/tests/_serialization/_data.py b/tests/_serialization/_data.py index 132231eb3..aa53f369d 100644 --- a/tests/_serialization/_data.py +++ b/tests/_serialization/_data.py @@ -7,12 +7,20 @@ (properties.Integer, 5, 5), (properties.Float, 3.0, 3.0), (properties.Raw, 1234, 1234), - (properties.String, 'foo', 'foo'), + (properties.String, "foo", "foo"), (properties.Boolean, True, True), (properties.Boolean, False, False), - (properties.UUID, uuid.UUID('284e6cec-dd05-4f8e-9a94-4abb298bde82'), '284e6cec-dd05-4f8e-9a94-4abb298bde82'), + ( + properties.UUID, + uuid.UUID("284e6cec-dd05-4f8e-9a94-4abb298bde82"), + "284e6cec-dd05-4f8e-9a94-4abb298bde82", + ), (properties.Datetime, arrow.get(269815509154).datetime, 269815509154), - (properties.Datetime, arrow.get('2019-07-19T10:46:08+00:00').datetime, 1563533168000), + ( + properties.Datetime, + arrow.get("2019-07-19T10:46:08+00:00").datetime, + 1563533168000, + ), ] @@ -21,13 +29,14 @@ (properties.Float, object()), (properties.String, 1), (properties.Boolean, 3), - (properties.Boolean, 'False'), - (properties.UUID, '284e6cec'), + (properties.Boolean, "False"), + (properties.UUID, "284e6cec"), ] class DummyProperty(properties.Property): """This is a concrete sublcass that does not overwrite __str__ for base Property testing""" + @property def underlying_types(self): return None @@ -44,12 +53,12 @@ def _deserialize(self, value): VALID_STRINGS = [ - (DummyProperty, 'hi', ""), - (properties.Raw, 'hi', ""), - (properties.Integer, 'foo', ""), - (properties.Float, 'bar', ""), - (properties.String, 'foobar', ""), - (properties.Boolean, 'what', ""), + (DummyProperty, "hi", ""), + (properties.Raw, "hi", ""), + (properties.Integer, "foo", ""), + (properties.Float, "bar", ""), + (properties.String, "foobar", ""), + (properties.Boolean, "what", ""), ] INVALID_INSTANCES = [ @@ -57,35 +66,38 @@ def _deserialize(self, value): (properties.Integer, "1"), (properties.Integer, complex(1, 2)), (properties.Integer, True), - (properties.Integer, 'asdf'), + (properties.Integer, "asdf"), (properties.Float, complex(1, 2)), (properties.Float, True), - (properties.Float, 'asdf'), + (properties.Float, "asdf"), (properties.String, 1), (properties.String, dict()), (properties.Boolean, 1), (properties.Boolean, 1.0), - (properties.Boolean, 'asdf'), + (properties.Boolean, "asdf"), (properties.UUID, str(uuid.uuid4())), # string(uuid) != uuid (properties.UUID, 1.0), - (properties.Datetime, '2019-07-19T10:46:08.949682+00:00'), # str(datetime) != datetime - (properties.LinkOrElse, object()) + ( + properties.Datetime, + "2019-07-19T10:46:08.949682+00:00", + ), # str(datetime) != datetime + (properties.LinkOrElse, object()), ] INVALID_SERIALIZED_INSTANCES = [ - (properties.Integer, '1.0'), + (properties.Integer, "1.0"), (properties.Integer, str(complex(1, 2))), (properties.Integer, True), - (properties.Integer, 'asdf'), + (properties.Integer, "asdf"), (properties.Integer, 14.4), - (properties.Float, str(complex(1,2))), + (properties.Float, str(complex(1, 2))), (properties.Float, True), - (properties.Float, 'asdf'), + (properties.Float, "asdf"), (properties.String, 1), (properties.String, dict()), (properties.Boolean, 1), (properties.Boolean, 1.0), - (properties.Boolean, 'asdf'), - (properties.UUID, 'wrong-number-of-chars'), - (properties.Datetime, '2019-07-19T35:46:08.949682+99:99'), # nonsense time + (properties.Boolean, "asdf"), + (properties.UUID, "wrong-number-of-chars"), + (properties.Datetime, "2019-07-19T35:46:08.949682+99:99"), # nonsense time ] diff --git a/tests/_serialization/_utils.py b/tests/_serialization/_utils.py index d8fe3142f..221402a44 100644 --- a/tests/_serialization/_utils.py +++ b/tests/_serialization/_utils.py @@ -3,13 +3,21 @@ from citrine._serialization import properties -def make_class_with_property(prop_type: Type[properties.Property], field_name: str, field_path: Optional[str] = None): +def make_class_with_property( + prop_type: Type[properties.Property], + field_name: str, + field_path: Optional[str] = None, +): class SampleObject: def __init__(self, field_value: Any): setattr(self, field_name, field_value) def __eq__(self, other): return getattr(self, field_name) == getattr(other, field_name) - setattr(SampleObject, field_name, - prop_type(serialization_path=field_name if field_path is None else field_path)) + + setattr( + SampleObject, + field_name, + prop_type(serialization_path=field_name if field_path is None else field_path), + ) return SampleObject diff --git a/tests/_serialization/test_container_properties.py b/tests/_serialization/test_container_properties.py index 660def198..c9f648804 100644 --- a/tests/_serialization/test_container_properties.py +++ b/tests/_serialization/test_container_properties.py @@ -8,7 +8,7 @@ from gemd.entity.link_by_uid import LinkByUID -@pytest.mark.parametrize('sub_prop,sub_value,sub_serialized', VALID_SERIALIZATIONS) +@pytest.mark.parametrize("sub_prop,sub_value,sub_serialized", VALID_SERIALIZATIONS) def test_list_property_serde(sub_prop, sub_value, sub_serialized): prop = properties.List(sub_prop) value = [sub_value for _ in range(5)] @@ -17,17 +17,17 @@ def test_list_property_serde(sub_prop, sub_value, sub_serialized): assert prop.serialize(value) == serialized -@pytest.mark.parametrize('sub_prop,sub_value,sub_serialized', VALID_SERIALIZATIONS) +@pytest.mark.parametrize("sub_prop,sub_value,sub_serialized", VALID_SERIALIZATIONS) def test_object_property_serde(sub_prop, sub_value, sub_serialized): - klass = make_class_with_property(sub_prop, 'some_property_name') + klass = make_class_with_property(sub_prop, "some_property_name") prop = properties.Object(klass) instance = klass(sub_value) - serialized = {'some_property_name': sub_serialized} + serialized = {"some_property_name": sub_serialized} assert prop.deserialize(serialized) == instance assert prop.serialize(instance) == serialized -@pytest.mark.parametrize('sub_prop,sub_value,sub_serialized', VALID_SERIALIZATIONS) +@pytest.mark.parametrize("sub_prop,sub_value,sub_serialized", VALID_SERIALIZATIONS) def test_optional_property(sub_prop, sub_value, sub_serialized): prop = properties.Optional(sub_prop) assert prop.deserialize(sub_serialized) == sub_value @@ -36,9 +36,13 @@ def test_optional_property(sub_prop, sub_value, sub_serialized): assert prop.serialize(None) is None -@pytest.mark.parametrize('key_type,key_value,key_serialized', VALID_SERIALIZATIONS) -@pytest.mark.parametrize('value_type,value_value,value_serialized', VALID_SERIALIZATIONS) -def test_mapping_property(key_type, value_type, key_value, value_value, key_serialized, value_serialized): +@pytest.mark.parametrize("key_type,key_value,key_serialized", VALID_SERIALIZATIONS) +@pytest.mark.parametrize( + "value_type,value_value,value_serialized", VALID_SERIALIZATIONS +) +def test_mapping_property( + key_type, value_type, key_value, value_value, key_serialized, value_serialized +): prop = properties.Mapping(key_type, value_type) value = {key_value: value_value} serialized = {key_serialized: value_serialized} @@ -46,20 +50,28 @@ def test_mapping_property(key_type, value_type, key_value, value_value, key_seri assert prop.serialize(value) == serialized -@pytest.mark.parametrize('key_type,key_value,key_serialized', VALID_SERIALIZATIONS) -@pytest.mark.parametrize('value_type,value_value,value_serialized', VALID_SERIALIZATIONS) -def test_mapping_property_list_of_pairs(key_type, value_type, key_value, value_value, key_serialized, value_serialized): - prop = properties.Mapping(key_type, value_type, ser_as_list_of_pairs = True) +@pytest.mark.parametrize("key_type,key_value,key_serialized", VALID_SERIALIZATIONS) +@pytest.mark.parametrize( + "value_type,value_value,value_serialized", VALID_SERIALIZATIONS +) +def test_mapping_property_list_of_pairs( + key_type, value_type, key_value, value_value, key_serialized, value_serialized +): + prop = properties.Mapping(key_type, value_type, ser_as_list_of_pairs=True) value = {key_value: value_value} - serialized = [(key_serialized, value_serialized),] + serialized = [ + (key_serialized, value_serialized), + ] assert prop.deserialize(serialized) == value unittest.TestCase().assertCountEqual(prop.serialize(value), serialized) def test_mapping_property_list_of_pairs_multiple(): - prop = properties.Mapping(properties.String, properties.Integer, ser_as_list_of_pairs = True) - value = {'foo': 1, 'bar': 2} - serialized = [('foo', 1), ('bar', 2)] + prop = properties.Mapping( + properties.String, properties.Integer, ser_as_list_of_pairs=True + ) + value = {"foo": 1, "bar": 2} + serialized = [("foo", 1), ("bar", 2)] assert prop.deserialize(serialized) == value unittest.TestCase().assertCountEqual(prop.serialize(value), serialized) @@ -69,8 +81,12 @@ class DummyDescriptor(object): dummy_list = properties.List(properties.Float, "dummy_list") dummy_set = properties.Set(type(properties.Float()), "dummy_map") link_or_else = properties.LinkOrElse(serialization_path="link_or_else") - map_collection_key = properties.Mapping(properties.Optional(properties.String), properties.Integer, "map_collection_key") - specified_mixed_list = properties.SpecifiedMixedList([properties.Integer(default=100)], "specified_mixed_list") + map_collection_key = properties.Mapping( + properties.Optional(properties.String), properties.Integer, "map_collection_key" + ) + specified_mixed_list = properties.SpecifiedMixedList( + [properties.Integer(default=100)], "specified_mixed_list" + ) def test_collection_setters(): @@ -78,8 +94,12 @@ def test_collection_setters(): dummy_descriptor.dummy_map = {1: "1"} dummy_descriptor.dummy_set = {1} dummy_descriptor.dummy_list = [1, 2] - dummy_descriptor.map_collection_key = {'foo': 1, 'bar': 2} - dummy_descriptor.link_or_else = {'type': LinkByUID.typ, "scope": "templates", "id": "density"} + dummy_descriptor.map_collection_key = {"foo": 1, "bar": 2} + dummy_descriptor.link_or_else = { + "type": LinkByUID.typ, + "scope": "templates", + "id": "density", + } dummy_descriptor.specified_mixed_list = [1] assert 1 in dummy_descriptor.specified_mixed_list @@ -92,6 +112,6 @@ def test_collection_setters(): dummy_descriptor.specified_mixed_list = [1, 2] assert 1.0 in dummy_descriptor.dummy_map - assert 'foo' in dummy_descriptor.map_collection_key + assert "foo" in dummy_descriptor.map_collection_key assert 1.0 in dummy_descriptor.dummy_set assert 1.0 in dummy_descriptor.dummy_list diff --git a/tests/_serialization/test_object_serialization.py b/tests/_serialization/test_object_serialization.py index 638e75a66..0e7f1c0e1 100644 --- a/tests/_serialization/test_object_serialization.py +++ b/tests/_serialization/test_object_serialization.py @@ -9,17 +9,21 @@ class UnserializableClass: """A dummy class that has no clear serialization or deserialization method.""" + def __init__(self, foo): self.foo = foo class SampleClass(Serializable): """A class to stress the deser scheme's ability to handle objects.""" - prop_string = String('prop_string.string', default='default') - prop_value = Object(BaseValue, 'prop_value') - prop_object = Optional(Object(UnserializableClass), 'prop_object') - def __init__(self, prop_string: str, prop_value: BaseValue, prop_object: Any = None): + prop_string = String("prop_string.string", default="default") + prop_value = Object(BaseValue, "prop_value") + prop_object = Optional(Object(UnserializableClass), "prop_object") + + def __init__( + self, prop_string: str, prop_value: BaseValue, prop_object: Any = None + ): self.prop_string = prop_string self.prop_value = prop_value self.prop_object = prop_object @@ -27,7 +31,7 @@ def __init__(self, prop_string: str, prop_value: BaseValue, prop_object: Any = N def test_gemd_object_serde(): """Test that an unspecified gemd object can be serialized and deserialized.""" - good_obj = SampleClass("Can be serialized", NominalReal(17, '')) + good_obj = SampleClass("Can be serialized", NominalReal(17, "")) copy = SampleClass.build(good_obj.dump()) assert copy.prop_value == good_obj.prop_value assert copy.prop_string == good_obj.prop_string @@ -35,37 +39,40 @@ def test_gemd_object_serde(): def test_default_nested_serde(): """Test that defaults work in nested dictionaries.""" - good_obj = SampleClass("Can be serialized", NominalReal(17, '')) + good_obj = SampleClass("Can be serialized", NominalReal(17, "")) data = good_obj.dump() # If 'prop_string.string' is a non-string, that's an error - data['prop_string']['string'] = 0 + data["prop_string"]["string"] = 0 with pytest.raises(ValueError): SampleClass.build(data) # If data['prop_string'] is an empty dictionary, then the default is used - data['prop_string'] = dict() - assert SampleClass.build(data).prop_string == 'default' + data["prop_string"] = dict() + assert SampleClass.build(data).prop_string == "default" # If `data` does not even have a 'prop_string' key, then the default is used - del data['prop_string'] - assert SampleClass.build(data).prop_string == 'default' + del data["prop_string"] + assert SampleClass.build(data).prop_string == "default" def test_bad_object_serde(): """Test that a 'mystery' object cannot be serialized.""" - bad_obj = SampleClass("Cannot be serialized", NominalReal(34, ''), UnserializableClass(1)) + bad_obj = SampleClass( + "Cannot be serialized", NominalReal(34, ""), UnserializableClass(1) + ) with pytest.raises(AttributeError): bad_obj.dump() def test_object_str_representation(): - assert "" == str(Object(NominalReal, 'foo')) + assert "" == str(Object(NominalReal, "foo")) def test_override_configurations(): """Check that weird override cases get caught.""" - class OverrideTestClass(Serializable['OverrideTestClass']): + + class OverrideTestClass(Serializable["OverrideTestClass"]): overridden_value = String("overridden_value", override=True) overridden_option = Optional(String(), "overridden_option", override=True) @@ -98,7 +105,7 @@ def initable(self): def required(self): return self._required - class OverrideTestClass(Serializable['OverrideTestClass'], BaseTestClass): + class OverrideTestClass(Serializable["OverrideTestClass"], BaseTestClass): no_key = Optional(String(), "no_key", override=True) initable = Optional(String(), "initable", override=True, use_init=True) required = String("required", override=True, use_init=True) @@ -135,11 +142,11 @@ def required(self, value): raise TypeError("magic_value") self._required = value - class BadClass(Serializable['BadClass'], TestClass): + class BadClass(Serializable["BadClass"], TestClass): required = String("required", override=True) optional = Optional(String(), "optional", use_init=True) - class GoodClass(Serializable['BadClass'], TestClass): + class GoodClass(Serializable["BadClass"], TestClass): required = Optional(String(), "required", override=True, use_init=True) optional = Optional(String(), "optional", use_init=True) diff --git a/tests/_serialization/test_resource.py b/tests/_serialization/test_resource.py index 78bb4a1af..e57d162b3 100644 --- a/tests/_serialization/test_resource.py +++ b/tests/_serialization/test_resource.py @@ -12,4 +12,4 @@ def test_module_ref_serialization(): ref_data = ref.dump() # Then - assert ref_data['module_uid'] == str(m_uid) + assert ref_data["module_uid"] == str(m_uid) diff --git a/tests/_serialization/test_simple_properties.py b/tests/_serialization/test_simple_properties.py index a5d916642..a7f185eab 100644 --- a/tests/_serialization/test_simple_properties.py +++ b/tests/_serialization/test_simple_properties.py @@ -19,9 +19,13 @@ Optional, String, Union, - UUID + UUID, +) +from citrine.informatics.predictor_evaluation_metrics import ( + PredictorEvaluationMetric, + RMSE, + CoverageProbability, ) -from citrine.informatics.predictor_evaluation_metrics import PredictorEvaluationMetric, RMSE, CoverageProbability from citrine.resources.dataset import Dataset from ._data import ( VALID_SERIALIZATIONS, @@ -32,65 +36,65 @@ ) -@pytest.mark.parametrize('prop_type,value,serialized', VALID_SERIALIZATIONS) +@pytest.mark.parametrize("prop_type,value,serialized", VALID_SERIALIZATIONS) def test_simple_property_serde(prop_type, value, serialized): prop = prop_type() assert prop.deserialize(serialized) == value assert prop.serialize(value) == serialized -@pytest.mark.parametrize('prop_type,value', INVALID_INSTANCES) +@pytest.mark.parametrize("prop_type,value", INVALID_INSTANCES) def test_invalid_property_serialization(prop_type, value): prop = prop_type() with pytest.raises(Exception): prop.serialize(value) -@pytest.mark.parametrize('prop_type,serialized', INVALID_SERIALIZED_INSTANCES) +@pytest.mark.parametrize("prop_type,serialized", INVALID_SERIALIZED_INSTANCES) def test_invalid_property_deserialization(prop_type, serialized): prop = prop_type() with pytest.raises(Exception): prop.deserialize(serialized) -@pytest.mark.parametrize('prop_type,serialized', INVALID_DESERIALIZATION_TYPES) +@pytest.mark.parametrize("prop_type,serialized", INVALID_DESERIALIZATION_TYPES) def test_invalid_deserialization_type(prop_type, serialized): prop = prop_type() with pytest.raises(ValueError): prop.deserialize(serialized) -@pytest.mark.parametrize('prop_type,serialized', INVALID_DESERIALIZATION_TYPES) +@pytest.mark.parametrize("prop_type,serialized", INVALID_DESERIALIZATION_TYPES) def test_invalid_deserialization_type_with_base_class(prop_type, serialized): class BaseTest: pass prop = prop_type() - prop.serialization_path = 'ser_path' + prop.serialization_path = "ser_path" with pytest.raises(ValueError) as excinfo: prop.deserialize(serialized, base_class=BaseTest().__class__) # Check that the exception includes the calling class name and argument if not isinstance(prop, UUID): - assert 'BaseTest:ser_path' in str(excinfo.value) + assert "BaseTest:ser_path" in str(excinfo.value) -@pytest.mark.parametrize('prop_type,serialized', INVALID_DESERIALIZATION_TYPES) +@pytest.mark.parametrize("prop_type,serialized", INVALID_DESERIALIZATION_TYPES) def test_invalid_deserialization_type_with_dataset(prop_type, serialized): # Supplying a Daatset instance as the base_class should include it's # name in the exception value string (UUIDs are a special case) dset = Dataset(name="dset", summary="test dataset", description="description") prop = prop_type() - prop.serialization_path = 'ser_path' + prop.serialization_path = "ser_path" with pytest.raises(ValueError) as excinfo: prop.deserialize(serialized, base_class=dset.__class__) if not isinstance(prop, UUID): - assert 'Dataset:ser_path' in str(excinfo.value) + assert "Dataset:ser_path" in str(excinfo.value) -@pytest.mark.parametrize('prop_type,path,expected', VALID_STRINGS) +@pytest.mark.parametrize("prop_type,path,expected", VALID_STRINGS) def test_valid_property_deserialization(prop_type, path, expected): assert expected == str(prop_type(path)) @@ -101,19 +105,19 @@ def test_serialize_to_dict_error(): def test_valid_serialize_to_dict(): - assert {'my_foo': 100} == Integer('my_foo').serialize_to_dict({}, 100) + assert {"my_foo": 100} == Integer("my_foo").serialize_to_dict({}, 100) def test_serialize_dot_value_to_dict(): - assert {'my': {'foo': 100}} == Integer('my.foo').serialize_to_dict({}, 100) + assert {"my": {"foo": 100}} == Integer("my.foo").serialize_to_dict({}, 100) def test_set_int_property_from_string(): class Foo: - bar = Integer('bar') + bar = Integer("bar") f = Foo() - f.bar = '12' + f.bar = "12" assert 12 == f.bar @@ -134,7 +138,9 @@ def test_float_cannot_deserialize_bool(): def test_deserialize_string_datetime(): - assert arrow.get('2019-07-19T10:46:08+00:00').datetime == Datetime().deserialize('2019-07-19T10:46:08+00:00') + assert arrow.get("2019-07-19T10:46:08+00:00").datetime == Datetime().deserialize( + "2019-07-19T10:46:08+00:00" + ) def test_datetime_cannot_deserialize_float(): @@ -149,14 +155,14 @@ def test_mixed_list_requires_property_list(): def test_deserialize_mixed_list(): ml = SpecifiedMixedList([Integer, String]) - assert [1, '2'] == ml.deserialize([1, '2']) + assert [1, "2"] == ml.deserialize([1, "2"]) assert [1, None] == ml.deserialize([1]) def test_mixed_list_cannot_deserialize_larger_lists(): ml = SpecifiedMixedList([Integer]) with pytest.raises(ValueError): - ml.deserialize([1, '2']) + ml.deserialize([1, "2"]) with pytest.raises(ValueError): ml.deserialize([1, 2]) @@ -164,7 +170,7 @@ def test_mixed_list_cannot_deserialize_larger_lists(): def test_mixed_list_cannot_serialize_larger_lists(): ml = SpecifiedMixedList([Integer]) with pytest.raises(ValueError): - ml.serialize([1, '2']) + ml.serialize([1, "2"]) with pytest.raises(ValueError): ml.serialize([1, 2]) @@ -231,7 +237,7 @@ class Foo: obj = Object(Foo) with pytest.raises(AttributeError): - obj.deserialize({'key': 'value'}) + obj.deserialize({"key": "value"}) def test_linkorelse_deserialize_requires_serializable(): @@ -243,28 +249,32 @@ def test_linkorelse_deserialize_requires_serializable(): def test_linkorelse_deserialize_requires_scope_and_id(): loe = LinkOrElse() with pytest.raises(ValueError, match=r"missing.+required"): - loe.deserialize({'type': LinkByUID.typ}) + loe.deserialize({"type": LinkByUID.typ}) def test_linkorelse_raises_deep_errors(): loe = LinkOrElse() with pytest.raises(TypeError): - loe.deserialize({ - 'type': ProcessSpec.typ, - 'name': 'Badly structured', - 'conditions': [{'type': Condition.typ, "value": 'invalid structure'}], - }) + loe.deserialize( + { + "type": ProcessSpec.typ, + "name": "Badly structured", + "conditions": [{"type": Condition.typ, "value": "invalid structure"}], + } + ) def test_linkorelse_deserialize(): loe = LinkOrElse() - lbu = loe.deserialize({'type': LinkByUID.typ, 'scope': 'foo', 'id': str(uuid.uuid4())}) + lbu = loe.deserialize( + {"type": LinkByUID.typ, "scope": "foo", "id": str(uuid.uuid4())} + ) assert isinstance(lbu, LinkByUID) def test_optional_repr(): opt = Optional(String) - assert '] None>' == str(opt) + assert "] None>" == str(opt) def test_set_serialize_sortable(): diff --git a/tests/_serialization/test_taurus_interop.py b/tests/_serialization/test_taurus_interop.py index bb5e02d0d..d14dc7817 100644 --- a/tests/_serialization/test_taurus_interop.py +++ b/tests/_serialization/test_taurus_interop.py @@ -16,12 +16,11 @@ def test_flatten(): bounds = CategoricalBounds(categories=["foo", "bar"]) template = ProcessTemplate( - "spam", - conditions=[(ConditionTemplate(name="eggs", bounds=bounds), bounds)] + "spam", conditions=[(ConditionTemplate(name="eggs", bounds=bounds), bounds)] ) spec = ProcessSpec(name="spec", template=template) - flat = flatten(spec, scope='testing') + flat = flatten(spec, scope="testing") assert len(flat) == 3, "Expected 3 flattened objects" diff --git a/tests/_util/source_mod.py b/tests/_util/source_mod.py index 3651138b7..aaecd1781 100644 --- a/tests/_util/source_mod.py +++ b/tests/_util/source_mod.py @@ -1,2 +1,2 @@ -class ExampleClass(): - pass \ No newline at end of file +class ExampleClass: + pass diff --git a/tests/_util/test_batcher.py b/tests/_util/test_batcher.py index 1018bd2c0..7c5e5ea8c 100644 --- a/tests/_util/test_batcher.py +++ b/tests/_util/test_batcher.py @@ -1,13 +1,28 @@ import pytest - -from citrine._utils.batcher import Batcher - from gemd.demo.cake import make_cake from gemd.entity.link_by_uid import LinkByUID -from gemd.entity.object import * -from gemd.entity.template import * +from gemd.entity.object import ( + IngredientRun, + IngredientSpec, + MaterialRun, + MaterialSpec, + MeasurementRun, + MeasurementSpec, + ProcessRun, + ProcessSpec, +) +from gemd.entity.template import ( + ConditionTemplate, + MaterialTemplate, + MeasurementTemplate, + ParameterTemplate, + ProcessTemplate, + PropertyTemplate, +) from gemd.util import flatten, writable_sort_order +from citrine._utils.batcher import Batcher + def test_by_type(): """Test type batching.""" @@ -16,8 +31,9 @@ def test_by_type(): first = batcher.batch(flatten(cake), batch_size=10) assert all(len(batch) <= 10 for batch in first), "A batch was too long" for i in range(len(first) - 1): - assert max(writable_sort_order(x) for x in first[i]) \ - <= min(writable_sort_order(x) for x in first[i+1]), "Load order violated" + assert max(writable_sort_order(x) for x in first[i]) <= min( + writable_sort_order(x) for x in first[i + 1] + ), "Load order violated" assert len(flatten(cake)) == len({y for x in first for y in x}), "Object missing" assert len(flatten(cake)) == len([y for x in first for y in x]), "Object repeated" @@ -30,7 +46,7 @@ def test_by_type(): with pytest.raises(ValueError): bad = [ ProcessSpec(name="One", uids={"bad": "id"}), - ProcessSpec(name="Two", uids={"bad": "id"}) + ProcessSpec(name="Two", uids={"bad": "id"}), ] batcher.batch(bad, batch_size=10) @@ -77,16 +93,20 @@ def test_by_dependency(): elif isinstance(obj, ProcessRun): assert obj.spec in derefs, "Spec wasn't in batch" for x in obj.parameters: - assert(x.template in derefs), "Referenced parameter wasn't in batch" + assert x.template in derefs, "Referenced parameter wasn't in batch" for x in obj.conditions: - assert(x.template in derefs), "Referenced condition wasn't in batch" + assert x.template in derefs, "Referenced condition wasn't in batch" elif isinstance(obj, MaterialSpec): assert obj.template in derefs, "Template wasn't in batch" assert obj.process in derefs, "Process wasn't in batch" for x in obj.properties: - assert x.property.template in derefs, "Referenced property wasn't in batch" + assert x.property.template in derefs, ( + "Referenced property wasn't in batch" + ) for y in x.conditions: - assert y.template in derefs, "Referenced condition wasn't in batch" + assert y.template in derefs, ( + "Referenced condition wasn't in batch" + ) elif isinstance(obj, MaterialRun): assert obj.spec in derefs, "Spec wasn't in batch" assert obj.process in derefs, "Process wasn't in batch" @@ -112,7 +132,9 @@ def test_by_dependency(): assert x.template in derefs, "Referenced condition wasn't in batch" for x in obj.properties: assert x.template in derefs, "Referenced property wasn't in batch" - elif isinstance(obj, (PropertyTemplate, ConditionTemplate, ParameterTemplate)): + elif isinstance( + obj, (PropertyTemplate, ConditionTemplate, ParameterTemplate) + ): pass # These objects don't reference other objects else: pytest.fail(f"Unhandled type in batch: {type(obj)}") diff --git a/tests/_util/test_functions.py b/tests/_util/test_functions.py index 803c99484..1e5a4a649 100644 --- a/tests/_util/test_functions.py +++ b/tests/_util/test_functions.py @@ -6,63 +6,68 @@ from gemd.entity.bounds.real_bounds import RealBounds from gemd.entity.link_by_uid import LinkByUID -from citrine._utils.functions import get_object_id, validate_type, object_to_link_by_uid, \ - rewrite_s3_links_locally, write_file_locally, migrate_deprecated_argument, format_escaped_url, \ - MigratedClassMeta, generate_shared_meta +from citrine._utils.functions import ( + get_object_id, + validate_type, + object_to_link_by_uid, + rewrite_s3_links_locally, + write_file_locally, + migrate_deprecated_argument, + format_escaped_url, + MigratedClassMeta, + generate_shared_meta, +) from gemd.entity.attribute.property import Property from citrine.resources.condition_template import ConditionTemplate def test_get_object_id_from_base_attribute(): with pytest.raises(ValueError): - get_object_id(Property('some property')) + get_object_id(Property("some property")) def test_get_object_id_from_data_concepts(): uid = str(uuid.uuid4()) template = ConditionTemplate( - name='test', - bounds=RealBounds(0.0, 1.0, ''), - uids={'id': uid} + name="test", bounds=RealBounds(0.0, 1.0, ""), uids={"id": uid} ) assert uid == get_object_id(template) def test_get_object_id_from_data_concepts_id_is_none(): - template = ConditionTemplate( - name='test', - bounds=RealBounds(0.0, 1.0, '') - ) + template = ConditionTemplate(name="test", bounds=RealBounds(0.0, 1.0, "")) with pytest.raises(ValueError): - template.uids = {'id': None} + template.uids = {"id": None} def test_get_object_id_link_by_uid_bad_scope(): with pytest.raises(ValueError): - get_object_id(LinkByUID('bad_scope', '123')) + get_object_id(LinkByUID("bad_scope", "123")) def test_get_object_id_wrong_type(): with pytest.raises(TypeError): - get_object_id('no id here') + get_object_id("no id here") def test_validate_type_wrong_type(): with pytest.raises(Exception): - validate_type({'type': 'int'}, 'foo') + validate_type({"type": "int"}, "foo") def test_validate_type_set_type(): - assert {'type': 'int'} == validate_type({}, 'int') + assert {"type": "int"} == validate_type({}, "int") def test_object_to_link_by_uid_missing_uids(): - assert {'foo': 'bar'} == object_to_link_by_uid({'foo': 'bar'}) + assert {"foo": "bar"} == object_to_link_by_uid({"foo": "bar"}) def test_rewrite_s3_links_locally(): - assert "http://localhost:9566" == rewrite_s3_links_locally("http://localstack:4566", "http://localhost:9566") + assert "http://localhost:9566" == rewrite_s3_links_locally( + "http://localstack:4566", "http://localhost:9566" + ) def test_write_file_locally(tmpdir): @@ -101,16 +106,19 @@ def test_migrated_class(): with warnings.catch_warnings(): warnings.simplefilter("error") - class MigratedProperty(Property, - deprecated_in="1.2.3", - removed_in="2.0.0", - metaclass=generate_shared_meta(Property)): + class MigratedProperty( + Property, + deprecated_in="1.2.3", + removed_in="2.0.0", + metaclass=generate_shared_meta(Property), + ): pass with pytest.deprecated_call(): MigratedProperty(name="I'm a property!") with pytest.deprecated_call(): + class DerivedProperty(MigratedProperty): pass @@ -132,13 +140,15 @@ class IndependentProperty(Property): assert isinstance(Property("Property Name"), MigratedProperty) with pytest.raises(TypeError, match="deprecated_in"): + class NoVersionInfo(Property, metaclass=generate_shared_meta(Property)): pass with pytest.raises(TypeError, match="precisely"): - class NoParent(deprecated_in="1.2.3", - removed_in="2.0.0", - metaclass=MigratedClassMeta): + + class NoParent( + deprecated_in="1.2.3", removed_in="2.0.0", metaclass=MigratedClassMeta + ): pass assert generate_shared_meta(dict) is MigratedClassMeta @@ -156,10 +166,9 @@ def test_recursive_subtype_recovery(): class Simple(abc.ABC): pass - class MigratedProperty(Simple, - deprecated_in="1.2.3", - removed_in="2.0.0", - metaclass=MigratedClassMeta): + class MigratedProperty( + Simple, deprecated_in="1.2.3", removed_in="2.0.0", metaclass=MigratedClassMeta + ): pass assert not issubclass(dict, Simple) @@ -173,7 +182,9 @@ def test_migrate_deprecated_argument(): with pytest.warns(DeprecationWarning): with pytest.raises(ValueError): # ValueError if both arguments are specified - migrate_deprecated_argument("something", "new name", "something else", "old name") + migrate_deprecated_argument( + "something", "new name", "something else", "old name" + ) # Return the value if the new argument is specified assert migrate_deprecated_argument(14, "new name", None, "old name") == 14 @@ -186,10 +197,12 @@ def test_migrate_deprecated_argument(): def test_format_escaped_url(): - url = format_escaped_url('http://base.com/{}/{}/{word1}/{word2}', 1, '&', word1='fine', word2='+/?#') - assert 'http://base.com/' in url - assert 'fine' in url - assert '1' in url - for c in '&' + '+?#': + url = format_escaped_url( + "http://base.com/{}/{}/{word1}/{word2}", 1, "&", word1="fine", word2="+/?#" + ) + assert "http://base.com/" in url + assert "fine" in url + assert "1" in url + for c in "&" + "+?#": assert c not in url - assert 6 == sum(c == '/' for c in url) + assert 6 == sum(c == "/" for c in url) diff --git a/tests/_util/test_replace_object_with_link.py b/tests/_util/test_replace_object_with_link.py index 09f34059e..fed2221d7 100644 --- a/tests/_util/test_replace_object_with_link.py +++ b/tests/_util/test_replace_object_with_link.py @@ -1,54 +1,47 @@ """Tests of the functions that replace objects with Links.""" + from citrine._utils.functions import replace_objects_with_links def test_simple_replacement(): """A top-level object should turn into a link-by-uid.""" json = dict( - key='value', - object=dict( - type='material_run', - uids={'my_id': '1', 'id': '17'} - ) + key="value", object=dict(type="material_run", uids={"my_id": "1", "id": "17"}) ) replaced_json = replace_objects_with_links(json) - assert replaced_json == {'key': 'value', - 'object': {'type': 'link_by_uid', 'scope': 'id', 'id': '17'}} + assert replaced_json == { + "key": "value", + "object": {"type": "link_by_uid", "scope": "id", "id": "17"}, + } def test_nested_replacement(): """A list of objects should turn into a list of link-by-uids.""" json = dict( - object=[dict(type='material_run', uids={'my_id': '1'}), - dict(type='material_run', uids={'my_id': '2'})] + object=[ + dict(type="material_run", uids={"my_id": "1"}), + dict(type="material_run", uids={"my_id": "2"}), + ] ) replaced_json = replace_objects_with_links(json) - assert replaced_json == {'object': [{'type': 'link_by_uid', 'scope': 'my_id', 'id': '1'}, - {'type': 'link_by_uid', 'scope': 'my_id', 'id': '2'}]} + assert replaced_json == { + "object": [ + {"type": "link_by_uid", "scope": "my_id", "id": "1"}, + {"type": "link_by_uid", "scope": "my_id", "id": "2"}, + ] + } def test_failed_replacement(): """An object that does not have a type and a uids dictionary should not be replaced.""" - json = dict(object=dict( - some_field='material_run', - uids={'my_id': '1', 'id': '17'} - )) + json = dict(object=dict(some_field="material_run", uids={"my_id": "1", "id": "17"})) assert json == replace_objects_with_links(json) # no type field - json = dict(object=dict( - type='material_run', - uids='a uid string' - )) + json = dict(object=dict(type="material_run", uids="a uid string")) assert json == replace_objects_with_links(json) # uids is not a dictionary - json = dict(object=dict( - type='material_run', - some_field={'my_id': '1', 'id': '17'} - )) + json = dict(object=dict(type="material_run", some_field={"my_id": "1", "id": "17"})) assert json == replace_objects_with_links(json) # no uids field - json = dict(object=dict( - type='material_run', - uids={} - )) + json = dict(object=dict(type="material_run", uids={})) assert json == replace_objects_with_links(json) # uids is an empty dictionary diff --git a/tests/_util/test_scrub_none.py b/tests/_util/test_scrub_none.py index b693f3426..9c5ff148b 100644 --- a/tests/_util/test_scrub_none.py +++ b/tests/_util/test_scrub_none.py @@ -1,42 +1,28 @@ """Tests of the method that removes None values from object dictionaries.""" + from citrine._utils.functions import scrub_none def test_scrub_none(): """Test that scrub_none() when applied to some examples yields expected results.""" - json = dict( - key1=1, - key2=None - ) + json = dict(key1=1, key2=None) scrub_none(json) assert json == dict(key1=1) json = dict( - key1=dict( - key11='foo', - key12=None - ), - key2=[ - dict(key21=None, key22=17), - dict(key23=None), - dict(key24=34, key25=51) - ] + key1=dict(key11="foo", key12=None), + key2=[dict(key21=None, key22=17), dict(key23=None), dict(key24=34, key25=51)], ) scrub_none(json) assert json == dict( - key1=dict(key11='foo'), - key2=[dict(key22=17), dict(), dict(key24=34, key25=51)] + key1=dict(key11="foo"), key2=[dict(key22=17), dict(), dict(key24=34, key25=51)] ) json = dict( - key1=1, - key2=[None, 'foo', None, None, 'bar', None], - key3=[None, None, None] + key1=1, key2=[None, "foo", None, None, "bar", None], key3=[None, None, None] ) scrub_none(json) # None should not be removed from lists assert json == dict( - key1=1, - key2=[None, 'foo', None, None, 'bar', None], - key3=[None, None, None] + key1=1, key2=[None, "foo", None, None, "bar", None], key3=[None, None, None] ) diff --git a/tests/_util/test_template_util.py b/tests/_util/test_template_util.py index c0d869d78..e2446aa7b 100644 --- a/tests/_util/test_template_util.py +++ b/tests/_util/test_template_util.py @@ -1,138 +1,130 @@ -from citrine._utils.template_util import make_attribute_table -from gemd.entity.object import * -from gemd.entity.attribute import * -from gemd.entity.value import * +from gemd.entity.attribute import ( + Condition, + Parameter, + Property, + PropertyAndConditions, +) from gemd.entity.link_by_uid import LinkByUID +from gemd.entity.object import ( + IngredientSpec, + MaterialSpec, + MeasurementRun, + MeasurementSpec, + ProcessRun, + ProcessSpec, +) +from gemd.entity.value import ( + InChI, + NominalCategorical, + NominalReal, + NormalReal, +) + +from citrine._utils.template_util import make_attribute_table + def _make_list_of_gems(): faux_gems = [ ProcessSpec( - name = "hello world", - parameters = [ - Parameter( - name = "param 1", - value = NominalReal(nominal=4.2, units="g") - ), + name="hello world", + parameters=[ + Parameter(name="param 1", value=NominalReal(nominal=4.2, units="g")), + Parameter(name="param 2", value=NominalCategorical(category="foo")), Parameter( - name = "param 2", - value = NominalCategorical(category="foo") + name="attr 1", + value=InChI( + inchi="InChI=1S/C8H10N4O2/c1-10-4-9-6-5(10)7(13)12(3)8(14)11(6)2/h4H,1-3H3" + ), ), - Parameter( - name = "attr 1", - value = InChI(inchi="InChI=1S/C8H10N4O2/c1-10-4-9-6-5(10)7(13)12(3)8(14)11(6)2/h4H,1-3H3") - ) ], - conditions = [ - Condition( - name = "cond 1", - value = NormalReal(mean=4, std=0.5, units="") - ) - ] + conditions=[ + Condition(name="cond 1", value=NormalReal(mean=4, std=0.5, units="")) + ], ), IngredientSpec( - name = "I shouldn't be a row", - material=LinkByUID(scope = "faux", id = "abcde"), - process=LinkByUID(scope = "foo", id = "bar") + name="I shouldn't be a row", + material=LinkByUID(scope="faux", id="abcde"), + process=LinkByUID(scope="foo", id="bar"), ), ProcessRun( - name = "process 1", - spec = ProcessSpec( - name = "nestled Spec", + name="process 1", + spec=ProcessSpec( + name="nestled Spec", conditions=[ Condition( - name = "cond 1", - value = NormalReal(mean=6, std=0.3, units="") + name="cond 1", value=NormalReal(mean=6, std=0.3, units="") ), - ] + ], ), - parameters = [ + parameters=[ Parameter( - name = "param 1", - value = NormalReal(mean=4.2, std = 0.1, units="g") + name="param 1", value=NormalReal(mean=4.2, std=0.1, units="g") ), - Parameter( - name = "param 3", - value = NominalCategorical(category="bar") - ) + Parameter(name="param 3", value=NominalCategorical(category="bar")), ], - conditions = [ - Condition( - name = "cond 1", - value = NormalReal(mean=4, std=0.5, units="") - ), - Condition( - name = "cond 2", - value = NominalCategorical(category="hi") - ), + conditions=[ + Condition(name="cond 1", value=NormalReal(mean=4, std=0.5, units="")), + Condition(name="cond 2", value=NominalCategorical(category="hi")), Condition( - name = "attr 1", - value = InChI(inchi="InChI=1S/C34H34N4O4.Fe/c1-7-21-17(3)25-13-26-19(5)23(9-11-33(39)40)31(37-26)16-32-24(10-12-34(41)42)20(6)28(38-32)15-30-22(8-2)18(4)27(36-30)14-29(21)35-25;/h7-8,13-16H,1-2,9-12H2,3-6H3,(H4,35,36,37,38,39,40,41,42);/q;+2/p-2") + name="attr 1", + value=InChI( + inchi="InChI=1S/C34H34N4O4.Fe/c1-7-21-17(3)25-13-26-19(5)23(9-11-33(39)40)31(37-26)16-32-24(10-12-34(41)42)20(6)28(38-32)15-30-22(8-2)18(4)27(36-30)14-29(21)35-25;/h7-8,13-16H,1-2,9-12H2,3-6H3,(H4,35,36,37,38,39,40,41,42);/q;+2/p-2" + ), ), - ] + ], ), MaterialSpec( - name = "material 1", - process = LinkByUID(scope = "faux 2", id = "id2"), + name="material 1", + process=LinkByUID(scope="faux 2", id="id2"), properties=[ PropertyAndConditions( property=Property( - name = "prop 1", - value = NormalReal(mean=100, std=10, units="g/cm**3") + name="prop 1", + value=NormalReal(mean=100, std=10, units="g/cm**3"), ), conditions=[ Condition( - name = "cond 2", - value = NominalCategorical(category="hi") + name="cond 2", value=NominalCategorical(category="hi") ) - ] + ], ), PropertyAndConditions( property=Property( - name = "prop 2", - value = NominalReal(nominal=33, units="1/lb") + name="prop 2", value=NominalReal(nominal=33, units="1/lb") ), conditions=[ Condition( - name = "cond 3", - value = NominalCategorical(category="citrine") + name="cond 3", value=NominalCategorical(category="citrine") ) - ] + ], ), - ] + ], ), MeasurementSpec( - name = "meas spec 1", - parameters = [ - Parameter( - name = "param 1", - value = NominalReal(nominal=2.2, units="kg") - ), - Parameter( - name = "param 2", - value = NominalCategorical(category="bar") - ) + name="meas spec 1", + parameters=[ + Parameter(name="param 1", value=NominalReal(nominal=2.2, units="kg")), + Parameter(name="param 2", value=NominalCategorical(category="bar")), ], ), MeasurementRun( - name = "meas run 1", - spec = LinkByUID(scope="another fake scope", id = "another fake id"), + name="meas run 1", + spec=LinkByUID(scope="another fake scope", id="another fake id"), properties=[ - Property( - name = "prop 1", - value=NominalReal(nominal=4.1, units="") - ) - ] - ) + Property(name="prop 1", value=NominalReal(nominal=4.1, units="")) + ], + ), ] return faux_gems + def test_attribute_alignment(): """Tests the make_attribute_table() method on a list of GEMD objects including nestled objects, confirming the expected values are being returned in the correct locations """ info_dict = make_attribute_table(_make_list_of_gems()) - assert(isinstance(info_dict, list)) - assert(isinstance(info_dict[0], dict)) + assert isinstance(info_dict, list) + assert isinstance(info_dict[0], dict) assert isinstance(info_dict[0]["PARAMETER: param 1"], NominalReal) assert isinstance(info_dict[1]["PARAMETER: param 1"], NormalReal) assert isinstance(info_dict[4]["PARAMETER: param 1"], NominalReal) diff --git a/tests/conftest.py b/tests/conftest.py index 39d2266af..fe33e7830 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,34 +6,29 @@ from citrine.informatics.predictors import AutoMLEstimator from citrine.resources.status_detail import StatusDetail, StatusLevelEnum -from tests.utils.factories import (PredictorEntityDataFactory, PredictorDataDataFactory, - PredictorMetadataDataFactory, StatusDataFactory) +from tests.utils.factories import ( + PredictorDataDataFactory, + PredictorEntityDataFactory, + PredictorMetadataDataFactory, + StatusDataFactory, +) def build_predictor_entity(instance, status_name="READY", status_detail=[]): user = str(uuid.uuid4()) - time = '2020-04-23T15:46:26Z' + time = "2020-04-23T15:46:26Z" return dict( id=str(uuid.uuid4()), data=dict( name=instance.get("name"), description=instance.get("description"), - instance=instance + instance=instance, ), metadata=dict( - status=dict( - name=status_name, - detail=status_detail - ), - created=dict( - user=user, - time=time - ), - updated=dict( - user=user, - time=time - ) - ) + status=dict(name=status_name, detail=status_detail), + created=dict(user=user, time=time), + updated=dict(user=user, time=time), + ), ) @@ -41,78 +36,70 @@ def build_predictor_entity(instance, status_name="READY", status_detail=[]): def valid_product_design_space_data(): """Produce valid product design space data.""" from citrine.informatics.descriptors import FormulationDescriptor + user = str(uuid.uuid4()) - time = '2020-04-23T15:46:26Z' + time = "2020-04-23T15:46:26Z" return dict( id=str(uuid.uuid4()), data=dict( - name='my design space', - description='does some things', + name="my design space", + description="does some things", instance=dict( - type='ProductDesignSpace', - name='my design space', - description='does some things', + type="ProductDesignSpace", + name="my design space", + description="does some things", subspaces=[ dict( - type='FormulationDesignSpace', - name='first subspace', - description='', + type="FormulationDesignSpace", + name="first subspace", + description="", formulation_descriptor=FormulationDescriptor.hierarchical().dump(), - ingredients=['foo'], - labels={'bar': ['foo']}, + ingredients=["foo"], + labels={"bar": ["foo"]}, constraints=[], - resolution=0.1 + resolution=0.1, ), dict( - type='FormulationDesignSpace', - name='second subspace', - description='formulates some things', + type="FormulationDesignSpace", + name="second subspace", + description="formulates some things", formulation_descriptor=FormulationDescriptor.hierarchical().dump(), - ingredients=['baz'], + ingredients=["baz"], labels={}, constraints=[], - resolution=0.1 - ) + resolution=0.1, + ), ], dimensions=[ dict( - type='ContinuousDimension', + type="ContinuousDimension", descriptor=dict( - type='Real', - descriptor_key='alpha', - units='', + type="Real", + descriptor_key="alpha", + units="", lower_bound=5.0, upper_bound=10.0, ), lower_bound=6.0, - upper_bound=7.0 + upper_bound=7.0, ), dict( - type='EnumeratedDimension', + type="EnumeratedDimension", descriptor=dict( - type='Categorical', - descriptor_key='color', - descriptor_values=['blue', 'green', 'red'], + type="Categorical", + descriptor_key="color", + descriptor_values=["blue", "green", "red"], ), - list=['red'] - ) - ] - ) + list=["red"], + ), + ], + ), ), metadata=dict( - created=dict( - user=user, - time=time - ), - updated=dict( - user=user, - time=time - ), - status=dict( - name='VALIDATING', - detail=[] - ) - ) + created=dict(user=user, time=time), + updated=dict(user=user, time=time), + status=dict(name="VALIDATING", detail=[]), + ), ) @@ -120,58 +107,43 @@ def valid_product_design_space_data(): def valid_enumerated_design_space_data(): """Produce valid enumerated design space data.""" user = str(uuid.uuid4()) - time = '2020-04-23T15:46:26Z' + time = "2020-04-23T15:46:26Z" return dict( id=str(uuid.uuid4()), data=dict( - name='my enumerated design space', - description='enumerates some things', + name="my enumerated design space", + description="enumerates some things", instance=dict( - type='EnumeratedDesignSpace', - name='my enumerated design space', - description='enumerates some things', + type="EnumeratedDesignSpace", + name="my enumerated design space", + description="enumerates some things", descriptors=[ dict( - type='Real', - descriptor_key='x', - units='', + type="Real", + descriptor_key="x", + units="", lower_bound=1.0, upper_bound=2.0, ), dict( - type='Categorical', - descriptor_key='color', - descriptor_values=['blue', 'green', 'red'], + type="Categorical", + descriptor_key="color", + descriptor_values=["blue", "green", "red"], ), - dict( - type='Inorganic', - descriptor_key='formula' - ) + dict(type="Inorganic", descriptor_key="formula"), ], data=[ - dict(x='1', color='red', formula='C44H54Si2'), - dict(x='2.0', color='green', formula='V2O3') - ] - ) + dict(x="1", color="red", formula="C44H54Si2"), + dict(x="2.0", color="green", formula="V2O3"), + ], + ), ), metadata=dict( - created=dict( - user=user, - time=time - ), - updated=dict( - user=user, - time=time - ), - archived=dict( - user=user, - time=time - ), - status=dict( - name='VALIDATING', - detail=[] - ) - ) + created=dict(user=user, time=time), + updated=dict(user=user, time=time), + archived=dict(user=user, time=time), + status=dict(name="VALIDATING", detail=[]), + ), ) @@ -180,129 +152,105 @@ def valid_formulation_design_space_data(): """Produce valid formulation design space data.""" from citrine.informatics.constraints import IngredientCountConstraint from citrine.informatics.descriptors import FormulationDescriptor + descriptor = FormulationDescriptor.hierarchical() - constraint = IngredientCountConstraint(formulation_descriptor=descriptor, min=0, max=1) + constraint = IngredientCountConstraint( + formulation_descriptor=descriptor, min=0, max=1 + ) user = str(uuid.uuid4()) - time = '2020-04-23T15:46:26Z' + time = "2020-04-23T15:46:26Z" return dict( id=str(uuid.uuid4()), data=dict( - name='formulation design space', - description='formulates some things', + name="formulation design space", + description="formulates some things", instance=dict( - type='FormulationDesignSpace', - name='formulation design space', - description='formulates some things', + type="FormulationDesignSpace", + name="formulation design space", + description="formulates some things", formulation_descriptor=descriptor.dump(), - ingredients=['foo'], - labels={'bar': ['foo']}, + ingredients=["foo"], + labels={"bar": ["foo"]}, constraints=[constraint.dump()], - resolution=0.1 - ) + resolution=0.1, + ), ), metadata=dict( - created=dict( - user=user, - time=time - ), - updated=dict( - user=user, - time=time - ), - archived=dict( - user=user, - time=time - ), - status=dict( - name='VALIDATING', - detail=[] - ) - ) + created=dict(user=user, time=time), + updated=dict(user=user, time=time), + archived=dict(user=user, time=time), + status=dict(name="VALIDATING", detail=[]), + ), ) @pytest.fixture def valid_hierarchical_design_space_data( - valid_material_node_definition_data, - valid_gem_data_source_dict + valid_material_node_definition_data, valid_gem_data_source_dict ): """Produce valid hierarchical design space data.""" import copy - name = 'hierarchical design space' - description = 'does things but in levels' + + name = "hierarchical design space" + description = "does things but in levels" user = str(uuid.uuid4()) - time = '2020-04-23T15:46:26Z' + time = "2020-04-23T15:46:26Z" return dict( id=str(uuid.uuid4()), data=dict( name=name, description=description, instance=dict( - type='HierarchicalDesignSpace', + type="HierarchicalDesignSpace", name=name, description=description, root=copy.deepcopy(valid_material_node_definition_data), subspaces=[copy.deepcopy(valid_material_node_definition_data)], - data_sources=[valid_gem_data_source_dict] - ) + data_sources=[valid_gem_data_source_dict], + ), ), metadata=dict( - created=dict( - user=user, - time=time - ), - updated=dict( - user=user, - time=time - ), - archived=dict( - user=user, - time=time - ), - status=dict( - name='VALIDATING', - detail=[] - ) - ) + created=dict(user=user, time=time), + updated=dict(user=user, time=time), + archived=dict(user=user, time=time), + status=dict(name="VALIDATING", detail=[]), + ), ) @pytest.fixture def valid_material_node_definition_data(valid_formulation_design_space_data): return dict( - identifier=dict( - id=f"Material Node-{uuid.uuid4()}", - scope="Custom Scope" - ), + identifier=dict(id=f"Material Node-{uuid.uuid4()}", scope="Custom Scope"), attributes=[ dict( - type='ContinuousDimension', + type="ContinuousDimension", descriptor=dict( - type='Real', - descriptor_key='alpha', - units='', + type="Real", + descriptor_key="alpha", + units="", lower_bound=5.0, upper_bound=10.0, ), lower_bound=6.0, - upper_bound=7.0 + upper_bound=7.0, ), dict( - type='EnumeratedDimension', + type="EnumeratedDimension", descriptor=dict( - type='Categorical', - descriptor_key='color', - descriptor_values=['blue', 'green', 'red'], + type="Categorical", + descriptor_key="color", + descriptor_values=["blue", "green", "red"], ), - list=['red'] - ) + list=["red"], + ), ], formulation=valid_formulation_design_space_data["data"]["instance"], template=dict( material_template=str(uuid.uuid4()), process_template=str(uuid.uuid4()), ), - display_name="Material Node" + display_name="Material Node", ) @@ -310,8 +258,8 @@ def valid_material_node_definition_data(valid_formulation_design_space_data): def valid_gem_data_source_dict(): return { "type": "hosted_table_data_source", - "table_id": 'e5c51369-8e71-4ec6-b027-1f92bdc14762', - "table_version": 2 + "table_id": "e5c51369-8e71-4ec6-b027-1f92bdc14762", + "table_version": 2, } @@ -319,40 +267,44 @@ def valid_gem_data_source_dict(): def valid_auto_ml_predictor_data(valid_gem_data_source_dict): """Produce valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor + x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="") z = RealDescriptor("z", lower_bound=0, upper_bound=100, units="") return dict( - type='AutoML', - name='AutoML predictor', - description='Predicts z from input x', + type="AutoML", + name="AutoML predictor", + description="Predicts z from input x", inputs=[x.dump()], outputs=[z.dump()], estimators=[AutoMLEstimator.RANDOM_FOREST.value], - training_data=[] + training_data=[], ) @pytest.fixture def valid_graph_predictor_data( - valid_simple_mixture_predictor_data, - valid_label_fractions_predictor_data, - valid_expression_predictor_data, - valid_mean_property_predictor_data, - valid_auto_ml_predictor_data + valid_simple_mixture_predictor_data, + valid_label_fractions_predictor_data, + valid_expression_predictor_data, + valid_mean_property_predictor_data, + valid_auto_ml_predictor_data, ): """Produce valid data used for tests.""" from citrine.informatics.data_sources import GemTableDataSource + instance = dict( - name='Graph predictor', - description='description', + name="Graph predictor", + description="description", predictors=[ valid_simple_mixture_predictor_data, valid_label_fractions_predictor_data, valid_expression_predictor_data, valid_mean_property_predictor_data, - valid_auto_ml_predictor_data + valid_auto_ml_predictor_data, + ], + training_data=[ + GemTableDataSource(table_id=uuid.uuid4(), table_version=0).dump() ], - training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=0).dump()] ) return PredictorEntityDataFactory(data=PredictorDataDataFactory(instance=instance)) @@ -361,11 +313,11 @@ def valid_graph_predictor_data( def valid_graph_predictor_data_empty(): """Another predictor valid data used for tests.""" instance = dict( - type='Graph', - name='Empty Graph predictor', - description='description', + type="Graph", + name="Empty Graph predictor", + description="description", predictors=[], - training_data=[] + training_data=[], ) return PredictorEntityDataFactory(data=PredictorDataDataFactory(instance=instance)) @@ -374,17 +326,20 @@ def valid_graph_predictor_data_empty(): def valid_deprecated_expression_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor - shear_modulus = RealDescriptor('Property~Shear modulus', lower_bound=0, upper_bound=100, units='GPa') + + shear_modulus = RealDescriptor( + "Property~Shear modulus", lower_bound=0, upper_bound=100, units="GPa" + ) return dict( - type='Expression', - name='Expression predictor', - description='Computes shear modulus from Youngs modulus and Poissons ratio', - expression='Y / (2 * (1 + v))', + type="Expression", + name="Expression predictor", + description="Computes shear modulus from Youngs modulus and Poissons ratio", + expression="Y / (2 * (1 + v))", output=shear_modulus.dump(), aliases={ - 'Y': "Property~Young's modulus", - 'v': "Property~Poisson's ratio", - } + "Y": "Property~Young's modulus", + "v": "Property~Poisson's ratio", + }, ) @@ -392,19 +347,26 @@ def valid_deprecated_expression_predictor_data(): def valid_expression_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor - shear_modulus = RealDescriptor('Property~Shear modulus', lower_bound=0, upper_bound=100, units='GPa') - youngs_modulus = RealDescriptor('Property~Young\'s modulus', lower_bound=0, upper_bound=100, units='GPa') - poissons_ratio = RealDescriptor('Property~Poisson\'s ratio', lower_bound=-1, upper_bound=0.5, units='') + + shear_modulus = RealDescriptor( + "Property~Shear modulus", lower_bound=0, upper_bound=100, units="GPa" + ) + youngs_modulus = RealDescriptor( + "Property~Young's modulus", lower_bound=0, upper_bound=100, units="GPa" + ) + poissons_ratio = RealDescriptor( + "Property~Poisson's ratio", lower_bound=-1, upper_bound=0.5, units="" + ) return dict( - type='AnalyticExpression', - name='Expression predictor', - description='Computes shear modulus from Youngs modulus and Poissons ratio', - expression='Y / (2 * (1 + v))', + type="AnalyticExpression", + name="Expression predictor", + description="Computes shear modulus from Youngs modulus and Poissons ratio", + expression="Y / (2 * (1 + v))", output=shear_modulus.dump(), aliases={ - 'Y': youngs_modulus.dump(), - 'v': poissons_ratio.dump(), - } + "Y": youngs_modulus.dump(), + "v": poissons_ratio.dump(), + }, ) @@ -412,40 +374,39 @@ def valid_expression_predictor_data(): def valid_predictor_report_data(example_categorical_pva_metrics, example_f1_metrics): """Produce valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor + x = RealDescriptor("x", lower_bound=0, upper_bound=1, units="") y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="") z = RealDescriptor("z", lower_bound=0, upper_bound=101, units="") return dict( - id='7c2dda5d-675a-41b6-829c-e485163f0e43', - module_id='31c7f311-6f3d-4a93-9387-94cc877f170c', - status='OK', - create_time='2020-04-23T15:46:26Z', - update_time='2020-04-23T15:46:26Z', + id="7c2dda5d-675a-41b6-829c-e485163f0e43", + module_id="31c7f311-6f3d-4a93-9387-94cc877f170c", + status="OK", + create_time="2020-04-23T15:46:26Z", + update_time="2020-04-23T15:46:26Z", report=dict( models=[ dict( - name='GeneralLoloModel_1', - type='ML Model', + name="GeneralLoloModel_1", + type="ML Model", inputs=[x.key], outputs=[y.key], - display_name='ML Model', + display_name="ML Model", model_settings=[ dict( - name='Algorithm', - value='Ensemble of non-linear estimators', + name="Algorithm", + value="Ensemble of non-linear estimators", children=[ - dict(name='Number of estimators', value=64, children=[]), - dict(name='Leaf model', value='Mean', children=[]), - dict(name='Use jackknife', value=True, children=[]) - ] + dict( + name="Number of estimators", value=64, children=[] + ), + dict(name="Leaf model", value="Mean", children=[]), + dict(name="Use jackknife", value=True, children=[]), + ], ) ], feature_importances=[ - dict( - response_key='y', - importances=dict(x=1.00), - top_features=5 - ) + dict(response_key="y", importances=dict(x=1.00), top_features=5) ], selection_summary=dict( n_folds=4, @@ -453,48 +414,56 @@ def valid_predictor_report_data(example_categorical_pva_metrics, example_f1_metr dict( model_settings=[ dict( - name='Algorithm', - value='Ensemble of non-linear estimators', + name="Algorithm", + value="Ensemble of non-linear estimators", children=[ - dict(name='Number of estimators', value=64, children=[]), - dict(name='Leaf model', value='Mean', children=[]), - dict(name='Use jackknife', value=True, children=[]) - ] + dict( + name="Number of estimators", + value=64, + children=[], + ), + dict( + name="Leaf model", + value="Mean", + children=[], + ), + dict( + name="Use jackknife", + value=True, + children=[], + ), + ], ) ], response_results=dict( response_name=dict( metrics=dict( predicted_vs_actual=example_categorical_pva_metrics, - f1=example_f1_metrics + f1=example_f1_metrics, ) ) - ) + ), ) - ] + ], ), - predictor_configuration_name="Predict y from x with ML" + predictor_configuration_name="Predict y from x with ML", ), dict( - name='GeneralLosslessModel_2', - type='Analytic Model', + name="GeneralLosslessModel_2", + type="Analytic Model", inputs=[x.key, y.key], outputs=[z.key], - display_name='GeneralLosslessModel_2', + display_name="GeneralLosslessModel_2", model_settings=[ - dict( - name="Expression", - value="(z) <- (x + y)", - children=[] - ) + dict(name="Expression", value="(z) <- (x + y)", children=[]) ], feature_importances=[], predictor_configuration_name="Expression for z", - predictor_configuration_uid="249bf32c-6f3d-4a93-9387-94cc877f170c" - ) + predictor_configuration_uid="249bf32c-6f3d-4a93-9387-94cc877f170c", + ), ], - descriptors=[x.dump(), y.dump(), z.dump()] - ) + descriptors=[x.dump(), y.dump(), z.dump()], + ), ) @@ -502,18 +471,23 @@ def valid_predictor_report_data(example_categorical_pva_metrics, example_f1_metr def valid_ing_formulation_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor + return dict( - type='IngredientsToSimpleMixture', - name='Ingredients to formulation predictor', - description='Constructs mixtures from ingredients', + type="IngredientsToSimpleMixture", + name="Ingredients to formulation predictor", + description="Constructs mixtures from ingredients", id_to_quantity={ - 'water': RealDescriptor('water quantity', lower_bound=0, upper_bound=1, units="").dump(), - 'salt': RealDescriptor('salt quantity', lower_bound=0, upper_bound=1, units="").dump() + "water": RealDescriptor( + "water quantity", lower_bound=0, upper_bound=1, units="" + ).dump(), + "salt": RealDescriptor( + "salt quantity", lower_bound=0, upper_bound=1, units="" + ).dump(), }, labels={ - 'solvent': ['water'], - 'solute': ['salt'], - } + "solvent": ["water"], + "solute": ["salt"], + }, ) @@ -521,17 +495,18 @@ def valid_ing_formulation_predictor_data(): def valid_generalized_mean_property_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import FormulationDescriptor + formulation_descriptor = FormulationDescriptor.hierarchical() return dict( - type='GeneralizedMeanProperty', - name='Mean property predictor', - description='Computes mean ingredient properties', + type="GeneralizedMeanProperty", + name="Mean property predictor", + description="Computes mean ingredient properties", input=formulation_descriptor.dump(), - properties=['density'], + properties=["density"], p=2, impute_properties=True, - default_properties={'density': 1.0}, - label='solvent' + default_properties={"density": 1.0}, + label="solvent", ) @@ -539,19 +514,22 @@ def valid_generalized_mean_property_predictor_data(): def valid_mean_property_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import FormulationDescriptor, RealDescriptor + formulation_descriptor = FormulationDescriptor.flat() - density = RealDescriptor(key='density', lower_bound=0, upper_bound=100, units='g/cm^3') + density = RealDescriptor( + key="density", lower_bound=0, upper_bound=100, units="g/cm^3" + ) return dict( - type='MeanProperty', - name='Mean property predictor', - description='Computes mean ingredient properties', + type="MeanProperty", + name="Mean property predictor", + description="Computes mean ingredient properties", input=formulation_descriptor.dump(), properties=[density.dump()], p=2.0, impute_properties=True, - default_properties={'density': 1.0}, - label='solvent', - training_data=[] + default_properties={"density": 1.0}, + label="solvent", + training_data=[], ) @@ -559,12 +537,13 @@ def valid_mean_property_predictor_data(): def valid_label_fractions_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import FormulationDescriptor + return dict( - type='LabelFractions', - name='Label fractions predictor', - description='Computes relative proportions of labeled ingredients', + type="LabelFractions", + name="Label fractions predictor", + description="Computes relative proportions of labeled ingredients", input=FormulationDescriptor.hierarchical().dump(), - labels=['solvent'] + labels=["solvent"], ) @@ -572,19 +551,20 @@ def valid_label_fractions_predictor_data(): def valid_ingredient_fractions_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import FormulationDescriptor + return dict( - type='IngredientFractions', - name='Ingredient fractions predictor', - description='Computes ingredient fractions', + type="IngredientFractions", + name="Ingredient fractions predictor", + description="Computes ingredient fractions", input=FormulationDescriptor.hierarchical().dump(), - ingredients=['Blue dye', 'Red dye'] + ingredients=["Blue dye", "Red dye"], ) @pytest.fixture def valid_data_source_design_space_dict(valid_gem_data_source_dict): user = str(uuid.uuid4()) - time = '2020-04-23T15:46:26Z' + time = "2020-04-23T15:46:26Z" return dict( id=str(uuid.uuid4()), data=dict( @@ -594,23 +574,14 @@ def valid_data_source_design_space_dict(valid_gem_data_source_dict): type="DataSourceDesignSpace", name="Example valid data source design space", description="Example valid data source design space based on a GEM Table Data Source.", - data_source=valid_gem_data_source_dict - ) + data_source=valid_gem_data_source_dict, + ), ), metadata=dict( - created=dict( - user=user, - time=time - ), - updated=dict( - user=user, - time=time - ), - status=dict( - name='VALIDATING', - detail=[] - ) - ) + created=dict(user=user, time=time), + updated=dict(user=user, time=time), + status=dict(name="VALIDATING", detail=[]), + ), ) @@ -618,15 +589,16 @@ def valid_data_source_design_space_dict(valid_gem_data_source_dict): def invalid_predictor_node_data(): """Produce invalid valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor + x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="") y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="") z = RealDescriptor("z", lower_bound=0, upper_bound=100, units="") return dict( - type='invalid', - name='my predictor', - description='does some things', + type="invalid", + name="my predictor", + description="does some things", inputs=[x.dump(), y.dump()], - output=z.dump() + output=z.dump(), ) @@ -634,23 +606,24 @@ def invalid_predictor_node_data(): def invalid_graph_predictor_data(): """Produce valid data used for tests.""" from citrine.informatics.descriptors import RealDescriptor + x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="") y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="") - z = RealDescriptor("z", lower_bound=0, upper_bound=100, units="") + instance = dict( - type='invalid', - name='my predictor', - description='does some things badly', + type="invalid", + name="my predictor", + description="does some things badly", predictors=[x.dump(), y.dump()], ) detail = [ - StatusDetail(level=StatusLevelEnum.WARNING, msg='Something is wrong'), - StatusDetail(level="Error", msg='Very wrong') + StatusDetail(level=StatusLevelEnum.WARNING, msg="Something is wrong"), + StatusDetail(level="Error", msg="Very wrong"), ] - status = StatusDataFactory(name='INVALID', detail=detail) + status = StatusDataFactory(name="INVALID", detail=detail) return PredictorEntityDataFactory( data=PredictorDataDataFactory(instance=instance), - meatadata=PredictorMetadataDataFactory(status=status) + meatadata=PredictorMetadataDataFactory(status=status), ) @@ -658,10 +631,10 @@ def invalid_graph_predictor_data(): def valid_simple_mixture_predictor_data(): """Produce valid data used for tests.""" return dict( - type='SimpleMixture', - name='Simple mixture predictor', - description='simple mixture description', - training_data=[] + type="SimpleMixture", + name="Simple mixture predictor", + description="simple mixture description", + training_data=[], ) @@ -674,10 +647,8 @@ def example_cv_evaluator_dict(): "responses": ["salt?", "saltiness"], "n_folds": 6, "n_trials": 8, - "metrics": [ - {"type": "PVA"}, {"type": "RMSE"}, {"type": "F1"} - ], - "ignore_when_grouping": ["temperature"] + "metrics": [{"type": "PVA"}, {"type": "RMSE"}, {"type": "F1"}], + "ignore_when_grouping": ["temperature"], } @@ -689,24 +660,18 @@ def example_holdout_evaluator_dict(valid_gem_data_source_dict): "description": "", "responses": ["sweetness"], "data_source": valid_gem_data_source_dict, - "metrics": [{"type": "RMSE"}] + "metrics": [{"type": "RMSE"}], } + @pytest.fixture() def example_rmse_metrics(): - return { - "type": "RealMetricValue", - "mean": 0.4, - "standard_error": 0.12 - } + return {"type": "RealMetricValue", "mean": 0.4, "standard_error": 0.12} @pytest.fixture def example_f1_metrics(): - return { - "type": "RealMetricValue", - "mean": 0.3 - } + return {"type": "RealMetricValue", "mean": 0.3} @pytest.fixture @@ -722,15 +687,15 @@ def example_real_pva_metrics(): "predicted": { "type": "RealMetricValue", "mean": 1.0, - "standard_error": 0.12 + "standard_error": 0.12, }, "actual": { "type": "RealMetricValue", "mean": 1.2, - "standard_error": 0.0 - } + "standard_error": 0.0, + }, } - ] + ], } @@ -744,20 +709,21 @@ def example_categorical_pva_metrics(): "identifiers": ["Foo", "Bar"], "trial": 1, "fold": 3, - "predicted": { - "salt": 0.3, - "not salt": 0.7 - }, - "actual": { - "not salt": 1.0 - } + "predicted": {"salt": 0.3, "not salt": 0.7}, + "actual": {"not salt": 1.0}, } - ] + ], } @pytest.fixture() -def example_cv_result_dict(example_cv_evaluator_dict, example_rmse_metrics, example_categorical_pva_metrics, example_f1_metrics, example_real_pva_metrics): +def example_cv_result_dict( + example_cv_evaluator_dict, + example_rmse_metrics, + example_categorical_pva_metrics, + example_f1_metrics, + example_real_pva_metrics, +): return { "type": "CrossValidationResult", "evaluator": example_cv_evaluator_dict, @@ -765,16 +731,16 @@ def example_cv_result_dict(example_cv_evaluator_dict, example_rmse_metrics, exam "salt?": { "metrics": { "predicted_vs_actual": example_categorical_pva_metrics, - "f1": example_f1_metrics + "f1": example_f1_metrics, } }, "saltiness": { "metrics": { "predicted_vs_actual": example_real_pva_metrics, - "rmse": example_rmse_metrics + "rmse": example_rmse_metrics, } - } - } + }, + }, } @@ -783,13 +749,7 @@ def example_holdout_result_dict(example_holdout_evaluator_dict, example_rmse_met return { "type": "HoldoutSetResult", "evaluator": example_holdout_evaluator_dict, - "response_results": { - "sweetness": { - "metrics": { - "rmse": example_rmse_metrics - } - } - } + "response_results": {"sweetness": {"metrics": {"rmse": example_rmse_metrics}}}, } @@ -802,8 +762,8 @@ def sample_design_space_execution_dict(generic_entity): "status": { "major": ret.get("status"), "minor": ret.get("status_description"), - "detail": ret.get("status_detail") - } + "detail": ret.get("status_detail"), + }, } ) return ret @@ -812,30 +772,31 @@ def sample_design_space_execution_dict(generic_entity): @pytest.fixture() def example_design_material(): return { - 'vars': { - 'Temperature': {'type': 'R', 'm': 475.8, 's': 0}, - 'Flour': {'type': 'C', 'cp': {'flour': 100.0}}, - 'Water': {'type': 'M', 'q': {'water': 72.5}, 'l': {}}, - 'Salt': {'type': 'F', 'f': 'NaCl'}, - 'Yeast': {'type': 'S', 's': 'O1C=2C=C(C=3SC=C4C=CNC43)CC2C=5C=CC=6C=CNC6C15'} + "vars": { + "Temperature": {"type": "R", "m": 475.8, "s": 0}, + "Flour": {"type": "C", "cp": {"flour": 100.0}}, + "Water": {"type": "M", "q": {"water": 72.5}, "l": {}}, + "Salt": {"type": "F", "f": "NaCl"}, + "Yeast": { + "type": "S", + "s": "O1C=2C=C(C=3SC=C4C=CNC43)CC2C=5C=CC=6C=CNC6C15", + }, + }, + "identifiers": { + "id": str(uuid.uuid4()), + "identifiers": [], + "material_template": str(uuid.uuid4()), + "process_template": str(uuid.uuid4()), }, - 'identifiers': { - 'id': str(uuid.uuid4()), - 'identifiers': [], - 'material_template': str(uuid.uuid4()), - 'process_template': str(uuid.uuid4()) - } } @pytest.fixture() def example_hierarchical_design_material(example_design_material): return { - 'terminal': example_design_material, - 'sub_materials': [example_design_material], - 'mixtures': { - str(uuid.uuid4()): {'q': {'A': 0.5, 'B': 0.5}, 'l': {}} - } + "terminal": example_design_material, + "sub_materials": [example_design_material], + "mixtures": {str(uuid.uuid4()): {"q": {"A": 0.5, "B": 0.5}, "l": {}}}, } @@ -844,48 +805,53 @@ def example_hierarchical_candidates(example_hierarchical_design_material): return { "page": 2, "per_page": 4, - "response": [{ - "id": str(uuid.uuid4()), - "primary_score": 0, - "rank": 1, - "material": example_hierarchical_design_material, - "name": "Example candidate", - "hidden": True, - "comments": [ - { - "message": "a message", - "created": { - "user": str(uuid.uuid4()), - "time": '2025-02-20T10:46:26Z' + "response": [ + { + "id": str(uuid.uuid4()), + "primary_score": 0, + "rank": 1, + "material": example_hierarchical_design_material, + "name": "Example candidate", + "hidden": True, + "comments": [ + { + "message": "a message", + "created": { + "user": str(uuid.uuid4()), + "time": "2025-02-20T10:46:26Z", + }, } - } - ] - }] + ], + } + ], } + @pytest.fixture() def example_candidates(example_design_material): return { "page": 2, "per_page": 4, - "response": [{ - "id": str(uuid.uuid4()), - "material_id": str(uuid.uuid4()), - "identifiers": [], - "primary_score": 0, - "material": example_design_material, - "name": "Example candidate", - "hidden": True, - "comments": [ - { - "message": "a message", - "created": { - "user": str(uuid.uuid4()), - "time": '2025-02-20T10:46:26Z' + "response": [ + { + "id": str(uuid.uuid4()), + "material_id": str(uuid.uuid4()), + "identifiers": [], + "primary_score": 0, + "material": example_design_material, + "name": "Example candidate", + "hidden": True, + "comments": [ + { + "message": "a message", + "created": { + "user": str(uuid.uuid4()), + "time": "2025-02-20T10:46:26Z", + }, } - } - ] - }] + ], + } + ], } @@ -893,15 +859,16 @@ def example_candidates(example_design_material): def example_sample_design_space_response(example_hierarchical_design_material): return { "per_page": 4, - "response": [{ - "id": str(uuid.uuid4()), - "execution_id": str(uuid.uuid4()), - "material": example_hierarchical_design_material - }] + "response": [ + { + "id": str(uuid.uuid4()), + "execution_id": str(uuid.uuid4()), + "material": example_hierarchical_design_material, + } + ], } - @pytest.fixture def generic_entity(): user = str(uuid.uuid4()) @@ -912,8 +879,8 @@ def generic_entity(): "status_detail": [{"level": "Info", "msg": "System processing"}], "experimental": False, "experimental_reasons": [], - "create_time": '2020-04-23T15:46:26Z', - "update_time": '2020-04-23T15:46:26Z', + "create_time": "2020-04-23T15:46:26Z", + "update_time": "2020-04-23T15:46:26Z", "created_by": user, "updated_by": user, } @@ -922,31 +889,35 @@ def generic_entity(): @pytest.fixture def predictor_evaluation_execution_dict(generic_entity): ret = deepcopy(generic_entity) - ret.update({ - "workflow_id": str(uuid.uuid4()), - "predictor_id": str(uuid.uuid4()), - "predictor_version": random.randint(1, 10), - "evaluator_names": ["Example evaluator"] - }) + ret.update( + { + "workflow_id": str(uuid.uuid4()), + "predictor_id": str(uuid.uuid4()), + "predictor_version": random.randint(1, 10), + "evaluator_names": ["Example evaluator"], + } + ) return ret @pytest.fixture def design_execution_dict(generic_entity): ret = generic_entity.copy() - ret.update({ - "workflow_id": str(uuid.uuid4()), - "version_number": 2, - "score": { - "type": "MLI", - "baselines": [], - "constraints": [], - "objectives": [], - "name": "score", - "description": "" - }, - "descriptors": [] - }) + ret.update( + { + "workflow_id": str(uuid.uuid4()), + "version_number": 2, + "score": { + "type": "MLI", + "baselines": [], + "constraints": [], + "objectives": [], + "name": "score", + "description": "", + }, + "descriptors": [], + } + ) return ret @@ -961,26 +932,31 @@ def example_generation_results(): return { "page": 1, "per_page": 4, - "response": [{ - "id": str(uuid.uuid4()), - "execution_id": str(uuid.uuid4()), - "result": { - "seed": "CCCCO", - "mutated": "CCCN", - "fingerprint_similarity": 0.41, - "fingerprint_type": "ECFP4", + "response": [ + { + "id": str(uuid.uuid4()), + "execution_id": str(uuid.uuid4()), + "result": { + "seed": "CCCCO", + "mutated": "CCCN", + "fingerprint_similarity": 0.41, + "fingerprint_type": "ECFP4", + }, } - }] + ], } - @pytest.fixture -def predictor_evaluation_workflow_dict(generic_entity, example_cv_evaluator_dict, example_holdout_evaluator_dict): +def predictor_evaluation_workflow_dict( + generic_entity, example_cv_evaluator_dict, example_holdout_evaluator_dict +): ret = deepcopy(generic_entity) - ret.update({ - "name": "Example PEW", - "description": "Example PEW for testing", - "evaluators": [example_cv_evaluator_dict, example_holdout_evaluator_dict] - }) + ret.update( + { + "name": "Example PEW", + "description": "Example PEW for testing", + "evaluators": [example_cv_evaluator_dict, example_holdout_evaluator_dict], + } + ) return ret diff --git a/tests/gemd_query/test_gemd_query.py b/tests/gemd_query/test_gemd_query.py index b302da482..7183a19fb 100644 --- a/tests/gemd_query/test_gemd_query.py +++ b/tests/gemd_query/test_gemd_query.py @@ -13,14 +13,14 @@ def test_gemd_query_version(): assert GemdQuery.build(valid) is not None invalid = GemdQueryDataFactory() - invalid['schema_version'] = 2 + invalid["schema_version"] = 2 with pytest.raises(ValueError): GemdQuery.build(invalid) def test_criteria_rebuild(): value_filter = AllRealFilter() - value_filter.unit = 'm' + value_filter.unit = "m" value_filter.lower = 0 value_filter.upper = 1 @@ -31,15 +31,27 @@ def test_criteria_rebuild(): query = GemdQuery() query.criteria.append(crit) query.datasets.add(uuid4()) - query.object_types = {'material_run'} + query.object_types = {"material_run"} query_copy = GemdQuery.build(query.dump()) assert len(query.criteria) == len(query_copy.criteria) - assert query.criteria[0].property_templates_filter == query_copy.criteria[0].property_templates_filter - assert query.criteria[0].value_type_filter.unit == query_copy.criteria[0].value_type_filter.unit - assert query.criteria[0].value_type_filter.lower == query_copy.criteria[0].value_type_filter.lower - assert query.criteria[0].value_type_filter.upper == query_copy.criteria[0].value_type_filter.upper + assert ( + query.criteria[0].property_templates_filter + == query_copy.criteria[0].property_templates_filter + ) + assert ( + query.criteria[0].value_type_filter.unit + == query_copy.criteria[0].value_type_filter.unit + ) + assert ( + query.criteria[0].value_type_filter.lower + == query_copy.criteria[0].value_type_filter.lower + ) + assert ( + query.criteria[0].value_type_filter.upper + == query_copy.criteria[0].value_type_filter.upper + ) assert query.datasets == query_copy.datasets assert query.object_types == query_copy.object_types assert query.schema_version == query_copy.schema_version diff --git a/tests/gemtable/test_columns.py b/tests/gemtable/test_columns.py index 4dbb268d3..c9e10983a 100644 --- a/tests/gemtable/test_columns.py +++ b/tests/gemtable/test_columns.py @@ -1,25 +1,53 @@ """Tests for citrine.informatics.columns.""" + import pytest -from citrine.gemtables.columns import * +from citrine.gemtables.columns import ( + ChemicalDisplayFormat, + Column, + ComponentQuantityColumn, + CompositionSortOrder, + ConcatColumn, + FlatCompositionColumn, + IdentityColumn, + MeanColumn, + MolecularStructureColumn, + MostLikelyCategoryColumn, + MostLikelyProbabilityColumn, + NthBiggestComponentNameColumn, + NthBiggestComponentQuantityColumn, + OriginalUnitsColumn, + QuantileColumn, + StdDevColumn, +) from citrine.gemtables.variables import TerminalMaterialInfo -@pytest.fixture(params=[ - IdentityColumn(data_source="terminal name"), - MeanColumn(data_source="density", target_units="g/cm^3"), - StdDevColumn(data_source="density", target_units="g/cm^3"), - QuantileColumn(data_source="density", quantile=0.95), - OriginalUnitsColumn(data_source="density"), - MostLikelyCategoryColumn(data_source="color"), - MostLikelyProbabilityColumn(data_source="color"), - FlatCompositionColumn(data_source="formula", sort_order=CompositionSortOrder.QUANTITY), - ComponentQuantityColumn(data_source="formula", component_name="Si", normalize=True), - NthBiggestComponentNameColumn(data_source="formula", n=1), - NthBiggestComponentQuantityColumn(data_source="formula", n=2), - MolecularStructureColumn(data_source="molecule", format=ChemicalDisplayFormat.SMILES), - ConcatColumn(data_source="labels", subcolumn=IdentityColumn(data_source="terminal name")) -]) +@pytest.fixture( + params=[ + IdentityColumn(data_source="terminal name"), + MeanColumn(data_source="density", target_units="g/cm^3"), + StdDevColumn(data_source="density", target_units="g/cm^3"), + QuantileColumn(data_source="density", quantile=0.95), + OriginalUnitsColumn(data_source="density"), + MostLikelyCategoryColumn(data_source="color"), + MostLikelyProbabilityColumn(data_source="color"), + FlatCompositionColumn( + data_source="formula", sort_order=CompositionSortOrder.QUANTITY + ), + ComponentQuantityColumn( + data_source="formula", component_name="Si", normalize=True + ), + NthBiggestComponentNameColumn(data_source="formula", n=1), + NthBiggestComponentQuantityColumn(data_source="formula", n=2), + MolecularStructureColumn( + data_source="molecule", format=ChemicalDisplayFormat.SMILES + ), + ConcatColumn( + data_source="labels", subcolumn=IdentityColumn(data_source="terminal name") + ), + ] +) def column(request): return request.param @@ -46,10 +74,9 @@ def test_invalid_deser(): def test_data_source_args(): terminal_name = "terminal name" - var = TerminalMaterialInfo(name=terminal_name, - headers=[terminal_name], - field='NAME' - ) + var = TerminalMaterialInfo( + name=terminal_name, headers=[terminal_name], field="NAME" + ) IdentityColumn(data_source=terminal_name) IdentityColumn(data_source=var) with pytest.raises(TypeError): diff --git a/tests/gemtable/test_rows.py b/tests/gemtable/test_rows.py index 47006415d..ac51df7c3 100644 --- a/tests/gemtable/test_rows.py +++ b/tests/gemtable/test_rows.py @@ -1,22 +1,28 @@ """Tests for citrine.informatics.rows.""" + import pytest from citrine.gemtables.rows import MaterialRunByTemplate, Row from gemd.entity.link_by_uid import LinkByUID -@pytest.fixture(params=[ - MaterialRunByTemplate(templates=[ - LinkByUID(scope="templates", id="density"), LinkByUID(scope="templates", id="ingredients") - ]), - MaterialRunByTemplate(templates=[ - LinkByUID(scope="templates", id="density"), LinkByUID(scope="templates", id="ingredients") - ], - tags=[ - "foo::bar", "some::tag" - ] - ), -]) +@pytest.fixture( + params=[ + MaterialRunByTemplate( + templates=[ + LinkByUID(scope="templates", id="density"), + LinkByUID(scope="templates", id="ingredients"), + ] + ), + MaterialRunByTemplate( + templates=[ + LinkByUID(scope="templates", id="density"), + LinkByUID(scope="templates", id="ingredients"), + ], + tags=["foo::bar", "some::tag"], + ), + ] +) def row(request): return request.param diff --git a/tests/gemtable/test_variables.py b/tests/gemtable/test_variables.py index beb28cac6..11881d16a 100644 --- a/tests/gemtable/test_variables.py +++ b/tests/gemtable/test_variables.py @@ -1,32 +1,171 @@ """Tests for citrine.informatics.variables.""" + import pytest from gemd.entity.bounds.real_bounds import RealBounds - -from citrine.gemtables.variables import * from gemd.entity.link_by_uid import LinkByUID +from citrine.gemtables.variables import ( + XOR, + AttributeByTemplate, + AttributeByTemplateAfterProcessTemplate, + AttributeByTemplateAndObjectTemplate, + AttributeInOutput, + IngredientIdentifierByProcessTemplateAndName, + IngredientIdentifierInOutput, + IngredientLabelByProcessAndName, + IngredientLabelsSetByProcessAndName, + IngredientLabelsSetInOutput, + IngredientQuantityByProcessAndName, + IngredientQuantityDimension, + IngredientQuantityInOutput, + LocalAttribute, + LocalAttributeAndObject, + LocalIngredientIdentifier, + LocalIngredientLabelsSet, + LocalIngredientQuantity, + TerminalMaterialIdentifier, + TerminalMaterialInfo, + Variable, +) + -@pytest.fixture(params=[ - TerminalMaterialInfo(name="terminal name", headers=["Root", "Name"], field="name"), - XOR(name="terminal name or sample_type", headers=["Root", "Info"], variables=[TerminalMaterialInfo(name="terminal name", headers=["Root", "Name"], field="name"), TerminalMaterialInfo(name="terminal name", headers=["Root", "Sample Type"], field="sample_type")]), - AttributeByTemplate(name="density", headers=["density"], template=LinkByUID(scope="templates", id="density"), attribute_constraints=[[LinkByUID(scope="templates", id="density"), RealBounds(0, 100, "g/cm**3")]]), - AttributeByTemplateAfterProcessTemplate(name="density", headers=["density"], attribute_template=LinkByUID(scope="template", id="density"), process_template=LinkByUID(scope="template", id="process")), - AttributeByTemplateAndObjectTemplate(name="density", headers=["density"], attribute_template=LinkByUID(scope="template", id="density"), object_template=LinkByUID(scope="template", id="object")), - AttributeInOutput(name="density", headers=["density"], attribute_template=LinkByUID(scope="template", id="density"), process_templates=[LinkByUID(scope="template", id="object")]), - LocalAttribute(name="density", headers=["density"], template=LinkByUID(scope="templates", id="density"), attribute_constraints=[[LinkByUID(scope="templates", id="density"), RealBounds(0, 100, "g/cm**3")]]), - LocalAttributeAndObject(name="density", headers=["density"], template=LinkByUID(scope="templates", id="density"), object_template=LinkByUID(scope="templates", id="object"), attribute_constraints=[[LinkByUID(scope="templates", id="density"), RealBounds(0, 100, "g/cm**3")]]), - IngredientIdentifierByProcessTemplateAndName(name="ingredient id", headers=["density"], process_template=LinkByUID(scope="template", id="process"), ingredient_name="ingredient", scope="scope"), - IngredientIdentifierInOutput(name="ingredient id", headers=["ingredient id"], ingredient_name="ingredient", process_templates=[LinkByUID(scope="template", id="object")], scope="scope"), - LocalIngredientIdentifier(name="ingredient id", headers=["ingredient id"], ingredient_name="ingredient", scope="scope"), - IngredientLabelByProcessAndName(name="ingredient label", headers=["label"], process_template=LinkByUID(scope="template", id="process"), ingredient_name="ingredient", label="label"), - IngredientLabelsSetByProcessAndName(name="ingredient label", headers=["label"], process_template=LinkByUID(scope="template", id="process"), ingredient_name="ingredient"), - IngredientLabelsSetInOutput(name="ingredient label", headers=["label"], process_templates=[LinkByUID(scope="template", id="process")], ingredient_name="ingredient"), - LocalIngredientLabelsSet(name="ingredient label", headers=["label"], ingredient_name="ingredient"), - IngredientQuantityByProcessAndName(name="ingredient quantity dimension", headers=["quantity"], process_template=LinkByUID(scope="template", id="process"), ingredient_name="ingredient", quantity_dimension=IngredientQuantityDimension.ABSOLUTE, unit='kg'), - IngredientQuantityInOutput(name="ingredient quantity", headers=["ingredient quantity"], ingredient_name="ingredient", quantity_dimension=IngredientQuantityDimension.MASS, process_templates=[LinkByUID(scope="template", id="object")]), - LocalIngredientQuantity(name="ingredient quantity", headers=["ingredient quantity"], ingredient_name="ingredient", quantity_dimension=IngredientQuantityDimension.MASS), - TerminalMaterialIdentifier(name="terminal id", headers=["id"], scope="scope") -]) +@pytest.fixture( + params=[ + TerminalMaterialInfo( + name="terminal name", headers=["Root", "Name"], field="name" + ), + XOR( + name="terminal name or sample_type", + headers=["Root", "Info"], + variables=[ + TerminalMaterialInfo( + name="terminal name", headers=["Root", "Name"], field="name" + ), + TerminalMaterialInfo( + name="terminal name", + headers=["Root", "Sample Type"], + field="sample_type", + ), + ], + ), + AttributeByTemplate( + name="density", + headers=["density"], + template=LinkByUID(scope="templates", id="density"), + attribute_constraints=[ + [ + LinkByUID(scope="templates", id="density"), + RealBounds(0, 100, "g/cm**3"), + ] + ], + ), + AttributeByTemplateAfterProcessTemplate( + name="density", + headers=["density"], + attribute_template=LinkByUID(scope="template", id="density"), + process_template=LinkByUID(scope="template", id="process"), + ), + AttributeByTemplateAndObjectTemplate( + name="density", + headers=["density"], + attribute_template=LinkByUID(scope="template", id="density"), + object_template=LinkByUID(scope="template", id="object"), + ), + AttributeInOutput( + name="density", + headers=["density"], + attribute_template=LinkByUID(scope="template", id="density"), + process_templates=[LinkByUID(scope="template", id="object")], + ), + LocalAttribute( + name="density", + headers=["density"], + template=LinkByUID(scope="templates", id="density"), + attribute_constraints=[ + [ + LinkByUID(scope="templates", id="density"), + RealBounds(0, 100, "g/cm**3"), + ] + ], + ), + LocalAttributeAndObject( + name="density", + headers=["density"], + template=LinkByUID(scope="templates", id="density"), + object_template=LinkByUID(scope="templates", id="object"), + attribute_constraints=[ + [ + LinkByUID(scope="templates", id="density"), + RealBounds(0, 100, "g/cm**3"), + ] + ], + ), + IngredientIdentifierByProcessTemplateAndName( + name="ingredient id", + headers=["density"], + process_template=LinkByUID(scope="template", id="process"), + ingredient_name="ingredient", + scope="scope", + ), + IngredientIdentifierInOutput( + name="ingredient id", + headers=["ingredient id"], + ingredient_name="ingredient", + process_templates=[LinkByUID(scope="template", id="object")], + scope="scope", + ), + LocalIngredientIdentifier( + name="ingredient id", + headers=["ingredient id"], + ingredient_name="ingredient", + scope="scope", + ), + IngredientLabelByProcessAndName( + name="ingredient label", + headers=["label"], + process_template=LinkByUID(scope="template", id="process"), + ingredient_name="ingredient", + label="label", + ), + IngredientLabelsSetByProcessAndName( + name="ingredient label", + headers=["label"], + process_template=LinkByUID(scope="template", id="process"), + ingredient_name="ingredient", + ), + IngredientLabelsSetInOutput( + name="ingredient label", + headers=["label"], + process_templates=[LinkByUID(scope="template", id="process")], + ingredient_name="ingredient", + ), + LocalIngredientLabelsSet( + name="ingredient label", headers=["label"], ingredient_name="ingredient" + ), + IngredientQuantityByProcessAndName( + name="ingredient quantity dimension", + headers=["quantity"], + process_template=LinkByUID(scope="template", id="process"), + ingredient_name="ingredient", + quantity_dimension=IngredientQuantityDimension.ABSOLUTE, + unit="kg", + ), + IngredientQuantityInOutput( + name="ingredient quantity", + headers=["ingredient quantity"], + ingredient_name="ingredient", + quantity_dimension=IngredientQuantityDimension.MASS, + process_templates=[LinkByUID(scope="template", id="object")], + ), + LocalIngredientQuantity( + name="ingredient quantity", + headers=["ingredient quantity"], + ingredient_name="ingredient", + quantity_dimension=IngredientQuantityDimension.MASS, + ), + TerminalMaterialIdentifier(name="terminal id", headers=["id"], scope="scope"), + ] +) def variable(request): return request.param @@ -57,7 +196,7 @@ def test_quantity_dimension_serializes_to_string(): headers=["quantity"], process_template=LinkByUID(scope="template", id="process"), ingredient_name="ingredient", - quantity_dimension=IngredientQuantityDimension.NUMBER + quantity_dimension=IngredientQuantityDimension.NUMBER, ) variable_data = variable.dump() assert variable_data["quantity_dimension"] == "number" @@ -69,7 +208,7 @@ def test_absolute_units(): headers=["quantity"], process_template=LinkByUID(scope="template", id="process"), ingredient_name="ingredient", - quantity_dimension=IngredientQuantityDimension.NUMBER + quantity_dimension=IngredientQuantityDimension.NUMBER, ) IngredientQuantityByProcessAndName( name="This should be fine, too", @@ -77,7 +216,7 @@ def test_absolute_units(): process_template=LinkByUID(scope="template", id="process"), ingredient_name="ingredient", quantity_dimension=IngredientQuantityDimension.ABSOLUTE, - unit='kg' + unit="kg", ) with pytest.raises(ValueError): IngredientQuantityByProcessAndName( @@ -85,7 +224,7 @@ def test_absolute_units(): headers=["quantity"], process_template=LinkByUID(scope="template", id="process"), ingredient_name="ingredient", - quantity_dimension="bunk" + quantity_dimension="bunk", ) with pytest.raises(ValueError): IngredientQuantityByProcessAndName( @@ -93,7 +232,7 @@ def test_absolute_units(): headers=["quantity"], process_template=LinkByUID(scope="template", id="process"), ingredient_name="ingredient", - quantity_dimension=IngredientQuantityDimension.ABSOLUTE + quantity_dimension=IngredientQuantityDimension.ABSOLUTE, ) with pytest.raises(ValueError): IngredientQuantityByProcessAndName( @@ -102,7 +241,7 @@ def test_absolute_units(): process_template=LinkByUID(scope="template", id="process"), ingredient_name="ingredient", quantity_dimension=IngredientQuantityDimension.NUMBER, - unit='kg' + unit="kg", ) # And again, for IngredientQuantityInOutput @@ -111,7 +250,7 @@ def test_absolute_units(): headers=["quantity"], process_templates=[LinkByUID(scope="template", id="process")], ingredient_name="ingredient", - quantity_dimension=IngredientQuantityDimension.NUMBER + quantity_dimension=IngredientQuantityDimension.NUMBER, ) IngredientQuantityInOutput( name="This should be fine, too", @@ -119,7 +258,7 @@ def test_absolute_units(): process_templates=[LinkByUID(scope="template", id="process")], ingredient_name="ingredient", quantity_dimension=IngredientQuantityDimension.ABSOLUTE, - unit='kg' + unit="kg", ) with pytest.raises(ValueError): IngredientQuantityInOutput( @@ -127,7 +266,7 @@ def test_absolute_units(): headers=["quantity"], process_templates=[LinkByUID(scope="template", id="process")], ingredient_name="ingredient", - quantity_dimension="bunk" + quantity_dimension="bunk", ) with pytest.raises(ValueError): IngredientQuantityInOutput( @@ -135,7 +274,7 @@ def test_absolute_units(): headers=["quantity"], process_templates=[LinkByUID(scope="template", id="process")], ingredient_name="ingredient", - quantity_dimension=IngredientQuantityDimension.ABSOLUTE + quantity_dimension=IngredientQuantityDimension.ABSOLUTE, ) with pytest.raises(ValueError): IngredientQuantityInOutput( @@ -144,7 +283,7 @@ def test_absolute_units(): process_templates=[LinkByUID(scope="template", id="process")], ingredient_name="ingredient", quantity_dimension=IngredientQuantityDimension.NUMBER, - unit='kg' + unit="kg", ) # And again, for LocalIngredientQuantity @@ -152,28 +291,28 @@ def test_absolute_units(): name="This should be fine", headers=["quantity"], ingredient_name="ingredient", - quantity_dimension=IngredientQuantityDimension.NUMBER + quantity_dimension=IngredientQuantityDimension.NUMBER, ) LocalIngredientQuantity( name="This should be fine, too", headers=["quantity"], ingredient_name="ingredient", quantity_dimension=IngredientQuantityDimension.ABSOLUTE, - unit='kg' + unit="kg", ) with pytest.raises(ValueError): LocalIngredientQuantity( name="Invalid quantity dimension as string", headers=["quantity"], ingredient_name="ingredient", - quantity_dimension="bunk" + quantity_dimension="bunk", ) with pytest.raises(ValueError): LocalIngredientQuantity( name="This needs units", headers=["quantity"], ingredient_name="ingredient", - quantity_dimension=IngredientQuantityDimension.ABSOLUTE + quantity_dimension=IngredientQuantityDimension.ABSOLUTE, ) with pytest.raises(ValueError): LocalIngredientQuantity( @@ -181,5 +320,5 @@ def test_absolute_units(): headers=["quantity"], ingredient_name="ingredient", quantity_dimension=IngredientQuantityDimension.NUMBER, - unit='kg' + unit="kg", ) diff --git a/tests/informatics/test_constraints.py b/tests/informatics/test_constraints.py index 78fb32f69..76052d7ab 100644 --- a/tests/informatics/test_constraints.py +++ b/tests/informatics/test_constraints.py @@ -1,7 +1,16 @@ """Tests for citrine.informatics.constraints.""" + import pytest -from citrine.informatics.constraints import * +from citrine.informatics.constraints import ( + AcceptableCategoriesConstraint, + IngredientCountConstraint, + IngredientFractionConstraint, + IngredientRatioConstraint, + IntegerRangeConstraint, + LabelFractionConstraint, + ScalarRangeConstraint, +) from citrine.informatics.descriptors import FormulationDescriptor formulation_descriptor = FormulationDescriptor.hierarchical() @@ -11,10 +20,7 @@ def scalar_range_constraint() -> ScalarRangeConstraint: """Build a ScalarRangeConstraint.""" return ScalarRangeConstraint( - descriptor_key='z', - lower_bound=1.0, - upper_bound=10.0, - lower_inclusive=False + descriptor_key="z", lower_bound=1.0, upper_bound=10.0, lower_inclusive=False ) @@ -22,9 +28,7 @@ def scalar_range_constraint() -> ScalarRangeConstraint: def integer_range_constraint() -> IntegerRangeConstraint: """Build an IntegerRangeConstraint.""" return IntegerRangeConstraint( - descriptor_key='integer', - lower_bound=1, - upper_bound=10 + descriptor_key="integer", lower_bound=1, upper_bound=10 ) @@ -32,8 +36,7 @@ def integer_range_constraint() -> IntegerRangeConstraint: def categorical_constraint() -> AcceptableCategoriesConstraint: """Build a CategoricalConstraint.""" return AcceptableCategoriesConstraint( - descriptor_key='x', - acceptable_categories=['y', 'z'] + descriptor_key="x", acceptable_categories=["y", "z"] ) @@ -42,10 +45,10 @@ def ingredient_fraction_constraint() -> IngredientFractionConstraint: """Build an IngredientFractionConstraint.""" return IngredientFractionConstraint( formulation_descriptor=formulation_descriptor, - ingredient='foo', + ingredient="foo", min=0.0, max=1.0, - is_required=False + is_required=False, ) @@ -53,10 +56,7 @@ def ingredient_fraction_constraint() -> IngredientFractionConstraint: def ingredient_count_constraint() -> IngredientCountConstraint: """Build an IngredientCountConstraint.""" return IngredientCountConstraint( - formulation_descriptor=formulation_descriptor, - min=0, - max=1, - label='bar' + formulation_descriptor=formulation_descriptor, min=0, max=1, label="bar" ) @@ -65,10 +65,10 @@ def label_fraction_constraint() -> LabelFractionConstraint: """Build a LabelFractionConstraint.""" return LabelFractionConstraint( formulation_descriptor=formulation_descriptor, - label='bar', + label="bar", min=0.0, max=1.0, - is_required=False + is_required=False, ) @@ -82,13 +82,13 @@ def ingredient_ratio_constraint() -> IngredientRatioConstraint: ingredient=("foo", 1.0), label=("foolabel", 0.5), basis_ingredients=["baz", "bat"], - basis_labels=["bazlabel", "batlabel"] + basis_labels=["bazlabel", "batlabel"], ) def test_scalar_range_initialization(scalar_range_constraint): """Make sure the correct fields go to the correct places.""" - assert scalar_range_constraint.descriptor_key == 'z' + assert scalar_range_constraint.descriptor_key == "z" assert scalar_range_constraint.lower_bound == 1.0 assert scalar_range_constraint.upper_bound == 10.0 assert not scalar_range_constraint.lower_inclusive @@ -97,22 +97,24 @@ def test_scalar_range_initialization(scalar_range_constraint): def test_integer_range_initialization(integer_range_constraint): """Make sure the correct fields go to the correct places.""" - assert integer_range_constraint.descriptor_key == 'integer' + assert integer_range_constraint.descriptor_key == "integer" assert integer_range_constraint.lower_bound == 1 assert integer_range_constraint.upper_bound == 10 def test_categorical_initialization(categorical_constraint): """Make sure the correct fields go to the correct places.""" - assert categorical_constraint.descriptor_key == 'x' - assert categorical_constraint.acceptable_categories == ['y', 'z'] + assert categorical_constraint.descriptor_key == "x" + assert categorical_constraint.acceptable_categories == ["y", "z"] assert "Acceptable" in str(categorical_constraint) def test_ingredient_fraction_initialization(ingredient_fraction_constraint): """Make sure the correct fields go to the correct places.""" - assert ingredient_fraction_constraint.formulation_descriptor == formulation_descriptor - assert ingredient_fraction_constraint.ingredient == 'foo' + assert ( + ingredient_fraction_constraint.formulation_descriptor == formulation_descriptor + ) + assert ingredient_fraction_constraint.ingredient == "foo" assert ingredient_fraction_constraint.min == 0.0 assert ingredient_fraction_constraint.max == 1.0 assert not ingredient_fraction_constraint.is_required @@ -123,13 +125,13 @@ def test_ingredient_count_initialization(ingredient_count_constraint): assert ingredient_count_constraint.formulation_descriptor == formulation_descriptor assert ingredient_count_constraint.min == 0 assert ingredient_count_constraint.max == 1 - assert ingredient_count_constraint.label == 'bar' + assert ingredient_count_constraint.label == "bar" def test_label_fraction_initialization(label_fraction_constraint): """Make sure the correct fields go to the correct places.""" assert label_fraction_constraint.formulation_descriptor == formulation_descriptor - assert label_fraction_constraint.label == 'bar' + assert label_fraction_constraint.label == "bar" assert label_fraction_constraint.min == 0.0 assert label_fraction_constraint.max == 1.0 assert not label_fraction_constraint.is_required @@ -150,7 +152,7 @@ def test_ingredient_ratio_interaction(ingredient_ratio_constraint): with pytest.raises(ValueError): ingredient_ratio_constraint.ingredient = ("foo", 2, "bar", 4) with pytest.raises(ValueError): - ingredient_ratio_constraint.ingredient = ("foo", ) + ingredient_ratio_constraint.ingredient = ("foo",) with pytest.raises(TypeError): ingredient_ratio_constraint.ingredient = ("foo", "yup") with pytest.raises(ValueError): @@ -167,7 +169,7 @@ def test_ingredient_ratio_interaction(ingredient_ratio_constraint): with pytest.raises(ValueError): ingredient_ratio_constraint.label = ("foolabel", 2, "barlabel", 4) with pytest.raises(ValueError): - ingredient_ratio_constraint.label = ("foolabel", ) + ingredient_ratio_constraint.label = ("foolabel",) with pytest.raises(TypeError): ingredient_ratio_constraint.label = ("foolabel", "yup") with pytest.raises(ValueError): @@ -215,8 +217,14 @@ def test_range_defaults(): assert ScalarRangeConstraint(descriptor_key="x").lower_inclusive is True assert ScalarRangeConstraint(descriptor_key="x").upper_inclusive is True - assert ScalarRangeConstraint(descriptor_key="x", upper_inclusive=False).upper_inclusive is False - assert ScalarRangeConstraint(descriptor_key="x", lower_inclusive=False).lower_inclusive is False + assert ( + ScalarRangeConstraint(descriptor_key="x", upper_inclusive=False).upper_inclusive + is False + ) + assert ( + ScalarRangeConstraint(descriptor_key="x", lower_inclusive=False).lower_inclusive + is False + ) assert ScalarRangeConstraint(descriptor_key="x", lower_bound=0).lower_bound == 0.0 assert ScalarRangeConstraint(descriptor_key="x", upper_bound=0).upper_bound == 0.0 diff --git a/tests/informatics/test_data_source.py b/tests/informatics/test_data_source.py index b1b4e2a06..860cd5d08 100644 --- a/tests/informatics/test_data_source.py +++ b/tests/informatics/test_data_source.py @@ -1,10 +1,15 @@ """Tests for citrine.informatics.descriptors.""" + import uuid import pytest from citrine.informatics.data_sources import ( - DataSource, CSVDataSource, ExperimentDataSourceRef, GemTableDataSource, SnapshotDataSource + DataSource, + CSVDataSource, + ExperimentDataSourceRef, + GemTableDataSource, + SnapshotDataSource, ) from citrine.informatics.descriptors import RealDescriptor from citrine.resources.file_link import FileLink @@ -12,12 +17,15 @@ from tests.utils.factories import GemTableDataFactory -@pytest.fixture(params=[ - GemTableDataSource(table_id=uuid.uuid4(), table_version=1), - GemTableDataSource(table_id=uuid.uuid4(), table_version="2"), - ExperimentDataSourceRef(datasource_id=uuid.uuid4()), - SnapshotDataSource(snapshot_id=uuid.uuid4()) -]) + +@pytest.fixture( + params=[ + GemTableDataSource(table_id=uuid.uuid4(), table_version=1), + GemTableDataSource(table_id=uuid.uuid4(), table_version="2"), + ExperimentDataSourceRef(datasource_id=uuid.uuid4()), + SnapshotDataSource(snapshot_id=uuid.uuid4()), + ] +) def data_source(request): return request.param @@ -43,7 +51,10 @@ def test_invalid_deser(): def test_data_source_id(data_source): - assert data_source == DataSource.from_data_source_id(data_source.to_data_source_id()) + assert data_source == DataSource.from_data_source_id( + data_source.to_data_source_id() + ) + def test_from_gem_table(): table = GemTable.build(GemTableDataFactory()) @@ -51,6 +62,7 @@ def test_from_gem_table(): assert data_source.table_id == table.uid assert data_source.table_version == table.version + def test_invalid_data_source_id(): with pytest.raises(ValueError): DataSource.from_data_source_id(f"Undefined::{uuid.uuid4()}") @@ -58,23 +70,34 @@ def test_invalid_data_source_id(): def test_deser_from_parent_deprecated(): with pytest.deprecated_call(): - data_source = CSVDataSource(file_link=FileLink("foo.spam", "http://example.com"), - column_definitions={"spam": RealDescriptor("eggs", lower_bound=0, upper_bound=1.0, units="")}, - identifiers=["identifier"]) + data_source = CSVDataSource( + file_link=FileLink("foo.spam", "http://example.com"), + column_definitions={ + "spam": RealDescriptor("eggs", lower_bound=0, upper_bound=1.0, units="") + }, + identifiers=["identifier"], + ) # Serialize and deserialize the descriptors, making sure they are round-trip serializable data = data_source.dump() data_source_deserialized = DataSource.build(data) assert data_source == data_source_deserialized + def test_data_source_id_deprecated(): with pytest.deprecated_call(): - data_source = CSVDataSource(file_link=FileLink("foo.spam", "http://example.com"), - column_definitions={"spam": RealDescriptor("eggs", lower_bound=0, upper_bound=1.0, units="")}, - identifiers=["identifier"]) - + data_source = CSVDataSource( + file_link=FileLink("foo.spam", "http://example.com"), + column_definitions={ + "spam": RealDescriptor("eggs", lower_bound=0, upper_bound=1.0, units="") + }, + identifiers=["identifier"], + ) + # TODO: There's no obvious way to recover the column_definitions & identifiers from the ID with pytest.deprecated_call(): with pytest.warns(UserWarning): - transformed = DataSource.from_data_source_id(data_source.to_data_source_id()) + transformed = DataSource.from_data_source_id( + data_source.to_data_source_id() + ) assert transformed.file_link == data_source.file_link diff --git a/tests/informatics/test_descriptors.py b/tests/informatics/test_descriptors.py index 8c1744756..a23bdde1b 100644 --- a/tests/informatics/test_descriptors.py +++ b/tests/informatics/test_descriptors.py @@ -1,20 +1,32 @@ """Tests for citrine.informatics.descriptors.""" + import json import pytest -from citrine.informatics.descriptors import * - - -@pytest.fixture(params=[ - RealDescriptor('alpha', lower_bound=0, upper_bound=100, units=""), - IntegerDescriptor('count', lower_bound=0, upper_bound=100), - ChemicalFormulaDescriptor('formula'), - MolecularStructureDescriptor("organic"), - CategoricalDescriptor("my categorical", categories=["a", "b"]), - CategoricalDescriptor("categorical", categories=["*"]), - FormulationDescriptor.hierarchical() -]) +from citrine.informatics.descriptors import ( + CategoricalDescriptor, + ChemicalFormulaDescriptor, + Descriptor, + FormulationDescriptor, + FormulationKey, + IntegerDescriptor, + MolecularStructureDescriptor, + RealDescriptor, +) + + +@pytest.fixture( + params=[ + RealDescriptor("alpha", lower_bound=0, upper_bound=100, units=""), + IntegerDescriptor("count", lower_bound=0, upper_bound=100), + ChemicalFormulaDescriptor("formula"), + MolecularStructureDescriptor("organic"), + CategoricalDescriptor("my categorical", categories=["a", "b"]), + CategoricalDescriptor("categorical", categories=["*"]), + FormulationDescriptor.hierarchical(), + ] +) def descriptor(request): return request.param @@ -27,7 +39,6 @@ def test_deser_from_parent(descriptor): def test_equals(descriptor): - assert descriptor._equals(descriptor, descriptor.__dict__.keys()) # attributes missing from the descriptor should raise an exception diff --git a/tests/informatics/test_design_candidate.py b/tests/informatics/test_design_candidate.py index 57124c814..bb935be45 100644 --- a/tests/informatics/test_design_candidate.py +++ b/tests/informatics/test_design_candidate.py @@ -1,4 +1,5 @@ """Tests for citrine.informatics.design_candidate.""" + from citrine.informatics.design_candidate import DesignVariable diff --git a/tests/informatics/test_design_spaces.py b/tests/informatics/test_design_spaces.py index 9f92188f7..acfa5ce1b 100644 --- a/tests/informatics/test_design_spaces.py +++ b/tests/informatics/test_design_spaces.py @@ -1,38 +1,59 @@ """Tests for citrine.informatics.design_spaces.""" + import uuid import pytest +from citrine.informatics.templates import TemplateLink from citrine.informatics.constraints import IngredientCountConstraint from citrine.informatics.data_sources import DataSource, GemTableDataSource -from citrine.informatics.descriptors import FormulationDescriptor, RealDescriptor, \ - CategoricalDescriptor, IntegerDescriptor -from citrine.informatics.design_spaces import * -from citrine.informatics.dimensions import ContinuousDimension, EnumeratedDimension, \ - IntegerDimension +from citrine.informatics.descriptors import ( + CategoricalDescriptor, + FormulationDescriptor, + IntegerDescriptor, + RealDescriptor, +) +from citrine.informatics.design_candidate import MaterialNodeDefinition +from citrine.informatics.design_spaces import ( + DataSourceDesignSpace, + DesignSpace, + EnumeratedDesignSpace, + FormulationDesignSpace, + HierarchicalDesignSpace, + ProductDesignSpace, +) +from citrine.informatics.dimensions import ( + ContinuousDimension, + EnumeratedDimension, + IntegerDimension, +) @pytest.fixture def product_design_space() -> ProductDesignSpace: """Build a ProductDesignSpace for testing.""" - alpha = RealDescriptor('alpha', lower_bound=0, upper_bound=100, units="") - beta = IntegerDescriptor('beta', lower_bound=0, upper_bound=100) - gamma = CategoricalDescriptor('gamma', categories=['a', 'b', 'c']) + alpha = RealDescriptor("alpha", lower_bound=0, upper_bound=100, units="") + beta = IntegerDescriptor("beta", lower_bound=0, upper_bound=100) + gamma = CategoricalDescriptor("gamma", categories=["a", "b", "c"]) dimensions = [ ContinuousDimension(alpha, lower_bound=0, upper_bound=10), IntegerDimension(beta, lower_bound=0, upper_bound=10), - EnumeratedDimension(gamma, values=['a', 'c']) + EnumeratedDimension(gamma, values=["a", "c"]), ] - return ProductDesignSpace(name='my design space', description='does some things', dimensions=dimensions) + return ProductDesignSpace( + name="my design space", description="does some things", dimensions=dimensions + ) @pytest.fixture def enumerated_design_space() -> EnumeratedDesignSpace: """Build an EnumeratedDesignSpace for testing.""" - x = RealDescriptor('x', lower_bound=0.0, upper_bound=1.0, units='') - color = CategoricalDescriptor('color', categories=['r', 'g', 'b']) - data = [dict(x='0', color='r'), dict(x='1.0', color='b')] - return EnumeratedDesignSpace('enumerated', description='desc', descriptors=[x, color], data=data) + x = RealDescriptor("x", lower_bound=0.0, upper_bound=1.0, units="") + color = CategoricalDescriptor("color", categories=["r", "g", "b"]) + data = [dict(x="0", color="r"), dict(x="1.0", color="b")] + return EnumeratedDesignSpace( + "enumerated", description="desc", descriptors=[x, color], data=data + ) @pytest.fixture @@ -46,7 +67,7 @@ def formulation_design_space() -> FormulationDesignSpace: labels={"canine": {"dog"}, "feline": {"cat"}}, constraints={ IngredientCountConstraint(formulation_descriptor=desc, min=1, max=2) - } + }, ) @@ -57,25 +78,23 @@ def hierarchical_design_space(material_node_definition) -> HierarchicalDesignSpa description="Does things in levels", root=material_node_definition, subspaces=[material_node_definition], - data_sources=[ - GemTableDataSource(table_id=uuid.uuid4(), table_version=2) - ] + data_sources=[GemTableDataSource(table_id=uuid.uuid4(), table_version=2)], ) @pytest.fixture def material_node_definition(formulation_design_space) -> MaterialNodeDefinition: - temp = RealDescriptor('temperature', lower_bound=0.0, upper_bound=1.0, units='') + temp = RealDescriptor("temperature", lower_bound=0.0, upper_bound=1.0, units="") temp_dimension = ContinuousDimension(temp, lower_bound=0.1, upper_bound=0.9) - color = CategoricalDescriptor('color', categories={'r', 'g', 'b'}) - color_dimension = EnumeratedDimension(color, values=['g', 'b']) + color = CategoricalDescriptor("color", categories={"r", "g", "b"}) + color_dimension = EnumeratedDimension(color, values=["g", "b"]) link = TemplateLink( material_template=uuid.uuid4(), process_template=uuid.uuid4(), material_template_name="Material Template Name", - process_template_name="Process Template Name" + process_template_name="Process Template Name", ) return MaterialNodeDefinition( @@ -84,28 +103,31 @@ def material_node_definition(formulation_design_space) -> MaterialNodeDefinition formulation_subspace=formulation_design_space, template_link=link, attributes=[temp_dimension, color_dimension], - display_name="Special Material" + display_name="Special Material", ) def test_product_initialization(product_design_space): """Make sure the correct fields go to the correct places.""" - assert product_design_space.name == 'my design space' - assert product_design_space.description == 'does some things' + assert product_design_space.name == "my design space" + assert product_design_space.description == "does some things" assert len(product_design_space.dimensions) == 3 - assert product_design_space.dimensions[0].descriptor.key == 'alpha' - assert product_design_space.dimensions[1].descriptor.key == 'beta' - assert product_design_space.dimensions[2].descriptor.key == 'gamma' + assert product_design_space.dimensions[0].descriptor.key == "alpha" + assert product_design_space.dimensions[1].descriptor.key == "beta" + assert product_design_space.dimensions[2].descriptor.key == "gamma" def test_enumerated_initialization(enumerated_design_space): """Make sure the correct fields go to the correct places.""" - assert enumerated_design_space.name == 'enumerated' - assert enumerated_design_space.description == 'desc' + assert enumerated_design_space.name == "enumerated" + assert enumerated_design_space.description == "desc" assert len(enumerated_design_space.descriptors) == 2 - assert enumerated_design_space.descriptors[0].key == 'x' - assert enumerated_design_space.descriptors[1].key == 'color' - assert enumerated_design_space.data == [{'x': '0', 'color': 'r'}, {'x': '1.0', 'color': 'b'}] + assert enumerated_design_space.descriptors[0].key == "x" + assert enumerated_design_space.descriptors[1].key == "color" + assert enumerated_design_space.data == [ + {"x": "0", "color": "r"}, + {"x": "1.0", "color": "b"}, + ] def test_hierarchical_initialization(hierarchical_design_space): @@ -124,16 +146,20 @@ def test_data_source_build(valid_data_source_design_space_dict): ds = DesignSpace.build(valid_data_source_design_space_dict) assert ds.name == valid_data_source_design_space_dict["data"]["instance"]["name"] assert ds.description == valid_data_source_design_space_dict["data"]["description"] - assert ds.data_source == DataSource.build(valid_data_source_design_space_dict["data"]["instance"]["data_source"]) + assert ds.data_source == DataSource.build( + valid_data_source_design_space_dict["data"]["instance"]["data_source"] + ) assert str(ds) == f"" def test_data_source_initialization(valid_data_source_design_space_dict): data = valid_data_source_design_space_dict["data"] data_source = DataSource.build(data["instance"]["data_source"]) - ds = DataSourceDesignSpace(name=data["instance"]["name"], - description=data["description"], - data_source=data_source) + ds = DataSourceDesignSpace( + name=data["instance"]["name"], + description=data["description"], + data_source=data_source, + ) assert ds.name == data["instance"]["name"] assert ds.description == data["description"] assert ds.data_source.dump() == data["instance"]["data_source"] diff --git a/tests/informatics/test_dimensions.py b/tests/informatics/test_dimensions.py index 1159f959f..314672c74 100644 --- a/tests/informatics/test_dimensions.py +++ b/tests/informatics/test_dimensions.py @@ -1,36 +1,43 @@ """Tests for citrine.informatics.dimensions.""" + import pytest -from citrine.informatics.descriptors import RealDescriptor, CategoricalDescriptor, \ - IntegerDescriptor -from citrine.informatics.dimensions import ContinuousDimension, EnumeratedDimension, \ - IntegerDimension +from citrine.informatics.descriptors import ( + RealDescriptor, + CategoricalDescriptor, + IntegerDescriptor, +) +from citrine.informatics.dimensions import ( + ContinuousDimension, + EnumeratedDimension, + IntegerDimension, +) @pytest.fixture def continuous_dimension() -> ContinuousDimension: """Build a ContinuousDimension.""" - alpha = RealDescriptor('alpha', lower_bound=0, upper_bound=100, units="") + alpha = RealDescriptor("alpha", lower_bound=0, upper_bound=100, units="") return ContinuousDimension(alpha, lower_bound=3, upper_bound=33) @pytest.fixture def enumerated_dimension() -> EnumeratedDimension: """Build an EnumeratedDimension.""" - color = CategoricalDescriptor('color', categories={'red', 'green', 'blue'}) - return EnumeratedDimension(color, values=['red', 'red', 'blue']) + color = CategoricalDescriptor("color", categories={"red", "green", "blue"}) + return EnumeratedDimension(color, values=["red", "red", "blue"]) def test_continuous_initialization(continuous_dimension): """Make sure the correct fields go to the correct places.""" - assert continuous_dimension.descriptor.key == 'alpha' + assert continuous_dimension.descriptor.key == "alpha" assert continuous_dimension.lower_bound == 3 assert continuous_dimension.upper_bound == 33 def test_continuous_bounds(): """Test bounds are assigned correctly, even when bounds are == 0""" - beta = RealDescriptor('beta', lower_bound=-10, upper_bound=10, units="") + beta = RealDescriptor("beta", lower_bound=-10, upper_bound=10, units="") lower_none = ContinuousDimension(beta, upper_bound=0) assert lower_none.lower_bound == -10 assert lower_none.upper_bound == 0 @@ -42,7 +49,7 @@ def test_continuous_bounds(): def test_integer_bounds(): """Test bounds are assigned correctly, even when bounds are == 0""" - beta = IntegerDescriptor('beta', lower_bound=-10, upper_bound=10) + beta = IntegerDescriptor("beta", lower_bound=-10, upper_bound=10) lower_none = IntegerDimension(beta, upper_bound=0) assert lower_none.lower_bound == -10 assert lower_none.upper_bound == 0 @@ -54,6 +61,6 @@ def test_integer_bounds(): def test_enumerated_initialization(enumerated_dimension): """Make sure the correct fields go to the correct places.""" - assert enumerated_dimension.descriptor.key == 'color' - assert enumerated_dimension.descriptor.categories == {'red', 'green', 'blue'} - assert enumerated_dimension.values == ['red', 'red', 'blue'] + assert enumerated_dimension.descriptor.key == "color" + assert enumerated_dimension.descriptor.categories == {"red", "green", "blue"} + assert enumerated_dimension.values == ["red", "red", "blue"] diff --git a/tests/informatics/test_experiment_values.py b/tests/informatics/test_experiment_values.py index fc2887a8e..fbade488e 100644 --- a/tests/informatics/test_experiment_values.py +++ b/tests/informatics/test_experiment_values.py @@ -1,24 +1,28 @@ -import uuid - import pytest -from citrine.informatics.experiment_values import ExperimentValue, \ - RealExperimentValue, \ - IntegerExperimentValue, \ - CategoricalExperimentValue, \ - MixtureExperimentValue, \ - ChemicalFormulaExperimentValue, \ - MolecularStructureExperimentValue - - -@pytest.fixture(params=[ - CategoricalExperimentValue("categorical"), - ChemicalFormulaExperimentValue("(Ca)1(O)3(Si)1"), - IntegerExperimentValue(7), - MixtureExperimentValue({"ingredient1": 0.3, "ingredient2": 0.7}), - MolecularStructureExperimentValue("CC1(CC(CC(N1)(C)C)NCCCCCCNC2CC(NC(C2)(C)C)(C)C)C.C1COCCN1C2=NC(=NC(=N2)Cl)Cl"), - RealExperimentValue(3.5) -]) +from citrine.informatics.experiment_values import ( + ExperimentValue, + RealExperimentValue, + IntegerExperimentValue, + CategoricalExperimentValue, + MixtureExperimentValue, + ChemicalFormulaExperimentValue, + MolecularStructureExperimentValue, +) + + +@pytest.fixture( + params=[ + CategoricalExperimentValue("categorical"), + ChemicalFormulaExperimentValue("(Ca)1(O)3(Si)1"), + IntegerExperimentValue(7), + MixtureExperimentValue({"ingredient1": 0.3, "ingredient2": 0.7}), + MolecularStructureExperimentValue( + "CC1(CC(CC(N1)(C)C)NCCCCCCNC2CC(NC(C2)(C)C)(C)C)C.C1COCCN1C2=NC(=NC(=N2)Cl)Cl" + ), + RealExperimentValue(3.5), + ] +) def experiment_value(request): return request.param diff --git a/tests/informatics/test_informatics.py b/tests/informatics/test_informatics.py index 07f8af2d6..67b88d3d8 100644 --- a/tests/informatics/test_informatics.py +++ b/tests/informatics/test_informatics.py @@ -1,53 +1,95 @@ import pytest from citrine.informatics.descriptors import FormulationDescriptor, FormulationKey -from citrine.informatics.constraints import ScalarRangeConstraint, AcceptableCategoriesConstraint, \ - IngredientCountConstraint, IngredientFractionConstraint, IngredientRatioConstraint, \ - LabelFractionConstraint, IntegerRangeConstraint -from citrine.informatics.design_spaces import ProductDesignSpace, EnumeratedDesignSpace, FormulationDesignSpace +from citrine.informatics.constraints import ( + ScalarRangeConstraint, + AcceptableCategoriesConstraint, + IngredientCountConstraint, + IngredientFractionConstraint, + IngredientRatioConstraint, + LabelFractionConstraint, + IntegerRangeConstraint, +) +from citrine.informatics.design_spaces import ( + ProductDesignSpace, + EnumeratedDesignSpace, + FormulationDesignSpace, +) from citrine.informatics.objectives import ScalarMaxObjective, ScalarMinObjective from citrine.informatics.scores import LIScore, EIScore, EVScore informatics_string_data = [ - (IngredientCountConstraint( - formulation_descriptor=FormulationDescriptor.hierarchical(), - min=0, max=1 - ), f""), - (IngredientFractionConstraint( - formulation_descriptor=FormulationDescriptor.hierarchical(), - ingredient='y', - min=0, - max=1 - ), f""), - (LabelFractionConstraint( - formulation_descriptor=FormulationDescriptor.hierarchical(), - label='y', - min=0, - max=1 - ), f""), - (ScalarRangeConstraint(descriptor_key='z'), ""), - (IntegerRangeConstraint(descriptor_key='w'), ""), - (AcceptableCategoriesConstraint(descriptor_key='x', acceptable_categories=[]), ""), - (IngredientRatioConstraint(formulation_descriptor=FormulationDescriptor('Flat Formulation'), min=0.0, max=1.0, ingredient=("x", 1.5), label=("x'", 0.5), basis_ingredients=["y", "z"], basis_labels=["y'", "z'"]), ""), - (ProductDesignSpace(name='my design space', description='does some things'), - ""), - (EnumeratedDesignSpace('enumerated', description='desc', descriptors=[], data=[]), ""), - (FormulationDesignSpace( - name='Formulation', - description='desc', - formulation_descriptor=FormulationDescriptor.hierarchical(), - ingredients={'y'}, - constraints=set(), - labels={} - ), ""), - (ScalarMaxObjective('z'), ""), - (ScalarMinObjective('z'), ""), + ( + IngredientCountConstraint( + formulation_descriptor=FormulationDescriptor.hierarchical(), min=0, max=1 + ), + f"", + ), + ( + IngredientFractionConstraint( + formulation_descriptor=FormulationDescriptor.hierarchical(), + ingredient="y", + min=0, + max=1, + ), + f"", + ), + ( + LabelFractionConstraint( + formulation_descriptor=FormulationDescriptor.hierarchical(), + label="y", + min=0, + max=1, + ), + f"", + ), + (ScalarRangeConstraint(descriptor_key="z"), ""), + (IntegerRangeConstraint(descriptor_key="w"), ""), + ( + AcceptableCategoriesConstraint(descriptor_key="x", acceptable_categories=[]), + "", + ), + ( + IngredientRatioConstraint( + formulation_descriptor=FormulationDescriptor("Flat Formulation"), + min=0.0, + max=1.0, + ingredient=("x", 1.5), + label=("x'", 0.5), + basis_ingredients=["y", "z"], + basis_labels=["y'", "z'"], + ), + "", + ), + ( + ProductDesignSpace(name="my design space", description="does some things"), + "", + ), + ( + EnumeratedDesignSpace( + "enumerated", description="desc", descriptors=[], data=[] + ), + "", + ), + ( + FormulationDesignSpace( + name="Formulation", + description="desc", + formulation_descriptor=FormulationDescriptor.hierarchical(), + ingredients={"y"}, + constraints=set(), + labels={}, + ), + "", + ), + (ScalarMaxObjective("z"), ""), + (ScalarMinObjective("z"), ""), (LIScore(objectives=[], baselines=[]), ""), (EIScore(objectives=[], baselines=[], constraints=[]), ""), (EVScore(objectives=[], constraints=[]), ""), ] -@pytest.mark.parametrize('obj,repr', informatics_string_data) +@pytest.mark.parametrize("obj,repr", informatics_string_data) def test_str_representation(obj, repr): assert str(obj) == repr diff --git a/tests/informatics/test_objectives.py b/tests/informatics/test_objectives.py index 56a61a0e8..87f8085f2 100644 --- a/tests/informatics/test_objectives.py +++ b/tests/informatics/test_objectives.py @@ -1,4 +1,5 @@ """Tests for citrine.informatics.objectives.""" + import pytest from citrine.informatics.objectives import ScalarMaxObjective, ScalarMinObjective diff --git a/tests/informatics/test_predictor_evaluation_metrics.py b/tests/informatics/test_predictor_evaluation_metrics.py index 0e024f4d0..10663d5c6 100644 --- a/tests/informatics/test_predictor_evaluation_metrics.py +++ b/tests/informatics/test_predictor_evaluation_metrics.py @@ -1,21 +1,39 @@ """Tests for citrine.informatics.descriptors.""" + import json import logging import pytest -from citrine.informatics.predictor_evaluation_metrics import * - - -@pytest.fixture(params=[ - (RMSE(), "rmse", "RMSE"), - (RSquared(), "R^2", "R^2"), - (NDME(), "ndme", "NDME"), - (StandardRMSE(), "standardized_rmse", "Standardized RMSE"), - (PVA(), "predicted_vs_actual", "Predicted vs Actual"), - (F1(), "f1", "F1 Score"), - (AreaUnderROC(), "area_under_roc", "Area Under the ROC"), - (CoverageProbability(coverage_level=0.123), "coverage_probability_0.123", "Coverage Probability (0.123)") -]) + +from citrine.informatics.predictor_evaluation_metrics import ( + F1, + NDME, + PVA, + RMSE, + AreaUnderROC, + CoverageProbability, + PredictorEvaluationMetric, + RSquared, + StandardRMSE, +) + + +@pytest.fixture( + params=[ + (RMSE(), "rmse", "RMSE"), + (RSquared(), "R^2", "R^2"), + (NDME(), "ndme", "NDME"), + (StandardRMSE(), "standardized_rmse", "Standardized RMSE"), + (PVA(), "predicted_vs_actual", "Predicted vs Actual"), + (F1(), "f1", "F1 Score"), + (AreaUnderROC(), "area_under_roc", "Area Under the ROC"), + ( + CoverageProbability(coverage_level=0.123), + "coverage_probability_0.123", + "Coverage Probability (0.123)", + ), + ] +) def metric(request): return request.param diff --git a/tests/informatics/test_predictor_evaluation_result.py b/tests/informatics/test_predictor_evaluation_result.py index 0f12c522d..959b6f311 100644 --- a/tests/informatics/test_predictor_evaluation_result.py +++ b/tests/informatics/test_predictor_evaluation_result.py @@ -1,11 +1,19 @@ """Tests for citrine.informatics.descriptors.""" + import json + import pytest -from citrine.informatics.predictor_evaluation_metrics import * -from citrine.informatics.predictor_evaluation_result import PredictorEvaluationResult, \ - PredictedVsActualRealPoint, \ - PredictedVsActualCategoricalPoint -from citrine.informatics.predictor_evaluator import CrossValidationEvaluator, HoldoutSetEvaluator + +from citrine.informatics.predictor_evaluation_metrics import F1, PVA, RMSE +from citrine.informatics.predictor_evaluation_result import ( + PredictedVsActualCategoricalPoint, + PredictedVsActualRealPoint, + PredictorEvaluationResult, +) +from citrine.informatics.predictor_evaluator import ( + CrossValidationEvaluator, + HoldoutSetEvaluator, +) @pytest.fixture @@ -30,12 +38,16 @@ def test_indexing(example_cv_result, example_holdout_result): def test_cv_serde(example_cv_result, example_cv_result_dict): - round_trip = PredictorEvaluationResult.build(json.loads(json.dumps(example_cv_result_dict))) + round_trip = PredictorEvaluationResult.build( + json.loads(json.dumps(example_cv_result_dict)) + ) assert example_cv_result.evaluator == round_trip.evaluator def test_holdout_serde(example_holdout_result, example_holdout_result_dict): - round_trip = PredictorEvaluationResult.build(json.loads(json.dumps(example_holdout_result_dict))) + round_trip = PredictorEvaluationResult.build( + json.loads(json.dumps(example_holdout_result_dict)) + ) assert example_holdout_result.evaluator == round_trip.evaluator @@ -43,23 +55,30 @@ def test_ev_evaluator(example_cv_result, example_cv_evaluator_dict): args = example_cv_evaluator_dict del args["type"] expected = CrossValidationEvaluator(**args) - assert expected.responses == set(example_cv_evaluator_dict['responses']) + assert expected.responses == set(example_cv_evaluator_dict["responses"]) assert example_cv_result.evaluator == expected - assert example_cv_result.evaluator != 0 # make sure eq does something for mismatched classes + assert ( + example_cv_result.evaluator != 0 + ) # make sure eq does something for mismatched classes def test_holdout_set_evaluator(example_holdout_result, example_holdout_evaluator_dict): args = example_holdout_evaluator_dict del args["type"] expected = HoldoutSetEvaluator(**args) - assert expected.responses == set(example_holdout_evaluator_dict['responses']) + assert expected.responses == set(example_holdout_evaluator_dict["responses"]) assert example_holdout_result.evaluator == expected - assert example_holdout_result.evaluator != 0 # make sure eq does something for mismatched classes + assert ( + example_holdout_result.evaluator != 0 + ) # make sure eq does something for mismatched classes def test_check_rmse(example_cv_result, example_rmse_metrics): assert example_cv_result["saltiness"]["rmse"].mean == example_rmse_metrics["mean"] - assert example_cv_result["saltiness"][RMSE()].standard_error == example_rmse_metrics["standard_error"] + assert ( + example_cv_result["saltiness"][RMSE()].standard_error + == example_rmse_metrics["standard_error"] + ) # check eq method does something assert example_cv_result["saltiness"][RMSE()] != 0 with pytest.raises(TypeError): @@ -69,12 +88,24 @@ def test_check_rmse(example_cv_result, example_rmse_metrics): def test_real_pva(example_cv_result, example_real_pva_metrics): args = example_real_pva_metrics["value"][0] expected = PredictedVsActualRealPoint.build(args) - assert example_cv_result["saltiness"]["predicted_vs_actual"][0].predicted == expected.predicted - assert next(iter(example_cv_result["saltiness"]["predicted_vs_actual"])).actual == expected.actual + assert ( + example_cv_result["saltiness"]["predicted_vs_actual"][0].predicted + == expected.predicted + ) + assert ( + next(iter(example_cv_result["saltiness"]["predicted_vs_actual"])).actual + == expected.actual + ) def test_categorical_pva(example_cv_result, example_categorical_pva_metrics): args = example_categorical_pva_metrics["value"][0] expected = PredictedVsActualCategoricalPoint.build(args) - assert example_cv_result["salt?"]["predicted_vs_actual"][0].predicted == expected.predicted - assert next(iter(example_cv_result["salt?"]["predicted_vs_actual"])).actual == expected.actual + assert ( + example_cv_result["salt?"]["predicted_vs_actual"][0].predicted + == expected.predicted + ) + assert ( + next(iter(example_cv_result["salt?"]["predicted_vs_actual"])).actual + == expected.actual + ) diff --git a/tests/informatics/test_predictor_evaluations.py b/tests/informatics/test_predictor_evaluations.py index 1ac41b9ee..598e0d219 100644 --- a/tests/informatics/test_predictor_evaluations.py +++ b/tests/informatics/test_predictor_evaluations.py @@ -2,7 +2,11 @@ import pytest -from citrine.informatics.executions.predictor_evaluation import PredictorEvaluation, PredictorEvaluationRequest, PredictorEvaluatorsResponse +from citrine.informatics.executions.predictor_evaluation import ( + PredictorEvaluation, + PredictorEvaluationRequest, + PredictorEvaluatorsResponse, +) from citrine.informatics.predictor_evaluator import CrossValidationEvaluator from citrine.informatics.predictor_evaluation_metrics import NDME from citrine.informatics.predictor_evaluation_result import PredictorEvaluationResult @@ -12,7 +16,14 @@ @pytest.fixture def cross_validation_evaluator(): - yield CrossValidationEvaluator("foo", description="desc", responses={"dk"}, n_folds=2, n_trials=5, metrics={NDME()}) + yield CrossValidationEvaluator( + "foo", + description="desc", + responses={"dk"}, + n_folds=2, + n_trials=5, + metrics={NDME()}, + ) @pytest.fixture @@ -27,7 +38,11 @@ def predictor_evaluators_response(cross_validation_evaluator): @pytest.fixture def predictor_evaluation_request(cross_validation_evaluator, predictor_ref): - yield PredictorEvaluationRequest(evaluators=[cross_validation_evaluator], predictor_id=predictor_ref.uid, predictor_version=predictor_ref.version) + yield PredictorEvaluationRequest( + evaluators=[cross_validation_evaluator], + predictor_id=predictor_ref.uid, + predictor_version=predictor_ref.version, + ) @pytest.fixture @@ -37,27 +52,33 @@ def predictor_evaluation(cross_validation_evaluator, predictor_ref): evaluation.evaluators = [cross_validation_evaluator] evaluation.predictor_id = predictor_ref.uid evaluation.predictor_version = predictor_ref.version - evaluation.status = 'SUCCEEDED' - evaluation.status_description = 'COMPLETED' + evaluation.status = "SUCCEEDED" + evaluation.status_description = "COMPLETED" yield evaluation -def test_predictor_evaluator_response(predictor_evaluators_response, cross_validation_evaluator): +def test_predictor_evaluator_response( + predictor_evaluators_response, cross_validation_evaluator +): assert predictor_evaluators_response.evaluators == [cross_validation_evaluator] -def test_predictor_evaluator_request(predictor_evaluation_request, cross_validation_evaluator, predictor_ref): +def test_predictor_evaluator_request( + predictor_evaluation_request, cross_validation_evaluator, predictor_ref +): assert predictor_evaluation_request.evaluators == [cross_validation_evaluator] assert predictor_evaluation_request.predictor.dump() == predictor_ref.dump() -def test_predictor_evaluation(predictor_evaluation, cross_validation_evaluator, predictor_ref): +def test_predictor_evaluation( + predictor_evaluation, cross_validation_evaluator, predictor_ref +): assert predictor_evaluation.evaluators == [cross_validation_evaluator] assert predictor_evaluation.evaluator_names == [cross_validation_evaluator.name] assert predictor_evaluation.predictor_id == predictor_ref.uid assert predictor_evaluation.predictor_version == predictor_ref.version - assert predictor_evaluation.status == 'SUCCEEDED' - assert predictor_evaluation.status_description == 'COMPLETED' + assert predictor_evaluation.status == "SUCCEEDED" + assert predictor_evaluation.status_description == "COMPLETED" assert predictor_evaluation.status_detail == [] @@ -71,12 +92,15 @@ def test_results(predictor_evaluation, example_cv_result_dict): results = predictor_evaluation["Example Evaluator"] expected_call = FakeCall( - method='GET', - path=f'/projects/{predictor_evaluation.project_id}/predictor-evaluations/{predictor_evaluation.uid}/results', - params={"evaluator_name": "Example Evaluator"} + method="GET", + path=f"/projects/{predictor_evaluation.project_id}/predictor-evaluations/{predictor_evaluation.uid}/results", + params={"evaluator_name": "Example Evaluator"}, ) assert session.last_call == expected_call - assert results.evaluator == PredictorEvaluationResult.build(example_cv_result_dict).evaluator + assert ( + results.evaluator + == PredictorEvaluationResult.build(example_cv_result_dict).evaluator + ) def test_results_invalid_type(predictor_evaluation): diff --git a/tests/informatics/test_predictors.py b/tests/informatics/test_predictors.py index 6eff41b5f..653c7b191 100644 --- a/tests/informatics/test_predictors.py +++ b/tests/informatics/test_predictors.py @@ -1,66 +1,89 @@ """Tests for citrine.informatics.predictors.""" + +import uuid + import mock import pytest -import uuid -from random import random from citrine.informatics.data_sources import GemTableDataSource -from citrine.informatics.descriptors import RealDescriptor, IntegerDescriptor, \ - MolecularStructureDescriptor, FormulationDescriptor, ChemicalFormulaDescriptor, \ - CategoricalDescriptor, FormulationKey -from citrine.informatics.predictors import * +from citrine.informatics.descriptors import ( + CategoricalDescriptor, + ChemicalFormulaDescriptor, + FormulationDescriptor, + FormulationKey, + IntegerDescriptor, + MolecularStructureDescriptor, + RealDescriptor, +) +from citrine.informatics.design_candidate import DesignMaterial +from citrine.informatics.predictors import ( + AttributeAccumulationPredictor, + AutoMLEstimator, + AutoMLPredictor, + ChemicalFormulaFeaturizer, + ExpressionPredictor, + GraphPredictor, + IngredientFractionsPredictor, + IngredientsToFormulationPredictor, + LabelFractionsPredictor, + MeanPropertyPredictor, + MolecularStructureFeaturizer, + SimpleMixturePredictor, +) from citrine.informatics.predictors.single_predict_request import SinglePredictRequest from citrine.informatics.predictors.single_prediction import SinglePrediction -from citrine.informatics.design_candidate import DesignMaterial - from tests.utils.factories import FeatureEffectsResponseFactory from tests.utils.session import FakeCall, FakeSession - w = IntegerDescriptor("w", lower_bound=0, upper_bound=100) x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="") y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="") z = RealDescriptor("z", lower_bound=0, upper_bound=100, units="") -density = RealDescriptor('density', lower_bound=0, upper_bound=100, units='g/cm^3') -shear_modulus = RealDescriptor('Property~Shear modulus', lower_bound=0, upper_bound=100, units='GPa') -youngs_modulus = RealDescriptor('Property~Young\'s modulus', lower_bound=0, upper_bound=100, units='GPa') -poissons_ratio = RealDescriptor('Property~Poisson\'s ratio', lower_bound=-1, upper_bound=0.5, units='') -chain_type = CategoricalDescriptor('Chain Type', categories={'Gaussian Coil', 'Rigid Rod', 'Worm-like'}) +density = RealDescriptor("density", lower_bound=0, upper_bound=100, units="g/cm^3") +shear_modulus = RealDescriptor( + "Property~Shear modulus", lower_bound=0, upper_bound=100, units="GPa" +) +youngs_modulus = RealDescriptor( + "Property~Young's modulus", lower_bound=0, upper_bound=100, units="GPa" +) +poissons_ratio = RealDescriptor( + "Property~Poisson's ratio", lower_bound=-1, upper_bound=0.5, units="" +) +chain_type = CategoricalDescriptor( + "Chain Type", categories={"Gaussian Coil", "Rigid Rod", "Worm-like"} +) flat_formulation = FormulationDescriptor.flat() -water_quantity = RealDescriptor('water quantity', lower_bound=0, upper_bound=1, units="") -salt_quantity = RealDescriptor('salt quantity', lower_bound=0, upper_bound=1, units="") -data_source = GemTableDataSource(table_id=uuid.UUID('e5c51369-8e71-4ec6-b027-1f92bdc14762'), table_version=0) -formulation_data_source = GemTableDataSource(table_id=uuid.UUID('6894a181-81d2-4304-9dfa-a6c5b114d8bc'), table_version=0) +water_quantity = RealDescriptor( + "water quantity", lower_bound=0, upper_bound=1, units="" +) +salt_quantity = RealDescriptor("salt quantity", lower_bound=0, upper_bound=1, units="") +data_source = GemTableDataSource( + table_id=uuid.UUID("e5c51369-8e71-4ec6-b027-1f92bdc14762"), table_version=0 +) +formulation_data_source = GemTableDataSource( + table_id=uuid.UUID("6894a181-81d2-4304-9dfa-a6c5b114d8bc"), table_version=0 +) def build_predictor_data(instance): return dict( name=instance.get("name"), description=instance.get("description"), - instance=instance + instance=instance, ) def build_predictor_entity(data): user = str(uuid.uuid4()) - time = '2020-04-23T15:46:26Z' + time = "2020-04-23T15:46:26Z" return dict( id=str(uuid.uuid4()), data=data, metadata=dict( - status=dict( - name='READY', - info=[] - ), - created=dict( - user=user, - time=time - ), - updated=dict( - user=user, - time=time - ) - ) + status=dict(name="READY", info=[]), + created=dict(user=user, time=time), + updated=dict(user=user, time=time), + ), ) @@ -71,7 +94,7 @@ def molecule_featurizer() -> MolecularStructureFeaturizer: description="description", input_descriptor=MolecularStructureDescriptor("SMILES"), features=["all"], - excludes=["standard"] + excludes=["standard"], ) @@ -83,37 +106,37 @@ def chemical_featurizer() -> ChemicalFormulaFeaturizer: input_descriptor=ChemicalFormulaDescriptor("formula"), features=["standard"], excludes=None, - powers=[1, 2] + powers=[1, 2], ) @pytest.fixture def auto_ml() -> AutoMLPredictor: return AutoMLPredictor( - name='AutoML Predictor', - description='Predicts z from inputs w and x', + name="AutoML Predictor", + description="Predicts z from inputs w and x", inputs=[w, x], - outputs=[z] + outputs=[z], ) @pytest.fixture def auto_ml_no_outputs() -> AutoMLPredictor: return AutoMLPredictor( - name='AutoML Predictor', - description='Predicts z from inputs w and x', + name="AutoML Predictor", + description="Predicts z from inputs w and x", inputs=[w, x], - outputs=[] + outputs=[], ) @pytest.fixture def auto_ml_multiple_outputs() -> AutoMLPredictor: return AutoMLPredictor( - name='AutoML Predictor', - description='Predicts z from inputs w and x', + name="AutoML Predictor", + description="Predicts z from inputs w and x", inputs=[w, x], - outputs=[z, y] + outputs=[z, y], ) @@ -121,10 +144,10 @@ def auto_ml_multiple_outputs() -> AutoMLPredictor: def graph_predictor(molecule_featurizer, auto_ml) -> GraphPredictor: """Build a GraphPredictor for testing.""" return GraphPredictor( - name='Graph predictor', - description='description', + name="Graph predictor", + description="description", predictors=[molecule_featurizer, auto_ml], - training_data=[data_source, formulation_data_source] + training_data=[data_source, formulation_data_source], ) @@ -132,30 +155,22 @@ def graph_predictor(molecule_featurizer, auto_ml) -> GraphPredictor: def expression_predictor() -> ExpressionPredictor: """Build an ExpressionPredictor for testing.""" return ExpressionPredictor( - name='Expression predictor', - description='Computes shear modulus from Youngs modulus and Poissons ratio', - expression='Y / (2 * (1 + v))', + name="Expression predictor", + description="Computes shear modulus from Youngs modulus and Poissons ratio", + expression="Y / (2 * (1 + v))", output=shear_modulus, - aliases={ - 'Y': youngs_modulus, - 'v': poissons_ratio - }) + aliases={"Y": youngs_modulus, "v": poissons_ratio}, + ) @pytest.fixture def ing_to_formulation_predictor() -> IngredientsToFormulationPredictor: """Build an IngredientsToFormulationPredictor for testing.""" return IngredientsToFormulationPredictor( - name='Ingredients to formulation predictor', - description='Constructs a mixture from ingredient quantities', - id_to_quantity={ - 'water': water_quantity, - 'salt': salt_quantity - }, - labels={ - 'solvent': {'water'}, - 'solute': {'salt'} - } + name="Ingredients to formulation predictor", + description="Constructs a mixture from ingredient quantities", + id_to_quantity={"water": water_quantity, "salt": salt_quantity}, + labels={"solvent": {"water"}, "solute": {"salt"}}, ) @@ -163,14 +178,14 @@ def ing_to_formulation_predictor() -> IngredientsToFormulationPredictor: def mean_property_predictor() -> MeanPropertyPredictor: """Build a mean property predictor for testing.""" return MeanPropertyPredictor( - name='Mean property predictor', - description='Computes mean ingredient properties', + name="Mean property predictor", + description="Computes mean ingredient properties", input_descriptor=flat_formulation, properties=[density, chain_type], p=2.5, impute_properties=True, - default_properties={'density': 1.0, 'Chain Type': 'Gaussian Coil'}, - label='solvent' + default_properties={"density": 1.0, "Chain Type": "Gaussian Coil"}, + label="solvent", ) @@ -178,8 +193,8 @@ def mean_property_predictor() -> MeanPropertyPredictor: def simple_mixture_predictor() -> SimpleMixturePredictor: """Build a simple mixture predictor for testing.""" return SimpleMixturePredictor( - name='Simple mixture predictor', - description='Computes mean ingredient properties', + name="Simple mixture predictor", + description="Computes mean ingredient properties", ) @@ -187,10 +202,10 @@ def simple_mixture_predictor() -> SimpleMixturePredictor: def label_fractions_predictor() -> LabelFractionsPredictor: """Build a label fractions predictor for testing""" return LabelFractionsPredictor( - name='Label fractions predictor', - description='Compute relative proportions of labeled ingredients', + name="Label fractions predictor", + description="Compute relative proportions of labeled ingredients", input_descriptor=flat_formulation, - labels={'solvent'} + labels={"solvent"}, ) @@ -198,10 +213,10 @@ def label_fractions_predictor() -> LabelFractionsPredictor: def ingredient_fractions_predictor() -> IngredientFractionsPredictor: """Build a Ingredient Fractions predictor for testing.""" return IngredientFractionsPredictor( - name='Ingredient fractions predictor', - description='Computes total ingredient fractions', + name="Ingredient fractions predictor", + description="Computes total ingredient fractions", input_descriptor=flat_formulation, - ingredients={"Green Paste", "Blue Paste"} + ingredients={"Green Paste", "Blue Paste"}, ) @@ -211,7 +226,7 @@ def attribute_accumulation_predictor() -> AttributeAccumulationPredictor: name="Attribute accumulation predictor", description="Aid training", attributes=[x, y], - sequential=True + sequential=True, ) @@ -221,51 +236,58 @@ def test_simple_report(graph_predictor): # without a project or session, this should error assert graph_predictor.report is None session = mock.Mock() - session.get_resource.return_value = dict(status='OK', report=dict(descriptors=[], models=[]), uid=str(uuid.uuid4())) + session.get_resource.return_value = dict( + status="OK", report=dict(descriptors=[], models=[]), uid=str(uuid.uuid4()) + ) graph_predictor._session = session graph_predictor._project_id = uuid.uuid4() graph_predictor.uid = uuid.uuid4() graph_predictor.version = 2 assert graph_predictor.report is not None assert session.get_resource.call_count == 1 - assert graph_predictor.report.status == 'OK' + assert graph_predictor.report.status == "OK" def test_graph_initialization(graph_predictor): """Make sure the correct fields go to the correct places for a graph predictor.""" - assert graph_predictor.name == 'Graph predictor' - assert graph_predictor.description == 'description' + assert graph_predictor.name == "Graph predictor" + assert graph_predictor.description == "description" assert len(graph_predictor.predictors) == 2 assert graph_predictor.training_data == [data_source, formulation_data_source] - assert str(graph_predictor) == '' + assert str(graph_predictor) == "" def test_expression_initialization(expression_predictor): """Make sure the correct fields go to the correct places for an expression predictor.""" - assert expression_predictor.name == 'Expression predictor' - assert expression_predictor.output.key == 'Property~Shear modulus' - assert expression_predictor.expression == 'Y / (2 * (1 + v))' - assert expression_predictor.aliases == {'Y': youngs_modulus, 'v': poissons_ratio} - assert str(expression_predictor) == '' + assert expression_predictor.name == "Expression predictor" + assert expression_predictor.output.key == "Property~Shear modulus" + assert expression_predictor.expression == "Y / (2 * (1 + v))" + assert expression_predictor.aliases == {"Y": youngs_modulus, "v": poissons_ratio} + assert str(expression_predictor) == "" def test_molecule_featurizer(molecule_featurizer): assert molecule_featurizer.name == "Molecule featurizer" assert molecule_featurizer.description == "description" - assert molecule_featurizer.input_descriptor == MolecularStructureDescriptor("SMILES") + assert molecule_featurizer.input_descriptor == MolecularStructureDescriptor( + "SMILES" + ) assert molecule_featurizer.features == ["all"] assert molecule_featurizer.excludes == ["standard"] - assert str(molecule_featurizer) == "" + assert ( + str(molecule_featurizer) + == "" + ) assert molecule_featurizer.dump() == { - 'name': 'Molecule featurizer', - 'description': 'description', - 'descriptor': {'descriptor_key': 'SMILES', 'type': 'Organic'}, - 'features': ['all'], - 'excludes': ['standard'], - 'type': 'MoleculeFeaturizer' - } + "name": "Molecule featurizer", + "description": "description", + "descriptor": {"descriptor_key": "SMILES", "type": "Organic"}, + "features": ["all"], + "excludes": ["standard"], + "type": "MoleculeFeaturizer", + } def test_chemical_featurizer(chemical_featurizer): @@ -279,18 +301,20 @@ def test_chemical_featurizer(chemical_featurizer): with pytest.warns(PendingDeprecationWarning): assert chemical_featurizer.powers_as_float == [1.0, 2.0] - assert str(chemical_featurizer) == "" + assert ( + str(chemical_featurizer) == "" + ) assert chemical_featurizer.dump() == { - 'name': 'Chemical featurizer', - 'description': 'description', - 'input': ChemicalFormulaDescriptor("formula").dump(), - 'features': ['standard'], - 'excludes': [], - 'powers': [1, 2], - 'type': 'ChemicalFormulaFeaturizer' + "name": "Chemical featurizer", + "description": "description", + "input": ChemicalFormulaDescriptor("formula").dump(), + "features": ["standard"], + "excludes": [], + "powers": [1, 2], + "type": "ChemicalFormulaFeaturizer", } - + chemical_featurizer.powers = [0.5, -1] with pytest.warns(PendingDeprecationWarning): assert chemical_featurizer.powers_as_float == [0.5, -1.0] @@ -302,93 +326,96 @@ def test_auto_ml(auto_ml): assert auto_ml.name == "AutoML Predictor" assert auto_ml.description == "Predicts z from inputs w and x" assert auto_ml.inputs == [w, x] - assert auto_ml.dump()['outputs'] == [z.dump()] + assert auto_ml.dump()["outputs"] == [z.dump()] assert str(auto_ml) == "" built = AutoMLPredictor.build(auto_ml.dump()) assert built.outputs == [z] - assert built.dump()['outputs'] == [z.dump()] + assert built.dump()["outputs"] == [z.dump()] def test_auto_ml_no_outputs(auto_ml_no_outputs): assert auto_ml_no_outputs.outputs == [] - assert auto_ml_no_outputs.dump()['outputs'] == [] + assert auto_ml_no_outputs.dump()["outputs"] == [] built = AutoMLPredictor.build(auto_ml_no_outputs.dump()) assert built.outputs == [] - assert built.dump()['outputs'] == [] + assert built.dump()["outputs"] == [] def test_auto_ml_estimators(): # Check an empty set is coerced to RF default empty_aml = AutoMLPredictor( - name="", - description="", - inputs=[x], - outputs=[y], - estimators={} + name="", description="", inputs=[x], outputs=[y], estimators={} ) assert empty_aml.estimators == {AutoMLEstimator.RANDOM_FOREST} # Check passing invalid strings leads to an error with pytest.raises(ValueError): AutoMLPredictor( - name="", - description="", - inputs=[x], - outputs=[y], - estimators={"pancakes"} + name="", description="", inputs=[x], outputs=[y], estimators={"pancakes"} ) def test_auto_ml_multiple_outputs(auto_ml_multiple_outputs): assert auto_ml_multiple_outputs.outputs == [z, y] - assert auto_ml_multiple_outputs.dump()['outputs'] == [z.dump(), y.dump()] + assert auto_ml_multiple_outputs.dump()["outputs"] == [z.dump(), y.dump()] built = AutoMLPredictor.build(auto_ml_multiple_outputs.dump()) assert built.outputs == [z, y] - assert built.dump()['outputs'] == [z.dump(), y.dump()] + assert built.dump()["outputs"] == [z.dump(), y.dump()] def test_auto_ml_deprecated_training_data(auto_ml): with pytest.deprecated_call(): pred = AutoMLPredictor( - name='AutoML Predictor', - description='Predicts z from inputs w and x', + name="AutoML Predictor", + description="Predicts z from inputs w and x", inputs=auto_ml.inputs, outputs=auto_ml.outputs, - training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)] + training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)], ) new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)] with pytest.deprecated_call(): pred.training_data = new_training_data - + with pytest.deprecated_call(): assert pred.training_data == new_training_data def test_ing_to_formulation_initialization(ing_to_formulation_predictor): """Make sure the correct fields go to the correct places for an ingredients to formulation predictor.""" - assert ing_to_formulation_predictor.name == 'Ingredients to formulation predictor' + assert ing_to_formulation_predictor.name == "Ingredients to formulation predictor" assert ing_to_formulation_predictor.output.key == FormulationKey.HIERARCHICAL.value - assert ing_to_formulation_predictor.id_to_quantity == {'water': water_quantity, 'salt': salt_quantity} - assert ing_to_formulation_predictor.labels == {'solvent': {'water'}, 'solute': {'salt'}} - expected_str = f'' + assert ing_to_formulation_predictor.id_to_quantity == { + "water": water_quantity, + "salt": salt_quantity, + } + assert ing_to_formulation_predictor.labels == { + "solvent": {"water"}, + "solute": {"salt"}, + } + expected_str = ( + f"" + ) assert str(ing_to_formulation_predictor) == expected_str def test_mean_property_initialization(mean_property_predictor): """Make sure the correct fields go to the correct places for a mean property predictor.""" - assert mean_property_predictor.name == 'Mean property predictor' + assert mean_property_predictor.name == "Mean property predictor" assert mean_property_predictor.input_descriptor.key == FormulationKey.FLAT.value assert mean_property_predictor.properties == [density, chain_type] assert mean_property_predictor.p == 2.5 - assert mean_property_predictor.impute_properties == True - assert mean_property_predictor.default_properties == {'density': 1.0, 'Chain Type': 'Gaussian Coil'} - assert mean_property_predictor.label == 'solvent' - expected_str = '' + assert mean_property_predictor.impute_properties is True + assert mean_property_predictor.default_properties == { + "density": 1.0, + "Chain Type": "Gaussian Coil", + } + assert mean_property_predictor.label == "solvent" + expected_str = "" assert str(mean_property_predictor) == expected_str @@ -406,84 +433,97 @@ def test_mean_property_round_robin(mean_property_predictor): def test_mean_property_training_data_deprecated(mean_property_predictor): with pytest.deprecated_call(): pred = MeanPropertyPredictor( - name='Mean property predictor', - description='Computes mean ingredient properties', + name="Mean property predictor", + description="Computes mean ingredient properties", input_descriptor=mean_property_predictor.input_descriptor, properties=mean_property_predictor.properties, p=2.5, impute_properties=True, default_properties=mean_property_predictor.default_properties, label=mean_property_predictor.label, - training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)] + training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)], ) - + new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)] with pytest.deprecated_call(): pred.training_data = new_training_data - + with pytest.deprecated_call(): assert pred.training_data == new_training_data def test_label_fractions_property_initialization(label_fractions_predictor): """Make sure the correct fields go to the correct places for a label fraction predictor.""" - assert label_fractions_predictor.name == 'Label fractions predictor' + assert label_fractions_predictor.name == "Label fractions predictor" assert label_fractions_predictor.input_descriptor.key == FormulationKey.FLAT.value - assert label_fractions_predictor.labels == {'solvent'} - expected_str = '' + assert label_fractions_predictor.labels == {"solvent"} + expected_str = "" assert str(label_fractions_predictor) == expected_str def test_simple_mixture_predictor_initialization(simple_mixture_predictor): """Make sure the correct fields go to the correct places for a simple mixture predictor.""" - assert simple_mixture_predictor.name == 'Simple mixture predictor' - assert simple_mixture_predictor.input_descriptor.key == FormulationKey.HIERARCHICAL.value + assert simple_mixture_predictor.name == "Simple mixture predictor" + assert ( + simple_mixture_predictor.input_descriptor.key + == FormulationKey.HIERARCHICAL.value + ) assert simple_mixture_predictor.output_descriptor.key == FormulationKey.FLAT.value - expected_str = '' + expected_str = "" assert str(simple_mixture_predictor) == expected_str def test_simplex_mixture_training_data_deprecated(): with pytest.deprecated_call(): pred = SimpleMixturePredictor( - name='Simple mixture predictor', - description='Computes mean ingredient properties', - training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)] + name="Simple mixture predictor", + description="Computes mean ingredient properties", + training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)], ) - + new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)] with pytest.deprecated_call(): pred.training_data = new_training_data - + with pytest.deprecated_call(): assert pred.training_data == new_training_data def test_ingredient_fractions_property_initialization(ingredient_fractions_predictor): """Make sure the correct fields go to the correct places for an ingredient fractions predictor.""" - assert ingredient_fractions_predictor.name == 'Ingredient fractions predictor' - assert ingredient_fractions_predictor.input_descriptor.key == FormulationKey.FLAT.value + assert ingredient_fractions_predictor.name == "Ingredient fractions predictor" + assert ( + ingredient_fractions_predictor.input_descriptor.key == FormulationKey.FLAT.value + ) assert ingredient_fractions_predictor.ingredients == {"Green Paste", "Blue Paste"} - expected_str = '' + expected_str = "" assert str(ingredient_fractions_predictor) == expected_str -def test_attribute_accumulation_property_initialization(attribute_accumulation_predictor): +def test_attribute_accumulation_property_initialization( + attribute_accumulation_predictor, +): """Make sure the correct fields go to the correct places for an attribute accumulation predictor.""" - assert attribute_accumulation_predictor.name == 'Attribute accumulation predictor' + assert attribute_accumulation_predictor.name == "Attribute accumulation predictor" assert attribute_accumulation_predictor.attributes == [x, y] assert attribute_accumulation_predictor.sequential is True - expected_str = '' + expected_str = "" assert str(attribute_accumulation_predictor) == expected_str def test_status(graph_predictor, valid_graph_predictor_data): """Ensure we can check the status of predictor validation.""" # A locally built predictor should be "False" for all status checks - assert not graph_predictor.in_progress() and not graph_predictor.failed() and not graph_predictor.succeeded() + assert ( + not graph_predictor.in_progress() + and not graph_predictor.failed() + and not graph_predictor.succeeded() + ) # A deserialized predictor should have the correct status predictor = GraphPredictor.build(valid_graph_predictor_data) - assert predictor.succeeded() and not predictor.in_progress() and not predictor.failed() + assert ( + predictor.succeeded() and not predictor.in_progress() and not predictor.failed() + ) def test_single_predict(graph_predictor): @@ -493,13 +533,11 @@ def test_single_predict(graph_predictor): graph_predictor.uid = uuid.uuid4() graph_predictor.version = 2 material_data = { - 'vars': { - 'X': {'m': 1.1, 's': 0.1, 'type': 'R'}, - 'Y': {'m': 2.2, 's': 0.2, 'type': 'R'} + "vars": { + "X": {"m": 1.1, "s": 0.1, "type": "R"}, + "Y": {"m": 2.2, "s": 0.2, "type": "R"}, }, - 'identifiers': { - 'id': str(uuid.uuid4()) - } + "identifiers": {"id": str(uuid.uuid4())}, } material = DesignMaterial.build(material_data) request = SinglePredictRequest(uuid.uuid4(), list(), material) @@ -517,31 +555,37 @@ def test_feature_effects(graph_predictor): session = FakeSession() session.set_response(feature_effects_response) - + graph_predictor._session = session graph_predictor._project_id = uuid.uuid4() fe = graph_predictor.feature_effects - expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \ - f"/versions/{graph_predictor.version}/shapley/query" - assert session.calls == [FakeCall(method='POST', path=expected_path, json={})] + expected_path = ( + f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + + f"/versions/{graph_predictor.version}/shapley/query" + ) + assert session.calls == [FakeCall(method="POST", path=expected_path, json={})] assert fe.as_dict == feature_effects_as_dict def test_feature_effects_in_progress(graph_predictor): - feature_effects_response = FeatureEffectsResponseFactory(metadata__status="INPROGRESS", result=None) + feature_effects_response = FeatureEffectsResponseFactory( + metadata__status="INPROGRESS", result=None + ) session = FakeSession() session.set_response(feature_effects_response) - + graph_predictor._session = session graph_predictor._project_id = uuid.uuid4() fe = graph_predictor.feature_effects - expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \ - f"/versions/{graph_predictor.version}/shapley/query" - assert session.calls == [FakeCall(method='POST', path=expected_path, json={})] + expected_path = ( + f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + + f"/versions/{graph_predictor.version}/shapley/query" + ) + assert session.calls == [FakeCall(method="POST", path=expected_path, json={})] assert fe.outputs is None assert fe.as_dict == {} diff --git a/tests/informatics/test_reports.py b/tests/informatics/test_reports.py index 6b031c796..49ed97af5 100644 --- a/tests/informatics/test_reports.py +++ b/tests/informatics/test_reports.py @@ -1,6 +1,6 @@ """Tests reports initialization.""" -from citrine.informatics.reports import PredictorReport, ModelSummary, FeatureImportanceReport, Report -from citrine.informatics.descriptors import RealDescriptor + +from citrine.informatics.reports import PredictorReport, Report def test_status(valid_predictor_report_data): @@ -13,7 +13,9 @@ def test_selection_summary(valid_predictor_report_data): """Ensure that we can iterate selection summary results as expected.""" report = PredictorReport.build(valid_predictor_report_data) selection_summaries = [ - s.selection_summary for s in report.model_summaries if s.selection_summary is not None + s.selection_summary + for s in report.model_summaries + if s.selection_summary is not None ] assert len(selection_summaries) > 0 diff --git a/tests/informatics/test_scores.py b/tests/informatics/test_scores.py index 7a6cdc8b8..46f7020b7 100644 --- a/tests/informatics/test_scores.py +++ b/tests/informatics/test_scores.py @@ -1,4 +1,5 @@ """Tests for citrine.informatics.scores.""" + import pytest from citrine.informatics.constraints import ScalarRangeConstraint @@ -10,12 +11,7 @@ def li_score() -> LIScore: """Build an LIScore.""" return LIScore( - objectives=[ - ScalarMaxObjective( - descriptor_key="z" - ) - ], - baselines=[10.0] + objectives=[ScalarMaxObjective(descriptor_key="z")], baselines=[10.0] ) @@ -23,13 +19,11 @@ def li_score() -> LIScore: def ei_score() -> EIScore: """Build an EIScore.""" return EIScore( - objectives=[ - ScalarMaxObjective( - descriptor_key="x" - ) - ], + objectives=[ScalarMaxObjective(descriptor_key="x")], baselines=[1.0], - constraints=[ScalarRangeConstraint(descriptor_key='y', lower_bound=0.0, upper_bound=1.0)] + constraints=[ + ScalarRangeConstraint(descriptor_key="y", lower_bound=0.0, upper_bound=1.0) + ], ) @@ -37,34 +31,32 @@ def ei_score() -> EIScore: def ev_score() -> EVScore: """Build an MEVScore.""" return EVScore( - objectives=[ - ScalarMaxObjective( - descriptor_key="x" - ) + objectives=[ScalarMaxObjective(descriptor_key="x")], + constraints=[ + ScalarRangeConstraint(descriptor_key="y", lower_bound=0.0, upper_bound=1.0) ], - constraints=[ScalarRangeConstraint(descriptor_key='y', lower_bound=0.0, upper_bound=1.0)] ) def test_li_initialization(li_score): """Make sure the correct fields go to the correct places.""" assert isinstance(li_score.objectives[0], ScalarMaxObjective) - assert li_score.objectives[0].descriptor_key == 'z' + assert li_score.objectives[0].descriptor_key == "z" assert li_score.baselines == [10.0] assert li_score.constraints == [] def test_ei_initialization(ei_score): """Make sure the correct fields go to the correct places.""" - assert ei_score.objectives[0].descriptor_key == 'x' + assert ei_score.objectives[0].descriptor_key == "x" assert ei_score.baselines == [1.0] assert isinstance(ei_score.constraints[0], ScalarRangeConstraint) - assert ei_score.constraints[0].descriptor_key == 'y' + assert ei_score.constraints[0].descriptor_key == "y" def test_ev_initialization(ev_score): """Make sure the correct fields go to the correct places.""" - assert ev_score.objectives[0].descriptor_key == 'x' + assert ev_score.objectives[0].descriptor_key == "x" assert isinstance(ev_score.constraints[0], ScalarRangeConstraint) - assert ev_score.constraints[0].descriptor_key == 'y' + assert ev_score.constraints[0].descriptor_key == "y" assert "EVScore" in str(ev_score) diff --git a/tests/informatics/test_workflows.py b/tests/informatics/test_workflows.py index 7d5fcdcb1..c007451ee 100644 --- a/tests/informatics/test_workflows.py +++ b/tests/informatics/test_workflows.py @@ -1,11 +1,18 @@ """Tests for citrine.informatics.workflows.""" -from multiprocessing.reduction import register + from uuid import uuid4, UUID import pytest -from citrine.informatics.design_candidate import DesignMaterial, DesignCandidate, ChemicalFormula, \ - MeanAndStd, TopCategories, Mixture, MolecularStructure +from citrine.informatics.design_candidate import ( + DesignMaterial, + DesignCandidate, + ChemicalFormula, + MeanAndStd, + TopCategories, + Mixture, + MolecularStructure, +) from citrine.informatics.executions import DesignExecution from citrine.informatics.predict_request import PredictRequest from citrine.informatics.workflows import DesignWorkflow @@ -20,10 +27,12 @@ def branch_data(): return BranchDataFactory() + @pytest.fixture def session() -> FakeSession: return FakeSession() + @pytest.fixture def collection(session, branch_data) -> DesignWorkflowCollection: session.set_response(branch_data) @@ -49,13 +58,14 @@ def execution_collection(session) -> DesignExecutionCollection: def design_workflow(collection) -> DesignWorkflow: return collection.build(DesignWorkflowDataFactory(register=True)) + @pytest.fixture def design_execution(execution_collection, design_execution_dict) -> DesignExecution: return execution_collection.build(design_execution_dict) def test_d_workflow_str(design_workflow): - assert str(design_workflow) == f'' + assert str(design_workflow) == f"" def test_workflow_executions_with_project(design_workflow): @@ -64,9 +74,7 @@ def test_workflow_executions_with_project(design_workflow): def test_workflow_executions_without_project(): workflow = DesignWorkflow( - name="workflow", - design_space_id=uuid4(), - predictor_id=uuid4() + name="workflow", design_space_id=uuid4(), predictor_id=uuid4() ) with pytest.raises(AttributeError): workflow.design_executions @@ -74,17 +82,11 @@ def test_workflow_executions_without_project(): def test_design_material(): values = { - "RealValue": MeanAndStd(mean=1.4,std=.3), - "Cat": TopCategories(probabilities={ - "Red": 0.85, - "Blue": 0.15 - }), - "Mixture": Mixture(quantities={ - "Water": 0.5, - "Active": 0.5 - }), + "RealValue": MeanAndStd(mean=1.4, std=0.3), + "Cat": TopCategories(probabilities={"Red": 0.85, "Blue": 0.15}), + "Mixture": Mixture(quantities={"Water": 0.5, "Active": 0.5}), "Formula": ChemicalFormula(formula="NaCl"), - "Organic": MolecularStructure(smiles="CCO") + "Organic": MolecularStructure(smiles="CCO"), } material = DesignMaterial(values=values) assert material.values == values @@ -93,30 +95,32 @@ def test_design_material(): def test_predict(design_workflow, design_execution, example_candidates): session = design_execution._session - candidate = DesignCandidate.build(example_candidates['response'][0]) + candidate = DesignCandidate.build(example_candidates["response"][0]) material_id = UUID("9953cc63-5d53-4d0a-884a-a9cff3b7de18") - predict_req = PredictRequest(material_id=material_id, - material=candidate.material, - created_from_id=candidate.uid, - identifiers=candidate.identifiers) + predict_req = PredictRequest( + material_id=material_id, + material=candidate.material, + created_from_id=candidate.uid, + identifiers=candidate.identifiers, + ) session.set_response(candidate.dump()) predict_response = design_execution.predict(predict_request=predict_req) assert session.num_calls == 1 expected_call = FakeCall( - method='POST', + method="POST", path=f"/projects/{design_execution.project_id}/design-workflows/{design_execution.workflow_id}" - + f"/executions/{design_execution.uid}/predict", + + f"/executions/{design_execution.uid}/predict", json={ - 'material_id': str(material_id), - 'identifiers': [], - 'material': candidate.material.dump(), - 'created_from_id': str(candidate.uid), - 'random_seed': None + "material_id": str(material_id), + "identifiers": [], + "material": candidate.material.dump(), + "created_from_id": str(candidate.uid), + "random_seed": None, }, - version="v1" + version="v1", ) assert expected_call == session.last_call diff --git a/tests/informatics/workflows/test_predictor_evaluation_workflow.py b/tests/informatics/workflows/test_predictor_evaluation_workflow.py index eee00a1cc..40d872bb4 100644 --- a/tests/informatics/workflows/test_predictor_evaluation_workflow.py +++ b/tests/informatics/workflows/test_predictor_evaluation_workflow.py @@ -2,7 +2,11 @@ import uuid from citrine.informatics.data_sources import GemTableDataSource -from citrine.informatics.predictor_evaluator import HoldoutSetEvaluator, CrossValidationEvaluator, PredictorEvaluator +from citrine.informatics.predictor_evaluator import ( + HoldoutSetEvaluator, + CrossValidationEvaluator, + PredictorEvaluator, +) from citrine.informatics.workflows import PredictorEvaluationWorkflow @@ -10,11 +14,11 @@ def pew(): data_source = GemTableDataSource(table_id=uuid.uuid4(), table_version=3) evaluator1 = CrossValidationEvaluator(name="test CV", responses={"foo"}) - evaluator2 = HoldoutSetEvaluator(name="test holdout", responses={"foo"}, data_source=data_source) + evaluator2 = HoldoutSetEvaluator( + name="test holdout", responses={"foo"}, data_source=data_source + ) pew = PredictorEvaluationWorkflow( - name="Test", - description="TestWorkflow", - evaluators=[evaluator1, evaluator2] + name="Test", description="TestWorkflow", evaluators=[evaluator1, evaluator2] ) return pew @@ -23,8 +27,12 @@ def test_round_robin(pew): dumped = pew.dump() assert dumped["name"] == "Test" assert dumped["description"] == "TestWorkflow" - assert PredictorEvaluator.build(dumped["evaluators"][0]).name == pew.evaluators[0].name - assert PredictorEvaluator.build(dumped["evaluators"][1]).name == pew.evaluators[1].name + assert ( + PredictorEvaluator.build(dumped["evaluators"][0]).name == pew.evaluators[0].name + ) + assert ( + PredictorEvaluator.build(dumped["evaluators"][1]).name == pew.evaluators[1].name + ) def test_print(pew): diff --git a/tests/jobs/test_deprecations.py b/tests/jobs/test_deprecations.py index af0b870ee..3919f090c 100644 --- a/tests/jobs/test_deprecations.py +++ b/tests/jobs/test_deprecations.py @@ -4,12 +4,15 @@ from tests.utils.factories import TaskNodeDataFactory, JobStatusResponseDataFactory + def test_status_response_status(): - status_response = JobStatusResponse.build(JobStatusResponseDataFactory(failure=True)) + status_response = JobStatusResponse.build( + JobStatusResponseDataFactory(failure=True) + ) assert status_response.status == JobStatus.FAILURE with pytest.deprecated_call(): - status_response.status = 'Failed' + status_response.status = "Failed" with warnings.catch_warnings(): warnings.simplefilter("error") assert not isinstance(status_response.status, JobStatus) @@ -25,12 +28,13 @@ def test_status_response_status(): status_response.status = JobStatus.SUCCESS assert status_response.status == JobStatus.SUCCESS + def test_task_node_status(): status_response = TaskNode.build(TaskNodeDataFactory(failure=True)) assert status_response.status == JobStatus.FAILURE with pytest.deprecated_call(): - status_response.status = 'Failed' + status_response.status = "Failed" assert not isinstance(status_response.status, JobStatus) with warnings.catch_warnings(): diff --git a/tests/jobs/test_waiting.py b/tests/jobs/test_waiting.py index 62ce79f0d..f71c0a38c 100644 --- a/tests/jobs/test_waiting.py +++ b/tests/jobs/test_waiting.py @@ -1,4 +1,5 @@ """Tests waiting utilities""" + from datetime import datetime import io import mock @@ -11,12 +12,12 @@ wait_for_asynchronous_object, wait_while_executing, wait_while_validating, - ConditionTimeoutError + ConditionTimeoutError, ) from citrine.resources.status_detail import StatusDetail -@mock.patch('time.sleep', return_value=None) +@mock.patch("time.sleep", return_value=None) def test_wait_while_validating(sleep_mock): captured_output = io.StringIO() sys.stdout = captured_output @@ -24,7 +25,9 @@ def test_wait_while_validating(sleep_mock): collection = mock.Mock() module = mock.Mock() statuses = mock.PropertyMock(side_effect=["VALIDATING", "VALID", "VALID"]) - status_detail = mock.PropertyMock(return_value=[StatusDetail(msg="The predictor is now validated.", level="Info")]) + status_detail = mock.PropertyMock( + return_value=[StatusDetail(msg="The predictor is now validated.", level="Info")] + ) in_progress = mock.PropertyMock(side_effect=[True, False, False]) type(module).status = statuses type(module).status_detail = status_detail @@ -33,15 +36,16 @@ def test_wait_while_validating(sleep_mock): wait_while_validating(collection=collection, module=module, print_status_info=True) - assert("Status = VALID" in captured_output.getvalue()) - assert("The predictor is now validated." in captured_output.getvalue()) + assert "Status = VALID" in captured_output.getvalue() + assert "The predictor is now validated." in captured_output.getvalue() + -@mock.patch('time.time') -@mock.patch('time.sleep', return_value=None) +@mock.patch("time.time") +@mock.patch("time.sleep", return_value=None) def test_wait_while_validating_timeout(sleep_mock, time_mock): time_mock.side_effect = [ time.mktime(datetime(2020, 10, 30).timetuple()), - time.mktime(datetime(2020, 10, 31).timetuple()) + time.mktime(datetime(2020, 10, 31).timetuple()), ] collection = mock.Mock() @@ -54,7 +58,8 @@ def test_wait_while_validating_timeout(sleep_mock, time_mock): with pytest.raises(ConditionTimeoutError): wait_while_validating(collection=collection, module=module, timeout=1.0) -@mock.patch('time.sleep', return_value=None) + +@mock.patch("time.sleep", return_value=None) def test_wait_while_executing(sleep_mock): captured_output = io.StringIO() sys.stdout = captured_output @@ -62,24 +67,28 @@ def test_wait_while_executing(sleep_mock): collection = mock.Mock() workflow_execution = mock.Mock(spec=DesignExecution) statuses = mock.PropertyMock(side_effect=["INPROGRESS", "SUCCEEDED", "SUCCEEDED"]) - status_detail = mock.PropertyMock(return_value=[StatusDetail(msg="Execution is complete.", level="Info")]) + status_detail = mock.PropertyMock( + return_value=[StatusDetail(msg="Execution is complete.", level="Info")] + ) in_progress = mock.PropertyMock(side_effect=[True, False, False]) type(workflow_execution).status = statuses type(workflow_execution).status_detail = status_detail workflow_execution.in_progress = in_progress collection.get.return_value = workflow_execution - wait_while_executing(collection=collection, - execution=workflow_execution, print_status_info=True) + wait_while_executing( + collection=collection, execution=workflow_execution, print_status_info=True + ) + + assert "SUCCEEDED" in captured_output.getvalue() - assert("SUCCEEDED" in captured_output.getvalue()) -@mock.patch('time.time') -@mock.patch('time.sleep', return_value=None) +@mock.patch("time.time") +@mock.patch("time.sleep", return_value=None) def test_wait_for_asynchronous_object(sleep_mock, time_mock): time_mock.side_effect = [ time.mktime(datetime(2021, 8, 1).timetuple()), - time.mktime(datetime(2021, 8, 2).timetuple()) + time.mktime(datetime(2021, 8, 2).timetuple()), ] resource = mock.Mock() @@ -87,7 +96,10 @@ def test_wait_for_asynchronous_object(sleep_mock, time_mock): type(resource).uid = mock.PropertyMock(return_value=123456) with pytest.raises(ConditionTimeoutError) as exception: - wait_for_asynchronous_object(collection=collection, resource=resource, timeout=1.0) + wait_for_asynchronous_object( + collection=collection, resource=resource, timeout=1.0 + ) - assert str(exception.value) == ("Timeout of 1.0 seconds reached, " - "but task 123456 is still in progress") + assert str(exception.value) == ( + "Timeout of 1.0 seconds reached, but task 123456 is still in progress" + ) diff --git a/tests/resources/test_analysis_workflow.py b/tests/resources/test_analysis_workflow.py index 8f234412b..16986bf2a 100644 --- a/tests/resources/test_analysis_workflow.py +++ b/tests/resources/test_analysis_workflow.py @@ -16,8 +16,8 @@ def paging_response(*items): def _assert_user_timestamp_equals_dict(user, time, ut_dict): - assert str(user) == ut_dict['user'] - assert time == datetime.fromtimestamp(ut_dict['time'] / 1000, tz=timezone.utc) + assert str(user) == ut_dict["user"] + assert time == datetime.fromtimestamp(ut_dict["time"] / 1000, tz=timezone.utc) def _assert_aw_plot_equals_dict(plot, plot_dict): @@ -26,34 +26,40 @@ def _assert_aw_plot_equals_dict(plot, plot_dict): def _assert_aw_equals_dict(aw, aw_dict): - assert str(aw.uid) == aw_dict['id'] - assert aw.name == aw_dict['data']['name'] - assert aw.description == aw_dict['data']['description'] - snapshot_id_dict = aw_dict['data'].get('snapshot_id') + assert str(aw.uid) == aw_dict["id"] + assert aw.name == aw_dict["data"]["name"] + assert aw.description == aw_dict["data"]["description"] + snapshot_id_dict = aw_dict["data"].get("snapshot_id") if snapshot_id_dict: assert str(aw.snapshot_id) == snapshot_id_dict else: assert aw.snapshot_id is None - _assert_user_timestamp_equals_dict(aw.created_by, aw.create_time, aw_dict['metadata']['created']) - _assert_user_timestamp_equals_dict(aw.updated_by, aw.update_time, aw_dict['metadata']['updated']) - - aw_dict_latest_build = aw_dict['metadata'].get('latest_build') or {} + _assert_user_timestamp_equals_dict( + aw.created_by, aw.create_time, aw_dict["metadata"]["created"] + ) + _assert_user_timestamp_equals_dict( + aw.updated_by, aw.update_time, aw_dict["metadata"]["updated"] + ) + + aw_dict_latest_build = aw_dict["metadata"].get("latest_build") or {} if aw_dict_latest_build: - assert aw.latest_build.status == aw_dict_latest_build['status'] - assert aw.latest_build.failures == aw_dict_latest_build['failure_reason'] - assert aw.status == aw_dict_latest_build['status'] + assert aw.latest_build.status == aw_dict_latest_build["status"] + assert aw.latest_build.failures == aw_dict_latest_build["failure_reason"] + assert aw.status == aw_dict_latest_build["status"] else: assert aw.latest_build is None - aw_dict_archived = aw_dict['metadata'].get('archived') or {} + aw_dict_archived = aw_dict["metadata"].get("archived") or {} if aw_dict_archived: - _assert_user_timestamp_equals_dict(aw.archived_by, aw.archive_time, aw_dict_archived) + _assert_user_timestamp_equals_dict( + aw.archived_by, aw.archive_time, aw_dict_archived + ) else: assert aw.archived_by is None assert aw.archive_time is None - for plot, plot_dict in zip(aw._plots, aw_dict['data'].get('plots')): + for plot, plot_dict in zip(aw._plots, aw_dict["data"].get("plots")): _assert_aw_plot_equals_dict(plot, plot_dict) @@ -61,76 +67,110 @@ def _assert_aw_equals_dict(aw, aw_dict): def session(): return FakeSession() + @pytest.fixture def team_id(): return uuid.uuid4() + @pytest.fixture def collection(session, team_id): return AnalysisWorkflowCollection(session, team_id=team_id) + @pytest.fixture def base_path(team_id): - return f'/teams/{team_id}/analysis-workflows' + return f"/teams/{team_id}/analysis-workflows" def test_register(session, collection, base_path): aw_data = AnalysisWorkflowEntityDataFactory(data__plot_count=3) session.set_response(aw_data) - aw_module = AnalysisWorkflow(**aw_data['data']) + aw_module = AnalysisWorkflow(**aw_data["data"]) aw = collection.register(aw_module) expected_payload = { - **aw_data['data'], - "plots": [plot['data'] for plot in aw_data['data']['plots']] + **aw_data["data"], + "plots": [plot["data"] for plot in aw_data["data"]["plots"]], } - assert session.calls == [FakeCall(method='POST', path=base_path, json=expected_payload)] + assert session.calls == [ + FakeCall(method="POST", path=base_path, json=expected_payload) + ] _assert_aw_equals_dict(aw, aw_data) def test_get(session, collection, base_path): aw_data = AnalysisWorkflowEntityDataFactory() session.set_response(aw_data) - - aw = collection.get(aw_data['id']) - assert session.calls == [FakeCall(method='GET', path=f'{base_path}/{aw_data["id"]}')] + aw = collection.get(aw_data["id"]) + + assert session.calls == [ + FakeCall(method="GET", path=f"{base_path}/{aw_data['id']}") + ] _assert_aw_equals_dict(aw, aw_data) def test_list_all(session, collection, base_path): - aw_data = [AnalysisWorkflowEntityDataFactory(metadata__is_archived=random.choice((True, False))) for _ in range(5)] + aw_data = [ + AnalysisWorkflowEntityDataFactory( + metadata__is_archived=random.choice((True, False)) + ) + for _ in range(5) + ] session.set_response(paging_response(*aw_data)) - + aws = list(collection.list_all()) - expected_call = FakeCall(method='GET', path=base_path, params={'page': 1, 'per_page': 20, 'include_archived': True}) + expected_call = FakeCall( + method="GET", + path=base_path, + params={"page": 1, "per_page": 20, "include_archived": True}, + ) assert session.calls == [expected_call] assert len(aws) == len(aw_data) def test_list_archived(session, collection, base_path): - aw_data = [AnalysisWorkflowEntityDataFactory(metadata__is_archived=False) for _ in range(3)] + aw_data = [ + AnalysisWorkflowEntityDataFactory(metadata__is_archived=False) for _ in range(3) + ] session.set_response(paging_response(*aw_data)) - + aws = list(collection.list_archived(per_page=50)) - expected_call = FakeCall(method='GET', path=base_path, params={'page': 1, 'per_page': 50, 'filter': "archived eq 'true'"}) + expected_call = FakeCall( + method="GET", + path=base_path, + params={"page": 1, "per_page": 50, "filter": "archived eq 'true'"}, + ) assert session.calls == [expected_call] assert len(aws) == len(aw_data) def test_list(session, collection, base_path): - aw_data = [AnalysisWorkflowEntityDataFactory(metadata__is_archived=False) for _ in range(3)] - session.set_responses(paging_response(*aw_data[0:2]), paging_response(*aw_data[2:4])) - + aw_data = [ + AnalysisWorkflowEntityDataFactory(metadata__is_archived=False) for _ in range(3) + ] + session.set_responses( + paging_response(*aw_data[0:2]), paging_response(*aw_data[2:4]) + ) + aws = list(collection.list(per_page=2)) expected_calls = [ - FakeCall(method='GET', path=base_path, params={'page': 1, 'per_page': 2, 'filter': "archived eq 'false'"}), - FakeCall(method='GET', path=base_path, params={'page': 2, 'per_page': 2, 'filter': "archived eq 'false'"}) + FakeCall( + method="GET", + path=base_path, + params={"page": 1, "per_page": 2, "filter": "archived eq 'false'"}, + ), + FakeCall( + method="GET", + path=base_path, + params={"page": 2, "per_page": 2, "filter": "archived eq 'false'"}, + ), ] assert session.calls == expected_calls assert len(aws) == len(aw_data) @@ -139,43 +179,55 @@ def test_list(session, collection, base_path): def test_archive(session, collection, base_path): aw_data = AnalysisWorkflowEntityDataFactory(metadata__is_archived=True) session.set_response(aw_data) - - aw = collection.archive(aw_data['id']) - assert session.calls == [FakeCall(method='PUT', path=f'{base_path}/{aw_data["id"]}/archive', json={})] + aw = collection.archive(aw_data["id"]) + + assert session.calls == [ + FakeCall(method="PUT", path=f"{base_path}/{aw_data['id']}/archive", json={}) + ] _assert_aw_equals_dict(aw, aw_data) def test_restore(session, collection, base_path): aw_data = AnalysisWorkflowEntityDataFactory(metadata__is_archived=False) session.set_response(aw_data) - - aw = collection.restore(aw_data['id']) - assert session.calls == [FakeCall(method='PUT', path=f'{base_path}/{aw_data["id"]}/restore', json={})] + aw = collection.restore(aw_data["id"]) + + assert session.calls == [ + FakeCall(method="PUT", path=f"{base_path}/{aw_data['id']}/restore", json={}) + ] _assert_aw_equals_dict(aw, aw_data) def test_update(session, collection, base_path): aw_data = AnalysisWorkflowEntityDataFactory(metadata__is_archived=False) session.set_response(aw_data) - - name, description = aw_data['data']['name'], aw_data['data']['description'] - - aw = collection.update(aw_data['id'], name=name, description=description) + + name, description = aw_data["data"]["name"], aw_data["data"]["description"] + + aw = collection.update(aw_data["id"], name=name, description=description) expected_payload = {"name": name, "description": description} - assert session.calls == [FakeCall(method='PUT', path=f'{base_path}/{aw_data["id"]}', json=expected_payload)] + assert session.calls == [ + FakeCall( + method="PUT", path=f"{base_path}/{aw_data['id']}", json=expected_payload + ) + ] _assert_aw_equals_dict(aw, aw_data) def test_rebuild(session, collection, base_path): - aw_data = AnalysisWorkflowEntityDataFactory(data__has_snapshot=True, metadata__has_build=True) + aw_data = AnalysisWorkflowEntityDataFactory( + data__has_snapshot=True, metadata__has_build=True + ) session.set_response(aw_data) - - aw = collection.rebuild(aw_data['id']) - assert session.calls == [FakeCall(method='PUT', path=f'{base_path}/{aw_data["id"]}/query/rerun', json={})] + aw = collection.rebuild(aw_data["id"]) + + assert session.calls == [ + FakeCall(method="PUT", path=f"{base_path}/{aw_data['id']}/query/rerun", json={}) + ] _assert_aw_equals_dict(aw, aw_data) diff --git a/tests/resources/test_api_error.py b/tests/resources/test_api_error.py index 6be413fc5..006fec197 100644 --- a/tests/resources/test_api_error.py +++ b/tests/resources/test_api_error.py @@ -1,51 +1,48 @@ import pytest -from citrine.resources.api_error import ApiError, ValidationError +from citrine.resources.api_error import ApiError def test_has_failure(): - error = ApiError.build({ - "code": 400, - "message": "you messed up", - "validation_errors": [ - {"failure_message": 'failure 1', "failure_id": 'fail.one'}, - {"failure_message": 'failure 2', "failure_id": 'fail.two'}, - {"failure_message": 'vague failure'}, - ] - }) - assert error.has_failure('fail.one') - assert error.has_failure('fail.two') - assert not error.has_failure('not.present') + error = ApiError.build( + { + "code": 400, + "message": "you messed up", + "validation_errors": [ + {"failure_message": "failure 1", "failure_id": "fail.one"}, + {"failure_message": "failure 2", "failure_id": "fail.two"}, + {"failure_message": "vague failure"}, + ], + } + ) + assert error.has_failure("fail.one") + assert error.has_failure("fail.two") + assert not error.has_failure("not.present") with pytest.raises(ValueError): error.has_failure(None) with pytest.raises(ValueError): - error.has_failure('') + error.has_failure("") def test_deserialization(): - msg = 'ya failed' + msg = "ya failed" missing_id = { - 'code': 400, - 'message': 'an error', - 'validation_errors': [ + "code": 400, + "message": "an error", + "validation_errors": [ { - 'failure_message': msg, + "failure_message": msg, } - ] + ], } error = ApiError.build(missing_id) assert error.validation_errors[0].failure_message == msg with_id = { - 'code': 400, - 'message': 'an error', - 'validation_errors': [ - { - 'failure_message': msg, - 'failure_id': 'foo.id' - } - ] + "code": 400, + "message": "an error", + "validation_errors": [{"failure_message": msg, "failure_id": "foo.id"}], } error_with_id = ApiError.build(with_id) - assert error_with_id.validation_errors[0].failure_id == 'foo.id' + assert error_with_id.validation_errors[0].failure_id == "foo.id" diff --git a/tests/resources/test_audit_info.py b/tests/resources/test_audit_info.py index 8db1bba9f..e26ac31b1 100644 --- a/tests/resources/test_audit_info.py +++ b/tests/resources/test_audit_info.py @@ -1,19 +1,21 @@ from uuid import uuid4 -from datetime import datetime from citrine.resources.audit_info import AuditInfo def test_audit_info_str(): - audit_info_full = AuditInfo.build({ - "created_by": str(uuid4()), - "created_at": 1559933807392, - "updated_by": str(uuid4()), - "updated_at": 1559933807392 - }) - audit_info_part = AuditInfo.build({ - "created_by": str(uuid4()), - "created_at": 1559933807392 - }) - assert 'Updated by' in str(audit_info_full) and 'Created by' in str(audit_info_full) - assert 'Updated by' not in str(audit_info_part) and 'Created by' in str(audit_info_part) + audit_info_full = AuditInfo.build( + { + "created_by": str(uuid4()), + "created_at": 1559933807392, + "updated_by": str(uuid4()), + "updated_at": 1559933807392, + } + ) + audit_info_part = AuditInfo.build( + {"created_by": str(uuid4()), "created_at": 1559933807392} + ) + assert "Updated by" in str(audit_info_full) and "Created by" in str(audit_info_full) + assert "Updated by" not in str(audit_info_part) and "Created by" in str( + audit_info_part + ) diff --git a/tests/resources/test_branch.py b/tests/resources/test_branch.py index c0feea86b..171f53633 100644 --- a/tests/resources/test_branch.py +++ b/tests/resources/test_branch.py @@ -6,13 +6,20 @@ from citrine._rest.resource import PredictorRef from citrine.exceptions import NotFound -from citrine.resources.data_version_update import NextBranchVersionRequest, DataVersionUpdate, BranchDataUpdate from citrine.resources.branch import Branch, BranchCollection -from tests.utils.factories import BranchDataFactory, BranchRootDataFactory, \ - CandidateExperimentSnapshotDataFactory, ExperimentDataSourceDataFactory, \ - BranchDataFieldFactory, BranchMetadataFieldFactory, BranchDataUpdateFactory -from tests.utils.session import FakeSession, FakeCall, FakePaginatedSession - +from citrine.resources.data_version_update import ( + BranchDataUpdate, + DataVersionUpdate, + NextBranchVersionRequest, +) +from tests.utils.factories import ( + BranchDataFactory, + BranchDataFieldFactory, + BranchDataUpdateFactory, + BranchMetadataFieldFactory, + ExperimentDataSourceDataFactory, +) +from tests.utils.session import FakeCall, FakePaginatedSession, FakeSession LATEST_VER = "latest" @@ -29,10 +36,7 @@ def paginated_session() -> FakePaginatedSession: @pytest.fixture def collection(session) -> BranchCollection: - return BranchCollection( - project_id=uuid.uuid4(), - session=session - ) + return BranchCollection(project_id=uuid.uuid4(), session=session) @pytest.fixture @@ -43,7 +47,7 @@ def branch_path(collection) -> str: def test_str(): name = "Test Branch name" branch = Branch(name=name) - assert str(branch) == f'' + assert str(branch) == f"" def test_branch_build(collection): @@ -57,18 +61,16 @@ def test_branch_build(collection): def test_branch_register(session, collection, branch_path): # Given - root_id = str(uuid.uuid4()) - name = 'branch-name' + + name = "branch-name" now = datetime.now(tz.UTC).replace(microsecond=0) now_ms = int(now.timestamp() * 1000) # ms since epoch - branch_data = BranchDataFactory(data=BranchDataFieldFactory(name=name), - metadata=BranchMetadataFieldFactory( - created={ - 'time': now_ms - }, - updated={ - 'time': now_ms - })) + branch_data = BranchDataFactory( + data=BranchDataFieldFactory(name=name), + metadata=BranchMetadataFieldFactory( + created={"time": now_ms}, updated={"time": now_ms} + ), + ) session.set_response(branch_data) # When @@ -77,12 +79,7 @@ def test_branch_register(session, collection, branch_path): # Then assert session.num_calls == 1 expected_call = FakeCall( - method='POST', - path=branch_path, - json={ - 'name': name - }, - version="v2" + method="POST", path=branch_path, json={"name": name}, version="v2" ) assert expected_call == session.last_call @@ -95,16 +92,20 @@ def test_branch_register(session, collection, branch_path): def test_branch_get(session, collection, branch_path): # Given branch_data = BranchDataFactory() - root_id = branch_data['metadata']['root_id'] - version = branch_data['metadata']['version'] + root_id = branch_data["metadata"]["root_id"] + version = branch_data["metadata"]["version"] session.set_response({"response": [branch_data]}) # When - branch = collection.get(root_id=root_id, version=version) + collection.get(root_id=root_id, version=version) # Then assert session.num_calls == 1 - assert session.last_call == FakeCall(method='GET', path=branch_path, params={'page': 1, 'per_page': 1, 'root': root_id, 'version': version}) + assert session.last_call == FakeCall( + method="GET", + path=branch_path, + params={"page": 1, "per_page": 1, "root": root_id, "version": version}, + ) def test_branch_get_not_found(session, collection, branch_path): @@ -119,29 +120,35 @@ def test_branch_get_not_found(session, collection, branch_path): def test_branch_get_by_version_id(session, collection, branch_path): # Given branch_data = BranchDataFactory() - version_id = branch_data['id'] + version_id = branch_data["id"] session.set_response(branch_data) # When - branch = collection.get_by_version_id(version_id=version_id) + collection.get_by_version_id(version_id=version_id) # Then assert session.num_calls == 1 - assert session.last_call == FakeCall(method='GET', path=f"{branch_path}/{version_id}") + assert session.last_call == FakeCall( + method="GET", path=f"{branch_path}/{version_id}" + ) def test_branch_list(session, collection, branch_path): # Given branch_count = 5 branches_data = BranchDataFactory.create_batch(branch_count) - session.set_response({'response': branches_data}) + session.set_response({"response": branches_data}) # When branches = list(collection.list()) # Then assert session.num_calls == 1 - assert session.last_call == FakeCall(method='GET', path=branch_path, params={'archived': False, 'page': 1, 'per_page': 20}) + assert session.last_call == FakeCall( + method="GET", + path=branch_path, + params={"archived": False, "page": 1, "per_page": 20}, + ) assert len(branches) == branch_count @@ -149,14 +156,16 @@ def test_branch_list_all(session, collection, branch_path): # Given branch_count = 5 branches_data = BranchDataFactory.create_batch(branch_count) - session.set_response({'response': branches_data}) + session.set_response({"response": branches_data}) # When - branches = list(collection.list_all()) + collection.list_all() # Then assert session.num_calls == 1 - assert session.last_call == FakeCall(method='GET', path=branch_path, params={'per_page': 20, 'page': 1}) + assert session.last_call == FakeCall( + method="GET", path=branch_path, params={"per_page": 20, "page": 1} + ) def test_branch_delete(session, collection, branch_path): @@ -164,11 +173,13 @@ def test_branch_delete(session, collection, branch_path): branch_id = uuid.uuid4() # When - response = collection.delete(branch_id) + collection.delete(branch_id) # Then assert session.num_calls == 1 - assert session.last_call == FakeCall(method='DELETE', path=f'{branch_path}/{branch_id}') + assert session.last_call == FakeCall( + method="DELETE", path=f"{branch_path}/{branch_id}" + ) def test_branch_update(session, collection, branch_path): @@ -182,15 +193,13 @@ def test_branch_update(session, collection, branch_path): # Then assert session.num_calls == 1 expected_call = FakeCall( - method='PUT', - path=f'{branch_path}/{branch_data["id"]}', - json={ - 'name': branch_data['data']['name'] - }, - version='v2' + method="PUT", + path=f"{branch_path}/{branch_data['id']}", + json={"name": branch_data["data"]["name"]}, + version="v2", ) assert session.last_call == expected_call - assert updated_branch.name == branch_data['data']['name'] + assert updated_branch.name == branch_data["data"]["name"] def test_branch_get_design_workflows(collection): @@ -215,12 +224,15 @@ def test_branch_get_design_workflows_no_project_id(session): def test_branch_archive(session, collection, branch_path): # Given branch_data = BranchDataFactory(metadata=BranchMetadataFieldFactory(archived=True)) - branch_id = branch_data['id'] - root_id = branch_data['metadata']['root_id'] - version = branch_data['metadata']['version'] + branch_id = branch_data["id"] + root_id = branch_data["metadata"]["root_id"] + version = branch_data["metadata"]["version"] branch_data_get_resp = {"response": [branch_data]} branch_data_get_params = { - 'page': 1, 'per_page': 1, 'root': str(root_id), 'version': version + "page": 1, + "per_page": 1, + "root": str(root_id), + "version": version, } session.set_responses(branch_data_get_resp, branch_data) @@ -228,10 +240,10 @@ def test_branch_archive(session, collection, branch_path): archived_branch = collection.archive(root_id=root_id, version=version) # Then - expected_path = f'{branch_path}/{branch_id}/archive' + expected_path = f"{branch_path}/{branch_id}/archive" assert session.calls == [ - FakeCall(method='GET', path=branch_path, params=branch_data_get_params), - FakeCall(method='PUT', path=expected_path, json={}) + FakeCall(method="GET", path=branch_path, params=branch_data_get_params), + FakeCall(method="PUT", path=expected_path, json={}), ] assert archived_branch.archived is True @@ -239,11 +251,14 @@ def test_branch_archive(session, collection, branch_path): def test_archive_version_omitted(session, collection, branch_path): # Given branch_data = BranchDataFactory(metadata=BranchMetadataFieldFactory(archived=True)) - branch_id = branch_data['id'] - root_id = branch_data['metadata']['root_id'] + branch_id = branch_data["id"] + root_id = branch_data["metadata"]["root_id"] branch_data_get_resp = {"response": [branch_data]} branch_data_get_params = { - 'page': 1, 'per_page': 1, 'root': str(root_id), 'version': LATEST_VER + "page": 1, + "per_page": 1, + "root": str(root_id), + "version": LATEST_VER, } session.set_responses(branch_data_get_resp, branch_data) @@ -251,10 +266,10 @@ def test_archive_version_omitted(session, collection, branch_path): archived_branch = collection.archive(root_id=root_id) # Then - expected_path = f'{branch_path}/{branch_id}/archive' + expected_path = f"{branch_path}/{branch_id}/archive" assert session.calls == [ - FakeCall(method='GET', path=branch_path, params=branch_data_get_params), - FakeCall(method='PUT', path=expected_path, json={}) + FakeCall(method="GET", path=branch_path, params=branch_data_get_params), + FakeCall(method="PUT", path=expected_path, json={}), ] assert archived_branch.archived is True @@ -262,12 +277,15 @@ def test_archive_version_omitted(session, collection, branch_path): def test_branch_restore(session, collection, branch_path): # Given branch_data = BranchDataFactory(metadata=BranchMetadataFieldFactory(archived=False)) - branch_id = branch_data['id'] - root_id = branch_data['metadata']['root_id'] - version = branch_data['metadata']['version'] + branch_id = branch_data["id"] + root_id = branch_data["metadata"]["root_id"] + version = branch_data["metadata"]["version"] branch_data_get_resp = {"response": [branch_data]} branch_data_get_params = { - 'page': 1, 'per_page': 1, 'root': str(root_id), 'version': version + "page": 1, + "per_page": 1, + "root": str(root_id), + "version": version, } session.set_responses(branch_data_get_resp, branch_data) @@ -275,10 +293,10 @@ def test_branch_restore(session, collection, branch_path): restored_branch = collection.restore(root_id=root_id, version=version) # Then - expected_path = f'{branch_path}/{branch_id}/restore' + expected_path = f"{branch_path}/{branch_id}/restore" assert session.calls == [ - FakeCall(method='GET', path=branch_path, params=branch_data_get_params), - FakeCall(method='PUT', path=expected_path, json={}) + FakeCall(method="GET", path=branch_path, params=branch_data_get_params), + FakeCall(method="PUT", path=expected_path, json={}), ] assert restored_branch.archived is False @@ -286,11 +304,14 @@ def test_branch_restore(session, collection, branch_path): def test_restore_version_omitted(session, collection, branch_path): # Given branch_data = BranchDataFactory(metadata=BranchMetadataFieldFactory(archived=False)) - branch_id = branch_data['id'] - root_id = branch_data['metadata']['root_id'] + branch_id = branch_data["id"] + root_id = branch_data["metadata"]["root_id"] branch_data_get_resp = {"response": [branch_data]} branch_data_get_params = { - 'page': 1, 'per_page': 1, 'root': str(root_id), 'version': LATEST_VER + "page": 1, + "per_page": 1, + "root": str(root_id), + "version": LATEST_VER, } session.set_responses(branch_data_get_resp, branch_data) @@ -298,10 +319,10 @@ def test_restore_version_omitted(session, collection, branch_path): restored_branch = collection.restore(root_id=root_id) # Then - expected_path = f'{branch_path}/{branch_id}/restore' + expected_path = f"{branch_path}/{branch_id}/restore" assert session.calls == [ - FakeCall(method='GET', path=branch_path, params=branch_data_get_params), - FakeCall(method='PUT', path=expected_path, json={}) + FakeCall(method="GET", path=branch_path, params=branch_data_get_params), + FakeCall(method="PUT", path=expected_path, json={}), ] assert restored_branch.archived is False @@ -310,114 +331,156 @@ def test_branch_list_archived(session, collection, branch_path): # Given branch_count = 5 branches_data = BranchDataFactory.create_batch(branch_count) - session.set_response({'response': branches_data}) + session.set_response({"response": branches_data}) # When - branches = list(collection.list_archived()) + collection.list_archived() # Then assert session.num_calls == 1 - assert session.last_call == FakeCall(method='GET', path=branch_path, params={'archived': True, 'per_page': 20, 'page': 1}) + assert session.last_call == FakeCall( + method="GET", + path=branch_path, + params={"archived": True, "per_page": 20, "page": 1}, + ) # Needed for coverage checks def test_branch_data_update_inits(): - data_updates = [DataVersionUpdate(current="gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::1", - latest="gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::2")] + data_updates = [ + DataVersionUpdate( + current="gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::1", + latest="gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::2", + ) + ] predictors = [PredictorRef("aa971886-d17c-43b4-b602-5af7b44fcd5a", 2)] branch_update = BranchDataUpdate(data_updates=data_updates, predictors=predictors) - assert branch_update.data_updates[0].current == "gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::1" + assert ( + branch_update.data_updates[0].current + == "gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::1" + ) def test_branch_data_updates(session, collection, branch_path): # Given branch_data = BranchDataFactory() - root_branch_id = branch_data['metadata']['root_id'] - branch_id = branch_data['id'] + root_branch_id = branch_data["metadata"]["root_id"] + branch_id = branch_data["id"] expected_data_updates = BranchDataUpdateFactory() branch_data_get_resp = {"response": [branch_data]} branch_data_get_params = { - 'page': 1, 'per_page': 1, 'root': str(root_branch_id), 'version': branch_data['metadata']['version'] + "page": 1, + "per_page": 1, + "root": str(root_branch_id), + "version": branch_data["metadata"]["version"], } session.set_responses(branch_data_get_resp, expected_data_updates) # When - actual_data_updates = collection.data_updates(root_id=root_branch_id, version=branch_data['metadata']['version']) + actual_data_updates = collection.data_updates( + root_id=root_branch_id, version=branch_data["metadata"]["version"] + ) # Then - expected_path = f'{branch_path}/{branch_id}/data-version-updates-predictor' + expected_path = f"{branch_path}/{branch_id}/data-version-updates-predictor" assert session.calls == [ - FakeCall(method='GET', path=branch_path, params=branch_data_get_params), - FakeCall(method='GET', path=expected_path, version='v2') + FakeCall(method="GET", path=branch_path, params=branch_data_get_params), + FakeCall(method="GET", path=expected_path, version="v2"), ] - assert expected_data_updates['data_updates'][0]['current'] == actual_data_updates.data_updates[0].current - assert expected_data_updates['data_updates'][0]['latest'] == actual_data_updates.data_updates[0].latest - assert expected_data_updates['predictors'][0]['predictor_id'] == str(actual_data_updates.predictors[0].uid) + assert ( + expected_data_updates["data_updates"][0]["current"] + == actual_data_updates.data_updates[0].current + ) + assert ( + expected_data_updates["data_updates"][0]["latest"] + == actual_data_updates.data_updates[0].latest + ) + assert expected_data_updates["predictors"][0]["predictor_id"] == str( + actual_data_updates.predictors[0].uid + ) def test_data_updates_version_omitted(session, collection, branch_path): # Given branch_data = BranchDataFactory() - root_branch_id = branch_data['metadata']['root_id'] - branch_id = branch_data['id'] + root_branch_id = branch_data["metadata"]["root_id"] + branch_id = branch_data["id"] expected_data_updates = BranchDataUpdateFactory() branch_data_get_resp = {"response": [branch_data]} branch_data_get_params = { - 'page': 1, 'per_page': 1, 'root': str(root_branch_id), 'version': branch_data['metadata']['version'] + "page": 1, + "per_page": 1, + "root": str(root_branch_id), + "version": branch_data["metadata"]["version"], } session.set_responses(branch_data_get_resp, expected_data_updates) # When - actual_data_updates = collection.data_updates(root_id=root_branch_id, version=branch_data['metadata']['version']) + actual_data_updates = collection.data_updates( + root_id=root_branch_id, version=branch_data["metadata"]["version"] + ) # Then - expected_path = f'{branch_path}/{branch_id}/data-version-updates-predictor' + expected_path = f"{branch_path}/{branch_id}/data-version-updates-predictor" assert session.calls == [ - FakeCall(method='GET', path=branch_path, params=branch_data_get_params), - FakeCall(method='GET', path=expected_path, version='v2') + FakeCall(method="GET", path=branch_path, params=branch_data_get_params), + FakeCall(method="GET", path=expected_path, version="v2"), ] - assert expected_data_updates['data_updates'][0]['current'] == actual_data_updates.data_updates[0].current - assert expected_data_updates['data_updates'][0]['latest'] == actual_data_updates.data_updates[0].latest - assert expected_data_updates['predictors'][0]['predictor_id'] == str(actual_data_updates.predictors[0].uid) - - + assert ( + expected_data_updates["data_updates"][0]["current"] + == actual_data_updates.data_updates[0].current + ) + assert ( + expected_data_updates["data_updates"][0]["latest"] + == actual_data_updates.data_updates[0].latest + ) + assert expected_data_updates["predictors"][0]["predictor_id"] == str( + actual_data_updates.predictors[0].uid + ) def test_branch_next_version(session, collection, branch_path): # Given branch_data = BranchDataFactory() - root_branch_id = branch_data['metadata']['root_id'] + root_branch_id = branch_data["metadata"]["root_id"] session.set_response(branch_data) - data_updates = [DataVersionUpdate(current="gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::1", - latest="gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::2")] + data_updates = [ + DataVersionUpdate( + current="gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::1", + latest="gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::2", + ) + ] predictors = [PredictorRef("aa971886-d17c-43b4-b602-5af7b44fcd5a", 2)] req = NextBranchVersionRequest(data_updates=data_updates, use_predictors=predictors) # When - branchv2 = collection.next_version(root_id=root_branch_id, branch_instructions=req, retrain_models=False) + branchv2 = collection.next_version( + root_id=root_branch_id, branch_instructions=req, retrain_models=False + ) # Then - expected_path = f'{branch_path}/next-version-predictor' - expected_call = FakeCall(method='POST', - path=expected_path, - params={'root': str(root_branch_id), - 'retrain_models': False}, - json={ - 'data_updates': [ - { - 'current': 'gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::1', - 'latest': 'gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::2', - 'type': 'DataVersionUpdate' - } - ], - 'use_predictors': [ - { - 'predictor_id': 'aa971886-d17c-43b4-b602-5af7b44fcd5a', - 'predictor_version': 2 - } - ] - }, - version='v2') + expected_path = f"{branch_path}/next-version-predictor" + expected_call = FakeCall( + method="POST", + path=expected_path, + params={"root": str(root_branch_id), "retrain_models": False}, + json={ + "data_updates": [ + { + "current": "gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::1", + "latest": "gemd::16f91e7e-0214-4866-8d7f-a4d5c2125d2b::2", + "type": "DataVersionUpdate", + } + ], + "use_predictors": [ + { + "predictor_id": "aa971886-d17c-43b4-b602-5af7b44fcd5a", + "predictor_version": 2, + } + ], + }, + version="v2", + ) assert session.num_calls == 1 assert session.last_call == expected_call assert str(branchv2.root_id) == root_branch_id @@ -429,43 +492,57 @@ def test_branch_data_updates_normal(session, collection, branch_path): root_branch_id = branch_data["metadata"]["root_id"] branch_data_get_resp = {"response": [branch_data]} branch_data_get_params = { - 'page': 1, 'per_page': 1, 'root': str(root_branch_id), 'version': branch_data['metadata']['version'] + "page": 1, + "per_page": 1, + "root": str(root_branch_id), + "version": branch_data["metadata"]["version"], } session.set_response(branch_data_get_resp) - branch = collection.get(root_id=root_branch_id, version=branch_data['metadata']['version']) + branch = collection.get( + root_id=root_branch_id, version=branch_data["metadata"]["version"] + ) data_updates = BranchDataUpdateFactory() - v2branch_data = BranchDataFactory(metadata=BranchMetadataFieldFactory(root_id=root_branch_id)) + v2branch_data = BranchDataFactory( + metadata=BranchMetadataFieldFactory(root_id=root_branch_id) + ) session.set_responses(branch_data_get_resp, data_updates, v2branch_data) v2branch = collection.update_data(root_id=branch.root_id, version=branch.version) # Then - next_version_call = FakeCall(method='POST', - path=f'{branch_path}/next-version-predictor', - params={'root': str(root_branch_id), 'retrain_models': False}, - json={ - 'data_updates': [ - { - 'current': data_updates['data_updates'][0]['current'], - 'latest': data_updates['data_updates'][0]['latest'], - 'type': 'DataVersionUpdate' - } - ], - 'use_predictors': [ - { - 'predictor_id': data_updates['predictors'][0]['predictor_id'], - 'predictor_version': data_updates['predictors'][0]['predictor_version'] - } - ] - }, - version='v2') + next_version_call = FakeCall( + method="POST", + path=f"{branch_path}/next-version-predictor", + params={"root": str(root_branch_id), "retrain_models": False}, + json={ + "data_updates": [ + { + "current": data_updates["data_updates"][0]["current"], + "latest": data_updates["data_updates"][0]["latest"], + "type": "DataVersionUpdate", + } + ], + "use_predictors": [ + { + "predictor_id": data_updates["predictors"][0]["predictor_id"], + "predictor_version": data_updates["predictors"][0][ + "predictor_version" + ], + } + ], + }, + version="v2", + ) assert session.calls == [ - FakeCall(method='GET', path=branch_path, params=branch_data_get_params), - FakeCall(method='GET', path=branch_path, params=branch_data_get_params), - FakeCall(method='GET', path=f'{branch_path}/{branch_data["id"]}/data-version-updates-predictor'), - next_version_call + FakeCall(method="GET", path=branch_path, params=branch_data_get_params), + FakeCall(method="GET", path=branch_path, params=branch_data_get_params), + FakeCall( + method="GET", + path=f"{branch_path}/{branch_data['id']}/data-version-updates-predictor", + ), + next_version_call, ] assert str(v2branch.root_id) == root_branch_id @@ -473,41 +550,57 @@ def test_branch_data_updates_normal(session, collection, branch_path): def test_branch_data_updates_latest(session, collection, branch_path): # Given branch_data = BranchDataFactory() - root_branch_id = branch_data['metadata']['root_id'] + root_branch_id = branch_data["metadata"]["root_id"] branch_data_get_resp = {"response": [branch_data]} branch_data_get_params = { - 'page': 1, 'per_page': 1, 'root': str(root_branch_id), 'version': branch_data['metadata']['version'] + "page": 1, + "per_page": 1, + "root": str(root_branch_id), + "version": branch_data["metadata"]["version"], } session.set_response(branch_data_get_resp) - branch = collection.get(root_id=root_branch_id, version=branch_data['metadata']['version']) + branch = collection.get( + root_id=root_branch_id, version=branch_data["metadata"]["version"] + ) data_updates = BranchDataUpdateFactory() - v2branch_data = BranchDataFactory(metadata=BranchMetadataFieldFactory(root_id=root_branch_id)) + v2branch_data = BranchDataFactory( + metadata=BranchMetadataFieldFactory(root_id=root_branch_id) + ) session.set_responses(branch_data_get_resp, data_updates, v2branch_data) - v2branch = collection.update_data(root_id=branch.root_id, version=branch.version, use_existing=False, retrain_models=True) + v2branch = collection.update_data( + root_id=branch.root_id, + version=branch.version, + use_existing=False, + retrain_models=True, + ) # Then - next_version_call = FakeCall(method='POST', - path=f'{branch_path}/next-version-predictor', - params={'root': str(root_branch_id), - 'retrain_models': True}, - json={ - 'data_updates': [ - { - 'current': data_updates['data_updates'][0]['current'], - 'latest': data_updates['data_updates'][0]['latest'], - 'type': 'DataVersionUpdate' - } - ], - 'use_predictors': [] - }, - version='v2') + next_version_call = FakeCall( + method="POST", + path=f"{branch_path}/next-version-predictor", + params={"root": str(root_branch_id), "retrain_models": True}, + json={ + "data_updates": [ + { + "current": data_updates["data_updates"][0]["current"], + "latest": data_updates["data_updates"][0]["latest"], + "type": "DataVersionUpdate", + } + ], + "use_predictors": [], + }, + version="v2", + ) assert session.calls == [ - FakeCall(method='GET', path=branch_path, params=branch_data_get_params), - FakeCall(method='GET', path=branch_path, params=branch_data_get_params), - FakeCall(method='GET', path=f'{branch_path}/{branch_data["id"]}/data-version-updates-predictor'), - next_version_call + FakeCall(method="GET", path=branch_path, params=branch_data_get_params), + FakeCall(method="GET", path=branch_path, params=branch_data_get_params), + FakeCall( + method="GET", + path=f"{branch_path}/{branch_data['id']}/data-version-updates-predictor", + ), + next_version_call, ] assert str(v2branch.root_id) == root_branch_id @@ -518,7 +611,10 @@ def test_branch_data_updates_nochange(session, collection, branch_path): branch_data_get_resp = {"response": [branch_data]} session.set_response(branch_data_get_resp) - branch = collection.get(root_id=branch_data['metadata']['root_id'], version=branch_data['metadata']['version']) + branch = collection.get( + root_id=branch_data["metadata"]["root_id"], + version=branch_data["metadata"]["version"], + ) data_updates = BranchDataUpdateFactory(data_updates=[], predictors=[]) session.set_responses(branch_data_get_resp, data_updates) @@ -529,30 +625,48 @@ def test_branch_data_updates_nochange(session, collection, branch_path): def test_experiment_datasource(session, collection): # Given - erds_path = f'projects/{collection.project_id}/candidate-experiment-datasources' + erds_path = f"projects/{collection.project_id}/candidate-experiment-datasources" erds = ExperimentDataSourceDataFactory() branch = collection.build(BranchDataFactory()) - session.set_response({'response': [erds]}) + session.set_response({"response": [erds]}) # When / Then assert branch.experiment_datasource is not None assert session.calls == [ - FakeCall(method='GET', path=erds_path, params={'branch': str(branch.uid), 'version': LATEST_VER, 'per_page': 100, 'page': 1}) + FakeCall( + method="GET", + path=erds_path, + params={ + "branch": str(branch.uid), + "version": LATEST_VER, + "per_page": 100, + "page": 1, + }, + ) ] def test_no_experiment_datasource(session, collection): # Given - erds_path = f'projects/{collection.project_id}/candidate-experiment-datasources' + erds_path = f"projects/{collection.project_id}/candidate-experiment-datasources" branch = collection.build(BranchDataFactory()) - session.set_response({'response': []}) + session.set_response({"response": []}) # When / Then assert branch.experiment_datasource is None assert session.calls == [ - FakeCall(method='GET', path=erds_path, params={'branch': str(branch.uid), 'version': LATEST_VER, 'per_page': 100, 'page': 1}) + FakeCall( + method="GET", + path=erds_path, + params={ + "branch": str(branch.uid), + "version": LATEST_VER, + "per_page": 100, + "page": 1, + }, + ) ] diff --git a/tests/resources/test_catalyst.py b/tests/resources/test_catalyst.py index 854d58405..fdc04425d 100644 --- a/tests/resources/test_catalyst.py +++ b/tests/resources/test_catalyst.py @@ -3,13 +3,13 @@ import pytest from citrine.resources.catalyst import CatalystResource -from citrine.resources.user import User -from citrine.informatics.catalyst.assistant import (AssistantResponse, - AssistantResponseMessage, - AssistantResponseConfig, - AssistantResponseUnsupported, - AssistantResponseInputErrors, - AssistantResponseExecError) +from citrine.informatics.catalyst.assistant import ( + AssistantResponseMessage, + AssistantResponseConfig, + AssistantResponseUnsupported, + AssistantResponseInputErrors, + AssistantResponseExecError, +) from citrine.informatics.catalyst.insights import InsightsResponse from citrine.informatics.predictors.graph_predictor import GraphPredictor from tests.utils.factories import UserDataFactory @@ -19,129 +19,122 @@ @pytest.fixture def assistant_message_data(): return { - "type": "message", - "data": { - "message": "We found the following available variables that may be relevant:\n * AtomicPolarizability for MolecularStructure" - } + "type": "message", + "data": { + "message": "We found the following available variables that may be relevant:\n * AtomicPolarizability for MolecularStructure" + }, } @pytest.fixture def assistant_config_data(): - return { - "type": "modified-config", - "data": { - "config": { - "type": "Graph", - "name": "Graph Model for 6 outputs", - "description": "Default Graph Model generated from data inspection.", - "predictors": [ - { - "type": "MeanProperty", - "name": "Mean properties for all ingredients", - "description": "Mean ingredient properties for all atomic ingredients. Missing property data is imputed from the training set.", - "input": { - "type": "Formulation", - "descriptor_key": "Flat Formulation" - }, - "properties": [ - { - "type": "Real", - "descriptor_key": "AtomicPolarizability for MolecularStructure", - "units": "", - "lower_bound": 0, - "upper_bound": 1000000000 - }, - { - "type": "Real", - "descriptor_key": "Density", - "units": "gram / centimeter ** 3", - "lower_bound": 0, - "upper_bound": 100 - } - ], - "p": 1, - "impute_properties": True, - "training_data": [], - "default_properties": {}, - "label": None - }, - { - "name": "", - "description": "", - "expression": "MixTime*Temperature", - "output": { - "descriptor_key": "MixTime_Temperature", - "lower_bound": -1.7976931348623157e+308, - "upper_bound": 1.7976931348623157e+308, - "units": "", - "type": "Real" - }, - "aliases": { - "MixTime": { - "descriptor_key": "Mix~Time", - "lower_bound": 0.0, - "upper_bound": 10000.0, - "units": "second", - "type": "Real" - }, - "Temperature": { - "descriptor_key": "Mix~Temperature", - "lower_bound": 0.0, - "upper_bound": 1000.0000000000001, - "units": "degree_Celsius", - "type": "Real" - } - }, - "type": "AnalyticExpression" - } - ] - } - } - } + return { + "type": "modified-config", + "data": { + "config": { + "type": "Graph", + "name": "Graph Model for 6 outputs", + "description": "Default Graph Model generated from data inspection.", + "predictors": [ + { + "type": "MeanProperty", + "name": "Mean properties for all ingredients", + "description": "Mean ingredient properties for all atomic ingredients. Missing property data is imputed from the training set.", + "input": { + "type": "Formulation", + "descriptor_key": "Flat Formulation", + }, + "properties": [ + { + "type": "Real", + "descriptor_key": "AtomicPolarizability for MolecularStructure", + "units": "", + "lower_bound": 0, + "upper_bound": 1000000000, + }, + { + "type": "Real", + "descriptor_key": "Density", + "units": "gram / centimeter ** 3", + "lower_bound": 0, + "upper_bound": 100, + }, + ], + "p": 1, + "impute_properties": True, + "training_data": [], + "default_properties": {}, + "label": None, + }, + { + "name": "", + "description": "", + "expression": "MixTime*Temperature", + "output": { + "descriptor_key": "MixTime_Temperature", + "lower_bound": -1.7976931348623157e308, + "upper_bound": 1.7976931348623157e308, + "units": "", + "type": "Real", + }, + "aliases": { + "MixTime": { + "descriptor_key": "Mix~Time", + "lower_bound": 0.0, + "upper_bound": 10000.0, + "units": "second", + "type": "Real", + }, + "Temperature": { + "descriptor_key": "Mix~Temperature", + "lower_bound": 0.0, + "upper_bound": 1000.0000000000001, + "units": "degree_Celsius", + "type": "Real", + }, + }, + "type": "AnalyticExpression", + }, + ], + } + }, + } + @pytest.fixture def assistant_unsupported_data(): return { - "type": "unsupported", - "data": { - "message": "Sorry, adding a featurizer is not currently supported. Please try again." - } + "type": "unsupported", + "data": { + "message": "Sorry, adding a featurizer is not currently supported. Please try again." + }, } @pytest.fixture def assistant_input_error_data(): return { - "type": "input-error", - "data": { - "request_dict": { - "question": "Is polarizability being considered?", - "config": "hello", - "language_model": "gpt-4-16k" + "type": "input-error", + "data": { + "request_dict": { + "question": "Is polarizability being considered?", + "config": "hello", + "language_model": "gpt-4-16k", + }, + "errors": [ + {"field": "config", "error": "Input should be a valid dictionary"}, + { + "field": "language_model", + "error": "Input should be 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4' or 'gpt-4-32k'", + }, + ], }, - "errors": [ - { - "field": "config", - "error": "Input should be a valid dictionary" - }, - { - "field": "language_model", - "error": "Input should be 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4' or 'gpt-4-32k'" - } - ] - } } @pytest.fixture def assistant_exec_error_data(): - return { - "type": "exec-error", - "data": { - "error": "An internal error occurred." - } - } + return {"type": "exec-error", "data": {"error": "An internal error occurred."}} @pytest.fixture @@ -190,14 +183,23 @@ def test_assistant_external_user(session, catalyst, external_user_data): catalyst.assistant("Test query", predictor=assistant_predictor) -def test_assistant_invalid_response(session, catalyst, internal_user_data, assistant_message_data, assistant_predictor): +def test_assistant_invalid_response( + session, catalyst, internal_user_data, assistant_message_data, assistant_predictor +): session.set_responses(internal_user_data, {**assistant_message_data, "type": "foo"}) with pytest.raises(ValueError): catalyst.assistant("Test query", predictor=assistant_predictor) -def test_assistant_message(session, catalyst, internal_user_data, assistant_message_data, assistant_predictor, assistant_predictor_data): +def test_assistant_message( + session, + catalyst, + internal_user_data, + assistant_message_data, + assistant_predictor, + assistant_predictor_data, +): session.set_responses(internal_user_data, assistant_message_data) query = "Test query" @@ -207,11 +209,13 @@ def test_assistant_message(session, catalyst, internal_user_data, assistant_mess "question": query, "config": assistant_predictor_data["data"]["instance"], "temperature": 0.0, - "language_model": "gpt-4" + "language_model": "gpt-4", } expected_calls = [ FakeCall(method="GET", path="/users/me"), - FakeCall(method="POST", path="/catalyst/assistant", json=expected_assistant_request) + FakeCall( + method="POST", path="/catalyst/assistant", json=expected_assistant_request + ), ] assert isinstance(resp, AssistantResponseMessage) @@ -219,7 +223,14 @@ def test_assistant_message(session, catalyst, internal_user_data, assistant_mess assert resp.message == assistant_message_data["data"]["message"] -def test_assistant_config(session, catalyst, internal_user_data, assistant_config_data, assistant_predictor, assistant_predictor_data): +def test_assistant_config( + session, + catalyst, + internal_user_data, + assistant_config_data, + assistant_predictor, + assistant_predictor_data, +): assistant_config_data_orig = deepcopy(assistant_config_data) session.set_responses(internal_user_data, assistant_config_data) @@ -231,19 +242,33 @@ def test_assistant_config(session, catalyst, internal_user_data, assistant_confi "question": query, "config": assistant_predictor_data["data"]["instance"], "temperature": 0.0, - "language_model": "gpt-4" + "language_model": "gpt-4", } expected_calls = [ FakeCall(method="GET", path="/users/me"), - FakeCall(method="POST", path="/catalyst/assistant", json=expected_assistant_request) + FakeCall( + method="POST", path="/catalyst/assistant", json=expected_assistant_request + ), ] assert isinstance(resp, AssistantResponseConfig) assert session.calls == expected_calls - assert resp.predictor.dump() == GraphPredictor.build(GraphPredictor.wrap_instance(assistant_config_data_orig["data"]["config"])).dump() - - -def test_assistant_unsupported(session, catalyst, internal_user_data, assistant_unsupported_data, assistant_predictor, assistant_predictor_data): + assert ( + resp.predictor.dump() + == GraphPredictor.build( + GraphPredictor.wrap_instance(assistant_config_data_orig["data"]["config"]) + ).dump() + ) + + +def test_assistant_unsupported( + session, + catalyst, + internal_user_data, + assistant_unsupported_data, + assistant_predictor, + assistant_predictor_data, +): session.set_responses(internal_user_data, assistant_unsupported_data) query = "Test query" @@ -253,11 +278,13 @@ def test_assistant_unsupported(session, catalyst, internal_user_data, assistant_ "question": query, "config": assistant_predictor_data["data"]["instance"], "temperature": 0.0, - "language_model": "gpt-4" + "language_model": "gpt-4", } expected_calls = [ FakeCall(method="GET", path="/users/me"), - FakeCall(method="POST", path="/catalyst/assistant", json=expected_assistant_request) + FakeCall( + method="POST", path="/catalyst/assistant", json=expected_assistant_request + ), ] assert isinstance(resp, AssistantResponseUnsupported) @@ -265,7 +292,14 @@ def test_assistant_unsupported(session, catalyst, internal_user_data, assistant_ assert resp.message == assistant_unsupported_data["data"]["message"] -def test_assistant_input_error(session, catalyst, internal_user_data, assistant_input_error_data, assistant_predictor, assistant_predictor_data): +def test_assistant_input_error( + session, + catalyst, + internal_user_data, + assistant_input_error_data, + assistant_predictor, + assistant_predictor_data, +): session.set_responses(internal_user_data, assistant_input_error_data) query = "Test query" @@ -275,11 +309,13 @@ def test_assistant_input_error(session, catalyst, internal_user_data, assistant_ "question": query, "config": assistant_predictor_data["data"]["instance"], "temperature": 0.0, - "language_model": "gpt-4" + "language_model": "gpt-4", } expected_calls = [ FakeCall(method="GET", path="/users/me"), - FakeCall(method="POST", path="/catalyst/assistant", json=expected_assistant_request) + FakeCall( + method="POST", path="/catalyst/assistant", json=expected_assistant_request + ), ] assert isinstance(resp, AssistantResponseInputErrors) @@ -287,7 +323,14 @@ def test_assistant_input_error(session, catalyst, internal_user_data, assistant_ assert resp.dump()["data"]["errors"] == assistant_input_error_data["data"]["errors"] -def test_assistant_exec_error(session, catalyst, internal_user_data, assistant_exec_error_data, assistant_predictor, assistant_predictor_data): +def test_assistant_exec_error( + session, + catalyst, + internal_user_data, + assistant_exec_error_data, + assistant_predictor, + assistant_predictor_data, +): session.set_responses(internal_user_data, assistant_exec_error_data) query = "Test query" @@ -297,11 +340,13 @@ def test_assistant_exec_error(session, catalyst, internal_user_data, assistant_e "question": query, "config": assistant_predictor_data["data"]["instance"], "temperature": 0.0, - "language_model": "gpt-4" + "language_model": "gpt-4", } expected_calls = [ FakeCall(method="GET", path="/users/me"), - FakeCall(method="POST", path="/catalyst/assistant", json=expected_assistant_request) + FakeCall( + method="POST", path="/catalyst/assistant", json=expected_assistant_request + ), ] assert isinstance(resp, AssistantResponseExecError) diff --git a/tests/resources/test_data_concepts.py b/tests/resources/test_data_concepts.py index d9c359550..f60e8618a 100644 --- a/tests/resources/test_data_concepts.py +++ b/tests/resources/test_data_concepts.py @@ -3,22 +3,26 @@ import pytest -from gemd.entity.dict_serializable import DictSerializable -from gemd.entity.template import ProcessTemplate as GEMDTemplate from gemd.entity.link_by_uid import LinkByUID from citrine.resources.audit_info import AuditInfo -from citrine.resources.data_concepts import DataConcepts, _make_link_by_uid, CITRINE_SCOPE, DataConceptsCollection +from citrine.resources.data_concepts import ( + DataConcepts, + _make_link_by_uid, + CITRINE_SCOPE, +) from citrine.resources.process_run import ProcessRun from citrine.resources.process_spec import ProcessSpec, ProcessSpecCollection from tests.utils.session import FakeCall, FakeSession -def run_noop_gemd_relation_search_test(search_for, search_with, collection, search_fn, per_page=100): +def run_noop_gemd_relation_search_test( + search_for, search_with, collection, search_fn, per_page=100 +): """Test that relation searches hit the correct endpoint.""" - collection.session.set_response({'contents': []}) - test_id = 'foo-id' - test_scope = 'foo-scope' + collection.session.set_response({"contents": []}) + test_id = "foo-id" + test_scope = "foo-scope" result = search_fn(LinkByUID(id=test_id, scope=test_scope)) if isinstance(result, Iterator): # evaluate iterator to make calls happen @@ -26,57 +30,77 @@ def run_noop_gemd_relation_search_test(search_for, search_with, collection, sear assert collection.session.num_calls == 1 assert collection.session.last_call == FakeCall( method="GET", - path="teams/{}/{}/{}/{}/{}".format(collection.team_id, search_with, test_scope, test_id, search_for), - params={"dataset_id": str(collection.dataset_id), "forward": True, "ascending": True, "per_page": per_page} + path="teams/{}/{}/{}/{}/{}".format( + collection.team_id, search_with, test_scope, test_id, search_for + ), + params={ + "dataset_id": str(collection.dataset_id), + "forward": True, + "ascending": True, + "per_page": per_page, + }, ) + def test_deprication_of_positional_arguments(): session = FakeSession() - team_id = UUID('6b608f78-e341-422c-8076-35adc8828000') - check_project = {'project': {'team': {'id': team_id}}} + team_id = UUID("6b608f78-e341-422c-8076-35adc8828000") + check_project = {"project": {"team": {"id": team_id}}} session.set_response(check_project) with pytest.deprecated_call(): ProcessSpecCollection(uuid4(), uuid4(), session) with pytest.raises(TypeError): ProcessSpecCollection(project_id=uuid4(), dataset_id=uuid4(), session=None) + def test_assign_audit_info(): """Test that audit_info can be injected with build but not set""" - assert ProcessSpec("Spec with no audit info").audit_info is None, \ + assert ProcessSpec("Spec with no audit info").audit_info is None, ( "Audit info should be None by default" + ) - audit_info_dict = {'created_by': str(uuid4()), 'created_at': 1560033807392} + audit_info_dict = {"created_by": str(uuid4()), "created_at": 1560033807392} audit_info_obj = AuditInfo.build(audit_info_dict) - sample_object = ProcessSpec.build({ - 'type': 'process_spec', - 'name': "A process spec", - "audit_info": audit_info_dict - }) - assert sample_object.audit_info == audit_info_obj, "Audit info should be built from a dict" + sample_object = ProcessSpec.build( + { + "type": "process_spec", + "name": "A process spec", + "audit_info": audit_info_dict, + } + ) + assert sample_object.audit_info == audit_info_obj, ( + "Audit info should be built from a dict" + ) - another_object = ProcessSpec.build({ - 'type': 'process_spec', 'name': "A process spec", "audit_info": audit_info_obj - }) - assert another_object.audit_info == audit_info_obj, "Audit info should be built from an obj" + another_object = ProcessSpec.build( + {"type": "process_spec", "name": "A process spec", "audit_info": audit_info_obj} + ) + assert another_object.audit_info == audit_info_obj, ( + "Audit info should be built from an obj" + ) with pytest.raises(AttributeError, match=r"can't set attribute|has no setter"): sample_object.audit_info = None with pytest.raises(ValueError, match=r"is not one of valid types.*audit_info"): - ProcessSpec.build({ - 'type': 'process_spec', - 'name': "A process spec", - "audit_info": "Created by me, yesterday" - }) + ProcessSpec.build( + { + "type": "process_spec", + "name": "A process spec", + "audit_info": "Created by me, yesterday", + } + ) def test_make_link_by_uid(): """Test that _make_link_by_uid convenience method works.""" uid = uuid4() expected_link = LinkByUID(scope=CITRINE_SCOPE, id=str(uid)) - spec = ProcessSpec("spec", uids={"custom scope": "custom id", CITRINE_SCOPE: str(uid)}) + spec = ProcessSpec( + "spec", uids={"custom scope": "custom id", CITRINE_SCOPE: str(uid)} + ) assert _make_link_by_uid(spec) == expected_link assert _make_link_by_uid(expected_link) == expected_link assert _make_link_by_uid(uid) == expected_link @@ -84,7 +108,9 @@ def test_make_link_by_uid(): # If there's no Citrine ID, use an available ID no_citrine_id = ProcessSpec("spec", uids={"custom scope": "custom id"}) - assert _make_link_by_uid(no_citrine_id) == LinkByUID(scope="custom scope", id="custom id") + assert _make_link_by_uid(no_citrine_id) == LinkByUID( + scope="custom scope", id="custom id" + ) with pytest.raises(ValueError): _make_link_by_uid(ProcessSpec("spec")) # no ids diff --git a/tests/resources/test_dataset.py b/tests/resources/test_dataset.py index 83dad4842..955a4df9f 100644 --- a/tests/resources/test_dataset.py +++ b/tests/resources/test_dataset.py @@ -1,31 +1,49 @@ from collections import defaultdict -from os.path import basename from uuid import UUID, uuid4 import pytest +from gemd.demo.cake import get_demo_scope, get_template_scope, make_cake from gemd.entity.bounds.integer_bounds import IntegerBounds -from gemd.demo.cake import make_cake, get_demo_scope, get_template_scope -from gemd.util import recursive_flatmap, flatten +from gemd.util import flatten, recursive_flatmap from citrine.exceptions import NotFound -from citrine.resources.condition_template import ConditionTemplateCollection, ConditionTemplate +from citrine.resources.condition_template import ( + ConditionTemplate, + ConditionTemplateCollection, +) from citrine.resources.dataset import DatasetCollection -from citrine.resources.gemd_resource import GEMDResourceCollection -from citrine.resources.material_run import MaterialRunCollection, MaterialRun -from citrine.resources.material_spec import MaterialSpecCollection, MaterialSpec -from citrine.resources.material_template import MaterialTemplateCollection, MaterialTemplate -from citrine.resources.measurement_run import MeasurementRunCollection, MeasurementRun -from citrine.resources.measurement_spec import MeasurementSpec, MeasurementSpecCollection -from citrine.resources.measurement_template import MeasurementTemplate, \ - MeasurementTemplateCollection -from citrine.resources.parameter_template import ParameterTemplateCollection, ParameterTemplate -from citrine.resources.process_run import ProcessRunCollection, ProcessRun -from citrine.resources.process_spec import ProcessSpecCollection, ProcessSpec -from citrine.resources.process_template import ProcessTemplateCollection, ProcessTemplate -from citrine.resources.property_template import PropertyTemplateCollection, PropertyTemplate -from tests.utils.factories import DatasetDataFactory, DatasetFactory from citrine.resources.delete import _async_gemd_batch_delete -from tests.utils.session import FakeSession, FakePaginatedSession, FakeCall +from citrine.resources.material_run import MaterialRun, MaterialRunCollection +from citrine.resources.material_spec import MaterialSpec, MaterialSpecCollection +from citrine.resources.material_template import ( + MaterialTemplate, + MaterialTemplateCollection, +) +from citrine.resources.measurement_run import MeasurementRun, MeasurementRunCollection +from citrine.resources.measurement_spec import ( + MeasurementSpec, + MeasurementSpecCollection, +) +from citrine.resources.measurement_template import ( + MeasurementTemplate, + MeasurementTemplateCollection, +) +from citrine.resources.parameter_template import ( + ParameterTemplate, + ParameterTemplateCollection, +) +from citrine.resources.process_run import ProcessRun, ProcessRunCollection +from citrine.resources.process_spec import ProcessSpec, ProcessSpecCollection +from citrine.resources.process_template import ( + ProcessTemplate, + ProcessTemplateCollection, +) +from citrine.resources.property_template import ( + PropertyTemplate, + PropertyTemplateCollection, +) +from tests.utils.factories import DatasetDataFactory, DatasetFactory +from tests.utils.session import FakeCall, FakePaginatedSession, FakeSession @pytest.fixture @@ -41,51 +59,55 @@ def paginated_session() -> FakePaginatedSession: @pytest.fixture def collection(session) -> DatasetCollection: return DatasetCollection( - team_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), - session=session + team_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), session=session ) @pytest.fixture def paginated_collection(paginated_session) -> DatasetCollection: return DatasetCollection( - team_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), - session=paginated_session + team_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), session=paginated_session ) @pytest.fixture(scope="function") def dataset(): - dataset = DatasetFactory(name='Test Dataset') - dataset.team_id = UUID('6b608f78-e341-422c-8076-35adc8828545') + dataset = DatasetFactory(name="Test Dataset") + dataset.team_id = UUID("6b608f78-e341-422c-8076-35adc8828545") dataset.uid = UUID("503d7bf6-8e2d-4d29-88af-257af0d4fe4a") dataset.session = FakeSession() return dataset + def test_deprecation_of_positional_arguments(session): - team_id = UUID('6b608f78-e341-422c-8076-35adc8828000') - check_project = {'project': {'team': {'id': team_id}}} + team_id = UUID("6b608f78-e341-422c-8076-35adc8828000") + check_project = {"project": {"team": {"id": team_id}}} session.set_response(check_project) with pytest.deprecated_call(): - dset = DatasetCollection(uuid4(), session) + DatasetCollection(uuid4(), session) with pytest.raises(TypeError): - dset = DatasetCollection(project_id=uuid4(), session=None) + DatasetCollection(project_id=uuid4(), session=None) + def test_register_dataset(collection, session): # Given - name = 'Test Dataset' - summary = 'testing summary' - description = 'testing description' - session.set_response(DatasetDataFactory(name=name, summary=summary, description=description)) + name = "Test Dataset" + summary = "testing summary" + description = "testing description" + session.set_response( + DatasetDataFactory(name=name, summary=summary, description=description) + ) # When - dataset = collection.register(DatasetFactory(name=name, summary=summary, description=description)) + dataset = collection.register( + DatasetFactory(name=name, summary=summary, description=description) + ) expected_call = FakeCall( - method='POST', - path='teams/{}/datasets'.format(collection.team_id), - json={'name': name, 'summary': summary, 'description': description} + method="POST", + path="teams/{}/datasets".format(collection.team_id), + json={"name": name, "summary": summary, "description": description}, ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -94,20 +116,33 @@ def test_register_dataset(collection, session): def test_register_dataset_with_idempotent_put(collection, session): # Given - name = 'Test Dataset' - summary = 'testing summary' - description = 'testing description' - unique_name = 'foo' - session.set_response(DatasetDataFactory(name=name, summary=summary, description=description, unique_name=unique_name)) + name = "Test Dataset" + summary = "testing summary" + description = "testing description" + unique_name = "foo" + session.set_response( + DatasetDataFactory( + name=name, summary=summary, description=description, unique_name=unique_name + ) + ) # When session.use_idempotent_dataset_put = True - dataset = collection.register(DatasetFactory(name=name, summary=summary, description=description, unique_name=unique_name)) + dataset = collection.register( + DatasetFactory( + name=name, summary=summary, description=description, unique_name=unique_name + ) + ) expected_call = FakeCall( - method='PUT', - path='teams/{}/datasets'.format(collection.team_id), - json={'name': name, 'summary': summary, 'description': description, 'unique_name': unique_name} + method="PUT", + path="teams/{}/datasets".format(collection.team_id), + json={ + "name": name, + "summary": summary, + "description": description, + "unique_name": unique_name, + }, ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -116,24 +151,29 @@ def test_register_dataset_with_idempotent_put(collection, session): def test_register_dataset_with_existing_id(collection, session): # Given - name = 'Test Dataset' - summary = 'testing summary' - description = 'testing description' - session.set_response(DatasetDataFactory(name=name, summary=summary, description=description)) + name = "Test Dataset" + summary = "testing summary" + description = "testing description" + session.set_response( + DatasetDataFactory(name=name, summary=summary, description=description) + ) # When - dataset = DatasetFactory(name=name, summary=summary, - description=description) + dataset = DatasetFactory(name=name, summary=summary, description=description) - ds_uid = UUID('cafebeef-e341-422c-8076-35adc8828545') + ds_uid = UUID("cafebeef-e341-422c-8076-35adc8828545") dataset.uid = ds_uid dataset = collection.register(dataset) expected_call = FakeCall( - method='PUT', - path='teams/{}/datasets/{}'.format(collection.team_id, ds_uid), - json={'name': name, 'summary': summary, 'description': description, - 'id': str(ds_uid)} + method="PUT", + path="teams/{}/datasets/{}".format(collection.team_id, ds_uid), + json={ + "name": name, + "summary": summary, + "description": description, + "id": str(ds_uid), + }, ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -142,7 +182,7 @@ def test_register_dataset_with_existing_id(collection, session): def test_get_by_unique_name_with_single_result(collection, session): # Given - name = 'Test Dataset' + name = "Test Dataset" unique_name = "foo" session.set_response([DatasetDataFactory(name=name, unique_name=unique_name)]) @@ -151,8 +191,8 @@ def test_get_by_unique_name_with_single_result(collection, session): # Then expected_call = FakeCall( - method='GET', - path='teams/{}/datasets?unique_name={}'.format(collection.team_id, unique_name) + method="GET", + path="teams/{}/datasets?unique_name={}".format(collection.team_id, unique_name), ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -176,7 +216,6 @@ def test_get_by_unique_name_no_unique_name_present(collection, session): def test_get_by_unique_name_multiple_results(collection, session): - # This really shouldn't happen # Given @@ -197,15 +236,21 @@ def test_list_datasets(paginated_collection, paginated_session): # Then assert 3 == paginated_session.num_calls - expected_first_call = FakeCall(method='GET', path='teams/{}/datasets'.format(paginated_collection.team_id), - params={'per_page': 20, 'page': 1}) - expected_last_call = FakeCall(method='GET', path='teams/{}/datasets'.format(paginated_collection.team_id), - params={'page': 3, 'per_page': 20}) + expected_first_call = FakeCall( + method="GET", + path="teams/{}/datasets".format(paginated_collection.team_id), + params={"per_page": 20, "page": 1}, + ) + expected_last_call = FakeCall( + method="GET", + path="teams/{}/datasets".format(paginated_collection.team_id), + params={"page": 3, "per_page": 20}, + ) assert expected_first_call == paginated_session.calls[0] assert expected_last_call == paginated_session.last_call assert 50 == len(datasets) - expected_uids = [d['id'] for d in datasets_data] + expected_uids = [d["id"] for d in datasets_data] dataset_ids = [str(d.uid) for d in datasets] assert dataset_ids == expected_uids @@ -223,15 +268,21 @@ def test_list_datasets_infinite_loop_detect(paginated_collection, paginated_sess # Then assert 2 == paginated_session.num_calls # duplicate UID detected on the second call - expected_first_call = FakeCall(method='GET', path='teams/{}/datasets'.format(paginated_collection.team_id), - params={'per_page': batch_size, 'page': 1}) - expected_last_call = FakeCall(method='GET', path='teams/{}/datasets'.format(paginated_collection.team_id), - params={'page': 2, 'per_page': batch_size}) + expected_first_call = FakeCall( + method="GET", + path="teams/{}/datasets".format(paginated_collection.team_id), + params={"per_page": batch_size, "page": 1}, + ) + expected_last_call = FakeCall( + method="GET", + path="teams/{}/datasets".format(paginated_collection.team_id), + params={"page": 2, "per_page": batch_size}, + ) assert expected_first_call == paginated_session.calls[0] assert expected_last_call == paginated_session.last_call assert len(datasets) == batch_size - expected_uids = [d['id'] for d in datasets_data[0:batch_size]] + expected_uids = [d["id"] for d in datasets_data[0:batch_size]] dataset_ids = [str(d.uid) for d in datasets] assert dataset_ids == expected_uids @@ -245,8 +296,9 @@ def test_delete_dataset(collection, session, dataset): # Then assert 1 == session.num_calls - expected_call = FakeCall(method='DELETE', path='teams/{}/datasets/{}'.format( - collection.team_id, uid)) + expected_call = FakeCall( + method="DELETE", path="teams/{}/datasets/{}".format(collection.team_id, uid) + ) assert expected_call == session.last_call @@ -335,12 +387,16 @@ def test_gemd_posts(dataset): MeasurementSpecCollection: MeasurementSpec("foo"), MeasurementRunCollection: MeasurementRun("foo"), PropertyTemplateCollection: PropertyTemplate("bar", bounds=IntegerBounds(0, 1)), - ParameterTemplateCollection: ParameterTemplate("bar", bounds=IntegerBounds(0, 1)), - ConditionTemplateCollection: ConditionTemplate("bar", bounds=IntegerBounds(0, 1)) + ParameterTemplateCollection: ParameterTemplate( + "bar", bounds=IntegerBounds(0, 1) + ), + ConditionTemplateCollection: ConditionTemplate( + "bar", bounds=IntegerBounds(0, 1) + ), } for collection, obj in expected.items(): - obj.name = 'This is my name' + obj.name = "This is my name" # Register the objects assert len(obj.uids) == 0 @@ -351,7 +407,7 @@ def test_gemd_posts(dataset): assert pair[1] == registered.uids[pair[0]] # Update the objects - registered.name = 'Name change!' + registered.name = "Name change!" updated = dataset.update(registered) assert registered.name == updated.name assert len(updated.uids) == 1 @@ -378,7 +434,9 @@ def test_gemd_posts(dataset): assert pair not in seen_ids # All ids are different seen_ids.add(pair) - after = dataset.register_all(before, status_bar=True) # Exercise the status_bar path + after = dataset.register_all( + before, status_bar=True + ) # Exercise the status_bar path assert len(before) == len(after) for obj in after: for pair in obj.uids.items(): @@ -389,7 +447,9 @@ def test_register_all_nested(dataset): cake = make_cake() after = dataset.register_all([cake], include_nested=True) assert cake in after - assert len(after) == len(recursive_flatmap(cake, lambda o: [o], unidirectional=False)) + assert len(after) == len( + recursive_flatmap(cake, lambda o: [o], unidirectional=False) + ) def test_register_all_iterable(dataset): @@ -406,7 +466,9 @@ def test_register_all_iterable(dataset): for c in cake_set: scope = get_demo_scope() if get_demo_scope() in c.uids else get_template_scope() assert c.to_link(scope) in dry_dict, f"Results didn't contain {c.typ} {c.name}" - assert all(c == d for d in dry_dict[c.to_link(scope)]), f"Not all matched {c.typ} {c.name}" + assert all(c == d for d in dry_dict[c.to_link(scope)]), ( + f"Not all matched {c.typ} {c.name}" + ) del dry_dict[c.to_link(scope)] assert len(dry_dict) == 0, f"{len(dry_dict)} unmatched objects" @@ -418,13 +480,19 @@ def test_register_all_iterable(dataset): for c in cake_set: scope = get_demo_scope() if get_demo_scope() in c.uids else get_template_scope() assert c.to_link(scope) in wet_dict, f"Results didn't contain {c.typ} {c.name}" - assert all(c == w for w in wet_dict[c.to_link(scope)]), f"Not all matched {c.typ} {c.name}" + assert all(c == w for w in wet_dict[c.to_link(scope)]), ( + f"Not all matched {c.typ} {c.name}" + ) del wet_dict[c.to_link(scope)] assert len(wet_dict) == 0, f"{len(wet_dict)} unmatched objects" + def test_batch_delete_malformed(session): with pytest.raises(TypeError): - _async_gemd_batch_delete(id_list=[uuid4()], session=session, team_id=None, dataset_id=None) + _async_gemd_batch_delete( + id_list=[uuid4()], session=session, team_id=None, dataset_id=None + ) + def test_gemd_batch_delete(dataset): """Pass through to GEMDResourceCollection working.""" @@ -435,26 +503,25 @@ def test_gemd_batch_delete(dataset): @pytest.mark.parametrize("prompt_to_confirm", [None, False]) @pytest.mark.parametrize("remove_templates", [False, True]) def test_delete_contents(dataset, prompt_to_confirm, remove_templates): - - job_resp = { - 'job_id': '1234' - } + job_resp = {"job_id": "1234"} failed_job_resp = { - 'job_type': 'batch_delete', - 'status': 'Success', - 'tasks': [], - 'output': { + "job_type": "batch_delete", + "status": "Success", + "tasks": [], + "output": { # Keep in mind this is a stringified JSON value. Eww. - 'failures': '[]' - } + "failures": "[]" + }, } session = dataset.session session.set_responses(job_resp, failed_job_resp) # When - del_resp = dataset.delete_contents(prompt_to_confirm=prompt_to_confirm, remove_templates=remove_templates) + del_resp = dataset.delete_contents( + prompt_to_confirm=prompt_to_confirm, remove_templates=remove_templates + ) # Then assert len(del_resp) == 0 @@ -462,36 +529,29 @@ def test_delete_contents(dataset, prompt_to_confirm, remove_templates): # Ensure we made the expected delete call path = f"teams/{dataset.team_id}/datasets/{dataset.uid}/contents" params = {"remove_templates": remove_templates} - expected_call = FakeCall( - method='DELETE', - path=path, - params=params - ) + expected_call = FakeCall(method="DELETE", path=path, params=params) assert len(session.calls) == 2 assert session.calls[0] == expected_call def test_delete_contents_ok(dataset, monkeypatch): - - job_resp = { - 'job_id': '1234' - } + job_resp = {"job_id": "1234"} failed_job_resp = { - 'job_type': 'batch_delete', - 'status': 'Success', - 'tasks': [], - 'output': { + "job_type": "batch_delete", + "status": "Success", + "tasks": [], + "output": { # Keep in mind this is a stringified JSON value. Eww. - 'failures': '[]' - } + "failures": "[]" + }, } session = dataset.session session.set_responses(job_resp, failed_job_resp) - user_responses = iter(['bad user response', 'Y']) - monkeypatch.setattr('builtins.input', lambda: next(user_responses)) + user_responses = iter(["bad user response", "Y"]) + monkeypatch.setattr("builtins.input", lambda: next(user_responses)) # When del_resp = dataset.delete_contents(prompt_to_confirm=True) @@ -501,16 +561,16 @@ def test_delete_contents_ok(dataset, monkeypatch): # Ensure we made the expected delete call expected_call = FakeCall( - method='DELETE', - path='teams/{}/datasets/{}/contents'.format(dataset.team_id, dataset.uid), - params={"remove_templates": True} + method="DELETE", + path="teams/{}/datasets/{}/contents".format(dataset.team_id, dataset.uid), + params={"remove_templates": True}, ) assert len(session.calls) == 2 assert session.calls[0] == expected_call def test_delete_contents_abort(dataset, monkeypatch): - user_responses = iter(['N']) - monkeypatch.setattr('builtins.input', lambda: next(user_responses)) + user_responses = iter(["N"]) + monkeypatch.setattr("builtins.input", lambda: next(user_responses)) with pytest.raises(RuntimeError): dataset.delete_contents(prompt_to_confirm=True) diff --git a/tests/resources/test_default_labels.py b/tests/resources/test_default_labels.py index b6d8267b0..357088da9 100644 --- a/tests/resources/test_default_labels.py +++ b/tests/resources/test_default_labels.py @@ -2,6 +2,7 @@ from citrine.resources._default_labels import _inject_default_label_tags + @pytest.mark.parametrize( "original_tags, default_labels, expected", [ diff --git a/tests/resources/test_descriptors.py b/tests/resources/test_descriptors.py index b289025d4..7c89aa289 100644 --- a/tests/resources/test_descriptors.py +++ b/tests/resources/test_descriptors.py @@ -1,5 +1,3 @@ -import pytest - from uuid import uuid4 from citrine.informatics.data_sources import GemTableDataSource @@ -11,22 +9,22 @@ def test_from_predictor_responses(): session = FakeSession() - col = 'smiles' + col = "smiles" response_json = { - 'responses': [ # shortened sample response + "responses": [ # shortened sample response { - 'type': 'Real', - 'descriptor_key': 'khs.sNH3 KierHallSmarts for {}'.format(col), - 'units': '', - 'lower_bound': 0, - 'upper_bound': 1000000000 + "type": "Real", + "descriptor_key": "khs.sNH3 KierHallSmarts for {}".format(col), + "units": "", + "lower_bound": 0, + "upper_bound": 1000000000, }, { - 'type': 'Real', - 'descriptor_key': 'khs.dsN KierHallSmarts for {}'.format(col), - 'units': '', - 'lower_bound': 0, - 'upper_bound': 1000000000 + "type": "Real", + "descriptor_key": "khs.dsN KierHallSmarts for {}".format(col), + "units": "", + "lower_bound": 0, + "upper_bound": 1000000000, }, ] } @@ -37,71 +35,86 @@ def test_from_predictor_responses(): description="description", input_descriptor=MolecularStructureDescriptor(col), features=["all"], - excludes=["standard"] + excludes=["standard"], + ) + results = descriptors.from_predictor_responses( + predictor=featurizer, inputs=[MolecularStructureDescriptor(col)] ) - results = descriptors.from_predictor_responses(predictor=featurizer, inputs=[MolecularStructureDescriptor(col)]) assert results == [ RealDescriptor( - key=r['descriptor_key'], - lower_bound=r['lower_bound'], - upper_bound=r['upper_bound'], - units=r['units'] - ) for r in response_json['responses'] + key=r["descriptor_key"], + lower_bound=r["lower_bound"], + upper_bound=r["upper_bound"], + units=r["units"], + ) + for r in response_json["responses"] ] - assert session.last_call.path == '/projects/{}/material-descriptors/predictor-responses'\ - .format(descriptors.project_id) - assert session.last_call.method == 'POST' + assert ( + session.last_call.path + == "/projects/{}/material-descriptors/predictor-responses".format( + descriptors.project_id + ) + ) + assert session.last_call.method == "POST" graph = GraphPredictor( - name="Graph", - description="Contains a featurizer", - predictors=[featurizer] + name="Graph", description="Contains a featurizer", predictors=[featurizer] + ) + graph_results = descriptors.from_predictor_responses( + predictor=graph, inputs=[MolecularStructureDescriptor(col)] ) - graph_results = descriptors.from_predictor_responses(predictor=graph, inputs=[MolecularStructureDescriptor(col)]) assert graph_results == [ RealDescriptor( - key=r['descriptor_key'], - lower_bound=r['lower_bound'], - upper_bound=r['upper_bound'], - units=r['units'] - ) for r in response_json['responses'] + key=r["descriptor_key"], + lower_bound=r["lower_bound"], + upper_bound=r["upper_bound"], + units=r["units"], + ) + for r in response_json["responses"] ] def test_from_data_source(): session = FakeSession() - col = 'smiles' + col = "smiles" response_json = { - 'descriptors': [ # shortened sample response + "descriptors": [ # shortened sample response { - 'type': 'Real', - 'descriptor_key': 'khs.sNH3 KierHallSmarts for {}'.format(col), - 'units': '', - 'lower_bound': 0, - 'upper_bound': 1000000000 + "type": "Real", + "descriptor_key": "khs.sNH3 KierHallSmarts for {}".format(col), + "units": "", + "lower_bound": 0, + "upper_bound": 1000000000, }, { - 'type': 'Real', - 'descriptor_key': 'khs.dsN KierHallSmarts for {}'.format(col), - 'units': '', - 'lower_bound': 0, - 'upper_bound': 1000000000 + "type": "Real", + "descriptor_key": "khs.dsN KierHallSmarts for {}".format(col), + "units": "", + "lower_bound": 0, + "upper_bound": 1000000000, }, ] } session.set_response(response_json) descriptors = DescriptorMethods(uuid4(), session) - data_source = GemTableDataSource(table_id='43357a66-3644-4959-8115-77b2630aca45', table_version=123) + data_source = GemTableDataSource( + table_id="43357a66-3644-4959-8115-77b2630aca45", table_version=123 + ) results = descriptors.from_data_source(data_source=data_source) assert results == [ RealDescriptor( - key=r['descriptor_key'], - lower_bound=r['lower_bound'], - upper_bound=r['upper_bound'], - units=r['units'] - ) for r in response_json['descriptors'] + key=r["descriptor_key"], + lower_bound=r["lower_bound"], + upper_bound=r["upper_bound"], + units=r["units"], + ) + for r in response_json["descriptors"] ] - assert session.last_call.path == '/projects/{}/material-descriptors/from-data-source'\ - .format(descriptors.project_id) - assert session.last_call.method == 'POST' + assert ( + session.last_call.path + == "/projects/{}/material-descriptors/from-data-source".format( + descriptors.project_id + ) + ) + assert session.last_call.method == "POST" diff --git a/tests/resources/test_design_executions.py b/tests/resources/test_design_executions.py index 4dc786212..4e0dc10ec 100644 --- a/tests/resources/test_design_executions.py +++ b/tests/resources/test_design_executions.py @@ -24,7 +24,9 @@ def collection(session) -> DesignExecutionCollection: @pytest.fixture -def workflow_execution(collection: DesignExecutionCollection, design_execution_dict) -> DesignExecution: +def workflow_execution( + collection: DesignExecutionCollection, design_execution_dict +) -> DesignExecution: return collection.build(design_execution_dict) @@ -59,11 +61,15 @@ def test_build_new_execution(collection, design_execution_dict): assert execution.project_id == collection.project_id assert execution.workflow_id == collection.workflow_id assert execution._session == collection.session - assert execution.in_progress() and not execution.succeeded() and not execution.failed() + assert ( + execution.in_progress() and not execution.succeeded() and not execution.failed() + ) assert execution.status_detail -def test_trigger_workflow_execution(collection: DesignExecutionCollection, design_execution_dict, session): +def test_trigger_workflow_execution( + collection: DesignExecutionCollection, design_execution_dict, session +): # Given session.set_response(design_execution_dict) score = MLIScoreFactory() @@ -74,18 +80,20 @@ def test_trigger_workflow_execution(collection: DesignExecutionCollection, desig # Then assert str(actual_execution.uid) == design_execution_dict["id"] - expected_path = '/projects/{}/design-workflows/{}/executions'.format( + expected_path = "/projects/{}/design-workflows/{}/executions".format( collection.project_id, collection.workflow_id, ) assert session.last_call == FakeCall( - method='POST', + method="POST", path=expected_path, - json={'score': score.dump(), 'max_candidates': max_candidates} + json={"score": score.dump(), "max_candidates": max_candidates}, ) -def test_workflow_execution_results(workflow_execution: DesignExecution, session, example_candidates): +def test_workflow_execution_results( + workflow_execution: DesignExecution, session, example_candidates +): # Given session.set_response(example_candidates) @@ -93,15 +101,19 @@ def test_workflow_execution_results(workflow_execution: DesignExecution, session list(workflow_execution.candidates(per_page=4)) # Then - expected_path = '/projects/{}/design-workflows/{}/executions/{}/candidates'.format( + expected_path = "/projects/{}/design-workflows/{}/executions/{}/candidates".format( workflow_execution.project_id, workflow_execution.workflow_id, workflow_execution.uid, ) - assert session.last_call == FakeCall(method='GET', path=expected_path, params={"per_page": 4, 'page': 1}) + assert session.last_call == FakeCall( + method="GET", path=expected_path, params={"per_page": 4, "page": 1} + ) -def test_workflow_execution_hierarchical_results(workflow_execution: DesignExecution, session, example_hierarchical_candidates): +def test_workflow_execution_hierarchical_results( + workflow_execution: DesignExecution, session, example_hierarchical_candidates +): # Given session.set_response(example_hierarchical_candidates) @@ -109,22 +121,28 @@ def test_workflow_execution_hierarchical_results(workflow_execution: DesignExecu list(workflow_execution.hierarchical_candidates(per_page=4)) # Then - expected_path = '/projects/{}/design-workflows/{}/executions/{}/candidate-histories'.format( - workflow_execution.project_id, - workflow_execution.workflow_id, - workflow_execution.uid, + expected_path = ( + "/projects/{}/design-workflows/{}/executions/{}/candidate-histories".format( + workflow_execution.project_id, + workflow_execution.workflow_id, + workflow_execution.uid, + ) + ) + assert session.last_call == FakeCall( + method="GET", path=expected_path, params={"per_page": 4, "page": 1} ) - assert session.last_call == FakeCall(method='GET', path=expected_path, params={"per_page": 4, 'page': 1}) -def test_workflow_execution_results_pinned(workflow_execution: DesignExecution, session, example_candidates): +def test_workflow_execution_results_pinned( + workflow_execution: DesignExecution, session, example_candidates +): # Given pinned_by = uuid.uuid4() pinned_time = datetime.now() example_candidates_pinned = deepcopy(example_candidates) example_candidates_pinned["response"][0]["pinned"] = { "user": pinned_by, - "time": pinned_time + "time": pinned_time, } session.set_response(example_candidates_pinned) @@ -132,12 +150,14 @@ def test_workflow_execution_results_pinned(workflow_execution: DesignExecution, candidates = list(workflow_execution.candidates(per_page=4)) # Then - expected_path = '/projects/{}/design-workflows/{}/executions/{}/candidates'.format( + expected_path = "/projects/{}/design-workflows/{}/executions/{}/candidates".format( workflow_execution.project_id, workflow_execution.workflow_id, workflow_execution.uid, ) - assert session.last_call == FakeCall(method='GET', path=expected_path, params={"per_page": 4, 'page': 1}) + assert session.last_call == FakeCall( + method="GET", path=expected_path, params={"per_page": 4, "page": 1} + ) assert candidates[0].pinned_by == pinned_by assert candidates[0].pinned_time == pinned_time @@ -147,11 +167,11 @@ def test_list(collection: DesignExecutionCollection, session): lst = list(collection.list(per_page=4)) assert len(lst) == 0 - expected_path = '/projects/{}/design-workflows/{}/executions'.format(collection.project_id, collection.workflow_id) + expected_path = "/projects/{}/design-workflows/{}/executions".format( + collection.project_id, collection.workflow_id + ) assert session.last_call == FakeCall( - method='GET', - path=expected_path, - params={"per_page": 4, 'page': 1} + method="GET", path=expected_path, params={"per_page": 4, "page": 1} ) diff --git a/tests/resources/test_design_space.py b/tests/resources/test_design_space.py index f56f3ea85..e31c076ba 100644 --- a/tests/resources/test_design_space.py +++ b/tests/resources/test_design_space.py @@ -3,42 +3,38 @@ from copy import deepcopy from datetime import datetime, timezone -import mock import pytest -from citrine.exceptions import ModuleRegistrationFailedException, NotFound -from citrine.informatics.descriptors import RealDescriptor, FormulationKey -from citrine.informatics.design_spaces import DefaultDesignSpaceMode, DesignSpace, \ - DesignSpaceSettings, EnumeratedDesignSpace, HierarchicalDesignSpace, ProductDesignSpace +from citrine.informatics.descriptors import FormulationKey, RealDescriptor +from citrine.informatics.design_spaces import ( + DefaultDesignSpaceMode, + DesignSpace, + DesignSpaceSettings, + EnumeratedDesignSpace, + HierarchicalDesignSpace, + ProductDesignSpace, +) from citrine.resources.design_space import DesignSpaceCollection -from citrine.resources.status_detail import StatusDetail, StatusLevelEnum from tests.utils.session import FakeCall, FakeSession + def _ds_dict_to_response(ds_dict, status="CREATED"): - time = '2020-04-23T15:46:26Z' + time = "2020-04-23T15:46:26Z" return { "id": str(uuid.uuid4()), "data": { "name": ds_dict["name"], "description": ds_dict["description"], - "instance": ds_dict + "instance": ds_dict, }, "metadata": { - "created": { - "user": str(uuid.uuid4()), - "time": time - }, - "updated": { - "user": str(uuid.uuid4()), - "time": time - }, - "status": { - "name": status, - "detail": [] - } - } + "created": {"user": str(uuid.uuid4()), "time": time}, + "updated": {"user": str(uuid.uuid4()), "time": time}, + "status": {"name": status, "detail": []}, + }, } + def _ds_to_response(ds, status="CREATED"): return _ds_dict_to_response(ds.dump()["instance"], status) @@ -59,34 +55,49 @@ def test_design_space_build(valid_product_design_space_data): # Then assert str(design_space.uid) == design_space_id - assert design_space.name == valid_product_design_space_data["data"]["instance"]["name"] - assert design_space.dimensions[0].descriptor.key == valid_product_design_space_data["data"]["instance"]["dimensions"][0]["descriptor"]["descriptor_key"] + assert ( + design_space.name == valid_product_design_space_data["data"]["instance"]["name"] + ) + assert ( + design_space.dimensions[0].descriptor.key + == valid_product_design_space_data["data"]["instance"]["dimensions"][0][ + "descriptor" + ]["descriptor_key"] + ) def test_design_space_build_with_status_detail(valid_product_design_space_data): # Given collection = DesignSpaceCollection(uuid.uuid4(), None) - status_detail_data = {("Info", "info_msg"), ("Warning", "warning msg"), ("Error", "error msg")} + status_detail_data = { + ("Info", "info_msg"), + ("Warning", "warning msg"), + ("Error", "error msg"), + } data = deepcopy(valid_product_design_space_data) - data["metadata"]["status"]["detail"] = [{"level": level, "msg": msg} for level, msg in status_detail_data] + data["metadata"]["status"]["detail"] = [ + {"level": level, "msg": msg} for level, msg in status_detail_data + ] # When design_space = collection.build(data) # Then - status_detail_tuples = {(detail.level, detail.msg) for detail in design_space.status_detail} + status_detail_tuples = { + (detail.level, detail.msg) for detail in design_space.status_detail + } assert status_detail_tuples == status_detail_data def test_formulation_build(valid_formulation_design_space_data): pc = DesignSpaceCollection(uuid.uuid4(), None) design_space = pc.build(valid_formulation_design_space_data) - assert design_space.name == 'formulation design space' - assert design_space.description == 'formulates some things' + assert design_space.name == "formulation design space" + assert design_space.description == "formulates some things" assert design_space.formulation_descriptor.key == FormulationKey.HIERARCHICAL.value - assert design_space.ingredients == {'foo'} - assert design_space.labels == {'bar': {'foo'}} + assert design_space.ingredients == {"foo"} + assert design_space.labels == {"bar": {"foo"}} assert len(design_space.constraints) == 1 assert design_space.resolution == 0.1 @@ -94,8 +105,8 @@ def test_formulation_build(valid_formulation_design_space_data): def test_hierarchical_build(valid_hierarchical_design_space_data): dc = DesignSpaceCollection(uuid.uuid4(), None) hds = dc.build(valid_hierarchical_design_space_data) - assert hds.name == 'hierarchical design space' - assert hds.description == 'does things but in levels' + assert hds.name == "hierarchical design space" + assert hds.description == "does things but in levels" assert hds.root.formulation_subspace is not None assert hds.root.template_link is not None assert hds.root.display_name is not None @@ -113,17 +124,16 @@ def test_convert_to_hierarchical(valid_hierarchical_design_space_data): ds_id = uuid.uuid4() predictor_id = uuid.uuid4() - dc.convert_to_hierarchical(uid=ds_id, predictor_id=predictor_id, predictor_version=2) + dc.convert_to_hierarchical( + uid=ds_id, predictor_id=predictor_id, predictor_version=2 + ) - expected_payload = { - "predictor_id": str(predictor_id), - "predictor_version": 2 - } + expected_payload = {"predictor_id": str(predictor_id), "predictor_version": 2} expected_call = FakeCall( - method='POST', + method="POST", path=f"projects/{dc.project_id}/design-spaces/{ds_id}/convert-hierarchical", json=expected_payload, - version="v3" + version="v3", ) assert session.num_calls == 1 @@ -135,22 +145,25 @@ def test_design_space_limits(): # Given session = FakeSession() collection = DesignSpaceCollection(uuid.uuid4(), session) - - descriptors = [RealDescriptor(f"R-{i}", lower_bound=0, upper_bound=1, units="") for i in range(128)] + + descriptors = [ + RealDescriptor(f"R-{i}", lower_bound=0, upper_bound=1, units="") + for i in range(128) + ] descriptor_values = {f"R-{i}": str(random.random()) for i in range(128)} just_right = EnumeratedDesignSpace( "just right", description="just right desc", descriptors=descriptors, - data=[descriptor_values] * 2000 + data=[descriptor_values] * 2000, ) too_big = EnumeratedDesignSpace( "too big", description="too big desc", descriptors=just_right.descriptors, - data=[descriptor_values] * 2001 + data=[descriptor_values] * 2001, ) # create mock post response by setting the status. @@ -163,7 +176,7 @@ def test_design_space_limits(): "basic", description="basic desc", descriptors=[dummy_desc], - data=[{dummy_desc.key: descriptor_values[dummy_desc.key]}] + data=[{dummy_desc.key: descriptor_values[dummy_desc.key]}], ) mock_response = _ds_to_response(dummy_resp, status="READY") session.responses.append(mock_response) @@ -190,12 +203,9 @@ def test_design_space_limits(): def test_create_default(predictor_version, valid_product_design_space): session = FakeSession() session.set_response(valid_product_design_space.dump()) - + predictor_id = uuid.uuid4() - collection = DesignSpaceCollection( - project_id=uuid.uuid4(), - session=session - ) + collection = DesignSpaceCollection(project_id=uuid.uuid4(), session=session) expected_payload = DesignSpaceSettings( predictor_id=predictor_id, @@ -204,37 +214,43 @@ def test_create_default(predictor_version, valid_product_design_space): include_label_fraction_constraints=False, include_label_count_constraints=False, include_parameter_constraints=False, - mode=DefaultDesignSpaceMode.ATTRIBUTE + mode=DefaultDesignSpaceMode.ATTRIBUTE, ).dump() expected_call = FakeCall( - method='POST', + method="POST", path=f"projects/{collection.project_id}/design-spaces/default", json=expected_payload, - version="v3" + version="v3", ) - default_design_space = collection.create_default(predictor_id=predictor_id, predictor_version=predictor_version) + default_design_space = collection.create_default( + predictor_id=predictor_id, predictor_version=predictor_version + ) assert session.num_calls == 1 assert session.last_call == expected_call - - expected_response = {**valid_product_design_space.dump(), "settings": expected_payload} + + expected_response = { + **valid_product_design_space.dump(), + "settings": expected_payload, + } assert default_design_space.dump() == expected_response @pytest.mark.parametrize("predictor_version", (2, "1", "latest", None)) -def test_create_default_hierarchical(predictor_version, valid_hierarchical_design_space_data): - valid_hierarchical_design_space = HierarchicalDesignSpace.build(valid_hierarchical_design_space_data) +def test_create_default_hierarchical( + predictor_version, valid_hierarchical_design_space_data +): + valid_hierarchical_design_space = HierarchicalDesignSpace.build( + valid_hierarchical_design_space_data + ) session = FakeSession() session.set_response(valid_hierarchical_design_space.dump()) - + predictor_id = uuid.uuid4() - collection = DesignSpaceCollection( - project_id=uuid.uuid4(), - session=session - ) + collection = DesignSpaceCollection(project_id=uuid.uuid4(), session=session) expected_payload = DesignSpaceSettings( predictor_id=predictor_id, @@ -243,26 +259,29 @@ def test_create_default_hierarchical(predictor_version, valid_hierarchical_desig include_label_fraction_constraints=False, include_label_count_constraints=False, include_parameter_constraints=False, - mode=DefaultDesignSpaceMode.HIERARCHICAL + mode=DefaultDesignSpaceMode.HIERARCHICAL, ).dump() expected_call = FakeCall( - method='POST', + method="POST", path=f"projects/{collection.project_id}/design-spaces/default", json=expected_payload, - version="v3" + version="v3", ) default_design_space = collection.create_default( predictor_id=predictor_id, predictor_version=predictor_version, - mode=DefaultDesignSpaceMode.HIERARCHICAL + mode=DefaultDesignSpaceMode.HIERARCHICAL, ) assert session.num_calls == 1 assert session.last_call == expected_call - - expected_response = {**valid_hierarchical_design_space.dump(), "settings": expected_payload} + + expected_response = { + **valid_hierarchical_design_space.dump(), + "settings": expected_payload, + } assert default_design_space.dump() == expected_response @@ -270,18 +289,20 @@ def test_create_default_hierarchical(predictor_version, valid_hierarchical_desig @pytest.mark.parametrize("label_fractions", (True, False)) @pytest.mark.parametrize("label_count", (True, False)) @pytest.mark.parametrize("parameters", (True, False)) -def test_create_default_with_config(valid_product_design_space, ingredient_fractions, - label_fractions, label_count, parameters): +def test_create_default_with_config( + valid_product_design_space, + ingredient_fractions, + label_fractions, + label_count, + parameters, +): session = FakeSession() session.set_response(valid_product_design_space.dump()) - + predictor_id = uuid.uuid4() predictor_version = random.randint(1, 10) - collection = DesignSpaceCollection( - project_id=uuid.uuid4(), - session=session - ) - + collection = DesignSpaceCollection(project_id=uuid.uuid4(), session=session) + expected_payload = DesignSpaceSettings( predictor_id=predictor_id, predictor_version=predictor_version, @@ -289,14 +310,14 @@ def test_create_default_with_config(valid_product_design_space, ingredient_fract include_label_fraction_constraints=label_fractions, include_label_count_constraints=label_count, include_parameter_constraints=parameters, - mode=DefaultDesignSpaceMode.ATTRIBUTE + mode=DefaultDesignSpaceMode.ATTRIBUTE, ).dump() expected_call = FakeCall( - method='POST', + method="POST", path=f"projects/{collection.project_id}/design-spaces/default", json=expected_payload, - version="v3" + version="v3", ) default_design_space = collection.create_default( @@ -305,68 +326,104 @@ def test_create_default_with_config(valid_product_design_space, ingredient_fract include_ingredient_fraction_constraints=ingredient_fractions, include_label_fraction_constraints=label_fractions, include_label_count_constraints=label_count, - include_parameter_constraints=parameters + include_parameter_constraints=parameters, ) assert session.num_calls == 1 assert session.last_call == expected_call - - expected_response = {**valid_product_design_space.dump(), "settings": expected_payload} + + expected_response = { + **valid_product_design_space.dump(), + "settings": expected_payload, + } assert default_design_space.dump() == expected_response -def test_list_design_spaces(valid_product_design_space_data, valid_hierarchical_design_space_data): +def test_list_design_spaces( + valid_product_design_space_data, valid_hierarchical_design_space_data +): # Given session = FakeSession() collection = DesignSpaceCollection(uuid.uuid4(), session) - session.set_response({ - 'response': [valid_product_design_space_data, valid_hierarchical_design_space_data] - }) + session.set_response( + { + "response": [ + valid_product_design_space_data, + valid_hierarchical_design_space_data, + ] + } + ) # When design_spaces = list(collection.list(per_page=20)) # Then - expected_call = FakeCall(method='GET', path='/projects/{}/design-spaces'.format(collection.project_id), - params={'per_page': 20, 'page': 1, 'archived': False}, version="v4") + expected_call = FakeCall( + method="GET", + path="/projects/{}/design-spaces".format(collection.project_id), + params={"per_page": 20, "page": 1, "archived": False}, + version="v4", + ) assert 1 == session.num_calls, session.calls assert expected_call == session.calls[0] assert len(design_spaces) == 2 -def test_list_all_design_spaces(valid_product_design_space_data, valid_hierarchical_design_space_data): +def test_list_all_design_spaces( + valid_product_design_space_data, valid_hierarchical_design_space_data +): # Given session = FakeSession() collection = DesignSpaceCollection(uuid.uuid4(), session) - session.set_response({ - 'response': [valid_product_design_space_data, valid_hierarchical_design_space_data] - }) + session.set_response( + { + "response": [ + valid_product_design_space_data, + valid_hierarchical_design_space_data, + ] + } + ) # When design_spaces = list(collection.list_all(per_page=25)) # Then - expected_call = FakeCall(method='GET', path='/projects/{}/design-spaces'.format(collection.project_id), - params={'per_page': 25, 'page': 1}, version="v4") + expected_call = FakeCall( + method="GET", + path="/projects/{}/design-spaces".format(collection.project_id), + params={"per_page": 25, "page": 1}, + version="v4", + ) assert 1 == session.num_calls, session.calls assert expected_call == session.calls[0] assert len(design_spaces) == 2 -def test_list_archived_design_spaces(valid_product_design_space_data, valid_hierarchical_design_space_data): +def test_list_archived_design_spaces( + valid_product_design_space_data, valid_hierarchical_design_space_data +): # Given session = FakeSession() collection = DesignSpaceCollection(uuid.uuid4(), session) - session.set_response({ - 'response': [valid_product_design_space_data, valid_hierarchical_design_space_data] - }) + session.set_response( + { + "response": [ + valid_product_design_space_data, + valid_hierarchical_design_space_data, + ] + } + ) # When design_spaces = list(collection.list_archived(per_page=25)) # Then - expected_call = FakeCall(method='GET', path='/projects/{}/design-spaces'.format(collection.project_id), - params={'per_page': 25, 'page': 1, 'archived': True}, version="v4") + expected_call = FakeCall( + method="GET", + path="/projects/{}/design-spaces".format(collection.project_id), + params={"per_page": 25, "page": 1, "archived": True}, + version="v4", + ) assert 1 == session.num_calls, session.calls assert expected_call == session.calls[0] assert len(design_spaces) == 2 @@ -386,7 +443,7 @@ def test_archive(valid_product_design_space_data): assert archived_design_space.is_archived assert session.calls == [ - FakeCall(method='PUT', path=f"{base_path}/{ds_id}/archive", json={}), + FakeCall(method="PUT", path=f"{base_path}/{ds_id}/archive", json={}), ] @@ -405,7 +462,7 @@ def test_restore(valid_product_design_space_data): assert not restored_design_space.is_archived assert session.calls == [ - FakeCall(method='PUT', path=f"{base_path}/{ds_id}/restore", json={}), + FakeCall(method="PUT", path=f"{base_path}/{ds_id}/restore", json={}), ] @@ -421,36 +478,36 @@ def test_get_none(): def test_failed_register(valid_product_design_space_data): response_data = deepcopy(valid_product_design_space_data) - response_data['metadata']['status']['name'] = 'INVALID' + response_data["metadata"]["status"]["name"] = "INVALID" session = FakeSession() session.set_response(response_data) dsc = DesignSpaceCollection(uuid.uuid4(), session) ds = dsc.build(deepcopy(valid_product_design_space_data)) - + retval = dsc.register(ds) - + base_path = f"/projects/{dsc.project_id}/design-spaces" assert session.calls == [ - FakeCall(method='POST', path=base_path, json=ds.dump()), + FakeCall(method="POST", path=base_path, json=ds.dump()), ] assert retval.dump() == ds.dump() def test_failed_update(valid_product_design_space_data): response_data = deepcopy(valid_product_design_space_data) - response_data['metadata']['status']['name'] = 'INVALID' + response_data["metadata"]["status"]["name"] = "INVALID" session = FakeSession() session.set_response(response_data) dsc = DesignSpaceCollection(uuid.uuid4(), session) ds = dsc.build(deepcopy(valid_product_design_space_data)) - + retval = dsc.update(ds) - + base_path = f"/projects/{dsc.project_id}/design-spaces" assert session.calls == [ - FakeCall(method='PUT', path=f'{base_path}/{ds.uid}', json=ds.dump()), + FakeCall(method="PUT", path=f"{base_path}/{ds.uid}", json=ds.dump()), ] assert retval.dump() == ds.dump() @@ -475,9 +532,9 @@ def test_carrying_settings_from_create_default(valid_product_design_space): default_design_space = collection.create_default( predictor_id=predictor_id, predictor_version=predictor_version, - include_label_count_constraints=True + include_label_count_constraints=True, ) - registered = collection.register(default_design_space) + collection.register(default_design_space) expected_settings = DesignSpaceSettings( predictor_id=predictor_id, @@ -486,15 +543,18 @@ def test_carrying_settings_from_create_default(valid_product_design_space): include_label_fraction_constraints=False, include_label_count_constraints=True, include_parameter_constraints=False, - mode=DefaultDesignSpaceMode.ATTRIBUTE + mode=DefaultDesignSpaceMode.ATTRIBUTE, ) - expected_payload = {**valid_product_design_space.dump(), "settings": expected_settings.dump()} + expected_payload = { + **valid_product_design_space.dump(), + "settings": expected_settings.dump(), + } expected_call = FakeCall( - method='POST', + method="POST", path=f"projects/{collection.project_id}/design-spaces", json=expected_payload, - version="v3" + version="v3", ) assert session.num_calls == 3 @@ -506,7 +566,7 @@ def test_carrying_settings_from_get(valid_product_design_space): predictor_version = 4 session = FakeSession() - + expected_settings = DesignSpaceSettings( predictor_id=predictor_id, predictor_version=predictor_version, @@ -515,7 +575,7 @@ def test_carrying_settings_from_get(valid_product_design_space): include_label_fraction_constraints=False, include_label_count_constraints=False, include_parameter_constraints=True, - mode=DefaultDesignSpaceMode.ATTRIBUTE + mode=DefaultDesignSpaceMode.ATTRIBUTE, ) ds_resp = _ds_to_response(valid_product_design_space) @@ -525,15 +585,18 @@ def test_carrying_settings_from_get(valid_product_design_space): collection = DesignSpaceCollection(project_id=uuid.uuid4(), session=session) retrieved = collection.get(uuid.uuid4()) - registered = collection.register(retrieved) + collection.register(retrieved) - expected_payload = {**valid_product_design_space.dump(), "settings": expected_settings.dump()} + expected_payload = { + **valid_product_design_space.dump(), + "settings": expected_settings.dump(), + } expected_call = FakeCall( - method='POST', + method="POST", path=f"projects/{collection.project_id}/design-spaces", json=expected_payload, - version="v3" + version="v3", ) assert session.num_calls == 3 @@ -557,7 +620,10 @@ def test_locked(valid_product_design_space_data): lock_timestamp = int(lock_time.timestamp()) * 1000 response_data = deepcopy(valid_product_design_space_data) - response_data['metadata']['locked'] = {'user': str(lock_user), 'time': lock_timestamp} + response_data["metadata"]["locked"] = { + "user": str(lock_user), + "time": lock_timestamp, + } session.set_response(response_data) @@ -568,9 +634,14 @@ def test_locked(valid_product_design_space_data): assert ds.lock_time == lock_time -@pytest.mark.parametrize("ds_data_fixture_name", ("valid_formulation_design_space_data", - "valid_enumerated_design_space_data", - "valid_data_source_design_space_dict")) +@pytest.mark.parametrize( + "ds_data_fixture_name", + ( + "valid_formulation_design_space_data", + "valid_enumerated_design_space_data", + "valid_data_source_design_space_dict", + ), +) def test_deprecated_top_level_design_spaces(request, ds_data_fixture_name): ds_data = request.getfixturevalue(ds_data_fixture_name) diff --git a/tests/resources/test_design_workflows.py b/tests/resources/test_design_workflows.py index 077c7346c..65a5ac06d 100644 --- a/tests/resources/test_design_workflows.py +++ b/tests/resources/test_design_workflows.py @@ -7,16 +7,20 @@ from citrine.informatics.workflows import DesignWorkflow from citrine.resources.design_workflow import DesignWorkflowCollection from tests.utils.factories import ( - BranchDataFactory, DesignWorkflowDataFactory, TableDataSourceFactory + BranchDataFactory, + DesignWorkflowDataFactory, + TableDataSourceFactory, ) from tests.utils.session import FakeSession, FakeCall PARTIAL_DW_ARGS = ( ("data_source_id", lambda: TableDataSourceFactory().to_data_source_id()), ("predictor_id", lambda: str(uuid.uuid4())), - ("design_space_id", lambda: str(uuid.uuid4())) + ("design_space_id", lambda: str(uuid.uuid4())), +) +OPTIONAL_ARGS = PARTIAL_DW_ARGS + ( + ("predictor_version", lambda: random.randint(1, 10)), ) -OPTIONAL_ARGS = PARTIAL_DW_ARGS + (("predictor_version", lambda: random.randint(1, 10)),) @pytest.fixture @@ -42,8 +46,8 @@ def collection(branch_data, collection_without_branch) -> DesignWorkflowCollecti return DesignWorkflowCollection( project_id=collection_without_branch.project_id, session=collection_without_branch.session, - branch_root_id=uuid.UUID(branch_data['metadata']['root_id']), - branch_version=branch_data['metadata']['version'], + branch_root_id=uuid.UUID(branch_data["metadata"]["root_id"]), + branch_version=branch_data["metadata"]["version"], ) @@ -55,14 +59,18 @@ def workflow(collection, branch_data) -> DesignWorkflow: def all_combination_lengths(vals, maxlen=None): maxlen = maxlen or len(vals) - return [args for k in range(0, maxlen + 1) for args in itertools.combinations(vals, k)] + return [ + args for k in range(0, maxlen + 1) for args in itertools.combinations(vals, k) + ] + def workflow_path(collection, workflow=None): - path = f'/projects/{collection.project_id}/design-workflows' + path = f"/projects/{collection.project_id}/design-workflows" if workflow: - path = f'{path}/{workflow.uid}' + path = f"{path}/{workflow.uid}" return path + def assert_workflow(actual, expected, *, include_branch=False): assert actual.name == expected.name assert actual.description == expected.description @@ -79,7 +87,7 @@ def assert_workflow(actual, expected, *, include_branch=False): def test_basic_methods(workflow, collection): - assert 'DesignWorkflow' in str(workflow) + assert "DesignWorkflow" in str(workflow) assert workflow.design_executions.project_id == workflow.project_id @@ -90,7 +98,7 @@ def test_register(session, branch_data, collection, optional_args): workflow_data = DesignWorkflowDataFactory(**kw_args, branch=branch_data) # Given - post_dict = {k: v for k, v in workflow_data.items() if k != 'status_description'} + post_dict = {k: v for k, v in workflow_data.items() if k != "status_description"} session.set_responses(workflow_data) # When @@ -98,7 +106,9 @@ def test_register(session, branch_data, collection, optional_args): new_workflow = collection.register(old_workflow) # Then - assert session.calls == [FakeCall(method='POST', path=workflow_path(collection), json=post_dict)] + assert session.calls == [ + FakeCall(method="POST", path=workflow_path(collection), json=post_dict) + ] assert new_workflow.branch_root_id == collection.branch_root_id assert new_workflow.branch_version == collection.branch_version @@ -110,18 +120,24 @@ def test_register_conflicting_branches(session, branch_data, workflow, collectio old_branch_root_id = uuid.uuid4() workflow.branch_root_id = old_branch_root_id assert workflow.branch_root_id != collection.branch_root_id - + new_branch_root_id = str(branch_data["metadata"]["root_id"]) new_branch_version = branch_data["metadata"]["version"] - post_dict = {**workflow.dump(), "branch_root_id": new_branch_root_id, "branch_version": new_branch_version} - session.set_responses({**post_dict, 'status_description': 'status'}) + post_dict = { + **workflow.dump(), + "branch_root_id": new_branch_root_id, + "branch_version": new_branch_version, + } + session.set_responses({**post_dict, "status_description": "status"}) # When new_workflow = collection.register(workflow) # Then - assert session.calls == [FakeCall(method='POST', path=workflow_path(collection), json=post_dict)] + assert session.calls == [ + FakeCall(method="POST", path=workflow_path(collection), json=post_dict) + ] assert workflow.branch_root_id == old_branch_root_id assert new_workflow.branch_root_id == collection.branch_root_id @@ -136,14 +152,22 @@ def test_register_partial_workflow_without_branch(session, collection_without_br def test_archive(workflow, collection): collection.archive(workflow.uid) - expected_path = '/projects/{}/design-workflows/{}/archive'.format(collection.project_id, workflow.uid) - assert collection.session.last_call == FakeCall(method='PUT', path=expected_path, json={}) + expected_path = "/projects/{}/design-workflows/{}/archive".format( + collection.project_id, workflow.uid + ) + assert collection.session.last_call == FakeCall( + method="PUT", path=expected_path, json={} + ) def test_restore(workflow, collection): collection.restore(workflow.uid) - expected_path = '/projects/{}/design-workflows/{}/restore'.format(collection.project_id, workflow.uid) - assert collection.session.last_call == FakeCall(method='PUT', path=expected_path, json={}) + expected_path = "/projects/{}/design-workflows/{}/restore".format( + collection.project_id, workflow.uid + ) + assert collection.session.last_call == FakeCall( + method="PUT", path=expected_path, json={} + ) def test_delete(collection): @@ -152,20 +176,26 @@ def test_delete(collection): def test_list_archived(branch_data, workflow, collection: DesignWorkflowCollection): - branch_root_id = uuid.UUID(branch_data['metadata']['root_id']) - branch_version = branch_data['metadata']['version'] + branch_root_id = uuid.UUID(branch_data["metadata"]["root_id"]) + branch_version = branch_data["metadata"]["version"] collection.session.set_responses({"response": []}) lst = list(collection.list_archived(per_page=10)) assert len(lst) == 0 - expected_path = '/projects/{}/design-workflows'.format(collection.project_id) + expected_path = "/projects/{}/design-workflows".format(collection.project_id) assert collection.session.last_call == FakeCall( - method='GET', + method="GET", path=expected_path, - params={'page': 1, 'per_page': 10, 'filter': "archived eq 'true'", 'branch_root_id': branch_root_id, 'branch_version': branch_version}, - json=None + params={ + "page": 1, + "per_page": 10, + "filter": "archived eq 'true'", + "branch_root_id": branch_root_id, + "branch_version": branch_version, + }, + json=None, ) @@ -184,27 +214,36 @@ def test_update(session, branch_data, workflow, collection_without_branch): post_dict = workflow.dump() session.set_responses( {"per_page": 1, "next": "", "response": []}, - {**post_dict, 'status_description': 'status'}, + {**post_dict, "status_description": "status"}, ) # When new_workflow = collection_without_branch.update(workflow) # Then - executions_path = f'/projects/{collection_without_branch.project_id}/design-workflows/{workflow.uid}/executions' + executions_path = f"/projects/{collection_without_branch.project_id}/design-workflows/{workflow.uid}/executions" assert session.calls == [ - FakeCall(method='GET', path=executions_path, params={'page': 1, 'per_page': 100}), - FakeCall(method='PUT', path=workflow_path(collection_without_branch, workflow), json=post_dict), + FakeCall( + method="GET", path=executions_path, params={"page": 1, "per_page": 100} + ), + FakeCall( + method="PUT", + path=workflow_path(collection_without_branch, workflow), + json=post_dict, + ), ] assert_workflow(new_workflow, workflow) -def test_update_failure_with_existing_execution(session, branch_data, workflow, collection_without_branch, design_execution_dict): +def test_update_failure_with_existing_execution( + session, branch_data, workflow, collection_without_branch, design_execution_dict +): workflow.branch_root_id = uuid.uuid4() post_dict = workflow.dump() session.set_responses( {"per_page": 1, "next": "", "response": [design_execution_dict]}, - {**post_dict, 'status_description': 'status'}) + {**post_dict, "status_description": "status"}, + ) with pytest.raises(RuntimeError): collection_without_branch.update(workflow) @@ -219,7 +258,9 @@ def test_update_with_mismatched_branch_root_ids(session, workflow, collection): collection.update(workflow) -def test_update_model_missing_branch_root_id(session, workflow, collection_without_branch): +def test_update_model_missing_branch_root_id( + session, workflow, collection_without_branch +): # Given workflow.branch_root_id = None @@ -228,7 +269,9 @@ def test_update_model_missing_branch_root_id(session, workflow, collection_witho collection_without_branch.update(workflow) -def test_update_model_missing_branch_version(session, workflow, collection_without_branch): +def test_update_model_missing_branch_version( + session, workflow, collection_without_branch +): # Given workflow.branch_version = None @@ -244,6 +287,7 @@ def test_update_branch_not_found(collection, workflow): with pytest.raises(ValueError): collection.update(workflow) + def test_data_source_id(workflow): original_id = workflow.data_source_id assert workflow.data_source.to_data_source_id() == original_id diff --git a/tests/resources/test_experiment_datasource.py b/tests/resources/test_experiment_datasource.py index b7db06011..023abe31b 100644 --- a/tests/resources/test_experiment_datasource.py +++ b/tests/resources/test_experiment_datasource.py @@ -5,12 +5,12 @@ import pytest -from citrine.resources.experiment_datasource import ExperimentDataSource, ExperimentDataSourceCollection -from tests.utils.factories import ExperimentDataSourceDataFactory -from tests.utils.session import ( - FakeCall, - FakeSession +from citrine.resources.experiment_datasource import ( + ExperimentDataSource, + ExperimentDataSourceCollection, ) +from tests.utils.factories import ExperimentDataSourceDataFactory +from tests.utils.session import FakeCall, FakeSession LATEST_VER = "latest" @@ -28,11 +28,13 @@ def collection(session) -> ExperimentDataSourceCollection: @pytest.fixture def erds_base_path(collection): - return f'projects/{collection.project_id}/candidate-experiment-datasources' + return f"projects/{collection.project_id}/candidate-experiment-datasources" def assert_erds_csv(erds_csv, erds_dict): - for row, expt in zip(csv.DictReader(io.StringIO(erds_csv)), erds_dict["data"]["experiments"]): + for row, expt in zip( + csv.DictReader(io.StringIO(erds_csv)), erds_dict["data"]["experiments"] + ): for variable, actual_value_raw in row.items(): assert expt["overrides"][variable]["value"] == json.loads(actual_value_raw) @@ -46,18 +48,28 @@ def test_build(collection): assert actual_erds.version == erds_dict["metadata"]["version"] assert str(actual_erds.created_by) == erds_dict["metadata"]["created"]["user"] # TODO: It'd be better to actually invoke the Datetime._serialize method - assert int(actual_erds.create_time.timestamp() * 1000 + 0.0001) == erds_dict["metadata"]["created"]["time"] - - for actual_experiment, erds_experiment in zip(actual_erds.experiments, erds_dict["data"]["experiments"]): + assert ( + int(actual_erds.create_time.timestamp() * 1000 + 0.0001) + == erds_dict["metadata"]["created"]["time"] + ) + + for actual_experiment, erds_experiment in zip( + actual_erds.experiments, erds_dict["data"]["experiments"] + ): assert str(actual_experiment.uid) == erds_experiment["experiment_id"] assert str(actual_experiment.candidate_id) == erds_experiment["candidate_id"] assert str(actual_experiment.workflow_id) == erds_experiment["workflow_id"] assert actual_experiment.name == erds_experiment["name"] assert actual_experiment.description == erds_experiment["description"] # TODO: It'd be better to actually invoke the Datetime._serialize method - assert int(actual_experiment.updated_time.timestamp() * 1000 + 0.0001) == erds_experiment["updated_time"] - - for actual_override, erds_override in zip(actual_experiment.overrides.items(), erds_experiment["overrides"].items()): + assert ( + int(actual_experiment.updated_time.timestamp() * 1000 + 0.0001) + == erds_experiment["updated_time"] + ) + + for actual_override, erds_override in zip( + actual_experiment.overrides.items(), erds_experiment["overrides"].items() + ): actual_override_key, actual_override_value = actual_override erds_override_key, erds_override_value = erds_override assert actual_override_key == erds_override_key @@ -78,12 +90,44 @@ def test_list(session, collection, erds_base_path): list(collection.list(branch_version_id=version_id, version=LATEST_VER)) assert session.calls == [ - FakeCall(method='GET', path=erds_base_path, params={'per_page': 100, 'page': 1}), - FakeCall(method='GET', path=erds_base_path, params={'per_page': 100, "branch": str(version_id), 'page': 1}), - FakeCall(method='GET', path=erds_base_path, params={'per_page': 100, "version": 4, 'page': 1}), - FakeCall(method='GET', path=erds_base_path, params={'per_page': 100, "version": LATEST_VER, 'page': 1}), - FakeCall(method='GET', path=erds_base_path, params={'per_page': 100, "branch": str(version_id), "version": 12, 'page': 1}), - FakeCall(method='GET', path=erds_base_path, params={'per_page': 100, "branch": str(version_id), "version": LATEST_VER, 'page': 1}) + FakeCall( + method="GET", path=erds_base_path, params={"per_page": 100, "page": 1} + ), + FakeCall( + method="GET", + path=erds_base_path, + params={"per_page": 100, "branch": str(version_id), "page": 1}, + ), + FakeCall( + method="GET", + path=erds_base_path, + params={"per_page": 100, "version": 4, "page": 1}, + ), + FakeCall( + method="GET", + path=erds_base_path, + params={"per_page": 100, "version": LATEST_VER, "page": 1}, + ), + FakeCall( + method="GET", + path=erds_base_path, + params={ + "per_page": 100, + "branch": str(version_id), + "version": 12, + "page": 1, + }, + ), + FakeCall( + method="GET", + path=erds_base_path, + params={ + "per_page": 100, + "branch": str(version_id), + "version": LATEST_VER, + "page": 1, + }, + ), ] @@ -96,7 +140,7 @@ def test_read_and_retrieve(session, collection, erds_base_path): erds_csv = collection.read(erds_id) - assert session.calls == [FakeCall(method='GET', path=erds_path)] + assert session.calls == [FakeCall(method="GET", path=erds_path)] assert_erds_csv(erds_csv, erds_dict) diff --git a/tests/resources/test_file_link.py b/tests/resources/test_file_link.py index 338559e97..232386e23 100644 --- a/tests/resources/test_file_link.py +++ b/tests/resources/test_file_link.py @@ -9,16 +9,30 @@ from botocore.exceptions import ClientError from citrine.resources.api_error import ValidationError -from citrine.resources.file_link import FileCollection, FileLink, GEMDFileLink, _Uploader, \ - _get_ids_from_url +from citrine.resources.file_link import ( + FileCollection, + FileLink, + GEMDFileLink, + _Uploader, + _get_ids_from_url, +) from citrine.resources.ingestion import Ingestion, IngestionCollection from citrine.exceptions import NotFound from tests.utils.factories import ( - FileLinkDataFactory, _UploaderFactory, JobStatusResponseDataFactory, - IngestionStatusResponseDataFactory, IngestFilesResponseDataFactory, JobSubmissionResponseDataFactory + FileLinkDataFactory, + _UploaderFactory, + JobStatusResponseDataFactory, + IngestionStatusResponseDataFactory, + IngestFilesResponseDataFactory, + JobSubmissionResponseDataFactory, +) +from tests.utils.session import ( + FakeSession, + FakeS3Client, + FakeCall, + FakeRequestResponseApiError, ) -from tests.utils.session import FakeSession, FakeS3Client, FakeCall, FakeRequestResponseApiError @pytest.fixture @@ -28,29 +42,25 @@ def session() -> FakeSession: @pytest.fixture def collection(session) -> FileCollection: - return FileCollection( - team_id=uuid4(), - dataset_id=uuid4(), - session=session - ) + return FileCollection(team_id=uuid4(), dataset_id=uuid4(), session=session) @pytest.fixture def valid_data() -> dict: - return FileLinkDataFactory(url='www.citrine.io', filename='materials.txt') + return FileLinkDataFactory(url="www.citrine.io", filename="materials.txt") @pytest.mark.parametrize( ("filename", "mimetype"), [ pytest.param( - "asdf.xlsx", - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "asdf.xlsx", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", marks=pytest.mark.xfail( - platform.system() == "Windows", - reason="windows-latest test servers omit xlsx from their registry", - strict=True - ) + platform.system() == "Windows", + reason="windows-latest test servers omit xlsx from their registry", + strict=True, + ), ), ("asdf.xls", "application/vnd.ms-excel"), ("asdf.XLS", "application/vnd.ms-excel"), @@ -81,13 +91,13 @@ def test_name_alias(valid_data): def test_string_representation(valid_data): """Test the string representation.""" - assert str(FileLink.build(valid_data)) == '' + assert str(FileLink.build(valid_data)) == "" def test_from_path(): """Test the string representation.""" - path = Path.cwd() / 'some' / 'path' / 'with' / 'file.txt' - assert FileLink.from_path(path).filename == 'file.txt' + path = Path.cwd() / "some" / "path" / "with" / "file.txt" + assert FileLink.from_path(path).filename == "file.txt" assert FileLink.from_path(str(path)).url == path.as_uri() assert FileCollection._is_local_url(FileLink.from_path(path).url) @@ -97,9 +107,10 @@ def uploader() -> _Uploader: """An _Uploader object with all of its fields filled in.""" return _UploaderFactory() + def test_deprecation_of_positional_arguments(session): - team_id = UUID('6b608f78-e341-422c-8076-35adc8828000') - check_project = {'project': {'team': {'id': team_id}}} + team_id = UUID("6b608f78-e341-422c-8076-35adc8828000") + check_project = {"project": {"team": {"id": team_id}}} session.set_response(check_project) with pytest.deprecated_call(): _ = FileCollection(uuid4(), uuid4(), session) @@ -108,6 +119,7 @@ def test_deprecation_of_positional_arguments(session): with pytest.raises(TypeError): _ = FileCollection(project_id=uuid4(), dataset_id=None, session=session) + def test_delete(collection: FileCollection, session): """Test that deletion calls the expected endpoint and checks the url structure.""" # Given @@ -120,21 +132,20 @@ def test_delete(collection: FileCollection, session): # Then assert 1 == session.num_calls - expected_call = FakeCall( - method='DELETE', - path=collection._get_path(file_id) - ) + expected_call = FakeCall(method="DELETE", path=collection._get_path(file_id)) assert expected_call == session.last_call # A URL that does not follow the files/{id}/versions/{id} format is invalid - for chunk in (f'{file_id}', f'{file_id}/{version_id}'): - invalid_url = f'{collection._get_path}/{chunk}' + for chunk in (f"{file_id}", f"{file_id}/{version_id}"): + invalid_url = f"{collection._get_path}/{chunk}" invalid_file_link = collection.build(FileLinkDataFactory(url=invalid_url)) with pytest.raises(ValueError): collection.delete(invalid_file_link) # A remote URL is invalid - ext_invalid_url = f'http://www.citrine.io/develop/files/{file_id}/versions/{version_id}' + ext_invalid_url = ( + f"http://www.citrine.io/develop/files/{file_id}/versions/{version_id}" + ) ext_invalid_file_link = collection.build(FileLinkDataFactory(url=ext_invalid_url)) with pytest.raises(ValueError): collection.delete(ext_invalid_file_link) @@ -142,37 +153,34 @@ def test_delete(collection: FileCollection, session): def test_upload(collection: FileCollection, session, tmpdir, monkeypatch): """Test signaling that an upload has completed and the creation of a FileLink object.""" - monkeypatch.setattr(Session, 'client', lambda *args, **kwargs: FakeS3Client({'VersionId': '42'})) + monkeypatch.setattr( + Session, "client", lambda *args, **kwargs: FakeS3Client({"VersionId": "42"}) + ) # It would be good to test these, but the values assigned are not accessible dest_names = { - 'foo.txt': 'text/plain', - 'foo.TXT': 'text/plain', # Capitalization in extension is fine - 'foo.bar': 'application/octet-stream' # No match == generic binary + "foo.txt": "text/plain", + "foo.TXT": "text/plain", # Capitalization in extension is fine + "foo.bar": "application/octet-stream", # No match == generic binary } file_id = str(uuid4()) version = str(uuid4()) # This is the dictionary structure we expect from the upload completion request - file_info_response = { - 'file_info': { - 'file_id': file_id, - 'version': version - } - } + file_info_response = {"file_info": {"file_id": file_id, "version": version}} uploads_response = { - 's3_region': 'us-east-1', - 's3_bucket': 'temp-bucket', - 'temporary_credentials': { - 'access_key_id': '1234', - 'secret_access_key': 'abbb8777', - 'session_token': 'hefheuhuhhu83772333', + "s3_region": "us-east-1", + "s3_bucket": "temp-bucket", + "temporary_credentials": { + "access_key_id": "1234", + "secret_access_key": "abbb8777", + "session_token": "hefheuhuhhu83772333", }, - 'uploads': [ + "uploads": [ { - 's3_key': '66377378', - 'upload_id': '111', + "s3_key": "66377378", + "upload_id": "111", } - ] + ], } for dest_name in dest_names: @@ -182,8 +190,9 @@ def test_upload(collection: FileCollection, session, tmpdir, monkeypatch): session.set_responses(uploads_response, file_info_response) file_link = collection.upload(file_path=tmp_path) - url = 'teams/{}/datasets/{}/files/{}/versions/{}'\ - .format(collection.team_id, collection.dataset_id, file_id, version) + url = "teams/{}/datasets/{}/files/{}/versions/{}".format( + collection.team_id, collection.dataset_id, file_id, version + ) assert file_link.dump() == FileLink(dest_name, url=url).dump() assert session.num_calls == 2 * len(dest_names) @@ -191,30 +200,25 @@ def test_upload(collection: FileCollection, session, tmpdir, monkeypatch): def test_upload_missing_file(collection: FileCollection): with pytest.raises(ValueError): - collection.upload(file_path='this-file-does-not-exist.xls') + collection.upload(file_path="this-file-does-not-exist.xls") def test_upload_request(collection: FileCollection, session, uploader, tmpdir): """Test that an upload request response contains all required fields.""" - filename = 'foo.txt' + filename = "foo.txt" tmppath = Path(tmpdir) / filename tmppath.write_text("Arbitrary text") # This is the dictionary structure we expect from the upload request upload_request_response = { - 's3_region': uploader.region_name, - 's3_bucket': uploader.bucket, - 'temporary_credentials': { - 'access_key_id': uploader.aws_access_key_id, - 'secret_access_key': uploader.aws_secret_access_key, - 'session_token': uploader.aws_session_token, + "s3_region": uploader.region_name, + "s3_bucket": uploader.bucket, + "temporary_credentials": { + "access_key_id": uploader.aws_access_key_id, + "secret_access_key": uploader.aws_secret_access_key, + "session_token": uploader.aws_session_token, }, - 'uploads': [ - { - 's3_key': uploader.object_key, - 'upload_id': uploader.upload_id - } - ] + "uploads": [{"s3_key": uploader.object_key, "upload_id": uploader.upload_id}], } session.set_response(upload_request_response) new_uploader = collection._make_upload_request(tmppath, filename) @@ -232,38 +236,35 @@ def test_upload_request(collection: FileCollection, session, uploader, tmpdir): assert new_uploader.s3_addressing_style == uploader.s3_addressing_style # Using a request response that is missing a field throws a RuntimeError - del upload_request_response['s3_bucket'] + del upload_request_response["s3_bucket"] with pytest.raises(RuntimeError): collection._make_upload_request(tmppath, filename) -def test_upload_request_s3_overrides(collection: FileCollection, session, uploader, tmpdir): +def test_upload_request_s3_overrides( + collection: FileCollection, session, uploader, tmpdir +): """Test that an upload request response contains all required fields.""" - filename = 'foo.txt' + filename = "foo.txt" tmppath = Path(tmpdir) / filename tmppath.write_text("Arbitrary text") # This is the dictionary structure we expect from the upload request upload_request_response = { - 's3_region': uploader.region_name, - 's3_bucket': uploader.bucket, - 'temporary_credentials': { - 'access_key_id': uploader.aws_access_key_id, - 'secret_access_key': uploader.aws_secret_access_key, - 'session_token': uploader.aws_session_token, + "s3_region": uploader.region_name, + "s3_bucket": uploader.bucket, + "temporary_credentials": { + "access_key_id": uploader.aws_access_key_id, + "secret_access_key": uploader.aws_secret_access_key, + "session_token": uploader.aws_session_token, }, - 'uploads': [ - { - 's3_key': uploader.object_key, - 'upload_id': uploader.upload_id - } - ] + "uploads": [{"s3_key": uploader.object_key, "upload_id": uploader.upload_id}], } session.set_response(upload_request_response) # Override the s3 endpoint settings in the session, ensure they make it to the upload - endpoint = 'http://foo.bar' - addressing_style = 'path' + endpoint = "http://foo.bar" + addressing_style = "path" use_ssl = False session.s3_endpoint_url = endpoint session.s3_addressing_style = addressing_style @@ -275,30 +276,34 @@ def test_upload_request_s3_overrides(collection: FileCollection, session, upload assert new_uploader.s3_addressing_style == addressing_style -def test_upload_file(collection: FileCollection, session, uploader, tmpdir, monkeypatch): +def test_upload_file( + collection: FileCollection, session, uploader, tmpdir, monkeypatch +): """Test that uploading a file returns the version ID.""" - filename = 'foo.txt' + filename = "foo.txt" tmppath = Path(tmpdir) / filename tmppath.write_text("Arbitrary text") # A successful file upload sets uploader.s3_version - new_version = '3' + new_version = "3" with monkeypatch.context() as m: - client = FakeS3Client({'VersionId': new_version}) - m.setattr(Session, 'client', lambda *args, **kwargs: client) + client = FakeS3Client({"VersionId": new_version}) + m.setattr(Session, "client", lambda *args, **kwargs: client) new_uploader = collection._upload_file(tmppath, uploader) assert new_uploader.s3_version == new_version # If the client throws a ClientError when attempting to upload, throw a RuntimeError with monkeypatch.context() as m: - client = FakeS3Client(ClientError(error_response={}, operation_name='put'), raises=True) - m.setattr(Session, 'client', lambda *args, **kwargs: client) + client = FakeS3Client( + ClientError(error_response={}, operation_name="put"), raises=True + ) + m.setattr(Session, "client", lambda *args, **kwargs: client) with pytest.raises(RuntimeError): collection._upload_file(tmppath, uploader) - s3_addressing_style = 'path' - s3_endpoint_url = 'http://foo.bar' + s3_addressing_style = "path" + s3_endpoint_url = "http://foo.bar" s3_use_ssl = False uploader.s3_addressing_style = s3_addressing_style @@ -310,26 +315,24 @@ def test_upload_file(collection: FileCollection, session, uploader, tmpdir, monk def _stash_kwargs(*_, **kwargs): stashed_kwargs.update(kwargs) - return FakeS3Client({'VersionId': '71'}) + return FakeS3Client({"VersionId": "71"}) - m.setattr(Session, 'client', _stash_kwargs) + m.setattr(Session, "client", _stash_kwargs) collection._upload_file(tmppath, uploader) - assert stashed_kwargs['config'].s3['addressing_style'] is s3_addressing_style - assert stashed_kwargs['endpoint_url'] is s3_endpoint_url - assert stashed_kwargs['use_ssl'] is s3_use_ssl + assert stashed_kwargs["config"].s3["addressing_style"] is s3_addressing_style + assert stashed_kwargs["endpoint_url"] is s3_endpoint_url + assert stashed_kwargs["use_ssl"] is s3_use_ssl def test_upload_missing_version(collection: FileCollection, session, uploader): - dest_name = 'foo.txt' - file_id = '12345' - version = '14' + dest_name = "foo.txt" + file_id = "12345" + version = "14" bad_complete_response = { - 'file_info': { - 'file_id': file_id - }, - 'version': version # 'version' is supposed to go inside 'file_info' + "file_info": {"file_id": file_id}, + "version": version, # 'version' is supposed to go inside 'file_info' } with pytest.raises(RuntimeError): session.set_response(bad_complete_response) @@ -340,31 +343,28 @@ def test_list_file_links(collection: FileCollection, session, valid_data): """Test that all files in a dataset can be turned into FileLink and listed.""" file_id = str(uuid4()) version = str(uuid4()) - filename = 'materials.txt' + filename = "materials.txt" # The actual response contains more fields, but these are the only ones we use. returned_data = { - 'id': file_id, - 'version': version, - 'filename': filename, + "id": file_id, + "version": version, + "filename": filename, } - returned_data["unversioned_url"] = f"http://test.domain.net:8002/api/v1/files/{returned_data['id']}" - returned_data["versioned_url"] = f"http://test.domain.net:8002/api/v1/files/{returned_data['id']}" \ - f"/versions/{returned_data['version']}" - session.set_response({ - 'files': [returned_data] - }) + returned_data["unversioned_url"] = ( + f"http://test.domain.net:8002/api/v1/files/{returned_data['id']}" + ) + returned_data["versioned_url"] = ( + f"http://test.domain.net:8002/api/v1/files/{returned_data['id']}" + f"/versions/{returned_data['version']}" + ) + session.set_response({"files": [returned_data]}) files_iterator = collection.list(per_page=15) files = [file for file in files_iterator] assert session.num_calls == 1 expected_call = FakeCall( - method='GET', - path=collection._get_path(), - params={ - 'per_page': 15, - 'page': 1 - } + method="GET", path=collection._get_path(), params={"per_page": 15, "page": 1} ) assert expected_call == session.last_call assert len(files) == 1 @@ -383,16 +383,24 @@ def test_file_download(collection: FileCollection, session, tmpdir): it does not exist, make a call to get the pre-signed URL, and another to download. """ # Given - filename = 'diagram.pdf' + filename = "diagram.pdf" file_uid = str(uuid4()) version_uid = str(uuid4()) url = f"teams/{collection.team_id}/datasets/{collection.dataset_id}/files/{file_uid}/versions/{version_uid}" - file = FileLink.build(FileLinkDataFactory(url=url, filename=filename, id=file_uid, version=version_uid)) - pre_signed_url = "http://files.citrine.io/secret-codes/jiifema987pjfsda" # arbitrary - session.set_response({ - 'pre_signed_read_link': pre_signed_url, - }) - target_dir = str(tmpdir) + 'some/new/directory/' + file = FileLink.build( + FileLinkDataFactory( + url=url, filename=filename, id=file_uid, version=version_uid + ) + ) + pre_signed_url = ( + "http://files.citrine.io/secret-codes/jiifema987pjfsda" # arbitrary + ) + session.set_response( + { + "pre_signed_read_link": pre_signed_url, + } + ) + target_dir = str(tmpdir) + "some/new/directory/" target_file = target_dir + filename def _checked_write(path, content): @@ -403,28 +411,25 @@ def _checked_write(path, content): # When assert mock_get.call_count == 1 - expected_call = FakeCall( - method='GET', - path=url + '/content-link' - ) + expected_call = FakeCall(method="GET", path=url + "/content-link") assert expected_call == session.last_call - _checked_write(target_dir, 'content') - assert Path(target_file).read_text() == 'content' + _checked_write(target_dir, "content") + assert Path(target_file).read_text() == "content" # Now the directory exists - _checked_write(Path(target_dir), 'other content') - assert Path(target_file).read_text() == 'other content' + _checked_write(Path(target_dir), "other content") + assert Path(target_file).read_text() == "other content" # Give it the filename instead - _checked_write(target_file, 'more content') - assert Path(target_file).read_text() == 'more content' + _checked_write(target_file, "more content") + assert Path(target_file).read_text() == "more content" # And as a Path - _checked_write(target_file, 'love that content') - assert Path(target_file).read_text() == 'love that content' + _checked_write(target_file, "love that content") + assert Path(target_file).read_text() == "love that content" - bad_url = f"bin/uuid3/versions/uuid4" + bad_url = "bin/uuid3/versions/uuid4" bad_file = FileLink.build(FileLinkDataFactory(url=bad_url, filename=filename)) with pytest.raises(ValueError, match="Citrine"): collection.download(file_link=bad_file, local_path=target_dir) @@ -436,62 +441,68 @@ def test_read(collection: FileCollection, session, tmp_path): """ # Given - filename = 'diagram.pdf' + filename = "diagram.pdf" file_uid = str(uuid4()) version_uid = str(uuid4()) url = f"teams/{collection.team_id}/datasets/{collection.dataset_id}/files/{file_uid}/versions/{version_uid}" - file = FileLink.build(FileLinkDataFactory(url=url, filename=filename, id=file_uid, version=version_uid)) - pre_signed_url = "http://files.citrine.io/secret-codes/jiifema987pjfsda" # arbitrary - session.set_response({ - 'pre_signed_read_link': pre_signed_url, - }) + file = FileLink.build( + FileLinkDataFactory( + url=url, filename=filename, id=file_uid, version=version_uid + ) + ) + pre_signed_url = ( + "http://files.citrine.io/secret-codes/jiifema987pjfsda" # arbitrary + ) + session.set_response( + { + "pre_signed_read_link": pre_signed_url, + } + ) with requests_mock.mock() as mock_get: mock_get.get(pre_signed_url, text="lorem ipsum") # When io = collection.read(file_link=file) - assert io.decode('UTF-8') == 'lorem ipsum' + assert io.decode("UTF-8") == "lorem ipsum" # When assert mock_get.call_count == 1 - expected_call = FakeCall( - method='GET', - path=url + '/content-link' - ) + expected_call = FakeCall(method="GET", path=url + "/content-link") assert expected_call == session.last_call - bad_url = f"bin/uuid3/versions/uuid4" + bad_url = "bin/uuid3/versions/uuid4" bad_file = FileLink.build(FileLinkDataFactory(url=bad_url, filename=filename)) with pytest.raises(ValueError, match="Citrine"): collection.read(file_link=bad_file) # Test with files.list endpoint-like object - filelink = collection.build({"id": str(uuid4()), - "version": str(uuid4()), - "filename": filename, - "type": FileLink.typ}) + filelink = collection.build( + { + "id": str(uuid4()), + "version": str(uuid4()), + "filename": filename, + "type": FileLink.typ, + } + ) pre_signed_url_2 = "http://files.citrine.io/secret-codes/2222222222222" # arbitrary - session.set_response({'pre_signed_read_link': pre_signed_url_2}) + session.set_response({"pre_signed_read_link": pre_signed_url_2}) with requests_mock.mock() as mock_get: mock_get.get(pre_signed_url_2, text="quite lovely") # When io = collection.read(file_link=filelink) - assert io.decode('UTF-8') == 'quite lovely' + assert io.decode("UTF-8") == "quite lovely" # When assert mock_get.call_count == 1 - expected_call_2 = FakeCall( - method='GET', - path=filelink.url + '/content-link' - ) + expected_call_2 = FakeCall(method="GET", path=filelink.url + "/content-link") assert expected_call_2 == session.last_call # Test the local read behaves with requests_mock.mock() as mock_get: - local = tmp_path / 'test.txt' + local = tmp_path / "test.txt" content = "This is content" local.write_text(content) # When io = collection.read(file_link=FileLink.from_path(local)) - assert io.decode('UTF-8') == content + assert io.decode("UTF-8") == content # When assert mock_get.call_count == 0 @@ -502,16 +513,16 @@ def test_external_file_read(collection: FileCollection, session): """ # Given - filename = 'spreadsheet.xlsx' + filename = "spreadsheet.xlsx" url = "http://customer.com/data-lake/files/123/versions/456" file = FileLink.build(FileLinkDataFactory(url=url, filename=filename)) with requests_mock.mock() as mock_get: - mock_get.get(url, text='010111011') + mock_get.get(url, text="010111011") # When io = collection.read(file_link=file) - assert io.decode('UTF-8') == '010111011' + assert io.decode("UTF-8") == "010111011" # When assert mock_get.call_count == 1 @@ -525,13 +536,13 @@ def test_external_file_download(collection: FileCollection, session, tmpdir): it does not exist, and make a single call to download. """ # Given - filename = 'spreadsheet.xlsx' + filename = "spreadsheet.xlsx" url = "http://customer.com/data-lake/files/123/versions/456" file = FileLink.build(FileLinkDataFactory(url=url, filename=filename)) - local_path = Path(tmpdir) / 'test_external_file_download/new_name.xlsx' + local_path = Path(tmpdir) / "test_external_file_download/new_name.xlsx" with requests_mock.mock() as mock_get: - mock_get.get(url, text='010111011') + mock_get.get(url, text="010111011") # When collection.download(file_link=file, local_path=local_path) @@ -539,24 +550,30 @@ def test_external_file_download(collection: FileCollection, session, tmpdir): # When assert mock_get.call_count == 1 - assert local_path.read_text() == '010111011' + assert local_path.read_text() == "010111011" def test_ingest(collection: FileCollection, session): """Test the on-platform ingest route.""" - good_file1 = collection.build({"filename": "good.csv", "id": str(uuid4()), "version": str(uuid4())}) - good_file2 = collection.build({"filename": "also.csv", "id": str(uuid4()), "version": str(uuid4())}) + good_file1 = collection.build( + {"filename": "good.csv", "id": str(uuid4()), "version": str(uuid4())} + ) + good_file2 = collection.build( + {"filename": "also.csv", "id": str(uuid4()), "version": str(uuid4())} + ) bad_file = FileLink(filename="bad.csv", url="http://files.com/input.csv") ingest_files_resp = IngestFilesResponseDataFactory() job_id_resp = JobSubmissionResponseDataFactory() job_status_resp = JobStatusResponseDataFactory( - job_id=job_id_resp['job_id'], - job_type='create-gemd-objects', + job_id=job_id_resp["job_id"], + job_type="create-gemd-objects", ) ingest_status_resp = IngestionStatusResponseDataFactory() - session.set_responses(ingest_files_resp, job_id_resp, job_status_resp, ingest_status_resp) + session.set_responses( + ingest_files_resp, job_id_resp, job_status_resp, ingest_status_resp + ) collection.ingest([good_file1, good_file2]) with pytest.raises(ValueError, match=bad_file.url): @@ -568,8 +585,12 @@ def test_ingest(collection: FileCollection, session): with pytest.raises(ValueError): collection.ingest([good_file1], build_table=True) - session.set_responses(ingest_files_resp, job_id_resp, job_status_resp, ingest_status_resp) - coll_with_project_id = FileCollection(team_id=uuid4(), dataset_id=uuid4(), session=session) + session.set_responses( + ingest_files_resp, job_id_resp, job_status_resp, ingest_status_resp + ) + coll_with_project_id = FileCollection( + team_id=uuid4(), dataset_id=uuid4(), session=session + ) coll_with_project_id.project_id = uuid4() with pytest.deprecated_call(): coll_with_project_id.ingest([good_file1], build_table=True) @@ -578,13 +599,13 @@ def test_ingest(collection: FileCollection, session): def test_ingest_with_upload(collection, monkeypatch, tmp_path, session): """Test more advanced workflows, patching to avoid unnecessary complexity.""" - platform_file = FileLink(url='relative/path', filename='file.txt') + platform_file = FileLink(url="relative/path", filename="file.txt") platform_file.uid = uuid4() - external_file = FileLink(url='http://citrine.io/other.txt', filename='other.txt') - local_file = tmp_path / 'file.csv' + external_file = FileLink(url="http://citrine.io/other.txt", filename="other.txt") + local_file = tmp_path / "file.csv" local_file.write_text("a,b,c\n1,2,3") local_file_link = FileLink(filename=local_file.name, url=local_file.as_uri()) - local_none = tmp_path / 'not_here.csv' + local_none = tmp_path / "not_here.csv" def _mock_download(self, *, file_link, local_path): assert file_link == external_file or file_link == local_file_link @@ -593,31 +614,36 @@ def _mock_download(self, *, file_link, local_path): def _mock_upload(self, *, file_path, dest_name=None): uploads.add(dest_name) - return FileLink(url='relative/path', filename=file_path.name) - - def _mock_build_from_file_links(self: IngestionCollection, - file_links: Collection[FileLink], - *, - raise_errors: bool = True - ): + return FileLink(url="relative/path", filename=file_path.name) + + def _mock_build_from_file_links( + self: IngestionCollection, + file_links: Collection[FileLink], + *, + raise_errors: bool = True, + ): assert len(file_links) == 3 assert platform_file in file_links assert external_file not in file_links assert local_file not in file_links - return Ingestion.build({ - "ingestion_id": uuid4(), - "team_id": self.team_id, - "dataset_id": self.dataset_id, - "session": self.session, - "raise_errors": raise_errors, - }) + return Ingestion.build( + { + "ingestion_id": uuid4(), + "team_id": self.team_id, + "dataset_id": self.dataset_id, + "session": self.session, + "raise_errors": raise_errors, + } + ) def _mock_build_objects(self, **_): pass monkeypatch.setattr(FileCollection, "download", _mock_download) monkeypatch.setattr(FileCollection, "upload", _mock_upload) - monkeypatch.setattr(IngestionCollection, "build_from_file_links", _mock_build_from_file_links) + monkeypatch.setattr( + IngestionCollection, "build_from_file_links", _mock_build_from_file_links + ) monkeypatch.setattr(Ingestion, "build_objects", _mock_build_objects) collection.ingest([platform_file, external_file, local_file], upload=True) @@ -640,48 +666,50 @@ def test_resolve_file_link(collection: FileCollection, session): # The actual response contains more fields, but these are the only ones we use. raw_files = [ { - 'id': str(uuid4()), - 'version': str(uuid4()), - 'filename': 'file0.txt', - 'version_number': 1 + "id": str(uuid4()), + "version": str(uuid4()), + "filename": "file0.txt", + "version_number": 1, }, { - 'id': str(uuid4()), - 'version': str(uuid4()), - 'filename': 'file1.txt', - 'version_number': 3 + "id": str(uuid4()), + "version": str(uuid4()), + "filename": "file1.txt", + "version_number": 3, }, { - 'id': str(uuid4()), - 'version': str(uuid4()), - 'filename': 'file2.txt', - 'version_number': 1 + "id": str(uuid4()), + "version": str(uuid4()), + "filename": "file2.txt", + "version_number": 1, }, ] file1_versions = [raw_files[1].copy() for _ in range(3)] - file1_versions[0]['version'] = str(uuid4()) - file1_versions[0]['version_number'] = 1 - file1_versions[2]['version'] = str(uuid4()) - file1_versions[2]['version_number'] = 2 + file1_versions[0]["version"] = str(uuid4()) + file1_versions[0]["version_number"] = 1 + file1_versions[2]["version"] = str(uuid4()) + file1_versions[2]["version_number"] = 2 for raw in raw_files: - raw['unversioned_url'] = f"http://test.domain.net:8002/api/v1/files/{raw['id']}" - raw['versioned_url'] = f"http://test.domain.net:8002/api/v1/files/{raw['id']}/versions/{raw['version']}" + raw["unversioned_url"] = f"http://test.domain.net:8002/api/v1/files/{raw['id']}" + raw["versioned_url"] = ( + f"http://test.domain.net:8002/api/v1/files/{raw['id']}/versions/{raw['version']}" + ) for f1 in file1_versions: - f1['unversioned_url'] = f"http://test.domain.net:8002/api/v1/files/{f1['id']}" - f1['versioned_url'] = f"http://test.domain.net:8002/api/v1/files/{f1['id']}/versions/{f1['version']}" + f1["unversioned_url"] = f"http://test.domain.net:8002/api/v1/files/{f1['id']}" + f1["versioned_url"] = ( + f"http://test.domain.net:8002/api/v1/files/{f1['id']}/versions/{f1['version']}" + ) - session.set_response({ - 'files': raw_files - }) + session.set_response({"files": raw_files}) file1 = collection.build(raw_files[1]) - assert collection._resolve_file_link(file1) == file1, "Resolving a FileLink is a no-op" + assert collection._resolve_file_link(file1) == file1, ( + "Resolving a FileLink is a no-op" + ) assert session.num_calls == 0, "No-op still hit server" - session.set_response({ - 'files': [raw_files[1]] - }) + session.set_response({"files": [raw_files[1]]}) unresolved = GEMDFileLink(filename=file1.filename, url=file1.url) assert collection._resolve_file_link(unresolved) == file1, "FileLink didn't resolve" @@ -698,32 +726,36 @@ def test_resolve_file_link(collection: FileCollection, session): collection._resolve_file_link(unresolved) assert session.num_calls == 2 - assert collection._resolve_file_link(UUID(raw_files[1]['id'])) == file1, "UUID didn't resolve" + assert collection._resolve_file_link(UUID(raw_files[1]["id"])) == file1, ( + "UUID didn't resolve" + ) assert session.num_calls == 3 - session.set_response({ - 'files': [raw_files[1]] - }) - assert collection._resolve_file_link(raw_files[1]['id']) == file1, "String UUID didn't resolve" + session.set_response({"files": [raw_files[1]]}) + assert collection._resolve_file_link(raw_files[1]["id"]) == file1, ( + "String UUID didn't resolve" + ) assert session.num_calls == 4 - assert collection._resolve_file_link(raw_files[1]['version']) == file1, "Version UUID didn't resolve" + assert collection._resolve_file_link(raw_files[1]["version"]) == file1, ( + "Version UUID didn't resolve" + ) assert session.num_calls == 5 abs_link = "https://wwww.website.web/web.pdf" assert collection._resolve_file_link(abs_link).filename == "web.pdf" assert collection._resolve_file_link(abs_link).url == abs_link - session.set_response({ - 'files': [raw_files[1]] - }) - assert collection._resolve_file_link(file1.url) == file1, "Relative path didn't resolve" + session.set_response({"files": [raw_files[1]]}) + assert collection._resolve_file_link(file1.url) == file1, ( + "Relative path didn't resolve" + ) assert session.num_calls == 6 - session.set_response({ - 'files': [raw_files[1]] - }) - assert collection._resolve_file_link(file1.filename) == file1, "Filename didn't resolve" + session.set_response({"files": [raw_files[1]]}) + assert collection._resolve_file_link(file1.filename) == file1, ( + "Filename didn't resolve" + ) assert session.num_calls == 7 with pytest.raises(TypeError): @@ -762,68 +794,83 @@ def test_get_ids_from_url(collection: FileCollection): def test_get(collection: FileCollection, session): raw_files = [ { - 'id': str(uuid4()), - 'version': str(uuid4()), - 'filename': 'file0.txt', - 'version_number': 1 + "id": str(uuid4()), + "version": str(uuid4()), + "filename": "file0.txt", + "version_number": 1, }, { - 'id': str(uuid4()), - 'version': str(uuid4()), - 'filename': 'file1.txt', - 'version_number': 3 + "id": str(uuid4()), + "version": str(uuid4()), + "filename": "file1.txt", + "version_number": 3, }, { - 'id': str(uuid4()), - 'version': str(uuid4()), - 'filename': 'file2.txt', - 'version_number': 1 + "id": str(uuid4()), + "version": str(uuid4()), + "filename": "file2.txt", + "version_number": 1, }, ] file1_versions = [raw_files[1].copy() for _ in range(3)] - file1_versions[0]['version'] = str(uuid4()) - file1_versions[0]['version_number'] = 1 - file1_versions[2]['version'] = str(uuid4()) - file1_versions[2]['version_number'] = 2 + file1_versions[0]["version"] = str(uuid4()) + file1_versions[0]["version_number"] = 1 + file1_versions[2]["version"] = str(uuid4()) + file1_versions[2]["version_number"] = 2 for raw in raw_files: - raw['unversioned_url'] = f"http://test.domain.net:8002/api/v1/files/{raw['id']}" - raw['versioned_url'] = f"http://test.domain.net:8002/api/v1/files/{raw['id']}/versions/{raw['version']}" + raw["unversioned_url"] = f"http://test.domain.net:8002/api/v1/files/{raw['id']}" + raw["versioned_url"] = ( + f"http://test.domain.net:8002/api/v1/files/{raw['id']}/versions/{raw['version']}" + ) for f1 in file1_versions: - f1['unversioned_url'] = f"http://test.domain.net:8002/api/v1/files/{f1['id']}" - f1['versioned_url'] = f"http://test.domain.net:8002/api/v1/files/{f1['id']}/versions/{f1['version']}" + f1["unversioned_url"] = f"http://test.domain.net:8002/api/v1/files/{f1['id']}" + f1["versioned_url"] = ( + f"http://test.domain.net:8002/api/v1/files/{f1['id']}/versions/{f1['version']}" + ) file0 = collection.build(raw_files[0]) file1 = collection.build(raw_files[1]) - session.set_response({ - 'files': [raw_files[1]] - }) - assert collection.get(uid=raw_files[1]['id'], version=raw_files[1]['version']) == file1 + session.set_response({"files": [raw_files[1]]}) + assert ( + collection.get(uid=raw_files[1]["id"], version=raw_files[1]["version"]) == file1 + ) - session.set_response({ - 'files': [raw_files[0]] - }) - assert collection.get(uid=raw_files[0]['id'], version=raw_files[0]['version_number']) == file0 + session.set_response({"files": [raw_files[0]]}) + assert ( + collection.get(uid=raw_files[0]["id"], version=raw_files[0]["version_number"]) + == file0 + ) - session.set_response({ - 'files': [raw_files[1]] - }) - assert collection.get(uid=raw_files[1]['filename'], version=raw_files[1]['version_number']) == file1 + session.set_response({"files": [raw_files[1]]}) + assert ( + collection.get( + uid=raw_files[1]["filename"], version=raw_files[1]["version_number"] + ) + == file1 + ) - session.set_response({ - 'files': [raw_files[1]] - }) - assert collection.get(uid=raw_files[1]['filename'], version=raw_files[1]['version']) == file1 + session.set_response({"files": [raw_files[1]]}) + assert ( + collection.get(uid=raw_files[1]["filename"], version=raw_files[1]["version"]) + == file1 + ) - validation_error = ValidationError.build({"failure_message": "file not found", "failure_id": "failure_id"}) + validation_error = ValidationError.build( + {"failure_message": "file not found", "failure_id": "failure_id"} + ) session.set_response( - NotFound("path", FakeRequestResponseApiError(400, "Not found", [validation_error])) + NotFound( + "path", FakeRequestResponseApiError(400, "Not found", [validation_error]) + ) ) with pytest.raises(NotFound): - collection.get(uid=raw_files[1]['filename'], version=4) + collection.get(uid=raw_files[1]["filename"], version=4) def test_exceptions(collection: FileCollection, session): - file_link = FileLink(url="http://customer.com/data-lake/files/123/versions/456", filename="456") + file_link = FileLink( + url="http://customer.com/data-lake/files/123/versions/456", filename="456" + ) with pytest.raises(ValueError): collection._get_path_from_file_link(file_link) @@ -836,12 +883,16 @@ def test_exceptions(collection: FileCollection, session): with pytest.raises(ValueError): collection.get(uid=uuid4(), version="Words!") - validation_error = ValidationError.build({"failure_message": "file not found", "failure_id": "failure_id"}) + validation_error = ValidationError.build( + {"failure_message": "file not found", "failure_id": "failure_id"} + ) session.set_response( - NotFound("path", FakeRequestResponseApiError(400, "Not found", [validation_error])) + NotFound( + "path", FakeRequestResponseApiError(400, "Not found", [validation_error]) + ) ) with pytest.raises(NotFound): collection.get(uid="name") with pytest.raises(ValueError, match="Windows"): - collection.read(file_link=FileLink('File', 'file://remote/network/file.txt')) + collection.read(file_link=FileLink("File", "file://remote/network/file.txt")) diff --git a/tests/resources/test_gem_table.py b/tests/resources/test_gem_table.py index d71656fb7..98381ba8a 100644 --- a/tests/resources/test_gem_table.py +++ b/tests/resources/test_gem_table.py @@ -20,36 +20,36 @@ def session() -> FakeSession: @pytest.fixture def collection(session) -> GemTableCollection: return GemTableCollection( - team_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), - project_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), - session=session + team_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), + project_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), + session=session, ) def test_deprecated_create_collection(session): with pytest.raises(TypeError): return GemTableCollection( - project_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), - session=session + project_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), session=session ) with pytest.raises(TypeError): return GemTableCollection( - team_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), - session=session + team_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), session=session ) with pytest.raises(TypeError): return GemTableCollection( - team_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), - project_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), + team_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), + project_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), ) @pytest.fixture def table(): def _table(download_url: str) -> GemTable: - return GemTable.build(GemTableDataFactory(signed_download_url=download_url, version=2)) + return GemTable.build( + GemTableDataFactory(signed_download_url=download_url, version=2) + ) return _table @@ -60,13 +60,15 @@ def test_get_table_metadata(collection, session): session.set_response(gem_table) # When - retrieved_table: GemTable = collection.get(gem_table["id"], version=gem_table["version"]) + retrieved_table: GemTable = collection.get( + gem_table["id"], version=gem_table["version"] + ) # Then assert 1 == session.num_calls expect_call = FakeCall( method="GET", - path=f"projects/{collection.project_id}/display-tables/{gem_table['id']}/versions/{gem_table['version']}" + path=f"projects/{collection.project_id}/display-tables/{gem_table['id']}/versions/{gem_table['version']}", ) assert session.last_call == expect_call assert str(retrieved_table.uid) == gem_table["id"] @@ -85,15 +87,19 @@ def test_get_table_metadata(collection, session): assert retrieved_table.version == version_number # Given - config = TableConfig(name="foo", description="bar", datasets=[], variables=[], rows=[], columns=[]) - session.set_response({ - "version": { - "ara_definition": config.dump(), - "version_number": config.version_number, - "id": config.config_uid, - }, - "definition": {"id": uuid4()} - }) + config = TableConfig( + name="foo", description="bar", datasets=[], variables=[], rows=[], columns=[] + ) + session.set_response( + { + "version": { + "ara_definition": config.dump(), + "version_number": config.version_number, + "id": config.config_uid, + }, + "definition": {"id": uuid4()}, + } + ) # Then assert retrieved_table.config.name == config.name @@ -101,7 +107,7 @@ def test_get_table_metadata(collection, session): assert retrieved_table.description == config.description expect_call = FakeCall( method="GET", - path=f"projects/{collection.project_id}/display-tables/{retrieved_table.uid}/versions/{retrieved_table.version}/definition" + path=f"projects/{collection.project_id}/display-tables/{retrieved_table.uid}/versions/{retrieved_table.version}/definition", ) assert session.last_call == expect_call @@ -125,7 +131,7 @@ def test_list_table_versions(collection, session): session.set_response(table_versions) # When - results = list(collection.list_versions(table_versions['tables'][0]['id'])) + results = list(collection.list_versions(table_versions["tables"][0]["id"])) # Then assert len(results) == 3 @@ -140,7 +146,7 @@ def test_list_by_config(collection, session): # When # NOTE: list_by_config returns slightly more info in this call, but it's a superset of # a typical Table, and parsed identically in citrine-python. - results = list(collection.list_by_config(table_versions['tables'][0]['id'])) + results = list(collection.list_by_config(table_versions["tables"][0]["id"])) # Then assert len(results) == 3 @@ -178,29 +184,31 @@ def test_build_from_config(collection: GemTableCollection, session): config_uid = uuid4() config_version = 2 config = TableConfig( - name='foo', - description='bar', - columns=[], - rows=[], - variables=[], - datasets=[] + name="foo", description="bar", columns=[], rows=[], variables=[], datasets=[] ) config.config_uid = config_uid config.version_number = config_version expected_table_data = GemTableDataFactory() session.set_responses( - {'job_id': '12345678-1234-1234-1234-123456789ccc'}, - {'job_type': 'foo', 'status': 'In Progress', 'tasks': []}, - {'job_type': 'foo', 'status': 'Success', 'tasks': [], 'output': { - 'display_table_id': expected_table_data['id'], - 'display_table_version': str(expected_table_data['version']), - 'table_warnings': json.dumps([ - {'limited_results': ['foo', 'bar'], 'total_count': 3}, - ]) - }}, + {"job_id": "12345678-1234-1234-1234-123456789ccc"}, + {"job_type": "foo", "status": "In Progress", "tasks": []}, + { + "job_type": "foo", + "status": "Success", + "tasks": [], + "output": { + "display_table_id": expected_table_data["id"], + "display_table_version": str(expected_table_data["version"]), + "table_warnings": json.dumps( + [ + {"limited_results": ["foo", "bar"], "total_count": 3}, + ] + ), + }, + }, expected_table_data, ) - gem_table = collection.build_from_config(config, version='ignored') + gem_table = collection.build_from_config(config, version="ignored") assert isinstance(gem_table, GemTable) assert session.num_calls == 4 @@ -209,12 +217,7 @@ def test_build_from_config_failures(collection: GemTableCollection, session): with pytest.raises(ValueError): collection.build_from_config(uuid4()) config = TableConfig( - name='foo', - description='bar', - columns=[], - rows=[], - variables=[], - datasets=[] + name="foo", description="bar", columns=[], rows=[], variables=[], datasets=[] ) config.definition_uid = uuid4() with pytest.raises(ValueError): @@ -225,16 +228,26 @@ def test_build_from_config_failures(collection: GemTableCollection, session): collection.build_from_config(config) config.config_uid = uuid4() session.set_responses( - {'job_id': '12345678-1234-1234-1234-123456789ccc'}, - {'job_type': 'foo', 'status': 'Failure', 'tasks': [ - {'task_type': 'foo', 'id': 'foo', 'status': 'Failure', 'failure_reason': 'because', 'dependencies': []} - ]}, + {"job_id": "12345678-1234-1234-1234-123456789ccc"}, + { + "job_type": "foo", + "status": "Failure", + "tasks": [ + { + "task_type": "foo", + "id": "foo", + "status": "Failure", + "failure_reason": "because", + "dependencies": [], + } + ], + }, ) with pytest.raises(JobFailureError): collection.build_from_config(uuid4(), version=1) session.set_responses( - {'job_id': '12345678-1234-1234-1234-123456789ccc'}, - {'job_type': 'foo', 'status': 'In Progress', 'tasks': []}, + {"job_id": "12345678-1234-1234-1234-123456789ccc"}, + {"job_type": "foo", "status": "In Progress", "tasks": []}, ) with pytest.raises(PollingTimeoutError): collection.build_from_config(config, timeout=0) @@ -245,44 +258,44 @@ def test_read_table_from_collection(mock_write_files_locally, collection, table) # When with requests_mock.mock() as mock_get: remote_url = "http://otherhost:4566/anywhere" - mock_get.get(remote_url, text='stuff') + mock_get.get(remote_url, text="stuff") collection.read(table=table(remote_url), local_path="table.pdf") assert mock_get.call_count == 1 assert mock_write_files_locally.call_count == 1 - assert mock_write_files_locally.call_args == call(b'stuff', "table.pdf") + assert mock_write_files_locally.call_args == call(b"stuff", "table.pdf") with requests_mock.mock() as mock_get: # When localstack_url = "http://localstack:4566/anywhere" - mock_get.get(localstack_url, text='stuff') + mock_get.get(localstack_url, text="stuff") collection.read(table=table(localstack_url), local_path="table2.pdf") assert mock_get.call_count == 1 assert mock_write_files_locally.call_count == 2 - assert mock_write_files_locally.call_args == call(b'stuff', "table2.pdf") + assert mock_write_files_locally.call_args == call(b"stuff", "table2.pdf") with requests_mock.mock() as mock_get: # When localstack_url = "http://localstack:4566/anywhere" override_url = "https://fakestack:1337" collection.session.s3_endpoint_url = override_url - mock_get.get(override_url + "/anywhere", text='stuff') + mock_get.get(override_url + "/anywhere", text="stuff") collection.read(table=table(localstack_url), local_path="table3.pdf") assert mock_get.call_count == 1 assert mock_write_files_locally.call_count == 3 - assert mock_write_files_locally.call_args == call(b'stuff', "table3.pdf") + assert mock_write_files_locally.call_args == call(b"stuff", "table3.pdf") with requests_mock.mock() as mock_get: # When localstack_url = "http://localstack:4566/anywhere" override_url = "https://fakestack:1337" collection.session.s3_endpoint_url = override_url - mock_get.get(override_url + "/anywhere", text='stuff') + mock_get.get(override_url + "/anywhere", text="stuff") this_table = table(localstack_url) collection.session.set_response({"tables": [this_table.dump()]}) collection.read(table=this_table.uid, local_path="table4.pdf") assert mock_get.call_count == 1 assert mock_write_files_locally.call_count == 4 - assert mock_write_files_locally.call_args == call(b'stuff', "table4.pdf") + assert mock_write_files_locally.call_args == call(b"stuff", "table4.pdf") def test_read_table_into_memory_from_collection(table, session, collection): @@ -302,7 +315,4 @@ def test_gem_table_entity_dict(): table = GemTable.build(GemTableDataFactory()) entity = table.access_control_dict() - assert entity == { - 'id': str(table.uid), - 'type': 'TABLE' - } + assert entity == {"id": str(table.uid), "type": "TABLE"} diff --git a/tests/resources/test_gemd_resource.py b/tests/resources/test_gemd_resource.py index 780aa30c9..8b15cdfd7 100644 --- a/tests/resources/test_gemd_resource.py +++ b/tests/resources/test_gemd_resource.py @@ -1,39 +1,56 @@ import random -from uuid import uuid4, UUID from os.path import basename +from uuid import UUID, uuid4 import pytest - +from gemd.entity.attribute import Condition, Parameter, Property, PropertyAndConditions from gemd.entity.bounds.integer_bounds import IntegerBounds -from gemd.entity.attribute import Property, Condition, Parameter, PropertyAndConditions -from gemd.entity.value import NominalInteger from gemd.entity.link_by_uid import LinkByUID -from gemd.entity.object.material_spec import MaterialSpec as GemdMaterialSpec +from gemd.entity.object.ingredient_run import IngredientRun as GemdIngredientRun +from gemd.entity.object.ingredient_spec import IngredientSpec as GemdIngredientSpec from gemd.entity.object.material_run import MaterialRun as GemdMaterialRun -from gemd.entity.object.process_spec import ProcessSpec as GemdProcessSpec -from gemd.entity.object.process_run import ProcessRun as GemdProcessRun -from gemd.entity.object.measurement_spec import MeasurementSpec as GemdMeasurementSpec +from gemd.entity.object.material_spec import MaterialSpec as GemdMaterialSpec from gemd.entity.object.measurement_run import MeasurementRun as GemdMeasurementRun -from gemd.entity.object.ingredient_spec import IngredientSpec as GemdIngredientSpec -from gemd.entity.object.ingredient_run import IngredientRun as GemdIngredientRun -from gemd.entity.template.material_template import MaterialTemplate as GemdMaterialTemplate +from gemd.entity.object.measurement_spec import MeasurementSpec as GemdMeasurementSpec +from gemd.entity.object.process_run import ProcessRun as GemdProcessRun +from gemd.entity.object.process_spec import ProcessSpec as GemdProcessSpec +from gemd.entity.template.condition_template import ( + ConditionTemplate as GemdConditionTemplate, +) +from gemd.entity.template.material_template import ( + MaterialTemplate as GemdMaterialTemplate, +) +from gemd.entity.template.measurement_template import ( + MeasurementTemplate as GemdMeasurementTemplate, +) +from gemd.entity.template.parameter_template import ( + ParameterTemplate as GemdParameterTemplate, +) from gemd.entity.template.process_template import ProcessTemplate as GemdProcessTemplate -from gemd.entity.template.measurement_template import MeasurementTemplate as GemdMeasurementTemplate -from gemd.entity.template.condition_template import ConditionTemplate as GemdConditionTemplate -from gemd.entity.template.parameter_template import ParameterTemplate as GemdParameterTemplate -from gemd.entity.template.property_template import PropertyTemplate as GemdPropertyTemplate +from gemd.entity.template.property_template import ( + PropertyTemplate as GemdPropertyTemplate, +) +from gemd.entity.value import NominalInteger -from citrine.exceptions import PollingTimeoutError, JobFailureError -from citrine.resources.api_error import ApiError, ValidationError +from citrine._utils.functions import format_escaped_url +from citrine.exceptions import JobFailureError, PollingTimeoutError +from citrine.resources.api_error import ApiError from citrine.resources.audit_info import AuditInfo from citrine.resources.condition_template import ConditionTemplate -from citrine.resources.data_concepts import DataConcepts, CITRINE_SCOPE, CITRINE_TAG_PREFIX +from citrine.resources.data_concepts import ( + CITRINE_SCOPE, + CITRINE_TAG_PREFIX, + DataConcepts, +) from citrine.resources.gemd_resource import GEMDResourceCollection from citrine.resources.ingredient_run import IngredientRun from citrine.resources.ingredient_spec import IngredientSpec from citrine.resources.material_run import MaterialRun -from citrine.resources.material_spec import MaterialSpecCollection, MaterialSpec -from citrine.resources.material_template import MaterialTemplateCollection, MaterialTemplate +from citrine.resources.material_spec import MaterialSpec, MaterialSpecCollection +from citrine.resources.material_template import ( + MaterialTemplate, + MaterialTemplateCollection, +) from citrine.resources.measurement_run import MeasurementRun from citrine.resources.measurement_spec import MeasurementSpec from citrine.resources.measurement_template import MeasurementTemplate @@ -42,11 +59,12 @@ from citrine.resources.process_spec import ProcessSpec from citrine.resources.process_template import ProcessTemplate from citrine.resources.property_template import PropertyTemplate -from citrine._utils.functions import format_escaped_url - -from tests.utils.factories import MaterialRunDataFactory, MaterialSpecDataFactory -from tests.utils.factories import JobSubmissionResponseDataFactory -from tests.utils.session import FakeSession, FakeCall +from tests.utils.factories import ( + JobSubmissionResponseDataFactory, + MaterialRunDataFactory, + MaterialSpecDataFactory, +) +from tests.utils.session import FakeCall, FakeSession @pytest.fixture @@ -56,24 +74,23 @@ def session() -> FakeSession: @pytest.fixture def gemd_collection(session) -> GEMDResourceCollection: - return GEMDResourceCollection( - team_id=uuid4(), - dataset_id=uuid4(), - session=session - ) + return GEMDResourceCollection(team_id=uuid4(), dataset_id=uuid4(), session=session) + def test_invalid_collection_construction(): with pytest.raises(TypeError): return GEMDResourceCollection( - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - session=session) + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), session=session + ) + def test_deprecation_of_positional_arguments(session): - team_id = UUID('6b608f78-e341-422c-8076-35adc8828000') - check_project = {'project': {'team': {'id': team_id}}} + team_id = UUID("6b608f78-e341-422c-8076-35adc8828000") + check_project = {"project": {"team": {"id": team_id}}} session.set_response(check_project) with pytest.deprecated_call(): - fcol = GEMDResourceCollection(uuid4(), uuid4(), session) + GEMDResourceCollection(uuid4(), uuid4(), session) + def sample_gems(nsamples, **kwargs): factories = [MaterialRunDataFactory, MaterialSpecDataFactory] @@ -87,9 +104,7 @@ def test_get_type(gemd_collection): def test_list(gemd_collection, session): # Given samples = sample_gems(20) - session.set_response({ - 'contents': samples - }) + session.set_response({"contents": samples}) # When gems = list(gemd_collection.list()) @@ -97,52 +112,65 @@ def test_list(gemd_collection, session): # Then assert 1 == session.num_calls expected_call = FakeCall( - method='GET', - path=format_escaped_url('teams/{}/storables', gemd_collection.team_id, gemd_collection.dataset_id), + method="GET", + path=format_escaped_url( + "teams/{}/storables", gemd_collection.team_id, gemd_collection.dataset_id + ), params={ - 'dataset_id': str(gemd_collection.dataset_id), - 'forward': True, - 'ascending': True, - 'per_page': 100 - } + "dataset_id": str(gemd_collection.dataset_id), + "forward": True, + "ascending": True, + "per_page": 100, + }, ) assert expected_call == session.last_call assert len(samples) == len(gems) for i in range(len(gems)): - assert samples[i]['uids']['id'] == gems[i].uids['id'] + assert samples[i]["uids"]["id"] == gems[i].uids["id"] def test_register(gemd_collection): """Check that register routes to the correct collections""" targets = [ MaterialTemplate("foo"), - MaterialSpec("foo", - properties=[PropertyAndConditions( - property=Property("prop", value=NominalInteger(1)), - conditions=[Condition("cond", value=NominalInteger(1))], - )] - ), + MaterialSpec( + "foo", + properties=[ + PropertyAndConditions( + property=Property("prop", value=NominalInteger(1)), + conditions=[Condition("cond", value=NominalInteger(1))], + ) + ], + ), MaterialRun("foo"), ProcessTemplate("foo"), - ProcessSpec("foo", - conditions=[Condition("cond", value=NominalInteger(1))], - parameters=[Parameter("para", value=NominalInteger(1))]), - ProcessRun("foo", - conditions=[Condition("cond", value=NominalInteger(1))], - parameters=[Parameter("para", value=NominalInteger(1))]), + ProcessSpec( + "foo", + conditions=[Condition("cond", value=NominalInteger(1))], + parameters=[Parameter("para", value=NominalInteger(1))], + ), + ProcessRun( + "foo", + conditions=[Condition("cond", value=NominalInteger(1))], + parameters=[Parameter("para", value=NominalInteger(1))], + ), MeasurementTemplate("foo"), - MeasurementSpec("foo", - conditions=[Condition("cond", value=NominalInteger(1))], - parameters=[Parameter("para", value=NominalInteger(1))]), - MeasurementRun("foo", - properties=[Property("prop", value=NominalInteger(1))], - conditions=[Condition("cond", value=NominalInteger(1))], - parameters=[Parameter("para", value=NominalInteger(1))]), + MeasurementSpec( + "foo", + conditions=[Condition("cond", value=NominalInteger(1))], + parameters=[Parameter("para", value=NominalInteger(1))], + ), + MeasurementRun( + "foo", + properties=[Property("prop", value=NominalInteger(1))], + conditions=[Condition("cond", value=NominalInteger(1))], + parameters=[Parameter("para", value=NominalInteger(1))], + ), IngredientSpec("foo"), IngredientRun(), PropertyTemplate("bar", bounds=IntegerBounds(0, 1)), ParameterTemplate("bar", bounds=IntegerBounds(0, 1)), - ConditionTemplate("bar", bounds=IntegerBounds(0, 1)) + ConditionTemplate("bar", bounds=IntegerBounds(0, 1)), ] for obj in targets: @@ -152,7 +180,9 @@ def test_register(gemd_collection): registered = gemd_collection.register(obj, dry_run=False) assert len(obj.uids) == 1 assert len(registered.uids) == 1 - assert basename(gemd_collection.session.calls[-1].path) == basename(gemd_collection._path_template) + assert basename(gemd_collection.session.calls[-1].path) == basename( + gemd_collection._path_template + ) for pair in obj.uids.items(): assert pair[1] == registered.uids[pair[0]] @@ -161,33 +191,44 @@ def test_gemd_register(gemd_collection): """Check that register routes to the correct collections""" targets = [ GemdMaterialTemplate("foo"), - GemdMaterialSpec("foo", - properties=[PropertyAndConditions( - property=Property("prop", value=NominalInteger(1)), - conditions=[Condition("cond", value=NominalInteger(1))], - )] - ), + GemdMaterialSpec( + "foo", + properties=[ + PropertyAndConditions( + property=Property("prop", value=NominalInteger(1)), + conditions=[Condition("cond", value=NominalInteger(1))], + ) + ], + ), GemdMaterialRun("foo"), GemdProcessTemplate("foo"), - GemdProcessSpec("foo", - conditions=[Condition("cond", value=NominalInteger(1))], - parameters=[Parameter("para", value=NominalInteger(1))]), - GemdProcessRun("foo", - conditions=[Condition("cond", value=NominalInteger(1))], - parameters=[Parameter("para", value=NominalInteger(1))]), + GemdProcessSpec( + "foo", + conditions=[Condition("cond", value=NominalInteger(1))], + parameters=[Parameter("para", value=NominalInteger(1))], + ), + GemdProcessRun( + "foo", + conditions=[Condition("cond", value=NominalInteger(1))], + parameters=[Parameter("para", value=NominalInteger(1))], + ), GemdMeasurementTemplate("foo"), - GemdMeasurementSpec("foo", - conditions=[Condition("cond", value=NominalInteger(1))], - parameters=[Parameter("para", value=NominalInteger(1))]), - GemdMeasurementRun("foo", - properties=[Property("prop", value=NominalInteger(1))], - conditions=[Condition("cond", value=NominalInteger(1))], - parameters=[Parameter("para", value=NominalInteger(1))]), + GemdMeasurementSpec( + "foo", + conditions=[Condition("cond", value=NominalInteger(1))], + parameters=[Parameter("para", value=NominalInteger(1))], + ), + GemdMeasurementRun( + "foo", + properties=[Property("prop", value=NominalInteger(1))], + conditions=[Condition("cond", value=NominalInteger(1))], + parameters=[Parameter("para", value=NominalInteger(1))], + ), GemdIngredientSpec("foo"), GemdIngredientRun(), GemdPropertyTemplate("bar", bounds=IntegerBounds(0, 1)), GemdParameterTemplate("bar", bounds=IntegerBounds(0, 1)), - GemdConditionTemplate("bar", bounds=IntegerBounds(0, 1)) + GemdConditionTemplate("bar", bounds=IntegerBounds(0, 1)), ] for obj in targets: @@ -197,7 +238,9 @@ def test_gemd_register(gemd_collection): registered = gemd_collection.register(obj, dry_run=False) assert len(obj.uids) == 1 assert len(registered.uids) == 1 - assert basename(gemd_collection.session.calls[-1].path) == basename(gemd_collection._path_template) + assert basename(gemd_collection.session.calls[-1].path) == basename( + gemd_collection._path_template + ) for pair in obj.uids.items(): assert pair[1] == registered.uids[pair[0]] @@ -205,21 +248,17 @@ def test_gemd_register(gemd_collection): def test_register_no_mutate(gemd_collection): """Check that register routes to the correct collections""" expected = { - MaterialTemplateCollection: MaterialTemplate("foo", - uids={'scope1': 'A', - 'scope2': 'B' - } - ), - MaterialSpecCollection: MaterialSpec("foo", - uids={'id': str(uuid4())} - ), + MaterialTemplateCollection: MaterialTemplate( + "foo", uids={"scope1": "A", "scope2": "B"} + ), + MaterialSpecCollection: MaterialSpec("foo", uids={"id": str(uuid4())}), } for specific_collection, obj in expected.items(): len_before = len(obj.uids) registered = gemd_collection.register(obj) assert len(obj.uids) == len_before for pair in registered.uids.items(): - assert pair[1] == obj.uids.get(pair[0], 'No such key') + assert pair[1] == obj.uids.get(pair[0], "No such key") def test_register_all(gemd_collection): @@ -228,39 +267,67 @@ def test_register_all(gemd_collection): property_template = PropertyTemplate("bar", bounds=bounds) parameter_template = ParameterTemplate("bar", bounds=bounds) condition_template = ConditionTemplate("bar", bounds=bounds) - foo_process_template = ProcessTemplate("foo", - conditions=[[condition_template, bounds]], - parameters=[[parameter_template, bounds]]) + foo_process_template = ProcessTemplate( + "foo", + conditions=[[condition_template, bounds]], + parameters=[[parameter_template, bounds]], + ) foo_process_spec = ProcessSpec("foo", template=foo_process_template) foo_process_run = ProcessRun("foo", spec=foo_process_spec) - foo_material_template = MaterialTemplate("foo", properties=[[property_template, bounds]]) - foo_material_spec = MaterialSpec("foo", template=foo_material_template, process=foo_process_spec) - foo_material_run = MaterialRun("foo", spec=foo_material_spec, process=foo_process_run) - foo_measurement_template = MeasurementTemplate("foo", - conditions=[[condition_template, bounds]], - parameters=[[parameter_template, bounds]], - properties=[[property_template, bounds]]) + foo_material_template = MaterialTemplate( + "foo", properties=[[property_template, bounds]] + ) + foo_material_spec = MaterialSpec( + "foo", template=foo_material_template, process=foo_process_spec + ) + foo_material_run = MaterialRun( + "foo", spec=foo_material_spec, process=foo_process_run + ) + foo_measurement_template = MeasurementTemplate( + "foo", + conditions=[[condition_template, bounds]], + parameters=[[parameter_template, bounds]], + properties=[[property_template, bounds]], + ) foo_measurement_spec = MeasurementSpec("foo", template=foo_measurement_template) - foo_measurement_run = MeasurementRun("foo", spec=foo_measurement_spec, material=foo_material_run) + foo_measurement_run = MeasurementRun( + "foo", spec=foo_measurement_spec, material=foo_material_run + ) - baz_process_template = ProcessTemplate("baz", - conditions=[[condition_template, bounds]], - parameters=[[parameter_template, bounds]]) + baz_process_template = ProcessTemplate( + "baz", + conditions=[[condition_template, bounds]], + parameters=[[parameter_template, bounds]], + ) baz_process_spec = ProcessSpec("baz", template=baz_process_template) baz_process_run = ProcessRun("baz", spec=baz_process_spec) - baz_material_template = MaterialTemplate("baz", properties=[[property_template, bounds]]) - baz_material_spec = MaterialSpec("baz", template=baz_material_template, process=baz_process_spec) - baz_material_run = MaterialRun("baz", spec=baz_material_spec, process=baz_process_run) - baz_measurement_template = MeasurementTemplate("baz", - conditions=[[condition_template, bounds]], - parameters=[[parameter_template, bounds]], - properties=[[property_template, bounds]]) + baz_material_template = MaterialTemplate( + "baz", properties=[[property_template, bounds]] + ) + baz_material_spec = MaterialSpec( + "baz", template=baz_material_template, process=baz_process_spec + ) + baz_material_run = MaterialRun( + "baz", spec=baz_material_spec, process=baz_process_run + ) + baz_measurement_template = MeasurementTemplate( + "baz", + conditions=[[condition_template, bounds]], + parameters=[[parameter_template, bounds]], + properties=[[property_template, bounds]], + ) baz_measurement_spec = MeasurementSpec("baz", template=baz_measurement_template) - baz_measurement_run = MeasurementRun("baz", spec=baz_measurement_spec, material=baz_material_run) + baz_measurement_run = MeasurementRun( + "baz", spec=baz_measurement_spec, material=baz_material_run + ) - foo_baz_ingredient_spec = IngredientSpec("foo", material=foo_material_spec, process=baz_process_spec) - foo_baz_ingredient_run = IngredientRun(spec=foo_baz_ingredient_spec, material=foo_material_run, process=baz_process_run) + foo_baz_ingredient_spec = IngredientSpec( + "foo", material=foo_material_spec, process=baz_process_spec + ) + foo_baz_ingredient_run = IngredientRun( + spec=foo_baz_ingredient_spec, material=foo_material_run, process=baz_process_run + ) expected = [ foo_baz_ingredient_run, @@ -274,7 +341,6 @@ def test_register_all(gemd_collection): foo_process_run, foo_process_spec, foo_process_template, - baz_measurement_run, baz_measurement_spec, baz_measurement_template, @@ -284,10 +350,9 @@ def test_register_all(gemd_collection): baz_process_run, baz_process_spec, baz_process_template, - property_template, parameter_template, - condition_template + condition_template, ] for obj in expected: @@ -320,7 +385,12 @@ def test_register_all(gemd_collection): def test_register_all_dry_run(gemd_collection): """Verify expected behavior around batching. Note we cannot actually test dependencies.""" - from gemd.demo.cake import make_cake_templates, make_cake_spec, make_cake, change_scope + from gemd.demo.cake import ( + change_scope, + make_cake, + make_cake_spec, + make_cake_templates, + ) from gemd.util import flatten change_scope("pr-688") @@ -355,7 +425,9 @@ def test_register_all_object_update(gemd_collection): process = GemdProcessSpec("process") material = GemdMaterialSpec("material", process=process) - registered_process, registered_material = gemd_collection.register_all([process, material]) + registered_process, registered_material = gemd_collection.register_all( + [process, material] + ) assert process.uids == registered_process.uids assert material.uids == registered_material.uids @@ -376,14 +448,22 @@ def test_delete(gemd_collection, session): for obj in targets: for dry_run in True, False: - session.set_response(obj.dump()) # Delete calls get, must return object data internally + session.set_response( + obj.dump() + ) # Delete calls get, must return object data internally gemd_collection.delete(obj, dry_run=dry_run) - assert gemd_collection.session.calls[-1].path.split("/")[-3] == basename(gemd_collection._path_template) + assert gemd_collection.session.calls[-1].path.split("/")[-3] == basename( + gemd_collection._path_template + ) # And again, with uids - session.set_response(obj.dump()) # Delete calls get, must return object data internally + session.set_response( + obj.dump() + ) # Delete calls get, must return object data internally gemd_collection.delete(obj.uid, dry_run=dry_run) - assert gemd_collection.session.calls[-1].path.split("/")[-3] == basename(gemd_collection._path_template) + assert gemd_collection.session.calls[-1].path.split("/")[-3] == basename( + gemd_collection._path_template + ) def test_update(gemd_collection): @@ -393,20 +473,19 @@ def test_update(gemd_collection): template.description = "updated description" template_updated = gemd_collection.update(template) assert template_updated == template - assert gemd_collection.session.calls[0].path == gemd_collection.session.calls[1].path + assert ( + gemd_collection.session.calls[0].path == gemd_collection.session.calls[1].path + ) def test_async_update(gemd_collection, session): """Check that async update returns appropriately returns None on success.""" - obj = ProcessTemplate( - "foo", - uids={'id': str(uuid4())} - ) + obj = ProcessTemplate("foo", uids={"id": str(uuid4())}) fake_job_status_resp = { - 'job_type': 'some_typ', - 'status': 'Success', - 'tasks': [], - 'output': {} + "job_type": "some_typ", + "status": "Success", + "tasks": [], + "output": {}, } session.set_responses(JobSubmissionResponseDataFactory(), fake_job_status_resp) @@ -418,10 +497,7 @@ def test_async_update(gemd_collection, session): def test_async_update_and_no_dataset_id(gemd_collection, session): """Ensure async_update requires a dataset id""" - obj = ProcessTemplate( - "foo", - uids={'id': str(uuid4())} - ) + obj = ProcessTemplate("foo", uids={"id": str(uuid4())}) session.set_response(JobSubmissionResponseDataFactory()) gemd_collection.dataset_id = None @@ -433,36 +509,29 @@ def test_async_update_and_no_dataset_id(gemd_collection, session): def test_async_update_timeout(gemd_collection, session): """Ensure the proper exception is thrown on a timeout error""" - obj = ProcessTemplate( - "foo", - uids={'id': str(uuid4())} - ) + obj = ProcessTemplate("foo", uids={"id": str(uuid4())}) fake_job_status_resp = { - 'job_type': 'some_typ', - 'status': 'Pending', - 'tasks': [], - 'output': {} + "job_type": "some_typ", + "status": "Pending", + "tasks": [], + "output": {}, } session.set_responses(JobSubmissionResponseDataFactory(), fake_job_status_resp) with pytest.raises(PollingTimeoutError): - gemd_collection.async_update(obj, wait_for_response=True, - timeout=-1.0) + gemd_collection.async_update(obj, wait_for_response=True, timeout=-1.0) def test_async_update_and_wait(gemd_collection, session): """Check that async_update parses the response when waiting""" - obj = ProcessTemplate( - "foo", - uids={'id': str(uuid4())} - ) + obj = ProcessTemplate("foo", uids={"id": str(uuid4())}) fake_job_status_resp = { - 'job_type': 'some_typ', - 'status': 'Success', - 'tasks': [], - 'output': {} + "job_type": "some_typ", + "status": "Success", + "tasks": [], + "output": {}, } session.set_responses(JobSubmissionResponseDataFactory(), fake_job_status_resp) @@ -474,15 +543,12 @@ def test_async_update_and_wait(gemd_collection, session): def test_async_update_and_wait_failure(gemd_collection, session): """Check that async_update parses the failure correctly""" - obj = ProcessTemplate( - "foo", - uids={'id': str(uuid4())} - ) + obj = ProcessTemplate("foo", uids={"id": str(uuid4())}) fake_job_status_resp = { - 'job_type': 'some_typ', - 'status': 'Failure', - 'tasks': [], - 'output': {} + "job_type": "some_typ", + "status": "Failure", + "tasks": [], + "output": {}, } session.set_responses(JobSubmissionResponseDataFactory(), fake_job_status_resp) @@ -494,10 +560,7 @@ def test_async_update_and_wait_failure(gemd_collection, session): def test_async_update_with_no_wait(gemd_collection, session): """Check that async_update parses the response when not waiting""" - obj = ProcessTemplate( - "foo", - uids={'id': str(uuid4())} - ) + obj = ProcessTemplate("foo", uids={"id": str(uuid4())}) session.set_response(JobSubmissionResponseDataFactory()) job_id = gemd_collection.async_update(obj, wait_for_response=False) @@ -505,44 +568,41 @@ def test_async_update_with_no_wait(gemd_collection, session): def test_batch_delete(gemd_collection, session): - job_resp = { - 'job_id': '1234' - } + job_resp = {"job_id": "1234"} import json - failures_escaped_json = json.dumps([ - { - "id": { - 'scope': 'somescope', - 'id': 'abcd-1234' - }, - 'cause': { - "code": 400, - "message": "", - "validation_errors": [ - { - "failure_message": "fail msg", - "failure_id": "identifier.coreid.missing" - } - ] + + failures_escaped_json = json.dumps( + [ + { + "id": {"scope": "somescope", "id": "abcd-1234"}, + "cause": { + "code": 400, + "message": "", + "validation_errors": [ + { + "failure_message": "fail msg", + "failure_id": "identifier.coreid.missing", + } + ], + }, } - } - ]) + ] + ) failed_job_resp = { - 'job_type': 'batch_delete', - 'status': 'Success', - 'tasks': [], - 'output': { - 'failures': failures_escaped_json - } + "job_type": "batch_delete", + "status": "Success", + "tasks": [], + "output": {"failures": failures_escaped_json}, } session.set_responses(job_resp, failed_job_resp) # When - del_resp = gemd_collection.batch_delete([UUID( - '16fd2706-8baf-433b-82eb-8c7fada847da')]) + del_resp = gemd_collection.batch_delete( + [UUID("16fd2706-8baf-433b-82eb-8c7fada847da")] + ) # Then assert 2 == session.num_calls @@ -550,13 +610,20 @@ def test_batch_delete(gemd_collection, session): assert len(del_resp) == 1 first_failure = del_resp[0] - expected_api_error = ApiError.build({ - "code": "400", - "message": "", - "validation_errors": [{"failure_message": "fail msg", "failure_id": "identifier.coreid.missing"}] - }) + expected_api_error = ApiError.build( + { + "code": "400", + "message": "", + "validation_errors": [ + { + "failure_message": "fail msg", + "failure_id": "identifier.coreid.missing", + } + ], + } + ) - assert first_failure[0] == LinkByUID('somescope', 'abcd-1234') + assert first_failure[0] == LinkByUID("somescope", "abcd-1234") assert first_failure[1].dump() == expected_api_error.dump() @@ -569,20 +636,33 @@ def test_type_passthrough(gemd_collection, session): """Verify objects that are not directly referenced by objects (e.g., a tuple of Templates) don't get type information stripped.""" # Generate some metadata metadata = { - 'dataset': str(uuid4()), - 'audit_info': AuditInfo.build({"created_by": str(uuid4()), - "created_at": 1559933807392 - }), - "tags": [f"{CITRINE_TAG_PREFIX}::added"] + "dataset": str(uuid4()), + "audit_info": AuditInfo.build( + {"created_by": str(uuid4()), "created_at": 1559933807392} + ), + "tags": [f"{CITRINE_TAG_PREFIX}::added"], } # Set up the Condition Templates low_tmpl, high_tmpl = [ - ConditionTemplate('condition low', uids={CITRINE_SCOPE: str(uuid4())}, bounds=IntegerBounds(1, 10)), - ConditionTemplate('condition high', uids={CITRINE_SCOPE: str(uuid4())}, bounds=IntegerBounds(11, 20)), + ConditionTemplate( + "condition low", + uids={CITRINE_SCOPE: str(uuid4())}, + bounds=IntegerBounds(1, 10), + ), + ConditionTemplate( + "condition high", + uids={CITRINE_SCOPE: str(uuid4())}, + bounds=IntegerBounds(11, 20), + ), ] - session.set_response({"objects": [dict(low_tmpl.dump(), **metadata), - dict(high_tmpl.dump(), **metadata), - ]}) + session.set_response( + { + "objects": [ + dict(low_tmpl.dump(), **metadata), + dict(high_tmpl.dump(), **metadata), + ] + } + ) low_tmpl, high_tmpl = gemd_collection.register_all([low_tmpl, high_tmpl]) assert low_tmpl.dataset is not None assert low_tmpl.audit_info is not None @@ -590,10 +670,12 @@ def test_type_passthrough(gemd_collection, session): assert high_tmpl.audit_info is not None ptempl = ProcessTemplate( - 'my template', + "my template", uids={CITRINE_SCOPE: str(uuid4())}, - conditions=[(low_tmpl, IntegerBounds(2, 4)), (high_tmpl, IntegerBounds(12, 15))], - + conditions=[ + (low_tmpl, IntegerBounds(2, 4)), + (high_tmpl, IntegerBounds(12, 15)), + ], ) session.set_response(dict(ptempl.dump(), **metadata)) ptempl = gemd_collection.register(ptempl) @@ -602,37 +684,36 @@ def test_type_passthrough(gemd_collection, session): arr = [ ProcessSpec( - 'foo', + "foo", uids={CITRINE_SCOPE: str(uuid4())}, template=ptempl, conditions=[ - Condition(name='low', value=NominalInteger(3), template=low_tmpl), - Condition(name='high', value=NominalInteger(13), template=high_tmpl), - ] + Condition(name="low", value=NominalInteger(3), template=low_tmpl), + Condition(name="high", value=NominalInteger(13), template=high_tmpl), + ], ), ProcessSpec( - 'bar', + "bar", uids={CITRINE_SCOPE: str(uuid4())}, template=ptempl, conditions=[ - Condition(name='high', value=NominalInteger(14), template=high_tmpl), - ] + Condition(name="high", value=NominalInteger(14), template=high_tmpl), + ], ), - ProcessSpec('baz', uids={CITRINE_SCOPE: str(uuid4())}), + ProcessSpec("baz", uids={CITRINE_SCOPE: str(uuid4())}), ] session.set_response({"objects": [dict(x.dump(), **metadata) for x in arr]}) pspecs = gemd_collection.register_all(arr) - assert [s.name for s in pspecs] == ['foo', 'bar', 'baz'] + assert [s.name for s in pspecs] == ["foo", "bar", "baz"] assert pspecs == arr def test_tag_magic(gemd_collection, session): auto_tag = f"{CITRINE_TAG_PREFIX}::added" - additions = {"tags": ["tag", auto_tag], - "uids": {CITRINE_SCOPE: str(uuid4()), - "original": "id" - } - } + additions = { + "tags": ["tag", auto_tag], + "uids": {CITRINE_SCOPE: str(uuid4()), "original": "id"}, + } obj1 = ProcessSpec("one", tags=["tag"], uids={"original": "id"}) session.set_response(dict(obj1.dump(), **additions)) diff --git a/tests/resources/test_generative_design_execution.py b/tests/resources/test_generative_design_execution.py index 0a5327d83..9e656649c 100644 --- a/tests/resources/test_generative_design_execution.py +++ b/tests/resources/test_generative_design_execution.py @@ -2,8 +2,12 @@ import uuid from citrine.informatics.generative_design import GenerativeDesignInput -from citrine.informatics.executions.generative_design_execution import GenerativeDesignExecution -from citrine.resources.generative_design_execution import GenerativeDesignExecutionCollection +from citrine.informatics.executions.generative_design_execution import ( + GenerativeDesignExecution, +) +from citrine.resources.generative_design_execution import ( + GenerativeDesignExecutionCollection, +) from citrine.informatics.generative_design import FingerprintType, StructureExclusion from tests.utils.session import FakeSession, FakeCall @@ -22,7 +26,9 @@ def collection(session) -> GenerativeDesignExecutionCollection: @pytest.fixture -def generative_design_execution(collection: GenerativeDesignExecutionCollection, generative_design_execution_dict) -> GenerativeDesignExecution: +def generative_design_execution( + collection: GenerativeDesignExecutionCollection, generative_design_execution_dict +) -> GenerativeDesignExecution: return collection.build(generative_design_execution_dict) @@ -37,16 +43,24 @@ def test_basic_methods(generative_design_execution, collection): def test_build_new_execution(collection, generative_design_execution_dict): - execution: GenerativeDesignExecution = collection.build(generative_design_execution_dict) + execution: GenerativeDesignExecution = collection.build( + generative_design_execution_dict + ) assert str(execution.uid) == generative_design_execution_dict["id"] assert execution.project_id == collection.project_id assert execution._session == collection.session - assert execution.in_progress() and not execution.succeeded() and not execution.failed() + assert ( + execution.in_progress() and not execution.succeeded() and not execution.failed() + ) assert execution.status_detail -def test_trigger_execution(collection: GenerativeDesignExecutionCollection, generative_design_execution_dict, session): +def test_trigger_execution( + collection: GenerativeDesignExecutionCollection, + generative_design_execution_dict, + session, +): # Given session.set_response(generative_design_execution_dict) design_execution_input = GenerativeDesignInput( @@ -63,26 +77,31 @@ def test_trigger_execution(collection: GenerativeDesignExecutionCollection, gene # Then assert str(actual_execution.uid) == generative_design_execution_dict["id"] - expected_path = '/projects/{}/generative-design/executions'.format( + expected_path = "/projects/{}/generative-design/executions".format( collection.project_id, ) assert session.last_call == FakeCall( - method='POST', + method="POST", path=expected_path, json={ - 'seeds': design_execution_input.seeds, - 'fingerprint_type': design_execution_input.fingerprint_type.value, - 'min_fingerprint_similarity': design_execution_input.min_fingerprint_similarity, - 'mutation_per_seed': design_execution_input.mutation_per_seed, - 'structure_exclusions': [ - exclusion.value for exclusion in design_execution_input.structure_exclusions + "seeds": design_execution_input.seeds, + "fingerprint_type": design_execution_input.fingerprint_type.value, + "min_fingerprint_similarity": design_execution_input.min_fingerprint_similarity, + "mutation_per_seed": design_execution_input.mutation_per_seed, + "structure_exclusions": [ + exclusion.value + for exclusion in design_execution_input.structure_exclusions ], - 'min_substructure_counts': design_execution_input.min_substructure_counts, - } + "min_substructure_counts": design_execution_input.min_substructure_counts, + }, ) -def test_generative_design_execution_results(generative_design_execution: GenerativeDesignExecution, session, example_generation_results): +def test_generative_design_execution_results( + generative_design_execution: GenerativeDesignExecution, + session, + example_generation_results, +): # Given session.set_response(example_generation_results) @@ -90,28 +109,34 @@ def test_generative_design_execution_results(generative_design_execution: Genera list(generative_design_execution.results(per_page=4)) # Then - expected_path = '/projects/{}/generative-design/executions/{}/results'.format( + expected_path = "/projects/{}/generative-design/executions/{}/results".format( generative_design_execution.project_id, generative_design_execution.uid, ) - assert session.last_call == FakeCall(method='GET', path=expected_path, params={"per_page": 4, "page": 1}) + assert session.last_call == FakeCall( + method="GET", path=expected_path, params={"per_page": 4, "page": 1} + ) -def test_generative_design_execution_result(generative_design_execution: GenerativeDesignExecution, session, example_generation_results): +def test_generative_design_execution_result( + generative_design_execution: GenerativeDesignExecution, + session, + example_generation_results, +): # Given session.set_response(example_generation_results["response"][0]) # When - result_id=example_generation_results["response"][0]["id"] + result_id = example_generation_results["response"][0]["id"] generative_design_execution.result(result_id=result_id) # Then - expected_path = '/projects/{}/generative-design/executions/{}/results/{}'.format( + expected_path = "/projects/{}/generative-design/executions/{}/results/{}".format( generative_design_execution.project_id, generative_design_execution.uid, result_id, ) - assert session.last_call == FakeCall(method='GET', path=expected_path) + assert session.last_call == FakeCall(method="GET", path=expected_path) def test_list(collection: GenerativeDesignExecutionCollection, session): @@ -119,11 +144,11 @@ def test_list(collection: GenerativeDesignExecutionCollection, session): lst = list(collection.list(per_page=4)) assert len(lst) == 0 - expected_path = '/projects/{}/generative-design/executions'.format(collection.project_id) + expected_path = "/projects/{}/generative-design/executions".format( + collection.project_id + ) assert session.last_call == FakeCall( - method='GET', - path=expected_path, - params={"page": 1, "per_page": 4} + method="GET", path=expected_path, params={"page": 1, "per_page": 4} ) diff --git a/tests/resources/test_ingestion.py b/tests/resources/test_ingestion.py index 5b3f791f0..69cb0da9c 100644 --- a/tests/resources/test_ingestion.py +++ b/tests/resources/test_ingestion.py @@ -7,15 +7,24 @@ from citrine.resources.dataset import Dataset from citrine.resources.file_link import FileLink from citrine.resources.ingestion import ( - Ingestion, IngestionCollection, IngestionStatus, IngestionStatusType, IngestionException, - IngestionErrorTrace, IngestionErrorType, IngestionErrorFamily, IngestionErrorLevel + Ingestion, + IngestionCollection, + IngestionStatus, + IngestionStatusType, + IngestionException, + IngestionErrorTrace, + IngestionErrorType, + IngestionErrorFamily, + IngestionErrorLevel, ) from citrine.jobs.job import JobSubmissionResponse, JobStatusResponse, JobFailureError from citrine.resources.project import Project from tests.utils.factories import ( - DatasetFactory, IngestionStatusResponseDataFactory, JobSubmissionResponseDataFactory, - JobStatusResponseDataFactory + DatasetFactory, + IngestionStatusResponseDataFactory, + JobSubmissionResponseDataFactory, + JobStatusResponseDataFactory, ) from tests.utils.session import FakeCall, FakeSession, FakeRequestResponseApiError @@ -27,7 +36,7 @@ def session() -> FakeSession: @pytest.fixture def dataset(session: Session): - dataset = DatasetFactory(name='Test Dataset') + dataset = DatasetFactory(name="Test Dataset") dataset.team_id = uuid4() dataset.uid = uuid4() dataset.session = session @@ -37,7 +46,7 @@ def dataset(session: Session): @pytest.fixture def deprecated_dataset(session: Session): - deprecated_dataset = DatasetFactory(name='Test Dataset') + deprecated_dataset = DatasetFactory(name="Test Dataset") deprecated_dataset.uid = uuid4() deprecated_dataset.session = session deprecated_dataset.project_id = uuid4() @@ -59,36 +68,37 @@ def file_link(dataset: Dataset) -> FileLink: @pytest.fixture def ingest(collection) -> Ingestion: - return collection.build({ - "ingestion_id": uuid4(), - "team_id": collection.team_id, - "dataset_id": collection.dataset_id - }) + return collection.build( + { + "ingestion_id": uuid4(), + "team_id": collection.team_id, + "dataset_id": collection.dataset_id, + } + ) @pytest.fixture def operation() -> JobSubmissionResponse: - return JobSubmissionResponse.build({ - "job_id": uuid4() - }) + return JobSubmissionResponse.build({"job_id": uuid4()}) @pytest.fixture def status() -> IngestionStatus: - return IngestionStatus.build({ - "status": IngestionStatusType.INGESTION_CREATED, - "errors": [] - }) + return IngestionStatus.build( + {"status": IngestionStatusType.INGESTION_CREATED, "errors": []} + ) def test_create_deprecated_collection(session, deprecated_dataset): - check_project = {'project': {'team': {'id': str(uuid4())}}} + check_project = {"project": {"team": {"id": str(uuid4())}}} session.set_response(check_project) with pytest.deprecated_call(): ingestions = deprecated_dataset.ingestions - assert session.calls == [FakeCall(method="GET", path=f'projects/{ingestions.project_id}')] - assert ingestions._path_template == f'projects/{ingestions.project_id}/ingestions' + assert session.calls == [ + FakeCall(method="GET", path=f"projects/{ingestions.project_id}") + ] + assert ingestions._path_template == f"projects/{ingestions.project_id}/ingestions" def test_not_implementeds(collection): @@ -107,8 +117,8 @@ def test_not_implementeds(collection): def test_deprecation_of_positional_arguments(session): - team_id = UUID('6b608f78-e341-422c-8076-35adc8828000') - check_project = {'project': {'team': {'id': team_id}}} + team_id = UUID("6b608f78-e341-422c-8076-35adc8828000") + check_project = {"project": {"team": {"id": team_id}}} session.set_response(check_project) with pytest.deprecated_call(): IngestionCollection(uuid4(), uuid4(), session) @@ -126,14 +136,15 @@ def test_poll_for_job_completion_signature(ingest, operation, status, monkeypatc outer_raise_errors = None def _mock_poll_for_job_completion( - session, - team_id, - job, - *, - project_id=None, - timeout=-1.0, - polling_delay=-2.0, - raise_errors=True): + session, + team_id, + job, + *, + project_id=None, + timeout=-1.0, + polling_delay=-2.0, + raise_errors=True, + ): nonlocal outer_timeout nonlocal outer_polling_delay nonlocal outer_raise_errors @@ -146,7 +157,10 @@ def _mock_poll_for_job_completion( def _mock_status(self) -> IngestionStatus: return status - monkeypatch.setattr("citrine.resources.ingestion._poll_for_job_completion", _mock_poll_for_job_completion) + monkeypatch.setattr( + "citrine.resources.ingestion._poll_for_job_completion", + _mock_poll_for_job_completion, + ) monkeypatch.setattr(Ingestion, "status", _mock_status) ingest.poll_for_job_completion(operation) @@ -161,19 +175,23 @@ def _mock_status(self) -> IngestionStatus: def test_processing_exceptions(session, ingest, monkeypatch): - def _mock_poll_for_job_completion(**_): return JobStatusResponse.build(JobStatusResponseDataFactory()) # This is mocked equivalently for all tests - monkeypatch.setattr("citrine.resources.ingestion._poll_for_job_completion", _mock_poll_for_job_completion) - validation_error = ValidationError.build({"failure_message": "you failed", "failure_id": "failure_id"}) + monkeypatch.setattr( + "citrine.resources.ingestion._poll_for_job_completion", + _mock_poll_for_job_completion, + ) + validation_error = ValidationError.build( + {"failure_message": "you failed", "failure_id": "failure_id"} + ) # Raise exceptions, but it worked ingest.raise_errors = True session.set_responses( {"job_id": str(uuid4())}, - {"status": IngestionStatusType.INGESTION_CREATED, "errors": []} + {"status": IngestionStatusType.INGESTION_CREATED, "errors": []}, ) result = ingest.build_objects() assert result.success @@ -182,8 +200,10 @@ def _mock_poll_for_job_completion(**_): # Raise exceptions, and build_objects_async returned errors ingest.raise_errors = True session.set_responses( - BadRequest("path", FakeRequestResponseApiError(400, "Bad Request", [validation_error])), - {"status": IngestionStatusType.INGESTION_CREATED, "errors": []} + BadRequest( + "path", FakeRequestResponseApiError(400, "Bad Request", [validation_error]) + ), + {"status": IngestionStatusType.INGESTION_CREATED, "errors": []}, ) with pytest.raises(IngestionException, match="you failed"): ingest.build_objects() @@ -192,7 +212,7 @@ def _mock_poll_for_job_completion(**_): ingest.raise_errors = True session.set_responses( BadRequest("path", FakeRequestResponseApiError(400, "This has no details", [])), - {"status": IngestionStatusType.INGESTION_CREATED, "errors": []} + {"status": IngestionStatusType.INGESTION_CREATED, "errors": []}, ) with pytest.raises(IngestionException, match="no details"): ingest.build_objects() @@ -201,7 +221,7 @@ def _mock_poll_for_job_completion(**_): ingest.raise_errors = True session.set_responses( BadRequest("path", FakeRequestResponseApiError(500, "This was internal", [])), - {"status": IngestionStatusType.INGESTION_CREATED, "errors": []} + {"status": IngestionStatusType.INGESTION_CREATED, "errors": []}, ) with pytest.raises(IngestionException, match="internal"): ingest.build_objects() @@ -210,11 +230,17 @@ def _mock_poll_for_job_completion(**_): ingest.raise_errors = True session.set_responses( {"job_id": str(uuid4())}, - {"status": IngestionStatusType.INGESTION_CREATED, - "errors": [{"msg": "Bad things!", - "level": IngestionErrorLevel.ERROR, - "family": IngestionErrorFamily.STRUCTURE, - "error_type": IngestionErrorType.INVALID_DUPLICATE_NAME}]} + { + "status": IngestionStatusType.INGESTION_CREATED, + "errors": [ + { + "msg": "Bad things!", + "level": IngestionErrorLevel.ERROR, + "family": IngestionErrorFamily.STRUCTURE, + "error_type": IngestionErrorType.INVALID_DUPLICATE_NAME, + } + ], + }, ) with pytest.raises(IngestionException, match="Bad things"): ingest.build_objects() @@ -223,7 +249,7 @@ def _mock_poll_for_job_completion(**_): ingest.raise_errors = False session.set_responses( {"job_id": str(uuid4())}, - {"status": IngestionStatusType.INGESTION_CREATED, "errors": []} + {"status": IngestionStatusType.INGESTION_CREATED, "errors": []}, ) result = ingest.build_objects() assert result.success @@ -231,23 +257,33 @@ def _mock_poll_for_job_completion(**_): # Suppress exceptions, and build_objects_async returned errors ingest.raise_errors = False session.set_responses( - BadRequest("path", FakeRequestResponseApiError(400, "Bad Request", [validation_error])), - {"status": IngestionStatusType.INGESTION_CREATED, - "errors": [{"msg": validation_error.failure_message, - "level": IngestionErrorLevel.ERROR, - "family": IngestionErrorFamily.DATA, - "error_type": IngestionErrorType.INVALID_DUPLICATE_NAME}]} + BadRequest( + "path", FakeRequestResponseApiError(400, "Bad Request", [validation_error]) + ), + { + "status": IngestionStatusType.INGESTION_CREATED, + "errors": [ + { + "msg": validation_error.failure_message, + "level": IngestionErrorLevel.ERROR, + "family": IngestionErrorFamily.DATA, + "error_type": IngestionErrorType.INVALID_DUPLICATE_NAME, + } + ], + }, ) result = ingest.build_objects() assert not result.success - assert any('you failed' in str(e) for e in result.errors) + assert any("you failed" in str(e) for e in result.errors) # Suppress exceptions, and build_objects_async returned errors ingest.raise_errors = False session.set_responses( BadRequest("No API error, so it's thrown", None), - {"status": IngestionStatusType.INGESTION_CREATED, - "errors": [IngestionErrorTrace(validation_error.failure_message).dump()]} + { + "status": IngestionStatusType.INGESTION_CREATED, + "errors": [IngestionErrorTrace(validation_error.failure_message).dump()], + }, ) with pytest.raises(BadRequest): ingest.build_objects() @@ -256,19 +292,23 @@ def _mock_poll_for_job_completion(**_): ingest.raise_errors = False session.set_responses( {"job_id": str(uuid4())}, - {"status": IngestionStatusType.INGESTION_CREATED, - "errors": [IngestionErrorTrace("Sad").dump()] * 3} + { + "status": IngestionStatusType.INGESTION_CREATED, + "errors": [IngestionErrorTrace("Sad").dump()] * 3, + }, ) result = ingest.build_objects() assert not result.success - assert any('Sad' in e.msg for e in result.errors) + assert any("Sad" in e.msg for e in result.errors) -def test_ingestion_with_table_build(session: FakeSession, - ingest: Ingestion, - dataset: Dataset, - deprecated_dataset: Dataset, - file_link: FileLink): +def test_ingestion_with_table_build( + session: FakeSession, + ingest: Ingestion, + dataset: Dataset, + deprecated_dataset: Dataset, + file_link: FileLink, +): # build_objects_async will always approve, if we get that far session.set_responses(JobSubmissionResponseDataFactory()) @@ -299,10 +339,10 @@ def test_ingestion_with_table_build(session: FakeSession, # full build_objects full_build_job = JobSubmissionResponseDataFactory() output = { - 'ingestion_id': str(ingest.uid), - 'gemd_table_config_version': '1', - 'table_build_job_id': str(uuid4()), - 'gemd_table_config_id': str(uuid4()) + "ingestion_id": str(ingest.uid), + "gemd_table_config_version": "1", + "table_build_job_id": str(uuid4()), + "gemd_table_config_id": str(uuid4()), } session.set_responses( full_build_job, @@ -311,36 +351,46 @@ def test_ingestion_with_table_build(session: FakeSession, output=output, ), JobStatusResponseDataFactory(), - IngestionStatusResponseDataFactory() + IngestionStatusResponseDataFactory(), ) status = ingest.build_objects(build_table=True, project=str(project_uuid)) assert status.success -def test_ingestion_flow(session: FakeSession, - ingest: Ingestion, - collection: IngestionCollection, - file_link: FileLink, - monkeypatch): +def test_ingestion_flow( + session: FakeSession, + ingest: Ingestion, + collection: IngestionCollection, + file_link: FileLink, + monkeypatch, +): validation_error = ValidationError.build({"failure_message": "I've failed"}) with pytest.raises(ValueError, match="No files"): collection.build_from_file_links([]) with pytest.raises(ValueError, match="UID"): - collection.build_from_file_links([FileLink(filename="mine.txt", url="http:/external.com")]) + collection.build_from_file_links( + [FileLink(filename="mine.txt", url="http:/external.com")] + ) session.set_response(ingest.dump()) assert collection.build_from_file_links([file_link]).uid == ingest.uid - session.set_response(BadRequest("path", FakeRequestResponseApiError(400, "Sad face", []))) + session.set_response( + BadRequest("path", FakeRequestResponseApiError(400, "Sad face", [])) + ) with pytest.raises(IngestionException, match="Sad face"): collection.build_from_file_links([file_link], raise_errors=True) session.set_response(BadRequest("Generic Failure", None)) with pytest.raises(BadRequest): assert collection.build_from_file_links([file_link], raise_errors=False) - session.set_response(BadRequest("path", FakeRequestResponseApiError(400, "Bad Request", [validation_error]))) + session.set_response( + BadRequest( + "path", FakeRequestResponseApiError(400, "Bad Request", [validation_error]) + ) + ) failed = collection.build_from_file_links([file_link], raise_errors=False) def _raise_exception(): @@ -348,7 +398,7 @@ def _raise_exception(): with monkeypatch.context() as m: # There should be no calls given a failed ingest object - m.setattr(Session, 'request', _raise_exception) + m.setattr(Session, "request", _raise_exception) assert not failed.status().success assert not failed.build_objects().success with pytest.raises(JobFailureError): @@ -367,12 +417,14 @@ def _raise_exception(): JobSubmissionResponseDataFactory(), JobStatusResponseDataFactory(), IngestionStatusResponseDataFactory( - errors=[{ - "family": IngestionErrorFamily.DATA, - "error_type": IngestionErrorType.MISSING_RAW_FOR_INGREDIENT, - "level": IngestionErrorLevel.ERROR, - "msg": "Missing ingredient: \"myristic (14:0)\" (Note ingredient IDs are case sensitive)" - }] + errors=[ + { + "family": IngestionErrorFamily.DATA, + "error_type": IngestionErrorType.MISSING_RAW_FOR_INGREDIENT, + "level": IngestionErrorLevel.ERROR, + "msg": 'Missing ingredient: "myristic (14:0)" (Note ingredient IDs are case sensitive)', + } + ] ), ) with pytest.raises(IngestionException, match="Missing ingredient"): diff --git a/tests/resources/test_ingredient_run.py b/tests/resources/test_ingredient_run.py index ae817c898..9465d7f0f 100644 --- a/tests/resources/test_ingredient_run.py +++ b/tests/resources/test_ingredient_run.py @@ -15,30 +15,32 @@ def session() -> FakeSession: @pytest.fixture def collection(session) -> IngredientRunCollection: return IngredientRunCollection( - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), session=session, - team_id=UUID('6b608f78-e341-422c-8076-35adc8828000') + team_id=UUID("6b608f78-e341-422c-8076-35adc8828000"), ) def test_create_deprecated_collection(session): - project_id = '6b608f78-e341-422c-8076-35adc8828545' - session.set_response({'project': {'team': {'id': UUID("6b608f78-e341-422c-8076-35adc8828000")}}}) + project_id = "6b608f78-e341-422c-8076-35adc8828545" + session.set_response( + {"project": {"team": {"id": UUID("6b608f78-e341-422c-8076-35adc8828000")}}} + ) with pytest.deprecated_call(): IngredientRunCollection( - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), session=session, - project_id=UUID(project_id) + project_id=UUID(project_id), ) - assert session.calls == [FakeCall(method="GET", path=f'projects/{project_id}')] + assert session.calls == [FakeCall(method="GET", path=f"projects/{project_id}")] def test_list_by_spec(collection: IngredientRunCollection): run_noop_gemd_relation_search_test( - search_for='ingredient-runs', - search_with='ingredient-specs', + search_for="ingredient-runs", + search_with="ingredient-specs", collection=collection, search_fn=collection.list_by_spec, ) @@ -46,8 +48,8 @@ def test_list_by_spec(collection: IngredientRunCollection): def test_list_by_material(collection: IngredientRunCollection): run_noop_gemd_relation_search_test( - search_for='ingredient-runs', - search_with='material-runs', + search_for="ingredient-runs", + search_with="material-runs", collection=collection, search_fn=collection.list_by_material, ) @@ -55,8 +57,8 @@ def test_list_by_material(collection: IngredientRunCollection): def test_list_by_process(collection: IngredientRunCollection): run_noop_gemd_relation_search_test( - search_for='ingredient-runs', - search_with='process-runs', + search_for="ingredient-runs", + search_with="process-runs", collection=collection, search_fn=collection.list_by_process, ) @@ -69,14 +71,10 @@ def test_equals(): from gemd.entity.value import NominalReal gemd_obj = GEMDIngredientRun( - mass_fraction=NominalReal(1.0, ""), - notes="I have notes", - tags=["tag!"] + mass_fraction=NominalReal(1.0, ""), notes="I have notes", tags=["tag!"] ) citrine_obj = CitrineIngredientRun( - mass_fraction=NominalReal(1.0, ""), - notes="I have notes", - tags=["tag!"] + mass_fraction=NominalReal(1.0, ""), notes="I have notes", tags=["tag!"] ) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.notes = "Something else" diff --git a/tests/resources/test_ingredient_spec.py b/tests/resources/test_ingredient_spec.py index 52a4b4e21..99fa653ca 100644 --- a/tests/resources/test_ingredient_spec.py +++ b/tests/resources/test_ingredient_spec.py @@ -19,28 +19,32 @@ def session() -> FakeSession: @pytest.fixture def collection(session) -> IngredientSpecCollection: return IngredientSpecCollection( - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - team_id = UUID('6b608f78-e341-422c-8076-35adc8828000'), - session=session) + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), + team_id=UUID("6b608f78-e341-422c-8076-35adc8828000"), + session=session, + ) def test_create_deprecated_collection(session): - project_id = '6b608f78-e341-422c-8076-35adc8828545' - session.set_response({'project': {'team': {'id': UUID("6b608f78-e341-422c-8076-35adc8828000")}}}) + project_id = "6b608f78-e341-422c-8076-35adc8828545" + session.set_response( + {"project": {"team": {"id": UUID("6b608f78-e341-422c-8076-35adc8828000")}}} + ) with pytest.deprecated_call(): IngredientSpecCollection( project_id=UUID(project_id), - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - session=session) + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), + session=session, + ) - assert session.calls == [FakeCall(method="GET", path=f'projects/{project_id}')] + assert session.calls == [FakeCall(method="GET", path=f"projects/{project_id}")] def test_list_by_material(collection: IngredientSpecCollection): run_noop_gemd_relation_search_test( - search_for='ingredient-specs', - search_with='material-specs', + search_for="ingredient-specs", + search_with="material-specs", collection=collection, search_fn=collection.list_by_material, ) @@ -48,8 +52,8 @@ def test_list_by_material(collection: IngredientSpecCollection): def test_list_by_process(collection: IngredientSpecCollection): run_noop_gemd_relation_search_test( - search_for='ingredient-specs', - search_with='process-specs', + search_for="ingredient-specs", + search_with="process-specs", collection=collection, search_fn=collection.list_by_process, ) @@ -62,14 +66,14 @@ def test_equals(): labels=["nice", "words"], mass_fraction=NominalReal(1.0, ""), notes="I have notes", - tags=["tag!"] + tags=["tag!"], ) citrine_obj = CitrineIngredientSpec( name="My Name", labels=["nice", "words"], mass_fraction=NominalReal(1.0, ""), notes="I have notes", - tags=["tag!"] + tags=["tag!"], ) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.notes = "Something else" diff --git a/tests/resources/test_material_run.py b/tests/resources/test_material_run.py index 1c24e4bcd..fb3ae680b 100644 --- a/tests/resources/test_material_run.py +++ b/tests/resources/test_material_run.py @@ -1,18 +1,8 @@ -from uuid import UUID import json +from uuid import UUID import pytest -from citrine._session import Session -from citrine._utils.functions import scrub_none -from citrine.exceptions import BadRequest -from citrine.resources.api_error import ValidationError -from citrine.resources.data_concepts import CITRINE_SCOPE -from citrine.resources.material_run import MaterialRunCollection -from citrine.resources.material_run import MaterialRun as CitrineRun -from citrine.resources.material_run import _inject_default_label_tags -from citrine.resources.gemd_resource import GEMDResourceCollection - -from gemd.demo.cake import make_cake, change_scope +from gemd.demo.cake import change_scope, make_cake from gemd.entity.bounds.integer_bounds import IntegerBounds from gemd.entity.link_by_uid import LinkByUID from gemd.entity.object.material_run import MaterialRun as GEMDRun @@ -20,11 +10,30 @@ from gemd.json import GEMDJson from gemd.util import flatten +from citrine._session import Session +from citrine._utils.functions import scrub_none +from citrine.exceptions import BadRequest +from citrine.resources.api_error import ValidationError +from citrine.resources.data_concepts import CITRINE_SCOPE +from citrine.resources.gemd_resource import GEMDResourceCollection +from citrine.resources.material_run import MaterialRun as CitrineRun +from citrine.resources.material_run import MaterialRunCollection from tests.resources.test_data_concepts import run_noop_gemd_relation_search_test -from tests.utils.factories import MaterialRunFactory, MaterialRunDataFactory, LinkByUIDFactory, \ - MaterialTemplateFactory, MaterialSpecDataFactory, ProcessTemplateFactory -from tests.utils.session import FakeSession, FakeCall, make_fake_cursor_request_function, FakeRequestResponseApiError, \ - FakeRequestResponse +from tests.utils.factories import ( + LinkByUIDFactory, + MaterialRunDataFactory, + MaterialRunFactory, + MaterialSpecDataFactory, + MaterialTemplateFactory, + ProcessTemplateFactory, +) +from tests.utils.session import ( + FakeCall, + FakeRequestResponse, + FakeRequestResponseApiError, + FakeSession, + make_fake_cursor_request_function, +) @pytest.fixture @@ -35,29 +44,34 @@ def session() -> FakeSession: @pytest.fixture def collection(session) -> MaterialRunCollection: return MaterialRunCollection( - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), session=session, - team_id = UUID('6b608f78-e341-422c-8076-35adc8828000')) + team_id=UUID("6b608f78-e341-422c-8076-35adc8828000"), + ) + def test_deprecated_collection_construction(session): with pytest.deprecated_call(): - team_id = UUID('6b608f78-e341-422c-8076-35adc8828000') - check_project = {'project': {'team': {'id': team_id}}} + team_id = UUID("6b608f78-e341-422c-8076-35adc8828000") + check_project = {"project": {"team": {"id": team_id}}} session.set_response(check_project) - mr = MaterialRunCollection( - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), + MaterialRunCollection( + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), session=session, - project_id=UUID('6b608f78-e341-422c-8076-35adc8828545')) + project_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), + ) + def test_invalid_collection_construction(): with pytest.raises(TypeError): - mr = MaterialRunCollection(dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - session=session) + MaterialRunCollection( + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), session=session + ) def test_register_material_run(collection, session): # Given - session.set_response(MaterialRunDataFactory(name='Test MR 123')) + session.set_response(MaterialRunDataFactory(name="Test MR 123")) material_run = MaterialRunFactory() # When @@ -68,21 +82,33 @@ def test_register_material_run(collection, session): def test_register_all(collection, session): - runs = [MaterialRunFactory(name='1'), MaterialRunFactory(name='2'), MaterialRunFactory(name='3')] - session.set_response({'objects': [r.dump() for r in runs]}) + runs = [ + MaterialRunFactory(name="1"), + MaterialRunFactory(name="2"), + MaterialRunFactory(name="3"), + ] + session.set_response({"objects": [r.dump() for r in runs]}) registered = collection.register_all(runs) assert [r.name for r in runs] == [r.name for r in registered] assert len(session.calls) == 1 - assert session.calls[0].method == 'PUT' - assert GEMDResourceCollection(team_id = collection.team_id, dataset_id = collection.dataset_id, session = collection.session)._get_path() \ - in session.calls[0].path + assert session.calls[0].method == "PUT" + assert ( + GEMDResourceCollection( + team_id=collection.team_id, + dataset_id=collection.dataset_id, + session=collection.session, + )._get_path() + in session.calls[0].path + ) with pytest.raises(RuntimeError): - MaterialRunCollection(team_id=collection.team_id, dataset_id=None, session=session).register_all([]) + MaterialRunCollection( + team_id=collection.team_id, dataset_id=None, session=session + ).register_all([]) def test_dry_run_register_material_run(collection, session): # Given - session.set_response(MaterialRunDataFactory(name='Test MR 123')) + session.set_response(MaterialRunDataFactory(name="Test MR 123")) material_run = MaterialRunFactory() # When @@ -90,14 +116,16 @@ def test_dry_run_register_material_run(collection, session): # Then assert "" == str(registered) - assert session.last_call.params == {'dry_run': True} + assert session.last_call.params == {"dry_run": True} def test_nomutate_gemd(collection, session): """When registering a GEMD object, the object should not change (aside from auto ids)""" # Given - session.set_response(MaterialRunDataFactory(name='Test MR mutation')) - before, after = (GEMDRun(name='Main', uids={'nomutate': 'please'}) for _ in range(2)) + session.set_response(MaterialRunDataFactory(name="Test MR mutation")) + before, after = ( + GEMDRun(name="Main", uids={"nomutate": "please"}) for _ in range(2) + ) # When registered = collection.register(after) @@ -113,10 +141,14 @@ def test_get_history(collection, session): # Given cake = make_cake() cake_json = json.loads(GEMDJson(scope=CITRINE_SCOPE).dumps(cake)) - root_link = LinkByUID.build(cake_json.pop('object')) - root_obj = next(o for o in cake_json['context'] if root_link.id == o['uids'].get(root_link.scope)) - cake_json['roots'] = [root_obj] - cake_json['context'].remove(root_obj) + root_link = LinkByUID.build(cake_json.pop("object")) + root_obj = next( + o + for o in cake_json["context"] + if root_link.id == o["uids"].get(root_link.scope) + ) + cake_json["roots"] = [root_obj] + cake_json["context"].remove(root_obj) session.set_response([cake_json]) @@ -126,17 +158,19 @@ def test_get_history(collection, session): # Then assert 1 == session.num_calls expected_call = FakeCall( - method='POST', - path=f'teams/{collection.team_id}/gemd/query/material-histories?filter_nonroot_materials=true', + method="POST", + path=f"teams/{collection.team_id}/gemd/query/material-histories?filter_nonroot_materials=true", json={ - 'criteria': [ + "criteria": [ { - 'datasets': [str(collection.dataset_id)], - 'type': 'terminal_material_run_identifiers_criteria', - 'terminal_material_ids': [{'scope': root_link.scope, 'id': root_link.id}] + "datasets": [str(collection.dataset_id)], + "type": "terminal_material_run_identifiers_criteria", + "terminal_material_ids": [ + {"scope": root_link.scope, "id": root_link.id} + ], } ] - } + }, ) assert expected_call == session.last_call assert run == cake @@ -154,17 +188,19 @@ def test_get_history_no_histories(collection, session): # Then assert 1 == session.num_calls expected_call = FakeCall( - method='POST', - path=f'teams/{collection.team_id}/gemd/query/material-histories?filter_nonroot_materials=true', + method="POST", + path=f"teams/{collection.team_id}/gemd/query/material-histories?filter_nonroot_materials=true", json={ - 'criteria': [ + "criteria": [ { - 'datasets': [str(collection.dataset_id)], - 'type': 'terminal_material_run_identifiers_criteria', - 'terminal_material_ids': [{'scope': CITRINE_SCOPE, 'id': str(root_id)}] + "datasets": [str(collection.dataset_id)], + "type": "terminal_material_run_identifiers_criteria", + "terminal_material_ids": [ + {"scope": CITRINE_SCOPE, "id": str(root_id)} + ], } ] - } + }, ) assert expected_call == session.last_call assert run is None @@ -174,10 +210,14 @@ def test_get_history_no_roots(collection, session): # Given cake = make_cake() cake_json = json.loads(GEMDJson(scope=CITRINE_SCOPE).dumps(cake)) - root_link = LinkByUID.build(cake_json.pop('object')) - root_obj = next(o for o in cake_json['context'] if root_link.id == o['uids'].get(root_link.scope)) - cake_json['roots'] = [] - cake_json['context'].remove(root_obj) + root_link = LinkByUID.build(cake_json.pop("object")) + root_obj = next( + o + for o in cake_json["context"] + if root_link.id == o["uids"].get(root_link.scope) + ) + cake_json["roots"] = [] + cake_json["context"].remove(root_obj) session.set_response([cake_json]) @@ -187,17 +227,19 @@ def test_get_history_no_roots(collection, session): # Then assert 1 == session.num_calls expected_call = FakeCall( - method='POST', - path=f'teams/{collection.team_id}/gemd/query/material-histories?filter_nonroot_materials=true', + method="POST", + path=f"teams/{collection.team_id}/gemd/query/material-histories?filter_nonroot_materials=true", json={ - 'criteria': [ + "criteria": [ { - 'datasets': [str(collection.dataset_id)], - 'type': 'terminal_material_run_identifiers_criteria', - 'terminal_material_ids': [{'scope': root_link.scope, 'id': root_link.id}] + "datasets": [str(collection.dataset_id)], + "type": "terminal_material_run_identifiers_criteria", + "terminal_material_ids": [ + {"scope": root_link.scope, "id": root_link.id} + ], } ] - } + }, ) assert expected_call == session.last_call assert run is None @@ -205,8 +247,8 @@ def test_get_history_no_roots(collection, session): def test_get_material_run(collection, session): # Given - run_data = MaterialRunDataFactory(name='Cake 2') - mr_id = run_data['uids']['id'] + run_data = MaterialRunDataFactory(name="Cake 2") + mr_id = run_data["uids"]["id"] session.set_response(run_data) # When @@ -215,18 +257,19 @@ def test_get_material_run(collection, session): # Then assert 1 == session.num_calls expected_call = FakeCall( - method='GET', - path='teams/{}/datasets/{}/material-runs/id/{}'.format(collection.team_id, collection.dataset_id, mr_id) + method="GET", + path="teams/{}/datasets/{}/material-runs/id/{}".format( + collection.team_id, collection.dataset_id, mr_id + ), ) assert expected_call == session.last_call - assert 'Cake 2' == run.name + assert "Cake 2" == run.name + def test_list_material_runs(collection, session): # Given sample_run = MaterialRunDataFactory() - session.set_response({ - 'contents': [sample_run] - }) + session.set_response({"contents": [sample_run]}) # When runs = list(collection.list()) @@ -235,18 +278,20 @@ def test_list_material_runs(collection, session): assert 1 == session.num_calls expected_call = FakeCall( - method='GET', - path='teams/{}/material-runs'.format(collection.team_id, collection.dataset_id), + method="GET", + path="teams/{}/material-runs".format( + collection.team_id, + ), params={ - 'dataset_id': str(collection.dataset_id), - 'forward': True, - 'ascending': True, - 'per_page': 100 - } + "dataset_id": str(collection.dataset_id), + "forward": True, + "ascending": True, + "per_page": 100, + }, ) assert expected_call == session.last_call assert 1 == len(runs) - assert sample_run['uids'] == runs[0].uids + assert sample_run["uids"] == runs[0].uids def test_cursor_paginated_searches(collection, session): @@ -254,39 +299,45 @@ def test_cursor_paginated_searches(collection, session): Tests that search methods using cursor-pagination are hooked up correctly. There is no real search logic tested here. """ - all_runs = [ - MaterialRunDataFactory(name="foo_{}".format(i)) for i in range(20) - ] + all_runs = [MaterialRunDataFactory(name="foo_{}".format(i)) for i in range(20)] fake_request = make_fake_cursor_request_function(all_runs) # pretty shady, need to add these methods to the fake session to test their # interactions with the actual search methods - setattr(session, 'get_resource', fake_request) - setattr(session, 'post_resource', fake_request) - setattr(session, 'cursor_paged_resource', Session.cursor_paged_resource) + setattr(session, "get_resource", fake_request) + setattr(session, "post_resource", fake_request) + setattr(session, "cursor_paged_resource", Session.cursor_paged_resource) - assert len(list(collection.list_by_name('unused', per_page=2))) == len(all_runs) + assert len(list(collection.list_by_name("unused", per_page=2))) == len(all_runs) assert len(list(collection.list(per_page=2))) == len(all_runs) - assert len(list(collection.list_by_tag('unused', per_page=2))) == len(all_runs) - assert len(list(collection.list_by_attribute_bounds( - {LinkByUIDFactory(): IntegerBounds(1, 5)}, per_page=2))) == len(all_runs) + assert len(list(collection.list_by_tag("unused", per_page=2))) == len(all_runs) + assert len( + list( + collection.list_by_attribute_bounds( + {LinkByUIDFactory(): IntegerBounds(1, 5)}, per_page=2 + ) + ) + ) == len(all_runs) # invalid inputs with pytest.raises(TypeError): collection.list_by_attribute_bounds([1, 5], per_page=2) with pytest.raises(NotImplementedError): - collection.list_by_attribute_bounds({ - LinkByUIDFactory(): IntegerBounds(1, 5), - LinkByUIDFactory(): IntegerBounds(1, 5), - }, per_page=2) + collection.list_by_attribute_bounds( + { + LinkByUIDFactory(): IntegerBounds(1, 5), + LinkByUIDFactory(): IntegerBounds(1, 5), + }, + per_page=2, + ) with pytest.raises(RuntimeError): collection.dataset_id = None - collection.list_by_name('unused', per_page=2) + collection.list_by_name("unused", per_page=2) def test_delete_material_run(collection, session): # Given - material_run_uid = '2d3a782f-aee7-41db-853c-36bf4bff0626' - material_run_scope = 'id' + material_run_uid = "2d3a782f-aee7-41db-853c-36bf4bff0626" + material_run_scope = "id" # When collection.delete(material_run_uid) @@ -294,22 +345,22 @@ def test_delete_material_run(collection, session): # Then assert 1 == session.num_calls expected_call = FakeCall( - method='DELETE', - path='teams/{}/datasets/{}/material-runs/{}/{}'.format( + method="DELETE", + path="teams/{}/datasets/{}/material-runs/{}/{}".format( collection.team_id, collection.dataset_id, material_run_scope, - material_run_uid + material_run_uid, ), - params={'dry_run': False} + params={"dry_run": False}, ) assert expected_call == session.last_call def test_dry_run_delete_material_run(collection, session): # Given - material_run_uid = '2d3a782f-aee7-41db-853c-36bf4bff0626' - material_run_scope = 'id' + material_run_uid = "2d3a782f-aee7-41db-853c-36bf4bff0626" + material_run_scope = "id" # When collection.delete(material_run_uid, dry_run=True) @@ -317,14 +368,14 @@ def test_dry_run_delete_material_run(collection, session): # Then assert 1 == session.num_calls expected_call = FakeCall( - method='DELETE', - path='teams/{}/datasets/{}/material-runs/{}/{}'.format( + method="DELETE", + path="teams/{}/datasets/{}/material-runs/{}/{}".format( collection.team_id, collection.dataset_id, material_run_scope, - material_run_uid + material_run_uid, ), - params={'dry_run': True} + params={"dry_run": True}, ) assert expected_call == session.last_call @@ -342,8 +393,8 @@ def test_material_run_can_get_with_no_id(collection, session): # Given collection.dataset_id = None - run_data = MaterialRunDataFactory(name='Cake 2') - mr_id = run_data['uids']['id'] + run_data = MaterialRunDataFactory(name="Cake 2") + mr_id = run_data["uids"]["id"] session.set_response(run_data) # When @@ -352,17 +403,17 @@ def test_material_run_can_get_with_no_id(collection, session): # Then assert 1 == session.num_calls expected_call = FakeCall( - method='GET', - path='teams/{}/material-runs/id/{}'.format(collection.team_id, mr_id) + method="GET", + path="teams/{}/material-runs/id/{}".format(collection.team_id, mr_id), ) assert expected_call == session.last_call - assert 'Cake 2' == run.name + assert "Cake 2" == run.name def test_get_by_process(collection): run_noop_gemd_relation_search_test( - search_for='material-runs', - search_with='process-runs', + search_for="material-runs", + search_with="process-runs", collection=collection, search_fn=collection.get_by_process, per_page=1, @@ -371,8 +422,8 @@ def test_get_by_process(collection): def test_list_by_spec(collection): run_noop_gemd_relation_search_test( - search_for='material-runs', - search_with='material-specs', + search_for="material-runs", + search_with="material-specs", collection=collection, search_fn=collection.list_by_spec, ) @@ -397,7 +448,8 @@ def test_validate_templates_successful_minimal_params(collection, session): expected_call = FakeCall( method="PUT", path="teams/{}/material-runs/validate-templates".format(team_id), - json={"dataObject": scrub_none(run.dump())}) + json={"dataObject": scrub_none(run.dump())}, + ) assert session.last_call == expected_call assert errors == [] @@ -416,16 +468,23 @@ def test_validate_templates_successful_all_params(collection, session): # When session.set_response("") - errors = collection.validate_templates(model=run, object_template=template, ingredient_process_template=unused_process_template) + errors = collection.validate_templates( + model=run, + object_template=template, + ingredient_process_template=unused_process_template, + ) # Then assert 1 == session.num_calls expected_call = FakeCall( method="PUT", path="teams/{}/material-runs/validate-templates".format(team_id), - json={"dataObject": scrub_none(run.dump()), - "objectTemplate": scrub_none(template.dump()), - "ingredientProcessTemplate": scrub_none(unused_process_template.dump())}) + json={ + "dataObject": scrub_none(run.dump()), + "objectTemplate": scrub_none(template.dump()), + "ingredientProcessTemplate": scrub_none(unused_process_template.dump()), + }, + ) assert session.last_call == expected_call assert errors == [] @@ -439,8 +498,14 @@ def test_validate_templates_errors(collection, session): run = MaterialRunFactory(name="") # When - validation_error = ValidationError.build({"failure_message": "you failed", "failure_id": "failure_id"}) - session.set_response(BadRequest("path", FakeRequestResponseApiError(400, "Bad Request", [validation_error]))) + validation_error = ValidationError.build( + {"failure_message": "you failed", "failure_id": "failure_id"} + ) + session.set_response( + BadRequest( + "path", FakeRequestResponseApiError(400, "Bad Request", [validation_error]) + ) + ) errors = collection.validate_templates(model=run) # Then @@ -448,7 +513,8 @@ def test_validate_templates_errors(collection, session): expected_call = FakeCall( method="PUT", path="teams/{}/material-runs/validate-templates".format(team_id), - json={"dataObject": scrub_none(run.dump())}) + json={"dataObject": scrub_none(run.dump())}, + ) assert session.last_call == expected_call assert len(errors) == 1 assert errors[0].dump() == validation_error.dump() @@ -475,7 +541,11 @@ def test_validate_templates_unrelated_400_with_api_error(collection, session): run = MaterialRunFactory() # When - session.set_response(BadRequest("path", FakeRequestResponseApiError(400, "I am not a validation error", []))) + session.set_response( + BadRequest( + "path", FakeRequestResponseApiError(400, "I am not a validation error", []) + ) + ) with pytest.raises(BadRequest): collection.validate_templates(model=run) @@ -486,40 +556,41 @@ def test_list_by_template(collection, session): """ # Given material_template = MaterialTemplateFactory() - test_scope = 'id' + test_scope = "id" template_id = material_template.uids[test_scope] sample_spec1 = MaterialSpecDataFactory(template=material_template) sample_spec2 = MaterialSpecDataFactory(template=material_template) - key = 'contents' + key = "contents" sample_run1_1 = MaterialRunDataFactory(spec=sample_spec1) sample_run2_1 = MaterialRunDataFactory(spec=sample_spec2) sample_run1_2 = MaterialRunDataFactory(spec=sample_spec1) sample_run2_2 = MaterialRunDataFactory(spec=sample_spec2) - session.set_responses({key: [sample_spec1, sample_spec2]}, {key: [sample_run1_1, sample_run1_2]}, - {key: [sample_run2_1, sample_run2_2]}) + session.set_responses( + {key: [sample_spec1, sample_spec2]}, + {key: [sample_run1_1, sample_run1_2]}, + {key: [sample_run2_1, sample_run2_2]}, + ) # When runs = [run for run in collection.list_by_template(template_id)] # Then assert 3 == session.num_calls - assert runs == [collection.build(run) for run in [sample_run1_1, sample_run1_2, sample_run2_1, sample_run2_2]] + assert runs == [ + collection.build(run) + for run in [sample_run1_1, sample_run1_2, sample_run2_1, sample_run2_2] + ] def test_equals(): """Test basic equality. Complex relationships are tested in test_material_run.test_deep_equals().""" - from citrine.resources.material_run import MaterialRun as CitrineMaterialRun from gemd.entity.object import MaterialRun as GEMDMaterialRun - gemd_obj = GEMDMaterialRun( - name="My Name", - notes="I have notes", - tags=["tag!"] - ) + from citrine.resources.material_run import MaterialRun as CitrineMaterialRun + + gemd_obj = GEMDMaterialRun(name="My Name", notes="I have notes", tags=["tag!"]) citrine_obj = CitrineMaterialRun( - name="My Name", - notes="I have notes", - tags=["tag!"] + name="My Name", notes="I have notes", tags=["tag!"] ) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.notes = "Something else" @@ -527,7 +598,7 @@ def test_equals(): def test_deep_equals(collection): - change_scope('test_deep_equals_scope') + change_scope("test_deep_equals_scope") cake = make_cake() flat_list = flatten(cake) # Note that registered turns them into a flat list of Citrine resources @@ -545,7 +616,7 @@ def test_deep_equals(collection): def test_nonmutating_dry_run(collection): - change_scope('test_deep_equals_scope') + change_scope("test_deep_equals_scope") cake = make_cake() uid_stash = cake.uids.copy() @@ -565,9 +636,9 @@ def test_nonmutating_dry_run(collection): def test_args_only(collection): - """"Test that only arguments to register_all get registered/tested/returned.""" + """ "Test that only arguments to register_all get registered/tested/returned.""" obj = GEMDRun("name", spec=GEMDSpec("name")) - GEMDJson(scope='test_args_only').dumps(obj) # no-op to populate ids + GEMDJson(scope="test_args_only").dumps(obj) # no-op to populate ids dry = collection.register_all([obj], dry_run=True) assert obj in dry assert obj.spec not in dry diff --git a/tests/resources/test_material_spec.py b/tests/resources/test_material_spec.py index e0609eeba..c3089e046 100644 --- a/tests/resources/test_material_spec.py +++ b/tests/resources/test_material_spec.py @@ -2,7 +2,10 @@ import pytest -from citrine.resources.material_spec import MaterialSpec as CitrineMaterialSpec, MaterialSpecCollection +from citrine.resources.material_spec import ( + MaterialSpec as CitrineMaterialSpec, + MaterialSpecCollection, +) from tests.resources.test_data_concepts import run_noop_gemd_relation_search_test from tests.utils.factories import MaterialSpecDataFactory from tests.utils.session import FakeCall, FakeSession @@ -18,28 +21,32 @@ def session() -> FakeSession: @pytest.fixture def collection(session) -> MaterialSpecCollection: return MaterialSpecCollection( - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - team_id = UUID('6b608f78-e341-422c-8076-35adc8828000'), - session=session) + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), + team_id=UUID("6b608f78-e341-422c-8076-35adc8828000"), + session=session, + ) def test_create_deprecated_collection(session): - project_id = '6b608f78-e341-422c-8076-35adc8828545' - session.set_response({'project': {'team': {'id': UUID("6b608f78-e341-422c-8076-35adc8828000")}}}) + project_id = "6b608f78-e341-422c-8076-35adc8828545" + session.set_response( + {"project": {"team": {"id": UUID("6b608f78-e341-422c-8076-35adc8828000")}}} + ) with pytest.deprecated_call(): MaterialSpecCollection( project_id=UUID(project_id), - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - session=session) + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), + session=session, + ) - assert session.calls == [FakeCall(method="GET", path=f'projects/{project_id}')] + assert session.calls == [FakeCall(method="GET", path=f"projects/{project_id}")] def test_list_by_template(collection): run_noop_gemd_relation_search_test( - search_for='material-specs', - search_with='material-templates', + search_for="material-specs", + search_with="material-templates", collection=collection, search_fn=collection.list_by_template, ) @@ -47,8 +54,8 @@ def test_list_by_template(collection): def test_get_by_process(collection): run_noop_gemd_relation_search_test( - search_for='material-specs', - search_with='process-specs', + search_for="material-specs", + search_with="process-specs", collection=collection, search_fn=collection.get_by_process, per_page=1, @@ -62,30 +69,31 @@ def test_repeat_serialization_gemd(collection, session): """ from gemd.entity.object.material_spec import MaterialSpec as GEMDMaterial from gemd.entity.object.process_spec import ProcessSpec as GEMDProcess + # Given - session.set_response(MaterialSpecDataFactory(name='Test gemd mutation')) - proc = GEMDProcess(name='Test gemd mutation (process)', uids={'nomutate': 'process'}) - mat = GEMDMaterial(name='Test gemd mutation', uids={'nomutate': 'material'}, process=proc) + session.set_response(MaterialSpecDataFactory(name="Test gemd mutation")) + proc = GEMDProcess( + name="Test gemd mutation (process)", uids={"nomutate": "process"} + ) + mat = GEMDMaterial( + name="Test gemd mutation", uids={"nomutate": "material"}, process=proc + ) # When collection.register(proc) - session.set_response(MaterialSpecDataFactory(name='Test gemd mutation')) - registered = collection.register(mat) # This will serialize the linked process as a side effect + session.set_response(MaterialSpecDataFactory(name="Test gemd mutation")) + registered = collection.register( + mat + ) # This will serialize the linked process as a side effect # Then assert "" == str(registered) def test_equals(): - gemd_obj = GEMDMaterialSpec( - name="My Name", - notes="I have notes", - tags=["tag!"] - ) + gemd_obj = GEMDMaterialSpec(name="My Name", notes="I have notes", tags=["tag!"]) citrine_obj = CitrineMaterialSpec( - name="My Name", - notes="I have notes", - tags=["tag!"] + name="My Name", notes="I have notes", tags=["tag!"] ) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.notes = "Something else" diff --git a/tests/resources/test_measurement_run.py b/tests/resources/test_measurement_run.py index 9df202700..8b0ca41f1 100644 --- a/tests/resources/test_measurement_run.py +++ b/tests/resources/test_measurement_run.py @@ -15,28 +15,32 @@ def session() -> FakeSession: @pytest.fixture def collection(session) -> MeasurementRunCollection: return MeasurementRunCollection( - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - team_id = UUID('6b608f78-e341-422c-8076-35adc8828000'), - session=session) + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), + team_id=UUID("6b608f78-e341-422c-8076-35adc8828000"), + session=session, + ) def test_create_deprecated_collection(session): - project_id = '6b608f78-e341-422c-8076-35adc8828545' - session.set_response({'project': {'team': {'id': UUID("6b608f78-e341-422c-8076-35adc8828000")}}}) + project_id = "6b608f78-e341-422c-8076-35adc8828545" + session.set_response( + {"project": {"team": {"id": UUID("6b608f78-e341-422c-8076-35adc8828000")}}} + ) with pytest.deprecated_call(): MeasurementRunCollection( project_id=UUID(project_id), - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - session=session) + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), + session=session, + ) - assert session.calls == [FakeCall(method="GET", path=f'projects/{project_id}')] + assert session.calls == [FakeCall(method="GET", path=f"projects/{project_id}")] def test_list_by_template(collection: MeasurementRunCollection): run_noop_gemd_relation_search_test( - search_for='measurement-runs', - search_with='measurement-specs', + search_for="measurement-runs", + search_with="measurement-specs", collection=collection, search_fn=collection.list_by_spec, ) @@ -44,8 +48,8 @@ def test_list_by_template(collection: MeasurementRunCollection): def test_list_by_material(collection: MeasurementRunCollection): run_noop_gemd_relation_search_test( - search_for='measurement-runs', - search_with='material-runs', + search_for="measurement-runs", + search_with="material-runs", collection=collection, search_fn=collection.list_by_material, ) @@ -53,18 +57,14 @@ def test_list_by_material(collection: MeasurementRunCollection): def test_equals(): """Test basic equality. Complex relationships are tested in test_material_run.test_deep_equals().""" - from citrine.resources.measurement_run import MeasurementRun as CitrineMeasurementRun + from citrine.resources.measurement_run import ( + MeasurementRun as CitrineMeasurementRun, + ) from gemd.entity.object import MeasurementRun as GEMDMeasurementRun - gemd_obj = GEMDMeasurementRun( - name="My Name", - notes="I have notes", - tags=["tag!"] - ) + gemd_obj = GEMDMeasurementRun(name="My Name", notes="I have notes", tags=["tag!"]) citrine_obj = CitrineMeasurementRun( - name="My Name", - notes="I have notes", - tags=["tag!"] + name="My Name", notes="I have notes", tags=["tag!"] ) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.notes = "Something else" diff --git a/tests/resources/test_measurement_spec.py b/tests/resources/test_measurement_spec.py index 716da9ee6..a6fe2414b 100644 --- a/tests/resources/test_measurement_spec.py +++ b/tests/resources/test_measurement_spec.py @@ -1,10 +1,13 @@ from uuid import UUID import pytest - + from gemd.entity.object import MeasurementSpec as GEMDMeasurementSpec -from citrine.resources.measurement_spec import MeasurementSpec as CitrineMeasurementSpec, MeasurementSpecCollection +from citrine.resources.measurement_spec import ( + MeasurementSpec as CitrineMeasurementSpec, + MeasurementSpecCollection, +) from tests.resources.test_data_concepts import run_noop_gemd_relation_search_test from tests.utils.session import FakeCall, FakeSession @@ -17,28 +20,32 @@ def session() -> FakeSession: @pytest.fixture def collection(session) -> MeasurementSpecCollection: return MeasurementSpecCollection( - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - team_id = UUID('6b608f78-e341-422c-8076-35adc8828000'), - session=session) + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), + team_id=UUID("6b608f78-e341-422c-8076-35adc8828000"), + session=session, + ) def test_create_deprecated_collection(session): - project_id = '6b608f78-e341-422c-8076-35adc8828545' - session.set_response({'project': {'team': {'id': UUID("6b608f78-e341-422c-8076-35adc8828000")}}}) + project_id = "6b608f78-e341-422c-8076-35adc8828545" + session.set_response( + {"project": {"team": {"id": UUID("6b608f78-e341-422c-8076-35adc8828000")}}} + ) with pytest.deprecated_call(): MeasurementSpecCollection( project_id=UUID(project_id), - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - session=session) + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), + session=session, + ) - assert session.calls == [FakeCall(method="GET", path=f'projects/{project_id}')] + assert session.calls == [FakeCall(method="GET", path=f"projects/{project_id}")] def test_list_by_template(collection: MeasurementSpecCollection): run_noop_gemd_relation_search_test( - search_for='measurement-specs', - search_with='measurement-templates', + search_for="measurement-specs", + search_with="measurement-templates", collection=collection, search_fn=collection.list_by_template, ) @@ -46,15 +53,9 @@ def test_list_by_template(collection: MeasurementSpecCollection): def test_equals(): """Test basic equality. Complex relationships are tested in test_material_run.test_deep_equals().""" - gemd_obj = GEMDMeasurementSpec( - name="My Name", - notes="I have notes", - tags=["tag!"] - ) + gemd_obj = GEMDMeasurementSpec(name="My Name", notes="I have notes", tags=["tag!"]) citrine_obj = CitrineMeasurementSpec( - name="My Name", - notes="I have notes", - tags=["tag!"] + name="My Name", notes="I have notes", tags=["tag!"] ) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.notes = "Something else" diff --git a/tests/resources/test_object_setters.py b/tests/resources/test_object_setters.py index 52e10c3cb..1cc3224d2 100644 --- a/tests/resources/test_object_setters.py +++ b/tests/resources/test_object_setters.py @@ -9,7 +9,6 @@ from citrine.resources.material_run import MaterialRun from citrine.resources.material_spec import MaterialSpec from citrine.resources.measurement_run import MeasurementRun -from citrine.resources.ingredient_run import IngredientRun from citrine.resources.ingredient_spec import IngredientSpec @@ -23,8 +22,11 @@ def test_soft_process_material_attachment(): def test_soft_measurement_material_attachment(): """Test that soft attachments are formed from materials to measurements.""" cake = MaterialRun("A cake") - smell_test = MeasurementRun("use your nose", material=cake, properties=[ - Property(name="Smell", value=DiscreteCategorical("yummy"))]) + smell_test = MeasurementRun( + "use your nose", + material=cake, + properties=[Property(name="Smell", value=DiscreteCategorical("yummy"))], + ) taste_test = MeasurementRun("taste", material=cake) assert cake.measurements == [smell_test, taste_test] @@ -34,11 +36,14 @@ def test_soft_process_ingredient_attachment(): vinegar = MaterialSpec("vinegar") baking_soda = MaterialSpec("baking soda") eruption = ProcessSpec("Volcano eruption") - vinegar_sample = IngredientSpec("a bit of vinegar", material=vinegar, process=eruption) + vinegar_sample = IngredientSpec( + "a bit of vinegar", material=vinegar, process=eruption + ) baking_soda_sample = IngredientSpec("a bit of NaOh", material=baking_soda) baking_soda_sample.process = eruption - assert set(eruption.ingredients) == {vinegar_sample, baking_soda_sample}, \ + assert set(eruption.ingredients) == {vinegar_sample, baking_soda_sample}, ( "Creating an ingredient for a process did not auto-populate that process's ingredient list" + ) def test_object_pointer_serde(): diff --git a/tests/resources/test_predictor.py b/tests/resources/test_predictor.py index efce269ec..6e7ef5f54 100644 --- a/tests/resources/test_predictor.py +++ b/tests/resources/test_predictor.py @@ -1,41 +1,44 @@ """Tests predictor collection""" -import mock -import pytest + import uuid from copy import deepcopy -from citrine.exceptions import BadRequest, Conflict, ModuleRegistrationFailedException, NotFound +import mock +import pytest + +from citrine.exceptions import ModuleRegistrationFailedException, NotFound from citrine.informatics.data_sources import GemTableDataSource from citrine.informatics.descriptors import RealDescriptor from citrine.informatics.predictors import ( AutoMLPredictor, ExpressionPredictor, GraphPredictor, - SimpleMixturePredictor + SimpleMixturePredictor, ) -from citrine.resources.predictor import PredictorCollection, _PredictorVersionCollection, AutoConfigureMode -from tests.conftest import build_predictor_entity -from tests.utils.session import ( - FakeCall, - FakeRequestResponse, - FakeSession +from citrine.resources.predictor import ( + AutoConfigureMode, + PredictorCollection, + _PredictorVersionCollection, ) +from tests.conftest import build_predictor_entity from tests.utils.factories import ( - AsyncDefaultPredictorResponseFactory, AsyncDefaultPredictorResponseMetadataFactory, - TableDataSourceDataFactory + AsyncDefaultPredictorResponseFactory, + AsyncDefaultPredictorResponseMetadataFactory, + TableDataSourceDataFactory, ) +from tests.utils.session import FakeCall, FakeRequestResponse, FakeSession def paging_response(*items): return {"response": items} -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def basic_predictor_report_data(): return { - 'id': str(uuid.uuid4()), - 'status': 'VALID', - 'report': {'descriptors': [], 'models': []} + "id": str(uuid.uuid4()), + "status": "VALID", + "report": {"descriptors": [], "models": []}, } @@ -44,22 +47,30 @@ def test_build(valid_graph_predictor_data, basic_predictor_report_data): session.set_response(basic_predictor_report_data) pc = PredictorCollection(uuid.uuid4(), session) predictor = pc.build(valid_graph_predictor_data) - assert predictor.name == 'Graph predictor' - assert predictor.description == 'description' + assert predictor.name == "Graph predictor" + assert predictor.description == "description" def test_build_with_status(valid_graph_predictor_data, basic_predictor_report_data): session = FakeSession() session.set_response(basic_predictor_report_data) - status_detail_data = {("Info", "info_msg"), ("Warning", "warning msg"), ("Error", "error msg")} + status_detail_data = { + ("Info", "info_msg"), + ("Warning", "warning msg"), + ("Error", "error msg"), + } data = deepcopy(valid_graph_predictor_data) - data["metadata"]["status"]["detail"] = [{"level": level, "msg": msg} for level, msg in status_detail_data] + data["metadata"]["status"]["detail"] = [ + {"level": level, "msg": msg} for level, msg in status_detail_data + ] pc = PredictorCollection(uuid.uuid4(), session) predictor = pc.build(data) - status_detail_tuples = {(detail.level, detail.msg) for detail in predictor.status_detail} + status_detail_tuples = { + (detail.level, detail.msg) for detail in predictor.status_detail + } assert status_detail_tuples == status_detail_data @@ -78,27 +89,35 @@ def test_delete_version(): def test_archive_root(valid_graph_predictor_data): session = FakeSession() pc = PredictorCollection(uuid.uuid4(), session) - predictors_path = PredictorCollection._path_template.format(project_id=pc.project_id) + predictors_path = PredictorCollection._path_template.format( + project_id=pc.project_id + ) pred_id = valid_graph_predictor_data["id"] session.set_response(None) pc.archive_root(pred_id) - assert session.calls == [FakeCall(method='PUT', path=f"{predictors_path}/{pred_id}/archive", json={})] + assert session.calls == [ + FakeCall(method="PUT", path=f"{predictors_path}/{pred_id}/archive", json={}) + ] def test_restore_root(valid_graph_predictor_data): session = FakeSession() pc = PredictorCollection(uuid.uuid4(), session) - predictors_path = PredictorCollection._path_template.format(project_id=pc.project_id) + predictors_path = PredictorCollection._path_template.format( + project_id=pc.project_id + ) pred_id = valid_graph_predictor_data["id"] session.set_response(None) pc.restore_root(pred_id) - assert session.calls == [FakeCall(method='PUT', path=f"{predictors_path}/{pred_id}/restore", json={})] + assert session.calls == [ + FakeCall(method="PUT", path=f"{predictors_path}/{pred_id}/restore", json={}) + ] def test_root_is_archived(valid_graph_predictor_data): @@ -123,8 +142,8 @@ def test_graph_build(valid_graph_predictor_data, basic_predictor_report_data): session.get_resource.return_value = basic_predictor_report_data pc = PredictorCollection(uuid.uuid4(), session) predictor = pc.build(valid_graph_predictor_data) - assert predictor.name == 'Graph predictor' - assert predictor.description == 'description' + assert predictor.name == "Graph predictor" + assert predictor.description == "description" assert len(predictor.predictors) == 5 assert len(predictor.training_data) == 1 @@ -140,7 +159,12 @@ def test_register(valid_graph_predictor_data): predictors_path = f"/projects/{pc.project_id}/predictors" expected_calls = [ FakeCall(method="POST", path=predictors_path, json=predictor.dump()), - FakeCall(method="PUT", path=f"{predictors_path}/{entity['id']}/train", params={"create_version": True}, json={}), + FakeCall( + method="PUT", + path=f"{predictors_path}/{entity['id']}/train", + params={"create_version": True}, + json={}, + ), ] pc.register(predictor) @@ -174,20 +198,21 @@ def test_graph_register(valid_graph_predictor_data): pc = PredictorCollection(uuid.uuid4(), session) predictor = GraphPredictor.build(valid_graph_predictor_data) registered = pc.register(predictor) - - assert registered.name == 'Graph predictor' + + assert registered.name == "Graph predictor" def test_failed_register(valid_graph_predictor_data): session = mock.Mock() - session.post_resource.side_effect = NotFound("/projects/uuid/not_found", - FakeRequestResponse(400)) + session.post_resource.side_effect = NotFound( + "/projects/uuid/not_found", FakeRequestResponse(400) + ) pc = PredictorCollection(uuid.uuid4(), session) predictor = GraphPredictor.build(valid_graph_predictor_data) with pytest.raises(ModuleRegistrationFailedException) as e: pc.register(predictor) assert 'The "GraphPredictor" failed to register.' in str(e.value) - assert '/projects/uuid/not_found' in str(e.value) + assert "/projects/uuid/not_found" in str(e.value) def test_update(valid_graph_predictor_data): @@ -198,11 +223,18 @@ def test_update(valid_graph_predictor_data): predictor = pc.build(entity) - predictors_path = PredictorCollection._path_template.format(project_id=pc.project_id) + predictors_path = PredictorCollection._path_template.format( + project_id=pc.project_id + ) entity_path = f"{predictors_path}/{entity['id']}" expected_calls = [ FakeCall(method="PUT", path=entity_path, json=predictor.dump()), - FakeCall(method="PUT", path=f"{entity_path}/train", params={"create_version": True}, json={}), + FakeCall( + method="PUT", + path=f"{entity_path}/train", + params={"create_version": True}, + json={}, + ), ] pc.update(predictor) @@ -218,7 +250,9 @@ def test_update_no_train(valid_graph_predictor_data): predictor = pc.build(entity) - predictors_path = PredictorCollection._path_template.format(project_id=pc.project_id) + predictors_path = PredictorCollection._path_template.format( + project_id=pc.project_id + ) entity_path = f"{predictors_path}/{entity['id']}" expected_calls = [ FakeCall(method="PUT", path=entity_path, json=predictor.dump()), @@ -241,7 +275,7 @@ def test_register_update_checks_status(valid_graph_predictor_data): invalid_entity = build_predictor_entity( instance, status_name="INVALID", - status_detail=[{"level": "Error", "msg": "AHH IT BURNSSSSS!!!!"}] + status_detail=[{"level": "Error", "msg": "AHH IT BURNSSSSS!!!!"}], ) # Register returns first (invalid) response if failed @@ -267,10 +301,17 @@ def test_train(valid_graph_predictor_data): predictor = pc.build(entity) - predictors_path = PredictorCollection._path_template.format(project_id=pc.project_id) + predictors_path = PredictorCollection._path_template.format( + project_id=pc.project_id + ) entity_path = f"{predictors_path}/{entity['id']}" expected_calls = [ - FakeCall(method="PUT", path=f"{entity_path}/train", params={"create_version": True}, json={}), + FakeCall( + method="PUT", + path=f"{entity_path}/train", + params={"create_version": True}, + json={}, + ), ] pc.train(predictor.uid) @@ -284,22 +325,24 @@ def test_list(valid_graph_predictor_data, valid_graph_predictor_data_empty): collection = PredictorCollection(uuid.uuid4(), session) session.set_responses( { - 'response': [valid_graph_predictor_data, valid_graph_predictor_data_empty], - 'page': 1, - 'per_page': 25 + "response": [valid_graph_predictor_data, valid_graph_predictor_data_empty], + "page": 1, + "per_page": 25, }, basic_predictor_report_data, - basic_predictor_report_data + basic_predictor_report_data, ) # When predictors = list(collection.list(per_page=25)) # Then - expected_call = FakeCall(method='GET', - path='/projects/{}/predictors'.format(collection.project_id), - params={'per_page': 25, 'page': 1, 'archived': False}, - version="v4") + expected_call = FakeCall( + method="GET", + path="/projects/{}/predictors".format(collection.project_id), + params={"per_page": 25, "page": 1, "archived": False}, + version="v4", + ) assert 1 == session.num_calls, session.calls assert expected_call == session.calls[0] assert len(predictors) == 2 @@ -310,19 +353,21 @@ def test_list_all(valid_graph_predictor_data, valid_graph_predictor_data_empty): session = FakeSession() collection = PredictorCollection(uuid.uuid4(), session) session.set_responses( - {'response': [valid_graph_predictor_data, valid_graph_predictor_data_empty]}, + {"response": [valid_graph_predictor_data, valid_graph_predictor_data_empty]}, + basic_predictor_report_data, basic_predictor_report_data, - basic_predictor_report_data ) # When predictors = list(collection.list_all(per_page=25)) # Then - expected_call = FakeCall(method='GET', - path='/projects/{}/predictors'.format(collection.project_id), - params={'per_page': 25, 'page': 1}, - version="v4") + expected_call = FakeCall( + method="GET", + path="/projects/{}/predictors".format(collection.project_id), + params={"per_page": 25, "page": 1}, + version="v4", + ) assert 1 == session.num_calls, session.calls assert expected_call == session.calls[0] assert len(predictors) == 2 @@ -331,7 +376,7 @@ def test_list_all(valid_graph_predictor_data, valid_graph_predictor_data_empty): def test_list_archived(valid_graph_predictor_data): # Given session = FakeSession() - session.set_response({'response': [valid_graph_predictor_data]}) + session.set_response({"response": [valid_graph_predictor_data]}) pc = PredictorCollection(uuid.uuid4(), session) # When @@ -339,10 +384,12 @@ def test_list_archived(valid_graph_predictor_data): # Then assert session.num_calls == 1 - assert session.last_call == FakeCall(method='GET', - path=f"/projects/{pc.project_id}/predictors", - params={'per_page': 20, 'page': 1, 'archived': True}, - version="v4") + assert session.last_call == FakeCall( + method="GET", + path=f"/projects/{pc.project_id}/predictors", + params={"per_page": 20, "page": 1, "archived": True}, + version="v4", + ) def test_get(valid_graph_predictor_data): @@ -359,9 +406,9 @@ def test_get(valid_graph_predictor_data): # Then expected_call = FakeCall( - method='GET', - path=f'/projects/{pc.project_id}/predictors/{id}/versions/{version}', - params={} + method="GET", + path=f"/projects/{pc.project_id}/predictors/{id}/versions/{version}", + params={}, ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -390,7 +437,12 @@ def test_check_update_none(): # then assert update_check is None - expected_call = FakeCall(method='GET', path='/projects/{}/predictors/{}/update-check'.format(pc.project_id, predictor_id)) + expected_call = FakeCall( + method="GET", + path="/projects/{}/predictors/{}/update-check".format( + pc.project_id, predictor_id + ), + ) assert session.calls[0] == expected_call @@ -399,21 +451,25 @@ def test_check_update_some(): # given session = FakeSession() desc = RealDescriptor("spam", lower_bound=0, upper_bound=1, units="kg") - response = GraphPredictor.wrap_instance({ - "type": "Graph", - "name": "foo", - "description": "bar", - "predictors": [ - { - "type": "AnalyticExpression", - "name": "foo", - "description": "bar", - "expression": "2 * x", - "output": RealDescriptor("spam", lower_bound=0, upper_bound=1, units="kg").dump(), - "aliases": {} - } - ] - }) + response = GraphPredictor.wrap_instance( + { + "type": "Graph", + "name": "foo", + "description": "bar", + "predictors": [ + { + "type": "AnalyticExpression", + "name": "foo", + "description": "bar", + "expression": "2 * x", + "output": RealDescriptor( + "spam", lower_bound=0, upper_bound=1, units="kg" + ).dump(), + "aliases": {}, + } + ], + } + ) session.set_responses({"updatable": True, **response}) pc = PredictorCollection(uuid.uuid4(), session) predictor_id = uuid.uuid4() @@ -422,13 +478,11 @@ def test_check_update_some(): update_check = pc.check_for_update(predictor_id) # then - assert pc._api_version == 'v3' - exp = ExpressionPredictor("foo", description="bar", expression="2 * x", output=desc, aliases={}) - expected = GraphPredictor( - name="foo", - description="bar", - predictors=[exp] + assert pc._api_version == "v3" + exp = ExpressionPredictor( + "foo", description="bar", expression="2 * x", output=desc, aliases={} ) + expected = GraphPredictor(name="foo", description="bar", predictors=[exp]) assert update_check.dump() == expected.dump() assert update_check.uid == predictor_id @@ -441,9 +495,15 @@ def test_unexpected_pattern(): # Then with pytest.raises(ValueError): - pc.create_default(training_data=GemTableDataSource(table_id=uuid.uuid4(), table_version=0), pattern="yogurt") + pc.create_default( + training_data=GemTableDataSource(table_id=uuid.uuid4(), table_version=0), + pattern="yogurt", + ) with pytest.raises(ValueError): - pc.create_default_async(training_data=GemTableDataSource(table_id=uuid.uuid4(), table_version=0), pattern="yogurt") + pc.create_default_async( + training_data=GemTableDataSource(table_id=uuid.uuid4(), table_version=0), + pattern="yogurt", + ) def test_create_default_mode_pattern(valid_graph_predictor_data): @@ -458,11 +518,14 @@ def test_create_default_mode_pattern(valid_graph_predictor_data): pc = PredictorCollection(uuid.uuid4(), session) # When - pc.create_default(training_data=GemTableDataSource(table_id=uuid.uuid4(), table_version=0), pattern=AutoConfigureMode.INFER) + pc.create_default( + training_data=GemTableDataSource(table_id=uuid.uuid4(), table_version=0), + pattern=AutoConfigureMode.INFER, + ) # Then - assert (session.calls[0].json['pattern'] == "INFER") - assert (session.calls[0].json['prefer_valid'] == True) + assert session.calls[0].json["pattern"] == "INFER" + assert session.calls[0].json["prefer_valid"] is True def test_returned_predictor(valid_graph_predictor_data): @@ -477,7 +540,10 @@ def test_returned_predictor(valid_graph_predictor_data): pc = PredictorCollection(uuid.uuid4(), session) # When - result = pc.create_default(training_data=GemTableDataSource(table_id=uuid.uuid4(), table_version=0), pattern="PLAIN") + result = pc.create_default( + training_data=GemTableDataSource(table_id=uuid.uuid4(), table_version=0), + pattern="PLAIN", + ) # Then the response is parsed in a predictor assert result.name == valid_graph_predictor_data["data"]["name"] @@ -500,7 +566,9 @@ def test_list_versions(valid_graph_predictor_data): predictor_v2 = deepcopy(valid_graph_predictor_data) predictor_v2["metadata"]["version"] = 2 - versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) + versions_path = _PredictorVersionCollection._path_template.format( + project_id=pc.project_id, uid=pred_id + ) session.set_response(paging_response(predictor_v1, predictor_v2)) @@ -508,7 +576,9 @@ def test_list_versions(valid_graph_predictor_data): listed_predictors = list(pc.list_versions(pred_id, per_page=20)) # Then - assert session.calls == [FakeCall(method='GET', path=versions_path, params={'per_page': 20, 'page': 1})] + assert session.calls == [ + FakeCall(method="GET", path=versions_path, params={"per_page": 20, "page": 1}) + ] assert len(listed_predictors) == 2 @@ -524,7 +594,9 @@ def test_list_archived_versions(valid_graph_predictor_data): predictor_v2 = deepcopy(valid_graph_predictor_data) predictor_v2["metadata"]["version"] = 2 - versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) + versions_path = _PredictorVersionCollection._path_template.format( + project_id=pc.project_id, uid=pred_id + ) session.set_response(paging_response(predictor_v1, predictor_v2)) @@ -532,8 +604,10 @@ def test_list_archived_versions(valid_graph_predictor_data): listed_predictors = list(pc.list_archived_versions(pred_id, per_page=20)) # Then - expected_params = {'per_page': 20, "filter": "archived eq 'true'", 'page': 1} - assert session.calls == [FakeCall(method='GET', path=versions_path, params=expected_params)] + expected_params = {"per_page": 20, "filter": "archived eq 'true'", "page": 1} + assert session.calls == [ + FakeCall(method="GET", path=versions_path, params=expected_params) + ] assert len(listed_predictors) == 2 @@ -543,13 +617,17 @@ def test_archive_version(valid_graph_predictor_data, version): pc = PredictorCollection(uuid.uuid4(), session) pred_id = valid_graph_predictor_data["id"] - versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) + versions_path = _PredictorVersionCollection._path_template.format( + project_id=pc.project_id, uid=pred_id + ) session.set_response(valid_graph_predictor_data) pc.archive_version(pred_id, version=version) - assert session.calls == [FakeCall(method='PUT', path=f"{versions_path}/{version}/archive", json={})] + assert session.calls == [ + FakeCall(method="PUT", path=f"{versions_path}/{version}/archive", json={}) + ] @pytest.mark.parametrize("version", (2, "1", "latest", "most_recent")) @@ -558,13 +636,17 @@ def test_restore_version(valid_graph_predictor_data, version): pc = PredictorCollection(uuid.uuid4(), session) pred_id = valid_graph_predictor_data["id"] - versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) + versions_path = _PredictorVersionCollection._path_template.format( + project_id=pc.project_id, uid=pred_id + ) session.set_response(valid_graph_predictor_data) pc.restore_version(pred_id, version=version) - assert session.calls == [FakeCall(method='PUT', path=f"{versions_path}/{version}/restore", json={})] + assert session.calls == [ + FakeCall(method="PUT", path=f"{versions_path}/{version}/restore", json={}) + ] @pytest.mark.parametrize("version", (-2, 0, "1.5", "draft")) @@ -597,14 +679,18 @@ def test_is_stale(valid_graph_predictor_data, is_stale): "id": pred_id, "version": pred_version, "status": "READY", - "is_stale": is_stale + "is_stale": is_stale, } session.set_response(response) resp = pc.is_stale(pred_id, version=pred_version) - versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) - assert session.calls == [FakeCall(method='GET', path=f"{versions_path}/{pred_version}/is-stale")] + versions_path = _PredictorVersionCollection._path_template.format( + project_id=pc.project_id, uid=pred_id + ) + assert session.calls == [ + FakeCall(method="GET", path=f"{versions_path}/{pred_version}/is-stale") + ] assert resp == is_stale @@ -621,8 +707,14 @@ def test_retrain_stale(valid_graph_predictor_data): pc.retrain_stale(pred_id, version=pred_version) - versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) - assert session.calls == [FakeCall(method='PUT', path=f"{versions_path}/{pred_version}/retrain-stale", json={})] + versions_path = _PredictorVersionCollection._path_template.format( + project_id=pc.project_id, uid=pred_id + ) + assert session.calls == [ + FakeCall( + method="PUT", path=f"{versions_path}/{pred_version}/retrain-stale", json={} + ) + ] def test_unsupported_archive(): @@ -638,24 +730,38 @@ def test_unsupported_restore(): def test_create_default_async(): session = FakeSession() pc = PredictorCollection(uuid.uuid4(), session) - predictors_path = PredictorCollection._path_template.format(project_id=pc.project_id) - + predictors_path = PredictorCollection._path_template.format( + project_id=pc.project_id + ) + mode = "PLAIN" prefer_valid = False ds = GemTableDataSource(table_id=uuid.uuid4(), table_version=1) - data_source_payload = TableDataSourceDataFactory(table_id=str(ds.table_id), table_version=ds.table_version) + data_source_payload = TableDataSourceDataFactory( + table_id=str(ds.table_id), table_version=ds.table_version + ) expected_payload = { "data_source": data_source_payload, "pattern": mode, - "prefer_valid": prefer_valid + "prefer_valid": prefer_valid, } - metadata = AsyncDefaultPredictorResponseMetadataFactory(data_source=data_source_payload) - session.set_response(AsyncDefaultPredictorResponseFactory(metadata=metadata, data=None)) + metadata = AsyncDefaultPredictorResponseMetadataFactory( + data_source=data_source_payload + ) + session.set_response( + AsyncDefaultPredictorResponseFactory(metadata=metadata, data=None) + ) pc.create_default_async(training_data=ds, pattern=mode, prefer_valid=prefer_valid) - assert session.calls == [FakeCall(method="POST", path=f"{predictors_path}/default-async", json=expected_payload)] + assert session.calls == [ + FakeCall( + method="POST", + path=f"{predictors_path}/default-async", + json=expected_payload, + ) + ] def test_get_default_async(valid_graph_predictor_data): @@ -693,9 +799,9 @@ def test_get_featurized_training_data(example_hierarchical_design_material): # Then expected_call = FakeCall( - method='GET', - path=f'/projects/{pc.project_id}/predictors/{id}/versions/{version}/featurized-training-data', - params={} + method="GET", + path=f"/projects/{pc.project_id}/predictors/{id}/versions/{version}/featurized-training-data", + params={}, ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -714,9 +820,17 @@ def test_rename(valid_graph_predictor_data): session.set_response(valid_graph_predictor_data) pc.rename(pred_id, version=pred_version, name=new_name, description=new_description) # Then - versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) + versions_path = _PredictorVersionCollection._path_template.format( + project_id=pc.project_id, uid=pred_id + ) expected_payload = {"name": new_name, "description": new_description} - assert session.calls == [FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/rename", json=expected_payload)] + assert session.calls == [ + FakeCall( + method="PUT", + path=f"{versions_path}/{pred_version}/rename", + json=expected_payload, + ) + ] def test_rename_name_only(valid_graph_predictor_data): @@ -733,9 +847,18 @@ def test_rename_name_only(valid_graph_predictor_data): pc.rename(pred_id, version=pred_version, name=new_name) # Then - versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) + versions_path = _PredictorVersionCollection._path_template.format( + project_id=pc.project_id, uid=pred_id + ) expected_payload = {"name": new_name, "description": None} - assert session.calls == [FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/rename", json=expected_payload)] + assert session.calls == [ + FakeCall( + method="PUT", + path=f"{versions_path}/{pred_version}/rename", + json=expected_payload, + ) + ] + def test_rename_description_only(valid_graph_predictor_data): pred_id = valid_graph_predictor_data["id"] @@ -751,6 +874,14 @@ def test_rename_description_only(valid_graph_predictor_data): pc.rename(pred_id, version=pred_version, description=new_description) # Then - versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) + versions_path = _PredictorVersionCollection._path_template.format( + project_id=pc.project_id, uid=pred_id + ) expected_payload = {"name": None, "description": new_description} - assert session.calls == [FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/rename", json=expected_payload)] + assert session.calls == [ + FakeCall( + method="PUT", + path=f"{versions_path}/{pred_version}/rename", + json=expected_payload, + ) + ] diff --git a/tests/resources/test_predictor_evaluation_executions.py b/tests/resources/test_predictor_evaluation_executions.py index 53c548282..9b55fe466 100644 --- a/tests/resources/test_predictor_evaluation_executions.py +++ b/tests/resources/test_predictor_evaluation_executions.py @@ -4,10 +4,14 @@ import pytest from citrine._rest.resource import PredictorRef -from citrine.informatics.executions.predictor_evaluation_execution import PredictorEvaluationExecution +from citrine.informatics.executions.predictor_evaluation_execution import ( + PredictorEvaluationExecution, +) from citrine.informatics.predictor_evaluation_result import PredictorEvaluationResult -from citrine.resources.predictor_evaluation_execution import PredictorEvaluationExecutionCollection -from tests.utils.session import FakeSession, FakeCall +from citrine.resources.predictor_evaluation_execution import ( + PredictorEvaluationExecutionCollection, +) +from tests.utils.session import FakeCall, FakeSession @pytest.fixture @@ -26,7 +30,10 @@ def collection(session) -> PredictorEvaluationExecutionCollection: @pytest.fixture -def workflow_execution(collection: PredictorEvaluationExecutionCollection, predictor_evaluation_execution_dict) -> PredictorEvaluationExecution: +def workflow_execution( + collection: PredictorEvaluationExecutionCollection, + predictor_evaluation_execution_dict, +) -> PredictorEvaluationExecution: return collection.build(predictor_evaluation_execution_dict) @@ -62,12 +69,15 @@ def test_build_new_execution(collection, predictor_evaluation_execution_dict): assert execution.project_id == collection.project_id assert execution.workflow_id == collection.workflow_id assert execution._session == collection.session - assert execution.in_progress() and not execution.succeeded() and not execution.failed() + assert ( + execution.in_progress() and not execution.succeeded() and not execution.failed() + ) assert execution.status_detail -def test_workflow_execution_results(workflow_execution: PredictorEvaluationExecution, session, - example_cv_result_dict): +def test_workflow_execution_results( + workflow_execution: PredictorEvaluationExecution, session, example_cv_result_dict +): # Given session.set_response(example_cv_result_dict) @@ -75,15 +85,24 @@ def test_workflow_execution_results(workflow_execution: PredictorEvaluationExecu results = workflow_execution["Example Evaluator"] # Then - assert results.evaluator == PredictorEvaluationResult.build(example_cv_result_dict).evaluator - expected_path = '/projects/{}/predictor-evaluation-executions/{}/results'.format( + assert ( + results.evaluator + == PredictorEvaluationResult.build(example_cv_result_dict).evaluator + ) + expected_path = "/projects/{}/predictor-evaluation-executions/{}/results".format( workflow_execution.project_id, workflow_execution.uid, ) - assert session.last_call == FakeCall(method='GET', path=expected_path, params={"evaluator_name": "Example Evaluator"}) + assert session.last_call == FakeCall( + method="GET", path=expected_path, params={"evaluator_name": "Example Evaluator"} + ) -def test_trigger_workflow_execution(collection: PredictorEvaluationExecutionCollection, predictor_evaluation_execution_dict, session): +def test_trigger_workflow_execution( + collection: PredictorEvaluationExecutionCollection, + predictor_evaluation_execution_dict, + session, +): # Given predictor_id = uuid.uuid4() random_state = 9325 @@ -95,19 +114,23 @@ def test_trigger_workflow_execution(collection: PredictorEvaluationExecutionColl # Then assert str(actual_execution.uid) == predictor_evaluation_execution_dict["id"] - expected_path = '/projects/{}/predictor-evaluation-workflows/{}/executions'.format( + expected_path = "/projects/{}/predictor-evaluation-workflows/{}/executions".format( collection.project_id, collection.workflow_id, ) assert session.last_call == FakeCall( - method='POST', + method="POST", path=expected_path, json=PredictorRef(predictor_id).dump(), - params={"random_state": random_state} + params={"random_state": random_state}, ) -def test_trigger_workflow_execution_with_version(collection: PredictorEvaluationExecutionCollection, predictor_evaluation_execution_dict, session): +def test_trigger_workflow_execution_with_version( + collection: PredictorEvaluationExecutionCollection, + predictor_evaluation_execution_dict, + session, +): # Given predictor_id = uuid.uuid4() predictor_version = random.randint(1, 10) @@ -115,48 +138,79 @@ def test_trigger_workflow_execution_with_version(collection: PredictorEvaluation # When with pytest.deprecated_call(): - actual_execution = collection.trigger(predictor_id, predictor_version=predictor_version) + actual_execution = collection.trigger( + predictor_id, predictor_version=predictor_version + ) # Then assert str(actual_execution.uid) == predictor_evaluation_execution_dict["id"] - expected_path = '/projects/{}/predictor-evaluation-workflows/{}/executions'.format( + expected_path = "/projects/{}/predictor-evaluation-workflows/{}/executions".format( collection.project_id, collection.workflow_id, ) assert session.last_call == FakeCall( - method='POST', + method="POST", path=expected_path, - json=PredictorRef(predictor_id, predictor_version).dump() + json=PredictorRef(predictor_id, predictor_version).dump(), ) @pytest.mark.parametrize("predictor_version", (2, "1", "latest", None)) -def test_list(collection: PredictorEvaluationExecutionCollection, session, predictor_version): +def test_list( + collection: PredictorEvaluationExecutionCollection, session, predictor_version +): session.set_response({"page": 1, "per_page": 4, "next": "", "response": []}) predictor_id = uuid.uuid4() with pytest.deprecated_call(): - lst = list(collection.list(per_page=4, predictor_id=predictor_id, predictor_version=predictor_version)) + lst = list( + collection.list( + per_page=4, + predictor_id=predictor_id, + predictor_version=predictor_version, + ) + ) assert not lst - expected_path = '/projects/{}/predictor-evaluation-executions'.format(collection.project_id) - expected_payload = {"per_page": 4, "predictor_id": str(predictor_id), "workflow_id": str(collection.workflow_id), 'page': 1} + expected_path = "/projects/{}/predictor-evaluation-executions".format( + collection.project_id + ) + expected_payload = { + "per_page": 4, + "predictor_id": str(predictor_id), + "workflow_id": str(collection.workflow_id), + "page": 1, + } if predictor_version is not None: expected_payload["predictor_version"] = predictor_version - assert session.last_call == FakeCall(method='GET', path=expected_path, params=expected_payload) + assert session.last_call == FakeCall( + method="GET", path=expected_path, params=expected_payload + ) def test_archive(workflow_execution, collection): with pytest.deprecated_call(): collection.archive(workflow_execution.uid) - expected_path = '/projects/{}/predictor-evaluation-executions/archive'.format(collection.project_id) - assert collection.session.last_call == FakeCall(method='PUT', path=expected_path, json={"module_uid": str(workflow_execution.uid)}) + expected_path = "/projects/{}/predictor-evaluation-executions/archive".format( + collection.project_id + ) + assert collection.session.last_call == FakeCall( + method="PUT", + path=expected_path, + json={"module_uid": str(workflow_execution.uid)}, + ) def test_restore(workflow_execution, collection): with pytest.deprecated_call(): collection.restore(workflow_execution.uid) - expected_path = '/projects/{}/predictor-evaluation-executions/restore'.format(collection.project_id) - assert collection.session.last_call == FakeCall(method='PUT', path=expected_path, json={"module_uid": str(workflow_execution.uid)}) + expected_path = "/projects/{}/predictor-evaluation-executions/restore".format( + collection.project_id + ) + assert collection.session.last_call == FakeCall( + method="PUT", + path=expected_path, + json={"module_uid": str(workflow_execution.uid)}, + ) def test_delete(collection): @@ -164,11 +218,12 @@ def test_delete(collection): with pytest.raises(NotImplementedError): collection.delete(uuid.uuid4()) + def test_get(predictor_evaluation_execution_dict, workflow_execution, collection): collection.session.set_response(predictor_evaluation_execution_dict) - + with pytest.deprecated_call(): - execution = collection.get(workflow_execution.uid) - - expected_path = f'/projects/{collection.project_id}/predictor-evaluation-executions/{workflow_execution.uid}' - assert collection.session.last_call == FakeCall(method='GET', path=expected_path) + collection.get(workflow_execution.uid) + + expected_path = f"/projects/{collection.project_id}/predictor-evaluation-executions/{workflow_execution.uid}" + assert collection.session.last_call == FakeCall(method="GET", path=expected_path) diff --git a/tests/resources/test_predictor_evaluation_workflows.py b/tests/resources/test_predictor_evaluation_workflows.py index d231a583c..928335a68 100644 --- a/tests/resources/test_predictor_evaluation_workflows.py +++ b/tests/resources/test_predictor_evaluation_workflows.py @@ -4,7 +4,9 @@ import pytest from citrine.informatics.workflows import PredictorEvaluationWorkflow -from citrine.resources.predictor_evaluation_workflow import PredictorEvaluationWorkflowCollection +from citrine.resources.predictor_evaluation_workflow import ( + PredictorEvaluationWorkflowCollection, +) from tests.utils.session import FakeSession, FakeCall @@ -23,8 +25,10 @@ def collection(session) -> PredictorEvaluationWorkflowCollection: @pytest.fixture -def workflow(collection: PredictorEvaluationWorkflowCollection, - predictor_evaluation_workflow_dict) -> PredictorEvaluationWorkflow: +def workflow( + collection: PredictorEvaluationWorkflowCollection, + predictor_evaluation_workflow_dict, +) -> PredictorEvaluationWorkflow: return collection.build(predictor_evaluation_workflow_dict) @@ -36,17 +40,23 @@ def test_basic_methods(workflow, collection): def test_archive(workflow, collection): with pytest.deprecated_call(): collection.archive(workflow.uid) - expected_path = '/projects/{}/predictor-evaluation-workflows/archive'.format(collection.project_id) - assert collection.session.last_call == FakeCall(method='PUT', path=expected_path, - json={"module_uid": str(workflow.uid)}) + expected_path = "/projects/{}/predictor-evaluation-workflows/archive".format( + collection.project_id + ) + assert collection.session.last_call == FakeCall( + method="PUT", path=expected_path, json={"module_uid": str(workflow.uid)} + ) def test_restore(workflow, collection): with pytest.deprecated_call(): collection.restore(workflow.uid) - expected_path = '/projects/{}/predictor-evaluation-workflows/restore'.format(collection.project_id) - assert collection.session.last_call == FakeCall(method='PUT', path=expected_path, - json={"module_uid": str(workflow.uid)}) + expected_path = "/projects/{}/predictor-evaluation-workflows/restore".format( + collection.project_id + ) + assert collection.session.last_call == FakeCall( + method="PUT", path=expected_path, json={"module_uid": str(workflow.uid)} + ) def test_delete(collection): @@ -56,9 +66,11 @@ def test_delete(collection): @pytest.mark.parametrize("predictor_version", (2, "1", "latest", None)) -def test_create_default(predictor_evaluation_workflow_dict: dict, - predictor_version: Optional[Union[int, str]], - workflow: PredictorEvaluationWorkflow): +def test_create_default( + predictor_evaluation_workflow_dict: dict, + predictor_version: Optional[Union[int, str]], + workflow: PredictorEvaluationWorkflow, +): project_id = uuid.uuid4() predictor_id = uuid.uuid4() @@ -66,52 +78,70 @@ def test_create_default(predictor_evaluation_workflow_dict: dict, session.set_response(predictor_evaluation_workflow_dict) with pytest.deprecated_call(): collection = PredictorEvaluationWorkflowCollection( - project_id=project_id, - session=session + project_id=project_id, session=session ) with pytest.deprecated_call(): - default_workflow = collection.create_default(predictor_id=predictor_id, predictor_version=predictor_version) + default_workflow = collection.create_default( + predictor_id=predictor_id, predictor_version=predictor_version + ) - url = f'/projects/{collection.project_id}/predictor-evaluation-workflows/default' + url = f"/projects/{collection.project_id}/predictor-evaluation-workflows/default" expected_payload = {"predictor_id": str(predictor_id)} if predictor_version is not None: expected_payload["predictor_version"] = predictor_version assert session.calls == [FakeCall(method="POST", path=url, json=expected_payload)] assert default_workflow.dump() == workflow.dump() - + + def test_register(predictor_evaluation_workflow_dict, workflow, collection): collection.session.set_response(predictor_evaluation_workflow_dict) - + with pytest.deprecated_call(): collection.register(workflow) - expected_path = f'/projects/{collection.project_id}/predictor-evaluation-workflows' - assert collection.session.last_call == FakeCall(method='POST', path=expected_path, json=workflow.dump()) - + expected_path = f"/projects/{collection.project_id}/predictor-evaluation-workflows" + assert collection.session.last_call == FakeCall( + method="POST", path=expected_path, json=workflow.dump() + ) + + def test_list(predictor_evaluation_workflow_dict, workflow, collection): - collection.session.set_response({"page": 1, "per_page": 4, "next": "", "response": [predictor_evaluation_workflow_dict]}) - + collection.session.set_response( + { + "page": 1, + "per_page": 4, + "next": "", + "response": [predictor_evaluation_workflow_dict], + } + ) + with pytest.deprecated_call(): list(collection.list(per_page=20)) - expected_path = f'/projects/{collection.project_id}/predictor-evaluation-workflows' - assert collection.session.last_call == FakeCall(method='GET', path=expected_path, params={"per_page": 20, "page": 1}) + expected_path = f"/projects/{collection.project_id}/predictor-evaluation-workflows" + assert collection.session.last_call == FakeCall( + method="GET", path=expected_path, params={"per_page": 20, "page": 1} + ) + def test_update(predictor_evaluation_workflow_dict, workflow, collection): collection.session.set_response(predictor_evaluation_workflow_dict) - + with pytest.deprecated_call(): collection.update(workflow) - expected_path = f'/projects/{collection.project_id}/predictor-evaluation-workflows/{workflow.uid}' - assert collection.session.last_call == FakeCall(method='PUT', path=expected_path, json=workflow.dump()) + expected_path = f"/projects/{collection.project_id}/predictor-evaluation-workflows/{workflow.uid}" + assert collection.session.last_call == FakeCall( + method="PUT", path=expected_path, json=workflow.dump() + ) + def test_get(predictor_evaluation_workflow_dict, workflow, collection): collection.session.set_response(predictor_evaluation_workflow_dict) - + with pytest.deprecated_call(): collection.get(workflow.uid) - expected_path = f'/projects/{collection.project_id}/predictor-evaluation-workflows/{workflow.uid}' - assert collection.session.last_call == FakeCall(method='GET', path=expected_path) + expected_path = f"/projects/{collection.project_id}/predictor-evaluation-workflows/{workflow.uid}" + assert collection.session.last_call == FakeCall(method="GET", path=expected_path) diff --git a/tests/resources/test_predictor_evaluations.py b/tests/resources/test_predictor_evaluations.py index fed853601..c3d45b887 100644 --- a/tests/resources/test_predictor_evaluations.py +++ b/tests/resources/test_predictor_evaluations.py @@ -4,12 +4,18 @@ import pytest from citrine.resources.predictor_evaluation import PredictorEvaluationCollection -from citrine.informatics.executions.predictor_evaluation import PredictorEvaluationRequest +from citrine.informatics.executions.predictor_evaluation import ( + PredictorEvaluationRequest, +) from citrine.informatics.predictors import GraphPredictor from citrine.jobs.waiting import wait_while_executing -from tests.utils.factories import CrossValidationEvaluatorFactory, PredictorEvaluationDataFactory,\ - PredictorEvaluationFactory, PredictorInstanceDataFactory, PredictorRefFactory +from tests.utils.factories import ( + CrossValidationEvaluatorFactory, + PredictorEvaluationDataFactory, + PredictorEvaluationFactory, + PredictorRefFactory, +) from tests.utils.session import FakeCall, FakeSession @@ -23,15 +29,15 @@ def test_get(): session = FakeSession() pec = PredictorEvaluationCollection(uuid.uuid4(), session) - + session.set_response(evaluation_response) pec.get(id) expected_call = FakeCall( - method='GET', - path=f'/projects/{pec.project_id}/predictor-evaluations/{id}', - params={} + method="GET", + path=f"/projects/{pec.project_id}/predictor-evaluations/{id}", + params={}, ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -43,15 +49,15 @@ def test_archived(): session = FakeSession() pec = PredictorEvaluationCollection(uuid.uuid4(), session) - + session.set_response(evaluation_response) pec.archive(id) expected_call = FakeCall( - method='PUT', - path=f'/projects/{pec.project_id}/predictor-evaluations/{id}/archive', - json={} + method="PUT", + path=f"/projects/{pec.project_id}/predictor-evaluations/{id}/archive", + json={}, ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -63,15 +69,15 @@ def test_restore(): session = FakeSession() pec = PredictorEvaluationCollection(uuid.uuid4(), session) - + session.set_response(evaluation_response) pec.restore(id) expected_call = FakeCall( - method='PUT', - path=f'/projects/{pec.project_id}/predictor-evaluations/{id}/restore', - json={} + method="PUT", + path=f"/projects/{pec.project_id}/predictor-evaluations/{id}/restore", + json={}, ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -84,15 +90,21 @@ def test_list(): session = FakeSession() pec = PredictorEvaluationCollection(uuid.uuid4(), session) - + session.set_response(paging_response(evaluation_response)) evaluations = list(pec.list(predictor_id=pred_id, predictor_version=pred_ver)) expected_call = FakeCall( - method='GET', - path=f'/projects/{pec.project_id}/predictor-evaluations', - params={"page": 1, "per_page": 100, "predictor_id": str(pred_id), "predictor_version": pred_ver, "archived": False} + method="GET", + path=f"/projects/{pec.project_id}/predictor-evaluations", + params={ + "page": 1, + "per_page": 100, + "predictor_id": str(pred_id), + "predictor_version": pred_ver, + "archived": False, + }, ) assert session.num_calls == 1 @@ -107,15 +119,23 @@ def test_list_archived(): session = FakeSession() pec = PredictorEvaluationCollection(uuid.uuid4(), session) - + session.set_response(paging_response(evaluation_response)) - evaluations = list(pec.list_archived(predictor_id=pred_id, predictor_version=pred_ver)) + evaluations = list( + pec.list_archived(predictor_id=pred_id, predictor_version=pred_ver) + ) expected_call = FakeCall( - method='GET', - path=f'/projects/{pec.project_id}/predictor-evaluations', - params={"page": 1, "per_page": 100, "predictor_id": str(pred_id), "predictor_version": pred_ver, "archived": True} + method="GET", + path=f"/projects/{pec.project_id}/predictor-evaluations", + params={ + "page": 1, + "per_page": 100, + "predictor_id": str(pred_id), + "predictor_version": pred_ver, + "archived": True, + }, ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -123,21 +143,30 @@ def test_list_archived(): def test_list_all(): - evaluations = [PredictorEvaluationFactory(), PredictorEvaluationFactory(is_archived=True)] + evaluations = [ + PredictorEvaluationFactory(), + PredictorEvaluationFactory(is_archived=True), + ] pred_id = uuid.uuid4() pred_ver = 2 session = FakeSession() pec = PredictorEvaluationCollection(uuid.uuid4(), session) - + session.set_response(paging_response(*evaluations)) evaluations = list(pec.list_all(predictor_id=pred_id, predictor_version=pred_ver)) expected_call = FakeCall( - method='GET', - path=f'/projects/{pec.project_id}/predictor-evaluations', - params={"page": 1, "per_page": 100, "predictor_id": str(pred_id), "predictor_version": pred_ver, "archived": None} + method="GET", + path=f"/projects/{pec.project_id}/predictor-evaluations", + params={ + "page": 1, + "per_page": 100, + "predictor_id": str(pred_id), + "predictor_version": pred_ver, + "archived": None, + }, ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -151,18 +180,24 @@ def test_trigger(): session = FakeSession() pec = PredictorEvaluationCollection(uuid.uuid4(), session) - + session.set_response(evaluation_response) - pec.trigger(predictor_id=pred_ref["predictor_id"], predictor_version=pred_ref["predictor_version"], evaluators=evaluators) + pec.trigger( + predictor_id=pred_ref["predictor_id"], + predictor_version=pred_ref["predictor_version"], + evaluators=evaluators, + ) - expected_payload = PredictorEvaluationRequest(evaluators=evaluators, - predictor_id=pred_ref["predictor_id"], - predictor_version=pred_ref["predictor_version"]) + expected_payload = PredictorEvaluationRequest( + evaluators=evaluators, + predictor_id=pred_ref["predictor_id"], + predictor_version=pred_ref["predictor_version"], + ) expected_call = FakeCall( - method='POST', - path=f'/projects/{pec.project_id}/predictor-evaluations/trigger', - json=expected_payload.dump() + method="POST", + path=f"/projects/{pec.project_id}/predictor-evaluations/trigger", + json=expected_payload.dump(), ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -174,15 +209,18 @@ def test_trigger_default(): session = FakeSession() pec = PredictorEvaluationCollection(uuid.uuid4(), session) - + session.set_response(evaluation_response) - pec.trigger_default(predictor_id=pred_ref["predictor_id"], predictor_version=pred_ref["predictor_version"]) + pec.trigger_default( + predictor_id=pred_ref["predictor_id"], + predictor_version=pred_ref["predictor_version"], + ) expected_call = FakeCall( - method='POST', - path=f'/projects/{pec.project_id}/predictor-evaluations/trigger-default', - json=pred_ref + method="POST", + path=f"/projects/{pec.project_id}/predictor-evaluations/trigger-default", + json=pred_ref, ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -194,36 +232,40 @@ def test_default(): session = FakeSession() pec = PredictorEvaluationCollection(uuid.uuid4(), session) - + session.set_response(response) - default_evaluators = pec.default(predictor_id=pred_ref["predictor_id"], predictor_version=pred_ref["predictor_version"]) + default_evaluators = pec.default( + predictor_id=pred_ref["predictor_id"], + predictor_version=pred_ref["predictor_version"], + ) expected_call = FakeCall( - method='POST', - path=f'/projects/{pec.project_id}/predictor-evaluations/default', - json=pred_ref + method="POST", + path=f"/projects/{pec.project_id}/predictor-evaluations/default", + json=pred_ref, ) assert session.num_calls == 1 assert expected_call == session.last_call assert len(default_evaluators) == len(response["evaluators"]) + def test_default_from_config(valid_graph_predictor_data): response = PredictorEvaluationDataFactory() config = GraphPredictor.build(valid_graph_predictor_data) - payload = config.dump()['instance'] + payload = config.dump()["instance"] session = FakeSession() pec = PredictorEvaluationCollection(uuid.uuid4(), session) - + session.set_response(response) default_evaluators = pec.default_from_config(config) expected_call = FakeCall( - method='POST', - path=f'/projects/{pec.project_id}/predictor-evaluations/default-from-config', - json=payload + method="POST", + path=f"/projects/{pec.project_id}/predictor-evaluations/default-from-config", + json=payload, ) assert session.num_calls == 1 assert expected_call == session.last_call @@ -252,14 +294,16 @@ def test_delete_not_implemented(): def test_wait(): - in_progress_response = PredictorEvaluationFactory(metadata__status={"major": "INPROGRESS", "minor": "EXECUTING", "detail": []}) + in_progress_response = PredictorEvaluationFactory( + metadata__status={"major": "INPROGRESS", "minor": "EXECUTING", "detail": []} + ) completed_response = deepcopy(in_progress_response) completed_response["metadata"]["status"]["major"] = "SUCCEEDED" completed_response["metadata"]["status"]["minor"] = "COMPLETED" session = FakeSession() pec = PredictorEvaluationCollection(uuid.uuid4(), session) - + # wait_while_executing makes two additional calls once it's done polling. responses = 4 * [in_progress_response] + 3 * [completed_response] session.set_responses(*responses) @@ -268,7 +312,7 @@ def test_wait(): wait_while_executing(collection=pec, execution=evaluation, interval=0.1) expected_call = FakeCall( - method='GET', - path=f'/projects/{pec.project_id}/predictor-evaluations/{in_progress_response["id"]}' + method="GET", + path=f"/projects/{pec.project_id}/predictor-evaluations/{in_progress_response['id']}", ) assert (len(responses) * [expected_call]) == session.calls diff --git a/tests/resources/test_process_run.py b/tests/resources/test_process_run.py index 3dc6e04c5..fb6585a92 100644 --- a/tests/resources/test_process_run.py +++ b/tests/resources/test_process_run.py @@ -15,28 +15,32 @@ def session() -> FakeSession: @pytest.fixture def collection(session) -> ProcessRunCollection: return ProcessRunCollection( - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - team_id = UUID('6b608f78-e341-422c-8076-35adc8828000'), - session=session) + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), + team_id=UUID("6b608f78-e341-422c-8076-35adc8828000"), + session=session, + ) def test_create_deprecated_collection(session): - project_id = '6b608f78-e341-422c-8076-35adc8828545' - session.set_response({'project': {'team': {'id': UUID("6b608f78-e341-422c-8076-35adc8828000")}}}) + project_id = "6b608f78-e341-422c-8076-35adc8828545" + session.set_response( + {"project": {"team": {"id": UUID("6b608f78-e341-422c-8076-35adc8828000")}}} + ) with pytest.deprecated_call(): ProcessRunCollection( project_id=UUID(project_id), - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - session=session) + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), + session=session, + ) - assert session.calls == [FakeCall(method="GET", path=f'projects/{project_id}')] + assert session.calls == [FakeCall(method="GET", path=f"projects/{project_id}")] def test_list_by_spec(collection: ProcessRunCollection): run_noop_gemd_relation_search_test( - search_for='process-runs', - search_with='process-specs', + search_for="process-runs", + search_with="process-specs", collection=collection, search_fn=collection.list_by_spec, ) @@ -47,16 +51,8 @@ def test_equals(): from citrine.resources.process_run import ProcessRun as CitrineProcessRun from gemd.entity.object import ProcessRun as GEMDProcessRun - gemd_obj = GEMDProcessRun( - name="My Name", - notes="I have notes", - tags=["tag!"] - ) - citrine_obj = CitrineProcessRun( - name="My Name", - notes="I have notes", - tags=["tag!"] - ) + gemd_obj = GEMDProcessRun(name="My Name", notes="I have notes", tags=["tag!"]) + citrine_obj = CitrineProcessRun(name="My Name", notes="I have notes", tags=["tag!"]) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.notes = "Something else" assert gemd_obj != citrine_obj, "GEMD/Citrine detects difference" diff --git a/tests/resources/test_process_spec.py b/tests/resources/test_process_spec.py index fce60075b..2cd659c0d 100644 --- a/tests/resources/test_process_spec.py +++ b/tests/resources/test_process_spec.py @@ -4,7 +4,10 @@ from gemd.entity.object import ProcessSpec as GEMDProcessSpec -from citrine.resources.process_spec import ProcessSpec as CitrineProcesssSpec, ProcessSpecCollection +from citrine.resources.process_spec import ( + ProcessSpec as CitrineProcesssSpec, + ProcessSpecCollection, +) from tests.resources.test_data_concepts import run_noop_gemd_relation_search_test from tests.utils.session import FakeCall, FakeSession @@ -17,28 +20,32 @@ def session() -> FakeSession: @pytest.fixture def collection(session) -> ProcessSpecCollection: return ProcessSpecCollection( - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - team_id = UUID('6b608f78-e341-422c-8076-35adc8828000'), - session=session) + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), + team_id=UUID("6b608f78-e341-422c-8076-35adc8828000"), + session=session, + ) def test_create_deprecated_collection(session): - project_id = '6b608f78-e341-422c-8076-35adc8828545' - session.set_response({'project': {'team': {'id': UUID("6b608f78-e341-422c-8076-35adc8828000")}}}) + project_id = "6b608f78-e341-422c-8076-35adc8828545" + session.set_response( + {"project": {"team": {"id": UUID("6b608f78-e341-422c-8076-35adc8828000")}}} + ) with pytest.deprecated_call(): ProcessSpecCollection( project_id=UUID(project_id), - dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), - session=session) + dataset_id=UUID("8da51e93-8b55-4dd3-8489-af8f65d4ad9a"), + session=session, + ) - assert session.calls == [FakeCall(method="GET", path=f'projects/{project_id}')] + assert session.calls == [FakeCall(method="GET", path=f"projects/{project_id}")] def test_list_by_template(collection: ProcessSpecCollection): run_noop_gemd_relation_search_test( - search_for='process-specs', - search_with='process-templates', + search_for="process-specs", + search_with="process-templates", collection=collection, search_fn=collection.list_by_template, ) @@ -46,15 +53,9 @@ def test_list_by_template(collection: ProcessSpecCollection): def test_equals(): """Test basic equality. Complex relationships are tested in test_material_run.test_deep_equals().""" - gemd_obj = GEMDProcessSpec( - name="My Name", - notes="I have notes", - tags=["tag!"] - ) + gemd_obj = GEMDProcessSpec(name="My Name", notes="I have notes", tags=["tag!"]) citrine_obj = CitrineProcesssSpec( - name="My Name", - notes="I have notes", - tags=["tag!"] + name="My Name", notes="I have notes", tags=["tag!"] ) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.notes = "Something else" diff --git a/tests/resources/test_project.py b/tests/resources/test_project.py index 0a1905d74..836c7aa2e 100644 --- a/tests/resources/test_project.py +++ b/tests/resources/test_project.py @@ -8,16 +8,19 @@ from citrine.exceptions import NotFound, ModuleRegistrationFailedException from citrine.informatics.predictors import GraphPredictor -from citrine.resources.api_error import ApiError, ValidationError +from citrine.resources.api_error import ApiError from citrine.resources.dataset import Dataset, DatasetCollection from citrine.resources.gemtables import GemTableCollection from citrine.resources.process_spec import ProcessSpec from citrine.resources.project import Project, ProjectCollection -from citrine.resources.project_member import ProjectMember -from citrine.resources.project_roles import MEMBER, LEAD, WRITE from tests.utils.factories import ProjectDataFactory, UserDataFactory, TeamDataFactory -from tests.utils.session import FakeSession, FakeCall, FakePaginatedSession, FakeRequestResponse -from citrine.resources.team import READ, TeamMember +from tests.utils.session import ( + FakeSession, + FakeCall, + FakePaginatedSession, + FakeRequestResponse, +) +from citrine.resources.team import READ, TeamMember @pytest.fixture @@ -32,19 +35,17 @@ def paginated_session() -> FakePaginatedSession: @pytest.fixture def paginated_collection(paginated_session) -> ProjectCollection: - return ProjectCollection( - session=paginated_session - ) + return ProjectCollection(session=paginated_session) @pytest.fixture def project(session) -> Project: project = Project( - name='Test Project', + name="Test Project", session=session, - team_id=uuid.UUID('11111111-8baf-433b-82eb-8c7fada847da') + team_id=uuid.UUID("11111111-8baf-433b-82eb-8c7fada847da"), ) - project.uid = uuid.UUID('16fd2706-8baf-433b-82eb-8c7fada847da') + project.uid = uuid.UUID("16fd2706-8baf-433b-82eb-8c7fada847da") return project @@ -60,10 +61,10 @@ def datasets(project) -> DatasetCollection: def test_get_team_id_from_project(session): - team_id = uuid.UUID('6b608f78-e341-422c-8076-35adc8828000') - check_project = {'project': {'team': {'id': team_id}}} + team_id = uuid.UUID("6b608f78-e341-422c-8076-35adc8828000") + check_project = {"project": {"team": {"id": team_id}}} session.set_response(check_project) - p = Project(name='Test Project', session=session) + p = Project(name="Test Project", session=session) assert p.team_id == team_id @@ -78,31 +79,26 @@ def test_publish_resource(project, session): assert 1 == session.num_calls expected_call = FakeCall( - method='POST', - path=f'/projects/{project.uid}/published-resources/MODULE/batch-publish', - json={ - 'ids': [str(predictor.uid)] - } + method="POST", + path=f"/projects/{project.uid}/published-resources/MODULE/batch-publish", + json={"ids": [str(predictor.uid)]}, ) assert expected_call == session.last_call def test_publish_resource_deprecated(project, datasets, session): dataset_id = str(uuid.uuid4()) - dataset = datasets.build(dict( - id=dataset_id, - name="public dataset", summary="test", description="test" - )) + dataset = datasets.build( + dict(id=dataset_id, name="public dataset", summary="test", description="test") + ) with pytest.deprecated_call(): assert project.publish(resource=dataset) assert 1 == session.num_calls expected_call = FakeCall( - method='POST', - path=f'/projects/{project.uid}/published-resources/DATASET/batch-publish', - json={ - 'ids': [str(dataset.uid)] - } + method="POST", + path=f"/projects/{project.uid}/published-resources/DATASET/batch-publish", + json={"ids": [str(dataset.uid)]}, ) assert expected_call == session.last_call @@ -114,31 +110,26 @@ def test_pull_in_resource(project, datasets, session): assert 1 == session.num_calls expected_call = FakeCall( - method='POST', - path=f'/teams/{project.team_id}/projects/{project.uid}/outside-resources/MODULE/batch-pull-in', - json={ - 'ids': [str(predictor.uid)] - } + method="POST", + path=f"/teams/{project.team_id}/projects/{project.uid}/outside-resources/MODULE/batch-pull-in", + json={"ids": [str(predictor.uid)]}, ) assert expected_call == session.last_call def test_pull_in_resource_deprecated(project, datasets, session): dataset_id = str(uuid.uuid4()) - dataset = datasets.build(dict( - id=dataset_id, - name="public dataset", summary="test", description="test" - )) + dataset = datasets.build( + dict(id=dataset_id, name="public dataset", summary="test", description="test") + ) with pytest.deprecated_call(): assert project.pull_in_resource(resource=dataset) assert 1 == session.num_calls expected_call = FakeCall( - method='POST', - path=f'/teams/{project.team_id}/projects/{project.uid}/outside-resources/DATASET/batch-pull-in', - json={ - 'ids': [str(dataset.uid)] - } + method="POST", + path=f"/teams/{project.team_id}/projects/{project.uid}/outside-resources/DATASET/batch-pull-in", + json={"ids": [str(dataset.uid)]}, ) assert expected_call == session.last_call @@ -150,31 +141,26 @@ def test_un_publish_resource(project, datasets, session): assert 1 == session.num_calls expected_call = FakeCall( - method='POST', - path=f'/projects/{project.uid}/published-resources/MODULE/batch-un-publish', - json={ - 'ids': [str(predictor.uid)] - } + method="POST", + path=f"/projects/{project.uid}/published-resources/MODULE/batch-un-publish", + json={"ids": [str(predictor.uid)]}, ) assert expected_call == session.last_call def test_un_publish_resource_deprecated(project, datasets, session): dataset_id = str(uuid.uuid4()) - dataset = datasets.build(dict( - id=dataset_id, - name="public dataset", summary="test", description="test" - )) + dataset = datasets.build( + dict(id=dataset_id, name="public dataset", summary="test", description="test") + ) with pytest.deprecated_call(): assert project.un_publish(resource=dataset) assert 1 == session.num_calls expected_call = FakeCall( - method='POST', - path=f'/projects/{project.uid}/published-resources/DATASET/batch-un-publish', - json={ - 'ids': [str(dataset.uid)] - } + method="POST", + path=f"/projects/{project.uid}/published-resources/DATASET/batch-un-publish", + json={"ids": [str(dataset.uid)]}, ) assert expected_call == session.last_call @@ -189,7 +175,7 @@ def test_datasets_get_project_id(project): def test_property_templates_get_project_id(project): with pytest.deprecated_call(): collection = project.property_templates - + assert project.uid == collection.project_id @@ -345,13 +331,14 @@ def test_ara_definitions_get_project_id(project): def test_failed_register(): team_id = uuid.uuid4() session = mock.Mock() - session.post_resource.side_effect = NotFound(f'/teams/{team_id}/projects', - FakeRequestResponse(400)) + session.post_resource.side_effect = NotFound( + f"/teams/{team_id}/projects", FakeRequestResponse(400) + ) project_collection = ProjectCollection(session=session, team_id=team_id) with pytest.raises(ModuleRegistrationFailedException) as e: project_collection.register("Project") assert 'The "Project" failed to register.' in str(e.value) - assert f'/teams/{team_id}/projects' in str(e.value) + assert f"/teams/{team_id}/projects" in str(e.value) def test_failed_register_no_team(session): @@ -362,66 +349,66 @@ def test_failed_register_no_team(session): def test_project_registration(collection: ProjectCollection, session): # Given - create_time = parse('2019-09-10T00:00:00+00:00') + create_time = parse("2019-09-10T00:00:00+00:00") project_data = ProjectDataFactory( - name='testing', - description='A sample project', - created_at=int(create_time.timestamp() * 1000) # The lib expects ms since epoch, which is really odd + name="testing", + description="A sample project", + created_at=int( + create_time.timestamp() * 1000 + ), # The lib expects ms since epoch, which is really odd ) - session.set_response({'project': project_data}) + session.set_response({"project": project_data}) team_id = collection.team_id # When - created_project = collection.register('testing') + created_project = collection.register("testing") # Then assert 1 == session.num_calls expected_call = FakeCall( - method='POST', - path=f'teams/{team_id}/projects', - json={ - 'name': 'testing' - } + method="POST", path=f"teams/{team_id}/projects", json={"name": "testing"} ) assert expected_call == session.last_call - assert 'A sample project' == created_project.description - assert 'CREATED' == created_project.status + assert "A sample project" == created_project.description + assert "CREATED" == created_project.status assert create_time == created_project.created_at def test_get_project(collection: ProjectCollection, session): # Given - project_data = ProjectDataFactory(name='single project') - session.set_response({'project': project_data}) + project_data = ProjectDataFactory(name="single project") + session.set_response({"project": project_data}) # When - created_project = collection.get(project_data['id']) + created_project = collection.get(project_data["id"]) # Then assert 1 == session.num_calls expected_call = FakeCall( - method='GET', - path='/projects/{}'.format(project_data['id']), + method="GET", + path="/projects/{}".format(project_data["id"]), ) assert expected_call == session.last_call - assert 'single project' == created_project.name + assert "single project" == created_project.name def test_list_projects(collection, session): # Given projects_data = ProjectDataFactory.create_batch(5) - session.set_response({'projects': projects_data}) + session.set_response({"projects": projects_data}) # When projects = list(collection.list()) # Then assert 1 == session.num_calls - expected_call = FakeCall(method='GET', - path=f'/teams/{collection.team_id}/projects', - params={'per_page': 1000, 'page': 1}, - version="v3") + expected_call = FakeCall( + method="GET", + path=f"/teams/{collection.team_id}/projects", + params={"per_page": 1000, "page": 1}, + version="v3", + ) assert expected_call == session.last_call assert 5 == len(projects) @@ -429,17 +416,19 @@ def test_list_projects(collection, session): def test_list_archived_projects(collection, session): # Given projects_data = ProjectDataFactory.create_batch(5) - session.set_response({'projects': projects_data}) + session.set_response({"projects": projects_data}) # When projects = list(collection.list_archived()) # Then assert 1 == session.num_calls - expected_call = FakeCall(method='GET', - path=f'/teams/{collection.team_id}/projects', - params={'per_page': 1000, 'page': 1, 'archived': "true"}, - version="v3") + expected_call = FakeCall( + method="GET", + path=f"/teams/{collection.team_id}/projects", + params={"per_page": 1000, "page": 1, "archived": "true"}, + version="v3", + ) assert expected_call == session.last_call assert 5 == len(projects) @@ -447,17 +436,19 @@ def test_list_archived_projects(collection, session): def test_list_active_projects(collection, session): # Given projects_data = ProjectDataFactory.create_batch(5) - session.set_response({'projects': projects_data}) + session.set_response({"projects": projects_data}) # When projects = list(collection.list_active()) # Then assert 1 == session.num_calls - expected_call = FakeCall(method='GET', - path=f'/teams/{collection.team_id}/projects', - params={'per_page': 1000, 'page': 1, 'archived': "false"}, - version="v3") + expected_call = FakeCall( + method="GET", + path=f"/teams/{collection.team_id}/projects", + params={"per_page": 1000, "page": 1, "archived": "false"}, + version="v3", + ) assert expected_call == session.last_call assert 5 == len(projects) @@ -465,12 +456,14 @@ def test_list_active_projects(collection, session): def test_list_no_team(session): project_collection = ProjectCollection(session=session) projects_data = ProjectDataFactory.create_batch(5) - session.set_response({'projects': projects_data}) + session.set_response({"projects": projects_data}) projects = list(project_collection.list()) assert 1 == session.num_calls - expected_call = FakeCall(method='GET', path='/projects', params={'per_page': 1000, 'page': 1}) + expected_call = FakeCall( + method="GET", path="/projects", params={"per_page": 1000, "page": 1} + ) assert expected_call == session.last_call assert 5 == len(projects) @@ -478,81 +471,92 @@ def test_list_no_team(session): def test_list_projects_with_page_params(collection, session): # Given project_data = ProjectDataFactory() - session.set_response({'projects': [project_data]}) + session.set_response({"projects": [project_data]}) # When list(collection.list(per_page=10)) # Then assert 1 == session.num_calls - expected_call = FakeCall(method='GET', path=f'/teams/{collection.team_id}/projects', params={'per_page': 10, 'page': 1}) + expected_call = FakeCall( + method="GET", + path=f"/teams/{collection.team_id}/projects", + params={"per_page": 10, "page": 1}, + ) assert expected_call == session.last_call + def test_search_all_no_team(session): project_collection = ProjectCollection(session=session) projects_data = ProjectDataFactory.create_batch(2) - project_name_to_match = projects_data[0]['name'] + project_name_to_match = projects_data[0]["name"] - search_params = { - 'name': { - 'value': project_name_to_match, - 'search_method': 'EXACT'}} + search_params = {"name": {"value": project_name_to_match, "search_method": "EXACT"}} expected_response = [p for p in projects_data if p["name"] == project_name_to_match] - project_collection.session.set_response({'projects': expected_response}) + project_collection.session.set_response({"projects": expected_response}) # Then results = list(project_collection.search_all(search_params=search_params)) - expected_call = FakeCall(method='POST', path='/projects/search', params={'userId': ''}, json={'search_params': search_params}) + expected_call = FakeCall( + method="POST", + path="/projects/search", + params={"userId": ""}, + json={"search_params": search_params}, + ) assert 1 == project_collection.session.num_calls assert expected_call == project_collection.session.last_call assert 1 == len(results) + def test_search_all(collection: ProjectCollection): # Given projects_data = ProjectDataFactory.create_batch(2) - project_name_to_match = projects_data[0]['name'] + project_name_to_match = projects_data[0]["name"] - search_params = { - 'name': { - 'value': project_name_to_match, - 'search_method': 'EXACT'}} + search_params = {"name": {"value": project_name_to_match, "search_method": "EXACT"}} expected_response = [p for p in projects_data if p["name"] == project_name_to_match] - collection.session.set_response({'projects': expected_response}) + collection.session.set_response({"projects": expected_response}) # Then results = list(collection.search_all(search_params=search_params)) - expected_call = FakeCall(method='POST', - path=f'/teams/{collection.team_id}/projects/search', - params={'userId': ''}, - json={'search_params': { - 'name': { - 'value': project_name_to_match, - 'search_method': 'EXACT'}}}) + expected_call = FakeCall( + method="POST", + path=f"/teams/{collection.team_id}/projects/search", + params={"userId": ""}, + json={ + "search_params": { + "name": {"value": project_name_to_match, "search_method": "EXACT"} + } + }, + ) assert 1 == collection.session.num_calls assert expected_call == collection.session.last_call assert 1 == len(results) + def test_search_all_no_search_params(collection: ProjectCollection): # Given projects_data = ProjectDataFactory.create_batch(2) expected_response = projects_data - collection.session.set_response({'projects': expected_response}) + collection.session.set_response({"projects": expected_response}) # Then result = list(collection.search_all(search_params=None)) - expected_call = FakeCall(method='POST', - path=f'/teams/{collection.team_id}/projects/search', - params={'userId': ''}, - json={}) + expected_call = FakeCall( + method="POST", + path=f"/teams/{collection.team_id}/projects/search", + params={"userId": ""}, + json={}, + ) assert 1 == collection.session.num_calls assert expected_call == collection.session.last_call @@ -562,43 +566,49 @@ def test_search_all_no_search_params(collection: ProjectCollection): def test_search_projects(collection: ProjectCollection): # Given projects_data = ProjectDataFactory.create_batch(2) - project_name_to_match = projects_data[0]['name'] + project_name_to_match = projects_data[0]["name"] - search_params = { - 'name': { - 'value': project_name_to_match, - 'search_method': 'EXACT'}} + search_params = {"name": {"value": project_name_to_match, "search_method": "EXACT"}} expected_response = [p for p in projects_data if p["name"] == project_name_to_match] - collection.session.set_response({'projects': expected_response}) + collection.session.set_response({"projects": expected_response}) # Then result = list(collection.search(search_params=search_params)) - expected_call = FakeCall(method='POST', - path=f'/teams/{collection.team_id}/projects/search', - params={'userId': ''}, - json={'search_params': { - 'name': { - 'value': project_name_to_match, - 'search_method': 'EXACT'}}}) + expected_call = FakeCall( + method="POST", + path=f"/teams/{collection.team_id}/projects/search", + params={"userId": ""}, + json={ + "search_params": { + "name": {"value": project_name_to_match, "search_method": "EXACT"} + } + }, + ) assert 1 == collection.session.num_calls assert expected_call == collection.session.last_call assert 1 == len(result) + def test_search_projects_no_search_params(collection: ProjectCollection): # Given projects_data = ProjectDataFactory.create_batch(2) expected_response = projects_data - collection.session.set_response({'projects': expected_response}) + collection.session.set_response({"projects": expected_response}) # Then result = list(collection.search()) - expected_call = FakeCall(method='POST', path=f'/teams/{collection.team_id}/projects/search', params={'userId': ''}, json={}) + expected_call = FakeCall( + method="POST", + path=f"/teams/{collection.team_id}/projects/search", + params={"userId": ""}, + json={}, + ) assert 1 == collection.session.num_calls assert expected_call == collection.session.last_call @@ -607,40 +617,40 @@ def test_search_projects_no_search_params(collection: ProjectCollection): def test_archive_project(collection, session): # Given - uid = '151199ec-e9aa-49a1-ac8e-da722aaf74c4' + uid = "151199ec-e9aa-49a1-ac8e-da722aaf74c4" # When collection.archive(uid) # Then assert 1 == session.num_calls - expected_call = FakeCall(method='POST', path=f'/projects/{uid}/archive') + expected_call = FakeCall(method="POST", path=f"/projects/{uid}/archive") assert expected_call == session.last_call def test_restore_project(collection, session): # Given - uid = '151199ec-e9aa-49a1-ac8e-da722aaf74c4' + uid = "151199ec-e9aa-49a1-ac8e-da722aaf74c4" # When collection.restore(uid) # Then assert 1 == session.num_calls - expected_call = FakeCall(method='POST', path=f'/projects/{uid}/restore') + expected_call = FakeCall(method="POST", path=f"/projects/{uid}/restore") assert expected_call == session.last_call def test_delete_project(collection, session): # Given - uid = '151199ec-e9aa-49a1-ac8e-da722aaf74c4' + uid = "151199ec-e9aa-49a1-ac8e-da722aaf74c4" # When collection.delete(uid) # Then assert 1 == session.num_calls - expected_call = FakeCall(method='DELETE', path=f'/projects/{uid}') + expected_call = FakeCall(method="DELETE", path=f"/projects/{uid}") assert expected_call == session.last_call @@ -660,53 +670,54 @@ def test_list_members(project, session): id=str(project.team_id), ) - session.set_responses( - {'team': team_data}, - {'users': [user]} - ) + session.set_responses({"team": team_data}, {"users": [user]}) # When members = project.list_members() # Then assert 2 == session.num_calls - expect_call_1 = FakeCall(method='GET', path=f'/teams/{team_data["id"]}') - expect_call_2 = FakeCall(method='GET', path=f'/teams/{project.team_id}/users') + expect_call_1 = FakeCall(method="GET", path=f"/teams/{team_data['id']}") + expect_call_2 = FakeCall(method="GET", path=f"/teams/{project.team_id}/users") assert expect_call_1 == session.calls[0] assert expect_call_2 == session.calls[1] assert isinstance(members[0], TeamMember) def test_project_batch_delete_no_errors(project, session): - job_resp = { - 'job_id': '1234' - } + job_resp = {"job_id": "1234"} # Actual response-like data - note there is no 'failures' array within 'output' successful_job_resp = { - 'job_type': 'batch_delete', - 'status': 'Success', - 'tasks': [ + "job_type": "batch_delete", + "status": "Success", + "tasks": [ { - "id": "7b6bafd9-f32a-4567-b54c-7ce594edc018", "task_type": "batch_delete", - "status": "Success", "dependencies": [] - } - ], - 'output': {} + "id": "7b6bafd9-f32a-4567-b54c-7ce594edc018", + "task_type": "batch_delete", + "status": "Success", + "dependencies": [], + } + ], + "output": {}, } session.set_responses(job_resp, successful_job_resp) # When with pytest.deprecated_call(): - del_resp = project.gemd_batch_delete([uuid.UUID('16fd2706-8baf-433b-82eb-8c7fada847da')]) + del_resp = project.gemd_batch_delete( + [uuid.UUID("16fd2706-8baf-433b-82eb-8c7fada847da")] + ) # Then assert len(del_resp) == 0 # When trying with entities session.set_responses(job_resp, successful_job_resp) - entity = ProcessSpec(name="proc spec", uids={'id': '16fd2706-8baf-433b-82eb-8c7fada847da'}) + entity = ProcessSpec( + name="proc spec", uids={"id": "16fd2706-8baf-433b-82eb-8c7fada847da"} + ) with pytest.deprecated_call(): del_resp = project.gemd_batch_delete([entity]) @@ -715,43 +726,40 @@ def test_project_batch_delete_no_errors(project, session): def test_project_batch_delete(project, session): - job_resp = { - 'job_id': '1234' - } + job_resp = {"job_id": "1234"} - failures_escaped_json = json.dumps([ - { - "id": { - 'scope': 'somescope', - 'id': 'abcd-1234' - }, - 'cause': { - "code": 400, - "message": "", - "validation_errors": [ - { - "failure_message": "fail msg", - "failure_id": "identifier.coreid.missing" - } - ] + failures_escaped_json = json.dumps( + [ + { + "id": {"scope": "somescope", "id": "abcd-1234"}, + "cause": { + "code": 400, + "message": "", + "validation_errors": [ + { + "failure_message": "fail msg", + "failure_id": "identifier.coreid.missing", + } + ], + }, } - } - ]) + ] + ) failed_job_resp = { - 'job_type': 'batch_delete', - 'status': 'Success', - 'tasks': [], - 'output': { - 'failures': failures_escaped_json - } + "job_type": "batch_delete", + "status": "Success", + "tasks": [], + "output": {"failures": failures_escaped_json}, } session.set_responses(job_resp, failed_job_resp, job_resp, failed_job_resp) # When with pytest.deprecated_call(): - del_resp = project.gemd_batch_delete([uuid.UUID('16fd2706-8baf-433b-82eb-8c7fada847da')]) + del_resp = project.gemd_batch_delete( + [uuid.UUID("16fd2706-8baf-433b-82eb-8c7fada847da")] + ) # Then assert 2 == session.num_calls @@ -759,22 +767,31 @@ def test_project_batch_delete(project, session): assert len(del_resp) == 1 first_failure = del_resp[0] - expected_api_error = ApiError.build({ - "code": "400", - "message": "", - "validation_errors": [{"failure_message": "fail msg", "failure_id": "identifier.coreid.missing"}] - }) + expected_api_error = ApiError.build( + { + "code": "400", + "message": "", + "validation_errors": [ + { + "failure_message": "fail msg", + "failure_id": "identifier.coreid.missing", + } + ], + } + ) - assert first_failure[0] == LinkByUID('somescope', 'abcd-1234') + assert first_failure[0] == LinkByUID("somescope", "abcd-1234") assert first_failure[1].dump() == expected_api_error.dump() # And again with tuples of (scope, id) with pytest.deprecated_call(): - del_resp = project.gemd_batch_delete([LinkByUID('id', '16fd2706-8baf-433b-82eb-8c7fada847da')]) + del_resp = project.gemd_batch_delete( + [LinkByUID("id", "16fd2706-8baf-433b-82eb-8c7fada847da")] + ) assert len(del_resp) == 1 first_failure = del_resp[0] - assert first_failure[0] == LinkByUID('somescope', 'abcd-1234') + assert first_failure[0] == LinkByUID("somescope", "abcd-1234") assert first_failure[1].dump() == expected_api_error.dump() @@ -792,22 +809,30 @@ def test_owned_dataset_ids(project, datasets): # Create a set of datasets in the project ids = {uuid.uuid4() for _ in range(5)} for d_id in ids: - dataset = Dataset(name=f"Test Dataset - {d_id}", summary="Test Dataset", description="Test Dataset") + dataset = Dataset( + name=f"Test Dataset - {d_id}", + summary="Test Dataset", + description="Test Dataset", + ) datasets.register(dataset) # Set the session response to have the list of dataset IDs - project.session.set_response({'ids': list(ids)}) + project.session.set_response({"ids": list(ids)}) # Fetch the list of UUID owned by the current project with pytest.deprecated_call(): owned_ids = project.owned_dataset_ids() # Let's mock our expected API call so we can compare and ensure that the one made is the same - expect_call = FakeCall(method='GET', - path='/DATASET/authorized-ids', - params={'userId': '', - 'domain': '/projects/16fd2706-8baf-433b-82eb-8c7fada847da', - 'action': 'WRITE'}) + expect_call = FakeCall( + method="GET", + path="/DATASET/authorized-ids", + params={ + "userId": "", + "domain": "/projects/16fd2706-8baf-433b-82eb-8c7fada847da", + "action": "WRITE", + }, + ) # Compare our calls assert expect_call == project.session.last_call assert project.session.num_calls == len(ids) + 1 diff --git a/tests/resources/test_project_member.py b/tests/resources/test_project_member.py index 3eb4e3cdc..918841fa5 100644 --- a/tests/resources/test_project_member.py +++ b/tests/resources/test_project_member.py @@ -23,5 +23,6 @@ def project_member(user, project) -> ProjectMember: def test_string_representation(project_member): - assert project_member.__str__() == ""\ - .format(project_member.user.screen_name, project_member.project.name) + assert project_member.__str__() == "".format( + project_member.user.screen_name, project_member.project.name + ) diff --git a/tests/resources/test_report.py b/tests/resources/test_report.py index 5d913703f..e94028d9c 100644 --- a/tests/resources/test_report.py +++ b/tests/resources/test_report.py @@ -1,8 +1,8 @@ """Tests getting a report""" + import random import uuid -import pytest from citrine.resources.report import ReportResource @@ -12,16 +12,22 @@ def test_get_report(): project_id = uuid.uuid4() predictor_id = uuid.uuid4() - report_path = f'/projects/{project_id}/predictors/{predictor_id}/versions/most_recent/report' + report_path = ( + f"/projects/{project_id}/predictors/{predictor_id}/versions/most_recent/report" + ) session = FakeSession() - session.set_response(dict(status='PENDING', - report=dict(descriptors=[], models=[]), - uid=str(str(uuid.uuid4())))) + session.set_response( + dict( + status="PENDING", + report=dict(descriptors=[], models=[]), + uid=str(str(uuid.uuid4())), + ) + ) report = ReportResource(project_id, session).get(predictor_id=predictor_id) - assert report.status == 'PENDING' + assert report.status == "PENDING" assert session.calls == [FakeCall(method="GET", path=report_path)] @@ -29,14 +35,20 @@ def test_get_report_with_version(): project_id = uuid.uuid4() predictor_id = uuid.uuid4() predictor_version = random.randint(1, 10) - report_path = f'/projects/{project_id}/predictors/{predictor_id}/versions/{predictor_version}/report' + report_path = f"/projects/{project_id}/predictors/{predictor_id}/versions/{predictor_version}/report" session = FakeSession() - session.set_response(dict(status='PENDING', - report=dict(descriptors=[], models=[]), - uid=str(str(uuid.uuid4())))) - - report = ReportResource(project_id, session).get(predictor_id=predictor_id, predictor_version=predictor_version) - - assert report.status == 'PENDING' + session.set_response( + dict( + status="PENDING", + report=dict(descriptors=[], models=[]), + uid=str(str(uuid.uuid4())), + ) + ) + + report = ReportResource(project_id, session).get( + predictor_id=predictor_id, predictor_version=predictor_version + ) + + assert report.status == "PENDING" assert session.calls == [FakeCall(method="GET", path=report_path)] diff --git a/tests/resources/test_resources.py b/tests/resources/test_resources.py index 18a263f15..38c03be81 100644 --- a/tests/resources/test_resources.py +++ b/tests/resources/test_resources.py @@ -3,38 +3,71 @@ import pytest from gemd.entity.bounds.real_bounds import RealBounds -from citrine.resources.condition_template import ConditionTemplate, ConditionTemplateCollection +from citrine.resources.condition_template import ( + ConditionTemplate, + ConditionTemplateCollection, +) from citrine.resources.ingredient_run import IngredientRun, IngredientRunCollection from citrine.resources.ingredient_spec import IngredientSpec, IngredientSpecCollection from citrine.resources.material_spec import MaterialSpec, MaterialSpecCollection -from citrine.resources.material_template import MaterialTemplate, MaterialTemplateCollection +from citrine.resources.material_template import ( + MaterialTemplate, + MaterialTemplateCollection, +) from citrine.resources.measurement_run import MeasurementRun, MeasurementRunCollection -from citrine.resources.measurement_spec import MeasurementSpec, MeasurementSpecCollection -from citrine.resources.measurement_template import MeasurementTemplate, MeasurementTemplateCollection -from citrine.resources.parameter_template import ParameterTemplate, ParameterTemplateCollection +from citrine.resources.measurement_spec import ( + MeasurementSpec, + MeasurementSpecCollection, +) +from citrine.resources.measurement_template import ( + MeasurementTemplate, + MeasurementTemplateCollection, +) +from citrine.resources.parameter_template import ( + ParameterTemplate, + ParameterTemplateCollection, +) from citrine.resources.process_run import ProcessRun, ProcessRunCollection from citrine.resources.process_spec import ProcessSpec, ProcessSpecCollection -from citrine.resources.process_template import ProcessTemplate, ProcessTemplateCollection -from citrine.resources.property_template import PropertyTemplate, PropertyTemplateCollection +from citrine.resources.process_template import ( + ProcessTemplate, + ProcessTemplateCollection, +) +from citrine.resources.property_template import ( + PropertyTemplate, + PropertyTemplateCollection, +) from citrine.resources.response import Response arbitrary_uuid = uuid.uuid4() resource_string_data = [ (IngredientRun, {}, ""), - (IngredientSpec, {'name': 'foo'}, ""), - (MaterialSpec, {'name': 'foo'}, ""), - (MaterialTemplate, {'name': 'foo'}, ""), - (MeasurementRun, {'name': 'foo'}, ""), - (MeasurementSpec, {'name': 'foo'}, ""), - (MeasurementTemplate, {'name': 'foo'}, ""), - (ParameterTemplate, {'name': 'foo', 'bounds': RealBounds(0, 1, '')}, ""), - (ProcessRun, {'name': 'foo'}, ""), - (ProcessSpec, {'name': 'foo'}, ""), - (ProcessTemplate, {'name': 'foo'}, ""), - (PropertyTemplate, {'name': 'foo', 'bounds': RealBounds(0, 1, '')}, ""), - (ConditionTemplate, {'name': 'foo', 'bounds': RealBounds(0, 1, '')}, ""), - (Response, {'status_code': 200}, "") + (IngredientSpec, {"name": "foo"}, ""), + (MaterialSpec, {"name": "foo"}, ""), + (MaterialTemplate, {"name": "foo"}, ""), + (MeasurementRun, {"name": "foo"}, ""), + (MeasurementSpec, {"name": "foo"}, ""), + (MeasurementTemplate, {"name": "foo"}, ""), + ( + ParameterTemplate, + {"name": "foo", "bounds": RealBounds(0, 1, "")}, + "", + ), + (ProcessRun, {"name": "foo"}, ""), + (ProcessSpec, {"name": "foo"}, ""), + (ProcessTemplate, {"name": "foo"}, ""), + ( + PropertyTemplate, + {"name": "foo", "bounds": RealBounds(0, 1, "")}, + "", + ), + ( + ConditionTemplate, + {"name": "foo", "bounds": RealBounds(0, 1, "")}, + "", + ), + (Response, {"status_code": 200}, ""), ] resource_type_data = [ @@ -54,11 +87,11 @@ ] -@pytest.mark.parametrize('resource_type,kwargs,val', resource_string_data) +@pytest.mark.parametrize("resource_type,kwargs,val", resource_string_data) def test_str_representation(resource_type, kwargs, val): assert val == str(resource_type(**kwargs)) -@pytest.mark.parametrize('collection_type,resource_type', resource_type_data) +@pytest.mark.parametrize("collection_type,resource_type", resource_type_data) def test_collection_type(collection_type, resource_type): assert resource_type == collection_type.get_type() diff --git a/tests/resources/test_response.py b/tests/resources/test_response.py index 3fb28a63c..3350c605d 100644 --- a/tests/resources/test_response.py +++ b/tests/resources/test_response.py @@ -14,19 +14,9 @@ def test_empty_response_repr(): def test_empty_body_present_code(): """Tests that the repr output expresses the absence of body and presence of - status code correctly.""" + status code correctly.""" resp_with_code = Response(status_code=404) no_body_found = re.search("No body available", resp_with_code.__repr__()) status_code_found = re.search("404", resp_with_code.__repr__()) assert no_body_found assert status_code_found - - -def test_empty_body_present_code(): - """Tests that the repr output expresses the presence of body and presence of - status code correctly.""" - resp_with_code_and_body = Response(status_code=404, body={"message": "a quick message"}) - body_found = re.search("a quick message", resp_with_code_and_body.__repr__()) - status_code_found = re.search("404", resp_with_code_and_body.__repr__()) - assert body_found - assert status_code_found diff --git a/tests/resources/test_sample_design_space.py b/tests/resources/test_sample_design_space.py index 005b7aa33..b2a7110e1 100644 --- a/tests/resources/test_sample_design_space.py +++ b/tests/resources/test_sample_design_space.py @@ -3,8 +3,12 @@ from citrine.informatics.design_spaces.design_space import DesignSpace from citrine.informatics.design_spaces.sample_design_space import SampleDesignSpaceInput -from citrine.informatics.executions.sample_design_space_execution import SampleDesignSpaceExecution -from citrine.resources.sample_design_space_execution import SampleDesignSpaceExecutionCollection +from citrine.informatics.executions.sample_design_space_execution import ( + SampleDesignSpaceExecution, +) +from citrine.resources.sample_design_space_execution import ( + SampleDesignSpaceExecutionCollection, +) from tests.utils.session import FakeSession, FakeCall @@ -24,11 +28,13 @@ def collection(session) -> SampleDesignSpaceExecutionCollection: @pytest.fixture def design_space() -> DesignSpace: - return + return @pytest.fixture -def sample_design_space_execution(collection: SampleDesignSpaceExecutionCollection, sample_design_space_execution_dict) -> SampleDesignSpaceExecution: +def sample_design_space_execution( + collection: SampleDesignSpaceExecutionCollection, sample_design_space_execution_dict +) -> SampleDesignSpaceExecution: return collection.build(sample_design_space_execution_dict) @@ -43,48 +49,63 @@ def test_basic_methods(sample_design_space_execution, collection): def test_build_new_execution(collection, sample_design_space_execution_dict): - execution: SampleDesignSpaceExecution = collection.build(sample_design_space_execution_dict) + execution: SampleDesignSpaceExecution = collection.build( + sample_design_space_execution_dict + ) assert str(execution.uid) == sample_design_space_execution_dict["id"] assert execution.project_id == collection.project_id assert execution._session == collection.session - assert execution.in_progress() and not execution.succeeded() and not execution.failed() + assert ( + execution.in_progress() and not execution.succeeded() and not execution.failed() + ) -def test_trigger_execution(collection: SampleDesignSpaceExecutionCollection, sample_design_space_execution_dict, session): +def test_trigger_execution( + collection: SampleDesignSpaceExecutionCollection, + sample_design_space_execution_dict, + session, +): # Given session.set_response(sample_design_space_execution_dict) - sample_design_space_execution_input = SampleDesignSpaceInput( - n_candidates=10 - ) + sample_design_space_execution_input = SampleDesignSpaceInput(n_candidates=10) # When actual_execution = collection.trigger(sample_design_space_execution_input) # Then assert str(actual_execution.uid) == sample_design_space_execution_dict["id"] - expected_path = '/projects/{}/design-spaces/{}/sample'.format( + expected_path = "/projects/{}/design-spaces/{}/sample".format( collection.project_id, collection.design_space_id ) assert session.last_call == FakeCall( - method='POST', + method="POST", path=expected_path, json={ - 'n_candidates': sample_design_space_execution_input.n_candidates, - } + "n_candidates": sample_design_space_execution_input.n_candidates, + }, ) def test_execution_completes(): data_success = { - 'id': str(uuid.uuid4()), - 'status': {'major': 'SUCCEEDED', 'minor': 'COMPLETED', 'detail': [], 'info': []}, + "id": str(uuid.uuid4()), + "status": { + "major": "SUCCEEDED", + "minor": "COMPLETED", + "detail": [], + "info": [], + }, } execution_success = SampleDesignSpaceExecution.build(data_success) assert execution_success.succeeded() -def test_sample_design_space_execution_results(sample_design_space_execution: SampleDesignSpaceExecution, session, example_sample_design_space_response): +def test_sample_design_space_execution_results( + sample_design_space_execution: SampleDesignSpaceExecution, + session, + example_sample_design_space_response, +): # Given session.set_response(example_sample_design_space_response) @@ -92,30 +113,36 @@ def test_sample_design_space_execution_results(sample_design_space_execution: Sa list(sample_design_space_execution.results(per_page=4)) # Then - expected_path = '/projects/{}/design-spaces/{}/sample/{}/results'.format( + expected_path = "/projects/{}/design-spaces/{}/sample/{}/results".format( sample_design_space_execution.project_id, sample_design_space_execution.design_space_id, sample_design_space_execution.uid, ) - assert session.last_call == FakeCall(method='GET', path=expected_path, params={"page": 1, "per_page": 4}) + assert session.last_call == FakeCall( + method="GET", path=expected_path, params={"page": 1, "per_page": 4} + ) -def test_sample_design_space_execution_result(sample_design_space_execution: SampleDesignSpaceExecution, session, example_sample_design_space_response): +def test_sample_design_space_execution_result( + sample_design_space_execution: SampleDesignSpaceExecution, + session, + example_sample_design_space_response, +): # Given session.set_response(example_sample_design_space_response["response"][0]) # When - result_id=example_sample_design_space_response["response"][0]["id"] + result_id = example_sample_design_space_response["response"][0]["id"] sample_design_space_execution.result(result_id=result_id) # Then - expected_path = '/projects/{}/design-spaces/{}/sample/{}/results/{}'.format( + expected_path = "/projects/{}/design-spaces/{}/sample/{}/results/{}".format( sample_design_space_execution.project_id, sample_design_space_execution.design_space_id, sample_design_space_execution.uid, result_id, ) - assert session.last_call == FakeCall(method='GET', path=expected_path) + assert session.last_call == FakeCall(method="GET", path=expected_path) def test_list(collection: SampleDesignSpaceExecutionCollection, session): @@ -123,14 +150,12 @@ def test_list(collection: SampleDesignSpaceExecutionCollection, session): lst = list(collection.list(per_page=4)) assert len(lst) == 0 - expected_path = '/projects/{}/design-spaces/{}/sample'.format( + expected_path = "/projects/{}/design-spaces/{}/sample".format( collection.project_id, collection.design_space_id, ) assert session.last_call == FakeCall( - method='GET', - path=expected_path, - params={"page": 1, "per_page": 4} + method="GET", path=expected_path, params={"page": 1, "per_page": 4} ) diff --git a/tests/resources/test_table_config.py b/tests/resources/test_table_config.py index 7be645f05..1dbdaf34c 100644 --- a/tests/resources/test_table_config.py +++ b/tests/resources/test_table_config.py @@ -1,26 +1,48 @@ from uuid import UUID, uuid4 -import pytest +import pytest from gemd.entity.link_by_uid import LinkByUID + from citrine.gemd_queries.gemd_query import GemdQuery -from citrine.gemtables.columns import MeanColumn, OriginalUnitsColumn, StdDevColumn, IdentityColumn +from citrine.gemtables.columns import ( + IdentityColumn, + MeanColumn, + OriginalUnitsColumn, + StdDevColumn, +) from citrine.gemtables.rows import MaterialRunByTemplate -from citrine.gemtables.variables import AttributeByTemplate, TerminalMaterialInfo, \ - IngredientQuantityDimension, IngredientQuantityByProcessAndName, \ - IngredientIdentifierByProcessTemplateAndName, TerminalMaterialIdentifier, \ - IngredientQuantityInOutput, IngredientIdentifierInOutput, \ - IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput -from citrine.resources.table_config import TableConfig, TableConfigCollection, TableBuildAlgorithm, \ - TableFromGemdQueryAlgorithm +from citrine.gemtables.variables import ( + AttributeByTemplate, + IngredientIdentifierByProcessTemplateAndName, + IngredientIdentifierInOutput, + IngredientLabelsSetByProcessAndName, + IngredientLabelsSetInOutput, + IngredientQuantityByProcessAndName, + IngredientQuantityDimension, + IngredientQuantityInOutput, + TerminalMaterialIdentifier, + TerminalMaterialInfo, +) from citrine.resources.data_concepts import CITRINE_SCOPE from citrine.resources.material_run import MaterialRun -from citrine.resources.project import Project from citrine.resources.process_template import ProcessTemplate +from citrine.resources.project import Project +from citrine.resources.table_config import ( + TableBuildAlgorithm, + TableConfig, + TableConfigCollection, + TableFromGemdQueryAlgorithm, +) from citrine.resources.team import Team from citrine.seeding.find_or_create import create_or_update -from tests.utils.factories import TableConfigResponseDataFactory, ListTableConfigResponseDataFactory, \ - GemdQueryDataFactory, TableConfigDataFactory, DatasetDataFactory -from tests.utils.session import FakeSession, FakeCall +from tests.utils.factories import ( + DatasetDataFactory, + GemdQueryDataFactory, + ListTableConfigResponseDataFactory, + TableConfigDataFactory, + TableConfigResponseDataFactory, +) +from tests.utils.session import FakeCall, FakeSession @pytest.fixture @@ -30,60 +52,56 @@ def session() -> FakeSession: @pytest.fixture def team(session) -> Team: - team = Team(name='Test Team', session=session) - team.uid = UUID('16fd2706-8baf-433b-82eb-8c7fada847da') + team = Team(name="Test Team", session=session) + team.uid = UUID("16fd2706-8baf-433b-82eb-8c7fada847da") return team @pytest.fixture def project(session, team) -> Project: - project = Project( - name="Test GEM Table project", - session=session, - team_id=team.uid - ) - project.uid = UUID('6b608f78-e341-422c-8076-35adc8828545') - session.set_response({ - 'project': { - 'team': { - 'id': str(team.uid) - } - } - }) + project = Project(name="Test GEM Table project", session=session, team_id=team.uid) + project.uid = UUID("6b608f78-e341-422c-8076-35adc8828545") + session.set_response({"project": {"team": {"id": str(team.uid)}}}) return project @pytest.fixture def collection(session) -> TableConfigCollection: return TableConfigCollection( - team_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), - project_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), - session=session + team_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), + project_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), + session=session, ) def empty_defn() -> TableConfig: - return TableConfig(name="empty", description="empty", datasets=[], rows=[], variables=[], columns=[]) + return TableConfig( + name="empty", + description="empty", + datasets=[], + rows=[], + variables=[], + columns=[], + ) def test_deprecation_of_positional_arguments(session): with pytest.deprecated_call(): TableConfigCollection( - UUID('6b608f78-e341-422c-8076-35adc8828545'), + UUID("6b608f78-e341-422c-8076-35adc8828545"), session, - team_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), + team_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), ) with pytest.raises(TypeError): TableConfigCollection( - team_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), - session=session + team_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), session=session ) with pytest.raises(TypeError): TableConfigCollection( - team_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), - project_id=UUID('6b608f78-e341-422c-8076-35adc8828545') + team_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), + project_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), ) @@ -103,7 +121,7 @@ def test_get_table_config(collection, session): assert 1 == session.num_calls expect_call = FakeCall( method="GET", - path=collection._get_path(defn_id, action=["versions", ver_number]) + path=collection._get_path(defn_id, action=["versions", ver_number]), ) assert session.last_call == expect_call assert str(retrieved_table_config.config_uid) == defn_id @@ -112,7 +130,12 @@ def test_get_table_config(collection, session): # Given table_configs_response = ListTableConfigResponseDataFactory() defn_id = table_configs_response["definition"]["id"] - version_number = max([version_dict["version_number"] for version_dict in table_configs_response["versions"]]) + version_number = max( + [ + version_dict["version_number"] + for version_dict in table_configs_response["versions"] + ] + ) session.set_response(table_configs_response) # When @@ -120,10 +143,7 @@ def test_get_table_config(collection, session): # Then assert 2 == session.num_calls - expect_call = FakeCall( - method="GET", - path=collection._get_path(defn_id) - ) + expect_call = FakeCall(method="GET", path=collection._get_path(defn_id)) assert session.last_call == expect_call assert str(retrieved_table_config.config_uid) == defn_id assert retrieved_table_config.version_number == version_number @@ -136,15 +156,23 @@ def test_get_table_config_raises(collection): def test_init_table_config(): - table_config = TableConfig(name="foo", description="bar", rows=[], columns=[], variables=[], datasets=[]) + table_config = TableConfig( + name="foo", description="bar", rows=[], columns=[], variables=[], datasets=[] + ) assert table_config.config_uid is None assert table_config.version_number is None def test_uid_aliases_config_uid(): """Test that uid returns config_uid attribute""" - table_config = TableConfig(name="name", description="description", datasets=[], rows=[], variables=[], - columns=[]) + table_config = TableConfig( + name="name", + description="description", + datasets=[], + rows=[], + variables=[], + columns=[], + ) table_config.config_uid = uuid4() assert table_config.uid == table_config.config_uid @@ -156,23 +184,26 @@ def test_uid_aliases_config_uid(): def test_create_or_update_config(collection, session): initial_config = TableConfig( - name="Test Config", description="description", datasets=[], rows=[], variables=[], columns=[] + name="Test Config", + description="description", + datasets=[], + rows=[], + variables=[], + columns=[], ) # Fake table config data response retrieved_config_response = TableConfigResponseDataFactory() retrieved_config_response["definition"]["name"] = "Test Config" retrieved_id = retrieved_config_response["definition"]["id"] - retrieved_version = retrieved_config_response["version"]["version_number"] + session.set_responses( - {'definitions': [retrieved_config_response]}, - retrieved_config_response + {"definitions": [retrieved_config_response]}, retrieved_config_response ) # Create or update with mocked list, return just fake response updated_table_config = create_or_update( - collection=collection, - resource=initial_config + collection=collection, resource=initial_config ) # Updated config should have UID set from response data @@ -184,22 +215,34 @@ def test_dup_names(): """Make sure that variable name and headers are unique across a table config""" with pytest.raises(ValueError) as excinfo: TableConfig( - name="foo", description="bar", datasets=[], rows=[], columns=[], + name="foo", + description="bar", + datasets=[], + rows=[], + columns=[], variables=[ TerminalMaterialInfo(name="foo", headers=["foo", "bar"], field="name"), - TerminalMaterialInfo(name="foo", headers=["foo", "baz"], field="name") - ] + TerminalMaterialInfo(name="foo", headers=["foo", "baz"], field="name"), + ], ) assert "Multiple" in str(excinfo.value) assert "foo" in str(excinfo.value) with pytest.raises(ValueError) as excinfo: TableConfig( - name="foo", description="bar", datasets=[], rows=[], columns=[], + name="foo", + description="bar", + datasets=[], + rows=[], + columns=[], variables=[ - TerminalMaterialInfo(name="foo", headers=["spam", "eggs"], field="name"), - TerminalMaterialInfo(name="bar", headers=["spam", "eggs"], field="name") - ] + TerminalMaterialInfo( + name="foo", headers=["spam", "eggs"], field="name" + ), + TerminalMaterialInfo( + name="bar", headers=["spam", "eggs"], field="name" + ), + ], ) assert "Multiple" in str(excinfo.value) assert "spam" in str(excinfo.value) @@ -209,10 +252,12 @@ def test_missing_variable(): """Make sure that every data_source matches a name of a variable""" with pytest.raises(ValueError) as excinfo: TableConfig( - name="foo", description="bar", datasets=[], rows=[], variables=[], - columns=[ - MeanColumn(data_source="density") - ] + name="foo", + description="bar", + datasets=[], + rows=[], + variables=[], + columns=[MeanColumn(data_source="density")], ) assert "must match" in str(excinfo.value) assert "density" in str(excinfo.value) @@ -222,19 +267,21 @@ def test_dump_example(): density = AttributeByTemplate( name="density", headers=["Slice", "Density"], - template=LinkByUID(scope="templates", id="density") + template=LinkByUID(scope="templates", id="density"), ) - table_config = TableConfig( + TableConfig( name="Example Table", description="Illustrative example that's meant to show how Table Configs will look serialized", datasets=[uuid4()], variables=[density], - rows=[MaterialRunByTemplate(templates=[LinkByUID(scope="templates", id="slices")])], + rows=[ + MaterialRunByTemplate(templates=[LinkByUID(scope="templates", id="slices")]) + ], columns=[ MeanColumn(data_source=density.name), StdDevColumn(data_source=density.name), OriginalUnitsColumn(data_source=density.name), - ] + ], ) @@ -246,7 +293,7 @@ def test_preview(collection, session): expect_call = FakeCall( method="POST", path=f"teams/{collection.team_id}/ara-definitions/preview", - json={"definition": empty_defn().dump(), "rows": []} + json={"definition": empty_defn().dump(), "rows": []}, ) assert session.last_call == expect_call @@ -255,64 +302,63 @@ def test_default_for_material(collection: TableConfigCollection, session): """Test that default for material hits the right route""" # Given dummy_resp = { - 'config': TableConfigDataFactory(), - 'ambiguous': [ + "config": TableConfigDataFactory(), + "ambiguous": [ [ - TerminalMaterialIdentifier(name='foo', headers=['foo'], scope='id').dump(), - IdentityColumn(data_source='foo').dump(), + TerminalMaterialIdentifier( + name="foo", headers=["foo"], scope="id" + ).dump(), + IdentityColumn(data_source="foo").dump(), ] ], } session.responses.append(dummy_resp) collection.default_for_material( - material='my_id', - name='my_name', - description='my_description', - algorithm=TableBuildAlgorithm.SINGLE_ROW + material="my_id", + name="my_name", + description="my_description", + algorithm=TableBuildAlgorithm.SINGLE_ROW, ) assert 1 == session.num_calls assert session.last_call == FakeCall( method="GET", path=f"teams/{collection.team_id}/table-configs/default", params={ - 'id': 'my_id', - 'scope': CITRINE_SCOPE, - 'algorithm': TableBuildAlgorithm.SINGLE_ROW.value, - 'name': 'my_name', - 'description': 'my_description' - } + "id": "my_id", + "scope": CITRINE_SCOPE, + "algorithm": TableBuildAlgorithm.SINGLE_ROW.value, + "name": "my_name", + "description": "my_description", + }, ) # We allowed for the more forgiving call structure, so test it. session.calls.clear() session.responses.append(dummy_resp) collection.default_for_material( - material=MaterialRun('foo', uids={'scope': 'id'}), + material=MaterialRun("foo", uids={"scope": "id"}), algorithm=TableBuildAlgorithm.FORMULATIONS.value, - name='my_name', - description='my_description', + name="my_name", + description="my_description", ) assert 1 == session.num_calls assert session.last_call == FakeCall( method="GET", path=f"teams/{collection.team_id}/table-configs/default", params={ - 'id': 'id', - 'scope': 'scope', - 'algorithm': TableBuildAlgorithm.FORMULATIONS.value, - 'name': 'my_name', - 'description': 'my_description' - } + "id": "id", + "scope": "scope", + "algorithm": TableBuildAlgorithm.FORMULATIONS.value, + "name": "my_name", + "description": "my_description", + }, ) def test_default_for_material_failure(collection: TableConfigCollection): with pytest.raises(ValueError): - collection.default_for_material( - material=MaterialRun('foo'), - name='foo' - ) + collection.default_for_material(material=MaterialRun("foo"), name="foo") def test_from_query(collection: TableConfigCollection, session): @@ -321,31 +367,33 @@ def test_from_query(collection: TableConfigCollection, session): config = TableConfigDataFactory() config_resp = { - 'config': config, - 'ambiguous': [ + "config": config, + "ambiguous": [ [ - TerminalMaterialIdentifier(name='foo', headers=['foo'], scope='id').dump(), - IdentityColumn(data_source='foo').dump(), + TerminalMaterialIdentifier( + name="foo", headers=["foo"], scope="id" + ).dump(), + IdentityColumn(data_source="foo").dump(), ] ], } session.responses.append(config_resp) fake_call = FakeCall( - method='POST', - path=f'teams/{collection.team_id}/table-configs/from-query', + method="POST", + path=f"teams/{collection.team_id}/table-configs/from-query", params={ - 'name': config['name'], - 'description': config['description'], - 'algorithm': TableFromGemdQueryAlgorithm.MULTISTEP_MATERIALS, + "name": config["name"], + "description": config["description"], + "algorithm": TableFromGemdQueryAlgorithm.MULTISTEP_MATERIALS, }, json=query, ) collection.from_query( - name=config['name'], - description=config['description'], + name=config["name"], + description=config["description"], gemd_query=GemdQuery.build(query), - algorithm=TableFromGemdQueryAlgorithm.MULTISTEP_MATERIALS + algorithm=TableFromGemdQueryAlgorithm.MULTISTEP_MATERIALS, ) assert 1 == session.num_calls assert session.last_call.method == fake_call.method @@ -362,15 +410,19 @@ def test_from_query(collection: TableConfigCollection, session): def test_from_nameless_query_and_register(collection: TableConfigCollection, session): """Test that default for material hits the right route""" query = GemdQueryDataFactory() - config = TableConfigDataFactory(generation_algorithm=TableFromGemdQueryAlgorithm.MULTISTEP_MATERIALS) + config = TableConfigDataFactory( + generation_algorithm=TableFromGemdQueryAlgorithm.MULTISTEP_MATERIALS + ) - dataset_resps = [DatasetDataFactory(id=dataset) for dataset in query['datasets']] + dataset_resps = [DatasetDataFactory(id=dataset) for dataset in query["datasets"]] config_resp = { - 'config': config, - 'ambiguous': [ + "config": config, + "ambiguous": [ [ - TerminalMaterialIdentifier(name='foo', headers=['foo'], scope='id').dump(), - IdentityColumn(data_source='foo').dump(), + TerminalMaterialIdentifier( + name="foo", headers=["foo"], scope="id" + ).dump(), + IdentityColumn(data_source="foo").dump(), ] ], } @@ -382,10 +434,10 @@ def test_from_nameless_query_and_register(collection: TableConfigCollection, ses generated, _ = collection.from_query( gemd_query=GemdQuery.build(query), - description='my_description', - register_config=True + description="my_description", + register_config=True, ) - assert session.num_calls == len(query['datasets']) + 1 + 1 + assert session.num_calls == len(query["datasets"]) + 1 + 1 assert generated != TableConfig.build(config) # Because it has ids assert generated.variables == TableConfig.build(config).variables @@ -399,16 +451,18 @@ def test_add_columns(): with pytest.raises(ValueError) as excinfo: empty.add_columns( variable=TerminalMaterialInfo(name="foo", headers=["bar"], field="name"), - columns=[IdentityColumn(data_source="bar")] + columns=[IdentityColumn(data_source="bar")], ) assert "data_source must be" in str(excinfo.value) # Check desired behavior with_name_col = empty.add_columns( variable=TerminalMaterialInfo(name="name", headers=["bar"], field="name"), - columns=[IdentityColumn(data_source="name")] + columns=[IdentityColumn(data_source="name")], ) - assert with_name_col.variables == [TerminalMaterialInfo(name="name", headers=["bar"], field="name")] + assert with_name_col.variables == [ + TerminalMaterialInfo(name="name", headers=["bar"], field="name") + ] assert with_name_col.columns == [IdentityColumn(data_source="name")] assert with_name_col.config_uid == empty.config_uid @@ -416,7 +470,7 @@ def test_add_columns(): with pytest.raises(ValueError) as excinfo: with_name_col.add_columns( variable=TerminalMaterialInfo(name="name", headers=["bar"], field="name"), - columns=[IdentityColumn(data_source="name")] + columns=[IdentityColumn(data_source="name")], ) assert "already used" in str(excinfo.value) @@ -424,171 +478,281 @@ def test_add_columns(): def test_add_all_ingredients_via_project_deprecated(session, project): """Test the behavior of AraDefinition.add_all_ingredients.""" # GIVEN - process_id = '3a308f78-e341-f39c-8076-35a2c88292ad' - process_name = 'mixing' + process_id = "3a308f78-e341-f39c-8076-35a2c88292ad" + process_name = "mixing" allowed_names = ["gold nanoparticles", "methanol", "acetone"] - process_link = LinkByUID('id', process_id) + process_link = LinkByUID("id", process_id) session.set_response( - ProcessTemplate(process_name, uids={'id': process_id}, allowed_names=allowed_names).dump() + ProcessTemplate( + process_name, uids={"id": process_id}, allowed_names=allowed_names + ).dump() ) # WHEN we add all ingredients in a volume basis empty = empty_defn() with pytest.deprecated_call(): - def1 = empty.add_all_ingredients(process_template=process_link, project=project, - quantity_dimension=IngredientQuantityDimension.VOLUME) + def1 = empty.add_all_ingredients( + process_template=process_link, + project=project, + quantity_dimension=IngredientQuantityDimension.VOLUME, + ) def1.config_uid = uuid4() # THEN there should be 3 variables and columns for each name, one for id, quantity, and labels assert len(def1.variables) == len(allowed_names) * 3 assert len(def1.columns) == len(def1.variables) for name in allowed_names: - assert next((var for var in def1.variables if name in var.headers - and isinstance(var, IngredientQuantityByProcessAndName)), None) is not None - assert next((var for var in def1.variables if name in var.headers - and isinstance(var, IngredientIdentifierByProcessTemplateAndName)), None) is not None - assert next((var for var in def1.variables if name in var.headers - and isinstance(var, IngredientLabelsSetByProcessAndName)), None) is not None + assert ( + next( + ( + var + for var in def1.variables + if name in var.headers + and isinstance(var, IngredientQuantityByProcessAndName) + ), + None, + ) + is not None + ) + assert ( + next( + ( + var + for var in def1.variables + if name in var.headers + and isinstance(var, IngredientIdentifierByProcessTemplateAndName) + ), + None, + ) + is not None + ) + assert ( + next( + ( + var + for var in def1.variables + if name in var.headers + and isinstance(var, IngredientLabelsSetByProcessAndName) + ), + None, + ) + is not None + ) session.set_response( - ProcessTemplate(process_name, uids={'id': process_id}, allowed_names=allowed_names).dump() + ProcessTemplate( + process_name, uids={"id": process_id}, allowed_names=allowed_names + ).dump() ) # WHEN we add all ingredients to the same Table Config as absolute quantities with pytest.deprecated_call(): - def2 = def1.add_all_ingredients(process_template=process_link, project=project, - quantity_dimension=IngredientQuantityDimension.ABSOLUTE, - unit='kg') + def2 = def1.add_all_ingredients( + process_template=process_link, + project=project, + quantity_dimension=IngredientQuantityDimension.ABSOLUTE, + unit="kg", + ) # THEN there should be 1 new variable for each name, corresponding to the quantity # There is already a variable for id and labels # There should be 2 new columns for each name, one for the quantity and one for the units - new_variables = def2.variables[len(def1.variables):] - new_columns = def2.columns[len(def1.columns):] + new_variables = def2.variables[len(def1.variables) :] + new_columns = def2.columns[len(def1.columns) :] assert len(new_variables) == len(allowed_names) assert len(new_columns) == len(allowed_names) * 2 assert def2.config_uid == def1.config_uid for name in allowed_names: - assert next((var for var in new_variables if name in var.headers - and isinstance(var, IngredientQuantityByProcessAndName)), None) is not None + assert ( + next( + ( + var + for var in new_variables + if name in var.headers + and isinstance(var, IngredientQuantityByProcessAndName) + ), + None, + ) + is not None + ) session.set_response( - ProcessTemplate(process_name, uids={'id': process_id}, allowed_names=allowed_names).dump() + ProcessTemplate( + process_name, uids={"id": process_id}, allowed_names=allowed_names + ).dump() ) # WHEN we add all ingredients to the same Table Config in a volume basis # THEN it raises an exception because these variables and columns already exist with pytest.deprecated_call(): with pytest.raises(ValueError): - def2.add_all_ingredients(process_template=process_link, project=project, - quantity_dimension=IngredientQuantityDimension.VOLUME) + def2.add_all_ingredients( + process_template=process_link, + project=project, + quantity_dimension=IngredientQuantityDimension.VOLUME, + ) # If the process template has an empty allowed_names list then an error should be raised - session.set_response( - ProcessTemplate(process_name, uids={'id': process_id}).dump() - ) + session.set_response(ProcessTemplate(process_name, uids={"id": process_id}).dump()) with pytest.deprecated_call(): with pytest.raises(RuntimeError): - empty_defn().add_all_ingredients(process_template=process_link, project=project, - quantity_dimension=IngredientQuantityDimension.VOLUME) + empty_defn().add_all_ingredients( + process_template=process_link, + project=project, + quantity_dimension=IngredientQuantityDimension.VOLUME, + ) def test_add_all_ingredients_via_team(session, team): """Test the behavior of AraDefinition.add_all_ingredients.""" # GIVEN - process_id = '3a308f78-e341-f39c-8076-35a2c88292ad' - process_name = 'mixing' + process_id = "3a308f78-e341-f39c-8076-35a2c88292ad" + process_name = "mixing" allowed_names = ["gold nanoparticles", "methanol", "acetone"] - process_link = LinkByUID('id', process_id) + process_link = LinkByUID("id", process_id) session.set_response( - ProcessTemplate(process_name, uids={'id': process_id}, allowed_names=allowed_names).dump() + ProcessTemplate( + process_name, uids={"id": process_id}, allowed_names=allowed_names + ).dump() ) # WHEN we add all ingredients in a volume basis empty = empty_defn() - def1 = empty.add_all_ingredients(process_template=process_link, team=team, - quantity_dimension=IngredientQuantityDimension.VOLUME) + def1 = empty.add_all_ingredients( + process_template=process_link, + team=team, + quantity_dimension=IngredientQuantityDimension.VOLUME, + ) def1.config_uid = uuid4() # THEN there should be 3 variables and columns for each name, one for id, quantity, and labels assert len(def1.variables) == len(allowed_names) * 3 assert len(def1.columns) == len(def1.variables) for name in allowed_names: - assert next((var for var in def1.variables if name in var.headers - and isinstance(var, IngredientQuantityByProcessAndName)), None) is not None - assert next((var for var in def1.variables if name in var.headers - and isinstance(var, IngredientIdentifierByProcessTemplateAndName)), None) is not None - assert next((var for var in def1.variables if name in var.headers - and isinstance(var, IngredientLabelsSetByProcessAndName)), None) is not None + assert ( + next( + ( + var + for var in def1.variables + if name in var.headers + and isinstance(var, IngredientQuantityByProcessAndName) + ), + None, + ) + is not None + ) + assert ( + next( + ( + var + for var in def1.variables + if name in var.headers + and isinstance(var, IngredientIdentifierByProcessTemplateAndName) + ), + None, + ) + is not None + ) + assert ( + next( + ( + var + for var in def1.variables + if name in var.headers + and isinstance(var, IngredientLabelsSetByProcessAndName) + ), + None, + ) + is not None + ) session.set_response( - ProcessTemplate(process_name, uids={'id': process_id}, allowed_names=allowed_names).dump() + ProcessTemplate( + process_name, uids={"id": process_id}, allowed_names=allowed_names + ).dump() ) # WHEN we add all ingredients to the same Table Config as absolute quantities - def2 = def1.add_all_ingredients(process_template=process_link, team=team, - quantity_dimension=IngredientQuantityDimension.ABSOLUTE, - unit='kg') + def2 = def1.add_all_ingredients( + process_template=process_link, + team=team, + quantity_dimension=IngredientQuantityDimension.ABSOLUTE, + unit="kg", + ) # THEN there should be 1 new variable for each name, corresponding to the quantity # There is already a variable for id and labels # There should be 2 new columns for each name, one for the quantity and one for the units - new_variables = def2.variables[len(def1.variables):] - new_columns = def2.columns[len(def1.columns):] + new_variables = def2.variables[len(def1.variables) :] + new_columns = def2.columns[len(def1.columns) :] assert len(new_variables) == len(allowed_names) assert len(new_columns) == len(allowed_names) * 2 assert def2.config_uid == def1.config_uid for name in allowed_names: - assert next((var for var in new_variables if name in var.headers - and isinstance(var, IngredientQuantityByProcessAndName)), None) is not None + assert ( + next( + ( + var + for var in new_variables + if name in var.headers + and isinstance(var, IngredientQuantityByProcessAndName) + ), + None, + ) + is not None + ) session.set_response( - ProcessTemplate(process_name, uids={'id': process_id}, allowed_names=allowed_names).dump() + ProcessTemplate( + process_name, uids={"id": process_id}, allowed_names=allowed_names + ).dump() ) # WHEN we add all ingredients to the same Table Config in a volume basis # THEN it raises an exception because these variables and columns already exist with pytest.raises(ValueError): - def2.add_all_ingredients(process_template=process_link, team=team, - quantity_dimension=IngredientQuantityDimension.VOLUME) + def2.add_all_ingredients( + process_template=process_link, + team=team, + quantity_dimension=IngredientQuantityDimension.VOLUME, + ) # If the process template has an empty allowed_names list then an error should be raised - session.set_response( - ProcessTemplate(process_name, uids={'id': process_id}).dump() - ) + session.set_response(ProcessTemplate(process_name, uids={"id": process_id}).dump()) with pytest.raises(RuntimeError): - empty_defn().add_all_ingredients(process_template=process_link, team=team, - quantity_dimension=IngredientQuantityDimension.VOLUME) + empty_defn().add_all_ingredients( + process_template=process_link, + team=team, + quantity_dimension=IngredientQuantityDimension.VOLUME, + ) def test_add_all_ingredients_no_principal(session): """Test the behavior of AraDefinition.add_all_ingredients.""" - process_link = LinkByUID('id', '3a308f78-e341-f39c-8076-35a2c88292ad') + process_link = LinkByUID("id", "3a308f78-e341-f39c-8076-35a2c88292ad") with pytest.raises(TypeError): - empty_defn().add_all_ingredients(process_template=process_link, - quantity_dimension=IngredientQuantityDimension.VOLUME) + empty_defn().add_all_ingredients( + process_template=process_link, + quantity_dimension=IngredientQuantityDimension.VOLUME, + ) def test_add_all_ingredients_in_output_via_project_deprecated(session, project): """Test the behavior of TableConfig.add_all_ingredients_in_output.""" # GIVEN - process1_id = '3a308f78-e341-f39c-8076-35a2c88292ad' - process1_name = 'mixing' + process1_id = "3a308f78-e341-f39c-8076-35a2c88292ad" + process1_name = "mixing" allowed_names1 = ["gold nanoparticles", "methanol", "acetone"] - process1_link = LinkByUID('id', process1_id) + process1_link = LinkByUID("id", process1_id) - process2_id = '519ab440-fbda-4768-ad63-5e09b420285c' - process2_name = 'solvent_mixing' + process2_id = "519ab440-fbda-4768-ad63-5e09b420285c" + process2_name = "solvent_mixing" allowed_names2 = ["methanol", "acetone", "ethanol", "water"] - process2_link = LinkByUID('id', process2_id) + process2_link = LinkByUID("id", process2_id) union_allowed_names = list(set(allowed_names1) | set(allowed_names2)) session.set_responses( ProcessTemplate( - process1_name, - uids={'id': process1_id}, - allowed_names=allowed_names1 + process1_name, uids={"id": process1_id}, allowed_names=allowed_names1 ).dump(), ProcessTemplate( - process2_name, - uids={'id': process2_id}, - allowed_names=allowed_names2 - ).dump() + process2_name, uids={"id": process2_id}, allowed_names=allowed_names2 + ).dump(), ) # WHEN we add all ingredients in a volume basis @@ -597,7 +761,7 @@ def test_add_all_ingredients_in_output_via_project_deprecated(session, project): def1 = empty.add_all_ingredients_in_output( process_templates=[process1_link, process2_link], project=project, - quantity_dimension=IngredientQuantityDimension.VOLUME + quantity_dimension=IngredientQuantityDimension.VOLUME, ) def1.config_uid = uuid4() @@ -605,24 +769,50 @@ def test_add_all_ingredients_in_output_via_project_deprecated(session, project): assert len(def1.variables) == len(union_allowed_names) * 3 assert len(def1.columns) == len(def1.variables) for name in union_allowed_names: - assert next((var for var in def1.variables if name in var.headers - and isinstance(var, IngredientQuantityInOutput)), None) is not None - assert next((var for var in def1.variables if name in var.headers - and isinstance(var, IngredientIdentifierInOutput)), None) is not None - assert next((var for var in def1.variables if name in var.headers - and isinstance(var, IngredientLabelsSetInOutput)), None) is not None + assert ( + next( + ( + var + for var in def1.variables + if name in var.headers + and isinstance(var, IngredientQuantityInOutput) + ), + None, + ) + is not None + ) + assert ( + next( + ( + var + for var in def1.variables + if name in var.headers + and isinstance(var, IngredientIdentifierInOutput) + ), + None, + ) + is not None + ) + assert ( + next( + ( + var + for var in def1.variables + if name in var.headers + and isinstance(var, IngredientLabelsSetInOutput) + ), + None, + ) + is not None + ) session.set_responses( ProcessTemplate( - process1_name, - uids={'id': process1_id}, - allowed_names=allowed_names1 + process1_name, uids={"id": process1_id}, allowed_names=allowed_names1 ).dump(), ProcessTemplate( - process2_name, - uids={'id': process2_id}, - allowed_names=allowed_names2 - ).dump() + process2_name, uids={"id": process2_id}, allowed_names=allowed_names2 + ).dump(), ) # WHEN we add all ingredients to the same Table Config as absolute quantities with pytest.deprecated_call(): @@ -630,31 +820,37 @@ def test_add_all_ingredients_in_output_via_project_deprecated(session, project): process_templates=[process1_link, process2_link], project=project, quantity_dimension=IngredientQuantityDimension.ABSOLUTE, - unit='kg' + unit="kg", ) # THEN there should be 1 new variable for each name, corresponding to the quantity # There is already a variable for id and labels # There should be 2 new columns for each name, one for the quantity and one for the units - new_variables = def2.variables[len(def1.variables):] - new_columns = def2.columns[len(def1.columns):] + new_variables = def2.variables[len(def1.variables) :] + new_columns = def2.columns[len(def1.columns) :] assert len(new_variables) == len(union_allowed_names) assert len(new_columns) == len(union_allowed_names) * 2 assert def2.config_uid == def1.config_uid for name in union_allowed_names: - assert next((var for var in new_variables if name in var.headers - and isinstance(var, IngredientQuantityInOutput)), None) is not None + assert ( + next( + ( + var + for var in new_variables + if name in var.headers + and isinstance(var, IngredientQuantityInOutput) + ), + None, + ) + is not None + ) session.set_responses( ProcessTemplate( - process1_name, - uids={'id': process1_id}, - allowed_names=allowed_names1 + process1_name, uids={"id": process1_id}, allowed_names=allowed_names1 ).dump(), ProcessTemplate( - process2_name, - uids={'id': process2_id}, - allowed_names=allowed_names2 - ).dump() + process2_name, uids={"id": process2_id}, allowed_names=allowed_names2 + ).dump(), ) # WHEN we add all ingredients to the same Table Config in a volume basis # THEN it raises an exception because these variables and columns already exist @@ -663,55 +859,51 @@ def test_add_all_ingredients_in_output_via_project_deprecated(session, project): def2.add_all_ingredients_in_output( process_templates=[process1_link, process2_link], project=project, - quantity_dimension=IngredientQuantityDimension.VOLUME + quantity_dimension=IngredientQuantityDimension.VOLUME, ) # If the process template has an empty allowed_names list then an error should be raised session.set_responses( ProcessTemplate( process1_name, - uids={'id': process1_id}, + uids={"id": process1_id}, ).dump(), ProcessTemplate( process2_name, - uids={'id': process2_id}, - ).dump() + uids={"id": process2_id}, + ).dump(), ) with pytest.deprecated_call(): with pytest.raises(RuntimeError): empty_defn().add_all_ingredients_in_output( process_templates=[process1_link, process2_link], project=project, - quantity_dimension=IngredientQuantityDimension.VOLUME + quantity_dimension=IngredientQuantityDimension.VOLUME, ) def test_add_all_ingredients_in_output_via_team(session, team): """Test the behavior of TableConfig.add_all_ingredients_in_output.""" # GIVEN - process1_id = '3a308f78-e341-f39c-8076-35a2c88292ad' - process1_name = 'mixing' + process1_id = "3a308f78-e341-f39c-8076-35a2c88292ad" + process1_name = "mixing" allowed_names1 = ["gold nanoparticles", "methanol", "acetone"] - process1_link = LinkByUID('id', process1_id) + process1_link = LinkByUID("id", process1_id) - process2_id = '519ab440-fbda-4768-ad63-5e09b420285c' - process2_name = 'solvent_mixing' + process2_id = "519ab440-fbda-4768-ad63-5e09b420285c" + process2_name = "solvent_mixing" allowed_names2 = ["methanol", "acetone", "ethanol", "water"] - process2_link = LinkByUID('id', process2_id) + process2_link = LinkByUID("id", process2_id) union_allowed_names = list(set(allowed_names1) | set(allowed_names2)) session.set_responses( ProcessTemplate( - process1_name, - uids={'id': process1_id}, - allowed_names=allowed_names1 + process1_name, uids={"id": process1_id}, allowed_names=allowed_names1 ).dump(), ProcessTemplate( - process2_name, - uids={'id': process2_id}, - allowed_names=allowed_names2 - ).dump() + process2_name, uids={"id": process2_id}, allowed_names=allowed_names2 + ).dump(), ) # WHEN we add all ingredients in a volume basis @@ -719,7 +911,7 @@ def test_add_all_ingredients_in_output_via_team(session, team): def1 = empty.add_all_ingredients_in_output( process_templates=[process1_link, process2_link], team=team, - quantity_dimension=IngredientQuantityDimension.VOLUME + quantity_dimension=IngredientQuantityDimension.VOLUME, ) def1.config_uid = uuid4() @@ -727,55 +919,87 @@ def test_add_all_ingredients_in_output_via_team(session, team): assert len(def1.variables) == len(union_allowed_names) * 3 assert len(def1.columns) == len(def1.variables) for name in union_allowed_names: - assert next((var for var in def1.variables if name in var.headers - and isinstance(var, IngredientQuantityInOutput)), None) is not None - assert next((var for var in def1.variables if name in var.headers - and isinstance(var, IngredientIdentifierInOutput)), None) is not None - assert next((var for var in def1.variables if name in var.headers - and isinstance(var, IngredientLabelsSetInOutput)), None) is not None + assert ( + next( + ( + var + for var in def1.variables + if name in var.headers + and isinstance(var, IngredientQuantityInOutput) + ), + None, + ) + is not None + ) + assert ( + next( + ( + var + for var in def1.variables + if name in var.headers + and isinstance(var, IngredientIdentifierInOutput) + ), + None, + ) + is not None + ) + assert ( + next( + ( + var + for var in def1.variables + if name in var.headers + and isinstance(var, IngredientLabelsSetInOutput) + ), + None, + ) + is not None + ) session.set_responses( ProcessTemplate( - process1_name, - uids={'id': process1_id}, - allowed_names=allowed_names1 + process1_name, uids={"id": process1_id}, allowed_names=allowed_names1 ).dump(), ProcessTemplate( - process2_name, - uids={'id': process2_id}, - allowed_names=allowed_names2 - ).dump() + process2_name, uids={"id": process2_id}, allowed_names=allowed_names2 + ).dump(), ) # WHEN we add all ingredients to the same Table Config as absolute quantities def2 = def1.add_all_ingredients_in_output( process_templates=[process1_link, process2_link], team=team, quantity_dimension=IngredientQuantityDimension.ABSOLUTE, - unit='kg' + unit="kg", ) # THEN there should be 1 new variable for each name, corresponding to the quantity # There is already a variable for id and labels # There should be 2 new columns for each name, one for the quantity and one for the units - new_variables = def2.variables[len(def1.variables):] - new_columns = def2.columns[len(def1.columns):] + new_variables = def2.variables[len(def1.variables) :] + new_columns = def2.columns[len(def1.columns) :] assert len(new_variables) == len(union_allowed_names) assert len(new_columns) == len(union_allowed_names) * 2 assert def2.config_uid == def1.config_uid for name in union_allowed_names: - assert next((var for var in new_variables if name in var.headers - and isinstance(var, IngredientQuantityInOutput)), None) is not None + assert ( + next( + ( + var + for var in new_variables + if name in var.headers + and isinstance(var, IngredientQuantityInOutput) + ), + None, + ) + is not None + ) session.set_responses( ProcessTemplate( - process1_name, - uids={'id': process1_id}, - allowed_names=allowed_names1 + process1_name, uids={"id": process1_id}, allowed_names=allowed_names1 ).dump(), ProcessTemplate( - process2_name, - uids={'id': process2_id}, - allowed_names=allowed_names2 - ).dump() + process2_name, uids={"id": process2_id}, allowed_names=allowed_names2 + ).dump(), ) # WHEN we add all ingredients to the same Table Config in a volume basis # THEN it raises an exception because these variables and columns already exist @@ -783,42 +1007,51 @@ def test_add_all_ingredients_in_output_via_team(session, team): def2.add_all_ingredients_in_output( process_templates=[process1_link, process2_link], team=team, - quantity_dimension=IngredientQuantityDimension.VOLUME + quantity_dimension=IngredientQuantityDimension.VOLUME, ) # If the process template has an empty allowed_names list then an error should be raised session.set_responses( ProcessTemplate( process1_name, - uids={'id': process1_id}, + uids={"id": process1_id}, ).dump(), ProcessTemplate( process2_name, - uids={'id': process2_id}, - ).dump() + uids={"id": process2_id}, + ).dump(), ) with pytest.raises(RuntimeError): empty_defn().add_all_ingredients_in_output( process_templates=[process1_link, process2_link], team=team, - quantity_dimension=IngredientQuantityDimension.VOLUME + quantity_dimension=IngredientQuantityDimension.VOLUME, ) def test_add_all_ingredients_in_output_no_principal(session): """Test the behavior of AraDefinition.add_all_ingredients.""" - process_link1 = LinkByUID('id', '3a308f78-e341-f39c-8076-35a2c88292ad') - process_link2 = LinkByUID('id', '519ab440-fbda-4768-ad63-5e09b420285c') + process_link1 = LinkByUID("id", "3a308f78-e341-f39c-8076-35a2c88292ad") + process_link2 = LinkByUID("id", "519ab440-fbda-4768-ad63-5e09b420285c") with pytest.raises(TypeError): - empty_defn().add_all_ingredients_in_output(process_templates=[process_link1, process_link2], - quantity_dimension=IngredientQuantityDimension.VOLUME) + empty_defn().add_all_ingredients_in_output( + process_templates=[process_link1, process_link2], + quantity_dimension=IngredientQuantityDimension.VOLUME, + ) def test_register_new(collection, session): """Test the behavior of AraDefinitionCollection.register() on an unregistered AraDefinition""" # Given - table_config = TableConfig(name="name", description="description", datasets=[], rows=[], variables=[], columns=[]) + table_config = TableConfig( + name="name", + description="description", + datasets=[], + rows=[], + variables=[], + columns=[], + ) table_config_response = TableConfigResponseDataFactory() defn_uid = table_config_response["definition"]["id"] @@ -841,7 +1074,14 @@ def test_register_new(collection, session): def test_register_existing(collection, session): """Test the behavior of AraDefinitionCollection.register() on a registered AraDefinition""" # Given - table_config = TableConfig(name="name", description="description", datasets=[], rows=[], variables=[], columns=[]) + table_config = TableConfig( + name="name", + description="description", + datasets=[], + rows=[], + variables=[], + columns=[], + ) table_config.config_uid = uuid4() table_config_response = TableConfigResponseDataFactory() @@ -858,13 +1098,23 @@ def test_register_existing(collection, session): # Ensure we PUT if we were called with a table config id assert session.last_call.method == "PUT" - assert session.last_call.path == f"projects/{collection.project_id}/ara-definitions/{table_config.config_uid}" + assert ( + session.last_call.path + == f"projects/{collection.project_id}/ara-definitions/{table_config.config_uid}" + ) def test_update(collection, session): """Test the behavior of AraDefinitionCollection.update() on a registered AraDefinition""" # Given - table_config = TableConfig(name="name", description="description", datasets=[], rows=[], variables=[], columns=[]) + table_config = TableConfig( + name="name", + description="description", + datasets=[], + rows=[], + variables=[], + columns=[], + ) table_config.config_uid = uuid4() table_config_response = TableConfigResponseDataFactory() @@ -882,7 +1132,10 @@ def test_update(collection, session): # Ensure we POST if we weren't created with a table config id assert session.last_call.method == "PUT" - assert session.last_call.path == f"projects/{collection.project_id}/ara-definitions/{table_config.config_uid}" + assert ( + session.last_call.path + == f"projects/{collection.project_id}/ara-definitions/{table_config.config_uid}" + ) def test_update_unregistered_fail(collection, session): @@ -890,10 +1143,19 @@ def test_update_unregistered_fail(collection, session): # Given - table_config = TableConfig(name="name", description="description", datasets=[], rows=[], variables=[], columns=[]) + table_config = TableConfig( + name="name", + description="description", + datasets=[], + rows=[], + variables=[], + columns=[], + ) # When - with pytest.raises(ValueError, match="Cannot update Table Config without a config_uid."): + with pytest.raises( + ValueError, match="Cannot update Table Config without a config_uid." + ): collection.update(table_config) diff --git a/tests/resources/test_team.py b/tests/resources/test_team.py index a0aebec7d..a01fe64b1 100644 --- a/tests/resources/test_team.py +++ b/tests/resources/test_team.py @@ -1,6 +1,5 @@ import json import uuid -from uuid import UUID import pytest from dateutil.parser import parse @@ -8,11 +7,11 @@ from citrine._rest.resource import ResourceTypeEnum from citrine.resources.api_error import ApiError -from citrine.resources.dataset import Dataset, DatasetCollection +from citrine.resources.dataset import Dataset from citrine.resources.process_spec import ProcessSpec from citrine.resources.team import Team, TeamCollection, SHARE, READ, WRITE, TeamMember from citrine.resources.user import User -from tests.utils.factories import UserDataFactory, TeamDataFactory, DatasetDataFactory +from tests.utils.factories import UserDataFactory, TeamDataFactory from tests.utils.session import FakeSession, FakeCall, FakePaginatedSession @@ -30,20 +29,14 @@ def paginated_session() -> FakePaginatedSession: @pytest.fixture def team(session) -> Team: - team = Team( - name='Test Team', - session=session - ) - team.uid = uuid.UUID('16fd2706-8baf-433b-82eb-8c7fada847da') + team = Team(name="Test Team", session=session) + team.uid = uuid.UUID("16fd2706-8baf-433b-82eb-8c7fada847da") return team @pytest.fixture def other_team(session) -> Team: - team = Team( - name='Test Team', - session=session - ) + team = Team(name="Test Team", session=session) team.uid = uuid.uuid4() return team @@ -55,12 +48,10 @@ def collection(session) -> TeamCollection: def test_team_member_string_representation(team): user = User.build(UserDataFactory()) - team_member = TeamMember( - user=user, - team=team, - actions=[READ] + team_member = TeamMember(user=user, team=team, actions=[READ]) + assert team_member.__str__() == "".format( + user.screen_name, team_member.actions, team.name ) - assert team_member.__str__() == ''.format(user.screen_name, team_member.actions, team.name) def test_string_representation(team): @@ -74,78 +65,82 @@ def test_team_project_session(team): def test_team_registration(collection: TeamCollection, session): # Given - create_time = parse('2019-09-10T00:00:00+00:00') + create_time = parse("2019-09-10T00:00:00+00:00") team_data = TeamDataFactory( - name='testing', - description='A sample team', - created_at=int(create_time.timestamp() * 1000) # The lib expects ms since epoch, which is really odd + name="testing", + description="A sample team", + created_at=int( + create_time.timestamp() * 1000 + ), # The lib expects ms since epoch, which is really odd ) user = UserDataFactory() session.set_responses( - {'team': team_data}, + {"team": team_data}, user, - {'id': user['id'], 'actions': ['READ', 'WRITE', 'SHARE']} + {"id": user["id"], "actions": ["READ", "WRITE", "SHARE"]}, ) # When - created_team = collection.register('testing') + created_team = collection.register("testing") # Then assert 3 == session.num_calls expected_call_1 = FakeCall( - method='POST', - path='/teams', + method="POST", + path="/teams", json={ - 'name': 'testing', - 'description': '', - 'id': None, - 'created_at': None, - } + "name": "testing", + "description": "", + "id": None, + "created_at": None, + }, ) - expected_call_2 = FakeCall( - method="GET", - path='/users/me' + expected_call_2 = FakeCall(method="GET", path="/users/me") + expected_call_3 = FakeCall( + method="PUT", + path="/teams/{}/users".format(created_team.uid), + json={"id": user["id"], "actions": [READ, WRITE, SHARE]}, ) - expected_call_3 = FakeCall(method="PUT", path="/teams/{}/users".format(created_team.uid), - json={'id': user["id"], 'actions': [READ, WRITE, SHARE]}) assert expected_call_1 == session.calls[0] assert expected_call_2 == session.calls[1] assert expected_call_3 == session.calls[2] - assert 'A sample team' == created_team.description + assert "A sample team" == created_team.description assert create_time == created_team.created_at def test_get_team(collection: TeamCollection, session): # Given - team_data = TeamDataFactory(name='single team') - session.set_response({'team': team_data}) + team_data = TeamDataFactory(name="single team") + session.set_response({"team": team_data}) # When - created_team = collection.get(team_data['id']) + created_team = collection.get(team_data["id"]) # Then assert 1 == session.num_calls expected_call = FakeCall( - method='GET', - path='/teams/{}'.format(team_data['id']), + method="GET", + path="/teams/{}".format(team_data["id"]), ) assert expected_call == session.last_call - assert 'single team' == created_team.name + assert "single team" == created_team.name def test_list_teams(collection, session): # Given teams_data = TeamDataFactory.create_batch(5) - session.set_response({'teams': teams_data}) + session.set_response({"teams": teams_data}) # When teams = list(collection.list()) # Then assert 1 == session.num_calls - expected_call = FakeCall(method='GET', path='/teams', params={'per_page': 100, 'page': 1}) + expected_call = FakeCall( + method="GET", path="/teams", params={"per_page": 100, "page": 1} + ) assert expected_call == session.last_call assert 5 == len(teams) @@ -171,7 +166,7 @@ def test_list_teams_as_admin(collection, session): def test_update_team(collection: TeamCollection, team, session): team.name = "updated name" - session.set_response({'team': team.dump()}) + session.set_response({"team": team.dump()}) result = collection.update(team) assert result.name == team.name @@ -181,14 +176,14 @@ def test_list_members(team, session): user = UserDataFactory() user["actions"] = READ user.pop("position") - session.set_response({'users': [user]}) + session.set_response({"users": [user]}) # When members = team.list_members() # Then assert 1 == session.num_calls - expect_call = FakeCall(method='GET', path='/teams/{}/users'.format(team.uid)) + expect_call = FakeCall(method="GET", path="/teams/{}/users".format(team.uid)) assert expect_call == session.last_call assert isinstance(members[0], TeamMember) @@ -199,14 +194,16 @@ def test_me(team, session): member = user.copy() member["actions"] = [READ] member.pop("position") - session.set_responses({**user}, {'user': member}) + session.set_responses({**user}, {"user": member}) # When member = team.me() # Then assert 2 == session.num_calls - member_call = FakeCall(method='GET', path='/teams/{}/users/{}'.format(team.uid, user["id"])) + member_call = FakeCall( + method="GET", path="/teams/{}/users/{}".format(team.uid, user["id"]) + ) assert member_call == session.last_call assert isinstance(member, TeamMember) @@ -214,15 +211,20 @@ def test_me(team, session): def test_update_user_actions(team, session): # Given user = UserDataFactory() - session.set_response({'id': user['id'], 'actions': ['READ']}) + session.set_response({"id": user["id"], "actions": ["READ"]}) # When - update_user_role_response = team.update_user_action(user_id=User.build(user), actions=[WRITE, SHARE]) + update_user_role_response = team.update_user_action( + user_id=User.build(user), actions=[WRITE, SHARE] + ) # Then assert 1 == session.num_calls - expect_call = FakeCall(method="PUT", path="/teams/{}/users".format(team.uid), - json={'id': user["id"], 'actions': [WRITE, SHARE]}) + expect_call = FakeCall( + method="PUT", + path="/teams/{}/users".format(team.uid), + json={"id": user["id"], "actions": [WRITE, SHARE]}, + ) assert expect_call == session.last_call assert update_user_role_response is True @@ -230,17 +232,18 @@ def test_update_user_actions(team, session): def test_add_user(team, session): # Given user = UserDataFactory() - session.set_response({'id': user["id"], 'actions': ['READ']}) + session.set_response({"id": user["id"], "actions": ["READ"]}) # When add_user_response = team.add_user(User.build(user)) # Then assert 1 == session.num_calls - expect_call = FakeCall(method="PUT", path='/teams/{}/users'.format(team.uid), json={ - "id": user["id"], - "actions": ["READ"] - }) + expect_call = FakeCall( + method="PUT", + path="/teams/{}/users".format(team.uid), + json={"id": user["id"], "actions": ["READ"]}, + ) assert expect_call == session.last_call assert add_user_response is True @@ -248,17 +251,18 @@ def test_add_user(team, session): def test_add_user_with_actions(team, session): # Given user = UserDataFactory() - session.set_response({'id': user["id"], 'actions': ['READ', 'WRITE']}) + session.set_response({"id": user["id"], "actions": ["READ", "WRITE"]}) # When - add_user_response = team.add_user(user["id"], actions=['READ', 'WRITE']) + add_user_response = team.add_user(user["id"], actions=["READ", "WRITE"]) # Then assert 1 == session.num_calls - expect_call = FakeCall(method="PUT", path='/teams/{}/users'.format(team.uid), json={ - "id": user["id"], - "actions": ["READ", "WRITE"] - }) + expect_call = FakeCall( + method="PUT", + path="/teams/{}/users".format(team.uid), + json={"id": user["id"], "actions": ["READ", "WRITE"]}, + ) assert expect_call == session.last_call assert add_user_response is True @@ -266,7 +270,7 @@ def test_add_user_with_actions(team, session): def test_remove_user(team, session): # Given user = UserDataFactory() - session.set_response({'ids': [user["id"]]}) + session.set_response({"ids": [user["id"]]}) # When remove_user_response = team.remove_user(User.build(user)) @@ -276,7 +280,7 @@ def test_remove_user(team, session): expect_call = FakeCall( method="POST", path="/teams/{}/users/batch-remove".format(team.uid), - json={"ids": [user["id"]]} + json={"ids": [user["id"]]}, ) assert expect_call == session.last_call assert remove_user_response is True @@ -298,8 +302,8 @@ def test_share(team, other_team, session): json={ "resource_type": "DATASET", "resource_id": str(dataset.uid), - "target_team_id": str(other_team.uid) - } + "target_team_id": str(other_team.uid), + }, ) assert expect_call == session.last_call assert share_response is True @@ -317,27 +321,29 @@ def test_un_share(team, other_team, session): assert 1 == session.num_calls expect_call = FakeCall( method="DELETE", - path="/teams/{}/shared-resources/{}/{}".format(team.uid, "DATASET", str(dataset.uid)), - json={ - "target_team_id": str(other_team.uid) - } + path="/teams/{}/shared-resources/{}/{}".format( + team.uid, "DATASET", str(dataset.uid) + ), + json={"target_team_id": str(other_team.uid)}, ) assert expect_call == session.last_call assert share_response is True -@pytest.mark.parametrize("resource_type,method", +@pytest.mark.parametrize( + "resource_type,method", [ (ResourceTypeEnum.DATASET, "dataset_ids"), (ResourceTypeEnum.MODULE, "module_ids"), (ResourceTypeEnum.TABLE, "table_ids"), - (ResourceTypeEnum.TABLE_DEFINITION, "table_definition_ids") - ]) + (ResourceTypeEnum.TABLE_DEFINITION, "table_definition_ids"), + ], +) def test_list_resource_ids(team, session, resource_type, method): # Given - read_response = {'ids': [uuid.uuid4(), uuid.uuid4()]} - write_response = {'ids': [uuid.uuid4(), uuid.uuid4()]} - share_response = {'ids': [uuid.uuid4(), uuid.uuid4()]} + read_response = {"ids": [uuid.uuid4(), uuid.uuid4()]} + write_response = {"ids": [uuid.uuid4(), uuid.uuid4()]} + share_response = {"ids": [uuid.uuid4(), uuid.uuid4()]} # When # This is equivalent to team.dataset_ids, team.module_ids, etc. @@ -354,47 +360,63 @@ def test_list_resource_ids(team, session, resource_type, method): # Then assert session.num_calls == 3 - assert session.calls[0] == FakeCall(method='GET', - path=f'/{resource_type.value}/authorized-ids', - params={"domain": f"/teams/{team.uid}", "action": READ}) - assert session.calls[1] == FakeCall(method='GET', - path=f'/{resource_type.value}/authorized-ids', - params={"domain": f"/teams/{team.uid}", "action": WRITE}) - assert session.calls[2] == FakeCall(method='GET', - path=f'/{resource_type.value}/authorized-ids', - params={"domain": f"/teams/{team.uid}", "action": SHARE}) - assert readable_ids == read_response['ids'] - assert writeable_ids == write_response['ids'] - assert shareable_ids == share_response['ids'] + assert session.calls[0] == FakeCall( + method="GET", + path=f"/{resource_type.value}/authorized-ids", + params={"domain": f"/teams/{team.uid}", "action": READ}, + ) + assert session.calls[1] == FakeCall( + method="GET", + path=f"/{resource_type.value}/authorized-ids", + params={"domain": f"/teams/{team.uid}", "action": WRITE}, + ) + assert session.calls[2] == FakeCall( + method="GET", + path=f"/{resource_type.value}/authorized-ids", + params={"domain": f"/teams/{team.uid}", "action": SHARE}, + ) + assert readable_ids == read_response["ids"] + assert writeable_ids == write_response["ids"] + assert shareable_ids == share_response["ids"] def test_analyses_get_team_id(team): assert team.uid == team.analyses.team_id + def test_owned_dataset_ids(team): # Create a set of datasets in the project ids = {uuid.uuid4() for _ in range(5)} for d_id in ids: - dataset = Dataset(name=f"Test Dataset - {d_id}", summary="Test Dataset", description="Test Dataset") + dataset = Dataset( + name=f"Test Dataset - {d_id}", + summary="Test Dataset", + description="Test Dataset", + ) team.datasets.register(dataset) # Set the session response to have the list of dataset IDs - team.session.set_response({'ids': list(ids)}) + team.session.set_response({"ids": list(ids)}) # Fetch the list of UUID owned by the current project owned_ids = team.owned_dataset_ids() # Let's mock our expected API call so we can compare and ensure that the one made is the same - expect_call = FakeCall(method='GET', - path='/DATASET/authorized-ids', - params={'userId': '', - 'domain': '/teams/16fd2706-8baf-433b-82eb-8c7fada847da', - 'action': 'WRITE'}) + expect_call = FakeCall( + method="GET", + path="/DATASET/authorized-ids", + params={ + "userId": "", + "domain": "/teams/16fd2706-8baf-433b-82eb-8c7fada847da", + "action": "WRITE", + }, + ) # Compare our calls assert expect_call == team.session.last_call assert team.session.num_calls == len(ids) + 1 assert ids == set(owned_ids) + def test_datasets_get_team_id(team): assert team.uid == team.datasets.team_id @@ -460,34 +482,38 @@ def test_gemd_resource_get_team_id(team): def test_team_batch_delete_no_errors(team, session): - job_resp = { - 'job_id': '1234' - } + job_resp = {"job_id": "1234"} # Actual response-like data - note there is no 'failures' array within 'output' successful_job_resp = { - 'job_type': 'batch_delete', - 'status': 'Success', - 'tasks': [ + "job_type": "batch_delete", + "status": "Success", + "tasks": [ { - "id": "7b6bafd9-f32a-4567-b54c-7ce594edc018", "task_type": "batch_delete", - "status": "Success", "dependencies": [] - } - ], - 'output': {} + "id": "7b6bafd9-f32a-4567-b54c-7ce594edc018", + "task_type": "batch_delete", + "status": "Success", + "dependencies": [], + } + ], + "output": {}, } session.set_responses(job_resp, successful_job_resp) # When - del_resp = team.gemd_batch_delete([uuid.UUID('16fd2706-8baf-433b-82eb-8c7fada847da')]) + del_resp = team.gemd_batch_delete( + [uuid.UUID("16fd2706-8baf-433b-82eb-8c7fada847da")] + ) # Then assert len(del_resp) == 0 # When trying with entities session.set_responses(job_resp, successful_job_resp) - entity = ProcessSpec(name="proc spec", uids={'id': '16fd2706-8baf-433b-82eb-8c7fada847da'}) + entity = ProcessSpec( + name="proc spec", uids={"id": "16fd2706-8baf-433b-82eb-8c7fada847da"} + ) del_resp = team.gemd_batch_delete([entity]) # Then @@ -495,42 +521,39 @@ def test_team_batch_delete_no_errors(team, session): def test_team_batch_delete(team, session): - job_resp = { - 'job_id': '1234' - } + job_resp = {"job_id": "1234"} - failures_escaped_json = json.dumps([ - { - "id": { - 'scope': 'somescope', - 'id': 'abcd-1234' - }, - 'cause': { - "code": 400, - "message": "", - "validation_errors": [ - { - "failure_message": "fail msg", - "failure_id": "identifier.coreid.missing" - } - ] + failures_escaped_json = json.dumps( + [ + { + "id": {"scope": "somescope", "id": "abcd-1234"}, + "cause": { + "code": 400, + "message": "", + "validation_errors": [ + { + "failure_message": "fail msg", + "failure_id": "identifier.coreid.missing", + } + ], + }, } - } - ]) + ] + ) failed_job_resp = { - 'job_type': 'batch_delete', - 'status': 'Success', - 'tasks': [], - 'output': { - 'failures': failures_escaped_json - } + "job_type": "batch_delete", + "status": "Success", + "tasks": [], + "output": {"failures": failures_escaped_json}, } session.set_responses(job_resp, failed_job_resp, job_resp, failed_job_resp) # When - del_resp = team.gemd_batch_delete([uuid.UUID('16fd2706-8baf-433b-82eb-8c7fada847da')]) + del_resp = team.gemd_batch_delete( + [uuid.UUID("16fd2706-8baf-433b-82eb-8c7fada847da")] + ) # Then assert 2 == session.num_calls @@ -538,21 +561,30 @@ def test_team_batch_delete(team, session): assert len(del_resp) == 1 first_failure = del_resp[0] - expected_api_error = ApiError.build({ - "code": "400", - "message": "", - "validation_errors": [{"failure_message": "fail msg", "failure_id": "identifier.coreid.missing"}] - }) + expected_api_error = ApiError.build( + { + "code": "400", + "message": "", + "validation_errors": [ + { + "failure_message": "fail msg", + "failure_id": "identifier.coreid.missing", + } + ], + } + ) - assert first_failure[0] == LinkByUID('somescope', 'abcd-1234') + assert first_failure[0] == LinkByUID("somescope", "abcd-1234") assert first_failure[1].dump() == expected_api_error.dump() # And again with tuples of (scope, id) - del_resp = team.gemd_batch_delete([LinkByUID('id', '16fd2706-8baf-433b-82eb-8c7fada847da')]) + del_resp = team.gemd_batch_delete( + [LinkByUID("id", "16fd2706-8baf-433b-82eb-8c7fada847da")] + ) assert len(del_resp) == 1 first_failure = del_resp[0] - assert first_failure[0] == LinkByUID('somescope', 'abcd-1234') + assert first_failure[0] == LinkByUID("somescope", "abcd-1234") assert first_failure[1].dump() == expected_api_error.dump() diff --git a/tests/resources/test_templates.py b/tests/resources/test_templates.py index 4bf4bec17..2041bb343 100644 --- a/tests/resources/test_templates.py +++ b/tests/resources/test_templates.py @@ -1,4 +1,5 @@ """Test that templates show expected behavior.""" + import pytest from uuid import uuid4 @@ -6,7 +7,10 @@ from citrine.resources.measurement_template import MeasurementTemplate from citrine.resources.process_template import ProcessTemplate from citrine.resources.process_spec import ProcessSpec -from citrine.resources.property_template import PropertyTemplate, PropertyTemplateCollection +from citrine.resources.property_template import ( + PropertyTemplate, + PropertyTemplateCollection, +) from citrine.resources.condition_template import ConditionTemplate from citrine.resources.parameter_template import ParameterTemplate from citrine.exceptions import BadRequest @@ -16,75 +20,107 @@ from gemd.entity.value.nominal_real import NominalReal from gemd.entity.attribute.condition import Condition -from tests.utils.session import FakeSession, FakeCall +from tests.utils.session import FakeSession def test_object_template_validation(): """Test that attribute templates are validated against given bounds.""" - length_template = PropertyTemplate("Length", bounds=RealBounds(2.0, 3.5, 'cm')) + length_template = PropertyTemplate("Length", bounds=RealBounds(2.0, 3.5, "cm")) dial_template = ConditionTemplate("dial", bounds=IntegerBounds(0, 5)) - color_template = ParameterTemplate("Color", bounds=CategoricalBounds(["red", "green", "blue"])) + color_template = ParameterTemplate( + "Color", bounds=CategoricalBounds(["red", "green", "blue"]) + ) with pytest.raises(TypeError): MaterialTemplate() with pytest.raises(ValueError): - MaterialTemplate("Block", properties=[[length_template, RealBounds(3.0, 4.0, 'cm')]]) + MaterialTemplate( + "Block", properties=[[length_template, RealBounds(3.0, 4.0, "cm")]] + ) with pytest.raises(ValueError): - ProcessTemplate("a process", conditions=[[color_template, CategoricalBounds(["zz"])]]) - + ProcessTemplate( + "a process", conditions=[[color_template, CategoricalBounds(["zz"])]] + ) + with pytest.raises(ValueError): - MeasurementTemplate("A measurement", parameters=[[dial_template, IntegerBounds(-3, -1)]]) + MeasurementTemplate( + "A measurement", parameters=[[dial_template, IntegerBounds(-3, -1)]] + ) def test_template_assignment(): """Test that an object and its attributes can both be assigned templates.""" humidity_template = ConditionTemplate("Humidity", bounds=RealBounds(0.5, 0.75, "")) - template = ProcessTemplate("Dry", conditions=[[humidity_template, RealBounds(0.5, 0.65, "")]]) - ProcessSpec("Dry a polymer", template=template, conditions=[ - Condition("Humidity", value=NominalReal(0.6, ""), template=humidity_template)]) + template = ProcessTemplate( + "Dry", conditions=[[humidity_template, RealBounds(0.5, 0.65, "")]] + ) + ProcessSpec( + "Dry a polymer", + template=template, + conditions=[ + Condition( + "Humidity", value=NominalReal(0.6, ""), template=humidity_template + ) + ], + ) def test_automatic_async_update(): """Update on an object that requires an asynchronous check smoothly transitions to async_update.""" session = FakeSession() - collection = PropertyTemplateCollection(team_id=uuid4(), dataset_id=uuid4(), session=session) + collection = PropertyTemplateCollection( + team_id=uuid4(), dataset_id=uuid4(), session=session + ) this_id = str(uuid4()) - template = PropertyTemplate("dummy template", bounds=RealBounds(0.0, 0.5, ''), uids={'id': this_id}) + template = PropertyTemplate( + "dummy template", bounds=RealBounds(0.0, 0.5, ""), uids={"id": this_id} + ) session.set_responses( - BadRequest(""), # Attempted POST throws BadRequest because, for example, the template bounds are being narrowed. + BadRequest( + "" + ), # Attempted POST throws BadRequest because, for example, the template bounds are being narrowed. {"job_id": str(uuid4())}, # Call async route, returning a job_id. - {"job_type": "", "status": "Success", "tasks": []}, # Check job status, it succeeded. - template.dump() # Get the resource. + { + "job_type": "", + "status": "Success", + "tasks": [], + }, # Check job status, it succeeded. + template.dump(), # Get the resource. ) new_template = collection.update(template) assert new_template == template # Check that resource is returned. # First call should be an attempt to POST the resource assert session.calls[0].method == "POST" - assert session.calls[0].path == f"teams/{collection.team_id}/datasets/{collection.dataset_id}/property-templates" + assert ( + session.calls[0].path + == f"teams/{collection.team_id}/datasets/{collection.dataset_id}/property-templates" + ) # Second call should be a PUT to the async route assert session.calls[1].method == "PUT" - assert session.calls[1].path == f"teams/{collection.team_id}/datasets/{collection.dataset_id}/property-templates/id/{this_id}/async" + assert ( + session.calls[1].path + == f"teams/{collection.team_id}/datasets/{collection.dataset_id}/property-templates/id/{this_id}/async" + ) # Last call should get the resource assert session.last_call.method == "GET" - assert session.last_call.path == f"teams/{collection.team_id}/datasets/{collection.dataset_id}/property-templates/id/{this_id}" + assert ( + session.last_call.path + == f"teams/{collection.team_id}/datasets/{collection.dataset_id}/property-templates/id/{this_id}" + ) def test_process_template_equals(): """Test basic equality. Complex relationships are tested in test_material_run.test_deep_equals().""" - from citrine.resources.process_template import ProcessTemplate as CitrineProcessTemplate + from citrine.resources.process_template import ( + ProcessTemplate as CitrineProcessTemplate, + ) from gemd.entity.template import ProcessTemplate as GEMDProcessTemplate - gemd_obj = GEMDProcessTemplate( - name="My Name", - tags=["tag!"] - ) - citrine_obj = CitrineProcessTemplate( - name="My Name", - tags=["tag!"] - ) + gemd_obj = GEMDProcessTemplate(name="My Name", tags=["tag!"]) + citrine_obj = CitrineProcessTemplate(name="My Name", tags=["tag!"]) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.name = "Something else" assert gemd_obj != citrine_obj, "GEMD/Citrine detects difference" @@ -92,17 +128,13 @@ def test_process_template_equals(): def test_material_template_equals(): """Test basic equality. Complex relationships are tested in test_material_run.test_deep_equals().""" - from citrine.resources.material_template import MaterialTemplate as CitrineMaterialTemplate + from citrine.resources.material_template import ( + MaterialTemplate as CitrineMaterialTemplate, + ) from gemd.entity.template import MaterialTemplate as GEMDMaterialTemplate - gemd_obj = GEMDMaterialTemplate( - name="My Name", - tags=["tag!"] - ) - citrine_obj = CitrineMaterialTemplate( - name="My Name", - tags=["tag!"] - ) + gemd_obj = GEMDMaterialTemplate(name="My Name", tags=["tag!"]) + citrine_obj = CitrineMaterialTemplate(name="My Name", tags=["tag!"]) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.name = "Something else" assert gemd_obj != citrine_obj, "GEMD/Citrine detects difference" @@ -110,17 +142,13 @@ def test_material_template_equals(): def test_measurement_template_equals(): """Test basic equality. Complex relationships are tested in test_material_run.test_deep_equals().""" - from citrine.resources.measurement_template import MeasurementTemplate as CitrineMeasurementTemplate + from citrine.resources.measurement_template import ( + MeasurementTemplate as CitrineMeasurementTemplate, + ) from gemd.entity.template import MeasurementTemplate as GEMDMeasurementTemplate - gemd_obj = GEMDMeasurementTemplate( - name="My Name", - tags=["tag!"] - ) - citrine_obj = CitrineMeasurementTemplate( - name="My Name", - tags=["tag!"] - ) + gemd_obj = GEMDMeasurementTemplate(name="My Name", tags=["tag!"]) + citrine_obj = CitrineMeasurementTemplate(name="My Name", tags=["tag!"]) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.name = "Something else" assert gemd_obj != citrine_obj, "GEMD/Citrine detects difference" @@ -128,18 +156,16 @@ def test_measurement_template_equals(): def test_condition_template_equals(): """Test basic equality. Complex relationships are tested in test_material_run.test_deep_equals().""" - from citrine.resources.condition_template import ConditionTemplate as CitrineConditionTemplate + from citrine.resources.condition_template import ( + ConditionTemplate as CitrineConditionTemplate, + ) from gemd.entity.template import ConditionTemplate as GEMDConditionTemplate gemd_obj = GEMDConditionTemplate( - name="My Name", - bounds=CategoricalBounds(categories=["1"]), - tags=["tag!"] + name="My Name", bounds=CategoricalBounds(categories=["1"]), tags=["tag!"] ) citrine_obj = CitrineConditionTemplate( - name="My Name", - bounds=CategoricalBounds(categories=["1"]), - tags=["tag!"] + name="My Name", bounds=CategoricalBounds(categories=["1"]), tags=["tag!"] ) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.name = "Something else" @@ -148,18 +174,16 @@ def test_condition_template_equals(): def test_parameter_template_equals(): """Test basic equality. Complex relationships are tested in test_material_run.test_deep_equals().""" - from citrine.resources.parameter_template import ParameterTemplate as CitrineParameterTemplate + from citrine.resources.parameter_template import ( + ParameterTemplate as CitrineParameterTemplate, + ) from gemd.entity.template import ParameterTemplate as GEMDParameterTemplate gemd_obj = GEMDParameterTemplate( - name="My Name", - bounds=CategoricalBounds(categories=["1"]), - tags=["tag!"] + name="My Name", bounds=CategoricalBounds(categories=["1"]), tags=["tag!"] ) citrine_obj = CitrineParameterTemplate( - name="My Name", - bounds=CategoricalBounds(categories=["1"]), - tags=["tag!"] + name="My Name", bounds=CategoricalBounds(categories=["1"]), tags=["tag!"] ) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.name = "Something else" @@ -168,18 +192,16 @@ def test_parameter_template_equals(): def test_property_template_equals(): """Test basic equality. Complex relationships are tested in test_material_run.test_deep_equals().""" - from citrine.resources.property_template import PropertyTemplate as CitrinePropertyTemplate + from citrine.resources.property_template import ( + PropertyTemplate as CitrinePropertyTemplate, + ) from gemd.entity.template import PropertyTemplate as GEMDPropertyTemplate gemd_obj = GEMDPropertyTemplate( - name="My Name", - bounds=CategoricalBounds(categories=["1"]), - tags=["tag!"] + name="My Name", bounds=CategoricalBounds(categories=["1"]), tags=["tag!"] ) citrine_obj = CitrinePropertyTemplate( - name="My Name", - bounds=CategoricalBounds(categories=["1"]), - tags=["tag!"] + name="My Name", bounds=CategoricalBounds(categories=["1"]), tags=["tag!"] ) assert gemd_obj == citrine_obj, "GEMD/Citrine equivalence" citrine_obj.name = "Something else" diff --git a/tests/resources/test_user.py b/tests/resources/test_user.py index 9ca16cc22..41d310e48 100644 --- a/tests/resources/test_user.py +++ b/tests/resources/test_user.py @@ -4,7 +4,7 @@ from citrine.resources.user import User, UserCollection from tests.utils.factories import UserDataFactory -from tests.utils.session import FakeSession, FakeCall +from tests.utils.session import FakeCall, FakeSession @pytest.fixture @@ -15,12 +15,9 @@ def session() -> FakeSession: @pytest.fixture def user() -> User: user = User( - screen_name='Test User', - email="test@user.io", - position="QA", - is_admin=False + screen_name="Test User", email="test@user.io", position="QA", is_admin=False ) - user.uid = UUID('16fd2706-8baf-433b-82eb-8c7fada847da') + user.uid = UUID("16fd2706-8baf-433b-82eb-8c7fada847da") return user @@ -31,10 +28,10 @@ def collection(session) -> UserCollection: def test_user_str_representation(): user = User( - screen_name='joe', - email='joe@somewhere.com', - position='President', - is_admin=False + screen_name="joe", + email="joe@somewhere.com", + position="President", + is_admin=False, ) assert "" == str(user) @@ -49,39 +46,39 @@ def test_user_registration(collection, session): # given user = UserDataFactory() - session.set_response({'user': user}) + session.set_response({"user": user}) # When created_user = collection.register( screen_name=user["screen_name"], email=user["email"], position=user["position"], - is_admin=user["is_admin"] + is_admin=user["is_admin"], ) # Then assert 1 == session.num_calls expected_call = FakeCall( - method='POST', - path='/users', + method="POST", + path="/users", json={ - 'screen_name': user["screen_name"], - 'position': user["position"], - 'email': user["email"], - 'is_admin': user["is_admin"], - } + "screen_name": user["screen_name"], + "position": user["position"], + "email": user["email"], + "is_admin": user["is_admin"], + }, ) - assert expected_call.json['screen_name'] == created_user.screen_name - assert expected_call.json['email'] == created_user.email - assert expected_call.json['position'] == created_user.position - assert expected_call.json['is_admin'] == created_user.is_admin + assert expected_call.json["screen_name"] == created_user.screen_name + assert expected_call.json["email"] == created_user.email + assert expected_call.json["position"] == created_user.position + assert expected_call.json["is_admin"] == created_user.is_admin def test_list_users(collection, session): # Given user_data = UserDataFactory.create_batch(5) - session.set_response({'users': user_data}) + session.set_response({"users": user_data}) # When users = list(collection.list()) @@ -89,9 +86,7 @@ def test_list_users(collection, session): # Then assert 1 == session.num_calls expected_call = FakeCall( - method='GET', - path='/users', - params={'per_page': 100, 'page': 1} + method="GET", path="/users", params={"per_page": 100, "page": 1} ) assert expected_call == session.last_call @@ -101,7 +96,7 @@ def test_list_users(collection, session): def test_list_users_as_admin(collection, session): # Given user_data = UserDataFactory.create_batch(5) - session.set_response({'users': user_data}) + session.set_response({"users": user_data}) # When users = list(collection.list(as_admin=True)) @@ -109,9 +104,9 @@ def test_list_users_as_admin(collection, session): # Then assert 1 == session.num_calls expected_call = FakeCall( - method='GET', - path='/users', - params={'per_page': 100, 'page': 1, 'as_admin': 'true'} + method="GET", + path="/users", + params={"per_page": 100, "page": 1, "as_admin": "true"}, ) assert expected_call == session.last_call @@ -120,7 +115,7 @@ def test_list_users_as_admin(collection, session): def test_get_users(collection, session): # Given - uid = '151199ec-e9aa-49a1-ac8e-da722aaf74c4' + uid = "151199ec-e9aa-49a1-ac8e-da722aaf74c4" # When with pytest.raises(KeyError): @@ -132,12 +127,12 @@ def test_delete_user(collection, session): user = UserDataFactory() # When - collection.delete(user['id']) + collection.delete(user["id"]) - session.set_response({'message': 'User was deleted'}) + session.set_response({"message": "User was deleted"}) expected_call = FakeCall( method="DELETE", - path='/users/{}'.format(user["id"]), + path="/users/{}".format(user["id"]), ) assert 1 == session.num_calls @@ -150,13 +145,10 @@ def test_get_me(collection, session): session.set_response(user) # When - current_user = collection.me() + collection.me() # Then - expected_call = FakeCall( - method="GET", - path='/users/me' - ) + expected_call = FakeCall(method="GET", path="/users/me") assert 1 == session.num_calls assert expected_call == session.last_call diff --git a/tests/resources/test_workflow.py b/tests/resources/test_workflow.py index dafffd672..82e313334 100644 --- a/tests/resources/test_workflow.py +++ b/tests/resources/test_workflow.py @@ -5,37 +5,36 @@ from citrine.informatics.workflows.design_workflow import DesignWorkflow from citrine.resources.design_workflow import DesignWorkflowCollection - from tests.utils.factories import BranchDataFactory -from tests.utils.session import FakeSession, FakeCall +from tests.utils.session import FakeCall, FakeSession -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def basic_design_workflow_data(): return { - 'id': str(uuid.uuid4()), - 'name': 'Test Workflow', - 'status': 'SUCCEEDED', - 'status_description': 'READY', - 'design_space_id': str(uuid.uuid4()), - 'predictor_id': str(uuid.uuid4()), - 'branch_id': str(uuid.uuid4()), - 'module_type': 'DESIGN_WORKFLOW', - 'create_time': datetime(2020, 1, 1, 1, 1, 1, 1).isoformat("T"), - 'created_by': str(uuid.uuid4()), + "id": str(uuid.uuid4()), + "name": "Test Workflow", + "status": "SUCCEEDED", + "status_description": "READY", + "design_space_id": str(uuid.uuid4()), + "predictor_id": str(uuid.uuid4()), + "branch_id": str(uuid.uuid4()), + "module_type": "DESIGN_WORKFLOW", + "create_time": datetime(2020, 1, 1, 1, 1, 1, 1).isoformat("T"), + "created_by": str(uuid.uuid4()), } -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def failed_design_workflow_data(basic_design_workflow_data): return { **basic_design_workflow_data, - 'status': 'FAILED', - 'status_description': 'ERROR', - 'status_detail': [ - {'level': 'WARNING', 'msg': 'Something is wrong'}, - {'level': 'Error', 'msg': 'Very wrong'} - ] + "status": "FAILED", + "status_description": "ERROR", + "status_detail": [ + {"level": "WARNING", "msg": "Something is wrong"}, + {"level": "Error", "msg": "Very wrong"}, + ], } @@ -49,7 +48,9 @@ def test_build_design_workflow(session, basic_design_workflow_data): branch_data = BranchDataFactory() session.set_response(branch_data) - workflow_collection = DesignWorkflowCollection(project_id=uuid.uuid4(), session=session) + workflow_collection = DesignWorkflowCollection( + project_id=uuid.uuid4(), session=session + ) # When workflow = workflow_collection.build(basic_design_workflow_data) @@ -61,16 +62,23 @@ def test_build_design_workflow(session, basic_design_workflow_data): def test_list_workflows(session, basic_design_workflow_data): - #Given - workflow_collection = DesignWorkflowCollection(project_id=uuid.uuid4(), session=session) - session.set_responses({'response': [basic_design_workflow_data], 'page': 1, 'per_page': 20}) + # Given + workflow_collection = DesignWorkflowCollection( + project_id=uuid.uuid4(), session=session + ) + session.set_responses( + {"response": [basic_design_workflow_data], "page": 1, "per_page": 20} + ) # When workflows = list(workflow_collection.list(per_page=20)) # Then - expected_design_call = FakeCall(method='GET', path='/projects/{}/modules'.format(workflow_collection.project_id), - params={'per_page': 20, 'module_type': 'DESIGN_WORKFLOW'}) + FakeCall( + method="GET", + path="/projects/{}/modules".format(workflow_collection.project_id), + params={"per_page": 20, "module_type": "DESIGN_WORKFLOW"}, + ) assert 1 == session.num_calls assert len(workflows) == 1 assert isinstance(workflows[0], DesignWorkflow) diff --git a/tests/rest/test_ingredient_rest.py b/tests/rest/test_ingredient_rest.py index 12acab09c..157f5ebf7 100644 --- a/tests/rest/test_ingredient_rest.py +++ b/tests/rest/test_ingredient_rest.py @@ -1,4 +1,5 @@ """Test RESTful actions on ingredient runs""" + import pytest from citrine.resources.ingredient_run import IngredientRun @@ -8,23 +9,38 @@ @pytest.fixture def valid_data(): """Return valid data used for these tests.""" - return {"type": "ingredient_run", - "material": {"type": "link_by_uid", "id": "5c913611-c304-4254-bad2-4797c952a3b3", "scope": "ID"}, - "process": {"type": "link_by_uid", "id": "5c913611-c304-4254-bad2-4797c952a3b4", "scope": "ID"}, - "spec": {"type": "link_by_uid", "id": "5c913611-c304-4254-bad2-4797c952a3b5", "scope": "ID"}, - "name": "Good Ingredient Run", - "labels": [], - "mass_fraction": {'nominal': 0.5, 'units': 'dimensionless', 'type': 'nominal_real'}, - "volume_fraction": None, - "number_fraction": None, - "absolute_quantity": {'nominal': 2, 'units': 'g', 'type': 'nominal_real'}, - "uids": { - "id": "09145273-1ff2-4fbd-ba56-404c0408eb49" - }, - "tags": [], - "notes": "Ingredients!", - "file_links": [] - } + return { + "type": "ingredient_run", + "material": { + "type": "link_by_uid", + "id": "5c913611-c304-4254-bad2-4797c952a3b3", + "scope": "ID", + }, + "process": { + "type": "link_by_uid", + "id": "5c913611-c304-4254-bad2-4797c952a3b4", + "scope": "ID", + }, + "spec": { + "type": "link_by_uid", + "id": "5c913611-c304-4254-bad2-4797c952a3b5", + "scope": "ID", + }, + "name": "Good Ingredient Run", + "labels": [], + "mass_fraction": { + "nominal": 0.5, + "units": "dimensionless", + "type": "nominal_real", + }, + "volume_fraction": None, + "number_fraction": None, + "absolute_quantity": {"nominal": 2, "units": "g", "type": "nominal_real"}, + "uids": {"id": "09145273-1ff2-4fbd-ba56-404c0408eb49"}, + "tags": [], + "notes": "Ingredients!", + "file_links": [], + } def test_ingredient_build(valid_data): diff --git a/tests/rest/test_paginator.py b/tests/rest/test_paginator.py index 9edfd7974..6e0421974 100644 --- a/tests/rest/test_paginator.py +++ b/tests/rest/test_paginator.py @@ -1,8 +1,8 @@ """Test the Paginator""" + from uuid import uuid4 from mock import Mock -import pytest from citrine._rest.paginator import Paginator @@ -47,7 +47,9 @@ def test_pagination_stops_when_initial_item_repeated(): def test_pagination_deduplicates_repeated_intermediate_values(): - result = Paginator().paginate(mocked_fetcher(a, b, b, b, b, b, b, c, c), lambda x: x, per_page=1) + result = Paginator().paginate( + mocked_fetcher(a, b, b, b, b, b, b, c, c), lambda x: x, per_page=1 + ) assert list(result) == [a, b, c] diff --git a/tests/seeding/test_find_or_create.py b/tests/seeding/test_find_or_create.py index 944563ceb..64dcb9bb5 100644 --- a/tests/seeding/test_find_or_create.py +++ b/tests/seeding/test_find_or_create.py @@ -2,24 +2,29 @@ from uuid import UUID, uuid4 import pytest + from citrine._rest.collection import Collection +from citrine.informatics.predictors import AutoMLPredictor, GraphPredictor from citrine.resources.dataset import Dataset, DatasetCollection from citrine.resources.design_workflow import DesignWorkflowCollection -from citrine.resources.process_spec import ProcessSpecCollection, ProcessSpec from citrine.resources.predictor import PredictorCollection +from citrine.resources.process_spec import ProcessSpec, ProcessSpecCollection from citrine.resources.project import ProjectCollection from citrine.resources.team import TeamCollection -from citrine.informatics.predictors import AutoMLPredictor, GraphPredictor -from citrine.seeding.find_or_create import (find_collection, get_by_name_or_create, - get_by_name_or_raise_error, - find_or_create_project, find_or_create_dataset, - create_or_update, find_or_create_team) +from citrine.seeding.find_or_create import ( + create_or_update, + find_collection, + find_or_create_dataset, + find_or_create_project, + find_or_create_team, + get_by_name_or_create, + get_by_name_or_raise_error, +) from tests.utils.factories import BranchDataFactory, DesignWorkflowDataFactory -from tests.utils.fakes.fake_dataset_collection import FakeDatasetCollection from tests.utils.fakes import FakePredictorCollection +from tests.utils.fakes.fake_dataset_collection import FakeDatasetCollection from tests.utils.fakes.fake_project_collection import FakeProjectCollection from tests.utils.fakes.fake_team_collection import FakeTeamCollection - from tests.utils.session import FakeSession duplicate_name = "duplicate" @@ -50,11 +55,13 @@ def list(self, page: Optional[int] = None, per_page: int = 100): if page is None: return self.resources else: - return self.resources[(page - 1)*per_page:page*per_page] + return self.resources[(page - 1) * per_page : page * per_page] - collection = FakeCollection(dataset_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), - team_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), - session=FakeSession()) + collection = FakeCollection( + dataset_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), + team_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), + session=FakeSession(), + ) for i in range(0, 5): collection.register(ProcessSpec("resource " + str(i))) for i in range(0, 2): @@ -69,8 +76,9 @@ def session() -> FakeSession: @pytest.fixture def project_collection() -> Callable[[bool], ProjectCollection]: - - def _make_project(search_implemented: bool = True, team_id: Optional[Union[UUID, str]] = uuid4()): + def _make_project( + search_implemented: bool = True, team_id: Optional[Union[UUID, str]] = uuid4() + ): projects = FakeProjectCollection(search_implemented, team_id) for i in range(0, 5): projects.register("project " + str(i)) @@ -83,7 +91,6 @@ def _make_project(search_implemented: bool = True, team_id: Optional[Union[UUID, @pytest.fixture def team_collection() -> Callable[[bool], TeamCollection]: - def _make_team(): teams = FakeTeamCollection(True) for i in range(0, 5): @@ -97,24 +104,39 @@ def _make_team(): @pytest.fixture def dataset_collection() -> DatasetCollection: - datasets = FakeDatasetCollection(team_id=UUID('6b608f78-e341-422c-8076-35adc8828545'), session=FakeSession()) + datasets = FakeDatasetCollection( + team_id=UUID("6b608f78-e341-422c-8076-35adc8828545"), session=FakeSession() + ) for i in range(0, 5): num_string = str(i) - datasets.register(Dataset("dataset " + num_string, summary="summ " + num_string, description="desc " + num_string)) + datasets.register( + Dataset( + "dataset " + num_string, + summary="summ " + num_string, + description="desc " + num_string, + ) + ) for i in range(0, 2): - datasets.register(Dataset(duplicate_name, summary="dup", description="duplicate")) + datasets.register( + Dataset(duplicate_name, summary="dup", description="duplicate") + ) return datasets + @pytest.fixture def predictor_collection() -> PredictorCollection: - predictors = FakePredictorCollection(UUID('6b608f78-e341-422c-8076-35adc8828545'), FakeSession()) + predictors = FakePredictorCollection( + UUID("6b608f78-e341-422c-8076-35adc8828545"), FakeSession() + ) # Adding a few predictors in the collection to have something to update for i in range(0, 5): pred = GraphPredictor( name=f"resource {i}", description="", - predictors=[AutoMLPredictor(name="", description="", inputs=[], outputs=[])] + predictors=[ + AutoMLPredictor(name="", description="", inputs=[], outputs=[]) + ], ) predictors.register(pred) @@ -124,7 +146,9 @@ def predictor_collection() -> PredictorCollection: pred = GraphPredictor( name=f"resource {i}", description="", - predictors=[AutoMLPredictor(name="", description="", inputs=[], outputs=[])] + predictors=[ + AutoMLPredictor(name="", description="", inputs=[], outputs=[]) + ], ) predictors.register(pred) return predictors @@ -150,9 +174,15 @@ def test_find_collection_exist_multiple(fake_collection): def test_get_by_name_or_create_no_exist(fake_collection): # test when name doesn't exist - default_provider = lambda: fake_collection.register(ProcessSpec("New Resource")) + def default_provider(): + return fake_collection.register(ProcessSpec("New Resource")) + old_resource_count = len(list(fake_collection.list())) - result = get_by_name_or_create(collection=fake_collection, name="New Resource", default_provider=default_provider) + result = get_by_name_or_create( + collection=fake_collection, + name="New Resource", + default_provider=default_provider, + ) new_resource_count = len(list(fake_collection.list())) assert result.name == "New Resource" assert new_resource_count == old_resource_count + 1 @@ -161,9 +191,16 @@ def test_get_by_name_or_create_no_exist(fake_collection): def test_get_by_name_or_create_exist(fake_collection): # test when name exists resource_name = "resource 2" - default_provider = lambda: fake_collection.register(ProcessSpec("New Resource")) + + def default_provider(): + return fake_collection.register(ProcessSpec("New Resource")) + old_resource_count = len(list(fake_collection.list())) - result = get_by_name_or_create(collection=fake_collection, name=resource_name, default_provider=default_provider) + result = get_by_name_or_create( + collection=fake_collection, + name=resource_name, + default_provider=default_provider, + ) new_resource_count = len(list(fake_collection.list())) assert result.name == resource_name assert new_resource_count == old_resource_count @@ -204,14 +241,18 @@ def test_find_or_create_team_exist(team_collection): def test_find_or_create_raise_error_team_no_exist(team_collection): # test when team doesn't exist and raise_error flag is on with pytest.raises(ValueError): - find_or_create_team(team_collection=team_collection(), team_name=absent_name, raise_error=True) + find_or_create_team( + team_collection=team_collection(), team_name=absent_name, raise_error=True + ) def test_find_or_create_project_no_exist(project_collection): # test when project doesn't exist collection = project_collection() old_project_count = len(list(collection.list())) - result = find_or_create_project(project_collection=collection, project_name=absent_name) + result = find_or_create_project( + project_collection=collection, project_name=absent_name + ) new_project_count = len(list(collection.list())) assert result.name == absent_name assert new_project_count == old_project_count + 1 @@ -221,7 +262,9 @@ def test_find_or_create_project_exist(project_collection): # test when project exists collection = project_collection() old_project_count = len(list(collection.list())) - result = find_or_create_project(project_collection=collection, project_name="project 2") + result = find_or_create_project( + project_collection=collection, project_name="project 2" + ) new_project_count = len(list(collection.list())) assert result.name == "project 2" assert new_project_count == old_project_count @@ -231,7 +274,9 @@ def test_find_or_create_project_exist_no_search(project_collection): # test when project exists collection = project_collection(False) old_project_count = len(list(collection.list())) - result = find_or_create_project(project_collection=collection, project_name="project 2") + result = find_or_create_project( + project_collection=collection, project_name="project 2" + ) new_project_count = len(list(collection.list())) assert result.name == "project 2" assert new_project_count == old_project_count @@ -240,20 +285,28 @@ def test_find_or_create_project_exist_no_search(project_collection): def test_find_or_create_project_exist_multiple(project_collection): # test when project exists multiple times with pytest.raises(ValueError): - find_or_create_project(project_collection=project_collection(), project_name=duplicate_name) + find_or_create_project( + project_collection=project_collection(), project_name=duplicate_name + ) def test_find_or_create_raise_error_project_no_exist(project_collection): # test when project doesn't exist and raise_error flag is on with pytest.raises(ValueError): - find_or_create_project(project_collection=project_collection(), project_name=absent_name, raise_error=True) + find_or_create_project( + project_collection=project_collection(), + project_name=absent_name, + raise_error=True, + ) def test_find_or_create_raise_error_project_exist(project_collection): # test when project exists and raise_error flag is on collection = project_collection() old_project_count = len(list(collection.list())) - result = find_or_create_project(project_collection=collection, project_name="project 3", raise_error=True) + result = find_or_create_project( + project_collection=collection, project_name="project 3", raise_error=True + ) new_project_count = len(list(collection.list())) assert result.name == "project 3" assert new_project_count == old_project_count @@ -262,7 +315,11 @@ def test_find_or_create_raise_error_project_exist(project_collection): def test_find_or_create_raise_error_project_exist_multiple(project_collection): # test when project exists multiple times and raise_error flag is on with pytest.raises(ValueError): - find_or_create_project(project_collection=project_collection(), project_name=duplicate_name, raise_error=True) + find_or_create_project( + project_collection=project_collection(), + project_name=duplicate_name, + raise_error=True, + ) def test_find_or_create_project_no_team(project_collection): @@ -275,7 +332,9 @@ def test_find_or_create_project_no_team(project_collection): def test_find_or_create_dataset_no_exist(dataset_collection): # test when dataset doesn't exist old_dataset_count = len(list(dataset_collection.list())) - result = find_or_create_dataset(dataset_collection=dataset_collection, dataset_name=absent_name) + result = find_or_create_dataset( + dataset_collection=dataset_collection, dataset_name=absent_name + ) new_dataset_count = len(list(dataset_collection.list())) assert result.name == absent_name assert new_dataset_count == old_dataset_count + 1 @@ -284,7 +343,9 @@ def test_find_or_create_dataset_no_exist(dataset_collection): def test_find_or_create_dataset_exist(dataset_collection): # test when dataset exists old_dataset_count = len(list(dataset_collection.list())) - result = find_or_create_dataset(dataset_collection=dataset_collection, dataset_name="dataset 2") + result = find_or_create_dataset( + dataset_collection=dataset_collection, dataset_name="dataset 2" + ) new_dataset_count = len(list(dataset_collection.list())) assert result.name == "dataset 2" assert new_dataset_count == old_dataset_count @@ -293,19 +354,29 @@ def test_find_or_create_dataset_exist(dataset_collection): def test_find_or_create_dataset_exist_multiple(dataset_collection): # test when dataset exists multiple times with pytest.raises(ValueError): - find_or_create_dataset(dataset_collection=dataset_collection, dataset_name=duplicate_name) + find_or_create_dataset( + dataset_collection=dataset_collection, dataset_name=duplicate_name + ) def test_find_or_create_dataset_raise_error_no_exist(dataset_collection): # test when dataset doesn't exist and raise_error flag is on with pytest.raises(ValueError): - find_or_create_dataset(dataset_collection=dataset_collection, dataset_name=absent_name, raise_error=True) + find_or_create_dataset( + dataset_collection=dataset_collection, + dataset_name=absent_name, + raise_error=True, + ) def test_find_or_create_dataset_raise_error_exist(dataset_collection): # test when dataset exists and raise_error flag is on old_dataset_count = len(list(dataset_collection.list())) - result = find_or_create_dataset(dataset_collection=dataset_collection, dataset_name="dataset 3", raise_error=True) + result = find_or_create_dataset( + dataset_collection=dataset_collection, + dataset_name="dataset 3", + raise_error=True, + ) new_dataset_count = len(list(dataset_collection.list())) assert result.name == "dataset 3" assert new_dataset_count == old_dataset_count @@ -314,33 +385,41 @@ def test_find_or_create_dataset_raise_error_exist(dataset_collection): def test_find_or_create_dataset_raise_error_exist_multiple(dataset_collection): # test when dataset exists multiple times and raise_error flag is on with pytest.raises(ValueError): - find_or_create_dataset(dataset_collection=dataset_collection, dataset_name=duplicate_name, raise_error=True) + find_or_create_dataset( + dataset_collection=dataset_collection, + dataset_name=duplicate_name, + raise_error=True, + ) def test_create_or_update_none_found(predictor_collection): # test when resource doesn't exist with listed name and check if new one is created assert not [r for r in list(predictor_collection.list()) if r.name == absent_name] - aml = AutoMLPredictor(name=absent_name, description='', inputs=[], outputs=[]) - pred = GraphPredictor(name=absent_name, description='', predictors=[aml]) - #verify that the returned object is updated + aml = AutoMLPredictor(name=absent_name, description="", inputs=[], outputs=[]) + pred = GraphPredictor(name=absent_name, description="", predictors=[aml]) + # verify that the returned object is updated returned_pred = create_or_update(collection=predictor_collection, resource=pred) assert returned_pred.uid == pred.uid assert returned_pred.name == pred.name assert returned_pred.description == pred.description - #verify that the collection is also updated + # verify that the collection is also updated assert any([r for r in list(predictor_collection.list()) if r.name == absent_name]) def test_create_or_update_unique_found(predictor_collection): # test when there is a single unique resource that exists with the listed name and update - aml = AutoMLPredictor(name="", description='', inputs=[], outputs=[]) - pred = GraphPredictor(name="resource 4", description="I am updated!", predictors=[aml]) - #verify that the returned object is updated + aml = AutoMLPredictor(name="", description="", inputs=[], outputs=[]) + pred = GraphPredictor( + name="resource 4", description="I am updated!", predictors=[aml] + ) + # verify that the returned object is updated returned_pred = create_or_update(collection=predictor_collection, resource=pred) assert returned_pred.name == pred.name assert returned_pred.description == pred.description - #verify that the collection is also updated - updated_pred = [r for r in list(predictor_collection.list()) if r.name == "resource 4"][0] + # verify that the collection is also updated + updated_pred = [ + r for r in list(predictor_collection.list()) if r.name == "resource 4" + ][0] assert updated_pred.description == "I am updated!" @@ -360,20 +439,36 @@ def test_create_or_update_unique_found_design_workflow(session): dw2_dict, # Return the updated design workflow ) - collection = LocalDesignWorkflowCollection(project_id=uuid4(), session=session, branch_root_id=root_id, branch_version=version) + collection = LocalDesignWorkflowCollection( + project_id=uuid4(), + session=session, + branch_root_id=root_id, + branch_version=version, + ) dw2 = collection.build(dw2_dict) - #verify that the returned object is updated + # verify that the returned object is updated returned_dw = create_or_update(collection=collection, resource=dw2) assert returned_dw.name == dw2.name - assert returned_dw.branch_root_id == collection.branch_root_id == UUID(branch_data["metadata"]["root_id"]) - assert returned_dw.branch_version == collection.branch_version == branch_data["metadata"]["version"] + assert ( + returned_dw.branch_root_id + == collection.branch_root_id + == UUID(branch_data["metadata"]["root_id"]) + ) + assert ( + returned_dw.branch_version + == collection.branch_version + == branch_data["metadata"]["version"] + ) + def test_create_or_update_raise_error_multiple_found(predictor_collection): # test when there are multiple resources that exists with the same listed name and raise error # resource 1 is not a unique name - aml = AutoMLPredictor(name="", description='', inputs=[], outputs=[]) - pred = GraphPredictor(name="resource 1", description="I am updated!", predictors=[aml]) + aml = AutoMLPredictor(name="", description="", inputs=[], outputs=[]) + pred = GraphPredictor( + name="resource 1", description="I am updated!", predictors=[aml] + ) with pytest.raises(ValueError): create_or_update(collection=predictor_collection, resource=pred) diff --git a/tests/seeding/test_sort_gems.py b/tests/seeding/test_sort_gems.py index 72a0137cc..e1b9316b7 100644 --- a/tests/seeding/test_sort_gems.py +++ b/tests/seeding/test_sort_gems.py @@ -14,18 +14,22 @@ def test_no_templates(): def test_no_data_objects(): - objs = [PropertyTemplate("pt", bounds=CategoricalBounds()), - ConditionTemplate("ct", bounds=CategoricalBounds())] + objs = [ + PropertyTemplate("pt", bounds=CategoricalBounds()), + ConditionTemplate("ct", bounds=CategoricalBounds()), + ] templates, data_objects = split_templates_from_objects(objs) assert len(templates) == 2 assert len(data_objects) == 0 def test_both_present(): - objs = [ProcessSpec("ps"), - PropertyTemplate("pt", bounds=CategoricalBounds()), - MeasurementSpec("ms"), - ConditionTemplate("ct", bounds=CategoricalBounds())] + objs = [ + ProcessSpec("ps"), + PropertyTemplate("pt", bounds=CategoricalBounds()), + MeasurementSpec("ms"), + ConditionTemplate("ct", bounds=CategoricalBounds()), + ] templates, data_objects = split_templates_from_objects(objs) assert len(templates) == 2 assert len(data_objects) == 2 diff --git a/tests/serialization/__init__.py b/tests/serialization/__init__.py index 614ec6eae..4085a882f 100644 --- a/tests/serialization/__init__.py +++ b/tests/serialization/__init__.py @@ -2,7 +2,7 @@ def valid_serialization_output(valid_data): - exclude_fields = ['status', 'status_detail'] + exclude_fields = ["status", "status_detail"] return {x: y for x, y in valid_data.items() if x not in exclude_fields} @@ -16,7 +16,7 @@ def design_space_serialization_check(data, moduleClass): """ module = moduleClass.build(data) serialized = module.dump() - assert serialized == valid_serialization_output(data)['data'] + assert serialized == valid_serialization_output(data)["data"] def predictor_serialization_check(json, module_class): diff --git a/tests/serialization/test_attribute_template.py b/tests/serialization/test_attribute_template.py index 673d650ba..8a51d0ab5 100644 --- a/tests/serialization/test_attribute_template.py +++ b/tests/serialization/test_attribute_template.py @@ -1,5 +1,5 @@ """Tests of the attribute template schema.""" -import pytest + from citrine.resources.condition_template import ConditionTemplate from citrine.resources.parameter_template import ParameterTemplate from citrine.resources.property_template import PropertyTemplate @@ -11,8 +11,10 @@ def test_condition_template(): """Test creation and serde of condition templates.""" - bounds = RealBounds(2.5, 10.0, default_units='cm') - template = ConditionTemplate("Chamber width", tags=[], bounds=bounds, description="width of chamber") + bounds = RealBounds(2.5, 10.0, default_units="cm") + template = ConditionTemplate( + "Chamber width", tags=[], bounds=bounds, description="width of chamber" + ) assert template.uids is not None # uids should be added automatically # Take template through a serde cycle and ensure that it is unchanged @@ -25,13 +27,15 @@ def test_condition_template(): def test_parameter_template(): """Test creation and serde of parameter templates.""" bounds = IntegerBounds(-3, 8) - template = ParameterTemplate("Position knob", bounds=bounds, tags=["Tag1", "A::B::C"]) + template = ParameterTemplate( + "Position knob", bounds=bounds, tags=["Tag1", "A::B::C"] + ) assert template.uids is not None # uids should be added automatically assert ParameterTemplate.build(template.dump()) == template def test_property_template(): """Test creation and serde of condition templates.""" - bounds = CategoricalBounds(['solid', 'liquid', 'gas']) - template = PropertyTemplate("State", bounds=bounds, uids={'my_id': '0'}) + bounds = CategoricalBounds(["solid", "liquid", "gas"]) + template = PropertyTemplate("State", bounds=bounds, uids={"my_id": "0"}) assert PropertyTemplate.build(template.dump()) == template diff --git a/tests/serialization/test_constraints.py b/tests/serialization/test_constraints.py index 8bf647a35..76a806a40 100644 --- a/tests/serialization/test_constraints.py +++ b/tests/serialization/test_constraints.py @@ -1,18 +1,19 @@ """Tests for citrine.informatics.constraints.""" + import pytest -from citrine.informatics.constraints import Constraint, ScalarRangeConstraint, \ - AcceptableCategoriesConstraint +from citrine.informatics.constraints import ( + Constraint, + ScalarRangeConstraint, + AcceptableCategoriesConstraint, +) @pytest.fixture def scalar_range_constraint() -> ScalarRangeConstraint: """Build a ScalarRangeConstraint.""" return ScalarRangeConstraint( - descriptor_key='z', - lower_bound=1.0, - upper_bound=10.0, - lower_inclusive=False + descriptor_key="z", lower_bound=1.0, upper_bound=10.0, lower_inclusive=False ) @@ -20,20 +21,19 @@ def scalar_range_constraint() -> ScalarRangeConstraint: def acceptable_categories_constraint() -> AcceptableCategoriesConstraint: """Build a CategoricalConstraint.""" return AcceptableCategoriesConstraint( - descriptor_key='x', - acceptable_categories=['y', 'z'] + descriptor_key="x", acceptable_categories=["y", "z"] ) def test_scalar_range_dumps(scalar_range_constraint): """Ensure values are persisted through deser.""" result = scalar_range_constraint.dump() - assert result['type'] == 'ScalarRange' - assert result['descriptor_key'] == 'z' - assert result['min'] == 1.0 - assert result['max'] == 10.0 - assert not result['min_inclusive'] - assert result['max_inclusive'] + assert result["type"] == "ScalarRange" + assert result["descriptor_key"] == "z" + assert result["min"] == 1.0 + assert result["max"] == 10.0 + assert not result["min_inclusive"] + assert result["max_inclusive"] def test_get_scalar_range_type(scalar_range_constraint): @@ -45,9 +45,9 @@ def test_get_scalar_range_type(scalar_range_constraint): def test_categorical_dumps(acceptable_categories_constraint): """Ensure values are persisted through deser.""" result = acceptable_categories_constraint.dump() - assert result['type'] == 'AcceptableCategoriesConstraint' - assert result['descriptor_key'] == 'x' - assert result['acceptable_classes'] == ['y', 'z'] + assert result["type"] == "AcceptableCategoriesConstraint" + assert result["descriptor_key"] == "x" + assert result["acceptable_classes"] == ["y", "z"] def test_get_categorical_type(acceptable_categories_constraint): diff --git a/tests/serialization/test_dataset.py b/tests/serialization/test_dataset.py index f9fa33158..74f87b73a 100644 --- a/tests/serialization/test_dataset.py +++ b/tests/serialization/test_dataset.py @@ -1,4 +1,5 @@ """Tests of the Dataset schema.""" + import pytest from uuid import uuid4, UUID from citrine.resources.dataset import Dataset @@ -10,10 +11,10 @@ def valid_data(): """Return valid data used for these tests.""" return dict( id=str(uuid4()), - name='Dataset 1', + name="Dataset 1", unique_name=None, - summary='The first dataset', - description='A dummy dataset for performing unit tests', + summary="The first dataset", + description="A dummy dataset for performing unit tests", deleted=True, created_by=None, updated_by=None, @@ -21,19 +22,19 @@ def valid_data(): create_time=1559933807392, update_time=None, delete_time=None, - public=False + public=False, ) def test_simple_deserialization(valid_data): """Ensure that a deserialized Dataset looks sane.""" dataset: Dataset = Dataset.build(valid_data) - assert dataset.uid == UUID(valid_data['id']) - assert dataset.name == 'Dataset 1' - assert dataset.summary == 'The first dataset' - assert dataset.description == 'A dummy dataset for performing unit tests' + assert dataset.uid == UUID(valid_data["id"]) + assert dataset.name == "Dataset 1" + assert dataset.summary == "The first dataset" + assert dataset.description == "A dummy dataset for performing unit tests" assert dataset.deleted - assert dataset.create_time == arrow.get(valid_data['create_time'] / 1000).datetime + assert dataset.create_time == arrow.get(valid_data["create_time"] / 1000).datetime def test_serialization(valid_data): diff --git a/tests/serialization/test_descriptors.py b/tests/serialization/test_descriptors.py index d5c41f178..dce0619e2 100644 --- a/tests/serialization/test_descriptors.py +++ b/tests/serialization/test_descriptors.py @@ -1,4 +1,5 @@ """Tests for citrine.informatics.descriptors serialization.""" + import pytest from citrine.informatics.descriptors import RealDescriptor, Descriptor @@ -8,9 +9,9 @@ def valid_data(): """Produce valid descriptor data.""" return dict( - type='Real', - descriptor_key='alpha', - units='', + type="Real", + descriptor_key="alpha", + units="", lower_bound=5.0, upper_bound=10.0, ) @@ -19,8 +20,8 @@ def valid_data(): def test_simple_deserialization(valid_data): """Ensure a deserialized RealDescriptor looks sane.""" descriptor = RealDescriptor.build(valid_data) - assert descriptor.key == 'alpha' - assert descriptor.units == '' + assert descriptor.key == "alpha" + assert descriptor.units == "" assert descriptor.lower_bound == 5.0 assert descriptor.upper_bound == 10.0 @@ -28,8 +29,8 @@ def test_simple_deserialization(valid_data): def test_polymorphic_deserialization(valid_data): """Ensure a polymorphically deserialized RealDescriptor looks sane.""" descriptor: RealDescriptor = Descriptor.build(valid_data) - assert descriptor.key == 'alpha' - assert descriptor.units == '' + assert descriptor.key == "alpha" + assert descriptor.units == "" assert descriptor.lower_bound == 5.0 assert descriptor.upper_bound == 10.0 diff --git a/tests/serialization/test_design_spaces.py b/tests/serialization/test_design_spaces.py index 515a61c94..60bdcd572 100644 --- a/tests/serialization/test_design_spaces.py +++ b/tests/serialization/test_design_spaces.py @@ -1,34 +1,43 @@ """Tests for citrine.informatics.design_spaces serialization.""" -from copy import copy, deepcopy -from uuid import UUID + +from copy import deepcopy import pytest -from . import design_space_serialization_check, valid_serialization_output from citrine.informatics.constraints import IngredientCountConstraint -from citrine.informatics.descriptors import CategoricalDescriptor, RealDescriptor, ChemicalFormulaDescriptor,\ - FormulationDescriptor -from citrine.informatics.design_spaces import DesignSpace, ProductDesignSpace, EnumeratedDesignSpace,\ - FormulationDesignSpace +from citrine.informatics.descriptors import ( + CategoricalDescriptor, + ChemicalFormulaDescriptor, + FormulationDescriptor, + RealDescriptor, +) +from citrine.informatics.design_spaces import ( + DesignSpace, + EnumeratedDesignSpace, + FormulationDesignSpace, + ProductDesignSpace, +) from citrine.informatics.dimensions import ContinuousDimension, EnumeratedDimension +from . import design_space_serialization_check + def test_product_deserialization(valid_product_design_space_data): """Ensure that a deserialized ProductDesignSpace looks sane.""" for designSpaceClass in [ProductDesignSpace, DesignSpace]: data = deepcopy(valid_product_design_space_data) design_space: ProductDesignSpace = designSpaceClass.build(data) - assert design_space.name == 'my design space' - assert design_space.description == 'does some things' - assert type(design_space.dimensions[0]) == ContinuousDimension + assert design_space.name == "my design space" + assert design_space.description == "does some things" + assert type(design_space.dimensions[0]) is ContinuousDimension assert design_space.dimensions[0].lower_bound == 6.0 - assert type(design_space.dimensions[1]) == EnumeratedDimension - assert design_space.dimensions[1].values == ['red'] - assert type(design_space.subspaces[0]) == FormulationDesignSpace + assert type(design_space.dimensions[1]) is EnumeratedDimension + assert design_space.dimensions[1].values == ["red"] + assert type(design_space.subspaces[0]) is FormulationDesignSpace assert design_space.subspaces[0].uid is None - assert type(design_space.subspaces[1]) == FormulationDesignSpace + assert type(design_space.subspaces[1]) is FormulationDesignSpace assert design_space.subspaces[1].uid is None - assert design_space.subspaces[1].ingredients == {'baz'} + assert design_space.subspaces[1].ingredients == {"baz"} def test_product_serialization(valid_product_design_space_data): @@ -36,9 +45,15 @@ def test_product_serialization(valid_product_design_space_data): original_data = deepcopy(valid_product_design_space_data) design_space = ProductDesignSpace.build(valid_product_design_space_data) serialized = design_space.dump() - serialized['id'] = valid_product_design_space_data['id'] - assert serialized['instance']['subspaces'][0] == original_data['data']['instance']['subspaces'][0] - assert serialized['instance']['subspaces'][1] == original_data['data']['instance']['subspaces'][1] + serialized["id"] = valid_product_design_space_data["id"] + assert ( + serialized["instance"]["subspaces"][0] + == original_data["data"]["instance"]["subspaces"][0] + ) + assert ( + serialized["instance"]["subspaces"][1] + == original_data["data"]["instance"]["subspaces"][1] + ) def test_enumerated_deserialization(valid_enumerated_design_space_data): @@ -47,47 +62,59 @@ def test_enumerated_deserialization(valid_enumerated_design_space_data): and polymorphically (using DesignSpace) """ for designSpaceClass in [DesignSpace, EnumeratedDesignSpace]: - design_space: EnumeratedDesignSpace = designSpaceClass.build(valid_enumerated_design_space_data) - assert design_space.name == 'my enumerated design space' - assert design_space.description == 'enumerates some things' + design_space: EnumeratedDesignSpace = designSpaceClass.build( + valid_enumerated_design_space_data + ) + assert design_space.name == "my enumerated design space" + assert design_space.description == "enumerates some things" assert len(design_space.descriptors) == 3 real, categorical, formula = design_space.descriptors - assert type(real) == RealDescriptor - assert real.key == 'x' - assert real.units == '' + assert type(real) is RealDescriptor + assert real.key == "x" + assert real.units == "" assert real.lower_bound == 1.0 assert real.upper_bound == 2.0 - assert type(categorical) == CategoricalDescriptor - assert categorical.key == 'color' - assert categorical.categories == {'red', 'green', 'blue'} + assert type(categorical) is CategoricalDescriptor + assert categorical.key == "color" + assert categorical.categories == {"red", "green", "blue"} - assert type(formula) == ChemicalFormulaDescriptor - assert formula.key == 'formula' + assert type(formula) is ChemicalFormulaDescriptor + assert formula.key == "formula" assert len(design_space.data) == 2 - assert design_space.data[0] == {'x': '1', 'color': 'red', 'formula': 'C44H54Si2'} - assert design_space.data[1] == {'x': '2.0', 'color': 'green', 'formula': 'V2O3'} + assert design_space.data[0] == { + "x": "1", + "color": "red", + "formula": "C44H54Si2", + } + assert design_space.data[1] == {"x": "2.0", "color": "green", "formula": "V2O3"} -def test_enumerated_serialization_data_int_deprecated(valid_enumerated_design_space_data): +def test_enumerated_serialization_data_int_deprecated( + valid_enumerated_design_space_data, +): design_space = EnumeratedDesignSpace.build(valid_enumerated_design_space_data) with pytest.deprecated_call(): - design_space.data = [dict(x=1, color='red', formula='C44H54Si2')] + design_space.data = [dict(x=1, color="red", formula="C44H54Si2")] -def test_enumerated_serialization_data_float_deprecated(valid_enumerated_design_space_data): +def test_enumerated_serialization_data_float_deprecated( + valid_enumerated_design_space_data, +): design_space = EnumeratedDesignSpace.build(valid_enumerated_design_space_data) with pytest.deprecated_call(): - design_space.data = [dict(x=1.0, color='red', formula='C44H54Si2')] + design_space.data = [dict(x=1.0, color="red", formula="C44H54Si2")] def test_enumerated_serialization(valid_enumerated_design_space_data): """Ensure that a serialized EnumeratedDesignSpace looks sane.""" - design_space_serialization_check(valid_enumerated_design_space_data, EnumeratedDesignSpace) + design_space_serialization_check( + valid_enumerated_design_space_data, EnumeratedDesignSpace + ) def test_formulation_deserialization(valid_formulation_design_space_data): @@ -97,19 +124,21 @@ def test_formulation_deserialization(valid_formulation_design_space_data): """ expected_descriptor = FormulationDescriptor.hierarchical() expected_constraint = IngredientCountConstraint( - formulation_descriptor=expected_descriptor, - min=0, - max=1 + formulation_descriptor=expected_descriptor, min=0, max=1 ) for designSpaceClass in [DesignSpace, FormulationDesignSpace]: - design_space: FormulationDesignSpace = designSpaceClass.build(valid_formulation_design_space_data) - assert design_space.name == 'formulation design space' - assert design_space.description == 'formulates some things' + design_space: FormulationDesignSpace = designSpaceClass.build( + valid_formulation_design_space_data + ) + assert design_space.name == "formulation design space" + assert design_space.description == "formulates some things" assert design_space.formulation_descriptor.key == expected_descriptor.key - assert design_space.ingredients == {'foo'} - assert design_space.labels == {'bar': {'foo'}} + assert design_space.ingredients == {"foo"} + assert design_space.labels == {"bar": {"foo"}} assert len(design_space.constraints) == 1 - actual_constraint: IngredientCountConstraint = next(iter(design_space.constraints)) + actual_constraint: IngredientCountConstraint = next( + iter(design_space.constraints) + ) assert actual_constraint.formulation_descriptor == expected_descriptor assert actual_constraint.min == expected_constraint.min assert actual_constraint.max == expected_constraint.max @@ -118,4 +147,6 @@ def test_formulation_deserialization(valid_formulation_design_space_data): def test_formulation_serialization(valid_formulation_design_space_data): """Ensure that a serialized FormulationDesignSpace looks sane.""" - design_space_serialization_check(valid_formulation_design_space_data, FormulationDesignSpace) + design_space_serialization_check( + valid_formulation_design_space_data, FormulationDesignSpace + ) diff --git a/tests/serialization/test_dimensions.py b/tests/serialization/test_dimensions.py index 76af408af..8738b47a1 100644 --- a/tests/serialization/test_dimensions.py +++ b/tests/serialization/test_dimensions.py @@ -1,26 +1,29 @@ """Tests for citrine.informatics.dimensions serialization.""" -import uuid import pytest -from citrine.informatics.descriptors import RealDescriptor, CategoricalDescriptor -from citrine.informatics.dimensions import Dimension, ContinuousDimension, EnumeratedDimension +from citrine.informatics.descriptors import CategoricalDescriptor, RealDescriptor +from citrine.informatics.dimensions import ( + ContinuousDimension, + Dimension, + EnumeratedDimension, +) @pytest.fixture def valid_continuous_data(): """Produce valid continuous dimension data.""" return dict( - type='ContinuousDimension', + type="ContinuousDimension", descriptor=dict( - type='Real', - descriptor_key='alpha', - units='', + type="Real", + descriptor_key="alpha", + units="", lower_bound=5.0, upper_bound=10.0, ), lower_bound=6.0, - upper_bound=7.0 + upper_bound=7.0, ) @@ -28,32 +31,32 @@ def valid_continuous_data(): def valid_enumerated_data(): """Produce valid enumerated dimension data.""" return dict( - type='EnumeratedDimension', + type="EnumeratedDimension", descriptor=dict( - type='Categorical', - descriptor_key='color', - descriptor_values=['blue', 'green', 'red'], + type="Categorical", + descriptor_key="color", + descriptor_values=["blue", "green", "red"], ), - list=['red'] + list=["red"], ) def test_simple_continuous_deserialization(valid_continuous_data): """Ensure that a deserialized ContinuousDimension looks sane.""" dimension = ContinuousDimension.build(valid_continuous_data) - assert type(dimension) == ContinuousDimension + assert type(dimension) is ContinuousDimension assert dimension.lower_bound == 6.0 assert dimension.upper_bound == 7.0 - assert type(dimension.descriptor) == RealDescriptor + assert type(dimension.descriptor) is RealDescriptor def test_polymorphic_continuous_deserialization(valid_continuous_data): """Ensure that a polymorphically deserialized ContinuousDimension looks sane.""" dimension: ContinuousDimension = Dimension.build(valid_continuous_data) - assert type(dimension) == ContinuousDimension + assert type(dimension) is ContinuousDimension assert dimension.lower_bound == 6.0 assert dimension.upper_bound == 7.0 - assert type(dimension.descriptor) == RealDescriptor + assert type(dimension.descriptor) is RealDescriptor def test_continuous_serialization(valid_continuous_data): @@ -66,17 +69,17 @@ def test_continuous_serialization(valid_continuous_data): def test_simple_enumerated_deserialization(valid_enumerated_data): """Ensure that a deserialized EnumeratedDimension looks sane.""" dimension: EnumeratedDimension = EnumeratedDimension.build(valid_enumerated_data) - assert type(dimension) == EnumeratedDimension - assert dimension.values == ['red'] - assert type(dimension.descriptor) == CategoricalDescriptor + assert type(dimension) is EnumeratedDimension + assert dimension.values == ["red"] + assert type(dimension.descriptor) is CategoricalDescriptor def test_polymorphic_enumerated_deserialization(valid_enumerated_data): """Ensure that a polymorphically deserialized EnumeratedDimension looks sane.""" dimension: EnumeratedDimension = Dimension.build(valid_enumerated_data) - assert type(dimension) == EnumeratedDimension - assert dimension.values == ['red'] - assert type(dimension.descriptor) == CategoricalDescriptor + assert type(dimension) is EnumeratedDimension + assert dimension.values == ["red"] + assert type(dimension.descriptor) is CategoricalDescriptor def test_enumerated_serialization(valid_enumerated_data): diff --git a/tests/serialization/test_file_link.py b/tests/serialization/test_file_link.py index ec4b4ca69..d19b313fd 100644 --- a/tests/serialization/test_file_link.py +++ b/tests/serialization/test_file_link.py @@ -1,18 +1,19 @@ """Tests of FileLink serialization and deserialization.""" + from citrine.resources.file_link import FileLink from tests.utils.factories import FileLinkDataFactory def test_simple_deserialization(): """Ensure that a deserialized File Link looks sane.""" - valid_data = FileLinkDataFactory(url='www.citrine.io', filename='materials.txt') + valid_data = FileLinkDataFactory(url="www.citrine.io", filename="materials.txt") file_link = FileLink.build(valid_data) - assert file_link.url == 'www.citrine.io' - assert file_link.filename == 'materials.txt' + assert file_link.url == "www.citrine.io" + assert file_link.filename == "materials.txt" def test_serialization(): """Ensure that a serialized File Link looks sane.""" - valid_data = FileLinkDataFactory(url='www.citrine.io', filename='materials.txt') + valid_data = FileLinkDataFactory(url="www.citrine.io", filename="materials.txt") file_link = FileLink.build(valid_data) assert file_link.dump() == valid_data diff --git a/tests/serialization/test_gem_table.py b/tests/serialization/test_gem_table.py index 5a9cf1e6f..466a26293 100644 --- a/tests/serialization/test_gem_table.py +++ b/tests/serialization/test_gem_table.py @@ -12,14 +12,14 @@ def valid_data(): return dict( id=str(uuid4()), version=randrange(10), - signed_download_url="https://s3.amazonaws.citrine.io/bucketboi" + signed_download_url="https://s3.amazonaws.citrine.io/bucketboi", ) def test_simple_deserialization(valid_data): """Ensure that a deserialized Table looks normal.""" table: GemTable = GemTable.build(valid_data) - assert table.uid == UUID(valid_data['id']) + assert table.uid == UUID(valid_data["id"]) assert table.version == valid_data["version"] assert table.download_url == "https://s3.amazonaws.citrine.io/bucketboi" diff --git a/tests/serialization/test_ingredient_run.py b/tests/serialization/test_ingredient_run.py index 7c294bc5c..2096417b6 100644 --- a/tests/serialization/test_ingredient_run.py +++ b/tests/serialization/test_ingredient_run.py @@ -1,4 +1,5 @@ """Tests of the ingredient run schema.""" + import pytest from uuid import uuid4 @@ -12,43 +13,55 @@ def valid_data(): """Return valid data used for these tests.""" return dict( - uids={'id': str(uuid4())}, + uids={"id": str(uuid4())}, tags=[], notes=None, - material={'type': 'material_run', 'name': 'flour', 'uids': {'id': str(uuid4())}, - 'tags': [], 'file_links': [], 'notes': None, - 'process': None, 'sample_type': 'unknown', 'spec': None, - }, + material={ + "type": "material_run", + "name": "flour", + "uids": {"id": str(uuid4())}, + "tags": [], + "file_links": [], + "notes": None, + "process": None, + "sample_type": "unknown", + "spec": None, + }, process=None, - mass_fraction={'type': 'normal_real', 'mean': 0.5, 'std': 0.1, 'units': 'dimensionless'}, + mass_fraction={ + "type": "normal_real", + "mean": 0.5, + "std": 0.1, + "units": "dimensionless", + }, volume_fraction=None, number_fraction=None, absolute_quantity=None, - name='flour', - labels=['fine', 'bleached'], + name="flour", + labels=["fine", "bleached"], spec=None, file_links=[], - type='ingredient_run' + type="ingredient_run", ) def test_simple_deserialization(valid_data): """Ensure that a deserialized Ingredient Run looks sane.""" ingredient_run: IngredientRun = IngredientRun.build(valid_data) - assert ingredient_run.uids == {'id': valid_data['uids']['id']} + assert ingredient_run.uids == {"id": valid_data["uids"]["id"]} assert ingredient_run.tags == [] assert ingredient_run.notes is None - assert ingredient_run.material.dump() == valid_data['material'] + assert ingredient_run.material.dump() == valid_data["material"] assert ingredient_run.process is None - assert ingredient_run.mass_fraction == NormalReal(0.5, 0.1, '') + assert ingredient_run.mass_fraction == NormalReal(0.5, 0.1, "") assert ingredient_run.volume_fraction is None assert ingredient_run.number_fraction is None assert ingredient_run.absolute_quantity is None - assert ingredient_run.name == 'flour' - assert ingredient_run.labels == ['fine', 'bleached'] + assert ingredient_run.name == "flour" + assert ingredient_run.labels == ["fine", "bleached"] assert ingredient_run.spec is None assert ingredient_run.file_links == [] - assert ingredient_run.typ == 'ingredient_run' + assert ingredient_run.typ == "ingredient_run" def test_serialization(valid_data): @@ -64,8 +77,10 @@ def test_material_attachment(): Check that the ingredient can be built, and that the connection survives ser/de. """ - flour = MaterialRun("flour", sample_type='unknown') - flour_ingredient = IngredientRun(material=flour, absolute_quantity=NominalReal(500, 'g')) + flour = MaterialRun("flour", sample_type="unknown") + flour_ingredient = IngredientRun( + material=flour, absolute_quantity=NominalReal(500, "g") + ) flour_ingredient_copy = IngredientRun.build(flour_ingredient.dump()) assert flour_ingredient_copy == flour_ingredient diff --git a/tests/serialization/test_material_run.py b/tests/serialization/test_material_run.py index 0ed9acee8..1d686902c 100644 --- a/tests/serialization/test_material_run.py +++ b/tests/serialization/test_material_run.py @@ -1,44 +1,43 @@ """Tests of the Material Run schema.""" + import json -from typing import Optional, Iterable +from typing import Iterable, Optional -from citrine.resources.material_run import MaterialRun -from citrine.resources.material_spec import MaterialSpec -from citrine.resources.measurement_spec import MeasurementSpec -from citrine.resources.process_run import ProcessRun -from citrine.resources.ingredient_run import IngredientRun -from citrine.resources.ingredient_spec import IngredientSpec -from citrine.resources.measurement_run import MeasurementRun -from gemd.entity.link_by_uid import LinkByUID -from gemd.json import GEMDJson from gemd.demo.cake import make_cake -from gemd.entity.object import MeasurementRun as GEMDMeasurementRun +from gemd.entity.file_link import FileLink +from gemd.entity.link_by_uid import LinkByUID from gemd.entity.object import MaterialRun as GEMDMaterialRun from gemd.entity.object import MaterialSpec as GEMDMaterialSpec -from gemd.entity.object import MeasurementSpec as GEMDMeasurementSpec -from gemd.entity.object import ProcessSpec as GEMDProcessSpec from gemd.entity.object import ProcessRun as GEMDProcessRun -from gemd.entity.object.ingredient_spec import IngredientSpec as GEMDIngredientSpec -from gemd.entity.object.ingredient_run import IngredientRun as GEMDIngredientRun -from gemd.entity.file_link import FileLink +from gemd.entity.object import ProcessSpec as GEMDProcessSpec +from gemd.json import GEMDJson +from citrine.resources.ingredient_run import IngredientRun +from citrine.resources.ingredient_spec import IngredientSpec +from citrine.resources.material_run import MaterialRun +from citrine.resources.material_spec import MaterialSpec +from citrine.resources.measurement_run import MeasurementRun +from citrine.resources.measurement_spec import MeasurementSpec +from citrine.resources.process_run import ProcessRun from tests.utils.factories import MaterialRunDataFactory def test_simple_deserialization(): """Ensure that a deserialized Material Run looks sane.""" - valid_data: dict = MaterialRunDataFactory(name='Cake 1', notes=None, spec=None) + valid_data: dict = MaterialRunDataFactory(name="Cake 1", notes=None, spec=None) material_run: MaterialRun = MaterialRun.build(valid_data) assert isinstance(material_run, MaterialRun) - assert material_run.uids == valid_data['uids'] - assert material_run.name == valid_data['name'] - assert material_run.tags == valid_data['tags'] + assert material_run.uids == valid_data["uids"] + assert material_run.name == valid_data["name"] + assert material_run.tags == valid_data["tags"] assert material_run.notes is None - assert material_run.process == LinkByUID.build(valid_data['process']) - assert material_run.sample_type == valid_data['sample_type'] + assert material_run.process == LinkByUID.build(valid_data["process"]) + assert material_run.sample_type == valid_data["sample_type"] assert material_run.template is None assert material_run.spec is None - assert material_run.file_links == [FileLink.build(x) for x in valid_data['file_links']] + assert material_run.file_links == [ + FileLink.build(x) for x in valid_data["file_links"] + ] def test_serialization(): @@ -51,15 +50,15 @@ def test_serialization(): def test_process_attachment(): """Test that a process can be attached to a material, and that the connection survives serde""" - cake = MaterialRun('Final cake') - cake.process = ProcessRun('Icing', uids={'id': '12345'}) + cake = MaterialRun("Final cake") + cake.process = ProcessRun("Icing", uids={"id": "12345"}) cake_data = cake.dump() cake_copy = MaterialRun.build(cake_data).as_dict() - assert cake_copy['name'] == cake.name - assert cake_copy['uids'] == cake.uids - assert cake.process.uids['id'] == cake_copy['process'].uids['id'] + assert cake_copy["name"] == cake.name + assert cake_copy["uids"] == cake.uids + assert cake.process.uids["id"] == cake_copy["process"].uids["id"] reconstituted_cake = MaterialRun.build(cake_copy) assert isinstance(reconstituted_cake, MaterialRun) @@ -74,76 +73,80 @@ def make_ingredient(material: MaterialRun): return IngredientRun(material=material) icing = ProcessRun(name="Icing") - cake = MaterialRun(name='Final cake', process=icing) + cake = MaterialRun(name="Final cake", process=icing) - cake.process.ingredients.append(make_ingredient(MaterialRun('Baked Cake'))) - cake.process.ingredients.append(make_ingredient(MaterialRun('Frosting'))) + cake.process.ingredients.append(make_ingredient(MaterialRun("Baked Cake"))) + cake.process.ingredients.append(make_ingredient(MaterialRun("Frosting"))) baked = cake.process.ingredients[0].material - baked.process = ProcessRun(name='Baking') - baked.process.ingredients.append(make_ingredient(MaterialRun('Batter'))) + baked.process = ProcessRun(name="Baking") + baked.process.ingredients.append(make_ingredient(MaterialRun("Batter"))) batter = baked.process.ingredients[0].material - batter.process = ProcessRun(name='Mixing batter') + batter.process = ProcessRun(name="Mixing batter") - batter.process.ingredients.append(make_ingredient(material=MaterialRun('Butter'))) - batter.process.ingredients.append(make_ingredient(material=MaterialRun('Sugar'))) - batter.process.ingredients.append(make_ingredient(material=MaterialRun('Flour'))) - batter.process.ingredients.append(make_ingredient(material=MaterialRun('Milk'))) + batter.process.ingredients.append(make_ingredient(material=MaterialRun("Butter"))) + batter.process.ingredients.append(make_ingredient(material=MaterialRun("Sugar"))) + batter.process.ingredients.append(make_ingredient(material=MaterialRun("Flour"))) + batter.process.ingredients.append(make_ingredient(material=MaterialRun("Milk"))) cake.dump() def test_measurement_material_connection_rehydration(): """Test that fully-linked GEMD object can be built as fully-linked Citrine-python object.""" - starting_mat_spec = GEMDMaterialSpec("starting material") - starting_mat = GEMDMaterialRun("starting material", spec=starting_mat_spec) - meas_spec = GEMDMeasurementSpec("measurement spec") - meas1 = GEMDMeasurementRun("measurement on starting material", - spec=meas_spec, material=starting_mat) process_spec = GEMDProcessSpec("Transformative process") process = GEMDProcessRun("Transformative process", spec=process_spec) - ingredient_spec = GEMDIngredientSpec(name="ingredient", material=starting_mat_spec, - process=process_spec) - ingredient = GEMDIngredientRun(material=starting_mat, process=process, spec=ingredient_spec) ending_mat_spec = GEMDMaterialSpec("ending material", process=process_spec) - ending_mat = GEMDMaterialRun("ending material", process=process, spec=ending_mat_spec) - meas2 = GEMDMeasurementRun("measurement on ending material", - spec=meas_spec, material=ending_mat) + ending_mat = GEMDMaterialRun( + "ending material", process=process, spec=ending_mat_spec + ) copy = MaterialRun.build(json.loads(GEMDJson().dumps(ending_mat))) assert isinstance(copy, MaterialRun), "copy of ending_mat should be a MaterialRun" assert len(copy.measurements) == 1, "copy of ending_mat should have one measurement" - assert isinstance(copy.measurements[0], MeasurementRun), \ + assert isinstance(copy.measurements[0], MeasurementRun), ( "copy of ending_mat should have a measurement that is a MeasurementRun" - assert isinstance(copy.measurements[0].spec, MeasurementSpec), \ + ) + assert isinstance(copy.measurements[0].spec, MeasurementSpec), ( "copy of ending_mat should have a measurement that has a spec that is a MeasurementSpec" - assert isinstance(copy.process, ProcessRun), "copy of ending_mat should have a process" - assert len(copy.process.ingredients) == 1, \ + ) + assert isinstance(copy.process, ProcessRun), ( + "copy of ending_mat should have a process" + ) + assert len(copy.process.ingredients) == 1, ( "copy of ending_mat should have a process with one ingredient" + ) assert isinstance(copy.spec, MaterialSpec), "copy of ending_mat should have a spec" - assert len(copy.spec.process.ingredients) == 1, \ + assert len(copy.spec.process.ingredients) == 1, ( "copy of ending_mat should have a spec with a process that has one ingredient" - assert isinstance(copy.process.spec.ingredients[0], IngredientSpec), \ - "copy of ending_mat should have a spec with a process that has an ingredient " \ + ) + assert isinstance(copy.process.spec.ingredients[0], IngredientSpec), ( + "copy of ending_mat should have a spec with a process that has an ingredient " "that is an IngredientRun" + ) copy_ingredient = copy.process.ingredients[0] - assert isinstance(copy_ingredient, IngredientRun), \ + assert isinstance(copy_ingredient, IngredientRun), ( "copy of ending_mat should have a process with an ingredient that is an IngredientRun" - assert isinstance(copy_ingredient.material, MaterialRun), \ + ) + assert isinstance(copy_ingredient.material, MaterialRun), ( "copy of ending_mat should have a process with an ingredient that links to a MaterialRun" - assert len(copy_ingredient.material.measurements) == 1, \ - "copy of ending_mat should have a process with an ingredient derived from a material " \ + ) + assert len(copy_ingredient.material.measurements) == 1, ( + "copy of ending_mat should have a process with an ingredient derived from a material " "that has one measurement performed on it" - assert isinstance(copy_ingredient.material.measurements[0], MeasurementRun), \ - "copy of ending_mat should have a process with an ingredient derived from a material " \ + ) + assert isinstance(copy_ingredient.material.measurements[0], MeasurementRun), ( + "copy of ending_mat should have a process with an ingredient derived from a material " "that has one measurement that gets deserialized as a MeasurementRun" - assert isinstance(copy_ingredient.material.measurements[0].spec, MeasurementSpec), \ - "copy of ending_mat should have a process with an ingredient derived from a material " \ + ) + assert isinstance(copy_ingredient.material.measurements[0].spec, MeasurementSpec), ( + "copy of ending_mat should have a process with an ingredient derived from a material " "that has one measurement that has a spec" + ) def test_cake(): @@ -158,10 +161,12 @@ def test_cake(): """ gemd_cake = make_cake() cake = MaterialRun.build(json.loads(GEMDJson().dumps(gemd_cake))) - assert [ingred.name for ingred in cake.process.ingredients] == \ - [ingred.name for ingred in gemd_cake.process.ingredients] - assert [ingred.labels for ingred in cake.process.ingredients] == \ - [ingred.labels for ingred in gemd_cake.process.ingredients] + assert [ingred.name for ingred in cake.process.ingredients] == [ + ingred.name for ingred in gemd_cake.process.ingredients + ] + assert [ingred.labels for ingred in cake.process.ingredients] == [ + ingred.labels for ingred in gemd_cake.process.ingredients + ] assert gemd_cake == cake def _by_name(start: MaterialRun, names: Iterable[str]) -> Optional[MaterialRun]: @@ -169,7 +174,10 @@ def _by_name(start: MaterialRun, names: Iterable[str]) -> Optional[MaterialRun]: names = [names] while names: target = names.pop(0) - start = next((i.material for i in start.process.ingredients if i.name == target), None) + start = next( + (i.material for i in start.process.ingredients if i.name == target), + None, + ) if start is None: return None return start @@ -177,4 +185,6 @@ def _by_name(start: MaterialRun, names: Iterable[str]) -> Optional[MaterialRun]: by_cake = _by_name(cake, ["baked cake", "batter", "wet ingredients", "butter"]) by_frosting = _by_name(cake, ["frosting", "butter"]) assert by_cake is by_frosting # Same literal object - assert _by_name(gemd_cake, ["frosting", "butter"]) is not by_frosting # Same literal object + assert ( + _by_name(gemd_cake, ["frosting", "butter"]) is not by_frosting + ) # Same literal object diff --git a/tests/serialization/test_material_spec.py b/tests/serialization/test_material_spec.py index 9bb04c1fa..909e3e1ba 100644 --- a/tests/serialization/test_material_spec.py +++ b/tests/serialization/test_material_spec.py @@ -1,4 +1,5 @@ """Tests of the material spec schema.""" + import pytest from uuid import uuid4 @@ -14,64 +15,67 @@ def valid_data(): """Return valid data used for these tests.""" return dict( - name='spec of material', - uids={'id': str(uuid4())}, + name="spec of material", + uids={"id": str(uuid4())}, tags=[], notes=None, process=None, template=None, properties=[ { - 'type': 'property_and_conditions', - 'property': + "type": "property_and_conditions", + "property": { + "type": "property", + "origin": "specified", + "name": "color", + "template": None, + "notes": None, + "value": {"category": "tan", "type": "nominal_categorical"}, + "file_links": [], + }, + "conditions": [ { - 'type': 'property', - 'origin': 'specified', - 'name': 'color', - 'template': None, - 'notes': None, - 'value': {'category': 'tan', 'type': 'nominal_categorical'}, - 'file_links': [] - }, - 'conditions': - [ - { - 'type': 'condition', - 'origin': 'specified', - 'name': 'temperature', - 'template': None, - 'notes': None, - 'value': { - 'type': 'nominal_real', - 'nominal': 300.0, - 'units': 'kelvin' - }, - 'file_links': [] - } - ] + "type": "condition", + "origin": "specified", + "name": "temperature", + "template": None, + "notes": None, + "value": { + "type": "nominal_real", + "nominal": 300.0, + "units": "kelvin", + }, + "file_links": [], + } + ], } ], file_links=[], - type='material_spec' + type="material_spec", ) def test_simple_deserialization(valid_data): """Ensure that a deserialized Material Spec looks sane.""" material_spec: MaterialSpec = MaterialSpec.build(valid_data) - assert material_spec.uids == {'id': valid_data['uids']['id']} - assert material_spec.name == 'spec of material' + assert material_spec.uids == {"id": valid_data["uids"]["id"]} + assert material_spec.name == "spec of material" assert material_spec.tags == [] assert material_spec.notes is None assert material_spec.process is None - assert material_spec.properties[0] == \ - PropertyAndConditions(Property("color", origin='specified', - value=NominalCategorical("tan")), - conditions=[Condition('temperature', origin='specified', - value=NominalReal(300, units='kelvin'))]) + assert material_spec.properties[0] == PropertyAndConditions( + Property("color", origin="specified", value=NominalCategorical("tan")), + conditions=[ + Condition( + "temperature", + origin="specified", + value=NominalReal(300, units="kelvin"), + ) + ], + ) assert material_spec.template is None assert material_spec.file_links == [] - assert material_spec.typ == 'material_spec' + assert material_spec.typ == "material_spec" def test_serialization(valid_data): diff --git a/tests/serialization/test_measurement_run.py b/tests/serialization/test_measurement_run.py index 9dc6953f1..72f80d224 100644 --- a/tests/serialization/test_measurement_run.py +++ b/tests/serialization/test_measurement_run.py @@ -1,7 +1,7 @@ """Tests of the Measurement Run schema""" + import pytest from uuid import uuid4, UUID -from datetime import datetime from gemd.entity.attribute.property import Property from gemd.entity.value.nominal_integer import NominalInteger @@ -13,44 +13,59 @@ def valid_data(): """Return valid data used for these tests.""" return dict( - uids={'id': str(uuid4())}, - name='Taste test', + uids={"id": str(uuid4())}, + name="Taste test", tags=[], notes=None, conditions=[], parameters=[], - properties=[{'name': 'sweetness', 'type': 'property', 'template': None, 'notes': None, - 'origin': 'measured', 'file_links': [], - 'value': {'type': 'nominal_integer', 'nominal': 7}}, - {'type': 'property', 'name': 'fluffiness', 'template': None, 'notes': None, - 'origin': 'measured', 'file_links': [], - 'value': {'type': 'nominal_integer', 'nominal': 10} - }], + properties=[ + { + "name": "sweetness", + "type": "property", + "template": None, + "notes": None, + "origin": "measured", + "file_links": [], + "value": {"type": "nominal_integer", "nominal": 7}, + }, + { + "type": "property", + "name": "fluffiness", + "template": None, + "notes": None, + "origin": "measured", + "file_links": [], + "value": {"type": "nominal_integer", "nominal": 10}, + }, + ], material={ - 'uids': {'id': str(uuid4())}, - 'name': 'sponge cake', - 'tags': [], - 'notes': None, - 'process': None, - 'sample_type': 'experimental', - 'spec': None, - 'file_links': [], - 'type': 'material_run', - 'audit_info': { - 'created_by': str(uuid4()), 'created_at': 1559933807392, - 'updated_by': str(uuid4()), 'updated_at': 1560033807392 + "uids": {"id": str(uuid4())}, + "name": "sponge cake", + "tags": [], + "notes": None, + "process": None, + "sample_type": "experimental", + "spec": None, + "file_links": [], + "type": "material_run", + "audit_info": { + "created_by": str(uuid4()), + "created_at": 1559933807392, + "updated_by": str(uuid4()), + "updated_at": 1560033807392, }, - 'dataset': str(uuid4()), + "dataset": str(uuid4()), }, spec=None, file_links=[], - type='measurement_run', + type="measurement_run", source={ "type": "performed_source", "performed_by": "Marie Curie", - "performed_date": "1898-07-01" + "performed_date": "1898-07-01", }, - audit_info={'created_by': str(uuid4()), 'created_at': 1560133807392}, + audit_info={"created_by": str(uuid4()), "created_at": 1560133807392}, dataset=str(uuid4()), ) @@ -58,27 +73,36 @@ def valid_data(): def test_simple_deserialization(valid_data): """Ensure that a deserialized Measurement Run looks sane.""" measurement_run: MeasurementRun = MeasurementRun.build(valid_data) - assert measurement_run.uids == {'id': valid_data['uids']['id']} - assert measurement_run.name == 'Taste test' + assert measurement_run.uids == {"id": valid_data["uids"]["id"]} + assert measurement_run.name == "Taste test" assert measurement_run.notes is None assert measurement_run.tags == [] assert measurement_run.conditions == [] assert measurement_run.parameters == [] - assert measurement_run.properties[0] == Property('sweetness', origin="measured", - value=NominalInteger(7)) - assert measurement_run.properties[1] == Property('fluffiness', origin="measured", - value=NominalInteger(10)) + assert measurement_run.properties[0] == Property( + "sweetness", origin="measured", value=NominalInteger(7) + ) + assert measurement_run.properties[1] == Property( + "fluffiness", origin="measured", value=NominalInteger(10) + ) assert measurement_run.file_links == [] assert measurement_run.template is None - assert measurement_run.material == MaterialRun('sponge cake', tags=[], - uids={'id': valid_data['material']['uids']['id']}, - sample_type='experimental') - assert measurement_run.material.audit_info.created_by == UUID(valid_data['material']['audit_info']['created_by']) - assert measurement_run.material.dataset == UUID(valid_data['material']['dataset']) + assert measurement_run.material == MaterialRun( + "sponge cake", + tags=[], + uids={"id": valid_data["material"]["uids"]["id"]}, + sample_type="experimental", + ) + assert measurement_run.material.audit_info.created_by == UUID( + valid_data["material"]["audit_info"]["created_by"] + ) + assert measurement_run.material.dataset == UUID(valid_data["material"]["dataset"]) assert measurement_run.spec is None - assert measurement_run.typ == 'measurement_run' - assert measurement_run.audit_info.created_by == UUID(valid_data['audit_info']['created_by']) - assert measurement_run.dataset == UUID(valid_data['dataset']) + assert measurement_run.typ == "measurement_run" + assert measurement_run.audit_info.created_by == UUID( + valid_data["audit_info"]["created_by"] + ) + assert measurement_run.dataset == UUID(valid_data["dataset"]) def test_serialization(valid_data): @@ -86,17 +110,17 @@ def test_serialization(valid_data): measurement_run: MeasurementRun = MeasurementRun.build(valid_data) serialized = measurement_run.dump() # Audit info & dataset are not included in the dump - serialized['audit_info'] = valid_data['audit_info'] - serialized['dataset'] = valid_data['dataset'] - serialized['material']['audit_info'] = valid_data['material']['audit_info'] - serialized['material']['dataset'] = valid_data['material']['dataset'] + serialized["audit_info"] = valid_data["audit_info"] + serialized["dataset"] = valid_data["dataset"] + serialized["material"]["audit_info"] = valid_data["material"]["audit_info"] + serialized["material"]["dataset"] = valid_data["material"]["dataset"] assert serialized == valid_data def test_material_attachment(): """Test that a material can be attached to a measurement, and the connection survives serde.""" - cake = MaterialRun('Final Cake') - mass = MeasurementRun('Weigh cake', material=cake) + cake = MaterialRun("Final Cake") + mass = MeasurementRun("Weigh cake", material=cake) mass_data = mass.dump() mass_copy = MeasurementRun.build(mass_data) assert mass_copy == mass diff --git a/tests/serialization/test_object_template.py b/tests/serialization/test_object_template.py index cfb815045..7460ce099 100644 --- a/tests/serialization/test_object_template.py +++ b/tests/serialization/test_object_template.py @@ -1,4 +1,5 @@ """Tests of the object template schema.""" + from uuid import uuid4 from citrine.resources.material_template import MaterialTemplate @@ -15,61 +16,86 @@ def test_object_template_serde(): """Test serde of an object template.""" - length_template = PropertyTemplate("Length", bounds=RealBounds(2.0, 3.5, 'cm')) - sub_bounds = RealBounds(2.5, 3.0, 'cm') - color_template = PropertyTemplate("Color", bounds=CategoricalBounds(["red", "green", "blue"])) + length_template = PropertyTemplate("Length", bounds=RealBounds(2.0, 3.5, "cm")) + sub_bounds = RealBounds(2.5, 3.0, "cm") + color_template = PropertyTemplate( + "Color", bounds=CategoricalBounds(["red", "green", "blue"]) + ) # Properties are a mixture of property templates and [template, bounds], pairs - block_template = MaterialTemplate("Block", properties=[[length_template, sub_bounds], - color_template]) + block_template = MaterialTemplate( + "Block", properties=[[length_template, sub_bounds], color_template] + ) copy_template = MaterialTemplate.build(block_template.dump()) assert copy_template == block_template # Tests below exercise similar code, but for measurement and process templates - pressure_template = ConditionTemplate("pressure", bounds=RealBounds(0.1, 0.11, 'MPa')) + pressure_template = ConditionTemplate( + "pressure", bounds=RealBounds(0.1, 0.11, "MPa") + ) index_template = ParameterTemplate("index", bounds=IntegerBounds(2, 10)) - meas_template = MeasurementTemplate("A measurement of length", properties=[length_template], - conditions=[pressure_template], description="Description", - parameters=[index_template], tags=["foo"]) + meas_template = MeasurementTemplate( + "A measurement of length", + properties=[length_template], + conditions=[pressure_template], + description="Description", + parameters=[index_template], + tags=["foo"], + ) assert MeasurementTemplate.build(meas_template.dump()) == meas_template - proc_template = ProcessTemplate("Make an object", parameters=[index_template], - conditions=[pressure_template], allowed_labels=["Label"], - allowed_names=["first sample", "second sample"]) + proc_template = ProcessTemplate( + "Make an object", + parameters=[index_template], + conditions=[pressure_template], + allowed_labels=["Label"], + allowed_names=["first sample", "second sample"], + ) assert ProcessTemplate.build(proc_template.dump()) == proc_template # Check that serde still works if the template is a LinkByUID - pressure_template.uids['id'] = '12345' # uids['id'] not populated by default - proc_template.conditions[0][0] = LinkByUID('id', pressure_template.uids['id']) + pressure_template.uids["id"] = "12345" # uids['id'] not populated by default + proc_template.conditions[0][0] = LinkByUID("id", pressure_template.uids["id"]) assert ProcessTemplate.build(proc_template.dump()) == proc_template def test_bounds_optional(): """Test that each object template can have passthrough bounds for any of its attributes.""" + def link(): return LinkByUID(id=str(uuid4()), scope=str(uuid4())) + for template_type, attribute_args in [ - (MaterialTemplate, [ - ('properties', PropertyTemplate), - ]), - (ProcessTemplate, [ - ('conditions', ConditionTemplate), - ('parameters', ParameterTemplate), - ]), - (MeasurementTemplate, [ - ('properties', PropertyTemplate), - ('conditions', ConditionTemplate), - ('parameters', ParameterTemplate), - ]), + ( + MaterialTemplate, + [ + ("properties", PropertyTemplate), + ], + ), + ( + ProcessTemplate, + [ + ("conditions", ConditionTemplate), + ("parameters", ParameterTemplate), + ], + ), + ( + MeasurementTemplate, + [ + ("properties", PropertyTemplate), + ("conditions", ConditionTemplate), + ("parameters", ParameterTemplate), + ], + ), ]: kwargs = {} for name, attribute_type in attribute_args: kwargs[name] = [ [link(), IntegerBounds(0, 10)], link(), - attribute_type('foo', bounds=IntegerBounds(0, 10)), - (link(), None) + attribute_type("foo", bounds=IntegerBounds(0, 10)), + (link(), None), ] - template = template_type(name='foo', **kwargs) + template = template_type(name="foo", **kwargs) for name, _ in attribute_args: attributes = getattr(template, name) assert len(attributes) == 4 diff --git a/tests/serialization/test_objectives.py b/tests/serialization/test_objectives.py index 69f135669..3a2b6160d 100644 --- a/tests/serialization/test_objectives.py +++ b/tests/serialization/test_objectives.py @@ -1,23 +1,24 @@ """Tests for citrine.informatics.objectives.""" + import pytest -from citrine.informatics.objectives import Objective, ScalarMaxObjective, ScalarMinObjective +from citrine.informatics.objectives import ( + Objective, + ScalarMaxObjective, + ScalarMinObjective, +) @pytest.fixture def scalar_max_objective() -> ScalarMaxObjective: """Build a ScalarMaxObjective.""" - return ScalarMaxObjective( - descriptor_key="z" - ) + return ScalarMaxObjective(descriptor_key="z") @pytest.fixture def scalar_min_objective() -> ScalarMinObjective: """Build a ScalarMinObjective.""" - return ScalarMinObjective( - descriptor_key="z" - ) + return ScalarMinObjective(descriptor_key="z") def test_scalar_max_dumps(scalar_max_objective): diff --git a/tests/serialization/test_predictors.py b/tests/serialization/test_predictors.py index be55fd1c9..a976ed201 100644 --- a/tests/serialization/test_predictors.py +++ b/tests/serialization/test_predictors.py @@ -1,24 +1,38 @@ """Tests for citrine.informatics.predictors serialization.""" + from copy import deepcopy -from uuid import UUID import pytest -from . import predictor_serialization_check, valid_serialization_output, \ - predictor_node_serialization_check from citrine.informatics.descriptors import RealDescriptor -from citrine.informatics.predictors import * +from citrine.informatics.predictors import ( + AutoMLPredictor, + ExpressionPredictor, + GraphPredictor, + IngredientFractionsPredictor, + IngredientsToFormulationPredictor, + LabelFractionsPredictor, + MeanPropertyPredictor, + PredictorNode, + SimpleMixturePredictor, +) + +from . import predictor_node_serialization_check, valid_serialization_output def test_auto_ml_deserialization(valid_auto_ml_predictor_data): """Ensure that a deserialized SimplePredictor looks sane.""" predictor: AutoMLPredictor = AutoMLPredictor.build(valid_auto_ml_predictor_data) - assert predictor.name == 'AutoML predictor' - assert predictor.description == 'Predicts z from input x' + assert predictor.name == "AutoML predictor" + assert predictor.description == "Predicts z from input x" assert len(predictor.inputs) == 1 - assert predictor.inputs[0] == RealDescriptor("x", lower_bound=0, upper_bound=100, units="") + assert predictor.inputs[0] == RealDescriptor( + "x", lower_bound=0, upper_bound=100, units="" + ) assert len(predictor.outputs) == 1 - assert predictor.outputs[0] == RealDescriptor("z", lower_bound=0, upper_bound=100, units="") + assert predictor.outputs[0] == RealDescriptor( + "z", lower_bound=0, upper_bound=100, units="" + ) with pytest.deprecated_call(): assert len(predictor.training_data) == 0 @@ -26,12 +40,16 @@ def test_auto_ml_deserialization(valid_auto_ml_predictor_data): def test_polymorphic_auto_ml_deserialization(valid_auto_ml_predictor_data): """Ensure that a polymorphically deserialized SimplePredictor looks sane.""" predictor: AutoMLPredictor = PredictorNode.build(valid_auto_ml_predictor_data) - assert predictor.name == 'AutoML predictor' - assert predictor.description == 'Predicts z from input x' + assert predictor.name == "AutoML predictor" + assert predictor.description == "Predicts z from input x" assert len(predictor.inputs) == 1 - assert predictor.inputs[0] == RealDescriptor("x", lower_bound=0, upper_bound=100, units="") + assert predictor.inputs[0] == RealDescriptor( + "x", lower_bound=0, upper_bound=100, units="" + ) assert len(predictor.outputs) == 1 - assert predictor.outputs[0] == RealDescriptor("z", lower_bound=0, upper_bound=100, units="") + assert predictor.outputs[0] == RealDescriptor( + "z", lower_bound=0, upper_bound=100, units="" + ) with pytest.deprecated_call(): assert len(predictor.training_data) == 0 @@ -46,41 +64,56 @@ def test_graph_serialization(valid_graph_predictor_data): graph_data_copy = deepcopy(valid_graph_predictor_data) predictor = GraphPredictor.build(valid_graph_predictor_data) serialized = predictor.dump() - assert serialized['instance']['predictors'] == graph_data_copy['data']['instance']['predictors'] - assert serialized == valid_serialization_output(graph_data_copy['data']) + assert ( + serialized["instance"]["predictors"] + == graph_data_copy["data"]["instance"]["predictors"] + ) + assert serialized == valid_serialization_output(graph_data_copy["data"]) def test_expression_serialization(valid_expression_predictor_data): """Ensure that a serialized ExpressionPredictor looks sane.""" - predictor_node_serialization_check(valid_expression_predictor_data, ExpressionPredictor) + predictor_node_serialization_check( + valid_expression_predictor_data, ExpressionPredictor + ) def test_ing_to_formulation_serialization(valid_ing_formulation_predictor_data): """Ensure that a serialized IngredientsToFormulationPredictor looks sane.""" - predictor_node_serialization_check(valid_ing_formulation_predictor_data, IngredientsToFormulationPredictor) + predictor_node_serialization_check( + valid_ing_formulation_predictor_data, IngredientsToFormulationPredictor + ) def test_mean_property_serialization(valid_mean_property_predictor_data): """Ensure that a serialized MeanPropertyPredictor looks sane.""" - predictor_node_serialization_check(valid_mean_property_predictor_data, MeanPropertyPredictor) + predictor_node_serialization_check( + valid_mean_property_predictor_data, MeanPropertyPredictor + ) def test_simple_mixture_predictor_serialization(valid_simple_mixture_predictor_data): - predictor_node_serialization_check(valid_simple_mixture_predictor_data, SimpleMixturePredictor) + predictor_node_serialization_check( + valid_simple_mixture_predictor_data, SimpleMixturePredictor + ) def test_label_fractions_serialization(valid_label_fractions_predictor_data): """Ensure that a serialized LabelFractionPredictor looks sane.""" - predictor_node_serialization_check(valid_label_fractions_predictor_data, LabelFractionsPredictor) + predictor_node_serialization_check( + valid_label_fractions_predictor_data, LabelFractionsPredictor + ) def test_ingredient_fractions_serialization(valid_ingredient_fractions_predictor_data): - """"Ensure that a serialized IngredientsFractionsPredictor looks sane.""" - predictor_node_serialization_check(valid_ingredient_fractions_predictor_data, IngredientFractionsPredictor) + """ "Ensure that a serialized IngredientsFractionsPredictor looks sane.""" + predictor_node_serialization_check( + valid_ingredient_fractions_predictor_data, IngredientFractionsPredictor + ) def test_auto_ml_serialization(valid_auto_ml_predictor_data): - """"Ensure that a serialized AutoMLPredictor looks sane.""" + """ "Ensure that a serialized AutoMLPredictor looks sane.""" predictor_node_serialization_check(valid_auto_ml_predictor_data, AutoMLPredictor) diff --git a/tests/serialization/test_process_run.py b/tests/serialization/test_process_run.py index 06d538470..3cf3b25a2 100644 --- a/tests/serialization/test_process_run.py +++ b/tests/serialization/test_process_run.py @@ -1,4 +1,5 @@ """Tests of the Process Run schema""" + import pytest from uuid import uuid4 @@ -13,56 +14,87 @@ def valid_data(): """Return valid data used for these tests.""" return dict( - uids={'id': str(uuid4()), 'my_id': 'process1-v1'}, - name='Process 1', - tags=['baking::cakes', 'danger::low'], - notes='make sure to use oven mitts', - conditions=[{'name': 'oven temp', 'type': 'condition', 'notes': None, - 'template': None, 'origin': 'measured', 'file_links': [], - 'value': {'nominal': 203.0, 'units': 'dimensionless', 'type': 'nominal_real'} - }], + uids={"id": str(uuid4()), "my_id": "process1-v1"}, + name="Process 1", + tags=["baking::cakes", "danger::low"], + notes="make sure to use oven mitts", + conditions=[ + { + "name": "oven temp", + "type": "condition", + "notes": None, + "template": None, + "origin": "measured", + "file_links": [], + "value": { + "nominal": 203.0, + "units": "dimensionless", + "type": "nominal_real", + }, + } + ], parameters=[], - spec={'type': 'process_spec', 'name': 'Spec for proc 1', - 'uids': {'id': str(uuid4())}, 'file_links': [], 'notes': None, - 'conditions': [{'type': 'condition', 'name': 'oven temp', 'origin': 'specified', - 'template': None, 'notes': None, 'file_links': [], - 'value': {'type': 'uniform_real', 'units': 'dimensionless', - 'lower_bound': 175, 'upper_bound': 225 - } - }], - 'template': None, 'tags': [], 'parameters': [] - }, + spec={ + "type": "process_spec", + "name": "Spec for proc 1", + "uids": {"id": str(uuid4())}, + "file_links": [], + "notes": None, + "conditions": [ + { + "type": "condition", + "name": "oven temp", + "origin": "specified", + "template": None, + "notes": None, + "file_links": [], + "value": { + "type": "uniform_real", + "units": "dimensionless", + "lower_bound": 175, + "upper_bound": 225, + }, + } + ], + "template": None, + "tags": [], + "parameters": [], + }, file_links=[], - type='process_run', + type="process_run", source={ "type": "performed_source", "performed_by": "Marie Curie", - "performed_date": None - } + "performed_date": None, + }, ) def test_simple_deserialization(valid_data): """Ensure that a deserialized Process Run looks sane.""" process_run: ProcessRun = ProcessRun.build(valid_data) - assert process_run.uids == {'id': valid_data['uids']['id'], 'my_id': 'process1-v1'} - assert process_run.tags == ['baking::cakes', 'danger::low'] - assert process_run.conditions[0] == Condition(name='oven temp', - value=NominalReal(203.0, ''), - origin='measured') + assert process_run.uids == {"id": valid_data["uids"]["id"], "my_id": "process1-v1"} + assert process_run.tags == ["baking::cakes", "danger::low"] + assert process_run.conditions[0] == Condition( + name="oven temp", value=NominalReal(203.0, ""), origin="measured" + ) assert process_run.parameters == [] assert process_run.file_links == [] assert process_run.template is None assert process_run.output_material is None - assert process_run.spec == \ - ProcessSpec(name="Spec for proc 1", tags=[], - uids={'id': valid_data['spec']['uids']['id']}, - conditions=[Condition(name='oven temp', value=UniformReal(175, 225, ''), - origin='specified')] - ) - assert process_run.name == 'Process 1' - assert process_run.notes == 'make sure to use oven mitts' - assert process_run.typ == 'process_run' + assert process_run.spec == ProcessSpec( + name="Spec for proc 1", + tags=[], + uids={"id": valid_data["spec"]["uids"]["id"]}, + conditions=[ + Condition( + name="oven temp", value=UniformReal(175, 225, ""), origin="specified" + ) + ], + ) + assert process_run.name == "Process 1" + assert process_run.notes == "make sure to use oven mitts" + assert process_run.typ == "process_run" def test_serialization(valid_data): diff --git a/tests/serialization/test_process_spec.py b/tests/serialization/test_process_spec.py index bf0f8b3e2..162d1506c 100644 --- a/tests/serialization/test_process_spec.py +++ b/tests/serialization/test_process_spec.py @@ -1,4 +1,5 @@ """Tests of the Process Run schema""" + import pytest from uuid import uuid4, UUID @@ -15,74 +16,109 @@ def valid_data(): """Return valid data used for these tests.""" return dict( - uids={'id': str(uuid4())}, - name='Process 1', - tags=['baking::cakes', 'danger::low'], - notes='make sure to use oven mitts', - parameters=[{'name': 'oven temp', 'type': 'parameter', - 'template': None, 'origin': 'specified', 'notes': None, 'file_links': [], - 'value': {'lower_bound': 195, 'upper_bound': 205, - 'units': 'dimensionless', 'type': 'uniform_real'} - }], + uids={"id": str(uuid4())}, + name="Process 1", + tags=["baking::cakes", "danger::low"], + notes="make sure to use oven mitts", + parameters=[ + { + "name": "oven temp", + "type": "parameter", + "template": None, + "origin": "specified", + "notes": None, + "file_links": [], + "value": { + "lower_bound": 195, + "upper_bound": 205, + "units": "dimensionless", + "type": "uniform_real", + }, + } + ], conditions=[], template={ - 'name': 'the template', - 'tags': [], - 'uids': {'id': str(uuid4())}, - 'type': 'process_template', - 'conditions': [], - 'parameters': [ + "name": "the template", + "tags": [], + "uids": {"id": str(uuid4())}, + "type": "process_template", + "conditions": [], + "parameters": [ [ { - 'type': 'parameter_template', - 'name': 'oven temp template', - 'tags': [], - 'bounds': {'type': 'real_bounds', 'lower_bound': 175, 'upper_bound': 225, 'default_units': 'dimensionless'}, - 'uids': {'id': str(uuid4())}, - 'description': None, + "type": "parameter_template", + "name": "oven temp template", + "tags": [], + "bounds": { + "type": "real_bounds", + "lower_bound": 175, + "upper_bound": 225, + "default_units": "dimensionless", + }, + "uids": {"id": str(uuid4())}, + "description": None, }, { - 'type': 'real_bounds', - 'lower_bound': 175, 'upper_bound': 225, 'default_units': 'dimensionless' - } + "type": "real_bounds", + "lower_bound": 175, + "upper_bound": 225, + "default_units": "dimensionless", + }, ] ], - 'allowed_labels': ['a', 'b'], - 'allowed_names': ['a name'], - 'description': 'a long description', + "allowed_labels": ["a", "b"], + "allowed_names": ["a name"], + "description": "a long description", }, - file_links=[{'type': 'file_link', 'filename': 'cake_recipe.txt', 'url': 'www.baking.com'}], - audit_info={'created_by': str(uuid4()), 'created_at': 1559933807392}, - type='process_spec' + file_links=[ + { + "type": "file_link", + "filename": "cake_recipe.txt", + "url": "www.baking.com", + } + ], + audit_info={"created_by": str(uuid4()), "created_at": 1559933807392}, + type="process_spec", ) def test_simple_deserialization(valid_data): """Ensure that a deserialized Process Spec looks sane.""" process_spec: ProcessSpec = ProcessSpec.build(valid_data) - assert process_spec.uids == {'id': valid_data['uids']['id']} - assert process_spec.tags == ['baking::cakes', 'danger::low'] - assert process_spec.parameters[0] == Parameter(name='oven temp', - value=UniformReal(195, 205, ''), - origin='specified') + assert process_spec.uids == {"id": valid_data["uids"]["id"]} + assert process_spec.tags == ["baking::cakes", "danger::low"] + assert process_spec.parameters[0] == Parameter( + name="oven temp", value=UniformReal(195, 205, ""), origin="specified" + ) assert process_spec.conditions == [] - assert process_spec.template == \ - ProcessTemplate('the template', tags=[], - uids={'id': valid_data['template']['uids']['id']}, - parameters=[ - [ParameterTemplate('oven temp template', tags=[], - bounds=RealBounds(175, 225, ''), - uids={'id': valid_data['template']['parameters'][0][0]['uids']['id']}), - RealBounds(175, 225, '')] - ], - description='a long description', - allowed_labels=['a', 'b'], - allowed_names=['a name']) - assert process_spec.name == 'Process 1' - assert process_spec.notes == 'make sure to use oven mitts' - assert process_spec.file_links == [FileLink('cake_recipe.txt', 'www.baking.com')] - assert process_spec.typ == 'process_spec' - assert process_spec.audit_info.created_by == UUID(valid_data['audit_info']['created_by']) + assert process_spec.template == ProcessTemplate( + "the template", + tags=[], + uids={"id": valid_data["template"]["uids"]["id"]}, + parameters=[ + [ + ParameterTemplate( + "oven temp template", + tags=[], + bounds=RealBounds(175, 225, ""), + uids={ + "id": valid_data["template"]["parameters"][0][0]["uids"]["id"] + }, + ), + RealBounds(175, 225, ""), + ] + ], + description="a long description", + allowed_labels=["a", "b"], + allowed_names=["a name"], + ) + assert process_spec.name == "Process 1" + assert process_spec.notes == "make sure to use oven mitts" + assert process_spec.file_links == [FileLink("cake_recipe.txt", "www.baking.com")] + assert process_spec.typ == "process_spec" + assert process_spec.audit_info.created_by == UUID( + valid_data["audit_info"]["created_by"] + ) def test_serialization(valid_data): @@ -90,5 +126,5 @@ def test_serialization(valid_data): process_spec: ProcessSpec = ProcessSpec.build(valid_data) serialized = process_spec.dump() # Audit info & dataset are not included in the dump - serialized['audit_info'] = valid_data['audit_info'] + serialized["audit_info"] = valid_data["audit_info"] assert serialized == valid_data diff --git a/tests/serialization/test_project.py b/tests/serialization/test_project.py index 271a81b8e..08719601a 100644 --- a/tests/serialization/test_project.py +++ b/tests/serialization/test_project.py @@ -1,4 +1,5 @@ """Tests of the Project schema.""" + import pytest from uuid import uuid4, UUID from citrine.resources.project import Project @@ -11,19 +12,19 @@ def valid_data(): return dict( id=str(uuid4()), created_at=1559933807392, - name='my project', - description='a good project', - status='in-progress' + name="my project", + description="a good project", + status="in-progress", ) def test_simple_deserialization(valid_data): """Ensure that a deserialized Project looks sane.""" project: Project = Project.build(valid_data) - assert project.uid == UUID(valid_data['id']) - assert project.created_at == arrow.get(valid_data['created_at'] / 1000).datetime - assert project.name == 'my project' - assert project.status == 'in-progress' + assert project.uid == UUID(valid_data["id"]) + assert project.created_at == arrow.get(valid_data["created_at"] / 1000).datetime + assert project.name == "my project" + assert project.status == "in-progress" def test_serialization(valid_data): diff --git a/tests/serialization/test_reports.py b/tests/serialization/test_reports.py index b895c3d83..605cd5e66 100644 --- a/tests/serialization/test_reports.py +++ b/tests/serialization/test_reports.py @@ -1,21 +1,21 @@ """Tests for citrine.informatics.reports serialization.""" + import logging import pytest from copy import deepcopy -import warnings from uuid import UUID from citrine.informatics.descriptors import RealDescriptor -from citrine.informatics.reports import Report, ModelSummary, FeatureImportanceReport +from citrine.informatics.reports import Report, ModelSummary def test_predictor_report_build(valid_predictor_report_data): """Build a predictor report and verify its structure.""" report = Report.build(valid_predictor_report_data) - assert report.status == 'OK' - assert str(report.uid) == valid_predictor_report_data['id'] + assert report.status == "OK" + assert str(report.uid) == valid_predictor_report_data["id"] x = RealDescriptor("x", lower_bound=0, upper_bound=1, units="") y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="") @@ -23,40 +23,42 @@ def test_predictor_report_build(valid_predictor_report_data): assert report.descriptors == [x, y, z] lolo_model: ModelSummary = report.model_summaries[0] - assert lolo_model.name == 'GeneralLoloModel_1' - assert lolo_model.type_ == 'ML Model' + assert lolo_model.name == "GeneralLoloModel_1" + assert lolo_model.type_ == "ML Model" assert lolo_model.inputs == [x] assert lolo_model.outputs == [y] assert lolo_model.model_settings == { - 'Algorithm': 'Ensemble of non-linear estimators', - 'Number of estimators': 64, - 'Leaf model': 'Mean', - 'Use jackknife': True + "Algorithm": "Ensemble of non-linear estimators", + "Number of estimators": 64, + "Leaf model": "Mean", + "Use jackknife": True, } feature_importance = lolo_model.feature_importances[0] assert feature_importance.importances == {"x": 1.0} assert feature_importance.output_key == "y" - assert lolo_model.predictor_name == 'Predict y from x with ML' + assert lolo_model.predictor_name == "Predict y from x with ML" assert lolo_model.predictor_uid is None exp_model: ModelSummary = report.model_summaries[1] - assert exp_model.name == 'GeneralLosslessModel_2' - assert exp_model.type_ == 'Analytic Model' + assert exp_model.name == "GeneralLosslessModel_2" + assert exp_model.type_ == "Analytic Model" assert exp_model.inputs == [x, y] assert exp_model.outputs == [z] - assert exp_model.model_settings == { - "Expression": "(z) <- (x + y)" - } + assert exp_model.model_settings == {"Expression": "(z) <- (x + y)"} assert exp_model.feature_importances == [] - assert exp_model.predictor_name == 'Expression for z' + assert exp_model.predictor_name == "Expression for z" assert exp_model.predictor_uid == UUID("249bf32c-6f3d-4a93-9387-94cc877f170c") def test_empty_report_build(): """Build a predictor report when the 'report' field is somehow unfilled.""" - Report.build(dict(id='7c2dda5d-675a-41b6-829c-e485163f0e43', status='PENDING')) - Report.build(dict(id='7c2dda5d-675a-41b6-829c-e485163f0e43', status='PENDING', report=None)) - Report.build(dict(id='7c2dda5d-675a-41b6-829c-e485163f0e43', status='PENDING', report=dict())) + Report.build(dict(id="7c2dda5d-675a-41b6-829c-e485163f0e43", status="PENDING")) + Report.build( + dict(id="7c2dda5d-675a-41b6-829c-e485163f0e43", status="PENDING", report=None) + ) + Report.build( + dict(id="7c2dda5d-675a-41b6-829c-e485163f0e43", status="PENDING", report=dict()) + ) def test_bad_predictor_report_build(caplog, valid_predictor_report_data): @@ -64,7 +66,7 @@ def test_bad_predictor_report_build(caplog, valid_predictor_report_data): too_many_descriptors = deepcopy(valid_predictor_report_data) # Multiple descriptors with the same key other_x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="") - too_many_descriptors['report']['descriptors'].append(other_x.dump()) + too_many_descriptors["report"]["descriptors"].append(other_x.dump()) with caplog.at_level(logging.WARNING): caplog.clear() Report.build(too_many_descriptors) @@ -73,9 +75,9 @@ def test_bad_predictor_report_build(caplog, valid_predictor_report_data): # A key that appears in inputs and/or outputs, but there is no corresponding descriptor. # This is done twice for coverage, once to catch a missing input and once for a missing output. too_few_descriptors = deepcopy(valid_predictor_report_data) - too_few_descriptors['report']['descriptors'].pop() + too_few_descriptors["report"]["descriptors"].pop() with pytest.raises(RuntimeError): Report.build(too_few_descriptors) - too_few_descriptors['report']['descriptors'] = [] + too_few_descriptors["report"]["descriptors"] = [] with pytest.raises(RuntimeError): Report.build(too_few_descriptors) diff --git a/tests/serialization/test_scorers.py b/tests/serialization/test_scorers.py index 8de71ff94..615a51a4a 100644 --- a/tests/serialization/test_scorers.py +++ b/tests/serialization/test_scorers.py @@ -1,9 +1,8 @@ """Tests for citrine.informatics.scores.""" + from citrine.informatics.objectives import ScalarMaxObjective from citrine.informatics.scores import Score, EIScore, LIScore -from tests.informatics.test_scores import li_score, ei_score - def test_li_dumps(li_score): """Ensure values are persisted through deser.""" diff --git a/tests/serialization/test_user.py b/tests/serialization/test_user.py index ee96ae108..a880a9b32 100644 --- a/tests/serialization/test_user.py +++ b/tests/serialization/test_user.py @@ -1,4 +1,5 @@ """Tests of the Project schema.""" + import pytest from uuid import uuid4 from citrine.resources.user import User @@ -9,19 +10,19 @@ def valid_data(): """Return valid data used for these tests.""" return dict( id=str(uuid4()), - screen_name='bob', - position='the builder', - email='bob@thebuilder.com', - is_admin=True + screen_name="bob", + position="the builder", + email="bob@thebuilder.com", + is_admin=True, ) def test_simple_deserialization(valid_data): """Ensure a deserialized User looks sane.""" user: User = User.build(valid_data) - assert user.screen_name == 'bob' - assert user.position == 'the builder' - assert user.email == 'bob@thebuilder.com' + assert user.screen_name == "bob" + assert user.position == "the builder" + assert user.email == "bob@thebuilder.com" assert user.is_admin diff --git a/tests/serialization/test_workflow.py b/tests/serialization/test_workflow.py index 3a25a82bf..a2c08857d 100644 --- a/tests/serialization/test_workflow.py +++ b/tests/serialization/test_workflow.py @@ -1,4 +1,5 @@ """Tests of the Project schema.""" + import pytest from datetime import datetime from uuid import uuid4, UUID @@ -10,47 +11,57 @@ def valid_data(): """Return valid data used for these tests.""" return dict( id=str(uuid4()), - name='A rad new workflow', - description='All about my workflow', - status='SUCCEEDED', - status_description='READY', - status_detail=[{'level': 'Info', 'msg': 'Things are looking good'}], + name="A rad new workflow", + description="All about my workflow", + status="SUCCEEDED", + status_description="READY", + status_detail=[{"level": "Info", "msg": "Things are looking good"}], archived=False, design_space_id=str(uuid4()), predictor_id=str(uuid4()), created_by=str(uuid4()), - create_time=datetime(2020, 1, 1, 1, 1, 1, 1).isoformat("T") + create_time=datetime(2020, 1, 1, 1, 1, 1, 1).isoformat("T"), ) @pytest.fixture def valid_serialization_output(valid_data): - return {x: y for x, y in valid_data.items() if x not in - ['status', 'status_detail', 'status_description', 'created_by', 'create_time']} + return { + x: y + for x, y in valid_data.items() + if x + not in [ + "status", + "status_detail", + "status_description", + "created_by", + "create_time", + ] + } def test_simple_deserialization(valid_data): """Ensure a deserialized DesignWorkflow looks sane.""" workflow: DesignWorkflow = DesignWorkflow.build(valid_data) - assert workflow.design_space_id == UUID(valid_data['design_space_id']) - assert workflow.predictor_id == UUID(valid_data['predictor_id']) + assert workflow.design_space_id == UUID(valid_data["design_space_id"]) + assert workflow.predictor_id == UUID(valid_data["predictor_id"]) def test_deserialization_missing_created_by(valid_data): """Ensure a DesignWorkflow can be deserialized with no created_by field.""" - valid_data['created_by'] = None + valid_data["created_by"] = None workflow: DesignWorkflow = DesignWorkflow.build(valid_data) - assert workflow.design_space_id == UUID(valid_data['design_space_id']) + assert workflow.design_space_id == UUID(valid_data["design_space_id"]) assert workflow.created_by is None def test_deserialization_missing_create_time(valid_data): """Ensure a DesignWorkflow can be deserialized with no created_by field.""" - valid_data['create_time'] = None + valid_data["create_time"] = None workflow: DesignWorkflow = DesignWorkflow.build(valid_data) - assert workflow.design_space_id == UUID(valid_data['design_space_id']) + assert workflow.design_space_id == UUID(valid_data["design_space_id"]) assert workflow.create_time is None @@ -58,7 +69,7 @@ def test_serialization(valid_data, valid_serialization_output): """Ensure a serialized DesignWorkflow looks sane.""" workflow: DesignWorkflow = DesignWorkflow.build(valid_data) serialized = workflow.dump() - serialized['id'] = valid_data['id'] + serialized["id"] = valid_data["id"] # we can have extra fields in the output of `dump` # these support forwards and backwards compatibility for k in valid_serialization_output: diff --git a/tests/test_citrine.py b/tests/test_citrine.py index 539ce2a26..8a99277f2 100644 --- a/tests/test_citrine.py +++ b/tests/test_citrine.py @@ -9,11 +9,8 @@ def refresh_token(expiration: datetime = None) -> dict: - token = jwt.encode( - payload={'exp': expiration.timestamp()}, - key='garbage' - ) - return {'access_token': token} + token = jwt.encode(payload={"exp": expiration.timestamp()}, key="garbage") + return {"access_token": token} token_refresh_response = refresh_token(datetime(2019, 3, 14, tzinfo=timezone.utc)) @@ -21,19 +18,32 @@ def refresh_token(expiration: datetime = None) -> dict: def test_citrine_creation(): with requests_mock.Mocker() as m: - m.post('https://citrine-testing.fake/api/v1/tokens/refresh', json=token_refresh_response) + m.post( + "https://citrine-testing.fake/api/v1/tokens/refresh", + json=token_refresh_response, + ) - assert '1234' == Citrine(api_key='1234', host='citrine-testing.fake').session.refresh_token + assert ( + "1234" + == Citrine( + api_key="1234", host="citrine-testing.fake" + ).session.refresh_token + ) def test_citrine_signature(monkeypatch): with requests_mock.Mocker() as m: - m.post('http://citrine-testing.fake:8080/api/v1/tokens/refresh', json=token_refresh_response) - - assert '1234' == Citrine(api_key='1234', - scheme='http', - host='citrine-testing.fake', - port="8080").session.refresh_token + m.post( + "http://citrine-testing.fake:8080/api/v1/tokens/refresh", + json=token_refresh_response, + ) + + assert ( + "1234" + == Citrine( + api_key="1234", scheme="http", host="citrine-testing.fake", port="8080" + ).session.refresh_token + ) # Validate defaults with requests_mock.Mocker() as m: @@ -41,7 +51,9 @@ def test_citrine_signature(monkeypatch): patched_host = "monkeypatch.citrine-testing.fake" monkeypatch.setenv("CITRINE_API_KEY", patched_key) monkeypatch.setenv("CITRINE_API_HOST", patched_host) - m.post(f'https://{patched_host}/api/v1/tokens/refresh', json=token_refresh_response) + m.post( + f"https://{patched_host}/api/v1/tokens/refresh", json=token_refresh_response + ) assert patched_key == Citrine().session.refresh_token assert patched_key == Citrine(api_key=patched_key).session.refresh_token @@ -55,45 +67,60 @@ def test_citrine_signature(monkeypatch): def test_citrine_project_session(): with requests_mock.Mocker() as m: - m.post('https://citrine-testing.fake/api/v1/tokens/refresh', json=token_refresh_response) + m.post( + "https://citrine-testing.fake/api/v1/tokens/refresh", + json=token_refresh_response, + ) - citrine = Citrine(api_key='foo', host='citrine-testing.fake') + citrine = Citrine(api_key="foo", host="citrine-testing.fake") assert citrine.session == citrine.projects.session def test_citrine_user_session(): with requests_mock.Mocker() as m: - m.post('https://citrine-testing.fake/api/v1/tokens/refresh', json=token_refresh_response) - citrine = Citrine(api_key='foo', host='citrine-testing.fake') + m.post( + "https://citrine-testing.fake/api/v1/tokens/refresh", + json=token_refresh_response, + ) + citrine = Citrine(api_key="foo", host="citrine-testing.fake") assert citrine.session == citrine.users.session def test_citrine_team_session(): with requests_mock.Mocker() as m: - m.post('https://citrine-testing.fake/api/v1/tokens/refresh', json=token_refresh_response) - citrine = Citrine(api_key='foo', host='citrine-testing.fake') + m.post( + "https://citrine-testing.fake/api/v1/tokens/refresh", + json=token_refresh_response, + ) + citrine = Citrine(api_key="foo", host="citrine-testing.fake") assert citrine.session == citrine.teams.session def test_citrine_catalyst_session(): with requests_mock.Mocker() as m: - m.post('https://citrine-testing.fake/api/v1/tokens/refresh', json=token_refresh_response) - citrine = Citrine(api_key='foo', host='citrine-testing.fake') + m.post( + "https://citrine-testing.fake/api/v1/tokens/refresh", + json=token_refresh_response, + ) + citrine = Citrine(api_key="foo", host="citrine-testing.fake") assert citrine.session == citrine.catalyst.session def test_citrine_user_agent(): with requests_mock.Mocker() as m: - m.post('https://citrine-testing.fake/api/v1/tokens/refresh', json=token_refresh_response) - citrine = Citrine(api_key='foo', host='citrine-testing.fake') + m.post( + "https://citrine-testing.fake/api/v1/tokens/refresh", + json=token_refresh_response, + ) + citrine = Citrine(api_key="foo", host="citrine-testing.fake") - agent_parts = citrine.session.headers['User-Agent'].split() - python_impls = {'CPython', 'IronPython', 'Jython', 'PyPy'} - expected_products = {'python-requests', 'citrine-python'} + agent_parts = citrine.session.headers["User-Agent"].split() + python_impls = {"CPython", "IronPython", "Jython", "PyPy"} + expected_products = {"python-requests", "citrine-python"} for product in agent_parts: - product_name, product_version = product.split('/') + product_name, product_version = product.split("/") assert product_name in {*python_impls, *expected_products} if product_name in python_impls: @@ -102,4 +129,4 @@ def test_citrine_user_agent(): # Check that the version is major.minor.patch but don't # enforce them to be ints. It's common to see strings used # as the patch version - assert len(product_version.split('.')) == 3 + assert len(product_version.split(".")) == 3 diff --git a/tests/test_session.py b/tests/test_session.py index f1e105bb3..f44871744 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -6,7 +6,8 @@ Conflict, NonRetryableException, WorkflowNotReadyException, - RetryableException) + RetryableException, +) from datetime import datetime, timedelta, timezone @@ -19,22 +20,20 @@ def refresh_token(expiration: datetime = None) -> dict: - token = jwt.encode( - payload={'exp': expiration.timestamp()}, - key='garbage' - ) - return {'access_token': token} + token = jwt.encode(payload={"exp": expiration.timestamp()}, key="garbage") + return {"access_token": token} @pytest.fixture def session(): token_refresh_response = refresh_token(datetime(2019, 3, 14, tzinfo=timezone.utc)) with requests_mock.Mocker() as m: - m.post('http://citrine-testing.fake/api/v1/tokens/refresh', json=token_refresh_response) + m.post( + "http://citrine-testing.fake/api/v1/tokens/refresh", + json=token_refresh_response, + ) session = Session( - refresh_token='12345', - scheme='http', - host='citrine-testing.fake' + refresh_token="12345", scheme="http", host="citrine-testing.fake" ) # Default behavior is to *not* require a refresh - those tests can clear this out # As rule of thumb, we should be using freezegun or similar to never rely on the system clock @@ -47,12 +46,20 @@ def session(): def test_session_signature(monkeypatch): token_refresh_response = refresh_token(datetime(2019, 3, 14, tzinfo=timezone.utc)) with requests_mock.Mocker() as m: - m.post('ftp://citrine-testing.fake:8080/api/v1/tokens/refresh', json=token_refresh_response) + m.post( + "ftp://citrine-testing.fake:8080/api/v1/tokens/refresh", + json=token_refresh_response, + ) - assert '1234' == Session(refresh_token='1234', - scheme='ftp', - host='citrine-testing.fake', - port="8080").refresh_token + assert ( + "1234" + == Session( + refresh_token="1234", + scheme="ftp", + host="citrine-testing.fake", + port="8080", + ).refresh_token + ) # Validate defaults with requests_mock.Mocker() as m: @@ -60,7 +67,9 @@ def test_session_signature(monkeypatch): patched_host = "monkeypatch.citrine-testing.fake" monkeypatch.setenv("CITRINE_API_KEY", patched_key) monkeypatch.setenv("CITRINE_API_HOST", patched_host) - m.post(f'https://{patched_host}/api/v1/tokens/refresh', json=token_refresh_response) + m.post( + f"https://{patched_host}/api/v1/tokens/refresh", json=token_refresh_response + ) assert patched_key == Session().refresh_token assert patched_key == Session(refresh_token=patched_key).refresh_token @@ -77,14 +86,19 @@ def test_get_refreshes_token(session: Session): token_refresh_response = refresh_token(datetime(2019, 3, 14, tzinfo=timezone.utc)) with requests_mock.Mocker() as m: - m.post('http://citrine-testing.fake/api/v1/tokens/refresh', json=token_refresh_response) - m.get('http://citrine-testing.fake/api/v1/foo', - json={'foo': 'bar'}, - headers={'content-type': "application/json"}) + m.post( + "http://citrine-testing.fake/api/v1/tokens/refresh", + json=token_refresh_response, + ) + m.get( + "http://citrine-testing.fake/api/v1/foo", + json={"foo": "bar"}, + headers={"content-type": "application/json"}, + ) - resp = session.get_resource('/foo') + resp = session.get_resource("/foo") - assert {'foo': 'bar'} == resp + assert {"foo": "bar"} == resp assert datetime(2019, 3, 14, tzinfo=timezone.utc) == session.access_token_expiration @@ -92,37 +106,41 @@ def test_get_refresh_token_failure(session: Session): session.access_token_expiration = datetime.now(timezone.utc) - timedelta(minutes=1) with requests_mock.Mocker() as m: - m.post('http://citrine-testing.fake/api/v1/tokens/refresh', status_code=401) + m.post("http://citrine-testing.fake/api/v1/tokens/refresh", status_code=401) with pytest.raises(UnauthorizedRefreshToken): - session.get_resource('/foo') + session.get_resource("/foo") def test_get_no_refresh(session: Session): with requests_mock.Mocker() as m: - m.get('http://citrine-testing.fake/api/v1/foo', json={'foo': 'bar'}, headers={'content-type': "application/json"}) - resp = session.get_resource('/foo') + m.get( + "http://citrine-testing.fake/api/v1/foo", + json={"foo": "bar"}, + headers={"content-type": "application/json"}, + ) + resp = session.get_resource("/foo") - assert {'foo': 'bar'} == resp + assert {"foo": "bar"} == resp def test_get_not_found(session: Session): with requests_mock.Mocker() as m: - m.get('http://citrine-testing.fake/api/v1/foo', status_code=404) + m.get("http://citrine-testing.fake/api/v1/foo", status_code=404) with pytest.raises(NotFound): - session.get_resource('/foo') + session.get_resource("/foo") def test_status_code_409(session: Session): with requests_mock.Mocker() as m: - url = '/foo' - conflict_message = 'you have a conflict' + url = "/foo" + conflict_message = "you have a conflict" resp_json = { - 'code': 409, - 'message': 'a message', - 'validation_errors': [{'failure_message': conflict_message}] + "code": 409, + "message": "a message", + "validation_errors": [{"failure_message": conflict_message}], } - m.get('http://citrine-testing.fake/api/v1/foo', status_code=409, json=resp_json) + m.get("http://citrine-testing.fake/api/v1/foo", status_code=409, json=resp_json) with pytest.raises(Conflict) as einfo: session.get_resource(url) @@ -133,62 +151,69 @@ def test_status_code_409(session: Session): def test_status_code_425(session: Session): with requests_mock.Mocker() as m: - m.get('http://citrine-testing.fake/api/v1/foo', status_code=425) + m.get("http://citrine-testing.fake/api/v1/foo", status_code=425) with pytest.raises(RetryableException): - session.get_resource('/foo') + session.get_resource("/foo") with pytest.raises(WorkflowNotReadyException): - session.get_resource('/foo') + session.get_resource("/foo") def test_status_code_400(session: Session): with requests_mock.Mocker() as m: resp_json = { - 'code': 400, - 'message': 'a message', - 'validation_errors': [ + "code": 400, + "message": "a message", + "validation_errors": [ { - 'failure_message': 'you have failed', + "failure_message": "you have failed", }, ], } - m.get('http://citrine-testing.fake/api/v1/foo', - status_code=400, - json=resp_json - ) + m.get("http://citrine-testing.fake/api/v1/foo", status_code=400, json=resp_json) with pytest.raises(BadRequest) as einfo: - session.get_resource('/foo') - assert einfo.value.api_error.validation_errors[0].failure_message \ - == resp_json['validation_errors'][0]['failure_message'] + session.get_resource("/foo") + assert ( + einfo.value.api_error.validation_errors[0].failure_message + == resp_json["validation_errors"][0]["failure_message"] + ) def test_status_code_401(session: Session): with requests_mock.Mocker() as m: - m.get('http://citrine-testing.fake/api/v1/foo', status_code=401) + m.get("http://citrine-testing.fake/api/v1/foo", status_code=401) with pytest.raises(NonRetryableException): - session.get_resource('/foo') + session.get_resource("/foo") with pytest.raises(Unauthorized): - session.get_resource('/foo') + session.get_resource("/foo") def test_status_code_404(session: Session): with requests_mock.Mocker() as m: - m.get('http://citrine-testing.fake/api/v1/foo', status_code=404) + m.get("http://citrine-testing.fake/api/v1/foo", status_code=404) with pytest.raises(NonRetryableException): - session.get_resource('/foo') + session.get_resource("/foo") def test_connection_error(session: Session): - data = {'stuff': 'not_used'} + data = {"stuff": "not_used"} # Simulate a request using a stale session that raises # a ConnectionError then works on the second call. with requests_mock.Mocker() as m: - m.register_uri('GET', - 'http://citrine-testing.fake/api/v1/foo', - [{'exc': requests.exceptions.ConnectionError}, - {'json': data, 'status_code': 200, 'headers': {'content-type': "application/json"}}]) + m.register_uri( + "GET", + "http://citrine-testing.fake/api/v1/foo", + [ + {"exc": requests.exceptions.ConnectionError}, + { + "json": data, + "status_code": 200, + "headers": {"content-type": "application/json"}, + }, + ], + ) - resp = session.get_resource('/foo') + resp = session.get_resource("/foo") assert resp == data @@ -196,58 +221,68 @@ def test_post_refreshes_token_when_denied(session: Session): token_refresh_response = refresh_token(datetime(2019, 3, 14, tzinfo=timezone.utc)) with requests_mock.Mocker() as m: - m.post('http://citrine-testing.fake/api/v1/tokens/refresh', json=token_refresh_response) - m.register_uri('POST', 'http://citrine-testing.fake/api/v1/foo', [ - {'status_code': 401, 'json': {'reason': 'invalid-token'}}, - {'json': {'foo': 'bar'}, 'headers': {'content-type': "application/json"}} - ]) + m.post( + "http://citrine-testing.fake/api/v1/tokens/refresh", + json=token_refresh_response, + ) + m.register_uri( + "POST", + "http://citrine-testing.fake/api/v1/foo", + [ + {"status_code": 401, "json": {"reason": "invalid-token"}}, + { + "json": {"foo": "bar"}, + "headers": {"content-type": "application/json"}, + }, + ], + ) - resp = session.post_resource('/foo', json={'data': 'hi'}) + resp = session.post_resource("/foo", json={"data": "hi"}) - assert {'foo': 'bar'} == resp + assert {"foo": "bar"} == resp assert datetime(2019, 3, 14, tzinfo=timezone.utc) == session.access_token_expiration # this test exists to provide 100% coverage for the legacy 401 status on Unauthorized responses def test_delete_unauthorized_without_json_legacy(session: Session): with requests_mock.Mocker() as m: - m.delete('http://citrine-testing.fake/api/v1/bar/something', status_code=401) + m.delete("http://citrine-testing.fake/api/v1/bar/something", status_code=401) with pytest.raises(Unauthorized): - session.delete_resource('/bar/something') + session.delete_resource("/bar/something") def test_delete_unauthorized_with_str_json_legacy(session: Session): with requests_mock.Mocker() as m: m.delete( - 'http://citrine-testing.fake/api/v1/bar/something', + "http://citrine-testing.fake/api/v1/bar/something", status_code=401, - json='an error string' + json="an error string", ) with pytest.raises(Unauthorized): - session.delete_resource('/bar/something') + session.delete_resource("/bar/something") def test_delete_unauthorized_without_json(session: Session): with requests_mock.Mocker() as m: - m.delete('http://citrine-testing.fake/api/v1/bar/something', status_code=403) + m.delete("http://citrine-testing.fake/api/v1/bar/something", status_code=403) with pytest.raises(Unauthorized): - session.delete_resource('/bar/something') + session.delete_resource("/bar/something") def test_failed_put_with_stacktrace(session: Session): with mock.patch("time.sleep", return_value=None): with requests_mock.Mocker() as m: m.put( - 'http://citrine-testing.fake/api/v1/bad-endpoint', + "http://citrine-testing.fake/api/v1/bad-endpoint", status_code=500, - json={'debug_stacktrace': 'blew up!'} + json={"debug_stacktrace": "blew up!"}, ) with pytest.raises(Exception) as e: - session.put_resource('/bad-endpoint', json={}) + session.put_resource("/bad-endpoint", json={}) assert '{"debug_stacktrace": "blew up!"}' == str(e.value) @@ -258,37 +293,64 @@ def test_cursor_paged_resource(): fake_request = make_fake_cursor_request_function(full_result_set) # varying page size should not affect final result - assert list(Session.cursor_paged_resource(fake_request, 'foo', forward=True, per_page=10)) == full_result_set - assert list(Session.cursor_paged_resource(fake_request, 'foo', forward=True, per_page=26)) == full_result_set - assert list(Session.cursor_paged_resource(fake_request, 'foo', forward=True, per_page=40)) == full_result_set + assert ( + list( + Session.cursor_paged_resource( + fake_request, "foo", forward=True, per_page=10 + ) + ) + == full_result_set + ) + assert ( + list( + Session.cursor_paged_resource( + fake_request, "foo", forward=True, per_page=26 + ) + ) + == full_result_set + ) + assert ( + list( + Session.cursor_paged_resource( + fake_request, "foo", forward=True, per_page=40 + ) + ) + == full_result_set + ) def test_bad_json_response(session: Session): with requests_mock.Mocker() as m: - m.delete('http://citrine-testing.fake/api/v1/bar/something', - status_code=200, - headers={'content-type': "application/json"}) - response_json = session.delete_resource('/bar/something') + m.delete( + "http://citrine-testing.fake/api/v1/bar/something", + status_code=200, + headers={"content-type": "application/json"}, + ) + response_json = session.delete_resource("/bar/something") assert response_json == {} def test_good_json_response(session: Session): with requests_mock.Mocker() as m: json_to_validate = {"bar": "something"} - m.put('http://citrine-testing.fake/api/v1/bar/something', - status_code=200, - json=json_to_validate, - headers={'content-type': "application/json"}) - response_json = session.put_resource('bar/something', {"ignored": "true"}) + m.put( + "http://citrine-testing.fake/api/v1/bar/something", + status_code=200, + json=json_to_validate, + headers={"content-type": "application/json"}, + ) + response_json = session.put_resource("bar/something", {"ignored": "true"}) assert response_json == json_to_validate def test_patch(session: Session): with requests_mock.Mocker() as m: json_to_validate = {"bar": "something"} - m.patch('http://citrine-testing.fake/api/v1/bar/something', - status_code=200, - json=json_to_validate, - headers={'content-type': "application/json"}) - response_json = session.patch_resource('bar/something', {"ignored": "true"}) + m.patch( + "http://citrine-testing.fake/api/v1/bar/something", + status_code=200, + json=json_to_validate, + headers={"content-type": "application/json"}, + ) + response_json = session.patch_resource("bar/something", {"ignored": "true"}) assert response_json == json_to_validate diff --git a/tests/utils/factories.py b/tests/utils/factories.py index dc7f96399..8efc00091 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -4,15 +4,35 @@ # Naming convention here is to use "*DataFactory" for dictionaries used as API input/out, and # Factory for the domain objects themselves +from random import randint, random +from typing import Optional, Set + import arrow import factory from faker.providers.date_time import Provider -from random import random, randint -from typing import Set, Optional +from gemd import EmpiricalFormula, FileLink, LinkByUID +from gemd.enumeration import SampleType -from citrine.gemd_queries.gemd_query import * -from citrine.gemd_queries.criteria import * -from citrine.gemd_queries.filter import * +from citrine.gemd_queries.criteria import ( + AndOperator, + ConnectivityClassCriteria, + MaterialClassification, + MaterialRunClassificationCriteria, + MaterialTemplatesCriteria, + NameCriteria, + OrOperator, + PropertiesCriteria, + TagFilterType, + TagsCriteria, + TextSearchType, +) +from citrine.gemd_queries.filter import ( + AllIntegerFilter, + AllRealFilter, + NominalCategoricalFilter, +) +from citrine.gemd_queries.gemd_query import GemdObjectType +from citrine.informatics.data_sources import GemTableDataSource from citrine.informatics.scores import LIScore from citrine.informatics.workflows import DesignWorkflow from citrine.jobs.job import JobStatus @@ -25,9 +45,6 @@ from citrine.resources.process_template import ProcessTemplate from citrine.resources.table_config import TableConfigInitiator -from gemd import LinkByUID, EmpiricalFormula, FileLink -from gemd.enumeration import SampleType - class AugmentedProvider(Provider): def random_formula(self, count: int = None, elements: Set[str] = None) -> str: @@ -38,7 +55,9 @@ def random_formula(self, count: int = None, elements: Set[str] = None) -> str: count = self.generator.random.randrange(1, 5) components = sorted(self.generator.random.sample(elements, count)) # Use weights to bias toward looking more real-ish - amounts = self.generator.random.choices([1, 2, 3, 4, 5], weights=[40, 40, 10, 10, 2], k=count) + amounts = self.generator.random.choices( + [1, 2, 3, 4, 5], weights=[40, 40, 10, 10, 2], k=count + ) return "".join(f"({c}){a}" for c, a in zip(components, amounts)) def random_smiles(self) -> str: @@ -53,7 +72,7 @@ def random_smiles(self) -> str: "F": 1, "Cl": 1, "Br": 1, - "I": 1 + "I": 1, } valence = { "B": 3, @@ -65,9 +84,9 @@ def random_smiles(self) -> str: "F": 1, "Cl": 1, "Br": 1, - "I": 1 + "I": 1, } - bonds = ['', '=', '#', '$'] + bonds = ["", "=", "#", "$"] elements = list(element_weights) weights = list(element_weights.values()) @@ -83,12 +102,17 @@ def random_smiles(self) -> str: else: atom = self.generator.random.choices(elements, weights=weights)[0] max_bond = max(valence[atom], remain[-1]) - bond = 1 + self.generator.random.choices( - range(max_bond), - weights=[0.1 ** i for i in range(max_bond)] - )[0] + bond = ( + 1 + + self.generator.random.choices( + range(max_bond), weights=[0.1**i for i in range(max_bond)] + )[0] + ) remain[-1] -= bond - if remain[-1] > 1 and self.generator.random.randrange(3 ** len(remain)) == 0: + if ( + remain[-1] > 1 + and self.generator.random.randrange(3 ** len(remain)) == 0 + ): # Branch remain.append(None) smiles += "(" @@ -98,9 +122,9 @@ def random_smiles(self) -> str: return smiles[:-1] # Always has a superfluous ) at the end def unix_milliseconds( - self, - end_milliseconds: Optional[int] = None, - start_milliseconds: Optional[int] = None, + self, + end_milliseconds: Optional[int] = None, + start_milliseconds: Optional[int] = None, ) -> float: """ Get a timestamp in milliseconds between January 1, 1970 and now, unless @@ -131,19 +155,19 @@ class UserTimestampDataFactory(factory.DictFactory): class TeamDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') - name = factory.Faker('company') - description = factory.Faker('catch_phrase') + id = factory.Faker("uuid4") + name = factory.Faker("company") + description = factory.Faker("catch_phrase") created_at = factory.Faker("unix_milliseconds") class ProjectDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') - name = factory.Faker('company') - description = factory.Faker('catch_phrase') - status = 'CREATED' + id = factory.Faker("uuid4") + name = factory.Faker("company") + description = factory.Faker("catch_phrase") + status = "CREATED" created_at = factory.Faker("unix_milliseconds") - team_id = factory.Faker('uuid4') + team_id = factory.Faker("uuid4") class DataVersionUpdateFactory(factory.DictFactory): @@ -152,8 +176,8 @@ class DataVersionUpdateFactory(factory.DictFactory): class PredictorRefFactory(factory.DictFactory): - predictor_id = factory.Faker('uuid4') - predictor_version = factory.Faker('random_digit_not_null') + predictor_id = factory.Faker("uuid4") + predictor_version = factory.Faker("random_digit_not_null") class BranchDataUpdateFactory(factory.DictFactory): @@ -167,47 +191,47 @@ class NextBranchVersionFactory(factory.DictFactory): class BranchDataFieldFactory(factory.DictFactory): - name = factory.Faker('company') + name = factory.Faker("company") class BranchMetadataFieldFactory(factory.DictFactory): - root_id = factory.Faker('uuid4') - archived = factory.Faker('boolean') - version = factory.Faker('random_digit_not_null') + root_id = factory.Faker("uuid4") + archived = factory.Faker("boolean") + version = factory.Faker("random_digit_not_null") created = factory.SubFactory(UserTimestampDataFactory) updated = factory.SubFactory(UserTimestampDataFactory) class BranchDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(BranchDataFieldFactory) metadata = factory.SubFactory(BranchMetadataFieldFactory) class BranchVersionRefFactory(factory.DictFactory): - id = factory.Faker('uuid4') - version = factory.Faker('random_digit_not_null') + id = factory.Faker("uuid4") + version = factory.Faker("random_digit_not_null") class BranchRootMetadataFieldFactory(factory.DictFactory): latest_branch_version = factory.SubFactory(BranchVersionRefFactory) - archived = factory.Faker('boolean') + archived = factory.Faker("boolean") created = factory.SubFactory(UserTimestampDataFactory) updated = factory.SubFactory(UserTimestampDataFactory) class BranchRootDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(BranchDataFieldFactory) metadata = factory.SubFactory(BranchRootMetadataFieldFactory) class UserDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') - screen_name = factory.Faker('name') - position = factory.Faker('job') - email = factory.Faker('email') - is_admin = factory.Faker('boolean') + id = factory.Faker("uuid4") + screen_name = factory.Faker("name") + position = factory.Faker("job") + email = factory.Faker("email") + is_admin = factory.Faker("boolean") class GemTableDataFactory(factory.DictFactory): @@ -218,9 +242,9 @@ class GemTableDataFactory(factory.DictFactory): * table-configs/{table_config_uid_str}/gem-tables """ - id = factory.Faker('uuid4') - version = factory.Faker('random_digit_not_null') - signed_download_url = factory.Faker('uri') + id = factory.Faker("uuid4") + version = factory.Faker("random_digit_not_null") + signed_download_url = factory.Faker("uri") class ListGemTableVersionsDataFactory(factory.DictFactory): @@ -230,17 +254,20 @@ class ListGemTableVersionsDataFactory(factory.DictFactory): * gem-tables/ * gem-tables/{table_identity_id} """ + # Explicitly set version numbers so that they are distinct - tables = factory.List([ - factory.SubFactory(GemTableDataFactory, version=1), - factory.SubFactory(GemTableDataFactory, version=4), - factory.SubFactory(GemTableDataFactory, version=2), - ]) + tables = factory.List( + [ + factory.SubFactory(GemTableDataFactory, version=1), + factory.SubFactory(GemTableDataFactory, version=4), + factory.SubFactory(GemTableDataFactory, version=2), + ] + ) class RealFilterDataFactory(factory.DictFactory): type = AllRealFilter.typ - unit = 'dimensionless' + unit = "dimensionless" lower = factory.LazyAttribute(lambda o: min(0, 2 * o.upper) + random() * o.upper) upper = factory.Faker("pyfloat") @@ -253,12 +280,12 @@ class IntegerFilterDataFactory(factory.DictFactory): class CategoryFilterDataFactory(factory.DictFactory): type = NominalCategoricalFilter.typ - categories = factory.Faker('words', unique=True) + categories = factory.Faker("words", unique=True) class PropertiesCriteriaDataFactory(factory.DictFactory): type = PropertiesCriteria.typ - property_templates_filter = factory.List([factory.Faker('uuid4')]) + property_templates_filter = factory.List([factory.Faker("uuid4")]) value_type_filter = factory.SubFactory(RealFilterDataFactory) class Params: @@ -272,95 +299,101 @@ class Params: class NameCriteriaDataFactory(factory.DictFactory): type = NameCriteria.typ - name = factory.Faker('word') - search_type = factory.Faker('enum', enum_cls=TextSearchType) + name = factory.Faker("word") + search_type = factory.Faker("enum", enum_cls=TextSearchType) class MaterialRunClassificationCriteriaDataFactory(factory.DictFactory): type = MaterialRunClassificationCriteria.typ classifications = factory.Faker( - 'random_elements', + "random_elements", elements=[str(x) for x in MaterialClassification], - unique=True + unique=True, ) class MaterialTemplatesCriteriaDataFactory(factory.DictFactory): type = MaterialTemplatesCriteria.typ - material_templates_identifiers = factory.List([factory.Faker('uuid4')]) - tag_filters = factory.Faker('words', unique=True) + material_templates_identifiers = factory.List([factory.Faker("uuid4")]) + tag_filters = factory.Faker("words", unique=True) class ConnectivityClassCriteriaDataFactory(factory.DictFactory): type = ConnectivityClassCriteria.typ - is_consumed = factory.Faker('boolean') - is_produced = factory.Faker('boolean') + is_consumed = factory.Faker("boolean") + is_produced = factory.Faker("boolean") class TagsCriteriaDataFactory(factory.DictFactory): type = TagsCriteria.typ - tags = factory.Faker('words', unique=True) - filter_type = factory.Faker('enum', enum_cls=TagFilterType) + tags = factory.Faker("words", unique=True) + filter_type = factory.Faker("enum", enum_cls=TagFilterType) class AndOperatorCriteriaDataFactory(factory.DictFactory): type = AndOperator.typ - criteria = factory.List([ - factory.SubFactory(NameCriteriaDataFactory), - factory.SubFactory(MaterialRunClassificationCriteriaDataFactory), - factory.SubFactory(MaterialTemplatesCriteriaDataFactory) - ]) + criteria = factory.List( + [ + factory.SubFactory(NameCriteriaDataFactory), + factory.SubFactory(MaterialRunClassificationCriteriaDataFactory), + factory.SubFactory(MaterialTemplatesCriteriaDataFactory), + ] + ) class OrOperatorCriteriaDataFactory(factory.DictFactory): type = OrOperator.typ - criteria = factory.List([ - factory.SubFactory(PropertiesCriteriaDataFactory), - factory.SubFactory(PropertiesCriteriaDataFactory, integer=True), - factory.SubFactory(PropertiesCriteriaDataFactory, category=True), - factory.SubFactory(AndOperatorCriteriaDataFactory) - ]) + criteria = factory.List( + [ + factory.SubFactory(PropertiesCriteriaDataFactory), + factory.SubFactory(PropertiesCriteriaDataFactory, integer=True), + factory.SubFactory(PropertiesCriteriaDataFactory, category=True), + factory.SubFactory(AndOperatorCriteriaDataFactory), + ] + ) class GemdQueryDataFactory(factory.DictFactory): criteria = factory.List([factory.SubFactory(OrOperatorCriteriaDataFactory)]) - datasets = factory.List([factory.Faker('uuid4')]) + datasets = factory.List([factory.Faker("uuid4")]) object_types = factory.List([str(x) for x in GemdObjectType]) schema_version = 1 class TableConfigMainMetaDataDataFactory(factory.DictFactory): """This is the metadata for the primary definition ID of the TableConfig.""" - id = factory.Faker('uuid4') + + id = factory.Faker("uuid4") deleted = False create_time = factory.Faker("unix_milliseconds") - created_by = factory.Faker('uuid4') + created_by = factory.Faker("uuid4") update_time = factory.Faker("unix_milliseconds") - updated_by = factory.Faker('uuid4') + updated_by = factory.Faker("uuid4") class TableConfigDataFactory(factory.DictFactory): """This is simply the Blob stored in a Table Config Version.""" + name = factory.Faker("company") - description = factory.Faker('bs') + description = factory.Faker("bs") # TODO Create factories for definitions rows = [] columns = [] variables = [] - datasets = factory.List([factory.Faker('uuid4')]) + datasets = factory.List([factory.Faker("uuid4")]) gemd_query = factory.SubFactory(GemdQueryDataFactory) class TableConfigVersionMetaDataDataFactory(factory.DictFactory): ara_definition = factory.SubFactory(TableConfigDataFactory) - id = factory.Faker('uuid4') - definition_id = factory.Faker('uuid4') - version_number = factory.Faker('random_digit_not_null') + id = factory.Faker("uuid4") + definition_id = factory.Faker("uuid4") + version_number = factory.Faker("random_digit_not_null") deleted = False create_time = factory.Faker("unix_milliseconds") - created_by = factory.Faker('uuid4') + created_by = factory.Faker("uuid4") update_time = factory.Faker("unix_milliseconds") - updated_by = factory.Faker('uuid4') + updated_by = factory.Faker("uuid4") initiator = str(TableConfigInitiator.CITRINE_PYTHON) @@ -371,27 +404,30 @@ class TableConfigResponseDataFactory(factory.DictFactory): * projects/{project_id}/display-tables/{uid}/versions/{version}/definition """ + definition = factory.SubFactory(TableConfigMainMetaDataDataFactory) version = factory.SubFactory(TableConfigVersionMetaDataDataFactory) class ListTableConfigResponseDataFactory(factory.DictFactory): """This encapsulates all of the versions of a table config object.""" + definition = factory.SubFactory(TableConfigMainMetaDataDataFactory) # Explicitly set version numbers so that they are distinct - versions = factory.List([ - factory.SubFactory(TableConfigVersionMetaDataDataFactory, version_number=1), - factory.SubFactory(TableConfigVersionMetaDataDataFactory, version_number=4), - factory.SubFactory(TableConfigVersionMetaDataDataFactory, version_number=2), - ]) + versions = factory.List( + [ + factory.SubFactory(TableConfigVersionMetaDataDataFactory, version_number=1), + factory.SubFactory(TableConfigVersionMetaDataDataFactory, version_number=4), + factory.SubFactory(TableConfigVersionMetaDataDataFactory, version_number=2), + ] + ) class TableDataSourceDataFactory(factory.DictFactory): type = "hosted_table_data_source" table_id = factory.Faker("uuid4") - table_version = factory.Faker('random_digit_not_null') + table_version = factory.Faker("random_digit_not_null") -from citrine.informatics.data_sources import GemTableDataSource class TableDataSourceFactory(factory.Factory): class Meta: @@ -436,7 +472,7 @@ class PredictorDataDataFactory(factory.DictFactory): class PredictorEntityDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(PredictorDataDataFactory) metadata = factory.SubFactory(PredictorMetadataDataFactory) @@ -458,7 +494,7 @@ class AsyncDefaultPredictorResponseDataFactory(factory.DictFactory): class AsyncDefaultPredictorResponseFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") metadata = factory.SubFactory(AsyncDefaultPredictorResponseMetadataFactory) data = factory.SubFactory(AsyncDefaultPredictorResponseDataFactory) @@ -493,9 +529,9 @@ class AreaUnderROCFactory(factory.DictFactory): class CoverageProbabilityFactory(factory.DictFactory): class Meta: - exclude = ("_level", ) + exclude = ("_level",) - _level = factory.Faker('pyfloat', max_value=1, min_value=0) + _level = factory.Faker("pyfloat", max_value=1, min_value=0) coverage_level = factory.LazyAttribute(lambda o: str(o._level)) type = "CoverageProbability" @@ -503,22 +539,26 @@ class Meta: class CrossValidationEvaluatorFactory(factory.DictFactory): name = factory.Faker("company") description = factory.Faker("catch_phrase") - responses = factory.List(3 * [factory.Faker('company')]) - n_folds = factory.Faker('random_digit_not_null') - n_trials = factory.Faker('random_digit_not_null') - metrics = factory.List([factory.SubFactory(RMSEFactory), - factory.SubFactory(NDMEFactory), - factory.SubFactory(RSquaredFactory), - factory.SubFactory(StandardRMSEFactory), - factory.SubFactory(PVALFactory), - factory.SubFactory(F1Factory), - factory.SubFactory(AreaUnderROCFactory), - factory.SubFactory(CoverageProbabilityFactory)]) + responses = factory.List(3 * [factory.Faker("company")]) + n_folds = factory.Faker("random_digit_not_null") + n_trials = factory.Faker("random_digit_not_null") + metrics = factory.List( + [ + factory.SubFactory(RMSEFactory), + factory.SubFactory(NDMEFactory), + factory.SubFactory(RSquaredFactory), + factory.SubFactory(StandardRMSEFactory), + factory.SubFactory(PVALFactory), + factory.SubFactory(F1Factory), + factory.SubFactory(AreaUnderROCFactory), + factory.SubFactory(CoverageProbabilityFactory), + ] + ) type = "CrossValidationEvaluator" class PredictorEvaluationWorkflowFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") name = factory.Faker("company") description = factory.Faker("catch_phrase") archived = False @@ -536,24 +576,26 @@ class PredictorEvaluationDataFactory(factory.DictFactory): class PredictorEvaluationMetadataFactory(factory.DictFactory): class Meta: - exclude = ('is_archived', ) + exclude = ("is_archived",) created = factory.SubFactory(UserTimestampDataFactory) updated = factory.SubFactory(UserTimestampDataFactory) - archived = factory.Maybe('is_archived', factory.SubFactory(UserTimestampDataFactory), None) + archived = factory.Maybe( + "is_archived", factory.SubFactory(UserTimestampDataFactory), None + ) predictor_id = factory.Faker("uuid4") predictor_version = factory.Faker("random_digit_not_null") status = {"major": "SUCCEEDED", "minor": "READY", "detail": []} class PredictorEvaluationFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(PredictorEvaluationDataFactory) metadata = factory.SubFactory(PredictorEvaluationMetadataFactory) class DesignSpaceConfigDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") name = factory.Faker("company") descriptor = factory.Faker("catch_phrase") subspaces = [] # TODO Create SubspaceDataFactory @@ -563,7 +605,7 @@ class DesignSpaceConfigDataFactory(factory.DictFactory): class DesignSpaceDataFactory(factory.DictFactory): config = factory.SubFactory(DesignSpaceConfigDataFactory) - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") display_name = factory.Faker("company") archived = False module_type = "DESIGN_SPACE" @@ -578,28 +620,28 @@ class Params: branch = factory.SubFactory(BranchDataFactory) times = factory.List([factory.Faker("unix_milliseconds") for i in range(3)]) register = factory.Trait( - id = factory.Faker('uuid4'), - branch_id = factory.LazyAttribute(lambda o: o.branch["id"]), - created_by = factory.Faker('uuid4'), - updated_by = factory.LazyAttribute(lambda o: o.created_by), - create_time = factory.LazyAttribute(lambda o: sorted(o.times)[0]), - update_time = factory.LazyAttribute(lambda o: sorted(o.times)[0]), + id=factory.Faker("uuid4"), + branch_id=factory.LazyAttribute(lambda o: o.branch["id"]), + created_by=factory.Faker("uuid4"), + updated_by=factory.LazyAttribute(lambda o: o.created_by), + create_time=factory.LazyAttribute(lambda o: sorted(o.times)[0]), + update_time=factory.LazyAttribute(lambda o: sorted(o.times)[0]), # TODO: Create a Trait for statuses - status = "SUCCEEDED", - status_description = "READY", - status_info = [], - status_detail = [] + status="SUCCEEDED", + status_description="READY", + status_info=[], + status_detail=[], ) update = factory.Trait( - register = True, - updated_by = factory.Faker('uuid4'), - update_time = factory.LazyAttribute(lambda o: sorted(o.times)[1]) + register=True, + updated_by=factory.Faker("uuid4"), + update_time=factory.LazyAttribute(lambda o: sorted(o.times)[1]), ) archive = factory.Trait( - update = True, - archived = True, - archived_by = factory.Faker('uuid4'), - archive_time = factory.LazyAttribute(lambda o: sorted(o.times)[2]), + update=True, + archived=True, + archived_by=factory.Faker("uuid4"), + archive_time=factory.LazyAttribute(lambda o: sorted(o.times)[2]), ) type = DesignWorkflow.typ @@ -612,41 +654,39 @@ class Params: branch_root_id = factory.LazyAttribute(lambda o: o.branch["metadata"]["root_id"]) branch_version = factory.LazyAttribute(lambda o: o.branch["metadata"]["version"]) archived = False - status_description = "" # TODO: Should be None, but property not defined as Optional + status_description = ( + "" # TODO: Should be None, but property not defined as Optional + ) class IngestFilesResponseDataFactory(factory.DictFactory): - team_id = factory.Faker('uuid4') - dataset_id = factory.Faker('uuid4') - ingestion_id = factory.Faker('uuid4') + team_id = factory.Faker("uuid4") + dataset_id = factory.Faker("uuid4") + ingestion_id = factory.Faker("uuid4") class IngestionStatusResponseDataFactory(factory.DictFactory): - ingestion_id = factory.Faker('uuid4') + ingestion_id = factory.Faker("uuid4") status = IngestionStatusType.INGESTION_CREATED errors = factory.List([]) class JobSubmissionResponseDataFactory(factory.DictFactory): - job_id = factory.Faker('uuid4') + job_id = factory.Faker("uuid4") class TaskNodeDataFactory(factory.DictFactory): class Params: failure = False - id = factory.Faker('uuid4') - task_type = factory.Faker('word') + id = factory.Faker("uuid4") + task_type = factory.Faker("word") status = factory.Maybe( - "failure", - yes_declaration=JobStatus.FAILURE, - no_declaration=JobStatus.SUCCESS + "failure", yes_declaration=JobStatus.FAILURE, no_declaration=JobStatus.SUCCESS ) dependencies = factory.List([]) failure_reason = factory.Maybe( - "failure", - yes_declaration=factory.Faker('sentence'), - no_declaration=None + "failure", yes_declaration=factory.Faker("sentence"), no_declaration=None ) @@ -654,15 +694,17 @@ class JobStatusResponseDataFactory(factory.DictFactory): class Params: failure = False - job_type = factory.Faker('word') + job_type = factory.Faker("word") status = factory.Maybe( - "failure", - yes_declaration=JobStatus.FAILURE, - no_declaration=JobStatus.SUCCESS + "failure", yes_declaration=JobStatus.FAILURE, no_declaration=JobStatus.SUCCESS + ) + tasks = factory.List( + [ + factory.RelatedFactory( + TaskNodeDataFactory, failure=factory.SelfAttribute("...failure") + ) + ] ) - tasks = factory.List([ - factory.RelatedFactory(TaskNodeDataFactory, failure=factory.SelfAttribute('...failure')) - ]) output = factory.Dict({}) @@ -670,14 +712,14 @@ class DatasetDataFactory(factory.DictFactory): class Params: times = factory.List([factory.Faker("unix_milliseconds") for i in range(3)]) - id = factory.Faker('uuid4') - name = factory.Faker('company') - summary = factory.Faker('catch_phrase') - description = factory.Faker('bs') + id = factory.Faker("uuid4") + name = factory.Faker("company") + summary = factory.Faker("catch_phrase") + description = factory.Faker("bs") deleted = False - created_by = factory.Faker('uuid4') - updated_by = factory.Faker('uuid4') - deleted_by = factory.Faker('uuid4') + created_by = factory.Faker("uuid4") + updated_by = factory.Faker("uuid4") + deleted_by = factory.Faker("uuid4") unique_name = None # TODO Update tests to include unique_name create_time = factory.LazyAttribute(lambda o: sorted(o.times)[0]) update_time = factory.LazyAttribute(lambda o: sorted(o.times)[1]) @@ -686,23 +728,23 @@ class Params: class IDDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") class LinkByUIDFactory(factory.Factory): class Meta: model = LinkByUID - scope = 'id' - id = factory.Faker('uuid4') + scope = "id" + id = factory.Faker("uuid4") class FileLinkFactory(factory.Factory): class Meta: model = FileLink - url = factory.Faker('uri') - filename = factory.Faker('file_name') + url = factory.Faker("uri") + filename = factory.Faker("file_name") class ProcessTemplateFactory(factory.Factory): @@ -710,9 +752,9 @@ class Meta: model = ProcessTemplate uids = factory.SubFactory(IDDataFactory) - name = factory.Faker('color_name') - tags = factory.List([factory.Faker('color_name'), factory.Faker('color_name')]) - description = factory.Faker('catch_phrase') + name = factory.Faker("color_name") + tags = factory.List([factory.Faker("color_name"), factory.Faker("color_name")]) + description = factory.Faker("catch_phrase") conditions = [] # TODO make a ConditionsTemplateFactory parameters = [] # TODO make a ParametersTemplateFactory @@ -722,10 +764,10 @@ class Meta: model = MaterialTemplate uids = factory.SubFactory(IDDataFactory) - name = factory.Faker('color_name') - tags = factory.List([factory.Faker('color_name'), factory.Faker('color_name')]) + name = factory.Faker("color_name") + tags = factory.List([factory.Faker("color_name"), factory.Faker("color_name")]) properties = [] # TODO make a PropertiesTemplateFactory - description = factory.Faker('catch_phrase') + description = factory.Faker("catch_phrase") class MaterialSpecFactory(factory.Factory): @@ -733,9 +775,9 @@ class Meta: model = MaterialSpec uids = factory.SubFactory(IDDataFactory) - name = factory.Faker('color_name') - tags = factory.List([factory.Faker('color_name'), factory.Faker('color_name')]) - notes = factory.Faker('catch_phrase') + name = factory.Faker("color_name") + tags = factory.List([factory.Faker("color_name"), factory.Faker("color_name")]) + notes = factory.Faker("catch_phrase") process = factory.SubFactory(LinkByUIDFactory) file_links = factory.List([factory.SubFactory(FileLinkFactory)]) template = factory.SubFactory(LinkByUIDFactory) @@ -747,9 +789,9 @@ class Meta: model = MaterialRun uids = factory.SubFactory(IDDataFactory) - name = factory.Faker('color_name') - tags = factory.List([factory.Faker('color_name'), factory.Faker('color_name')]) - notes = factory.Faker('catch_phrase') + name = factory.Faker("color_name") + tags = factory.List([factory.Faker("color_name"), factory.Faker("color_name")]) + notes = factory.Faker("catch_phrase") process = factory.SubFactory(LinkByUIDFactory) sample_type = factory.Faker("enum", enum_cls=SampleType) spec = factory.SubFactory(LinkByUIDFactory) @@ -759,13 +801,13 @@ class Meta: class LinkByUIDDataFactory(factory.DictFactory): id = LinkByUIDFactory.id scope = LinkByUIDFactory.scope - type = 'link_by_uid' + type = "link_by_uid" class FileLinkDataFactory(factory.DictFactory): url = FileLinkFactory.url filename = FileLinkFactory.filename - type = 'file_link' + type = "file_link" class MaterialSpecDataFactory(factory.DictFactory): @@ -777,7 +819,7 @@ class MaterialSpecDataFactory(factory.DictFactory): file_links = factory.List([factory.SubFactory(FileLinkDataFactory)]) template = factory.SubFactory(LinkByUIDDataFactory) properties = [] # TODO make a PropertiesDataFactory - type = 'material_spec' + type = "material_spec" class MaterialRunDataFactory(factory.DictFactory): @@ -789,16 +831,16 @@ class MaterialRunDataFactory(factory.DictFactory): sample_type = MaterialRunFactory.sample_type spec = factory.SubFactory(LinkByUIDDataFactory) file_links = factory.List([factory.SubFactory(FileLinkDataFactory)]) - type = 'material_run' + type = "material_run" class DatasetFactory(factory.Factory): class Meta: model = Dataset - name = factory.Faker('company') - summary = factory.Faker('catch_phrase') - description = factory.Faker('bs') + name = factory.Faker("company") + summary = factory.Faker("catch_phrase") + description = factory.Faker("bs") unique_name = None # TODO Update tests to include unique_name @@ -809,14 +851,14 @@ class Meta: # TODO Bring _Uploader in line with other library concepts @factory.post_generation def assign_values(obj, create, extracted): - obj.bucket = 'citrine-datasvc' - obj.object_key = '334455' - obj.upload_id = 'dea3a-555' - obj.region_name = 'us-west' - obj.aws_access_key_id = 'dkfjiejkcm' - obj.aws_secret_access_key = 'ifeemkdsfjeijie8759235u2wjr388' - obj.aws_session_token = 'fafjeijfi87834j87woa' - obj.s3_version = '2' + obj.bucket = "citrine-datasvc" + obj.object_key = "334455" + obj.upload_id = "dea3a-555" + obj.region_name = "us-west" + obj.aws_access_key_id = "dkfjiejkcm" + obj.aws_secret_access_key = "ifeemkdsfjeijie8759235u2wjr388" + obj.aws_session_token = "fafjeijfi87834j87woa" + obj.s3_version = "2" class MLIScoreFactory(factory.Factory): @@ -830,17 +872,17 @@ class Meta: class CategoricalExperimentValueDataFactory(factory.DictFactory): type = "CategoricalValue" - value = factory.Faker('company') + value = factory.Faker("company") class ChemicalFormulaExperimentValueDataFactory(factory.DictFactory): type = "InorganicValue" - value = factory.Faker('random_formula') + value = factory.Faker("random_formula") class IntegerExperimentValueDataFactory(factory.DictFactory): type = "IntegerValue" - value = factory.Faker('random_int', min=1, max=99) + value = factory.Faker("random_int", min=1, max=99) class MixtureExperimentValueDataFactory(factory.DictFactory): @@ -850,127 +892,170 @@ class MixtureExperimentValueDataFactory(factory.DictFactory): class MolecularStructureExperimentValueDataFactory(factory.DictFactory): type = "OrganicValue" - value = factory.Faker('random_smiles') + value = factory.Faker("random_smiles") class RealExperimentValueDataFactory(factory.DictFactory): type = "RealValue" - value = factory.Faker('pyfloat', min_value=0, max_value=100) + value = factory.Faker("pyfloat", min_value=0, max_value=100) class CandidateExperimentSnapshotDataFactory(factory.DictFactory): - experiment_id = factory.Faker('uuid4') - candidate_id = factory.Faker('uuid4') - workflow_id = factory.Faker('uuid4') - name = factory.Faker('company') - description = factory.Faker('company') + experiment_id = factory.Faker("uuid4") + candidate_id = factory.Faker("uuid4") + workflow_id = factory.Faker("uuid4") + name = factory.Faker("company") + description = factory.Faker("company") updated_time = factory.Faker("unix_milliseconds") # TODO Generate Experiment keys randomly but uniquely - overrides = factory.Dict({ - "ingredient1": factory.SubFactory(CategoricalExperimentValueDataFactory), - "ingredient2": factory.SubFactory(ChemicalFormulaExperimentValueDataFactory), - "ingredient3": factory.SubFactory(IntegerExperimentValueDataFactory), - "Formulation": factory.SubFactory(MixtureExperimentValueDataFactory), - "ingredient4": factory.SubFactory(MolecularStructureExperimentValueDataFactory), - "ingredient5": factory.SubFactory(RealExperimentValueDataFactory) - }) + overrides = factory.Dict( + { + "ingredient1": factory.SubFactory(CategoricalExperimentValueDataFactory), + "ingredient2": factory.SubFactory( + ChemicalFormulaExperimentValueDataFactory + ), + "ingredient3": factory.SubFactory(IntegerExperimentValueDataFactory), + "Formulation": factory.SubFactory(MixtureExperimentValueDataFactory), + "ingredient4": factory.SubFactory( + MolecularStructureExperimentValueDataFactory + ), + "ingredient5": factory.SubFactory(RealExperimentValueDataFactory), + } + ) class ExperimentDataSourceDataDataFactory(factory.DictFactory): - experiments = factory.List([factory.SubFactory(CandidateExperimentSnapshotDataFactory)]) + experiments = factory.List( + [factory.SubFactory(CandidateExperimentSnapshotDataFactory)] + ) class ExperimentDataSourceMetadataDataFactory(factory.DictFactory): - branch_root_id = factory.Faker('uuid4') - version = factory.Faker('random_digit_not_null') + branch_root_id = factory.Faker("uuid4") + version = factory.Faker("random_digit_not_null") created = factory.SubFactory(UserTimestampDataFactory) class ExperimentDataSourceDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(ExperimentDataSourceDataDataFactory) metadata = factory.SubFactory(ExperimentDataSourceMetadataDataFactory) class AnalysisPlotMetadataDataFactory(factory.DictFactory): - rank = factory.Faker('random_int', min=1, max=10) + rank = factory.Faker("random_int", min=1, max=10) created = factory.SubFactory(UserTimestampDataFactory) updated = factory.SubFactory(UserTimestampDataFactory) class AnalysisPlotDataDataFactory(factory.DictFactory): - name = factory.Faker('company') - description = factory.Faker('catch_phrase') - plot_type = factory.Faker('random_element', elements=('SCATTER', 'VIOLIN')) + name = factory.Faker("company") + description = factory.Faker("catch_phrase") + plot_type = factory.Faker("random_element", elements=("SCATTER", "VIOLIN")) config = {} class AnalysisPlotEntityDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(AnalysisPlotDataDataFactory) metadata = factory.SubFactory(AnalysisPlotMetadataDataFactory) class LatestBuildDataFactory(factory.DictFactory): class Params: - is_failed = factory.LazyAttribute(lambda o: o.status == 'FAILED') + is_failed = factory.LazyAttribute(lambda o: o.status == "FAILED") - status = factory.Faker('random_element', elements=('INPROGRESS', 'SUCCEEDED', 'FAILED')) - failure_reason = factory.Maybe('is_failed', ['This is a test failure message'], []) + status = factory.Faker( + "random_element", elements=("INPROGRESS", "SUCCEEDED", "FAILED") + ) + failure_reason = factory.Maybe("is_failed", ["This is a test failure message"], []) query = factory.SubFactory(GemdQueryDataFactory) class AnalysisWorkflowMetadataDataFactory(factory.DictFactory): class Meta: - exclude = ('is_archived', 'has_build') + exclude = ("is_archived", "has_build") created = factory.SubFactory(UserTimestampDataFactory) updated = factory.SubFactory(UserTimestampDataFactory) - archived = factory.Maybe('is_archived', factory.SubFactory(UserTimestampDataFactory), None) - latest_build = factory.Maybe('has_build', factory.SubFactory(LatestBuildDataFactory), None) + archived = factory.Maybe( + "is_archived", factory.SubFactory(UserTimestampDataFactory), None + ) + latest_build = factory.Maybe( + "has_build", factory.SubFactory(LatestBuildDataFactory), None + ) class AnalysisWorkflowDataDataFactory(factory.DictFactory): class Meta: - exclude = ('has_snapshot', 'plot_count') + exclude = ("has_snapshot", "plot_count") class Params: plot_count = 1 - name = factory.Faker('company') - description = factory.Faker('catch_phrase') - snapshot_id = factory.Maybe('has_snapshot', factory.Faker('uuid4'), None) - plots = factory.LazyAttribute(lambda self: [AnalysisPlotEntityDataFactory() for _ in range(self.plot_count)]) + name = factory.Faker("company") + description = factory.Faker("catch_phrase") + snapshot_id = factory.Maybe("has_snapshot", factory.Faker("uuid4"), None) + plots = factory.LazyAttribute( + lambda self: [AnalysisPlotEntityDataFactory() for _ in range(self.plot_count)] + ) class AnalysisWorkflowEntityDataFactory(factory.DictFactory): - id = factory.Faker('uuid4') + id = factory.Faker("uuid4") data = factory.SubFactory(AnalysisWorkflowDataDataFactory) metadata = factory.SubFactory(AnalysisWorkflowMetadataDataFactory) class FeatureEffectsResponseResultFactory(factory.DictFactory): - materials = factory.List([ - factory.Faker('uuid4', cast_to=None), - factory.Faker('uuid4', cast_to=None), - factory.Faker('uuid4', cast_to=None) - ]) - outputs = factory.Dict({ - "output1": factory.Dict({ - "feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]) - }), - "output2": factory.Dict({ - "feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]), - "feature2": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]) - }) - }) + materials = factory.List( + [ + factory.Faker("uuid4", cast_to=None), + factory.Faker("uuid4", cast_to=None), + factory.Faker("uuid4", cast_to=None), + ] + ) + outputs = factory.Dict( + { + "output1": factory.Dict( + { + "feature1": factory.List( + [ + factory.Faker("pyfloat"), + factory.Faker("pyfloat"), + factory.Faker("pyfloat"), + ] + ) + } + ), + "output2": factory.Dict( + { + "feature1": factory.List( + [ + factory.Faker("pyfloat"), + factory.Faker("pyfloat"), + factory.Faker("pyfloat"), + ] + ), + "feature2": factory.List( + [ + factory.Faker("pyfloat"), + factory.Faker("pyfloat"), + factory.Faker("pyfloat"), + ] + ), + } + ), + } + ) + class FeatureEffectsMetadataFactory(factory.DictFactory): - predictor_id = factory.Faker('uuid4') - predictor_version = factory.Faker('random_digit_not_null') + predictor_id = factory.Faker("uuid4") + predictor_version = factory.Faker("random_digit_not_null") created = factory.SubFactory(UserTimestampDataFactory) updated = factory.SubFactory(UserTimestampDataFactory) - status = 'SUCCEEDED' + status = "SUCCEEDED" class FeatureEffectsResponseFactory(factory.DictFactory): diff --git a/tests/utils/fakes/__init__.py b/tests/utils/fakes/__init__.py index 6176ea9f5..62c68a9e0 100644 --- a/tests/utils/fakes/__init__.py +++ b/tests/utils/fakes/__init__.py @@ -1,10 +1,11 @@ +# ruff: noqa: F403 from .fake_collection import * -from .fake_file_collection import * from .fake_dataset_collection import * from .fake_descriptor_methods import * from .fake_execution_collection import * +from .fake_file_collection import * +from .fake_module_collection import * +from .fake_project_collection import * from .fake_table_collection import * +from .fake_workflow_collection import * from .fake_workflows import * -from .fake_module_collection import FakeDesignSpaceCollection, FakePredictorCollection -from .fake_workflow_collection import FakeDesignWorkflowCollection -from .fake_project_collection import * diff --git a/tests/utils/fakes/fake_collection.py b/tests/utils/fakes/fake_collection.py index 01f035e70..db40e34ee 100644 --- a/tests/utils/fakes/fake_collection.py +++ b/tests/utils/fakes/fake_collection.py @@ -1,16 +1,15 @@ -from uuid import uuid4, UUID -from typing import TypeVar, Optional, Union, Iterable +from typing import Iterable, Optional, TypeVar, Union +from uuid import UUID, uuid4 from citrine._rest.collection import Collection +from citrine._rest.resource import Resource from citrine.exceptions import NotFound - from tests.utils.functions import normalize_uid -ResourceType = TypeVar('ResourceType', bound='Resource') +ResourceType = TypeVar("ResourceType", bound="Resource") class FakeCollection(Collection[ResourceType]): - def __init__(self): self._resources = {} @@ -19,17 +18,21 @@ def register(self, resource: ResourceType) -> ResourceType: resource.uid = uuid4() self._resources[resource.uid] = resource return resource - + def update(self, resource: ResourceType): self._resources.pop(resource.uid, None) return self.register(resource) - - def list(self, page: Optional[int] = None, per_page: int = 100) -> Iterable[ResourceType]: + + def list( + self, page: Optional[int] = None, per_page: int = 100 + ) -> Iterable[ResourceType]: if page is None: return iter(list(self._resources.values())) else: - return iter(list(self._resources.values())[(page - 1)*per_page:page*per_page]) - + return iter( + list(self._resources.values())[(page - 1) * per_page : page * per_page] + ) + def get(self, uid: Union[UUID, str]) -> ResourceType: if normalize_uid(uid) not in self._resources: raise NotFound("") diff --git a/tests/utils/fakes/fake_dataset_collection.py b/tests/utils/fakes/fake_dataset_collection.py index 2346574eb..eb738b53c 100644 --- a/tests/utils/fakes/fake_dataset_collection.py +++ b/tests/utils/fakes/fake_dataset_collection.py @@ -6,7 +6,6 @@ class FakeDataset(Dataset): - def __init__(self): pass @@ -16,7 +15,6 @@ def files(self) -> FileCollection: class FakeDatasetCollection(DatasetCollection): - def __init__(self, *, session, team_id): super().__init__(team_id=team_id, session=session) self.datasets = [] @@ -29,4 +27,4 @@ def list(self, page: Optional[int] = None, per_page: int = 100): if page is None: return self.datasets else: - return self.datasets[(page - 1)*per_page:page*per_page] + return self.datasets[(page - 1) * per_page : page * per_page] diff --git a/tests/utils/fakes/fake_descriptor_methods.py b/tests/utils/fakes/fake_descriptor_methods.py index e3467b683..5841cd15c 100644 --- a/tests/utils/fakes/fake_descriptor_methods.py +++ b/tests/utils/fakes/fake_descriptor_methods.py @@ -1,11 +1,17 @@ from typing import List, Union from uuid import uuid4 -from citrine.informatics.descriptors import Descriptor, RealDescriptor, CategoricalDescriptor +from citrine.informatics.descriptors import ( + Descriptor, + RealDescriptor, + CategoricalDescriptor, +) from citrine.informatics.predictors import ( ChemicalFormulaFeaturizer, MolecularStructureFeaturizer, - MeanPropertyPredictor, PredictorNode, GraphPredictor + MeanPropertyPredictor, + PredictorNode, + GraphPredictor, ) from citrine.resources.descriptors import DescriptorMethods from tests.utils.session import FakeSession @@ -18,16 +24,26 @@ def __init__(self, num_properties): self.num_properties = num_properties def from_predictor_responses( - self, - predictor: Union[PredictorNode, GraphPredictor], - inputs: List[Descriptor] + self, predictor: Union[PredictorNode, GraphPredictor], inputs: List[Descriptor] ): - if isinstance(predictor, (MolecularStructureFeaturizer, ChemicalFormulaFeaturizer)): + if isinstance( + predictor, (MolecularStructureFeaturizer, ChemicalFormulaFeaturizer) + ): input_descriptor = predictor.input_descriptor return [ - RealDescriptor(f"{input_descriptor.key} real property {i}", lower_bound=0, upper_bound=1, units="") - for i in range(self.num_properties) - ] + [CategoricalDescriptor(f"{input_descriptor.key} categorical property", categories=["cat1", "cat2"])] + RealDescriptor( + f"{input_descriptor.key} real property {i}", + lower_bound=0, + upper_bound=1, + units="", + ) + for i in range(self.num_properties) + ] + [ + CategoricalDescriptor( + f"{input_descriptor.key} categorical property", + categories=["cat1", "cat2"], + ) + ] elif isinstance(predictor, MeanPropertyPredictor): label_str = predictor.label or "all ingredients" @@ -36,7 +52,7 @@ def from_predictor_responses( f"mean of {prop.key} for {label_str} in {predictor.input_descriptor.key}", lower_bound=0, upper_bound=1, - units="" + units="", ) for prop in predictor.properties - ] \ No newline at end of file + ] diff --git a/tests/utils/fakes/fake_execution_collection.py b/tests/utils/fakes/fake_execution_collection.py index 1719e4ece..8c3ae2f71 100644 --- a/tests/utils/fakes/fake_execution_collection.py +++ b/tests/utils/fakes/fake_execution_collection.py @@ -1,4 +1,3 @@ -from uuid import UUID from typing import Optional from citrine.informatics.executions import DesignExecution @@ -8,8 +7,9 @@ class FakeDesignExecutionCollection(DesignExecutionCollection): - - def trigger(self, execution_input: Score, max_candidates: Optional[int] = None) -> DesignExecution: + def trigger( + self, execution_input: Score, max_candidates: Optional[int] = None + ) -> DesignExecution: execution = DesignExecution() execution.score = execution_input execution.descriptors = [] diff --git a/tests/utils/fakes/fake_file_collection.py b/tests/utils/fakes/fake_file_collection.py index 10fe0b71c..699eaffc0 100644 --- a/tests/utils/fakes/fake_file_collection.py +++ b/tests/utils/fakes/fake_file_collection.py @@ -2,7 +2,6 @@ class FakeFileCollection(FileCollection): - def __init__(self): self.files = [] diff --git a/tests/utils/fakes/fake_module_collection.py b/tests/utils/fakes/fake_module_collection.py index ac261304b..f4c81adcf 100644 --- a/tests/utils/fakes/fake_module_collection.py +++ b/tests/utils/fakes/fake_module_collection.py @@ -1,7 +1,8 @@ from datetime import datetime -from typing import TypeVar, Union -from uuid import uuid4, UUID +from typing import TypeVar +from uuid import UUID, uuid4 +from citrine._rest.asynchronous_object import AsynchronousObject from citrine._rest.collection import Collection from citrine._session import Session from citrine.exceptions import BadRequest @@ -10,15 +11,13 @@ from citrine.informatics.predictors import GraphPredictor from citrine.resources.design_space import DesignSpaceCollection from citrine.resources.predictor import PredictorCollection - -from tests.utils.functions import normalize_uid from tests.utils.fakes import FakeCollection +from tests.utils.functions import normalize_uid -ModuleType = TypeVar('ModuleType', bound='Module') +ModuleType = TypeVar("ModuleType", bound=AsynchronousObject) class FakeModuleCollection(FakeCollection[ModuleType], Collection[ModuleType]): - def __init__(self, project_id, session): FakeCollection.__init__(self) self.project_id = project_id @@ -34,30 +33,24 @@ def archive(self, module_id: UUID): module.archive_time = datetime.now() return module -class FakeDesignSpaceCollection(FakeModuleCollection[DesignSpace], DesignSpaceCollection): +class FakeDesignSpaceCollection( + FakeModuleCollection[DesignSpace], DesignSpaceCollection +): def create_default(self, *, predictor_id: UUID) -> DesignSpace: return ProductDesignSpace( - f"Default design space", - description="", - dimensions=[], - subspaces=[] + "Default design space", description="", dimensions=[], subspaces=[] ) -class FakePredictorCollection(FakeModuleCollection[GraphPredictor], PredictorCollection): - +class FakePredictorCollection( + FakeModuleCollection[GraphPredictor], PredictorCollection +): def create_default( - self, - *, - training_data: DataSource, - pattern="PLAIN", - prefer_valid=True + self, *, training_data: DataSource, pattern="PLAIN", prefer_valid=True ) -> GraphPredictor: return GraphPredictor( - name=f"Default {pattern.lower()} predictor", - description="", - predictors=[] + name=f"Default {pattern.lower()} predictor", description="", predictors=[] ) - + auto_configure = create_default diff --git a/tests/utils/fakes/fake_project_collection.py b/tests/utils/fakes/fake_project_collection.py index a0f6a657f..6852e4026 100644 --- a/tests/utils/fakes/fake_project_collection.py +++ b/tests/utils/fakes/fake_project_collection.py @@ -12,8 +12,11 @@ class FakeProjectCollection(ProjectCollection): - - def __init__(self, search_implemented: bool = True, team_id: Optional[Union[UUID, str]] = None): + def __init__( + self, + search_implemented: bool = True, + team_id: Optional[Union[UUID, str]] = None, + ): super().__init__(session=FakeSession, team_id=team_id) self.projects = [] self.search_implemented = search_implemented @@ -27,7 +30,7 @@ def list(self, page: Optional[int] = None, per_page: int = 100): if page is None: return self.projects else: - return self.projects[(page - 1) * per_page:page * per_page] + return self.projects[(page - 1) * per_page : page * per_page] def search(self, search_params: Optional[dict] = None, per_page: int = 100): if not self.search_implemented: @@ -56,18 +59,25 @@ def delete(self, uuid): class FakeProject(Project): - - def __init__(self, name="foo", description="bar", num_properties=3, session=FakeSession()): + def __init__( + self, name="foo", description="bar", num_properties=3, session=FakeSession() + ): super().__init__(name=name, description=description, session=session) self.uid = uuid4() self.team_id = uuid4() self._design_spaces = FakeDesignSpaceCollection(self.uid, self.session) self._design_workflows = FakeDesignWorkflowCollection(self.uid, self.session) self._descriptor_methods = FakeDescriptorMethods(num_properties) - self._datasets = FakeDatasetCollection(team_id=self.team_id, session=self.session) + self._datasets = FakeDatasetCollection( + team_id=self.team_id, session=self.session + ) self._predictors = FakePredictorCollection(self.uid, self.session) - self._tables = FakeGemTableCollection(team_id=self.team_id, project_id=self.uid, session=self.session) - self._table_configs = FakeTableConfigCollection(team_id=self.team_id, project_id=self.uid, session=self.session) + self._tables = FakeGemTableCollection( + team_id=self.team_id, project_id=self.uid, session=self.session + ) + self._table_configs = FakeTableConfigCollection( + team_id=self.team_id, project_id=self.uid, session=self.session + ) @property def datasets(self) -> FakeDatasetCollection: diff --git a/tests/utils/fakes/fake_table_collection.py b/tests/utils/fakes/fake_table_collection.py index 5fe5b2dde..229e75c14 100644 --- a/tests/utils/fakes/fake_table_collection.py +++ b/tests/utils/fakes/fake_table_collection.py @@ -1,20 +1,23 @@ -from uuid import uuid4, UUID -from typing import List, Dict, Tuple, Optional, Union, Iterable, TypeVar, Generic +from typing import Dict, Generic, Iterable, List, Optional, Tuple, TypeVar, Union +from uuid import UUID, uuid4 from gemd.entity.link_by_uid import LinkByUID +from citrine._rest.resource import Resource from citrine._session import Session from citrine.exceptions import NotFound -from citrine.resources.material_run import MaterialRun -from citrine.resources.gemtables import GemTable, GemTableCollection -from citrine.resources.table_config import TableConfig, TableConfigCollection, TableBuildAlgorithm - from citrine.gemtables.columns import Column from citrine.gemtables.variables import Variable - +from citrine.resources.gemtables import GemTable, GemTableCollection +from citrine.resources.material_run import MaterialRun +from citrine.resources.table_config import ( + TableBuildAlgorithm, + TableConfig, + TableConfigCollection, +) from tests.utils.functions import normalize_uid -ResourceType = TypeVar('ResourceType', bound='Resource') +ResourceType = TypeVar("ResourceType", bound="Resource") class VersionedResourceStorage(Generic[ResourceType]): @@ -38,7 +41,9 @@ def _get_latest(self, uid: UUID) -> ResourceType: latest_version = max(versions.keys()) return versions[latest_version] - def get(self, uid: Union[str, UUID], *, version: Optional[int] = None) -> Optional[ResourceType]: + def get( + self, uid: Union[str, UUID], *, version: Optional[int] = None + ) -> Optional[ResourceType]: uid = normalize_uid(uid) if uid not in self.resources: return None @@ -59,7 +64,6 @@ def list_latest(self) -> List[ResourceType]: class FakeTableConfigCollection(TableConfigCollection): - def __init__(self, team_id: UUID, project_id: UUID, session: Session): super().__init__(team_id=team_id, project_id=project_id, session=session) self._storage = VersionedResourceStorage[TableConfig]() @@ -85,38 +89,43 @@ def register(self, table_config: TableConfig) -> TableConfig: return table_config - def list(self, page: Optional[int] = None, per_page: int = 100) -> Iterable[TableConfig]: + def list( + self, page: Optional[int] = None, per_page: int = 100 + ) -> Iterable[TableConfig]: configs = self._storage.list_latest() if page is None: return iter(configs) else: - return iter(configs[(page - 1)*per_page:page*per_page]) + return iter(configs[(page - 1) * per_page : page * per_page]) def default_for_material( - self, *, + self, + *, material: Union[MaterialRun, LinkByUID, str, UUID], name: str, description: str = None, algorithm: Optional[TableBuildAlgorithm] = None, - scope: str = None + scope: str = None, ) -> Tuple[TableConfig, List[Tuple[Variable, Column]]]: table_config = TableConfig( - name=name, description="", datasets=[], - rows=[], variables=[], columns=[] + name=name, description="", datasets=[], rows=[], variables=[], columns=[] ) return table_config, [] class FakeGemTableCollection(GemTableCollection): - def __init__(self, team_id: UUID, project_id: UUID, session: Session): super().__init__(team_id=team_id, project_id=project_id, session=session) self._config_map = {} # Map config UID to table UID self._table_storage = VersionedResourceStorage[GemTable]() - def build_from_config(self, config: Union[TableConfig, str, UUID], *, - version: Union[str, int] = None, - timeout: float = 15 * 60) -> GemTable: + def build_from_config( + self, + config: Union[TableConfig, str, UUID], + *, + version: Union[str, int] = None, + timeout: float = 15 * 60, + ) -> GemTable: if isinstance(config, TableConfig): config_uid = config.config_uid else: @@ -136,9 +145,9 @@ def build_from_config(self, config: Union[TableConfig, str, UUID], *, return table - def list_by_config(self, table_config_uid: UUID, - *, - per_page: int = 100) -> Iterable[GemTable]: + def list_by_config( + self, table_config_uid: UUID, *, per_page: int = 100 + ) -> Iterable[GemTable]: config_uid = normalize_uid(table_config_uid) if config_uid not in self._config_map: return iter([]) @@ -146,9 +155,6 @@ def list_by_config(self, table_config_uid: UUID, table_id = self._config_map[config_uid] return self.list_versions(table_id, per_page=per_page) - def list_versions(self, - uid: UUID, - *, - per_page: int = 100) -> Iterable[GemTable]: + def list_versions(self, uid: UUID, *, per_page: int = 100) -> Iterable[GemTable]: tables = self._table_storage.list_by_uid(uid) return iter(tables) diff --git a/tests/utils/fakes/fake_team_collection.py b/tests/utils/fakes/fake_team_collection.py index c2b59eacd..bdfd1cca4 100644 --- a/tests/utils/fakes/fake_team_collection.py +++ b/tests/utils/fakes/fake_team_collection.py @@ -4,13 +4,11 @@ class FakeTeam(Team): - def __init__(self, name): self.name = name class FakeTeamCollection(TeamCollection): - def __init__(self, session): super().__init__(session=session) self.teams = [] @@ -24,4 +22,4 @@ def list(self, page: Optional[int] = None, per_page: int = 100): if page is None: return self.teams else: - return self.teams[(page - 1) * per_page:page * per_page] + return self.teams[(page - 1) * per_page : page * per_page] diff --git a/tests/utils/fakes/fake_workflow_collection.py b/tests/utils/fakes/fake_workflow_collection.py index 6c1ba2587..0382cd7cd 100644 --- a/tests/utils/fakes/fake_workflow_collection.py +++ b/tests/utils/fakes/fake_workflow_collection.py @@ -1,18 +1,15 @@ from typing import TypeVar, Union -from uuid import uuid4, UUID +from uuid import UUID from citrine._session import Session -from citrine._utils.functions import migrate_deprecated_argument -from citrine.informatics.workflows import DesignWorkflow +from citrine.informatics.workflows import DesignWorkflow, Workflow from citrine.resources.design_workflow import DesignWorkflowCollection - from tests.utils.fakes import FakeCollection -WorkflowType = TypeVar('WorkflowType', bound='Workflow') +WorkflowType = TypeVar("WorkflowType", bound="Workflow") class FakeWorkflowCollection(FakeCollection[WorkflowType]): - def __init__(self, project_id, session: Session): FakeCollection.__init__(self) self.project_id = project_id @@ -31,5 +28,7 @@ def archive(self, uid: Union[UUID, str]): self.update(workflow) -class FakeDesignWorkflowCollection(FakeWorkflowCollection[DesignWorkflow], DesignWorkflowCollection): +class FakeDesignWorkflowCollection( + FakeWorkflowCollection[DesignWorkflow], DesignWorkflowCollection +): pass diff --git a/tests/utils/fakes/fake_workflows.py b/tests/utils/fakes/fake_workflows.py index e1afc1d44..f8d7c3cea 100644 --- a/tests/utils/fakes/fake_workflows.py +++ b/tests/utils/fakes/fake_workflows.py @@ -4,11 +4,13 @@ class FakeDesignWorkflow(DesignWorkflow): - @property def design_executions(self) -> FakeDesignExecutionCollection: """Return a resource representing all visible executions of this workflow.""" - if getattr(self, 'project_id', None) is None: - raise AttributeError('Cannot initialize execution without project reference!') + if getattr(self, "project_id", None) is None: + raise AttributeError( + "Cannot initialize execution without project reference!" + ) return FakeDesignExecutionCollection( - project_id=self.project_id, session=self._session, workflow_id=self.uid) + project_id=self.project_id, session=self._session, workflow_id=self.uid + ) diff --git a/tests/utils/session.py b/tests/utils/session.py index 0a6a7b97d..d58b400fa 100644 --- a/tests/utils/session.py +++ b/tests/utils/session.py @@ -10,7 +10,15 @@ class FakeCall: """Encapsulates a call to a FakeSession.""" - def __init__(self, method, path, json=None, params: dict = None, version: str = None, **kwargs): + def __init__( + self, + method, + path, + json=None, + params: dict = None, + version: str = None, + **kwargs, + ): self.method = method self.path = path self.json = json @@ -19,40 +27,44 @@ def __init__(self, method, path, json=None, params: dict = None, version: str = self.kwargs = kwargs def __repr__(self): - return f'FakeCall({self})' + return f"FakeCall({self})" def __str__(self) -> str: path = self.path if self.version: - path = path[1:] if path.startswith('/') else path - path = f'{self.version}/{path}' + path = path[1:] if path.startswith("/") else path + path = f"{self.version}/{path}" if self.params: - path = f'{path}?{urlencode(self.params)}' + path = f"{path}?{urlencode(self.params)}" - return f'{self.method} {path} : {dumps(self.json)}' + return f"{self.method} {path} : {dumps(self.json)}" def __eq__(self, other) -> bool: if not isinstance(other, FakeCall): return NotImplemented return ( - self.method == other.method and - self.path.lstrip('/') == other.path.lstrip('/') and # Leading slashes don't affect results - self.json == other.json and - self.params == other.params and - (not self.version or not other.version or self.version == other.version) # Allows users to check the URL version without forcing everyone to. + self.method == other.method + and self.path.lstrip("/") + == other.path.lstrip("/") # Leading slashes don't affect results + and self.json == other.json + and self.params == other.params + and ( + not self.version or not other.version or self.version == other.version + ) # Allows users to check the URL version without forcing everyone to. ) class FakeSession(Session): """Fake version of Session used to test API interaction.""" + def __init__(self): self.calls = [] self.responses = [] self.s3_endpoint_url = None self.s3_use_ssl = True - self.s3_addressing_style = 'auto' + self.s3_addressing_style = "auto" self.use_idempotent_dataset_put = False def set_response(self, resp): @@ -85,23 +97,23 @@ def delete_resource(self, path: str, **kwargs) -> dict: return self.checked_delete(path, **kwargs) def checked_get(self, path: str, **kwargs) -> dict: - self.calls.append(FakeCall('GET', path, **kwargs)) + self.calls.append(FakeCall("GET", path, **kwargs)) return self._get_response() def checked_post(self, path: str, json: dict, **kwargs) -> dict: - self.calls.append(FakeCall('POST', path, json, **kwargs)) + self.calls.append(FakeCall("POST", path, json, **kwargs)) return self._get_response(default_response=json) def checked_put(self, path: str, json: dict, **kwargs) -> dict: - self.calls.append(FakeCall('PUT', path, json, **kwargs)) + self.calls.append(FakeCall("PUT", path, json, **kwargs)) return self._get_response(default_response=json) def checked_patch(self, path: str, json: dict, **kwargs) -> dict: - self.calls.append(FakeCall('PATCH', path, json, **kwargs)) + self.calls.append(FakeCall("PATCH", path, json, **kwargs)) return self._get_response(default_response=json) def checked_delete(self, path: str, **kwargs) -> dict: - self.calls.append(FakeCall('DELETE', path, **kwargs)) + self.calls.append(FakeCall("DELETE", path, **kwargs)) return self._get_response() def _get_response(self, default_response: dict = None): @@ -121,39 +133,45 @@ def _get_response(self, default_response: dict = None): return response @staticmethod - def cursor_paged_resource(base_method: Callable[..., dict], path: str, - forward: bool = True, per_page: int = 100, - version: str = 'v2', **kwargs) -> Iterator[dict]: + def cursor_paged_resource( + base_method: Callable[..., dict], + path: str, + forward: bool = True, + per_page: int = 100, + version: str = "v2", + **kwargs, + ) -> Iterator[dict]: """ Returns a flat generator of results for an API query. Results are fetched in chunks of size `per_page` and loaded lazily. """ - params = kwargs.get('params', {}) - params['forward'] = forward - params['ascending'] = forward - params['per_page'] = per_page - kwargs['params'] = params + params = kwargs.get("params", {}) + params["forward"] = forward + params["ascending"] = forward + params["per_page"] = per_page + kwargs["params"] = params while True: response_json = base_method(path, version=version, **kwargs) - for obj in response_json['contents']: + for obj in response_json["contents"]: yield obj - cursor = response_json.get('next') + cursor = response_json.get("next") if cursor is None: break - params['cursor'] = cursor + params["cursor"] = cursor class FakePaginatedSession(FakeSession): """Fake version of Session used to test API interaction, with support for pagination params page and per_page.""" + def checked_get(self, path: str, **kwargs) -> dict: - params = kwargs.get('params') - self.calls.append(FakeCall('GET', path, params=params)) + params = kwargs.get("params") + self.calls.append(FakeCall("GET", path, params=params)) return self._get_response(**params) def checked_post(self, path: str, json: dict, **kwargs) -> dict: - params = kwargs.get('params') - self.calls.append(FakeCall('POST', path, json, params=params)) + params = kwargs.get("params") + self.calls.append(FakeCall("POST", path, json, params=params)) return self._get_response(**params) def _get_response(self, **kwargs): @@ -162,23 +180,23 @@ def _get_response(self, **kwargs): """ if not self.responses: return {} - - page = kwargs.get('page', 1) - per_page = kwargs.get('per_page', 20) + + page = kwargs.get("page", 1) + per_page = kwargs.get("per_page", 20) start_idx = (page - 1) * per_page - # in case the response takes the shape of something like + # in case the response takes the shape of something like # {'projects': [Project1, Project2, etc.]} has_collection_key = isinstance(self.responses[0], dict) if has_collection_key: key = list(self.responses[0].keys())[0] - list_values = self.responses[0][key][start_idx:start_idx + per_page] + list_values = self.responses[0][key][start_idx : start_idx + per_page] return dict.fromkeys([key], list_values) else: - return self.responses[0][start_idx:start_idx + per_page] + return self.responses[0][start_idx : start_idx + per_page] class FakeS3Client: @@ -200,7 +218,7 @@ def put_object(self, *args, **kwargs): class FakeRequestResponse: """A fake version of a requests.request() response.""" - def __init__(self, status_code, content=None, text="", reason='BadRequest'): + def __init__(self, status_code, content=None, text="", reason="BadRequest"): self.content = content self.text = text self.status_code = status_code @@ -215,11 +233,19 @@ def json(self): # the method to FakeRequest. class FakeRequestResponseApiError: """A fake version of a requests.request() response that has an ApiError""" - def __init__(self, code: int, message: str, validation_errors: List[ValidationError], - reason: str = 'BadRequest'): - self.api_error_json = {"code": code, - "message": message, - "validation_errors": [ve.dump() for ve in validation_errors]} + + def __init__( + self, + code: int, + message: str, + validation_errors: List[ValidationError], + reason: str = "BadRequest", + ): + self.api_error_json = { + "code": code, + "message": message, + "validation_errors": [ve.dump() for ve in validation_errors], + } self.text = message self.status_code = code self.reason = reason @@ -245,18 +271,20 @@ def make_fake_cursor_request_function(all_results: list): all_results: list All results in the result set to simulate paging """ + # TODO add logic for `forward` and `ascending` def fake_cursor_request(*_, params=None, **__): - page_size = params['per_page'] - if 'cursor' in params: - cursor = int(params['cursor']) - contents = all_results[cursor + 1:cursor + page_size + 1] + page_size = params["per_page"] + if "cursor" in params: + cursor = int(params["cursor"]) + contents = all_results[cursor + 1 : cursor + page_size + 1] else: contents = all_results[:page_size] - response = {'contents': contents} + response = {"contents": contents} if contents: - response['next'] = str(all_results.index(contents[-1])) - if 'cursor' in params: - response['previous'] = str(all_results.index(contents[0])) + response["next"] = str(all_results.index(contents[-1])) + if "cursor" in params: + response["previous"] = str(all_results.index(contents[0])) return response + return fake_cursor_request diff --git a/tests/utils/wait.py b/tests/utils/wait.py index 68a010d86..7a166f38e 100644 --- a/tests/utils/wait.py +++ b/tests/utils/wait.py @@ -16,7 +16,9 @@ def wait_until(condition, timeout=30, interval=0.5): return result -def generate_fake_wait_while(*, status: str, status_detail: Optional[List[StatusDetail]] = None) -> Callable: +def generate_fake_wait_while( + *, status: str, status_detail: Optional[List[StatusDetail]] = None +) -> Callable: """Generate a wait_while function that mutates a resource with the specified status info.""" status_detail = status_detail or [] diff --git a/uv.lock b/uv.lock new file mode 100644 index 000000000..d11fb9239 --- /dev/null +++ b/uv.lock @@ -0,0 +1,739 @@ +version = 1 +revision = 3 +requires-python = ">=3.12" + +[[package]] +name = "arrow" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, + { name = "tzdata" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/33/032cdc44182491aa708d06a68b62434140d8c50820a087fac7af37703357/arrow-1.4.0.tar.gz", hash = "sha256:ed0cc050e98001b8779e84d461b0098c4ac597e88704a655582b21d116e526d7", size = 152931, upload-time = "2025-10-18T17:46:46.761Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/c9/d7977eaacb9df673210491da99e6a247e93df98c715fc43fd136ce1d3d33/arrow-1.4.0-py3-none-any.whl", hash = "sha256:749f0769958ebdc79c173ff0b0670d59051a535fa26e8eba02953dc19eb43205", size = 68797, upload-time = "2025-10-18T17:46:45.663Z" }, +] + +[[package]] +name = "boto3" +version = "1.42.24" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, + { name = "jmespath" }, + { name = "s3transfer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ee/21/8be0e3685c3a4868be48d8d2f6e5b4641727e1d8a5d396b8b401d2b5f06e/boto3-1.42.24.tar.gz", hash = "sha256:c47a2f40df933e3861fc66fd8d6b87ee36d4361663a7e7ba39a87f5a78b2eae1", size = 112788, upload-time = "2026-01-07T20:30:51.019Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/75/bbfccb268f9faa4f59030888e859dca9797a980b77d6a074113af73bd4bf/boto3-1.42.24-py3-none-any.whl", hash = "sha256:8ed6ad670a5a2d7f66c1b0d3362791b48392c7a08f78479f5d8ab319a4d9118f", size = 140572, upload-time = "2026-01-07T20:30:49.431Z" }, +] + +[[package]] +name = "botocore" +version = "1.42.24" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jmespath" }, + { name = "python-dateutil" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/12/d7/bb4a4e839b238ffb67b002d7326b328ebe5eb23ed5180f2ca10399a802de/botocore-1.42.24.tar.gz", hash = "sha256:be8d1bea64fb91eea08254a1e5fea057e4428d08e61f4e11083a02cafc1f8cc6", size = 14878455, upload-time = "2026-01-07T20:30:40.379Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ff/d4/f2655d777eed8b069ecab3761454cb83f830f8be8b5b0d292e4b3a980d00/botocore-1.42.24-py3-none-any.whl", hash = "sha256:8fca9781d7c84f7ad070fceffaff7179c4aa7a5ffb27b43df9d1d957801e0a8d", size = 14551806, upload-time = "2026-01-07T20:30:38.103Z" }, +] + +[[package]] +name = "certifi" +version = "2026.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/2d/a891ca51311197f6ad14a7ef42e2399f36cf2f9bd44752b3dc4eab60fdc5/certifi-2026.1.4.tar.gz", hash = "sha256:ac726dd470482006e014ad384921ed6438c457018f4b3d204aea4281258b2120", size = 154268, upload-time = "2026-01-04T02:42:41.825Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl", hash = "sha256:9943707519e4add1115f44c2bc244f782c0249876bf51b6599fee1ffbedd685c", size = 152900, upload-time = "2026-01-04T02:42:40.15Z" }, +] + +[[package]] +name = "cfgv" +version = "3.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/b5/721b8799b04bf9afe054a3899c6cf4e880fcf8563cc71c15610242490a0c/cfgv-3.5.0.tar.gz", hash = "sha256:d5b1034354820651caa73ede66a6294d6e95c1b00acc5e9b098e917404669132", size = 7334, upload-time = "2025-11-19T20:55:51.612Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/db/3c/33bac158f8ab7f89b2e59426d5fe2e4f63f7ed25df84c036890172b412b5/cfgv-3.5.0-py2.py3-none-any.whl", hash = "sha256:a8dc6b26ad22ff227d2634a65cb388215ce6cc96bbcc5cfde7641ae87e8dacc0", size = 7445, upload-time = "2025-11-19T20:55:50.744Z" }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418, upload-time = "2025-10-14T04:42:32.879Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/85/1637cd4af66fa687396e757dec650f28025f2a2f5a5531a3208dc0ec43f2/charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394", size = 208425, upload-time = "2025-10-14T04:40:53.353Z" }, + { url = "https://files.pythonhosted.org/packages/9d/6a/04130023fef2a0d9c62d0bae2649b69f7b7d8d24ea5536feef50551029df/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25", size = 148162, upload-time = "2025-10-14T04:40:54.558Z" }, + { url = "https://files.pythonhosted.org/packages/78/29/62328d79aa60da22c9e0b9a66539feae06ca0f5a4171ac4f7dc285b83688/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef", size = 144558, upload-time = "2025-10-14T04:40:55.677Z" }, + { url = "https://files.pythonhosted.org/packages/86/bb/b32194a4bf15b88403537c2e120b817c61cd4ecffa9b6876e941c3ee38fe/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d", size = 161497, upload-time = "2025-10-14T04:40:57.217Z" }, + { url = "https://files.pythonhosted.org/packages/19/89/a54c82b253d5b9b111dc74aca196ba5ccfcca8242d0fb64146d4d3183ff1/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8", size = 159240, upload-time = "2025-10-14T04:40:58.358Z" }, + { url = "https://files.pythonhosted.org/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86", size = 153471, upload-time = "2025-10-14T04:40:59.468Z" }, + { url = "https://files.pythonhosted.org/packages/61/fa/fbf177b55bdd727010f9c0a3c49eefa1d10f960e5f09d1d887bf93c2e698/charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a", size = 150864, upload-time = "2025-10-14T04:41:00.623Z" }, + { url = "https://files.pythonhosted.org/packages/05/12/9fbc6a4d39c0198adeebbde20b619790e9236557ca59fc40e0e3cebe6f40/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f", size = 150647, upload-time = "2025-10-14T04:41:01.754Z" }, + { url = "https://files.pythonhosted.org/packages/ad/1f/6a9a593d52e3e8c5d2b167daf8c6b968808efb57ef4c210acb907c365bc4/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc", size = 145110, upload-time = "2025-10-14T04:41:03.231Z" }, + { url = "https://files.pythonhosted.org/packages/30/42/9a52c609e72471b0fc54386dc63c3781a387bb4fe61c20231a4ebcd58bdd/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf", size = 162839, upload-time = "2025-10-14T04:41:04.715Z" }, + { url = "https://files.pythonhosted.org/packages/c4/5b/c0682bbf9f11597073052628ddd38344a3d673fda35a36773f7d19344b23/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15", size = 150667, upload-time = "2025-10-14T04:41:05.827Z" }, + { url = "https://files.pythonhosted.org/packages/e4/24/a41afeab6f990cf2daf6cb8c67419b63b48cf518e4f56022230840c9bfb2/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9", size = 160535, upload-time = "2025-10-14T04:41:06.938Z" }, + { url = "https://files.pythonhosted.org/packages/2a/e5/6a4ce77ed243c4a50a1fecca6aaaab419628c818a49434be428fe24c9957/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0", size = 154816, upload-time = "2025-10-14T04:41:08.101Z" }, + { url = "https://files.pythonhosted.org/packages/a8/ef/89297262b8092b312d29cdb2517cb1237e51db8ecef2e9af5edbe7b683b1/charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26", size = 99694, upload-time = "2025-10-14T04:41:09.23Z" }, + { url = "https://files.pythonhosted.org/packages/3d/2d/1e5ed9dd3b3803994c155cd9aacb60c82c331bad84daf75bcb9c91b3295e/charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525", size = 107131, upload-time = "2025-10-14T04:41:10.467Z" }, + { url = "https://files.pythonhosted.org/packages/d0/d9/0ed4c7098a861482a7b6a95603edce4c0d9db2311af23da1fb2b75ec26fc/charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3", size = 100390, upload-time = "2025-10-14T04:41:11.915Z" }, + { url = "https://files.pythonhosted.org/packages/97/45/4b3a1239bbacd321068ea6e7ac28875b03ab8bc0aa0966452db17cd36714/charset_normalizer-3.4.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e1f185f86a6f3403aa2420e815904c67b2f9ebc443f045edd0de921108345794", size = 208091, upload-time = "2025-10-14T04:41:13.346Z" }, + { url = "https://files.pythonhosted.org/packages/7d/62/73a6d7450829655a35bb88a88fca7d736f9882a27eacdca2c6d505b57e2e/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b39f987ae8ccdf0d2642338faf2abb1862340facc796048b604ef14919e55ed", size = 147936, upload-time = "2025-10-14T04:41:14.461Z" }, + { url = "https://files.pythonhosted.org/packages/89/c5/adb8c8b3d6625bef6d88b251bbb0d95f8205831b987631ab0c8bb5d937c2/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3162d5d8ce1bb98dd51af660f2121c55d0fa541b46dff7bb9b9f86ea1d87de72", size = 144180, upload-time = "2025-10-14T04:41:15.588Z" }, + { url = "https://files.pythonhosted.org/packages/91/ed/9706e4070682d1cc219050b6048bfd293ccf67b3d4f5a4f39207453d4b99/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:81d5eb2a312700f4ecaa977a8235b634ce853200e828fbadf3a9c50bab278328", size = 161346, upload-time = "2025-10-14T04:41:16.738Z" }, + { url = "https://files.pythonhosted.org/packages/d5/0d/031f0d95e4972901a2f6f09ef055751805ff541511dc1252ba3ca1f80cf5/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5bd2293095d766545ec1a8f612559f6b40abc0eb18bb2f5d1171872d34036ede", size = 158874, upload-time = "2025-10-14T04:41:17.923Z" }, + { url = "https://files.pythonhosted.org/packages/f5/83/6ab5883f57c9c801ce5e5677242328aa45592be8a00644310a008d04f922/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a8a8b89589086a25749f471e6a900d3f662d1d3b6e2e59dcecf787b1cc3a1894", size = 153076, upload-time = "2025-10-14T04:41:19.106Z" }, + { url = "https://files.pythonhosted.org/packages/75/1e/5ff781ddf5260e387d6419959ee89ef13878229732732ee73cdae01800f2/charset_normalizer-3.4.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc7637e2f80d8530ee4a78e878bce464f70087ce73cf7c1caf142416923b98f1", size = 150601, upload-time = "2025-10-14T04:41:20.245Z" }, + { url = "https://files.pythonhosted.org/packages/d7/57/71be810965493d3510a6ca79b90c19e48696fb1ff964da319334b12677f0/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f8bf04158c6b607d747e93949aa60618b61312fe647a6369f88ce2ff16043490", size = 150376, upload-time = "2025-10-14T04:41:21.398Z" }, + { url = "https://files.pythonhosted.org/packages/e5/d5/c3d057a78c181d007014feb7e9f2e65905a6c4ef182c0ddf0de2924edd65/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:554af85e960429cf30784dd47447d5125aaa3b99a6f0683589dbd27e2f45da44", size = 144825, upload-time = "2025-10-14T04:41:22.583Z" }, + { url = "https://files.pythonhosted.org/packages/e6/8c/d0406294828d4976f275ffbe66f00266c4b3136b7506941d87c00cab5272/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:74018750915ee7ad843a774364e13a3db91682f26142baddf775342c3f5b1133", size = 162583, upload-time = "2025-10-14T04:41:23.754Z" }, + { url = "https://files.pythonhosted.org/packages/d7/24/e2aa1f18c8f15c4c0e932d9287b8609dd30ad56dbe41d926bd846e22fb8d/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:c0463276121fdee9c49b98908b3a89c39be45d86d1dbaa22957e38f6321d4ce3", size = 150366, upload-time = "2025-10-14T04:41:25.27Z" }, + { url = "https://files.pythonhosted.org/packages/e4/5b/1e6160c7739aad1e2df054300cc618b06bf784a7a164b0f238360721ab86/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:362d61fd13843997c1c446760ef36f240cf81d3ebf74ac62652aebaf7838561e", size = 160300, upload-time = "2025-10-14T04:41:26.725Z" }, + { url = "https://files.pythonhosted.org/packages/7a/10/f882167cd207fbdd743e55534d5d9620e095089d176d55cb22d5322f2afd/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9a26f18905b8dd5d685d6d07b0cdf98a79f3c7a918906af7cc143ea2e164c8bc", size = 154465, upload-time = "2025-10-14T04:41:28.322Z" }, + { url = "https://files.pythonhosted.org/packages/89/66/c7a9e1b7429be72123441bfdbaf2bc13faab3f90b933f664db506dea5915/charset_normalizer-3.4.4-cp313-cp313-win32.whl", hash = "sha256:9b35f4c90079ff2e2edc5b26c0c77925e5d2d255c42c74fdb70fb49b172726ac", size = 99404, upload-time = "2025-10-14T04:41:29.95Z" }, + { url = "https://files.pythonhosted.org/packages/c4/26/b9924fa27db384bdcd97ab83b4f0a8058d96ad9626ead570674d5e737d90/charset_normalizer-3.4.4-cp313-cp313-win_amd64.whl", hash = "sha256:b435cba5f4f750aa6c0a0d92c541fb79f69a387c91e61f1795227e4ed9cece14", size = 107092, upload-time = "2025-10-14T04:41:31.188Z" }, + { url = "https://files.pythonhosted.org/packages/af/8f/3ed4bfa0c0c72a7ca17f0380cd9e4dd842b09f664e780c13cff1dcf2ef1b/charset_normalizer-3.4.4-cp313-cp313-win_arm64.whl", hash = "sha256:542d2cee80be6f80247095cc36c418f7bddd14f4a6de45af91dfad36d817bba2", size = 100408, upload-time = "2025-10-14T04:41:32.624Z" }, + { url = "https://files.pythonhosted.org/packages/2a/35/7051599bd493e62411d6ede36fd5af83a38f37c4767b92884df7301db25d/charset_normalizer-3.4.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:da3326d9e65ef63a817ecbcc0df6e94463713b754fe293eaa03da99befb9a5bd", size = 207746, upload-time = "2025-10-14T04:41:33.773Z" }, + { url = "https://files.pythonhosted.org/packages/10/9a/97c8d48ef10d6cd4fcead2415523221624bf58bcf68a802721a6bc807c8f/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8af65f14dc14a79b924524b1e7fffe304517b2bff5a58bf64f30b98bbc5079eb", size = 147889, upload-time = "2025-10-14T04:41:34.897Z" }, + { url = "https://files.pythonhosted.org/packages/10/bf/979224a919a1b606c82bd2c5fa49b5c6d5727aa47b4312bb27b1734f53cd/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74664978bb272435107de04e36db5a9735e78232b85b77d45cfb38f758efd33e", size = 143641, upload-time = "2025-10-14T04:41:36.116Z" }, + { url = "https://files.pythonhosted.org/packages/ba/33/0ad65587441fc730dc7bd90e9716b30b4702dc7b617e6ba4997dc8651495/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:752944c7ffbfdd10c074dc58ec2d5a8a4cd9493b314d367c14d24c17684ddd14", size = 160779, upload-time = "2025-10-14T04:41:37.229Z" }, + { url = "https://files.pythonhosted.org/packages/67/ed/331d6b249259ee71ddea93f6f2f0a56cfebd46938bde6fcc6f7b9a3d0e09/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d1f13550535ad8cff21b8d757a3257963e951d96e20ec82ab44bc64aeb62a191", size = 159035, upload-time = "2025-10-14T04:41:38.368Z" }, + { url = "https://files.pythonhosted.org/packages/67/ff/f6b948ca32e4f2a4576aa129d8bed61f2e0543bf9f5f2b7fc3758ed005c9/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ecaae4149d99b1c9e7b88bb03e3221956f68fd6d50be2ef061b2381b61d20838", size = 152542, upload-time = "2025-10-14T04:41:39.862Z" }, + { url = "https://files.pythonhosted.org/packages/16/85/276033dcbcc369eb176594de22728541a925b2632f9716428c851b149e83/charset_normalizer-3.4.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:cb6254dc36b47a990e59e1068afacdcd02958bdcce30bb50cc1700a8b9d624a6", size = 149524, upload-time = "2025-10-14T04:41:41.319Z" }, + { url = "https://files.pythonhosted.org/packages/9e/f2/6a2a1f722b6aba37050e626530a46a68f74e63683947a8acff92569f979a/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c8ae8a0f02f57a6e61203a31428fa1d677cbe50c93622b4149d5c0f319c1d19e", size = 150395, upload-time = "2025-10-14T04:41:42.539Z" }, + { url = "https://files.pythonhosted.org/packages/60/bb/2186cb2f2bbaea6338cad15ce23a67f9b0672929744381e28b0592676824/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:47cc91b2f4dd2833fddaedd2893006b0106129d4b94fdb6af1f4ce5a9965577c", size = 143680, upload-time = "2025-10-14T04:41:43.661Z" }, + { url = "https://files.pythonhosted.org/packages/7d/a5/bf6f13b772fbb2a90360eb620d52ed8f796f3c5caee8398c3b2eb7b1c60d/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:82004af6c302b5d3ab2cfc4cc5f29db16123b1a8417f2e25f9066f91d4411090", size = 162045, upload-time = "2025-10-14T04:41:44.821Z" }, + { url = "https://files.pythonhosted.org/packages/df/c5/d1be898bf0dc3ef9030c3825e5d3b83f2c528d207d246cbabe245966808d/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2b7d8f6c26245217bd2ad053761201e9f9680f8ce52f0fcd8d0755aeae5b2152", size = 149687, upload-time = "2025-10-14T04:41:46.442Z" }, + { url = "https://files.pythonhosted.org/packages/a5/42/90c1f7b9341eef50c8a1cb3f098ac43b0508413f33affd762855f67a410e/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:799a7a5e4fb2d5898c60b640fd4981d6a25f1c11790935a44ce38c54e985f828", size = 160014, upload-time = "2025-10-14T04:41:47.631Z" }, + { url = "https://files.pythonhosted.org/packages/76/be/4d3ee471e8145d12795ab655ece37baed0929462a86e72372fd25859047c/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:99ae2cffebb06e6c22bdc25801d7b30f503cc87dbd283479e7b606f70aff57ec", size = 154044, upload-time = "2025-10-14T04:41:48.81Z" }, + { url = "https://files.pythonhosted.org/packages/b0/6f/8f7af07237c34a1defe7defc565a9bc1807762f672c0fde711a4b22bf9c0/charset_normalizer-3.4.4-cp314-cp314-win32.whl", hash = "sha256:f9d332f8c2a2fcbffe1378594431458ddbef721c1769d78e2cbc06280d8155f9", size = 99940, upload-time = "2025-10-14T04:41:49.946Z" }, + { url = "https://files.pythonhosted.org/packages/4b/51/8ade005e5ca5b0d80fb4aff72a3775b325bdc3d27408c8113811a7cbe640/charset_normalizer-3.4.4-cp314-cp314-win_amd64.whl", hash = "sha256:8a6562c3700cce886c5be75ade4a5db4214fda19fede41d9792d100288d8f94c", size = 107104, upload-time = "2025-10-14T04:41:51.051Z" }, + { url = "https://files.pythonhosted.org/packages/da/5f/6b8f83a55bb8278772c5ae54a577f3099025f9ade59d0136ac24a0df4bde/charset_normalizer-3.4.4-cp314-cp314-win_arm64.whl", hash = "sha256:de00632ca48df9daf77a2c65a484531649261ec9f25489917f09e455cb09ddb2", size = 100743, upload-time = "2025-10-14T04:41:52.122Z" }, + { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, +] + +[[package]] +name = "citrine" +version = "0.1.0" +source = { editable = "." } +dependencies = [ + { name = "arrow" }, + { name = "boto3" }, + { name = "deprecation" }, + { name = "gemd" }, + { name = "requests" }, + { name = "tqdm" }, + { name = "urllib3" }, +] + +[package.optional-dependencies] +test = [ + { name = "factory-boy" }, + { name = "mock" }, + { name = "pandas" }, + { name = "pre-commit" }, + { name = "pytest" }, + { name = "requests-mock" }, + { name = "ruff" }, + { name = "ty" }, +] + +[package.metadata] +requires-dist = [ + { name = "arrow" }, + { name = "boto3" }, + { name = "deprecation" }, + { name = "factory-boy", marker = "extra == 'test'" }, + { name = "gemd" }, + { name = "mock", marker = "extra == 'test'" }, + { name = "pandas", marker = "extra == 'test'" }, + { name = "pre-commit", marker = "extra == 'test'" }, + { name = "pytest", marker = "extra == 'test'" }, + { name = "requests" }, + { name = "requests-mock", marker = "extra == 'test'" }, + { name = "ruff", marker = "extra == 'test'" }, + { name = "tqdm" }, + { name = "ty", marker = "extra == 'test'" }, + { name = "urllib3" }, +] +provides-extras = ["test"] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "deprecation" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/d3/8ae2869247df154b64c1884d7346d412fed0c49df84db635aab2d1c40e62/deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff", size = 173788, upload-time = "2020-04-20T14:23:38.738Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/c3/253a89ee03fc9b9682f1541728eb66db7db22148cd94f89ab22528cd1e1b/deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a", size = 11178, upload-time = "2020-04-20T14:23:36.581Z" }, +] + +[[package]] +name = "distlib" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605, upload-time = "2025-07-17T16:52:00.465Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, +] + +[[package]] +name = "factory-boy" +version = "3.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "faker" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ba/98/75cacae9945f67cfe323829fc2ac451f64517a8a330b572a06a323997065/factory_boy-3.3.3.tar.gz", hash = "sha256:866862d226128dfac7f2b4160287e899daf54f2612778327dd03d0e2cb1e3d03", size = 164146, upload-time = "2025-02-03T09:49:04.433Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/8d/2bc5f5546ff2ccb3f7de06742853483ab75bf74f36a92254702f8baecc79/factory_boy-3.3.3-py2.py3-none-any.whl", hash = "sha256:1c39e3289f7e667c4285433f305f8d506efc2fe9c73aaea4151ebd5cdea394fc", size = 37036, upload-time = "2025-02-03T09:49:01.659Z" }, +] + +[[package]] +name = "faker" +version = "40.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tzdata" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d7/1d/aa43ef59589ddf3647df918143f1bac9eb004cce1c43124ee3347061797d/faker-40.1.0.tar.gz", hash = "sha256:c402212a981a8a28615fea9120d789e3f6062c0c259a82bfb8dff5d273e539d2", size = 1948784, upload-time = "2025-12-29T18:06:00.659Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/23/e22da510e1ec1488966330bf76d8ff4bd535cbfc93660eeb7657761a1bb2/faker-40.1.0-py3-none-any.whl", hash = "sha256:a616d35818e2a2387c297de80e2288083bc915e24b7e39d2fb5bc66cce3a929f", size = 1985317, upload-time = "2025-12-29T18:05:58.831Z" }, +] + +[[package]] +name = "filelock" +version = "3.20.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c1/e0/a75dbe4bca1e7d41307323dad5ea2efdd95408f74ab2de8bd7dba9b51a1a/filelock-3.20.2.tar.gz", hash = "sha256:a2241ff4ddde2a7cebddf78e39832509cb045d18ec1a09d7248d6bfc6bfbbe64", size = 19510, upload-time = "2026-01-02T15:33:32.582Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9a/30/ab407e2ec752aa541704ed8f93c11e2a5d92c168b8a755d818b74a3c5c2d/filelock-3.20.2-py3-none-any.whl", hash = "sha256:fbba7237d6ea277175a32c54bb71ef814a8546d8601269e1bfc388de333974e8", size = 16697, upload-time = "2026-01-02T15:33:31.133Z" }, +] + +[[package]] +name = "flexcache" +version = "0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/55/b0/8a21e330561c65653d010ef112bf38f60890051d244ede197ddaa08e50c1/flexcache-0.3.tar.gz", hash = "sha256:18743bd5a0621bfe2cf8d519e4c3bfdf57a269c15d1ced3fb4b64e0ff4600656", size = 15816, upload-time = "2024-03-09T03:21:07.555Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/cd/c883e1a7c447479d6e13985565080e3fea88ab5a107c21684c813dba1875/flexcache-0.3-py3-none-any.whl", hash = "sha256:d43c9fea82336af6e0115e308d9d33a185390b8346a017564611f1466dcd2e32", size = 13263, upload-time = "2024-03-09T03:21:05.635Z" }, +] + +[[package]] +name = "flexparser" +version = "0.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/82/99/b4de7e39e8eaf8207ba1a8fa2241dd98b2ba72ae6e16960d8351736d8702/flexparser-0.4.tar.gz", hash = "sha256:266d98905595be2ccc5da964fe0a2c3526fbbffdc45b65b3146d75db992ef6b2", size = 31799, upload-time = "2024-11-07T02:00:56.249Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/5e/3be305568fe5f34448807976dc82fc151d76c3e0e03958f34770286278c1/flexparser-0.4-py3-none-any.whl", hash = "sha256:3738b456192dcb3e15620f324c447721023c0293f6af9955b481e91d00179846", size = 27625, upload-time = "2024-11-07T02:00:54.523Z" }, +] + +[[package]] +name = "gemd" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "deprecation" }, + { name = "importlib-resources" }, + { name = "pint" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/32/fd/7b4851eb98a70072aff0eafbc2a7b94f966660dd006b48a5c4a646c97aca/gemd-2.2.0.tar.gz", hash = "sha256:0a95b27bf021e4c2b48936eee49ecb1de69077000a741d07874c416ea1455bd7", size = 152300, upload-time = "2025-10-01T20:34:48.086Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/ca/6545b00cb8fa49f3e5029f5584dd813930ab5d10914941d58932f7e97695/gemd-2.2.0-py3-none-any.whl", hash = "sha256:12529d6ac9edf2e628a8d37dff6030b2d41438250631c85193e41607fc197628", size = 202969, upload-time = "2025-10-01T20:34:47.13Z" }, +] + +[[package]] +name = "identify" +version = "2.6.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ff/e7/685de97986c916a6d93b3876139e00eef26ad5bbbd61925d670ae8013449/identify-2.6.15.tar.gz", hash = "sha256:e4f4864b96c6557ef2a1e1c951771838f4edc9df3a72ec7118b338801b11c7bf", size = 99311, upload-time = "2025-10-02T17:43:40.631Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/1c/e5fd8f973d4f375adb21565739498e2e9a1e54c858a97b9a8ccfdc81da9b/identify-2.6.15-py2.py3-none-any.whl", hash = "sha256:1181ef7608e00704db228516541eb83a88a9f94433a8c80bb9b5bd54b1d81757", size = 99183, upload-time = "2025-10-02T17:43:39.137Z" }, +] + +[[package]] +name = "idna" +version = "3.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582, upload-time = "2025-10-12T14:55:20.501Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, +] + +[[package]] +name = "importlib-resources" +version = "6.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cf/8c/f834fbf984f691b4f7ff60f50b514cc3de5cc08abfc3295564dd89c5e2e7/importlib_resources-6.5.2.tar.gz", hash = "sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c", size = 44693, upload-time = "2025-01-03T18:51:56.698Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461, upload-time = "2025-01-03T18:51:54.306Z" }, +] + +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + +[[package]] +name = "jmespath" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/00/2a/e867e8531cf3e36b41201936b7fa7ba7b5702dbef42922193f05c8976cd6/jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe", size = 25843, upload-time = "2022-06-17T18:00:12.224Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" }, +] + +[[package]] +name = "mock" +version = "5.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/07/8c/14c2ae915e5f9dca5a22edd68b35be94400719ccfa068a03e0fb63d0f6f6/mock-5.2.0.tar.gz", hash = "sha256:4e460e818629b4b173f32d08bf30d3af8123afbb8e04bb5707a1fd4799e503f0", size = 92796, upload-time = "2025-03-03T12:31:42.911Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/d9/617e6af809bf3a1d468e0d58c3997b1dc219a9a9202e650d30c2fc85d481/mock-5.2.0-py3-none-any.whl", hash = "sha256:7ba87f72ca0e915175596069dbbcc7c75af7b5e9b9bc107ad6349ede0819982f", size = 31617, upload-time = "2025-03-03T12:31:41.518Z" }, +] + +[[package]] +name = "nodeenv" +version = "1.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, +] + +[[package]] +name = "numpy" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a4/7a/6a3d14e205d292b738db449d0de649b373a59edb0d0b4493821d0a3e8718/numpy-2.4.0.tar.gz", hash = "sha256:6e504f7b16118198f138ef31ba24d985b124c2c469fe8467007cf30fd992f934", size = 20685720, upload-time = "2025-12-20T16:18:19.023Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/ff/f6400ffec95de41c74b8e73df32e3fff1830633193a7b1e409be7fb1bb8c/numpy-2.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2a8b6bb8369abefb8bd1801b054ad50e02b3275c8614dc6e5b0373c305291037", size = 16653117, upload-time = "2025-12-20T16:16:06.709Z" }, + { url = "https://files.pythonhosted.org/packages/fd/28/6c23e97450035072e8d830a3c411bf1abd1f42c611ff9d29e3d8f55c6252/numpy-2.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2e284ca13d5a8367e43734148622caf0b261b275673823593e3e3634a6490f83", size = 12369711, upload-time = "2025-12-20T16:16:08.758Z" }, + { url = "https://files.pythonhosted.org/packages/bc/af/acbef97b630ab1bb45e6a7d01d1452e4251aa88ce680ac36e56c272120ec/numpy-2.4.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:49ff32b09f5aa0cd30a20c2b39db3e669c845589f2b7fc910365210887e39344", size = 5198355, upload-time = "2025-12-20T16:16:10.902Z" }, + { url = "https://files.pythonhosted.org/packages/c1/c8/4e0d436b66b826f2e53330adaa6311f5cac9871a5b5c31ad773b27f25a74/numpy-2.4.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:36cbfb13c152b1c7c184ddac43765db8ad672567e7bafff2cc755a09917ed2e6", size = 6545298, upload-time = "2025-12-20T16:16:12.607Z" }, + { url = "https://files.pythonhosted.org/packages/ef/27/e1f5d144ab54eac34875e79037011d511ac57b21b220063310cb96c80fbc/numpy-2.4.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:35ddc8f4914466e6fc954c76527aa91aa763682a4f6d73249ef20b418fe6effb", size = 14398387, upload-time = "2025-12-20T16:16:14.257Z" }, + { url = "https://files.pythonhosted.org/packages/67/64/4cb909dd5ab09a9a5d086eff9586e69e827b88a5585517386879474f4cf7/numpy-2.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dc578891de1db95b2a35001b695451767b580bb45753717498213c5ff3c41d63", size = 16363091, upload-time = "2025-12-20T16:16:17.32Z" }, + { url = "https://files.pythonhosted.org/packages/9d/9c/8efe24577523ec6809261859737cf117b0eb6fdb655abdfdc81b2e468ce4/numpy-2.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98e81648e0b36e325ab67e46b5400a7a6d4a22b8a7c8e8bbfe20e7db7906bf95", size = 16176394, upload-time = "2025-12-20T16:16:19.524Z" }, + { url = "https://files.pythonhosted.org/packages/61/f0/1687441ece7b47a62e45a1f82015352c240765c707928edd8aef875d5951/numpy-2.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d57b5046c120561ba8fa8e4030fbb8b822f3063910fa901ffadf16e2b7128ad6", size = 18287378, upload-time = "2025-12-20T16:16:22.866Z" }, + { url = "https://files.pythonhosted.org/packages/d3/6f/f868765d44e6fc466467ed810ba9d8d6db1add7d4a748abfa2a4c99a3194/numpy-2.4.0-cp312-cp312-win32.whl", hash = "sha256:92190db305a6f48734d3982f2c60fa30d6b5ee9bff10f2887b930d7b40119f4c", size = 5955432, upload-time = "2025-12-20T16:16:25.06Z" }, + { url = "https://files.pythonhosted.org/packages/d4/b5/94c1e79fcbab38d1ca15e13777477b2914dd2d559b410f96949d6637b085/numpy-2.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:680060061adb2d74ce352628cb798cfdec399068aa7f07ba9fb818b2b3305f98", size = 12306201, upload-time = "2025-12-20T16:16:26.979Z" }, + { url = "https://files.pythonhosted.org/packages/70/09/c39dadf0b13bb0768cd29d6a3aaff1fb7c6905ac40e9aaeca26b1c086e06/numpy-2.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:39699233bc72dd482da1415dcb06076e32f60eddc796a796c5fb6c5efce94667", size = 10308234, upload-time = "2025-12-20T16:16:29.417Z" }, + { url = "https://files.pythonhosted.org/packages/a7/0d/853fd96372eda07c824d24adf02e8bc92bb3731b43a9b2a39161c3667cc4/numpy-2.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a152d86a3ae00ba5f47b3acf3b827509fd0b6cb7d3259665e63dafbad22a75ea", size = 16649088, upload-time = "2025-12-20T16:16:31.421Z" }, + { url = "https://files.pythonhosted.org/packages/e3/37/cc636f1f2a9f585434e20a3e6e63422f70bfe4f7f6698e941db52ea1ac9a/numpy-2.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:39b19251dec4de8ff8496cd0806cbe27bf0684f765abb1f4809554de93785f2d", size = 12364065, upload-time = "2025-12-20T16:16:33.491Z" }, + { url = "https://files.pythonhosted.org/packages/ed/69/0b78f37ca3690969beee54103ce5f6021709134e8020767e93ba691a72f1/numpy-2.4.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:009bd0ea12d3c784b6639a8457537016ce5172109e585338e11334f6a7bb88ee", size = 5192640, upload-time = "2025-12-20T16:16:35.636Z" }, + { url = "https://files.pythonhosted.org/packages/1d/2a/08569f8252abf590294dbb09a430543ec8f8cc710383abfb3e75cc73aeda/numpy-2.4.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:5fe44e277225fd3dff6882d86d3d447205d43532c3627313d17e754fb3905a0e", size = 6541556, upload-time = "2025-12-20T16:16:37.276Z" }, + { url = "https://files.pythonhosted.org/packages/93/e9/a949885a4e177493d61519377952186b6cbfdf1d6002764c664ba28349b5/numpy-2.4.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f935c4493eda9069851058fa0d9e39dbf6286be690066509305e52912714dbb2", size = 14396562, upload-time = "2025-12-20T16:16:38.953Z" }, + { url = "https://files.pythonhosted.org/packages/99/98/9d4ad53b0e9ef901c2ef1d550d2136f5ac42d3fd2988390a6def32e23e48/numpy-2.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8cfa5f29a695cb7438965e6c3e8d06e0416060cf0d709c1b1c1653a939bf5c2a", size = 16351719, upload-time = "2025-12-20T16:16:41.503Z" }, + { url = "https://files.pythonhosted.org/packages/28/de/5f3711a38341d6e8dd619f6353251a0cdd07f3d6d101a8fd46f4ef87f895/numpy-2.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ba0cb30acd3ef11c94dc27fbfba68940652492bc107075e7ffe23057f9425681", size = 16176053, upload-time = "2025-12-20T16:16:44.552Z" }, + { url = "https://files.pythonhosted.org/packages/2a/5b/2a3753dc43916501b4183532e7ace862e13211042bceafa253afb5c71272/numpy-2.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:60e8c196cd82cbbd4f130b5290007e13e6de3eca79f0d4d38014769d96a7c475", size = 18277859, upload-time = "2025-12-20T16:16:47.174Z" }, + { url = "https://files.pythonhosted.org/packages/2c/c5/a18bcdd07a941db3076ef489d036ab16d2bfc2eae0cf27e5a26e29189434/numpy-2.4.0-cp313-cp313-win32.whl", hash = "sha256:5f48cb3e88fbc294dc90e215d86fbaf1c852c63dbdb6c3a3e63f45c4b57f7344", size = 5953849, upload-time = "2025-12-20T16:16:49.554Z" }, + { url = "https://files.pythonhosted.org/packages/4f/f1/719010ff8061da6e8a26e1980cf090412d4f5f8060b31f0c45d77dd67a01/numpy-2.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:a899699294f28f7be8992853c0c60741f16ff199205e2e6cdca155762cbaa59d", size = 12302840, upload-time = "2025-12-20T16:16:51.227Z" }, + { url = "https://files.pythonhosted.org/packages/f5/5a/b3d259083ed8b4d335270c76966cb6cf14a5d1b69e1a608994ac57a659e6/numpy-2.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:9198f447e1dc5647d07c9a6bbe2063cc0132728cc7175b39dbc796da5b54920d", size = 10308509, upload-time = "2025-12-20T16:16:53.313Z" }, + { url = "https://files.pythonhosted.org/packages/31/01/95edcffd1bb6c0633df4e808130545c4f07383ab629ac7e316fb44fff677/numpy-2.4.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:74623f2ab5cc3f7c886add4f735d1031a1d2be4a4ae63c0546cfd74e7a31ddf6", size = 12491815, upload-time = "2025-12-20T16:16:55.496Z" }, + { url = "https://files.pythonhosted.org/packages/59/ea/5644b8baa92cc1c7163b4b4458c8679852733fa74ca49c942cfa82ded4e0/numpy-2.4.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:0804a8e4ab070d1d35496e65ffd3cf8114c136a2b81f61dfab0de4b218aacfd5", size = 5320321, upload-time = "2025-12-20T16:16:57.468Z" }, + { url = "https://files.pythonhosted.org/packages/26/4e/e10938106d70bc21319bd6a86ae726da37edc802ce35a3a71ecdf1fdfe7f/numpy-2.4.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:02a2038eb27f9443a8b266a66911e926566b5a6ffd1a689b588f7f35b81e7dc3", size = 6641635, upload-time = "2025-12-20T16:16:59.379Z" }, + { url = "https://files.pythonhosted.org/packages/b3/8d/a8828e3eaf5c0b4ab116924df82f24ce3416fa38d0674d8f708ddc6c8aac/numpy-2.4.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1889b3a3f47a7b5bee16bc25a2145bd7cb91897f815ce3499db64c7458b6d91d", size = 14456053, upload-time = "2025-12-20T16:17:01.768Z" }, + { url = "https://files.pythonhosted.org/packages/68/a1/17d97609d87d4520aa5ae2dcfb32305654550ac6a35effb946d303e594ce/numpy-2.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85eef4cb5625c47ee6425c58a3502555e10f45ee973da878ac8248ad58c136f3", size = 16401702, upload-time = "2025-12-20T16:17:04.235Z" }, + { url = "https://files.pythonhosted.org/packages/18/32/0f13c1b2d22bea1118356b8b963195446f3af124ed7a5adfa8fdecb1b6ca/numpy-2.4.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6dc8b7e2f4eb184b37655195f421836cfae6f58197b67e3ffc501f1333d993fa", size = 16242493, upload-time = "2025-12-20T16:17:06.856Z" }, + { url = "https://files.pythonhosted.org/packages/ae/23/48f21e3d309fbc137c068a1475358cbd3a901b3987dcfc97a029ab3068e2/numpy-2.4.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:44aba2f0cafd287871a495fb3163408b0bd25bbce135c6f621534a07f4f7875c", size = 18324222, upload-time = "2025-12-20T16:17:09.392Z" }, + { url = "https://files.pythonhosted.org/packages/ac/52/41f3d71296a3dcaa4f456aaa3c6fc8e745b43d0552b6bde56571bb4b4a0f/numpy-2.4.0-cp313-cp313t-win32.whl", hash = "sha256:20c115517513831860c573996e395707aa9fb691eb179200125c250e895fcd93", size = 6076216, upload-time = "2025-12-20T16:17:11.437Z" }, + { url = "https://files.pythonhosted.org/packages/35/ff/46fbfe60ab0710d2a2b16995f708750307d30eccbb4c38371ea9e986866e/numpy-2.4.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b48e35f4ab6f6a7597c46e301126ceba4c44cd3280e3750f85db48b082624fa4", size = 12444263, upload-time = "2025-12-20T16:17:13.182Z" }, + { url = "https://files.pythonhosted.org/packages/a3/e3/9189ab319c01d2ed556c932ccf55064c5d75bb5850d1df7a482ce0badead/numpy-2.4.0-cp313-cp313t-win_arm64.whl", hash = "sha256:4d1cfce39e511069b11e67cd0bd78ceff31443b7c9e5c04db73c7a19f572967c", size = 10378265, upload-time = "2025-12-20T16:17:15.211Z" }, + { url = "https://files.pythonhosted.org/packages/ab/ed/52eac27de39d5e5a6c9aadabe672bc06f55e24a3d9010cd1183948055d76/numpy-2.4.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:c95eb6db2884917d86cde0b4d4cf31adf485c8ec36bf8696dd66fa70de96f36b", size = 16647476, upload-time = "2025-12-20T16:17:17.671Z" }, + { url = "https://files.pythonhosted.org/packages/77/c0/990ce1b7fcd4e09aeaa574e2a0a839589e4b08b2ca68070f1acb1fea6736/numpy-2.4.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:65167da969cd1ec3a1df31cb221ca3a19a8aaa25370ecb17d428415e93c1935e", size = 12374563, upload-time = "2025-12-20T16:17:20.216Z" }, + { url = "https://files.pythonhosted.org/packages/37/7c/8c5e389c6ae8f5fd2277a988600d79e9625db3fff011a2d87ac80b881a4c/numpy-2.4.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:3de19cfecd1465d0dcf8a5b5ea8b3155b42ed0b639dba4b71e323d74f2a3be5e", size = 5203107, upload-time = "2025-12-20T16:17:22.47Z" }, + { url = "https://files.pythonhosted.org/packages/e6/94/ca5b3bd6a8a70a5eec9a0b8dd7f980c1eff4b8a54970a9a7fef248ef564f/numpy-2.4.0-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:6c05483c3136ac4c91b4e81903cb53a8707d316f488124d0398499a4f8e8ef51", size = 6538067, upload-time = "2025-12-20T16:17:24.001Z" }, + { url = "https://files.pythonhosted.org/packages/79/43/993eb7bb5be6761dde2b3a3a594d689cec83398e3f58f4758010f3b85727/numpy-2.4.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36667db4d6c1cea79c8930ab72fadfb4060feb4bfe724141cd4bd064d2e5f8ce", size = 14411926, upload-time = "2025-12-20T16:17:25.822Z" }, + { url = "https://files.pythonhosted.org/packages/03/75/d4c43b61de473912496317a854dac54f1efec3eeb158438da6884b70bb90/numpy-2.4.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9a818668b674047fd88c4cddada7ab8f1c298812783e8328e956b78dc4807f9f", size = 16354295, upload-time = "2025-12-20T16:17:28.308Z" }, + { url = "https://files.pythonhosted.org/packages/b8/0a/b54615b47ee8736a6461a4bb6749128dd3435c5a759d5663f11f0e9af4ac/numpy-2.4.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:1ee32359fb7543b7b7bd0b2f46294db27e29e7bbdf70541e81b190836cd83ded", size = 16190242, upload-time = "2025-12-20T16:17:30.993Z" }, + { url = "https://files.pythonhosted.org/packages/98/ce/ea207769aacad6246525ec6c6bbd66a2bf56c72443dc10e2f90feed29290/numpy-2.4.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e493962256a38f58283de033d8af176c5c91c084ea30f15834f7545451c42059", size = 18280875, upload-time = "2025-12-20T16:17:33.327Z" }, + { url = "https://files.pythonhosted.org/packages/17/ef/ec409437aa962ea372ed601c519a2b141701683ff028f894b7466f0ab42b/numpy-2.4.0-cp314-cp314-win32.whl", hash = "sha256:6bbaebf0d11567fa8926215ae731e1d58e6ec28a8a25235b8a47405d301332db", size = 6002530, upload-time = "2025-12-20T16:17:35.729Z" }, + { url = "https://files.pythonhosted.org/packages/5f/4a/5cb94c787a3ed1ac65e1271b968686521169a7b3ec0b6544bb3ca32960b0/numpy-2.4.0-cp314-cp314-win_amd64.whl", hash = "sha256:3d857f55e7fdf7c38ab96c4558c95b97d1c685be6b05c249f5fdafcbd6f9899e", size = 12435890, upload-time = "2025-12-20T16:17:37.599Z" }, + { url = "https://files.pythonhosted.org/packages/48/a0/04b89db963af9de1104975e2544f30de89adbf75b9e75f7dd2599be12c79/numpy-2.4.0-cp314-cp314-win_arm64.whl", hash = "sha256:bb50ce5fb202a26fd5404620e7ef820ad1ab3558b444cb0b55beb7ef66cd2d63", size = 10591892, upload-time = "2025-12-20T16:17:39.649Z" }, + { url = "https://files.pythonhosted.org/packages/53/e5/d74b5ccf6712c06c7a545025a6a71bfa03bdc7e0568b405b0d655232fd92/numpy-2.4.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:355354388cba60f2132df297e2d53053d4063f79077b67b481d21276d61fc4df", size = 12494312, upload-time = "2025-12-20T16:17:41.714Z" }, + { url = "https://files.pythonhosted.org/packages/c2/08/3ca9cc2ddf54dfee7ae9a6479c071092a228c68aef08252aa08dac2af002/numpy-2.4.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:1d8f9fde5f6dc1b6fc34df8162f3b3079365468703fee7f31d4e0cc8c63baed9", size = 5322862, upload-time = "2025-12-20T16:17:44.145Z" }, + { url = "https://files.pythonhosted.org/packages/87/74/0bb63a68394c0c1e52670cfff2e309afa41edbe11b3327d9af29e4383f34/numpy-2.4.0-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:e0434aa22c821f44eeb4c650b81c7fbdd8c0122c6c4b5a576a76d5a35625ecd9", size = 6644986, upload-time = "2025-12-20T16:17:46.203Z" }, + { url = "https://files.pythonhosted.org/packages/06/8f/9264d9bdbcf8236af2823623fe2f3981d740fc3461e2787e231d97c38c28/numpy-2.4.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:40483b2f2d3ba7aad426443767ff5632ec3156ef09742b96913787d13c336471", size = 14457958, upload-time = "2025-12-20T16:17:48.017Z" }, + { url = "https://files.pythonhosted.org/packages/8c/d9/f9a69ae564bbc7236a35aa883319364ef5fd41f72aa320cc1cbe66148fe2/numpy-2.4.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d9e6a7664ddd9746e20b7325351fe1a8408d0a2bf9c63b5e898290ddc8f09544", size = 16398394, upload-time = "2025-12-20T16:17:50.409Z" }, + { url = "https://files.pythonhosted.org/packages/34/c7/39241501408dde7f885d241a98caba5421061a2c6d2b2197ac5e3aa842d8/numpy-2.4.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ecb0019d44f4cdb50b676c5d0cb4b1eae8e15d1ed3d3e6639f986fc92b2ec52c", size = 16241044, upload-time = "2025-12-20T16:17:52.661Z" }, + { url = "https://files.pythonhosted.org/packages/7c/95/cae7effd90e065a95e59fe710eeee05d7328ed169776dfdd9f789e032125/numpy-2.4.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d0ffd9e2e4441c96a9c91ec1783285d80bf835b677853fc2770a89d50c1e48ac", size = 18321772, upload-time = "2025-12-20T16:17:54.947Z" }, + { url = "https://files.pythonhosted.org/packages/96/df/3c6c279accd2bfb968a76298e5b276310bd55d243df4fa8ac5816d79347d/numpy-2.4.0-cp314-cp314t-win32.whl", hash = "sha256:77f0d13fa87036d7553bf81f0e1fe3ce68d14c9976c9851744e4d3e91127e95f", size = 6148320, upload-time = "2025-12-20T16:17:57.249Z" }, + { url = "https://files.pythonhosted.org/packages/92/8d/f23033cce252e7a75cae853d17f582e86534c46404dea1c8ee094a9d6d84/numpy-2.4.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b1f5b45829ac1848893f0ddf5cb326110604d6df96cdc255b0bf9edd154104d4", size = 12623460, upload-time = "2025-12-20T16:17:58.963Z" }, + { url = "https://files.pythonhosted.org/packages/a4/4f/1f8475907d1a7c4ef9020edf7f39ea2422ec896849245f00688e4b268a71/numpy-2.4.0-cp314-cp314t-win_arm64.whl", hash = "sha256:23a3e9d1a6f360267e8fbb38ba5db355a6a7e9be71d7fce7ab3125e88bb646c8", size = 10661799, upload-time = "2025-12-20T16:18:01.078Z" }, +] + +[[package]] +name = "packaging" +version = "25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, +] + +[[package]] +name = "pandas" +version = "2.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "python-dateutil" }, + { name = "pytz" }, + { name = "tzdata" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/01/d40b85317f86cf08d853a4f495195c73815fdf205eef3993821720274518/pandas-2.3.3.tar.gz", hash = "sha256:e05e1af93b977f7eafa636d043f9f94c7ee3ac81af99c13508215942e64c993b", size = 4495223, upload-time = "2025-09-29T23:34:51.853Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/fb/231d89e8637c808b997d172b18e9d4a4bc7bf31296196c260526055d1ea0/pandas-2.3.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6d21f6d74eb1725c2efaa71a2bfc661a0689579b58e9c0ca58a739ff0b002b53", size = 11597846, upload-time = "2025-09-29T23:19:48.856Z" }, + { url = "https://files.pythonhosted.org/packages/5c/bd/bf8064d9cfa214294356c2d6702b716d3cf3bb24be59287a6a21e24cae6b/pandas-2.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3fd2f887589c7aa868e02632612ba39acb0b8948faf5cc58f0850e165bd46f35", size = 10729618, upload-time = "2025-09-29T23:39:08.659Z" }, + { url = "https://files.pythonhosted.org/packages/57/56/cf2dbe1a3f5271370669475ead12ce77c61726ffd19a35546e31aa8edf4e/pandas-2.3.3-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecaf1e12bdc03c86ad4a7ea848d66c685cb6851d807a26aa245ca3d2017a1908", size = 11737212, upload-time = "2025-09-29T23:19:59.765Z" }, + { url = "https://files.pythonhosted.org/packages/e5/63/cd7d615331b328e287d8233ba9fdf191a9c2d11b6af0c7a59cfcec23de68/pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b3d11d2fda7eb164ef27ffc14b4fcab16a80e1ce67e9f57e19ec0afaf715ba89", size = 12362693, upload-time = "2025-09-29T23:20:14.098Z" }, + { url = "https://files.pythonhosted.org/packages/a6/de/8b1895b107277d52f2b42d3a6806e69cfef0d5cf1d0ba343470b9d8e0a04/pandas-2.3.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a68e15f780eddf2b07d242e17a04aa187a7ee12b40b930bfdd78070556550e98", size = 12771002, upload-time = "2025-09-29T23:20:26.76Z" }, + { url = "https://files.pythonhosted.org/packages/87/21/84072af3187a677c5893b170ba2c8fbe450a6ff911234916da889b698220/pandas-2.3.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:371a4ab48e950033bcf52b6527eccb564f52dc826c02afd9a1bc0ab731bba084", size = 13450971, upload-time = "2025-09-29T23:20:41.344Z" }, + { url = "https://files.pythonhosted.org/packages/86/41/585a168330ff063014880a80d744219dbf1dd7a1c706e75ab3425a987384/pandas-2.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:a16dcec078a01eeef8ee61bf64074b4e524a2a3f4b3be9326420cabe59c4778b", size = 10992722, upload-time = "2025-09-29T23:20:54.139Z" }, + { url = "https://files.pythonhosted.org/packages/cd/4b/18b035ee18f97c1040d94debd8f2e737000ad70ccc8f5513f4eefad75f4b/pandas-2.3.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:56851a737e3470de7fa88e6131f41281ed440d29a9268dcbf0002da5ac366713", size = 11544671, upload-time = "2025-09-29T23:21:05.024Z" }, + { url = "https://files.pythonhosted.org/packages/31/94/72fac03573102779920099bcac1c3b05975c2cb5f01eac609faf34bed1ca/pandas-2.3.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bdcd9d1167f4885211e401b3036c0c8d9e274eee67ea8d0758a256d60704cfe8", size = 10680807, upload-time = "2025-09-29T23:21:15.979Z" }, + { url = "https://files.pythonhosted.org/packages/16/87/9472cf4a487d848476865321de18cc8c920b8cab98453ab79dbbc98db63a/pandas-2.3.3-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e32e7cc9af0f1cc15548288a51a3b681cc2a219faa838e995f7dc53dbab1062d", size = 11709872, upload-time = "2025-09-29T23:21:27.165Z" }, + { url = "https://files.pythonhosted.org/packages/15/07/284f757f63f8a8d69ed4472bfd85122bd086e637bf4ed09de572d575a693/pandas-2.3.3-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:318d77e0e42a628c04dc56bcef4b40de67918f7041c2b061af1da41dcff670ac", size = 12306371, upload-time = "2025-09-29T23:21:40.532Z" }, + { url = "https://files.pythonhosted.org/packages/33/81/a3afc88fca4aa925804a27d2676d22dcd2031c2ebe08aabd0ae55b9ff282/pandas-2.3.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4e0a175408804d566144e170d0476b15d78458795bb18f1304fb94160cabf40c", size = 12765333, upload-time = "2025-09-29T23:21:55.77Z" }, + { url = "https://files.pythonhosted.org/packages/8d/0f/b4d4ae743a83742f1153464cf1a8ecfafc3ac59722a0b5c8602310cb7158/pandas-2.3.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:93c2d9ab0fc11822b5eece72ec9587e172f63cff87c00b062f6e37448ced4493", size = 13418120, upload-time = "2025-09-29T23:22:10.109Z" }, + { url = "https://files.pythonhosted.org/packages/4f/c7/e54682c96a895d0c808453269e0b5928a07a127a15704fedb643e9b0a4c8/pandas-2.3.3-cp313-cp313-win_amd64.whl", hash = "sha256:f8bfc0e12dc78f777f323f55c58649591b2cd0c43534e8355c51d3fede5f4dee", size = 10993991, upload-time = "2025-09-29T23:25:04.889Z" }, + { url = "https://files.pythonhosted.org/packages/f9/ca/3f8d4f49740799189e1395812f3bf23b5e8fc7c190827d55a610da72ce55/pandas-2.3.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:75ea25f9529fdec2d2e93a42c523962261e567d250b0013b16210e1d40d7c2e5", size = 12048227, upload-time = "2025-09-29T23:22:24.343Z" }, + { url = "https://files.pythonhosted.org/packages/0e/5a/f43efec3e8c0cc92c4663ccad372dbdff72b60bdb56b2749f04aa1d07d7e/pandas-2.3.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:74ecdf1d301e812db96a465a525952f4dde225fdb6d8e5a521d47e1f42041e21", size = 11411056, upload-time = "2025-09-29T23:22:37.762Z" }, + { url = "https://files.pythonhosted.org/packages/46/b1/85331edfc591208c9d1a63a06baa67b21d332e63b7a591a5ba42a10bb507/pandas-2.3.3-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6435cb949cb34ec11cc9860246ccb2fdc9ecd742c12d3304989017d53f039a78", size = 11645189, upload-time = "2025-09-29T23:22:51.688Z" }, + { url = "https://files.pythonhosted.org/packages/44/23/78d645adc35d94d1ac4f2a3c4112ab6f5b8999f4898b8cdf01252f8df4a9/pandas-2.3.3-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:900f47d8f20860de523a1ac881c4c36d65efcb2eb850e6948140fa781736e110", size = 12121912, upload-time = "2025-09-29T23:23:05.042Z" }, + { url = "https://files.pythonhosted.org/packages/53/da/d10013df5e6aaef6b425aa0c32e1fc1f3e431e4bcabd420517dceadce354/pandas-2.3.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a45c765238e2ed7d7c608fc5bc4a6f88b642f2f01e70c0c23d2224dd21829d86", size = 12712160, upload-time = "2025-09-29T23:23:28.57Z" }, + { url = "https://files.pythonhosted.org/packages/bd/17/e756653095a083d8a37cbd816cb87148debcfcd920129b25f99dd8d04271/pandas-2.3.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c4fc4c21971a1a9f4bdb4c73978c7f7256caa3e62b323f70d6cb80db583350bc", size = 13199233, upload-time = "2025-09-29T23:24:24.876Z" }, + { url = "https://files.pythonhosted.org/packages/04/fd/74903979833db8390b73b3a8a7d30d146d710bd32703724dd9083950386f/pandas-2.3.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:ee15f284898e7b246df8087fc82b87b01686f98ee67d85a17b7ab44143a3a9a0", size = 11540635, upload-time = "2025-09-29T23:25:52.486Z" }, + { url = "https://files.pythonhosted.org/packages/21/00/266d6b357ad5e6d3ad55093a7e8efc7dd245f5a842b584db9f30b0f0a287/pandas-2.3.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1611aedd912e1ff81ff41c745822980c49ce4a7907537be8692c8dbc31924593", size = 10759079, upload-time = "2025-09-29T23:26:33.204Z" }, + { url = "https://files.pythonhosted.org/packages/ca/05/d01ef80a7a3a12b2f8bbf16daba1e17c98a2f039cbc8e2f77a2c5a63d382/pandas-2.3.3-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6d2cefc361461662ac48810cb14365a365ce864afe85ef1f447ff5a1e99ea81c", size = 11814049, upload-time = "2025-09-29T23:27:15.384Z" }, + { url = "https://files.pythonhosted.org/packages/15/b2/0e62f78c0c5ba7e3d2c5945a82456f4fac76c480940f805e0b97fcbc2f65/pandas-2.3.3-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ee67acbbf05014ea6c763beb097e03cd629961c8a632075eeb34247120abcb4b", size = 12332638, upload-time = "2025-09-29T23:27:51.625Z" }, + { url = "https://files.pythonhosted.org/packages/c5/33/dd70400631b62b9b29c3c93d2feee1d0964dc2bae2e5ad7a6c73a7f25325/pandas-2.3.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c46467899aaa4da076d5abc11084634e2d197e9460643dd455ac3db5856b24d6", size = 12886834, upload-time = "2025-09-29T23:28:21.289Z" }, + { url = "https://files.pythonhosted.org/packages/d3/18/b5d48f55821228d0d2692b34fd5034bb185e854bdb592e9c640f6290e012/pandas-2.3.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:6253c72c6a1d990a410bc7de641d34053364ef8bcd3126f7e7450125887dffe3", size = 13409925, upload-time = "2025-09-29T23:28:58.261Z" }, + { url = "https://files.pythonhosted.org/packages/a6/3d/124ac75fcd0ecc09b8fdccb0246ef65e35b012030defb0e0eba2cbbbe948/pandas-2.3.3-cp314-cp314-win_amd64.whl", hash = "sha256:1b07204a219b3b7350abaae088f451860223a52cfb8a6c53358e7948735158e5", size = 11109071, upload-time = "2025-09-29T23:32:27.484Z" }, + { url = "https://files.pythonhosted.org/packages/89/9c/0e21c895c38a157e0faa1fb64587a9226d6dd46452cac4532d80c3c4a244/pandas-2.3.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:2462b1a365b6109d275250baaae7b760fd25c726aaca0054649286bcfbb3e8ec", size = 12048504, upload-time = "2025-09-29T23:29:31.47Z" }, + { url = "https://files.pythonhosted.org/packages/d7/82/b69a1c95df796858777b68fbe6a81d37443a33319761d7c652ce77797475/pandas-2.3.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:0242fe9a49aa8b4d78a4fa03acb397a58833ef6199e9aa40a95f027bb3a1b6e7", size = 11410702, upload-time = "2025-09-29T23:29:54.591Z" }, + { url = "https://files.pythonhosted.org/packages/f9/88/702bde3ba0a94b8c73a0181e05144b10f13f29ebfc2150c3a79062a8195d/pandas-2.3.3-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a21d830e78df0a515db2b3d2f5570610f5e6bd2e27749770e8bb7b524b89b450", size = 11634535, upload-time = "2025-09-29T23:30:21.003Z" }, + { url = "https://files.pythonhosted.org/packages/a4/1e/1bac1a839d12e6a82ec6cb40cda2edde64a2013a66963293696bbf31fbbb/pandas-2.3.3-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2e3ebdb170b5ef78f19bfb71b0dc5dc58775032361fa188e814959b74d726dd5", size = 12121582, upload-time = "2025-09-29T23:30:43.391Z" }, + { url = "https://files.pythonhosted.org/packages/44/91/483de934193e12a3b1d6ae7c8645d083ff88dec75f46e827562f1e4b4da6/pandas-2.3.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:d051c0e065b94b7a3cea50eb1ec32e912cd96dba41647eb24104b6c6c14c5788", size = 12699963, upload-time = "2025-09-29T23:31:10.009Z" }, + { url = "https://files.pythonhosted.org/packages/70/44/5191d2e4026f86a2a109053e194d3ba7a31a2d10a9c2348368c63ed4e85a/pandas-2.3.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3869faf4bd07b3b66a9f462417d0ca3a9df29a9f6abd5d0d0dbab15dac7abe87", size = 13202175, upload-time = "2025-09-29T23:31:59.173Z" }, +] + +[[package]] +name = "pint" +version = "0.24.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "flexcache" }, + { name = "flexparser" }, + { name = "platformdirs" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/20/bb/52b15ddf7b7706ed591134a895dbf6e41c8348171fb635e655e0a4bbb0ea/pint-0.24.4.tar.gz", hash = "sha256:35275439b574837a6cd3020a5a4a73645eb125ce4152a73a2f126bf164b91b80", size = 342225, upload-time = "2024-11-07T16:29:46.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/16/bd2f5904557265882108dc2e04f18abc05ab0c2b7082ae9430091daf1d5c/Pint-0.24.4-py3-none-any.whl", hash = "sha256:aa54926c8772159fcf65f82cc0d34de6768c151b32ad1deb0331291c38fe7659", size = 302029, upload-time = "2024-11-07T16:29:43.976Z" }, +] + +[[package]] +name = "platformdirs" +version = "4.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cf/86/0248f086a84f01b37aaec0fa567b397df1a119f73c16f6c7a9aac73ea309/platformdirs-4.5.1.tar.gz", hash = "sha256:61d5cdcc6065745cdd94f0f878977f8de9437be93de97c1c12f853c9c0cdcbda", size = 21715, upload-time = "2025-12-05T13:52:58.638Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl", hash = "sha256:d03afa3963c806a9bed9d5125c8f4cb2fdaf74a55ab60e5d59b3fde758104d31", size = 18731, upload-time = "2025-12-05T13:52:56.823Z" }, +] + +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + +[[package]] +name = "pre-commit" +version = "4.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cfgv" }, + { name = "identify" }, + { name = "nodeenv" }, + { name = "pyyaml" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/40/f1/6d86a29246dfd2e9b6237f0b5823717f60cad94d47ddc26afa916d21f525/pre_commit-4.5.1.tar.gz", hash = "sha256:eb545fcff725875197837263e977ea257a402056661f09dae08e4b149b030a61", size = 198232, upload-time = "2025-12-16T21:14:33.552Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/19/fd3ef348460c80af7bb4669ea7926651d1f95c23ff2df18b9d24bab4f3fa/pre_commit-4.5.1-py2.py3-none-any.whl", hash = "sha256:3b3afd891e97337708c1674210f8eba659b52a38ea5f822ff142d10786221f77", size = 226437, upload-time = "2025-12-16T21:14:32.409Z" }, +] + +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, +] + +[[package]] +name = "pytest" +version = "9.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, +] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, +] + +[[package]] +name = "pytz" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884, upload-time = "2025-03-25T02:25:00.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, +] + +[[package]] +name = "pyyaml" +version = "6.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f", size = 130960, upload-time = "2025-09-25T21:33:16.546Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/33/422b98d2195232ca1826284a76852ad5a86fe23e31b009c9886b2d0fb8b2/pyyaml-6.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7f047e29dcae44602496db43be01ad42fc6f1cc0d8cd6c83d342306c32270196", size = 182063, upload-time = "2025-09-25T21:32:11.445Z" }, + { url = "https://files.pythonhosted.org/packages/89/a0/6cf41a19a1f2f3feab0e9c0b74134aa2ce6849093d5517a0c550fe37a648/pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fc09d0aa354569bc501d4e787133afc08552722d3ab34836a80547331bb5d4a0", size = 173973, upload-time = "2025-09-25T21:32:12.492Z" }, + { url = "https://files.pythonhosted.org/packages/ed/23/7a778b6bd0b9a8039df8b1b1d80e2e2ad78aa04171592c8a5c43a56a6af4/pyyaml-6.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9149cad251584d5fb4981be1ecde53a1ca46c891a79788c0df828d2f166bda28", size = 775116, upload-time = "2025-09-25T21:32:13.652Z" }, + { url = "https://files.pythonhosted.org/packages/65/30/d7353c338e12baef4ecc1b09e877c1970bd3382789c159b4f89d6a70dc09/pyyaml-6.0.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5fdec68f91a0c6739b380c83b951e2c72ac0197ace422360e6d5a959d8d97b2c", size = 844011, upload-time = "2025-09-25T21:32:15.21Z" }, + { url = "https://files.pythonhosted.org/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc", size = 807870, upload-time = "2025-09-25T21:32:16.431Z" }, + { url = "https://files.pythonhosted.org/packages/05/c0/b3be26a015601b822b97d9149ff8cb5ead58c66f981e04fedf4e762f4bd4/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8dc52c23056b9ddd46818a57b78404882310fb473d63f17b07d5c40421e47f8e", size = 761089, upload-time = "2025-09-25T21:32:17.56Z" }, + { url = "https://files.pythonhosted.org/packages/be/8e/98435a21d1d4b46590d5459a22d88128103f8da4c2d4cb8f14f2a96504e1/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41715c910c881bc081f1e8872880d3c650acf13dfa8214bad49ed4cede7c34ea", size = 790181, upload-time = "2025-09-25T21:32:18.834Z" }, + { url = "https://files.pythonhosted.org/packages/74/93/7baea19427dcfbe1e5a372d81473250b379f04b1bd3c4c5ff825e2327202/pyyaml-6.0.3-cp312-cp312-win32.whl", hash = "sha256:96b533f0e99f6579b3d4d4995707cf36df9100d67e0c8303a0c55b27b5f99bc5", size = 137658, upload-time = "2025-09-25T21:32:20.209Z" }, + { url = "https://files.pythonhosted.org/packages/86/bf/899e81e4cce32febab4fb42bb97dcdf66bc135272882d1987881a4b519e9/pyyaml-6.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:5fcd34e47f6e0b794d17de1b4ff496c00986e1c83f7ab2fb8fcfe9616ff7477b", size = 154003, upload-time = "2025-09-25T21:32:21.167Z" }, + { url = "https://files.pythonhosted.org/packages/1a/08/67bd04656199bbb51dbed1439b7f27601dfb576fb864099c7ef0c3e55531/pyyaml-6.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:64386e5e707d03a7e172c0701abfb7e10f0fb753ee1d773128192742712a98fd", size = 140344, upload-time = "2025-09-25T21:32:22.617Z" }, + { url = "https://files.pythonhosted.org/packages/d1/11/0fd08f8192109f7169db964b5707a2f1e8b745d4e239b784a5a1dd80d1db/pyyaml-6.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8da9669d359f02c0b91ccc01cac4a67f16afec0dac22c2ad09f46bee0697eba8", size = 181669, upload-time = "2025-09-25T21:32:23.673Z" }, + { url = "https://files.pythonhosted.org/packages/b1/16/95309993f1d3748cd644e02e38b75d50cbc0d9561d21f390a76242ce073f/pyyaml-6.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2283a07e2c21a2aa78d9c4442724ec1eb15f5e42a723b99cb3d822d48f5f7ad1", size = 173252, upload-time = "2025-09-25T21:32:25.149Z" }, + { url = "https://files.pythonhosted.org/packages/50/31/b20f376d3f810b9b2371e72ef5adb33879b25edb7a6d072cb7ca0c486398/pyyaml-6.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee2922902c45ae8ccada2c5b501ab86c36525b883eff4255313a253a3160861c", size = 767081, upload-time = "2025-09-25T21:32:26.575Z" }, + { url = "https://files.pythonhosted.org/packages/49/1e/a55ca81e949270d5d4432fbbd19dfea5321eda7c41a849d443dc92fd1ff7/pyyaml-6.0.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a33284e20b78bd4a18c8c2282d549d10bc8408a2a7ff57653c0cf0b9be0afce5", size = 841159, upload-time = "2025-09-25T21:32:27.727Z" }, + { url = "https://files.pythonhosted.org/packages/74/27/e5b8f34d02d9995b80abcef563ea1f8b56d20134d8f4e5e81733b1feceb2/pyyaml-6.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f29edc409a6392443abf94b9cf89ce99889a1dd5376d94316ae5145dfedd5d6", size = 801626, upload-time = "2025-09-25T21:32:28.878Z" }, + { url = "https://files.pythonhosted.org/packages/f9/11/ba845c23988798f40e52ba45f34849aa8a1f2d4af4b798588010792ebad6/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f7057c9a337546edc7973c0d3ba84ddcdf0daa14533c2065749c9075001090e6", size = 753613, upload-time = "2025-09-25T21:32:30.178Z" }, + { url = "https://files.pythonhosted.org/packages/3d/e0/7966e1a7bfc0a45bf0a7fb6b98ea03fc9b8d84fa7f2229e9659680b69ee3/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eda16858a3cab07b80edaf74336ece1f986ba330fdb8ee0d6c0d68fe82bc96be", size = 794115, upload-time = "2025-09-25T21:32:31.353Z" }, + { url = "https://files.pythonhosted.org/packages/de/94/980b50a6531b3019e45ddeada0626d45fa85cbe22300844a7983285bed3b/pyyaml-6.0.3-cp313-cp313-win32.whl", hash = "sha256:d0eae10f8159e8fdad514efdc92d74fd8d682c933a6dd088030f3834bc8e6b26", size = 137427, upload-time = "2025-09-25T21:32:32.58Z" }, + { url = "https://files.pythonhosted.org/packages/97/c9/39d5b874e8b28845e4ec2202b5da735d0199dbe5b8fb85f91398814a9a46/pyyaml-6.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:79005a0d97d5ddabfeeea4cf676af11e647e41d81c9a7722a193022accdb6b7c", size = 154090, upload-time = "2025-09-25T21:32:33.659Z" }, + { url = "https://files.pythonhosted.org/packages/73/e8/2bdf3ca2090f68bb3d75b44da7bbc71843b19c9f2b9cb9b0f4ab7a5a4329/pyyaml-6.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:5498cd1645aa724a7c71c8f378eb29ebe23da2fc0d7a08071d89469bf1d2defb", size = 140246, upload-time = "2025-09-25T21:32:34.663Z" }, + { url = "https://files.pythonhosted.org/packages/9d/8c/f4bd7f6465179953d3ac9bc44ac1a8a3e6122cf8ada906b4f96c60172d43/pyyaml-6.0.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:8d1fab6bb153a416f9aeb4b8763bc0f22a5586065f86f7664fc23339fc1c1fac", size = 181814, upload-time = "2025-09-25T21:32:35.712Z" }, + { url = "https://files.pythonhosted.org/packages/bd/9c/4d95bb87eb2063d20db7b60faa3840c1b18025517ae857371c4dd55a6b3a/pyyaml-6.0.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:34d5fcd24b8445fadc33f9cf348c1047101756fd760b4dacb5c3e99755703310", size = 173809, upload-time = "2025-09-25T21:32:36.789Z" }, + { url = "https://files.pythonhosted.org/packages/92/b5/47e807c2623074914e29dabd16cbbdd4bf5e9b2db9f8090fa64411fc5382/pyyaml-6.0.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:501a031947e3a9025ed4405a168e6ef5ae3126c59f90ce0cd6f2bfc477be31b7", size = 766454, upload-time = "2025-09-25T21:32:37.966Z" }, + { url = "https://files.pythonhosted.org/packages/02/9e/e5e9b168be58564121efb3de6859c452fccde0ab093d8438905899a3a483/pyyaml-6.0.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b3bc83488de33889877a0f2543ade9f70c67d66d9ebb4ac959502e12de895788", size = 836355, upload-time = "2025-09-25T21:32:39.178Z" }, + { url = "https://files.pythonhosted.org/packages/88/f9/16491d7ed2a919954993e48aa941b200f38040928474c9e85ea9e64222c3/pyyaml-6.0.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c458b6d084f9b935061bc36216e8a69a7e293a2f1e68bf956dcd9e6cbcd143f5", size = 794175, upload-time = "2025-09-25T21:32:40.865Z" }, + { url = "https://files.pythonhosted.org/packages/dd/3f/5989debef34dc6397317802b527dbbafb2b4760878a53d4166579111411e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7c6610def4f163542a622a73fb39f534f8c101d690126992300bf3207eab9764", size = 755228, upload-time = "2025-09-25T21:32:42.084Z" }, + { url = "https://files.pythonhosted.org/packages/d7/ce/af88a49043cd2e265be63d083fc75b27b6ed062f5f9fd6cdc223ad62f03e/pyyaml-6.0.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5190d403f121660ce8d1d2c1bb2ef1bd05b5f68533fc5c2ea899bd15f4399b35", size = 789194, upload-time = "2025-09-25T21:32:43.362Z" }, + { url = "https://files.pythonhosted.org/packages/23/20/bb6982b26a40bb43951265ba29d4c246ef0ff59c9fdcdf0ed04e0687de4d/pyyaml-6.0.3-cp314-cp314-win_amd64.whl", hash = "sha256:4a2e8cebe2ff6ab7d1050ecd59c25d4c8bd7e6f400f5f82b96557ac0abafd0ac", size = 156429, upload-time = "2025-09-25T21:32:57.844Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f4/a4541072bb9422c8a883ab55255f918fa378ecf083f5b85e87fc2b4eda1b/pyyaml-6.0.3-cp314-cp314-win_arm64.whl", hash = "sha256:93dda82c9c22deb0a405ea4dc5f2d0cda384168e466364dec6255b293923b2f3", size = 143912, upload-time = "2025-09-25T21:32:59.247Z" }, + { url = "https://files.pythonhosted.org/packages/7c/f9/07dd09ae774e4616edf6cda684ee78f97777bdd15847253637a6f052a62f/pyyaml-6.0.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:02893d100e99e03eda1c8fd5c441d8c60103fd175728e23e431db1b589cf5ab3", size = 189108, upload-time = "2025-09-25T21:32:44.377Z" }, + { url = "https://files.pythonhosted.org/packages/4e/78/8d08c9fb7ce09ad8c38ad533c1191cf27f7ae1effe5bb9400a46d9437fcf/pyyaml-6.0.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c1ff362665ae507275af2853520967820d9124984e0f7466736aea23d8611fba", size = 183641, upload-time = "2025-09-25T21:32:45.407Z" }, + { url = "https://files.pythonhosted.org/packages/7b/5b/3babb19104a46945cf816d047db2788bcaf8c94527a805610b0289a01c6b/pyyaml-6.0.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6adc77889b628398debc7b65c073bcb99c4a0237b248cacaf3fe8a557563ef6c", size = 831901, upload-time = "2025-09-25T21:32:48.83Z" }, + { url = "https://files.pythonhosted.org/packages/8b/cc/dff0684d8dc44da4d22a13f35f073d558c268780ce3c6ba1b87055bb0b87/pyyaml-6.0.3-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a80cb027f6b349846a3bf6d73b5e95e782175e52f22108cfa17876aaeff93702", size = 861132, upload-time = "2025-09-25T21:32:50.149Z" }, + { url = "https://files.pythonhosted.org/packages/b1/5e/f77dc6b9036943e285ba76b49e118d9ea929885becb0a29ba8a7c75e29fe/pyyaml-6.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00c4bdeba853cc34e7dd471f16b4114f4162dc03e6b7afcc2128711f0eca823c", size = 839261, upload-time = "2025-09-25T21:32:51.808Z" }, + { url = "https://files.pythonhosted.org/packages/ce/88/a9db1376aa2a228197c58b37302f284b5617f56a5d959fd1763fb1675ce6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e1674c3ef6f541c35191caae2d429b967b99e02040f5ba928632d9a7f0f065", size = 805272, upload-time = "2025-09-25T21:32:52.941Z" }, + { url = "https://files.pythonhosted.org/packages/da/92/1446574745d74df0c92e6aa4a7b0b3130706a4142b2d1a5869f2eaa423c6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:16249ee61e95f858e83976573de0f5b2893b3677ba71c9dd36b9cf8be9ac6d65", size = 829923, upload-time = "2025-09-25T21:32:54.537Z" }, + { url = "https://files.pythonhosted.org/packages/f0/7a/1c7270340330e575b92f397352af856a8c06f230aa3e76f86b39d01b416a/pyyaml-6.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:4ad1906908f2f5ae4e5a8ddfce73c320c2a1429ec52eafd27138b7f1cbe341c9", size = 174062, upload-time = "2025-09-25T21:32:55.767Z" }, + { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, +] + +[[package]] +name = "requests" +version = "2.32.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, +] + +[[package]] +name = "requests-mock" +version = "1.12.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/92/32/587625f91f9a0a3d84688bf9cfc4b2480a7e8ec327cefd0ff2ac891fd2cf/requests-mock-1.12.1.tar.gz", hash = "sha256:e9e12e333b525156e82a3c852f22016b9158220d2f47454de9cae8a77d371401", size = 60901, upload-time = "2024-03-29T03:54:29.446Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/97/ec/889fbc557727da0c34a33850950310240f2040f3b1955175fdb2b36a8910/requests_mock-1.12.1-py2.py3-none-any.whl", hash = "sha256:b1e37054004cdd5e56c84454cc7df12b25f90f382159087f4b6915aaeef39563", size = 27695, upload-time = "2024-03-29T03:54:27.64Z" }, +] + +[[package]] +name = "ruff" +version = "0.14.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/08/52232a877978dd8f9cf2aeddce3e611b40a63287dfca29b6b8da791f5e8d/ruff-0.14.10.tar.gz", hash = "sha256:9a2e830f075d1a42cd28420d7809ace390832a490ed0966fe373ba288e77aaf4", size = 5859763, upload-time = "2025-12-18T19:28:57.98Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/01/933704d69f3f05ee16ef11406b78881733c186fe14b6a46b05cfcaf6d3b2/ruff-0.14.10-py3-none-linux_armv6l.whl", hash = "sha256:7a3ce585f2ade3e1f29ec1b92df13e3da262178df8c8bdf876f48fa0e8316c49", size = 13527080, upload-time = "2025-12-18T19:29:25.642Z" }, + { url = "https://files.pythonhosted.org/packages/df/58/a0349197a7dfa603ffb7f5b0470391efa79ddc327c1e29c4851e85b09cc5/ruff-0.14.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:674f9be9372907f7257c51f1d4fc902cb7cf014b9980152b802794317941f08f", size = 13797320, upload-time = "2025-12-18T19:29:02.571Z" }, + { url = "https://files.pythonhosted.org/packages/7b/82/36be59f00a6082e38c23536df4e71cdbc6af8d7c707eade97fcad5c98235/ruff-0.14.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d85713d522348837ef9df8efca33ccb8bd6fcfc86a2cde3ccb4bc9d28a18003d", size = 12918434, upload-time = "2025-12-18T19:28:51.202Z" }, + { url = "https://files.pythonhosted.org/packages/a6/00/45c62a7f7e34da92a25804f813ebe05c88aa9e0c25e5cb5a7d23dd7450e3/ruff-0.14.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6987ebe0501ae4f4308d7d24e2d0fe3d7a98430f5adfd0f1fead050a740a3a77", size = 13371961, upload-time = "2025-12-18T19:29:04.991Z" }, + { url = "https://files.pythonhosted.org/packages/40/31/a5906d60f0405f7e57045a70f2d57084a93ca7425f22e1d66904769d1628/ruff-0.14.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:16a01dfb7b9e4eee556fbfd5392806b1b8550c9b4a9f6acd3dbe6812b193c70a", size = 13275629, upload-time = "2025-12-18T19:29:21.381Z" }, + { url = "https://files.pythonhosted.org/packages/3e/60/61c0087df21894cf9d928dc04bcd4fb10e8b2e8dca7b1a276ba2155b2002/ruff-0.14.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7165d31a925b7a294465fa81be8c12a0e9b60fb02bf177e79067c867e71f8b1f", size = 14029234, upload-time = "2025-12-18T19:29:00.132Z" }, + { url = "https://files.pythonhosted.org/packages/44/84/77d911bee3b92348b6e5dab5a0c898d87084ea03ac5dc708f46d88407def/ruff-0.14.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c561695675b972effb0c0a45db233f2c816ff3da8dcfbe7dfc7eed625f218935", size = 15449890, upload-time = "2025-12-18T19:28:53.573Z" }, + { url = "https://files.pythonhosted.org/packages/e9/36/480206eaefa24a7ec321582dda580443a8f0671fdbf6b1c80e9c3e93a16a/ruff-0.14.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4bb98fcbbc61725968893682fd4df8966a34611239c9fd07a1f6a07e7103d08e", size = 15123172, upload-time = "2025-12-18T19:29:23.453Z" }, + { url = "https://files.pythonhosted.org/packages/5c/38/68e414156015ba80cef5473d57919d27dfb62ec804b96180bafdeaf0e090/ruff-0.14.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f24b47993a9d8cb858429e97bdf8544c78029f09b520af615c1d261bf827001d", size = 14460260, upload-time = "2025-12-18T19:29:27.808Z" }, + { url = "https://files.pythonhosted.org/packages/b3/19/9e050c0dca8aba824d67cc0db69fb459c28d8cd3f6855b1405b3f29cc91d/ruff-0.14.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59aabd2e2c4fd614d2862e7939c34a532c04f1084476d6833dddef4afab87e9f", size = 14229978, upload-time = "2025-12-18T19:29:11.32Z" }, + { url = "https://files.pythonhosted.org/packages/51/eb/e8dd1dd6e05b9e695aa9dd420f4577debdd0f87a5ff2fedda33c09e9be8c/ruff-0.14.10-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:213db2b2e44be8625002dbea33bb9c60c66ea2c07c084a00d55732689d697a7f", size = 14338036, upload-time = "2025-12-18T19:29:09.184Z" }, + { url = "https://files.pythonhosted.org/packages/6a/12/f3e3a505db7c19303b70af370d137795fcfec136d670d5de5391e295c134/ruff-0.14.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b914c40ab64865a17a9a5b67911d14df72346a634527240039eb3bd650e5979d", size = 13264051, upload-time = "2025-12-18T19:29:13.431Z" }, + { url = "https://files.pythonhosted.org/packages/08/64/8c3a47eaccfef8ac20e0484e68e0772013eb85802f8a9f7603ca751eb166/ruff-0.14.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1484983559f026788e3a5c07c81ef7d1e97c1c78ed03041a18f75df104c45405", size = 13283998, upload-time = "2025-12-18T19:29:06.994Z" }, + { url = "https://files.pythonhosted.org/packages/12/84/534a5506f4074e5cc0529e5cd96cfc01bb480e460c7edf5af70d2bcae55e/ruff-0.14.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c70427132db492d25f982fffc8d6c7535cc2fd2c83fc8888f05caaa248521e60", size = 13601891, upload-time = "2025-12-18T19:28:55.811Z" }, + { url = "https://files.pythonhosted.org/packages/0d/1e/14c916087d8598917dbad9b2921d340f7884824ad6e9c55de948a93b106d/ruff-0.14.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5bcf45b681e9f1ee6445d317ce1fa9d6cba9a6049542d1c3d5b5958986be8830", size = 14336660, upload-time = "2025-12-18T19:29:16.531Z" }, + { url = "https://files.pythonhosted.org/packages/f2/1c/d7b67ab43f30013b47c12b42d1acd354c195351a3f7a1d67f59e54227ede/ruff-0.14.10-py3-none-win32.whl", hash = "sha256:104c49fc7ab73f3f3a758039adea978869a918f31b73280db175b43a2d9b51d6", size = 13196187, upload-time = "2025-12-18T19:29:19.006Z" }, + { url = "https://files.pythonhosted.org/packages/fb/9c/896c862e13886fae2af961bef3e6312db9ebc6adc2b156fe95e615dee8c1/ruff-0.14.10-py3-none-win_amd64.whl", hash = "sha256:466297bd73638c6bdf06485683e812db1c00c7ac96d4ddd0294a338c62fdc154", size = 14661283, upload-time = "2025-12-18T19:29:30.16Z" }, + { url = "https://files.pythonhosted.org/packages/74/31/b0e29d572670dca3674eeee78e418f20bdf97fa8aa9ea71380885e175ca0/ruff-0.14.10-py3-none-win_arm64.whl", hash = "sha256:e51d046cf6dda98a4633b8a8a771451107413b0f07183b2bef03f075599e44e6", size = 13729839, upload-time = "2025-12-18T19:28:48.636Z" }, +] + +[[package]] +name = "s3transfer" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/04/74127fc843314818edfa81b5540e26dd537353b123a4edc563109d8f17dd/s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920", size = 153827, upload-time = "2025-12-01T02:30:59.114Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" }, +] + +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, +] + +[[package]] +name = "tqdm" +version = "4.67.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, +] + +[[package]] +name = "ty" +version = "0.0.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/85/97b5276baa217e05db2fe3d5c61e4dfd35d1d3d0ec95bfca1986820114e0/ty-0.0.10.tar.gz", hash = "sha256:0a1f9f7577e56cd508a8f93d0be2a502fdf33de6a7d65a328a4c80b784f4ac5f", size = 4892892, upload-time = "2026-01-07T23:00:23.572Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/7a/5a7147ce5231c3ccc55d6f945dabd7412e233e755d28093bfdec988ba595/ty-0.0.10-py3-none-linux_armv6l.whl", hash = "sha256:406a8ea4e648551f885629b75dc3f070427de6ed099af45e52051d4c68224829", size = 9835881, upload-time = "2026-01-07T22:08:17.492Z" }, + { url = "https://files.pythonhosted.org/packages/3e/7d/89f4d2277c938332d047237b47b11b82a330dbff4fff0de8574cba992128/ty-0.0.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d6e0a733e3d6d3bce56d6766bc61923e8b130241088dc2c05e3c549487190096", size = 9696404, upload-time = "2026-01-07T22:08:37.965Z" }, + { url = "https://files.pythonhosted.org/packages/e8/cd/9dd49e6d40e54d4b7d563f9e2a432c4ec002c0673a81266e269c4bc194ce/ty-0.0.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e4832f8879cb95fc725f7e7fcab4f22be0cf2550f3a50641d5f4409ee04176d4", size = 9181195, upload-time = "2026-01-07T22:59:07.187Z" }, + { url = "https://files.pythonhosted.org/packages/d2/b8/3e7c556654ba0569ed5207138d318faf8633d87e194760fc030543817c26/ty-0.0.10-py3-none-manylinux_2_24_aarch64.whl", hash = "sha256:6b58cc78e5865bc908f053559a80bb77cab0dc168aaad2e88f2b47955694b138", size = 9665002, upload-time = "2026-01-07T22:08:30.782Z" }, + { url = "https://files.pythonhosted.org/packages/98/96/410a483321406c932c4e3aa1581d1072b72cdcde3ae83cd0664a65c7b254/ty-0.0.10-py3-none-manylinux_2_24_armv7l.whl", hash = "sha256:83c6a514bb86f05005fa93e3b173ae3fde94d291d994bed6fe1f1d2e5c7331cf", size = 9664948, upload-time = "2026-01-07T23:04:14.655Z" }, + { url = "https://files.pythonhosted.org/packages/1f/5d/cba2ab3e2f660763a72ad12620d0739db012e047eaa0ceaa252bf5e94ebb/ty-0.0.10-py3-none-manylinux_2_24_i686.whl", hash = "sha256:2e43f71e357f8a4f7fc75e4753b37beb2d0f297498055b1673a9306aa3e21897", size = 10125401, upload-time = "2026-01-07T22:08:28.171Z" }, + { url = "https://files.pythonhosted.org/packages/a7/67/29536e0d97f204a2933122239298e754db4564f4ed7f34e2153012b954be/ty-0.0.10-py3-none-manylinux_2_24_ppc64le.whl", hash = "sha256:18be3c679965c23944c8e574be0635504398c64c55f3f0c46259464e10c0a1c7", size = 10714052, upload-time = "2026-01-07T22:08:20.098Z" }, + { url = "https://files.pythonhosted.org/packages/63/c8/82ac83b79a71c940c5dcacb644f526f0c8fdf4b5e9664065ab7ee7c0e4ec/ty-0.0.10-py3-none-manylinux_2_24_s390x.whl", hash = "sha256:5477981681440a35acdf9b95c3097410c547abaa32b893f61553dbc3b0096fff", size = 10395924, upload-time = "2026-01-07T22:08:22.839Z" }, + { url = "https://files.pythonhosted.org/packages/9e/4c/2f9ac5edbd0e67bf82f5cd04275c4e87cbbf69a78f43e5dcf90c1573d44e/ty-0.0.10-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:e206a23bd887574302138b33383ae1edfcc39d33a06a12a5a00803b3f0287a45", size = 10220096, upload-time = "2026-01-07T22:08:13.171Z" }, + { url = "https://files.pythonhosted.org/packages/04/13/3be2b7bfd53b9952b39b6f2c2ef55edeb1a2fea3bf0285962736ee26731c/ty-0.0.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4e09ddb0d3396bd59f645b85eab20f9a72989aa8b736b34338dcb5ffecfe77b6", size = 9649120, upload-time = "2026-01-07T22:08:34.003Z" }, + { url = "https://files.pythonhosted.org/packages/93/e3/edd58547d9fd01e4e584cec9dced4f6f283506b422cdd953e946f6a8e9f0/ty-0.0.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:139d2a741579ad86a044233b5d7e189bb81f427eebce3464202f49c3ec0eba3b", size = 9686033, upload-time = "2026-01-07T22:08:40.967Z" }, + { url = "https://files.pythonhosted.org/packages/cc/bc/9d2f5fec925977446d577fb9b322d0e7b1b1758709f23a6cfc10231e9b84/ty-0.0.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6bae10420c0abfe4601fbbc6ce637b67d0b87a44fa520283131a26da98f2e74c", size = 9841905, upload-time = "2026-01-07T23:04:21.694Z" }, + { url = "https://files.pythonhosted.org/packages/7c/b8/5acd3492b6a4ef255ace24fcff0d4b1471a05b7f3758d8910a681543f899/ty-0.0.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:7358bbc5d037b9c59c3a48895206058bcd583985316c4125a74dd87fd1767adb", size = 10320058, upload-time = "2026-01-07T22:08:25.645Z" }, + { url = "https://files.pythonhosted.org/packages/35/67/5b6906fccef654c7e801d6ac8dcbe0d493e1f04c38127f82a5e6d7e0aa0e/ty-0.0.10-py3-none-win32.whl", hash = "sha256:f51b6fd485bc695d0fdf555e69e6a87d1c50f14daef6cb980c9c941e12d6bcba", size = 9271806, upload-time = "2026-01-07T22:08:10.08Z" }, + { url = "https://files.pythonhosted.org/packages/42/36/82e66b9753a76964d26fd9bc3514ea0abce0a5ba5ad7d5f084070c6981da/ty-0.0.10-py3-none-win_amd64.whl", hash = "sha256:16deb77a72cf93b89b4d29577829613eda535fbe030513dfd9fba70fe38bc9f5", size = 10130520, upload-time = "2026-01-07T23:04:11.759Z" }, + { url = "https://files.pythonhosted.org/packages/63/52/89da123f370e80b587d2db8551ff31562c882d87b32b0e92b59504b709ae/ty-0.0.10-py3-none-win_arm64.whl", hash = "sha256:7495288bca7afba9a4488c9906466d648ffd3ccb6902bc3578a6dbd91a8f05f0", size = 9626026, upload-time = "2026-01-07T23:04:17.91Z" }, +] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, +] + +[[package]] +name = "tzdata" +version = "2025.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/a7/c202b344c5ca7daf398f3b8a477eeb205cf3b6f32e7ec3a6bac0629ca975/tzdata-2025.3.tar.gz", hash = "sha256:de39c2ca5dc7b0344f2eba86f49d614019d29f060fc4ebc8a417896a620b56a7", size = 196772, upload-time = "2025-12-13T17:45:35.667Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl", hash = "sha256:06a47e5700f3081aab02b2e513160914ff0694bce9947d6b76ebd6bf57cfc5d1", size = 348521, upload-time = "2025-12-13T17:45:33.889Z" }, +] + +[[package]] +name = "urllib3" +version = "2.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, +] + +[[package]] +name = "virtualenv" +version = "20.36.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/49/87e23d8f742f10f965bce5d6b285fc88a4f436b11daf6b6225d4d66f8492/virtualenv-20.36.0.tar.gz", hash = "sha256:a3601f540b515a7983508113f14e78993841adc3d83710fa70f0ac50f43b23ed", size = 6032237, upload-time = "2026-01-07T17:20:04.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/6a/0af36875e0023a1f2d0b66b4051721fc26740e947696922df1665b75e5d3/virtualenv-20.36.0-py3-none-any.whl", hash = "sha256:e7ded577f3af534fd0886d4ca03277f5542053bedb98a70a989d3c22cfa5c9ac", size = 6008261, upload-time = "2026-01-07T17:20:02.87Z" }, +]