diff --git a/src/albert/__init__.py b/src/albert/__init__.py index 1c3bd16c..5629ed49 100644 --- a/src/albert/__init__.py +++ b/src/albert/__init__.py @@ -4,4 +4,4 @@ __all__ = ["Albert", "AlbertClientCredentials", "AlbertSSOClient"] -__version__ = "1.11.1" +__version__ = "1.11.2" diff --git a/src/albert/collections/btinsight.py b/src/albert/collections/btinsight.py index 38dbc322..7d7dbd3f 100644 --- a/src/albert/collections/btinsight.py +++ b/src/albert/collections/btinsight.py @@ -7,6 +7,7 @@ from albert.core.session import AlbertSession from albert.core.shared.enums import OrderBy, PaginationMode from albert.core.shared.identifiers import BTInsightId +from albert.core.utils import ensure_list from albert.resources.btinsight import BTInsight, BTInsightCategory, BTInsightState @@ -132,17 +133,17 @@ def search( """ params = { "offset": offset, - "order": OrderBy(order_by).value if order_by else None, + "order": order_by, "sortBy": sort_by, "text": text, - "name": name, + "name": ensure_list(name), } - if state: - state = state if isinstance(state, list) else [state] - params["state"] = [BTInsightState(x).value for x in state] - if category: - category = category if isinstance(category, list) else [category] - params["category"] = [BTInsightCategory(x).value for x in category] + + state_values = ensure_list(state) + params["state"] = state_values if state_values else None + + category_values = ensure_list(category) + params["category"] = category_values if category_values else None return AlbertPaginator( mode=PaginationMode.OFFSET, diff --git a/src/albert/collections/cas.py b/src/albert/collections/cas.py index eaea9083..47cb81ad 100644 --- a/src/albert/collections/cas.py +++ b/src/albert/collections/cas.py @@ -106,7 +106,7 @@ def get_all( An iterator over Cas entities. """ - params: dict[str, Any] = {"orderBy": order_by.value} + params: dict[str, Any] = {"orderBy": order_by} if id is not None: yield self.get_by_id(id=id) return diff --git a/src/albert/collections/companies.py b/src/albert/collections/companies.py index 027de1e3..a63b9059 100644 --- a/src/albert/collections/companies.py +++ b/src/albert/collections/companies.py @@ -7,6 +7,7 @@ from albert.core.pagination import AlbertPaginator, PaginationMode from albert.core.session import AlbertSession from albert.core.shared.identifiers import CompanyId +from albert.core.utils import ensure_list from albert.exceptions import AlbertException from albert.resources.companies import Company @@ -62,9 +63,8 @@ def get_all( "dupDetection": "false", "startKey": start_key, } - if name: - params["name"] = name if isinstance(name, list) else [name] - params["exactMatch"] = str(exact_match).lower() + params["name"] = ensure_list(name) + params["exactMatch"] = str(exact_match).lower() return AlbertPaginator( mode=PaginationMode.KEY, diff --git a/src/albert/collections/custom_templates.py b/src/albert/collections/custom_templates.py index 68f1f08f..8282c503 100644 --- a/src/albert/collections/custom_templates.py +++ b/src/albert/collections/custom_templates.py @@ -3,13 +3,20 @@ from pydantic import validate_call from albert.collections.base import BaseCollection +from albert.collections.tags import TagCollection from albert.core.logging import logger from albert.core.pagination import AlbertPaginator from albert.core.session import AlbertSession -from albert.core.shared.enums import PaginationMode +from albert.core.shared.enums import OrderBy, PaginationMode, Status from albert.core.shared.identifiers import CustomTemplateId -from albert.exceptions import AlbertHTTPError -from albert.resources.custom_templates import CustomTemplate, CustomTemplateSearchItem +from albert.core.shared.models.patch import PatchOperation +from albert.core.utils import ensure_list +from albert.resources.acls import ACL +from albert.resources.custom_templates import ( + CustomTemplate, + CustomTemplateSearchItem, + TemplateCategory, +) class CustomTemplatesCollection(BaseCollection): @@ -31,6 +38,94 @@ def __init__(self, *, session: AlbertSession): self.base_path = f"/api/{CustomTemplatesCollection._api_version}/customtemplates" @validate_call + def create( + self, + *, + custom_templates: CustomTemplate | list[CustomTemplate], + ) -> list[CustomTemplate]: + """ + Creates one or more custom templates. + + Parameters + ---------- + custom_templates : CustomTemplate | list[CustomTemplate] + The template entities to create. + + Returns + ------- + list[CustomTemplate] + The created CustomTemplate entities. + """ + templates = ensure_list(custom_templates) or [] + if len(templates) == 0: + raise ValueError("At least one CustomTemplate must be provided.") + + payload = [ + template.model_dump( + mode="json", + by_alias=True, + exclude_none=True, + exclude_unset=True, + ) + for template in templates + ] + response = self.session.post(url=self.base_path, json=payload) + response_data = response.json() + created_payloads = ( + (response_data or {}).get("CreatedItems") + if response.status_code == 206 + else response_data + ) or [] + + tag_collection = TagCollection(session=self.session) + + def resolve_tag(tag_id: str | None) -> dict[str, str] | None: + if not tag_id: + return None + tag = tag_collection.get_by_id(id=tag_id) + return {"albertId": tag.id or tag_id, "name": tag.tag} + + def populate_tag_names(section: dict | None) -> None: + if not isinstance(section, dict): + return + tags = section.get("Tags") + if not tags: + return + resolved_tags = [] + for tag in tags: + if isinstance(tag, dict): + tag_id = tag.get("id") or tag.get("albertId") + elif isinstance(tag, str): + tag_id = tag + else: + tag_id = None + + resolved_tag = resolve_tag(tag_id) + if resolved_tag: + resolved_tags.append(resolved_tag) + section["Tags"] = resolved_tags + + for payload in created_payloads: + if not isinstance(payload, dict): + continue + populate_tag_names(payload.get("Data")) + + if response.status_code == 206: + failed_items = response_data.get("FailedItems") or [] + if failed_items: + error_messages = [] + for failed in failed_items: + errors = failed.get("errors") or [] + if errors: + error_messages.extend(err.get("msg", "Unknown error") for err in errors) + joined = " | ".join(error_messages) if error_messages else "Unknown error" + logger.warning( + "Custom template creation partially succeeded. Errors: %s", + joined, + ) + + return [CustomTemplate(**item) for item in created_payloads] + def get_by_id(self, *, id: CustomTemplateId) -> CustomTemplate: """Get a Custom Template by ID @@ -52,8 +147,20 @@ def search( self, *, text: str | None = None, - max_items: int | None = None, offset: int | None = 0, + sort_by: str | None = None, + order_by: OrderBy | None = None, + status: Status | None = None, + created_by: str | None = None, + category: TemplateCategory | list[TemplateCategory] | None = None, + created_by_name: str | list[str] | None = None, + collaborator: str | list[str] | None = None, + facet_text: str | None = None, + facet_field: str | None = None, + contains_field: str | list[str] | None = None, + contains_text: str | list[str] | None = None, + my_role: str | list[str] | None = None, + max_items: int | None = None, ) -> Iterator[CustomTemplateSearchItem]: """ Search for CustomTemplate matching the provided criteria. @@ -64,20 +171,57 @@ def search( Parameters ---------- text : str, optional - Text to filter search results by. - max_items : int, optional - Maximum number of items to return in total. If None, fetches all available items. + Free text search term. offset : int, optional - Offset to begin pagination at. Default is 0. + Starting offset for pagination. + sort_by : str, optional + Field to sort on. + order_by : OrderBy, optional + Sort direction for `sort_by`. + status : Status | str, optional + Filter results by template status. + created_by : str, optional + Filter by creator id. + category : TemplateCategory | list[TemplateCategory], optional + Filter by template categories. + created_by_name : str | list[str], optional + Filter by creator display name(s). + collaborator : str | list[str], optional + Filter by collaborator ids. + facet_text : str, optional + Filter text within a facet. + facet_field : str, optional + Facet field to search inside. + contains_field : str | list[str], optional + Fields to apply contains search to. + contains_text : str | list[str], optional + Text values for contains search. + my_role : str | list[str], optional + Restrict templates to roles held by the calling user. + max_items : int, optional + Maximum number of items to yield client-side. Returns ------- Iterator[CustomTemplateSearchItem] An iterator of CustomTemplateSearchItem items. """ + params = { "text": text, "offset": offset, + "sortBy": sort_by, + "order": order_by, + "status": status, + "createdBy": created_by, + "category": ensure_list(category), + "createdByName": ensure_list(created_by_name), + "collaborator": ensure_list(collaborator), + "facetText": facet_text, + "facetField": facet_field, + "containsField": ensure_list(contains_field), + "containsText": ensure_list(contains_text), + "myRole": ensure_list(my_role), } return AlbertPaginator( @@ -94,32 +238,95 @@ def search( def get_all( self, *, - text: str | None = None, + name: str | list[str] | None = None, + created_by: str | None = None, + category: TemplateCategory | None = None, + start_key: str | None = None, max_items: int | None = None, - offset: int | None = 0, ) -> Iterator[CustomTemplate]: - """ - Retrieve fully hydrated CustomTemplate entities with optional filters. - - This method returns complete entity data using `get_by_id`. - Use :meth:`search` for faster retrieval when you only need lightweight, partial (unhydrated) entities. + """Iterate over CustomTemplate entities with optional filters. Parameters ---------- - text : str, optional - Text filter for template name or content. + name : str | list[str], optional + Filter by template name(s). + created_by : str, optional + Filter by creator id. + category : TemplateCategory, optional + Filter by category. + start_key : str, optional + Provide the `lastKey` from a previous request to resume pagination. max_items : int, optional - Maximum number of items to return in total. If None, fetches all available items. - offset : int, optional - Offset for search pagination. + Maximum number of items to return. Returns ------- Iterator[CustomTemplate] - An iterator of CustomTemplate entities. + An iterator of CustomTemplates. """ - for item in self.search(text=text, max_items=max_items, offset=offset): - try: - yield self.get_by_id(id=item.id) - except AlbertHTTPError as e: - logger.warning(f"Error hydrating custom template {item.id}: {e}") + params = { + "startKey": start_key, + "createdBy": created_by, + "category": category, + } + params["name"] = ensure_list(name) + + return AlbertPaginator( + mode=PaginationMode.KEY, + path=self.base_path, + session=self.session, + params=params, + max_items=max_items, + deserialize=lambda items: [CustomTemplate(**item) for item in items], + ) + + @validate_call + def delete(self, *, id: CustomTemplateId) -> None: + """Delete a custom template by id.""" + + url = f"{self.base_path}/{id}" + self.session.delete(url) + + @validate_call + def update_acl( + self, + *, + custom_template_id: CustomTemplateId, + acl_class: str | None = None, + acls: list[ACL] | None = None, + ) -> CustomTemplate: + """Replace the template's ACL class and/or entries with the provided values and return the updated template.""" + + if acl_class is None and not acls: + raise ValueError("Provide an ACL class and/or ACL entries to update.") + + data = [] + + if acl_class is not None: + data.append( + { + "operation": PatchOperation.UPDATE.value, + "attribute": "class", + "newValue": acl_class, + } + ) + + if acls: + entries = [] + for entry in acls: + payload: dict[str, str] = {"id": entry.id} + if entry.fgc is not None: + payload["fgc"] = getattr(entry.fgc, "value", entry.fgc) + entries.append(payload) + + data.append( + { + "operation": PatchOperation.UPDATE.value, + "attribute": "ACL", + "newValue": entries, + } + ) + + url = f"{self.base_path}/{custom_template_id}/acl" + self.session.patch(url, json={"data": data}) + return self.get_by_id(id=custom_template_id) diff --git a/src/albert/collections/data_columns.py b/src/albert/collections/data_columns.py index 2af8f525..a4ea5906 100644 --- a/src/albert/collections/data_columns.py +++ b/src/albert/collections/data_columns.py @@ -7,6 +7,7 @@ from albert.core.session import AlbertSession from albert.core.shared.enums import OrderBy, PaginationMode from albert.core.shared.identifiers import DataColumnId +from albert.core.utils import ensure_list from albert.resources.data_columns import DataColumn @@ -102,12 +103,12 @@ def deserialize(items: list[dict]) -> Iterator[DataColumn]: yield from (DataColumn(**item) for item in items) params = { - "orderBy": order_by.value, + "orderBy": order_by, "startKey": start_key, - "name": [name] if isinstance(name, str) else name, + "name": ensure_list(name), "exactMatch": exact_match, "default": default, - "dataColumns": [ids] if isinstance(ids, str) else ids, + "dataColumns": ensure_list(ids), } return AlbertPaginator( diff --git a/src/albert/collections/data_templates.py b/src/albert/collections/data_templates.py index 9a2715b9..8c9c6aa0 100644 --- a/src/albert/collections/data_templates.py +++ b/src/albert/collections/data_templates.py @@ -285,7 +285,7 @@ def search( """ params = { "offset": offset, - "order": order_by.value, + "order": order_by, "text": name, "userId": user_id, } diff --git a/src/albert/collections/entity_types.py b/src/albert/collections/entity_types.py index 56dd67d5..bbc845dc 100644 --- a/src/albert/collections/entity_types.py +++ b/src/albert/collections/entity_types.py @@ -282,10 +282,10 @@ def get_all( Returns an iterator of EntityType items matching the search criteria. """ params = { - "service": service.value if service else None, + "service": service, "limit": max_items, "startKey": start_key, - "orderBy": order.value if order else None, + "orderBy": order, } return AlbertPaginator( mode=PaginationMode.KEY, diff --git a/src/albert/collections/inventory.py b/src/albert/collections/inventory.py index b4a8bf1a..7171a6f4 100644 --- a/src/albert/collections/inventory.py +++ b/src/albert/collections/inventory.py @@ -16,6 +16,7 @@ SearchProjectId, WorksheetId, ) +from albert.core.utils import ensure_list from albert.resources.facet import FacetItem from albert.resources.inventory import ( ALL_MERGE_MODULES, @@ -88,10 +89,10 @@ def merge( # define merge endpoint url = f"{self.base_path}/merge" - if isinstance(child_id, list): - child_inventories = [{"id": i} for i in child_id] - else: - child_inventories = [{"id": child_id}] + child_ids = ensure_list(child_id) or [] + if not child_ids: + raise ValueError("At least one child inventory id is required for merge operations.") + child_inventories = [{"id": i} for i in child_ids] # define payload using the class payload = MergeInventory( @@ -360,9 +361,9 @@ def _prepare_parameters( params = { "text": text, - "order": order.value if order is not None else None, + "order": order, "sortBy": sort_by if sort_by is not None else None, - "category": [c.value for c in category] if category is not None else None, + "category": category, "tags": tags, "manufacturer": [c.name for c in company] if company is not None else None, "cas": [c.number for c in cas] if cas is not None else None, @@ -447,8 +448,7 @@ def get_facet_by_name( This can be used for example to fetch all remaining tags as part of an iterative refinement of a search. """ - if isinstance(name, str): - name = [name] + name = ensure_list(name) or [] facets = self.get_all_facets( text=text, diff --git a/src/albert/collections/lists.py b/src/albert/collections/lists.py index 224f64ab..0008d700 100644 --- a/src/albert/collections/lists.py +++ b/src/albert/collections/lists.py @@ -92,7 +92,7 @@ def get_all( params = { "startKey": start_key, "name": names, - "category": category.value if isinstance(category, ListItemCategory) else category, + "category": category, "listType": list_type, "orderBy": order_by, } diff --git a/src/albert/collections/locations.py b/src/albert/collections/locations.py index d0ed4b4a..a83b60df 100644 --- a/src/albert/collections/locations.py +++ b/src/albert/collections/locations.py @@ -4,6 +4,7 @@ from albert.core.pagination import AlbertPaginator from albert.core.session import AlbertSession from albert.core.shared.enums import PaginationMode +from albert.core.utils import ensure_list from albert.resources.locations import Location @@ -64,9 +65,8 @@ def get_all( } if ids: params["id"] = ids - if name: - params["name"] = [name] if isinstance(name, str) else name - params["exactMatch"] = exact_match + params["name"] = ensure_list(name) + params["exactMatch"] = exact_match return AlbertPaginator( mode=PaginationMode.KEY, diff --git a/src/albert/collections/lots.py b/src/albert/collections/lots.py index 07cc17af..ac77c6f2 100644 --- a/src/albert/collections/lots.py +++ b/src/albert/collections/lots.py @@ -10,6 +10,7 @@ from albert.core.shared.enums import OrderBy, PaginationMode from albert.core.shared.identifiers import InventoryId, LotId, TaskId from albert.core.shared.models.patch import PatchDatum, PatchOperation, PatchPayload +from albert.core.utils import ensure_list from albert.resources.inventory import InventoryCategory from albert.resources.lots import Lot, LotSearchItem @@ -188,39 +189,21 @@ def search( search_text = text if (text is None or len(text) < 50) else text[:50] - def _ensure_list(value): - if value is None: - return None - if isinstance(value, list | tuple | set): - return list(value) - return [value] - - def _format_categories(value): - raw = _ensure_list(value) - if raw is None: - return None - formatted: list[str] = [] - for category in raw: - formatted.append( - category.value if isinstance(category, InventoryCategory) else category - ) - return formatted - params = { "offset": offset, - "order": order_by.value, + "order": order_by, "text": search_text, "sortBy": sort_by, "isDropDown": is_drop_down, - "inventoryId": _ensure_list(inventory_id), - "locationId": _ensure_list(location_id), - "storageLocationId": _ensure_list(storage_location_id), - "taskId": _ensure_list(task_id), - "category": _format_categories(category), - "externalBarcodeId": _ensure_list(external_barcode_id), - "searchField": _ensure_list(search_field), - "sourceField": _ensure_list(source_field), - "additionalField": _ensure_list(additional_field), + "inventoryId": ensure_list(inventory_id), + "locationId": ensure_list(location_id), + "storageLocationId": ensure_list(storage_location_id), + "taskId": ensure_list(task_id), + "category": ensure_list(category), + "externalBarcodeId": ensure_list(external_barcode_id), + "searchField": ensure_list(search_field), + "sourceField": ensure_list(source_field), + "additionalField": ensure_list(additional_field), } params = {key: value for key, value in params.items() if value is not None} diff --git a/src/albert/collections/notes.py b/src/albert/collections/notes.py index 83f81118..73cf1c5b 100644 --- a/src/albert/collections/notes.py +++ b/src/albert/collections/notes.py @@ -106,7 +106,7 @@ def get_by_parent_id( """ params = { "parentId": parent_id, - "orderBy": order_by.value, + "orderBy": order_by, } response = self.session.get( url=self.base_path, diff --git a/src/albert/collections/parameter_groups.py b/src/albert/collections/parameter_groups.py index 419fa3ed..1abad4d8 100644 --- a/src/albert/collections/parameter_groups.py +++ b/src/albert/collections/parameter_groups.py @@ -10,6 +10,7 @@ from albert.core.session import AlbertSession from albert.core.shared.enums import OrderBy, PaginationMode from albert.core.shared.identifiers import ParameterGroupId +from albert.core.utils import ensure_list from albert.exceptions import AlbertHTTPError from albert.resources.parameter_groups import ( ParameterGroup, @@ -97,9 +98,9 @@ def search( """ params = { "offset": offset, - "order": order_by.value, + "order": order_by, "text": text, - "types": [types] if isinstance(types, PGType) else types, + "types": ensure_list(types), } return AlbertPaginator( diff --git a/src/albert/collections/parameters.py b/src/albert/collections/parameters.py index 0cbb2c64..37ae09fd 100644 --- a/src/albert/collections/parameters.py +++ b/src/albert/collections/parameters.py @@ -8,6 +8,7 @@ from albert.core.session import AlbertSession from albert.core.shared.enums import OrderBy, PaginationMode from albert.core.shared.identifiers import ParameterId +from albert.core.utils import ensure_list from albert.resources.parameters import Parameter @@ -137,13 +138,12 @@ def deserialize(items: list[dict]) -> Iterator[Parameter]: yield from (Parameter(**item) for item in items) params = { - "orderBy": order_by.value, + "orderBy": order_by, "parameters": ids, "startKey": start_key, } - if names: - params["name"] = [names] if isinstance(names, str) else names - params["exactMatch"] = exact_match + params["name"] = ensure_list(names) + params["exactMatch"] = exact_match return AlbertPaginator( mode=PaginationMode.KEY, diff --git a/src/albert/collections/projects.py b/src/albert/collections/projects.py index b4d731cf..e8e2543a 100644 --- a/src/albert/collections/projects.py +++ b/src/albert/collections/projects.py @@ -187,7 +187,7 @@ def search( An iterator of matching partial (unhydrated) Project results. """ query_params = { - "order": order_by.value, + "order": order_by, "offset": offset, "text": text, "sortBy": sort_by, diff --git a/src/albert/collections/property_data.py b/src/albert/collections/property_data.py index 6fcfc0bd..9a4a9ca4 100644 --- a/src/albert/collections/property_data.py +++ b/src/albert/collections/property_data.py @@ -1,6 +1,5 @@ from collections.abc import Iterator from contextlib import suppress -from enum import Enum import pandas as pd from pydantic import validate_call @@ -23,6 +22,7 @@ UserId, ) from albert.core.shared.models.patch import PatchOperation +from albert.core.utils import ensure_list from albert.exceptions import NotFoundError from albert.resources.property_data import ( BulkPropertyData, @@ -1086,22 +1086,19 @@ def search( def deserialize(items: list[dict]) -> list[PropertyDataSearchItem]: return [PropertyDataSearchItem.model_validate(x) for x in items] - def ensure_list(v): - if v is None: - return None - return [v] if isinstance(v, str | Enum) else v + category_values = ensure_list(category) params = { "result": result, "text": text, - "order": order.value if order else None, + "order": order, "sortBy": sort_by, "inventoryIds": ensure_list(inventory_ids), "projectIds": ensure_list(project_ids), "lotIds": ensure_list(lot_ids), "dataTemplateId": ensure_list(data_template_ids), "dataColumnId": ensure_list(data_column_ids), - "category": [c.value for c in ensure_list(category)] if category else None, + "category": category_values if category_values else None, "dataTemplates": ensure_list(data_templates), "dataColumns": ensure_list(data_columns), "parameters": ensure_list(parameters), diff --git a/src/albert/collections/storage_locations.py b/src/albert/collections/storage_locations.py index d12b724b..36ca7583 100644 --- a/src/albert/collections/storage_locations.py +++ b/src/albert/collections/storage_locations.py @@ -7,6 +7,7 @@ from albert.core.session import AlbertSession from albert.core.shared.enums import PaginationMode from albert.core.shared.models.base import EntityLink +from albert.core.utils import ensure_list from albert.exceptions import AlbertHTTPError from albert.resources.locations import Location from albert.resources.storage_locations import StorageLocation @@ -93,9 +94,8 @@ def deserialize(items: list[dict]) -> Iterator[StorageLocation]: "startKey": start_key, } - if name: - params["name"] = [name] if isinstance(name, str) else name - params["exactMatch"] = exact_match + params["name"] = ensure_list(name) + params["exactMatch"] = exact_match return AlbertPaginator( mode=PaginationMode.KEY, diff --git a/src/albert/collections/tags.py b/src/albert/collections/tags.py index e131aad7..8bd7d89c 100644 --- a/src/albert/collections/tags.py +++ b/src/albert/collections/tags.py @@ -9,6 +9,7 @@ from albert.core.session import AlbertSession from albert.core.shared.enums import OrderBy, PaginationMode from albert.core.shared.identifiers import TagId +from albert.core.utils import ensure_list from albert.exceptions import AlbertException from albert.resources.tags import Tag @@ -258,12 +259,12 @@ def get_all( An iterator of Tag entities matching the filters. """ params = { - "orderBy": order_by.value, + "orderBy": order_by, "startKey": start_key, } if name: - params["name"] = [name] if isinstance(name, str) else name + params["name"] = ensure_list(name) params["exactMatch"] = exact_match return AlbertPaginator( diff --git a/src/albert/collections/tasks.py b/src/albert/collections/tasks.py index 11c70b4e..78ebdc12 100644 --- a/src/albert/collections/tasks.py +++ b/src/albert/collections/tasks.py @@ -25,6 +25,7 @@ WorkflowId, remove_id_prefix, ) +from albert.core.utils import ensure_list from albert.exceptions import AlbertHTTPError from albert.resources.attachments import AttachmentCategory from albert.resources.data_templates import ImportMode @@ -570,13 +571,12 @@ def search( params = { "offset": offset, - "order": order_by.value, + "order": order_by, "text": text, "sortBy": sort_by, "tags": tags, "taskId": task_id, "linkedTask": linked_task, - "category": category, "albertId": albert_id, "dataTemplate": data_template, "assignedTo": assigned_to, @@ -588,6 +588,9 @@ def search( "projectId": project_id, } + category_values = ensure_list(category) + params["category"] = category_values if category_values else None + return AlbertPaginator( mode=PaginationMode.OFFSET, path=f"{self.base_path}/search", @@ -749,7 +752,7 @@ def get_history( """Fetch the audit history for the specified task.""" params = { "limit": limit, - "orderBy": OrderBy(order).value if order else None, + "orderBy": order, "entity": entity, "blockId": blockId, "startKey": startKey, diff --git a/src/albert/collections/units.py b/src/albert/collections/units.py index 5199514d..abd82524 100644 --- a/src/albert/collections/units.py +++ b/src/albert/collections/units.py @@ -8,6 +8,7 @@ from albert.core.session import AlbertSession from albert.core.shared.enums import OrderBy, PaginationMode from albert.core.shared.identifiers import UnitId +from albert.core.utils import ensure_list from albert.resources.units import Unit, UnitCategory @@ -192,11 +193,11 @@ def get_all( An iterator of Unit entities. """ params = { - "orderBy": order_by.value, - "name": [name] if isinstance(name, str) else name, + "orderBy": order_by, + "name": ensure_list(name), "exactMatch": exact_match, "verified": verified, - "category": category.value if isinstance(category, UnitCategory) else category, + "category": category, "startKey": start_key, } diff --git a/src/albert/collections/users.py b/src/albert/collections/users.py index b962626a..b1dd48c1 100644 --- a/src/albert/collections/users.py +++ b/src/albert/collections/users.py @@ -135,7 +135,7 @@ def search( params = { "text": text, "sortBy": sort_by, - "order": order_by.value, + "order": order_by, "roles": roles, "teams": teams, "locations": locations, @@ -198,7 +198,7 @@ def get_all( """ params = { "status": status, - "type": type.value if type else None, + "type": type, "id": id, "startKey": start_key, } diff --git a/src/albert/core/utils.py b/src/albert/core/utils.py new file mode 100644 index 00000000..7dfa7e86 --- /dev/null +++ b/src/albert/core/utils.py @@ -0,0 +1,20 @@ +"""Utility helpers shared across Albert SDK modules.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import TypeVar + +T = TypeVar("T") + + +def ensure_list(value: T | Iterable[T] | None) -> list[T] | None: + """Return ``value`` as a list, preserving ``None`` and existing lists.""" + + if value is None: + return None + if isinstance(value, list): + return value + if isinstance(value, tuple | set): + return list(value) + return [value] diff --git a/src/albert/resources/custom_templates.py b/src/albert/resources/custom_templates.py index 85b49603..6b5059b5 100644 --- a/src/albert/resources/custom_templates.py +++ b/src/albert/resources/custom_templates.py @@ -5,7 +5,7 @@ from albert.core.base import BaseAlbertModel from albert.core.shared.enums import SecurityClass, Status -from albert.core.shared.identifiers import CustomTemplateId, NotebookId +from albert.core.shared.identifiers import CustomTemplateId, EntityTypeId, NotebookId from albert.core.shared.models.base import BaseResource, EntityLink from albert.core.shared.types import MetadataItem, SerializeAsEntityLink from albert.resources._mixins import HydrationMixin @@ -29,6 +29,11 @@ class DesignLink(EntityLink): type: DesignType +class TemplateEntityType(BaseAlbertModel): + id: EntityTypeId | None = Field(default=None) + custom_category: str | None = Field(default=None, alias="customCategory") + + class TemplateCategory(str, Enum): PROPERTY_LIST = "Property Task" PROPERTY = "Property" @@ -79,7 +84,7 @@ class SamConfig(BaseResource): class Workflow(BaseResource): id: str - name: str + name: str | None = Field(default=None) # Some workflows may have SamConfig sam_config: list[SamConfig] | None = Field(default=None, alias="SamConfig") @@ -191,6 +196,10 @@ class CustomTemplate(BaseTaggedResource): The metadata of the template. Allowed Metadata fields can be found using Custim Fields. data : CustomTemplateData | None The data of the template. + entity_type : TemplateEntityType | None + The entity type associated with the template. + locked : bool | None + Whether the template is locked when loaded in the UI. team : List[TeamACL] | None The team of the template. acl : TemplateACL | None @@ -202,8 +211,10 @@ class CustomTemplate(BaseTaggedResource): category: TemplateCategory = Field(default=TemplateCategory.GENERAL) metadata: dict[str, MetadataItem] | None = Field(default=None, alias="Metadata") data: CustomTemplateData | None = Field(default=None, alias="Data") - team: list[TeamACL] | None = Field(default_factory=list) - acl: TemplateACL | None = Field(default_factory=list, alias="ACL") + entity_type: TemplateEntityType | None = Field(default=None, alias="EntityType") + locked: bool | None = Field(default=None) + team: list[TeamACL] | None = Field(default=None) + acl: TemplateACL | None = Field(default=None, alias="ACL") @model_validator(mode="before") # Must happen before construction so the data are captured @classmethod @@ -241,9 +252,9 @@ class CustomTemplateSearchItem(BaseAlbertModel, HydrationMixin[CustomTemplate]): id: CustomTemplateId = Field(alias="albertId") created_by_name: str = Field(..., alias="createdByName") created_at: str = Field(..., alias="createdAt") - category: str + category: str | None = None status: Status | None = None resource_class: SecurityClass | None = Field(default=None, alias="resourceClass") data: CustomTemplateSearchItemData | None = None - acl: list[CustomTemplateSearchItemACL] - team: list[CustomTemplateSearchItemTeam] + acl: list[CustomTemplateSearchItemACL] | None = None + team: list[CustomTemplateSearchItemTeam] | None = None diff --git a/src/albert/resources/tags.py b/src/albert/resources/tags.py index 7f57d06d..51bf0221 100644 --- a/src/albert/resources/tags.py +++ b/src/albert/resources/tags.py @@ -1,13 +1,10 @@ from __future__ import annotations from enum import Enum -from typing import Any -from pydantic import AliasChoices, Field, model_validator +from pydantic import AliasChoices, Field -from albert.core.logging import logger from albert.core.shared.models.base import BaseResource -from albert.core.shared.types import SerializeAsEntityLink class TagEntity(str, Enum): @@ -62,40 +59,3 @@ def from_string(cls, tag: str) -> Tag: The Tag object created from the string. """ return cls(tag=tag) - - -class BaseTaggedEntity(BaseResource): - """ - BaseTaggedEntity is a Pydantic model that includes functionality for handling tags as either Tag objects or strings. - - Attributes - ---------- - tags : List[Tag | str] | None - A list of Tag objects or strings representing tags. - """ - - tags: list[SerializeAsEntityLink[Tag]] | None = Field(None, alias="Tags") - - @model_validator(mode="before") # must happen before to keep type validation - @classmethod - def convert_tags(cls, data: dict[str, Any]) -> dict[str, Any]: - if not isinstance(data, dict): - return data - tags = data.get("tags") - if not tags: - tags = data.get("Tags") - if tags: - new_tags = [] - for t in tags: - if isinstance(t, Tag): - new_tags.append(t) - elif isinstance(t, str): - new_tags.append(Tag.from_string(t)) - elif isinstance(t, dict): - new_tags.append(Tag(**t)) - else: - # We do not expect this else to be hit because tags should only be Tag or str - logger.warning(f"Unexpected value for Tag. {t} of type {type(t)}") - continue - data["tags"] = new_tags - return data diff --git a/tests/collections/test_inventory.py b/tests/collections/test_inventory.py index fb7cefe3..6d3dba20 100644 --- a/tests/collections/test_inventory.py +++ b/tests/collections/test_inventory.py @@ -1,10 +1,7 @@ -import time - import pytest from albert.client import Albert from albert.collections.inventory import InventoryCategory -from albert.core.shared.enums import SecurityClass from albert.core.shared.identifiers import ensure_inventory_id from albert.exceptions import BadRequestError from albert.resources.cas import Cas @@ -133,31 +130,17 @@ def test_get_by_ids(client: Albert): # assert f"INV{inventory_id}" == inventory.id -def test_inventory_update(client: Albert, seed_prefix: str): - # create a new test inventory item - ii = InventoryItem( - name=f"{seed_prefix} - SDK UPDATE/DELETE TEST", - description="SDK item that will be updated and deleted.", - category=InventoryCategory.RAW_MATERIALS, - unit_category=InventoryUnitCategory.MASS, - security_class=SecurityClass.CONFIDENTIAL, - company="", - ) - created = client.inventory.create(inventory_item=ii) - # Give time for the DB to sync - somewhere between 1 and 4 seconds is needed - # for this test to work - time.sleep(4) +def test_inventory_update(client: Albert, seed_prefix: str, seeded_inventory: list[InventoryItem]): + # get a test inventory item + inventory_item = seeded_inventory[0] - assert client.inventory.exists(inventory_item=created) + assert client.inventory.exists(inventory_item=inventory_item) d = "testing SDK CRUD" - created.description = d + inventory_item.description = d - updated = client.inventory.update(inventory_item=created) + updated = client.inventory.update(inventory_item=inventory_item) assert updated.description == d - assert updated.id == created.id - - client.inventory.delete(id=created.id) - assert not client.inventory.exists(inventory_item=created) + assert updated.id == inventory_item.id def test_collection_blocks_formulation(client: Albert, seeded_projects):