From ccb2b9eea7cf165ad0e615cd1bfeb990c857222b Mon Sep 17 00:00:00 2001 From: kraysent Date: Sat, 15 Nov 2025 12:52:01 +0000 Subject: [PATCH 1/2] #5: use get_constraints instead of get_catalogs in vizier-v2 plugin --- app/discover.py | 4 +--- app/upload.py | 9 ++++++++- cli.py | 12 +++++------- plugins/csv_batched.py | 4 +--- plugins/fits.py | 4 +--- plugins/vizier.py | 36 +++++++++--------------------------- plugins/vizier_v2.py | 39 ++++++++++++++------------------------- pyproject.toml | 4 ++++ tests/test_upload.py | 32 ++++++++------------------------ 9 files changed, 51 insertions(+), 93 deletions(-) diff --git a/app/discover.py b/app/discover.py index 0e8ced3..ef4f839 100644 --- a/app/discover.py +++ b/app/discover.py @@ -21,9 +21,7 @@ def discover_plugins(dir: str) -> dict[str, type[interface.UploaderPlugin]]: spec.loader.exec_module(module) if not hasattr(module, "plugin"): - log.logger.warn( - "python file has no declared plugin", filename=str(file_path) - ) + log.logger.warn("python file has no declared plugin", filename=str(file_path)) continue plugin_class = getattr(module, "plugin") diff --git a/app/upload.py b/app/upload.py index 9c4dc63..6e0f1f2 100644 --- a/app/upload.py +++ b/app/upload.py @@ -1,3 +1,4 @@ +import math from typing import Any import click @@ -37,6 +38,12 @@ def handle_call[T: Any](response: types.Response[T | models.HTTPValidationError] return response.parsed +def sanitize_value(val: Any) -> Any: + if isinstance(val, float) and math.isnan(val): + return None + return val + + def _upload( plugin: interface.UploaderPlugin, client: adminapi.AuthenticatedClient, @@ -88,7 +95,7 @@ def _upload( for _, row in data.iterrows(): item = models.AddDataRequestDataItem() for col in data.columns: - item[col] = row[col] + item[col] = sanitize_value(row[col]) request_data.append(item) _ = handle_call( diff --git a/cli.py b/cli.py index 2953f8e..5d3cba3 100644 --- a/cli.py +++ b/cli.py @@ -55,7 +55,9 @@ def discover(plugin_dir: str) -> None: table_name_descr = "Table name is a primary identifier of the table in HyperLEDA. It usually is a machine-readable string that will later be user to do any alterations to the table. Example: sdss_dr12." -table_description_descr = "Description of the table is a human-readable string that can later be used for searching of viewing the table." +table_description_descr = ( + "Description of the table is a human-readable string that can later be used for searching of viewing the table." +) bibcode_descr = "Bibcode is an identifier for the publication from the NASA ADS system https://ui.adsabs.harvard.edu/. It allows for easy search of the publication throughout a range of different sources." pub_name_descr = "Name of the internal source. Can be a short description that represents where the data comes from." pub_authors_descr = "Comma-separated list of authors of the internal source." @@ -108,9 +110,7 @@ def upload( try: default_table_name = plugin.get_table_name() except Exception: - app.logger.warning( - "failed to get default table name from plugin", plugin=plugin - ) + app.logger.warning("failed to get default table name from plugin", plugin=plugin) table_name = question( "Enter table name", @@ -207,9 +207,7 @@ def upload( click.echo(parameter("Table type", table_type)) if not auto_proceed: - auto_proceed = question( - "Proceed? (y,n)", default="y", transformer=lambda s: s == "y" - ) + auto_proceed = question("Proceed? (y,n)", default="y", transformer=lambda s: s == "y") if auto_proceed: app.upload( diff --git a/plugins/csv_batched.py b/plugins/csv_batched.py index baecf57..dadcd47 100644 --- a/plugins/csv_batched.py +++ b/plugins/csv_batched.py @@ -37,9 +37,7 @@ def prepare(self) -> None: with open(self.filename) as f: self._total_chunks = sum(1 for _ in f) - 1 - self._total_chunks = ( - self._total_chunks + self._chunk_size - 1 - ) // self._chunk_size + self._total_chunks = (self._total_chunks + self._chunk_size - 1) // self._chunk_size self._current_chunk = 0 def get_schema(self) -> list[models.ColumnDescription]: diff --git a/plugins/fits.py b/plugins/fits.py index 0b70799..2725a63 100644 --- a/plugins/fits.py +++ b/plugins/fits.py @@ -38,9 +38,7 @@ def prepare(self) -> None: raise ValueError(f"HDU {self.hdu_index} is not a binary table") self._schema = self._table.columns - self._total_batches = ( - len(self._table.data) + self._batch_size - 1 - ) // self._batch_size + self._total_batches = (len(self._table.data) + self._batch_size - 1) // self._batch_size self._current_batch = 0 def get_schema(self) -> list[models.ColumnDescription]: diff --git a/plugins/vizier.py b/plugins/vizier.py index e02d349..f24a6b2 100644 --- a/plugins/vizier.py +++ b/plugins/vizier.py @@ -76,24 +76,18 @@ def _write_schema_cache(self, catalog_name: str, table_name: str): app.logger.debug("wrote cache", location=str(cache_filename)) - def _obtain_cache_path( - self, type_path: str, catalog_name: str, table_name: str - ) -> pathlib.Path: + def _obtain_cache_path(self, type_path: str, catalog_name: str, table_name: str) -> pathlib.Path: filename = f"{_get_filename(catalog_name, table_name)}.vot" path = pathlib.Path(self.cache_path) / type_path / filename path.parent.mkdir(parents=True, exist_ok=True) return path - def _get_schema_from_cache( - self, catalog_name: str, table_name: str - ) -> tree.VOTableFile: + def _get_schema_from_cache(self, catalog_name: str, table_name: str) -> tree.VOTableFile: cache_filename = self._obtain_cache_path("schemas", catalog_name, table_name) return votable.parse(str(cache_filename)) - def _get_table_from_cache( - self, catalog_name: str, table_name: str - ) -> astropy.table.Table: + def _get_table_from_cache(self, catalog_name: str, table_name: str) -> astropy.table.Table: cache_filename = self._obtain_cache_path("tables", catalog_name, table_name) return astropy.table.Table.read(cache_filename, format="votable") @@ -101,9 +95,7 @@ def prepare(self) -> None: pass def get_schema(self) -> list[models.ColumnDescription]: - if not self._obtain_cache_path( - "schemas", self.catalog_name, self.table_name - ).exists(): + if not self._obtain_cache_path("schemas", self.catalog_name, self.table_name).exists(): app.logger.debug("did not hit cache for the schema, downloading") self._write_schema_cache(self.catalog_name, self.table_name) @@ -122,9 +114,7 @@ def get_schema(self) -> list[models.ColumnDescription]: ] def get_data(self) -> Generator[tuple[pandas.DataFrame, float]]: - if not self._obtain_cache_path( - "tables", self.catalog_name, self.table_name - ).exists(): + if not self._obtain_cache_path("tables", self.catalog_name, self.table_name).exists(): app.logger.debug("did not hit cache for the table, downloading") self._write_table_cache(self.catalog_name, self.table_name) @@ -154,9 +144,7 @@ def get_table_name(self) -> str: def get_bibcode(self) -> str: self.get_schema() schema = self._get_schema_from_cache(self.catalog_name, self.table_name) - bibcode_info = next( - filter(lambda info: info.name == "cites", schema.resources[0].infos) - ) + bibcode_info = next(filter(lambda info: info.name == "cites", schema.resources[0].infos)) return bibcode_info.value.split(":")[1] def get_description(self) -> str: @@ -173,9 +161,7 @@ def _get_filename(catalog_name: str, table_name: str) -> str: return f"{_sanitize_filename(catalog_name)}_{_sanitize_filename(table_name)}" -def _get_columns( - client: vizier.VizierClass, catalog_name: str, table_name: str -) -> list[str]: +def _get_columns(client: vizier.VizierClass, catalog_name: str, table_name: str) -> list[str]: catalogs = client.get_catalogs(catalog_name) # type: ignore meta = None @@ -190,9 +176,7 @@ def _get_columns( return meta.colnames -def _download_table( - table_name: str, columns: list[str], max_rows: int | None = None -) -> str: +def _download_table(table_name: str, columns: list[str], max_rows: int | None = None) -> str: out_max = "unlimited" if max_rows is None else max_rows payload = [ @@ -216,9 +200,7 @@ def _download_table( "Content-Type": "application/x-www-form-urlencoded", } - response = requests.request( - http.HTTPMethod.POST, VIZIER_URL, data=data, headers=headers - ) + response = requests.request(http.HTTPMethod.POST, VIZIER_URL, data=data, headers=headers) return response.text diff --git a/plugins/vizier_v2.py b/plugins/vizier_v2.py index 910827c..441fa7b 100644 --- a/plugins/vizier_v2.py +++ b/plugins/vizier_v2.py @@ -3,6 +3,7 @@ from collections.abc import Generator from typing import final +import numpy as np import pandas from astropy import table from astroquery import utils, vizier @@ -15,35 +16,27 @@ def _sanitize_filename(string: str) -> str: return string.replace("/", "_") -def dtype_to_datatype(dtype) -> models.DatatypeEnum: - # Accept both dtypes and strings +def dtype_to_datatype(dtype: str | np.dtype) -> models.DatatypeEnum: dtype_str = str(dtype).lower() - # Typical mappings - if any( - dtype_str.startswith(x) - for x in ("str", "unicode", " pathlib.Path: + def _obtain_cache_path(self, catalog_name: str, row_num: int | None = None) -> pathlib.Path: filename = f"{_sanitize_filename(catalog_name)}.vot" if row_num is not None: filename = f"{_sanitize_filename(catalog_name)}_rows_{row_num}.vot" @@ -64,9 +55,7 @@ def _obtain_cache_path( path.parent.mkdir(parents=True, exist_ok=True) return path - def _write_catalog_cache( - self, catalog_name: str, row_num: int | None = None - ) -> None: + def _write_catalog_cache(self, catalog_name: str, row_num: int | None = None) -> None: app.logger.info( "downloading catalog from Vizier", catalog_name=catalog_name, @@ -76,7 +65,7 @@ def _write_catalog_cache( if row_num is not None: client = vizier.Vizier() client.ROW_LIMIT = row_num - catalogs: utils.TableList = client.get_catalogs(catalog_name) # pyright: ignore[reportAttributeAccessIssue] + catalogs: utils.TableList = client.query_constraints(catalog=catalog_name) # pyright: ignore[reportAttributeAccessIssue] if not catalogs: raise ValueError("catalog not found") @@ -109,7 +98,7 @@ def __init__( catalog_name: str, table_name: str, cache_path: str = ".vizier_cache/", - batch_size: int = 500, + batch_size: int = 100, ): self.catalog_name = catalog_name self.table_name = table_name diff --git a/pyproject.toml b/pyproject.toml index 18def46..88f840b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "astropy>=7.1.0", "openapi-python-client>=0.27.1", "pandas>=2.3.3", + "numpy>=2.3.4", ] [tool.pytest.ini_options] @@ -25,3 +26,6 @@ dev = [ "pandas-stubs>=2.2.3.250308", "pytest>=8.4.2", ] + +[tool.ruff] +line-length = 120 diff --git a/tests/test_upload.py b/tests/test_upload.py index 78d3537..b5ca69b 100644 --- a/tests/test_upload.py +++ b/tests/test_upload.py @@ -49,24 +49,16 @@ def mock_response[T: Any](resp: T) -> types.Response[T]: @patch("app.upload.create_source") @patch("app.upload.create_table") @patch("app.upload.add_data") -def test_upload_with_csv_plugin( - mock_add_data, mock_create_table, mock_create_source, mock_client -): +def test_upload_with_csv_plugin(mock_add_data, mock_create_table, mock_create_source, mock_client): mock_create_source_response = models.APIOkResponseCreateSourceResponse( data=models.CreateSourceResponse(code="test_bibcode") ) mock_create_source.sync.return_value = mock_create_source_response - mock_create_table_response = models.APIOkResponseCreateTableResponse( - data=models.CreateTableResponse(id=1) - ) - mock_create_table.sync_detailed.return_value = mock_response( - mock_create_table_response - ) + mock_create_table_response = models.APIOkResponseCreateTableResponse(data=models.CreateTableResponse(id=1)) + mock_create_table.sync_detailed.return_value = mock_response(mock_create_table_response) - mock_add_data_response = models.APIOkResponseAddDataResponse( - data=models.AddDataResponse() - ) + mock_add_data_response = models.APIOkResponseAddDataResponse(data=models.AddDataResponse()) mock_add_data.sync_detailed.return_value = mock_response(mock_add_data_response) plugin = CSVPlugin("tests/test_csv.csv") @@ -90,22 +82,14 @@ def test_upload_with_csv_plugin( @patch("app.upload.create_source") @patch("app.upload.create_table") -def test_plugin_stop_called_on_error( - mock_create_table, mock_create_source, mock_client -): +def test_plugin_stop_called_on_error(mock_create_table, mock_create_source, mock_client): mock_create_source_response = models.APIOkResponseCreateSourceResponse( data=models.CreateSourceResponse(code="test_bibcode") ) - mock_create_source.sync_detailed.return_value = mock_response( - mock_create_source_response - ) + mock_create_source.sync_detailed.return_value = mock_response(mock_create_source_response) - mock_create_table_response = models.APIOkResponseCreateTableResponse( - data=models.CreateTableResponse(id=1) - ) - mock_create_table.sync_detailed.return_value = mock_response( - mock_create_table_response - ) + mock_create_table_response = models.APIOkResponseCreateTableResponse(data=models.CreateTableResponse(id=1)) + mock_create_table.sync_detailed.return_value = mock_response(mock_create_table_response) plugin = StubPlugin(should_raise=True) From f1d00ed34b9f0361af7cfbf14354784553e3bbe1 Mon Sep 17 00:00:00 2001 From: kraysent Date: Sat, 15 Nov 2025 13:10:44 +0000 Subject: [PATCH 2/2] add constrait usage to vizier client --- plugins/vizier_v2.py | 55 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/plugins/vizier_v2.py b/plugins/vizier_v2.py index 441fa7b..ab2803b 100644 --- a/plugins/vizier_v2.py +++ b/plugins/vizier_v2.py @@ -13,7 +13,15 @@ def _sanitize_filename(string: str) -> str: - return string.replace("/", "_") + return ( + string.replace("/", "_") + .replace("&", "_and_") + .replace(">", "_gt_") + .replace("<", "_lt_") + .replace("=", "_eq_") + .replace(" ", "_") + .replace("!", "_not_") + ) def dtype_to_datatype(dtype: str | np.dtype) -> models.DatatypeEnum: @@ -47,38 +55,54 @@ def __init__(self, cache_path: str = ".vizier_cache/"): self._client = vizier.Vizier() self._client.ROW_LIMIT = -1 - def _obtain_cache_path(self, catalog_name: str, row_num: int | None = None) -> pathlib.Path: - filename = f"{_sanitize_filename(catalog_name)}.vot" + def _obtain_cache_path( + self, catalog_name: str, row_num: int | None = None, constraints: dict[str, str] | None = None + ) -> pathlib.Path: + filename = f"{catalog_name}.vot" if row_num is not None: - filename = f"{_sanitize_filename(catalog_name)}_rows_{row_num}.vot" + filename = f"{catalog_name}_rows_{row_num}.vot" + if constraints: + sorted_constraints = sorted(constraints.items()) + constraint_str = "_".join(f"{k}_{v}" for k, v in sorted_constraints) + filename = f"{catalog_name}_constraints_{constraint_str}.vot" + + filename = _sanitize_filename(filename) path = pathlib.Path(self.cache_path) / "catalogs" / filename path.parent.mkdir(parents=True, exist_ok=True) return path - def _write_catalog_cache(self, catalog_name: str, row_num: int | None = None) -> None: + def _write_catalog_cache( + self, catalog_name: str, row_num: int | None = None, constraints: dict[str, str] | None = None + ) -> None: app.logger.info( "downloading catalog from Vizier", catalog_name=catalog_name, row_num=row_num, + constraints=constraints, ) client = self._client if row_num is not None: client = vizier.Vizier() client.ROW_LIMIT = row_num - catalogs: utils.TableList = client.query_constraints(catalog=catalog_name) # pyright: ignore[reportAttributeAccessIssue] + query_kwargs = {"catalog": catalog_name} + if constraints: + query_kwargs.update(constraints) + catalogs: utils.TableList = client.query_constraints(**query_kwargs) # pyright: ignore[reportAttributeAccessIssue] if not catalogs: raise ValueError("catalog not found") - cache_filename = self._obtain_cache_path(catalog_name, row_num) + cache_filename = self._obtain_cache_path(catalog_name, row_num, constraints) catalogs[0].write(str(cache_filename), format="votable") app.logger.debug("wrote catalog cache", location=str(cache_filename)) - def get_table(self, catalog_name: str, row_num: int | None = None) -> table.Table: - cache_path = self._obtain_cache_path(catalog_name, row_num) + def get_table( + self, catalog_name: str, row_num: int | None = None, constraints: dict[str, str] | None = None + ) -> table.Table: + cache_path = self._obtain_cache_path(catalog_name, row_num, constraints) if not cache_path.exists(): app.logger.debug("did not hit cache for the catalog, downloading") - self._write_catalog_cache(catalog_name, row_num) + self._write_catalog_cache(catalog_name, row_num, constraints) return table.Table.read(cache_path, format="votable") @@ -97,9 +121,15 @@ def __init__( self, catalog_name: str, table_name: str, + *constraints: str, cache_path: str = ".vizier_cache/", - batch_size: int = 100, + batch_size: int = 10, ): + if len(constraints) % 2 != 0: + raise ValueError("constraints must be provided in pairs (column, constraint_value)") + self.constraints: dict[str, str] = {} + for i in range(0, len(constraints), 2): + self.constraints[constraints[i]] = constraints[i + 1] self.catalog_name = catalog_name self.table_name = table_name self.batch_size = batch_size @@ -141,7 +171,8 @@ def get_schema(self) -> list[models.ColumnDescription]: return result def get_data(self) -> Generator[tuple[pandas.DataFrame, float]]: - t = self.client.get_table(self.table_name) + constraints = self.constraints if self.constraints else None + t = self.client.get_table(self.table_name, constraints=constraints) total_rows = len(t) app.logger.info("uploading table", total_rows=total_rows)