From 9d2048e7041bac113e17bb148211fece1ac2fe41 Mon Sep 17 00:00:00 2001 From: Patrick Huck Date: Wed, 5 Mar 2025 11:37:09 -0800 Subject: [PATCH 01/26] exclude gnome for full downloads if needed --- mp_api/client/core/client.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 4cd98958c..50d862926 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -473,6 +473,23 @@ def _query_resource( suffix = infix if suffix == "core" else suffix suffix = suffix.replace("_", "-") + # Check if user has access to GNoMe + has_gnome_access = bool( + self._submit_requests( + url=urljoin(self.endpoint, "materials/summary/"), + criteria={ + "batch_id": "gnome_r2scan_statics", + "_fields": "material_id", + }, + use_document_model=False, + num_chunks=1, + chunk_size=1, + timeout=timeout, + ) + .get("meta", {}) + .get("total_doc", 0) + ) + # Paginate over all entries in the bucket. # TODO: change when a subset of entries needed from DB if "tasks" in suffix: @@ -481,6 +498,11 @@ def _query_resource( bucket_suffix = "build" prefix = f"collections/{db_version}/{suffix}" + # only include prefixes accessible to user + # i.e. append `batch_id=others/core` to `prefix` + if not has_gnome_access: + prefix += "/batch_id=others" + bucket = f"materialsproject-{bucket_suffix}" paginator = self.s3_client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket, Prefix=prefix) From 505ddfe0c311ef4de64d6dd9c19ea78c41b75754 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Wed, 22 Oct 2025 17:20:29 -0700 Subject: [PATCH 02/26] query s3 for trajectories --- mp_api/client/routes/materials/tasks.py | 33 +++++++++++++++++++------ 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index c78650780..b0854d42b 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -3,7 +3,11 @@ from datetime import datetime from typing import TYPE_CHECKING +import pyarrow as pa +from deltalake import DeltaTable, QueryBuilder +from emmet.core.mpid import AlphaID from emmet.core.tasks import CoreTaskDoc +from emmet.core.trajectory import RelaxTrajectory from mp_api.client.core import BaseRester, MPRestError from mp_api.client.core.utils import validate_ids @@ -16,6 +20,7 @@ class TaskRester(BaseRester): suffix: str = "materials/tasks" document_model: type[BaseModel] = CoreTaskDoc # type: ignore primary_key: str = "task_id" + delta_backed = True def get_trajectory(self, task_id): """Returns a Trajectory object containing the geometry of the @@ -26,16 +31,30 @@ def get_trajectory(self, task_id): task_id (str): Task ID """ - traj_data = self._query_resource_data( - {"task_ids": [task_id]}, suburl="trajectory/", use_document_model=False - )[0].get( - "trajectories", None - ) # type: ignore + as_alpha = str(AlphaID(task_id, padlen=8)).split("-")[-1] - if traj_data is None: + traj_tbl = DeltaTable( + "s3a://materialsproject-parsed/core/trajectories/", + storage_options={"AWS_SKIP_SIGNATURE": "true", "AWS_REGION": "us-east-1"}, + ) + + traj_data = pa.table( + QueryBuilder() + .register("traj", traj_tbl) + .execute( + f""" + SELECT * + FROM traj + WHERE identifier='{as_alpha}' + """ + ) + .read_all() + ).to_pylist(maps_as_pydicts="strict") + + if not traj_data: raise MPRestError(f"No trajectory data for {task_id} found") - return traj_data + return RelaxTrajectory(**traj_data[0]).to_pmg() def search( self, From aee0f8c117e01e514604b4c5996f144a4c3b560d Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Thu, 23 Oct 2025 08:39:17 -0700 Subject: [PATCH 03/26] add deltalake query support --- mp_api/client/core/client.py | 192 ++++++++++++++++++++++++++++----- mp_api/client/core/settings.py | 14 +++ mp_api/client/core/utils.py | 63 +++++++++++ mp_api/client/mprester.py | 13 +++ pyproject.toml | 2 + 5 files changed, 258 insertions(+), 26 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 50d862926..5234ef344 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -9,6 +9,7 @@ import itertools import os import platform +import shutil import sys import warnings from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait @@ -18,15 +19,13 @@ from importlib.metadata import PackageNotFoundError, version from json import JSONDecodeError from math import ceil -from typing import ( - TYPE_CHECKING, - ForwardRef, - Optional, - get_args, -) +from typing import TYPE_CHECKING, ForwardRef, Optional, get_args from urllib.parse import quote, urljoin +import pyarrow as pa +import pyarrow.dataset as ds import requests +from deltalake import DeltaTable, QueryBuilder, convert_to_deltalake from emmet.core.utils import jsanitize from pydantic import BaseModel, create_model from requests.adapters import HTTPAdapter @@ -36,7 +35,7 @@ from urllib3.util.retry import Retry from mp_api.client.core.settings import MAPIClientSettings -from mp_api.client.core.utils import load_json, validate_ids +from mp_api.client.core.utils import MPDataset, load_json, validate_ids try: import boto3 @@ -71,6 +70,7 @@ class BaseRester: document_model: type[BaseModel] | None = None supports_versions: bool = False primary_key: str = "material_id" + delta_backed: bool = False def __init__( self, @@ -85,6 +85,8 @@ def __init__( timeout: int = 20, headers: dict | None = None, mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS, + local_dataset_cache: str | os.PathLike = SETTINGS.LOCAL_DATASET_CACHE, + force_renew: bool = False, ): """Initialize the REST API helper class. @@ -116,6 +118,9 @@ def __init__( timeout: Time in seconds to wait until a request timeout error is thrown headers: Custom headers for localhost connections. mute_progress_bars: Whether to disable progress bars. + local_dataset_cache: Target directory for downloading full datasets. Defaults + to 'materialsproject_datasets' in the user's home directory + force_renew: Option to overwrite existing local dataset """ # TODO: think about how to migrate from PMG_MAPI_KEY self.api_key = api_key or os.getenv("MP_API_KEY") @@ -129,6 +134,8 @@ def __init__( self.timeout = timeout self.headers = headers or {} self.mute_progress_bars = mute_progress_bars + self.local_dataset_cache = local_dataset_cache + self.force_renew = force_renew self.db_version = BaseRester._get_database_version(self.endpoint) if self.suffix: @@ -212,7 +219,7 @@ def _get_database_version(endpoint): remains unchanged and available for querying via its task_id. The database version is set as a date in the format YYYY_MM_DD, - where "_DD" may be optional. An additional numerical or `postN` suffix + predicate "_DD" may be optional. An additional numerical or `postN` suffix might be added if multiple releases happen on the same day. Returns: database version as a string @@ -356,10 +363,7 @@ def _patch_resource( raise MPRestError(str(ex)) def _query_open_data( - self, - bucket: str, - key: str, - decoder: Callable | None = None, + self, bucket: str, key: str, decoder: Callable | None = None ) -> tuple[list[dict] | list[bytes], int]: """Query and deserialize Materials Project AWS open data s3 buckets. @@ -463,6 +467,12 @@ def _query_resource( url += "/" if query_s3: + pbar_message = ( # type: ignore + f"Retrieving {self.document_model.__name__} documents" # type: ignore + if self.document_model is not None + else "Retrieving documents" + ) + db_version = self.db_version.replace(".", "-") if "/" not in self.suffix: suffix = self.suffix @@ -474,9 +484,14 @@ def _query_resource( suffix = suffix.replace("_", "-") # Check if user has access to GNoMe + # temp suppress tqdm + re_enable = not self.mute_progress_bars + self.mute_progress_bars = True has_gnome_access = bool( self._submit_requests( - url=urljoin(self.endpoint, "materials/summary/"), + url=urljoin( + "https://api.materialsproject.org/", "materials/summary/" + ), criteria={ "batch_id": "gnome_r2scan_statics", "_fields": "material_id", @@ -489,21 +504,147 @@ def _query_resource( .get("meta", {}) .get("total_doc", 0) ) + self.mute_progress_bars = not re_enable - # Paginate over all entries in the bucket. - # TODO: change when a subset of entries needed from DB if "tasks" in suffix: - bucket_suffix, prefix = "parsed", "tasks_atomate2" + bucket_suffix, prefix = ("parsed", "core/tasks/") else: bucket_suffix = "build" prefix = f"collections/{db_version}/{suffix}" - # only include prefixes accessible to user - # i.e. append `batch_id=others/core` to `prefix` - if not has_gnome_access: - prefix += "/batch_id=others" - bucket = f"materialsproject-{bucket_suffix}" + + if self.delta_backed: + target_path = ( + self.local_dataset_cache + f"/{bucket_suffix}/{prefix}" + ) + os.makedirs(target_path, exist_ok=True) + + if DeltaTable.is_deltatable(target_path): + if self.force_renew: + shutil.rmtree(target_path) + warnings.warn( + f"Regenerating {suffix} dataset at {target_path}...", + MPLocalDatasetWarning, + ) + os.makedirs(target_path, exist_ok=True) + else: + warnings.warn( + f"Dataset for {suffix} already exists at {target_path}, delete or move existing dataset " + "or re-run search query with MPRester(force_renew=True)", + MPLocalDatasetWarning, + ) + + return { + "data": MPDataset( + path=target_path, + document_model=self.document_model, + use_document_model=self.use_document_model, + ) + } + + tbl = DeltaTable( + f"s3a://{bucket}/{prefix}", + storage_options={ + "AWS_SKIP_SIGNATURE": "true", + "AWS_REGION": "us-east-1", + }, + ) + + controlled_batch_str = ",".join( + [f"'{tag}'" for tag in SETTINGS.ACCESS_CONTROLLED_BATCH_IDS] + ) + + predicate = ( + " WHERE batch_id NOT IN (" # don't delete leading space + + controlled_batch_str + + ")" + if not has_gnome_access + else "" + ) + + builder = QueryBuilder().register("tbl", tbl) + + # Setup progress bar + num_docs_needed = pa.table( + builder.execute("SELECT COUNT(*) FROM tbl").read_all() + )[0][0].as_py() + + # TODO: Update tasks (+ others?) resource to have emmet-api BatchIdQuery operator + # -> need to modify BatchIdQuery operator to handle root level + # batch_id, not only builder_meta.batch_id + # if not has_gnome_access: + # num_docs_needed = self.count( + # {"batch_id_neq_any": SETTINGS.ACCESS_CONTROLLED_BATCH_IDS} + # ) + + pbar = ( + tqdm( + desc=pbar_message, + total=num_docs_needed, + ) + if not self.mute_progress_bars + else None + ) + + iterator = builder.execute("SELECT * FROM tbl" + predicate) + + file_options = ds.ParquetFileFormat().make_write_options( + compression="zstd" + ) + + def _flush(accumulator, group): + ds.write_dataset( + accumulator, + base_dir=target_path, + format="parquet", + basename_template=f"group-{group}-" + + "part-{i}.zstd.parquet", + existing_data_behavior="overwrite_or_ignore", + max_rows_per_group=1024, + file_options=file_options, + ) + + group = 1 + size = 0 + accumulator = [] + for page in iterator: + # arro3 rb to pyarrow rb for compat w/ pyarrow ds writer + accumulator.append(pa.record_batch(page)) + page_size = page.num_rows + size += page_size + + if pbar is not None: + pbar.update(page_size) + + if size >= SETTINGS.DATASET_FLUSH_THRESHOLD: + _flush(accumulator, group) + group += 1 + size = 0 + accumulator = [] + + if accumulator: + _flush(accumulator, group + 1) + + convert_to_deltalake(target_path) + + warnings.warn( + f"Dataset for {suffix} written to {target_path}. It is recommended to optimize " + "the table according to your usage patterns prior to running intensive workloads, " + "see: https://delta-io.github.io/delta-rs/delta-lake-best-practices/#optimizing-table-layout", + MPLocalDatasetWarning, + ) + + return { + "data": MPDataset( + path=target_path, + document_model=self.document_model, + use_document_model=self.use_document_model, + ) + } + + # Paginate over all entries in the bucket. + # TODO: change when a subset of entries needed from DB paginator = self.s3_client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket, Prefix=prefix) @@ -540,11 +681,6 @@ def _query_resource( } # Setup progress bar - pbar_message = ( # type: ignore - f"Retrieving {self.document_model.__name__} documents" # type: ignore - if self.document_model is not None - else "Retrieving documents" - ) num_docs_needed = int(self.count()) pbar = ( tqdm( @@ -1372,3 +1508,7 @@ class MPRestError(Exception): class MPRestWarning(Warning): """Raised when a query is malformed but interpretable.""" + + +class MPLocalDatasetWarning(Warning): + """Raised when unrecoverable actions are performed on a local dataset.""" diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index 200b67785..9dbc6a386 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -87,4 +87,18 @@ class MAPIClientSettings(BaseSettings): _MAX_LIST_LENGTH, description="Maximum length of query parameter list" ) + LOCAL_DATASET_CACHE: str = Field( + os.path.expanduser("~") + "/mp_datasets", + description="Target directory for downloading full datasets", + ) + + DATASET_FLUSH_THRESHOLD: int = Field( + 100000, + description="Threshold number of rows to accumulate in memory before flushing dataset to disk", + ) + + ACCESS_CONTROLLED_BATCH_IDS: list[str] = Field( + ["gnome_r2scan_statics"], description="Batch ids with access restrictions" + ) + model_config = SettingsConfigDict(env_prefix="MPRESTER_") diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index c2d03fec2..8fb48c142 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -1,12 +1,17 @@ from __future__ import annotations import re +from functools import cached_property +from itertools import chain from typing import TYPE_CHECKING, Literal import orjson +import pyarrow.dataset as ds +from deltalake import DeltaTable from emmet.core import __version__ as _EMMET_CORE_VER from monty.json import MontyDecoder from packaging.version import parse as parse_version +from pydantic._internal._model_construction import ModelMetaclass from mp_api.client.core.settings import MAPIClientSettings @@ -124,3 +129,61 @@ def validate_monty(cls, v, _): monty_cls.validate_monty_v2 = classmethod(validate_monty) return monty_cls + + +class MPDataset: + def __init__(self, path, document_model, use_document_model): + self._start = 0 + self._path = path + self._document_model = document_model + self._dataset = ds.dataset(path) + self._row_groups = list( + chain.from_iterable( + [ + fragment.split_by_row_group() + for fragment in self._dataset.get_fragments() + ] + ) + ) + self._use_document_model = use_document_model + + @property + def pyarrow_dataset(self) -> ds.Dataset: + return self._dataset + + @property + def pydantic_model(self) -> ModelMetaclass: + return self._document_model + + @property + def use_document_model(self) -> bool: + return self._use_document_model + + @use_document_model.setter + def use_document_model(self, value: bool): + self._use_document_model = value + + @cached_property + def delta_table(self) -> DeltaTable: + return DeltaTable(self._path) + + @cached_property + def num_chunks(self) -> int: + return len(self._row_groups) + + def __getitem__(self, idx): + return list( + map( + lambda x: self._document_model(**x) if self._use_document_model else x, + self._row_groups[idx].to_table().to_pylist(maps_as_pydicts="strict"), + ) + ) + + def __len__(self) -> int: + return self.num_chunks + + def __iter__(self): + current = self._start + while current < self.num_chunks: + yield self[current] + current += 1 diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 3fdc07f92..5537736a6 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -133,6 +133,8 @@ def __init__( session: Session | None = None, headers: dict | None = None, mute_progress_bars: bool = _MAPI_SETTINGS.MUTE_PROGRESS_BARS, + local_dataset_cache: str | os.PathLike = _MAPI_SETTINGS.LOCAL_DATASET_CACHE, + force_renew: bool = False, ): """Initialize the MPRester. @@ -167,6 +169,9 @@ def __init__( session: Session object to use. By default (None), the client will create one. headers: Custom headers for localhost connections. mute_progress_bars: Whether to mute progress bars. + local_dataset_cache: Target directory for downloading full datasets. Defaults + to "materialsproject_datasets" in the user's home directory + force_renew: Option to overwrite existing local dataset """ # SETTINGS tries to read API key from ~/.config/.pmgrc.yaml @@ -192,6 +197,8 @@ def __init__( self.use_document_model = use_document_model self.monty_decode = monty_decode self.mute_progress_bars = mute_progress_bars + self.local_dataset_cache = local_dataset_cache + self.force_renew = force_renew self._contribs = None self._deprecated_attributes = [ @@ -267,6 +274,8 @@ def __init__( use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, + local_dataset_cache=self.local_dataset_cache, + force_renew=self.force_renew, ) for cls in self._all_resters if cls.suffix in core_suffix @@ -293,6 +302,8 @@ def __init__( use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, + local_dataset_cache=self.local_dataset_cache, + force_renew=self.force_renew, ) # type: BaseRester setattr( self, @@ -323,6 +334,8 @@ def __core_custom_getattr(_self, _attr, _rester_map): use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, + local_dataset_cache=self.local_dataset_cache, + force_renew=self.force_renew, ) # type: BaseRester setattr( diff --git a/pyproject.toml b/pyproject.toml index f202666c7..063e8c9a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,8 @@ dependencies = [ "smart_open", "boto3", "orjson >= 3.10,<4", + "pyarrow >= 20.0.0", + "deltalake >= 1.2.0", ] dynamic = ["version"] From d5a25b19ca037771010f7c743ce3bae266aba0e6 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Thu, 23 Oct 2025 09:01:57 -0700 Subject: [PATCH 04/26] linting + mistaken sed replace on 'where' --- mp_api/client/core/client.py | 2 +- mp_api/client/core/utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 5234ef344..c8a492336 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -219,7 +219,7 @@ def _get_database_version(endpoint): remains unchanged and available for querying via its task_id. The database version is set as a date in the format YYYY_MM_DD, - predicate "_DD" may be optional. An additional numerical or `postN` suffix + where "_DD" may be optional. An additional numerical or `postN` suffix might be added if multiple releases happen on the same day. Returns: database version as a string diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 8fb48c142..9e7003ed2 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -133,6 +133,7 @@ def validate_monty(cls, v, _): class MPDataset: def __init__(self, path, document_model, use_document_model): + """Convenience wrapper for pyarrow datasets stored on disk.""" self._start = 0 self._path = path self._document_model = document_model From 2de051df8fde2dd2d10e611598b0ea2efdf984a0 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Thu, 23 Oct 2025 10:40:13 -0700 Subject: [PATCH 05/26] return trajectory as pmg dict --- mp_api/client/routes/materials/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index b0854d42b..a879a93c4 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -54,7 +54,7 @@ def get_trajectory(self, task_id): if not traj_data: raise MPRestError(f"No trajectory data for {task_id} found") - return RelaxTrajectory(**traj_data[0]).to_pmg() + return RelaxTrajectory(**traj_data[0]).to_pmg().as_dict() def search( self, From 7d0b8b749b3f163133a5028b6ee169f9bb39cc05 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Thu, 23 Oct 2025 10:46:29 -0700 Subject: [PATCH 06/26] update trajectory test --- tests/materials/test_tasks.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/materials/test_tasks.py b/tests/materials/test_tasks.py index b35dfd938..1ddf12c58 100644 --- a/tests/materials/test_tasks.py +++ b/tests/materials/test_tasks.py @@ -1,8 +1,9 @@ import os -from core_function import client_search_testing -import pytest +import pytest +from core_function import client_search_testing from emmet.core.utils import utcnow + from mp_api.client.routes.materials.tasks import TaskRester @@ -53,7 +54,6 @@ def test_client(rester): def test_get_trajectories(rester): - trajectories = [traj for traj in rester.get_trajectory("mp-149")] + trajectory = rester.get_trajectory("mp-149") - for traj in trajectories: - assert ("@module", "pymatgen.core.trajectory") in traj.items() + assert trajectory["@module"] == "pymatgen.core.trajectory" From 7195adf9b11394898dae78b502e3235b74a18f75 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Thu, 23 Oct 2025 10:48:39 -0700 Subject: [PATCH 07/26] correct docstrs --- mp_api/client/core/client.py | 2 +- mp_api/client/mprester.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index c8a492336..5b74e5dce 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -119,7 +119,7 @@ def __init__( headers: Custom headers for localhost connections. mute_progress_bars: Whether to disable progress bars. local_dataset_cache: Target directory for downloading full datasets. Defaults - to 'materialsproject_datasets' in the user's home directory + to 'mp_datasets' in the user's home directory force_renew: Option to overwrite existing local dataset """ # TODO: think about how to migrate from PMG_MAPI_KEY diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 5537736a6..a60de0f3d 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -170,7 +170,7 @@ def __init__( headers: Custom headers for localhost connections. mute_progress_bars: Whether to mute progress bars. local_dataset_cache: Target directory for downloading full datasets. Defaults - to "materialsproject_datasets" in the user's home directory + to "mp_datasets" in the user's home directory force_renew: Option to overwrite existing local dataset """ From 2664fcdbee06a282de13fa2e8a03513fdfc4c177 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Mon, 3 Nov 2025 10:12:54 -0800 Subject: [PATCH 08/26] get access controlled batch ids from heartbeat --- mp_api/client/core/client.py | 33 ++++++++++++++++++++++++++------- mp_api/client/core/settings.py | 4 ---- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index d3473bc9b..272cadd81 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -154,6 +154,9 @@ def __init__( self.local_dataset_cache = local_dataset_cache self.force_renew = force_renew self.db_version = BaseRester._get_database_version(self.endpoint) + self.access_controlled_batch_ids = BaseRester._get_access_restricted_batch_ids( + self.endpoint + ) if self.suffix: self.endpoint = urljoin(self.endpoint, self.suffix) @@ -243,6 +246,25 @@ def _get_database_version(endpoint): """ return requests.get(url=endpoint + "heartbeat").json()["db_version"] + @staticmethod + @cache + def _get_access_restricted_batch_ids(endpoint): + """Certain contributions to the Materials Project have access + control restrictions that require explicit agreement to the + Terms of Use for the respective datasets prior to access being + granted. + + A full list of the Terms of Use for all contributions in the + Materials Project are available at: + + https://next-gen.materialsproject.org/about/terms + + Returns: a list of strings + """ + return requests.get(url=endpoint + "heartbeat").json()[ + "access_controlled_batch_ids" + ] + def _post_resource( self, body: dict | None = None, @@ -583,13 +605,10 @@ def _query_resource( builder.execute("SELECT COUNT(*) FROM tbl").read_all() )[0][0].as_py() - # TODO: Update tasks (+ others?) resource to have emmet-api BatchIdQuery operator - # -> need to modify BatchIdQuery operator to handle root level - # batch_id, not only builder_meta.batch_id - # if not has_gnome_access: - # num_docs_needed = self.count( - # {"batch_id_neq_any": SETTINGS.ACCESS_CONTROLLED_BATCH_IDS} - # ) + if not has_gnome_access: + num_docs_needed = self.count( + {"batch_id_neq_any": self.access_controlled_batch_ids} + ) pbar = ( tqdm( diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index 9dbc6a386..8b0e63937 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -97,8 +97,4 @@ class MAPIClientSettings(BaseSettings): description="Threshold number of rows to accumulate in memory before flushing dataset to disk", ) - ACCESS_CONTROLLED_BATCH_IDS: list[str] = Field( - ["gnome_r2scan_statics"], description="Batch ids with access restrictions" - ) - model_config = SettingsConfigDict(env_prefix="MPRESTER_") From b498a762e3befc4d6aaf2c4bffb03e4e88e5eada Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:34:48 -0800 Subject: [PATCH 09/26] refactor --- mp_api/client/core/client.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 272cadd81..c01ef3b7e 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -524,9 +524,7 @@ def _query_resource( self.mute_progress_bars = True has_gnome_access = bool( self._submit_requests( - url=urljoin( - "https://api.materialsproject.org/", "materials/summary/" - ), + url=urljoin(self.base_endpoint, "materials/summary/"), criteria={ "batch_id": "gnome_r2scan_statics", "_fields": "material_id", @@ -653,7 +651,7 @@ def _flush(accumulator, group): _flush(accumulator, group) group += 1 size = 0 - accumulator = [] + accumulator.clear() if accumulator: _flush(accumulator, group + 1) From 948c1086182755aca713e9b2ff49820bc032176b Mon Sep 17 00:00:00 2001 From: github-actions Date: Wed, 5 Nov 2025 00:04:36 +0000 Subject: [PATCH 10/26] auto dependency upgrades --- .../requirements-ubuntu-latest_py3.11.txt | 23 ++++++++---- ...quirements-ubuntu-latest_py3.11_extras.txt | 35 ++++++++++++------- .../requirements-ubuntu-latest_py3.12.txt | 22 ++++++++---- ...quirements-ubuntu-latest_py3.12_extras.txt | 34 +++++++++++------- 4 files changed, 78 insertions(+), 36 deletions(-) diff --git a/requirements/requirements-ubuntu-latest_py3.11.txt b/requirements/requirements-ubuntu-latest_py3.11.txt index f18ccd754..7713bfc2c 100644 --- a/requirements/requirements-ubuntu-latest_py3.11.txt +++ b/requirements/requirements-ubuntu-latest_py3.11.txt @@ -6,13 +6,15 @@ # annotated-types==0.7.0 # via pydantic +arro3-core==0.6.5 + # via deltalake bibtexparser==1.4.3 # via pymatgen blake3==1.0.8 # via emmet-core -boto3==1.40.61 +boto3==1.40.66 # via mp-api (pyproject.toml) -botocore==1.40.61 +botocore==1.40.66 # via # boto3 # s3transfer @@ -24,6 +26,10 @@ contourpy==1.3.3 # via matplotlib cycler==0.12.1 # via matplotlib +deltalake==1.2.1 + # via mp-api (pyproject.toml) +deprecated==1.3.1 + # via deltalake emmet-core==0.86.0rc1 # via mp-api (pyproject.toml) fonttools==4.60.1 @@ -51,7 +57,7 @@ mpmath==1.3.0 # via sympy msgpack==1.1.2 # via mp-api (pyproject.toml) -narwhals==2.10.0 +narwhals==2.10.2 # via plotly networkx==3.5 # via pymatgen @@ -79,8 +85,10 @@ pandas==2.3.3 # via pymatgen pillow==12.0.0 # via matplotlib -plotly==6.3.1 +plotly==6.4.0 # via pymatgen +pyarrow==22.0.0 + # via mp-api (pyproject.toml) pybtex==0.25.1 # via emmet-core pydantic==2.12.3 @@ -133,7 +141,7 @@ scipy==1.16.3 # via pymatgen six==1.17.0 # via python-dateutil -smart-open==7.4.1 +smart-open==7.4.4 # via mp-api (pyproject.toml) spglib==2.6.0 # via pymatgen @@ -145,6 +153,7 @@ tqdm==4.67.1 # via pymatgen typing-extensions==4.15.0 # via + # arro3-core # blake3 # emmet-core # mp-api (pyproject.toml) @@ -165,4 +174,6 @@ urllib3==2.5.0 # botocore # requests wrapt==2.0.0 - # via smart-open + # via + # deprecated + # smart-open diff --git a/requirements/requirements-ubuntu-latest_py3.11_extras.txt b/requirements/requirements-ubuntu-latest_py3.11_extras.txt index 6f0769b2b..4fa3f653f 100644 --- a/requirements/requirements-ubuntu-latest_py3.11_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.11_extras.txt @@ -8,6 +8,8 @@ alabaster==1.0.0 # via sphinx annotated-types==0.7.0 # via pydantic +arro3-core==0.6.5 + # via deltalake arrow==1.4.0 # via isoduration ase==3.26.0 @@ -26,9 +28,9 @@ blake3==1.0.8 # via emmet-core boltons==25.0.0 # via mpcontribs-client -boto3==1.40.61 +boto3==1.40.66 # via mp-api (pyproject.toml) -botocore==1.40.61 +botocore==1.40.66 # via # boto3 # s3transfer @@ -54,6 +56,10 @@ cycler==0.12.1 # via matplotlib decorator==5.2.1 # via ipython +deltalake==1.2.1 + # via mp-api (pyproject.toml) +deprecated==1.3.1 + # via deltalake distlib==0.4.0 # via virtualenv dnspython==2.8.0 @@ -90,7 +96,7 @@ idna==3.11 # via # jsonschema # requests -imageio==2.37.0 +imageio==2.37.2 # via scikit-image imagesize==1.4.1 # via sphinx @@ -181,7 +187,7 @@ mypy-extensions==1.1.0 # via # mp-api (pyproject.toml) # mypy -narwhals==2.10.0 +narwhals==2.10.2 # via plotly networkx==3.5 # via @@ -244,13 +250,13 @@ pillow==12.0.0 # imageio # matplotlib # scikit-image -pint==0.25 +pint==0.25.1 # via mpcontribs-client platformdirs==4.5.0 # via # pint # virtualenv -plotly==6.3.1 +plotly==6.4.0 # via # mpcontribs-client # pymatgen @@ -262,7 +268,7 @@ pre-commit==4.3.0 # via mp-api (pyproject.toml) prompt-toolkit==3.0.52 # via ipython -psutil==7.1.2 +psutil==7.1.3 # via custodian ptyprocess==0.7.0 # via pexpect @@ -271,7 +277,9 @@ pubchempy==1.0.5 pure-eval==0.2.3 # via stack-data pyarrow==22.0.0 - # via emmet-core + # via + # emmet-core + # mp-api (pyproject.toml) pybtex==0.25.1 # via # emmet-core @@ -435,7 +443,7 @@ six==1.17.0 # flatten-dict # python-dateutil # rfc3339-validator -smart-open==7.4.1 +smart-open==7.4.4 # via mp-api (pyproject.toml) snowballstemmer==3.0.1 # via sphinx @@ -491,6 +499,7 @@ types-setuptools==80.9.0.20250822 # via mp-api (pyproject.toml) typing-extensions==4.15.0 # via + # arro3-core # blake3 # bravado # emmet-core @@ -527,11 +536,13 @@ urllib3==2.5.0 # botocore # requests # types-requests -virtualenv==20.35.3 +virtualenv==20.35.4 # via pre-commit wcwidth==0.2.14 # via prompt-toolkit -webcolors==24.11.1 +webcolors==25.10.0 # via jsonschema wrapt==2.0.0 - # via smart-open + # via + # deprecated + # smart-open diff --git a/requirements/requirements-ubuntu-latest_py3.12.txt b/requirements/requirements-ubuntu-latest_py3.12.txt index f1c760b88..b9fb44814 100644 --- a/requirements/requirements-ubuntu-latest_py3.12.txt +++ b/requirements/requirements-ubuntu-latest_py3.12.txt @@ -6,13 +6,15 @@ # annotated-types==0.7.0 # via pydantic +arro3-core==0.6.5 + # via deltalake bibtexparser==1.4.3 # via pymatgen blake3==1.0.8 # via emmet-core -boto3==1.40.61 +boto3==1.40.66 # via mp-api (pyproject.toml) -botocore==1.40.61 +botocore==1.40.66 # via # boto3 # s3transfer @@ -24,6 +26,10 @@ contourpy==1.3.3 # via matplotlib cycler==0.12.1 # via matplotlib +deltalake==1.2.1 + # via mp-api (pyproject.toml) +deprecated==1.3.1 + # via deltalake emmet-core==0.86.0rc1 # via mp-api (pyproject.toml) fonttools==4.60.1 @@ -51,7 +57,7 @@ mpmath==1.3.0 # via sympy msgpack==1.1.2 # via mp-api (pyproject.toml) -narwhals==2.10.0 +narwhals==2.10.2 # via plotly networkx==3.5 # via pymatgen @@ -79,8 +85,10 @@ pandas==2.3.3 # via pymatgen pillow==12.0.0 # via matplotlib -plotly==6.3.1 +plotly==6.4.0 # via pymatgen +pyarrow==22.0.0 + # via mp-api (pyproject.toml) pybtex==0.25.1 # via emmet-core pydantic==2.12.3 @@ -133,7 +141,7 @@ scipy==1.16.3 # via pymatgen six==1.17.0 # via python-dateutil -smart-open==7.4.1 +smart-open==7.4.4 # via mp-api (pyproject.toml) spglib==2.6.0 # via pymatgen @@ -164,4 +172,6 @@ urllib3==2.5.0 # botocore # requests wrapt==2.0.0 - # via smart-open + # via + # deprecated + # smart-open diff --git a/requirements/requirements-ubuntu-latest_py3.12_extras.txt b/requirements/requirements-ubuntu-latest_py3.12_extras.txt index fd592b9fd..1c69da588 100644 --- a/requirements/requirements-ubuntu-latest_py3.12_extras.txt +++ b/requirements/requirements-ubuntu-latest_py3.12_extras.txt @@ -8,6 +8,8 @@ alabaster==1.0.0 # via sphinx annotated-types==0.7.0 # via pydantic +arro3-core==0.6.5 + # via deltalake arrow==1.4.0 # via isoduration ase==3.26.0 @@ -26,9 +28,9 @@ blake3==1.0.8 # via emmet-core boltons==25.0.0 # via mpcontribs-client -boto3==1.40.61 +boto3==1.40.66 # via mp-api (pyproject.toml) -botocore==1.40.61 +botocore==1.40.66 # via # boto3 # s3transfer @@ -54,6 +56,10 @@ cycler==0.12.1 # via matplotlib decorator==5.2.1 # via ipython +deltalake==1.2.1 + # via mp-api (pyproject.toml) +deprecated==1.3.1 + # via deltalake distlib==0.4.0 # via virtualenv dnspython==2.8.0 @@ -90,7 +96,7 @@ idna==3.11 # via # jsonschema # requests -imageio==2.37.0 +imageio==2.37.2 # via scikit-image imagesize==1.4.1 # via sphinx @@ -181,7 +187,7 @@ mypy-extensions==1.1.0 # via # mp-api (pyproject.toml) # mypy -narwhals==2.10.0 +narwhals==2.10.2 # via plotly networkx==3.5 # via @@ -244,13 +250,13 @@ pillow==12.0.0 # imageio # matplotlib # scikit-image -pint==0.25 +pint==0.25.1 # via mpcontribs-client platformdirs==4.5.0 # via # pint # virtualenv -plotly==6.3.1 +plotly==6.4.0 # via # mpcontribs-client # pymatgen @@ -262,7 +268,7 @@ pre-commit==4.3.0 # via mp-api (pyproject.toml) prompt-toolkit==3.0.52 # via ipython -psutil==7.1.2 +psutil==7.1.3 # via custodian ptyprocess==0.7.0 # via pexpect @@ -271,7 +277,9 @@ pubchempy==1.0.5 pure-eval==0.2.3 # via stack-data pyarrow==22.0.0 - # via emmet-core + # via + # emmet-core + # mp-api (pyproject.toml) pybtex==0.25.1 # via # emmet-core @@ -435,7 +443,7 @@ six==1.17.0 # flatten-dict # python-dateutil # rfc3339-validator -smart-open==7.4.1 +smart-open==7.4.4 # via mp-api (pyproject.toml) snowballstemmer==3.0.1 # via sphinx @@ -525,11 +533,13 @@ urllib3==2.5.0 # botocore # requests # types-requests -virtualenv==20.35.3 +virtualenv==20.35.4 # via pre-commit wcwidth==0.2.14 # via prompt-toolkit -webcolors==24.11.1 +webcolors==25.10.0 # via jsonschema wrapt==2.0.0 - # via smart-open + # via + # deprecated + # smart-open From b0aed4f80dab90b051984d2e54f94129e3108dc5 Mon Sep 17 00:00:00 2001 From: Patrick Huck Date: Tue, 4 Nov 2025 16:11:13 -0800 Subject: [PATCH 11/26] Update testing.yml --- .github/workflows/testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index b96a5eb2d..5f9840ab7 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -57,7 +57,7 @@ jobs: - name: Test with pytest env: MP_API_KEY: ${{ secrets[env.API_KEY_NAME] }} - #MP_API_ENDPOINT: https://api-preview.materialsproject.org/ + MP_API_ENDPOINT: https://api-preview.materialsproject.org/ run: | pip install -e . pytest -n auto -x --cov=mp_api --cov-report=xml From a35bcb72c7d127023a32266a05eaf6792d896ae7 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:16:25 -0800 Subject: [PATCH 12/26] rm overlooked access of removed settings param --- mp_api/client/core/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index c01ef3b7e..35175c71f 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -585,7 +585,7 @@ def _query_resource( ) controlled_batch_str = ",".join( - [f"'{tag}'" for tag in SETTINGS.ACCESS_CONTROLLED_BATCH_IDS] + [f"'{tag}'" for tag in self.access_controlled_batch_ids] ) predicate = ( From 94606015fb7a47dbdc2f7ded80940653c0e82e97 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:24:43 -0800 Subject: [PATCH 13/26] refactor: consolidate requests to heartbeat for meta info --- mp_api/client/core/client.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 35175c71f..c25126803 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -153,9 +153,8 @@ def __init__( self.mute_progress_bars = mute_progress_bars self.local_dataset_cache = local_dataset_cache self.force_renew = force_renew - self.db_version = BaseRester._get_database_version(self.endpoint) - self.access_controlled_batch_ids = BaseRester._get_access_restricted_batch_ids( - self.endpoint + self.db_version, self.access_controlled_batch_ids = ( + BaseRester._get_hearbeat_info(self.endpoint) ) if self.suffix: @@ -231,8 +230,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover @staticmethod @cache - def _get_database_version(endpoint): - """The Materials Project database is periodically updated and has a + def _get_hearbeat_info(endpoint) -> tuple[str, str]: + """DB version: + The Materials Project database is periodically updated and has a database version associated with it. When the database is updated, consolidated data (information about "a material") may and does change, while calculation data about a specific calculation task @@ -242,14 +242,8 @@ def _get_database_version(endpoint): where "_DD" may be optional. An additional numerical or `postN` suffix might be added if multiple releases happen on the same day. - Returns: database version as a string - """ - return requests.get(url=endpoint + "heartbeat").json()["db_version"] - - @staticmethod - @cache - def _get_access_restricted_batch_ids(endpoint): - """Certain contributions to the Materials Project have access + Access Controlled Datasets: + Certain contributions to the Materials Project have access control restrictions that require explicit agreement to the Terms of Use for the respective datasets prior to access being granted. @@ -259,11 +253,12 @@ def _get_access_restricted_batch_ids(endpoint): https://next-gen.materialsproject.org/about/terms - Returns: a list of strings + Returns: + tuple with database version as a string and a comma separated + string with all calculation batch identifiers """ - return requests.get(url=endpoint + "heartbeat").json()[ - "access_controlled_batch_ids" - ] + response = requests.get(url=endpoint + "heartbeat").json() + return response["db_version"], response["access_controlled_batch_ids"] def _post_resource( self, From 05f1d0e9153810d0a3b357f77bbeed754066c126 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:25:48 -0800 Subject: [PATCH 14/26] lint --- mp_api/client/core/client.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index c25126803..bbd70ac61 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -153,9 +153,10 @@ def __init__( self.mute_progress_bars = mute_progress_bars self.local_dataset_cache = local_dataset_cache self.force_renew = force_renew - self.db_version, self.access_controlled_batch_ids = ( - BaseRester._get_hearbeat_info(self.endpoint) - ) + ( + self.db_version, + self.access_controlled_batch_ids, + ) = BaseRester._get_hearbeat_info(self.endpoint) if self.suffix: self.endpoint = urljoin(self.endpoint, self.suffix) From e685445c95e7f501c3314d05ff1e186b4aaa057e Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:30:34 -0800 Subject: [PATCH 15/26] fix incomplete docstr --- mp_api/client/core/client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index bbd70ac61..8a9e1da24 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -256,7 +256,8 @@ def _get_hearbeat_info(endpoint) -> tuple[str, str]: Returns: tuple with database version as a string and a comma separated - string with all calculation batch identifiers + string with all calculation batch identifiers that have access + restrictions """ response = requests.get(url=endpoint + "heartbeat").json() return response["db_version"], response["access_controlled_batch_ids"] From bb0b238e416079654d302919448835975032b85d Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Wed, 5 Nov 2025 10:42:02 -0800 Subject: [PATCH 16/26] typo --- mp_api/client/core/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 8a9e1da24..cbbcd15b2 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -156,7 +156,7 @@ def __init__( ( self.db_version, self.access_controlled_batch_ids, - ) = BaseRester._get_hearbeat_info(self.endpoint) + ) = BaseRester._get_heartbeat_info(self.endpoint) if self.suffix: self.endpoint = urljoin(self.endpoint, self.suffix) @@ -231,7 +231,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover @staticmethod @cache - def _get_hearbeat_info(endpoint) -> tuple[str, str]: + def _get_heartbeat_info(endpoint) -> tuple[str, str]: """DB version: The Materials Project database is periodically updated and has a database version associated with it. When the database is updated, From fb84d73e93a311b9ceb5b93ae4cc7cd79ee6a3e3 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Mon, 10 Nov 2025 09:38:18 -0800 Subject: [PATCH 17/26] revert testing endpoint --- .github/workflows/testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 5f9840ab7..7a4ff0d8a 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -57,7 +57,7 @@ jobs: - name: Test with pytest env: MP_API_KEY: ${{ secrets[env.API_KEY_NAME] }} - MP_API_ENDPOINT: https://api-preview.materialsproject.org/ + # MP_API_ENDPOINT: https://api-preview.materialsproject.org/ run: | pip install -e . pytest -n auto -x --cov=mp_api --cov-report=xml From 5bdacf57e020ec5ca7d5a6be6f45cb56b9cb3ed1 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Mon, 10 Nov 2025 12:39:49 -0800 Subject: [PATCH 18/26] no parallel on batch_id_neq_any --- mp_api/client/core/settings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index 8b0e63937..d50b84b26 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -50,6 +50,7 @@ class MAPIClientSettings(BaseSettings): "condition_mixing_media", "condition_heating_atmosphere", "operations", + "batch_id_neq_any", "_fields", ], description="List API query parameters that do not support parallel requests.", From 7ee551547c133abebd49f04ecfa06b92f0d3aef8 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Wed, 12 Nov 2025 10:36:01 -0800 Subject: [PATCH 19/26] more resilient dataset path expansion --- mp_api/client/core/settings.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index d50b84b26..cb0cbfc4c 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -1,4 +1,5 @@ import os +import pathlib from multiprocessing import cpu_count from typing import List @@ -89,7 +90,7 @@ class MAPIClientSettings(BaseSettings): ) LOCAL_DATASET_CACHE: str = Field( - os.path.expanduser("~") + "/mp_datasets", + pathlib.Path("~/mp_datasets").expanduser(), description="Target directory for downloading full datasets", ) From ae7674db1fc0e1208de31721a5eb15634916061f Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Wed, 12 Nov 2025 10:41:11 -0800 Subject: [PATCH 20/26] missed field annotation update --- mp_api/client/core/settings.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index cb0cbfc4c..b708ebfeb 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -1,6 +1,6 @@ import os -import pathlib from multiprocessing import cpu_count +from pathlib import Path from typing import List from pydantic import Field @@ -89,8 +89,8 @@ class MAPIClientSettings(BaseSettings): _MAX_LIST_LENGTH, description="Maximum length of query parameter list" ) - LOCAL_DATASET_CACHE: str = Field( - pathlib.Path("~/mp_datasets").expanduser(), + LOCAL_DATASET_CACHE: Path = Field( + Path("~/mp_datasets").expanduser(), description="Target directory for downloading full datasets", ) From 5538c74544a3c4984321ccf943f9b9dac7607afb Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Wed, 12 Nov 2025 10:52:23 -0800 Subject: [PATCH 21/26] coerce Path to str for deltalake lib --- mp_api/client/core/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index cbbcd15b2..48c8a8a10 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -545,8 +545,8 @@ def _query_resource( bucket = f"materialsproject-{bucket_suffix}" if self.delta_backed: - target_path = ( - self.local_dataset_cache + f"/{bucket_suffix}/{prefix}" + target_path = str( + self.local_dataset_cache.joinpath(f"{bucket_suffix}/{prefix}") ) os.makedirs(target_path, exist_ok=True) From f39c0d3178a321a7e9fe0c7a5e405e7e82d21277 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Thu, 13 Nov 2025 17:22:40 -0800 Subject: [PATCH 22/26] flush based on bytes --- mp_api/client/core/client.py | 5 +++-- mp_api/client/core/settings.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 48c8a8a10..fcd065de4 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -637,9 +637,10 @@ def _flush(accumulator, group): accumulator = [] for page in iterator: # arro3 rb to pyarrow rb for compat w/ pyarrow ds writer - accumulator.append(pa.record_batch(page)) + rg = pa.record_batch(page) + accumulator.append(rg) page_size = page.num_rows - size += page_size + size += rg.get_total_buffer_size() if pbar is not None: pbar.update(page_size) diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index b708ebfeb..09926fe82 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -95,8 +95,8 @@ class MAPIClientSettings(BaseSettings): ) DATASET_FLUSH_THRESHOLD: int = Field( - 100000, - description="Threshold number of rows to accumulate in memory before flushing dataset to disk", + int(2.75 * 1024**3), + description="Threshold bytes to accumulate in memory before flushing dataset to disk", ) model_config = SettingsConfigDict(env_prefix="MPRESTER_") From a9652552a7ee0a3e796a28949c21b5c6c1854699 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Thu, 13 Nov 2025 17:23:01 -0800 Subject: [PATCH 23/26] iterate over individual rows for local dataset --- mp_api/client/core/utils.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 9e7003ed2..9b586e7f2 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -173,15 +173,21 @@ def num_chunks(self) -> int: return len(self._row_groups) def __getitem__(self, idx): - return list( - map( - lambda x: self._document_model(**x) if self._use_document_model else x, - self._row_groups[idx].to_table().to_pylist(maps_as_pydicts="strict"), + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + _take = list(range(start, stop, step)) + ds_slice = self._dataset.take(_take).to_pylist(maps_as_pydicts="strict") + return ( + [self._document_model(**_row) for _row in ds_slice] + if self._use_document_model + else ds_slice ) - ) + + _row = self._dataset.take([idx]).to_pylist(maps_as_pydicts="strict")[0] + return self._document_model(**_row) if self._use_document_model else _row def __len__(self) -> int: - return self.num_chunks + return self._dataset.count_rows() def __iter__(self): current = self._start From 03b38e70874fcd4b536f892a4aef55dcae7605e8 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Fri, 14 Nov 2025 11:17:56 -0800 Subject: [PATCH 24/26] missed bounds check for updated iteration behavior --- mp_api/client/core/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 9b586e7f2..b549d5b2f 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -191,6 +191,6 @@ def __len__(self) -> int: def __iter__(self): current = self._start - while current < self.num_chunks: + while current < len(self): yield self[current] current += 1 From 3a44b4f4314e448b8d2c0abd6c08af5a25ae2427 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Fri, 14 Nov 2025 11:18:18 -0800 Subject: [PATCH 25/26] opt for module level logging over warnings lib --- mp_api/client/core/client.py | 43 ++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index fcd065de4..13f7a3b46 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -7,6 +7,7 @@ import inspect import itertools +import logging import os import platform import shutil @@ -62,6 +63,14 @@ SETTINGS = MAPIClientSettings() # type: ignore +hdlr = logging.StreamHandler() +fmt = logging.Formatter("%(name)s - %(levelname)s - %(message)s") +hdlr.setFormatter(fmt) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(hdlr) + class _DictLikeAccess(BaseModel): """Define a pydantic mix-in which permits dict-like access to model fields.""" @@ -553,16 +562,17 @@ def _query_resource( if DeltaTable.is_deltatable(target_path): if self.force_renew: shutil.rmtree(target_path) - warnings.warn( - f"Regenerating {suffix} dataset at {target_path}...", - MPLocalDatasetWarning, + logger.warning( + f"Regenerating {suffix} dataset at {target_path}..." ) os.makedirs(target_path, exist_ok=True) else: - warnings.warn( - f"Dataset for {suffix} already exists at {target_path}, delete or move existing dataset " - "or re-run search query with MPRester(force_renew=True)", - MPLocalDatasetWarning, + logger.warning( + f"Dataset for {suffix} already exists at {target_path}, returning existing dataset." + ) + logger.info( + "Delete or move existing dataset or re-run search query with MPRester(force_renew=True) " + "to refresh local dataset.", ) return { @@ -654,15 +664,20 @@ def _flush(accumulator, group): if accumulator: _flush(accumulator, group + 1) + if pbar is not None: + pbar.close() + + logger.info(f"Dataset for {suffix} written to {target_path}") + logger.info("Converting to DeltaTable...") + convert_to_deltalake(target_path) - warnings.warn( - f"Dataset for {suffix} written to {target_path}. It is recommended to optimize " - "the table according to your usage patterns prior to running intensive workloads, " - "see: https://delta-io.github.io/delta-rs/delta-lake-best-practices/#optimizing-table-layout", - MPLocalDatasetWarning, + logger.info( + "Consult the delta-rs and pyarrow documentation for advanced usage: " + "delta-io.github.io/delta-rs/, arrow.apache.org/docs/python/" ) + return { "data": MPDataset( path=target_path, @@ -1537,7 +1552,3 @@ class MPRestError(Exception): class MPRestWarning(Warning): """Raised when a query is malformed but interpretable.""" - - -class MPLocalDatasetWarning(Warning): - """Raised when unrecoverable actions are performed on a local dataset.""" From b2a832f9ed10784211750a4f46797ae1f1b53438 Mon Sep 17 00:00:00 2001 From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com> Date: Fri, 14 Nov 2025 11:23:05 -0800 Subject: [PATCH 26/26] lint --- mp_api/client/core/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 13f7a3b46..81e80bd88 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -677,7 +677,6 @@ def _flush(accumulator, group): "delta-io.github.io/delta-rs/, arrow.apache.org/docs/python/" ) - return { "data": MPDataset( path=target_path,