diff --git a/app/__init__.py b/app/__init__.py index eb82149..67946c6 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,12 +1,13 @@ +from app.discover import discover_plugins from app.interface import ( - UploaderPlugin, - DefaultTableNamer, BibcodeProvider, + DefaultTableNamer, DescriptionProvider, + UploaderPlugin, ) -from app.discover import discover_plugins -from app.upload import upload from app.log import logger +from app.tap import Constraint, TAPRepository +from app.upload import upload __all__ = [ "UploaderPlugin", @@ -16,4 +17,6 @@ "discover_plugins", "upload", "logger", + "TAPRepository", + "Constraint", ] diff --git a/app/tap.py b/app/tap.py new file mode 100644 index 0000000..389dc06 --- /dev/null +++ b/app/tap.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass + +from astropy import table +from pyvo import registry + +import app + +TAP_ENDPOINT = "https://tapvizier.cds.unistra.fr/TAPVizieR/tap/sync" + + +@dataclass +class Constraint: + column: str + operator: str + value: str + + +class TAPRepository: + def __init__(self, endpoint: str = TAP_ENDPOINT): + self.tap_endpoint = endpoint + + def _quote_column(self, column: str) -> str: + if any(char in column for char in "()[]."): + return f'"{column}"' + return column + + def _build_where_clause(self, constraints: list[Constraint]) -> str: + if not constraints: + return "" + + conditions = [] + for constraint in constraints: + quoted_column = self._quote_column(constraint.column) + conditions.append(f"{quoted_column} {constraint.operator} {constraint.value}") + + return " WHERE " + " AND ".join(conditions) + + def _build_order_by_clause(self, order_by: str | None) -> str: + if not order_by: + return "" + return f" ORDER BY {order_by}" + + def query( + self, + table_name: str, + constraints: list[Constraint] | None = None, + order_by: str | None = None, + limit: int | None = None, + ) -> table.Table: + where_clause = self._build_where_clause(constraints) if constraints else "" + order_by_clause = self._build_order_by_clause(order_by) + limit_clause = f" TOP {limit}" if limit else "" + + query = f'SELECT{limit_clause} *\nFROM "{table_name}"{where_clause}{order_by_clause}' + + app.logger.info("Running TAP query", query=query) + data = registry.regtap.RegistryQuery(self.tap_endpoint, query) + result = data.execute() + return result.to_table() diff --git a/makefile b/makefile index cb8dc5e..c34fab3 100644 --- a/makefile +++ b/makefile @@ -8,6 +8,7 @@ check: fix: uvx ruff format . + uvx ruff check --fix test: uv run pytest tests diff --git a/plugins/vizier_v2.py b/plugins/vizier_v2.py index ef856ba..2fcd972 100644 --- a/plugins/vizier_v2.py +++ b/plugins/vizier_v2.py @@ -1,13 +1,9 @@ -import itertools -import pathlib from collections.abc import Generator from typing import final import numpy as np import pandas -from astropy import table from astroquery import vizier -from pyvo import registry import app from app.gen.client.adminapi import models, types @@ -25,21 +21,6 @@ def _sanitize_filename(string: str) -> str: ) -def _build_where_clause(constraints: list[tuple[str, str, str]]) -> str: - if not constraints: - return "" - - conditions = [] - for column, sign, value in constraints: - if any(char in column for char in "()[]."): - quoted_column = f'"{column}"' - else: - quoted_column = column - conditions.append(f"{quoted_column} {sign} {value}") - - return " WHERE " + " AND ".join(conditions) - - def dtype_to_datatype(dtype: str | np.dtype) -> models.DatatypeEnum: dtype_str = str(dtype).lower() @@ -65,70 +46,6 @@ def dtype_to_datatype(dtype: str | np.dtype) -> models.DatatypeEnum: return models.DatatypeEnum.STRING -class CachedVizierClient: - TAP_ENDPOINT = "https://tapvizier.cds.unistra.fr/TAPVizieR/tap/sync" - - def __init__(self, cache_path: str = ".vizier_cache/"): - self.cache_path = cache_path - - def _obtain_cache_path( - self, catalog_name: str, row_num: int | None = None, constraints: list[tuple[str, str, str]] | None = None - ) -> pathlib.Path: - filename = f"{catalog_name}.vot" - if row_num is not None: - filename = f"{catalog_name}_rows_{row_num}.vot" - if constraints: - sorted_constraints = sorted(constraints) - constraint_str = "_".join(f"{col}_{sign}_{val}" for col, sign, val in sorted_constraints) - filename = f"{catalog_name}_constraints_{constraint_str}.vot" - - filename = _sanitize_filename(filename) - path = pathlib.Path(self.cache_path) / "catalogs" / filename - path.parent.mkdir(parents=True, exist_ok=True) - return path - - def _write_catalog_cache( - self, catalog_name: str, row_num: int | None = None, constraints: list[tuple[str, str, str]] | None = None - ) -> None: - app.logger.info( - "downloading catalog from Vizier", - catalog_name=catalog_name, - row_num=row_num, - constraints=constraints, - ) - - where_clause = _build_where_clause(constraints) if constraints else "" - - if row_num is not None: - select_clause = f"SELECT TOP {row_num} *" - else: - select_clause = "SELECT *" - - query = f'{select_clause}\nFROM "{catalog_name}"{where_clause}' - - app.logger.info("Running query", query=query) - data = registry.regtap.RegistryQuery(self.TAP_ENDPOINT, query) - result = data.execute() - tbl = result.to_table() - - cache_filename = self._obtain_cache_path(catalog_name, row_num, constraints) - tbl.write(str(cache_filename), format="votable") - app.logger.debug("wrote catalog cache", location=str(cache_filename)) - - def get_table( - self, catalog_name: str, row_num: int | None = None, constraints: list[tuple[str, str, str]] | None = None - ) -> table.Table: - cache_path = self._obtain_cache_path(catalog_name, row_num, constraints) - if not cache_path.exists(): - app.logger.debug("did not hit cache for the catalog, downloading") - self._write_catalog_cache(catalog_name, row_num, constraints) - - return table.Table.read(cache_path, format="votable") - - def get_catalog_metadata(self, catalog: str) -> dict: - return vizier.Vizier().get_catalog_metadata(catalog=catalog) - - @final class VizierV2Plugin( app.UploaderPlugin, @@ -140,19 +57,23 @@ def __init__( self, catalog_name: str, table_name: str, + index_column: str, *constraints: str, cache_path: str = ".vizier_cache/", - batch_size: int = 10, + batch_size: int = 1000, ): if len(constraints) % 3 != 0: raise ValueError("constraints must be provided in groups of three (column, sign, value)") - self.constraints: list[tuple[str, str, str]] = [] + self.constraints: list[app.Constraint] = [] for i in range(0, len(constraints), 3): - self.constraints.append((constraints[i], constraints[i + 1], constraints[i + 2])) + self.constraints.append( + app.Constraint(column=constraints[i], operator=constraints[i + 1], value=constraints[i + 2]) + ) self.catalog_name = catalog_name self.table_name = table_name + self.index_column = index_column self.batch_size = batch_size - self.client = CachedVizierClient(cache_path=cache_path) + self.repo = app.TAPRepository() def prepare(self) -> None: pass @@ -161,15 +82,15 @@ def get_table_name(self) -> str: return _sanitize_filename(self.table_name) def get_bibcode(self) -> str: - resp = self.client.get_catalog_metadata(catalog=self.catalog_name) + resp = vizier.Vizier().get_catalog_metadata(catalog=self.catalog_name) return resp["origin_article"][0] def get_description(self) -> str: - resp = self.client.get_catalog_metadata(catalog=self.catalog_name) + resp = vizier.Vizier().get_catalog_metadata(catalog=self.catalog_name) return resp["title"][0] def get_schema(self) -> list[models.ColumnDescription]: - t = self.client.get_table(self.table_name) + t = self.repo.query(self.table_name, limit=1) result = [] for _, col in t.columns.items(): @@ -186,22 +107,49 @@ def get_schema(self) -> list[models.ColumnDescription]: return result def get_data(self) -> Generator[tuple[pandas.DataFrame, float]]: - constraints = self.constraints if self.constraints else None - t = self.client.get_table(self.table_name, constraints=constraints) + last_index_value = None + total_rows_processed = 0 + batch_number = 0 + + while True: + constraints = list(self.constraints) if self.constraints else [] + + if last_index_value is not None: + constraints.append(app.Constraint(column=self.index_column, operator=">", value=str(last_index_value))) - total_rows = len(t) - app.logger.info("uploading table", total_rows=total_rows) + quoted_index_column = ( + f'"{self.index_column}"' if any(char in self.index_column for char in "()[].") else self.index_column + ) + order_by = f"{quoted_index_column} ASC" + t = self.repo.query( + self.table_name, + constraints=constraints if constraints else None, + order_by=order_by, + limit=self.batch_size, + ) - offset = 0 - for batch in itertools.batched(t, self.batch_size, strict=False): # pyright: ignore[reportArgumentType] - offset += len(batch) + if len(t) == 0: + break rows = [] - for row in batch: + for row in t: row_dict = {k: v for k, v in dict(row).items() if v != "--"} rows.append(row_dict) + last_index_value = row[self.index_column] + + total_rows_processed += len(rows) + batch_number += 1 + + app.logger.info( + "uploading batch", + batch_number=batch_number, + rows_in_batch=len(rows), + total_rows_processed=total_rows_processed, + ) + + yield pandas.DataFrame(rows), 0.0 - yield pandas.DataFrame(rows), offset / total_rows + app.logger.info("finished uploading table", total_rows=total_rows_processed) def stop(self) -> None: pass