diff --git a/.gitignore b/.gitignore index b9f63c2..9ae332d 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,4 @@ poetry.toml pyrightconfig.json uv.lock -.vizier_cache \ No newline at end of file +.vizier_cache diff --git a/plugins/vizier_v2.py b/plugins/vizier_v2.py new file mode 100644 index 0000000..910827c --- /dev/null +++ b/plugins/vizier_v2.py @@ -0,0 +1,176 @@ +import itertools +import pathlib +from collections.abc import Generator +from typing import final + +import pandas +from astropy import table +from astroquery import utils, vizier + +import app +from app.gen.client.adminapi import models, types + + +def _sanitize_filename(string: str) -> str: + return string.replace("/", "_") + + +def dtype_to_datatype(dtype) -> models.DatatypeEnum: + # Accept both dtypes and strings + dtype_str = str(dtype).lower() + # Typical mappings + if any( + dtype_str.startswith(x) + for x in ("str", "unicode", " pathlib.Path: + filename = f"{_sanitize_filename(catalog_name)}.vot" + if row_num is not None: + filename = f"{_sanitize_filename(catalog_name)}_rows_{row_num}.vot" + path = pathlib.Path(self.cache_path) / "catalogs" / filename + path.parent.mkdir(parents=True, exist_ok=True) + return path + + def _write_catalog_cache( + self, catalog_name: str, row_num: int | None = None + ) -> None: + app.logger.info( + "downloading catalog from Vizier", + catalog_name=catalog_name, + row_num=row_num, + ) + client = self._client + if row_num is not None: + client = vizier.Vizier() + client.ROW_LIMIT = row_num + catalogs: utils.TableList = client.get_catalogs(catalog_name) # pyright: ignore[reportAttributeAccessIssue] + + if not catalogs: + raise ValueError("catalog not found") + + cache_filename = self._obtain_cache_path(catalog_name, row_num) + catalogs[0].write(str(cache_filename), format="votable") + app.logger.debug("wrote catalog cache", location=str(cache_filename)) + + def get_table(self, catalog_name: str, row_num: int | None = None) -> table.Table: + cache_path = self._obtain_cache_path(catalog_name, row_num) + if not cache_path.exists(): + app.logger.debug("did not hit cache for the catalog, downloading") + self._write_catalog_cache(catalog_name, row_num) + + return table.Table.read(cache_path, format="votable") + + def get_catalog_metadata(self, catalog: str) -> dict: + return self._client.get_catalog_metadata(catalog=catalog) + + +@final +class VizierV2Plugin( + app.UploaderPlugin, + app.DefaultTableNamer, + app.BibcodeProvider, + app.DescriptionProvider, +): + def __init__( + self, + catalog_name: str, + table_name: str, + cache_path: str = ".vizier_cache/", + batch_size: int = 500, + ): + self.catalog_name = catalog_name + self.table_name = table_name + self.batch_size = batch_size + self.client = CachedVizierClient(cache_path=cache_path) + + def prepare(self) -> None: + pass + + def get_table_name(self) -> str: + t = self.client.get_table(self.table_name, row_num=1) + if not hasattr(t, "meta") or t.meta is None: + raise RuntimeError("unable to get table name") + + return str(t.meta["ID"]) + + def get_bibcode(self) -> str: + resp = self.client.get_catalog_metadata(catalog=self.catalog_name) + return resp["origin_article"][0] + + def get_description(self) -> str: + resp = self.client.get_catalog_metadata(catalog=self.catalog_name) + return resp["title"][0] + + def get_schema(self) -> list[models.ColumnDescription]: + t = self.client.get_table(self.table_name) + result = [] + + for _, col in t.columns.items(): + result.append( + models.ColumnDescription( + name=col.name, + data_type=dtype_to_datatype(col.dtype), + ucd=col.meta.get("ucd", types.UNSET), + description=col.description, + unit=str(col.unit) if col.unit else types.UNSET, + ) + ) + + return result + + def get_data(self) -> Generator[tuple[pandas.DataFrame, float]]: + t = self.client.get_table(self.table_name) + + total_rows = len(t) + app.logger.info("uploading table", total_rows=total_rows) + + offset = 0 + for batch in itertools.batched(t, self.batch_size, strict=False): # pyright: ignore[reportArgumentType] + offset += len(batch) + + rows = [] + for row in batch: + row_dict = {k: v for k, v in dict(row).items() if v != "--"} + rows.append(row_dict) + + yield pandas.DataFrame(rows), offset / total_rows + + def stop(self) -> None: + pass + + +plugin = VizierV2Plugin +name = "vizier-v2" diff --git a/pyproject.toml b/pyproject.toml index ab7f9e7..18def46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,10 @@ dependencies = [ [tool.pytest.ini_options] pythonpath = ["."] +[tool.uv.sources] +# this is needed until the PR is merged: https://github.com/astropy/astroquery/pull/3458 +astroquery = { git = "https://github.com/Kraysent/astroquery.git", branch = "release-with-fix" } + [dependency-groups] dev = [ "ruff>=0.11.12",