From 3fe84491b401cf2f0bf6b29eb628eb6d58138a35 Mon Sep 17 00:00:00 2001 From: kraysent Date: Sat, 15 Nov 2025 14:34:52 +0000 Subject: [PATCH 1/5] #5: batched download --- app/__init__.py | 11 +++++--- app/tap.py | 57 ++++++++++++++++++++++++++++++++++++++++ plugins/vizier_v2.py | 62 +++++++++++++++++++------------------------- 3 files changed, 90 insertions(+), 40 deletions(-) create mode 100644 app/tap.py diff --git a/app/__init__.py b/app/__init__.py index eb82149..67946c6 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,12 +1,13 @@ +from app.discover import discover_plugins from app.interface import ( - UploaderPlugin, - DefaultTableNamer, BibcodeProvider, + DefaultTableNamer, DescriptionProvider, + UploaderPlugin, ) -from app.discover import discover_plugins -from app.upload import upload from app.log import logger +from app.tap import Constraint, TAPRepository +from app.upload import upload __all__ = [ "UploaderPlugin", @@ -16,4 +17,6 @@ "discover_plugins", "upload", "logger", + "TAPRepository", + "Constraint", ] diff --git a/app/tap.py b/app/tap.py new file mode 100644 index 0000000..f79c0c5 --- /dev/null +++ b/app/tap.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass + +from astropy import table +from pyvo import registry + +import app + + +@dataclass +class Constraint: + column: str + operator: str + value: str + + +class TAPRepository: + TAP_ENDPOINT = "https://tapvizier.cds.unistra.fr/TAPVizieR/tap/sync" + + def __init__(self, tap_endpoint: str = TAP_ENDPOINT): + self.tap_endpoint = tap_endpoint + + def _quote_column(self, column: str) -> str: + if any(char in column for char in "()[]."): + return f'"{column}"' + return column + + def _build_where_clause(self, constraints: list[Constraint]) -> str: + if not constraints: + return "" + + conditions = [] + for constraint in constraints: + quoted_column = self._quote_column(constraint.column) + conditions.append(f"{quoted_column} {constraint.operator} {constraint.value}") + + return " WHERE " + " AND ".join(conditions) + + def _build_order_by_clause(self, order_by: str | None) -> str: + if not order_by: + return "" + return f" ORDER BY {order_by}" + + def query( + self, + table_name: str, + constraints: list[Constraint] | None = None, + order_by: str | None = None, + ) -> table.Table: + where_clause = self._build_where_clause(constraints) if constraints else "" + order_by_clause = self._build_order_by_clause(order_by) + + query = f'SELECT *\nFROM "{table_name}"{where_clause}{order_by_clause}' + + app.logger.info("Running TAP query", query=query) + data = registry.regtap.RegistryQuery(self.tap_endpoint, query) + result = data.execute() + return result.to_table() diff --git a/plugins/vizier_v2.py b/plugins/vizier_v2.py index ef856ba..22cbdf5 100644 --- a/plugins/vizier_v2.py +++ b/plugins/vizier_v2.py @@ -7,7 +7,6 @@ import pandas from astropy import table from astroquery import vizier -from pyvo import registry import app from app.gen.client.adminapi import models, types @@ -25,21 +24,6 @@ def _sanitize_filename(string: str) -> str: ) -def _build_where_clause(constraints: list[tuple[str, str, str]]) -> str: - if not constraints: - return "" - - conditions = [] - for column, sign, value in constraints: - if any(char in column for char in "()[]."): - quoted_column = f'"{column}"' - else: - quoted_column = column - conditions.append(f"{quoted_column} {sign} {value}") - - return " WHERE " + " AND ".join(conditions) - - def dtype_to_datatype(dtype: str | np.dtype) -> models.DatatypeEnum: dtype_str = str(dtype).lower() @@ -66,19 +50,26 @@ def dtype_to_datatype(dtype: str | np.dtype) -> models.DatatypeEnum: class CachedVizierClient: - TAP_ENDPOINT = "https://tapvizier.cds.unistra.fr/TAPVizieR/tap/sync" - def __init__(self, cache_path: str = ".vizier_cache/"): self.cache_path = cache_path + self.tap_repository = app.TAPRepository() + + def _constraints_to_tuples(self, constraints: list[app.Constraint] | None) -> list[tuple[str, str, str]] | None: + if not constraints: + return None + return [(c.column, c.operator, c.value) for c in constraints] def _obtain_cache_path( - self, catalog_name: str, row_num: int | None = None, constraints: list[tuple[str, str, str]] | None = None + self, + catalog_name: str, + row_num: int | None = None, + constraints: list[app.Constraint] | None = None, ) -> pathlib.Path: filename = f"{catalog_name}.vot" if row_num is not None: filename = f"{catalog_name}_rows_{row_num}.vot" if constraints: - sorted_constraints = sorted(constraints) + sorted_constraints = sorted(self._constraints_to_tuples(constraints) or []) constraint_str = "_".join(f"{col}_{sign}_{val}" for col, sign, val in sorted_constraints) filename = f"{catalog_name}_constraints_{constraint_str}.vot" @@ -88,35 +79,32 @@ def _obtain_cache_path( return path def _write_catalog_cache( - self, catalog_name: str, row_num: int | None = None, constraints: list[tuple[str, str, str]] | None = None + self, + catalog_name: str, + row_num: int | None = None, + constraints: list[app.Constraint] | None = None, ) -> None: app.logger.info( "downloading catalog from Vizier", catalog_name=catalog_name, row_num=row_num, - constraints=constraints, + constraints=self._constraints_to_tuples(constraints), ) - where_clause = _build_where_clause(constraints) if constraints else "" + tbl: table.Table = self.tap_repository.query(catalog_name, constraints=constraints) if row_num is not None: - select_clause = f"SELECT TOP {row_num} *" - else: - select_clause = "SELECT *" - - query = f'{select_clause}\nFROM "{catalog_name}"{where_clause}' - - app.logger.info("Running query", query=query) - data = registry.regtap.RegistryQuery(self.TAP_ENDPOINT, query) - result = data.execute() - tbl = result.to_table() + tbl = table.Table(tbl[:row_num]) cache_filename = self._obtain_cache_path(catalog_name, row_num, constraints) tbl.write(str(cache_filename), format="votable") app.logger.debug("wrote catalog cache", location=str(cache_filename)) def get_table( - self, catalog_name: str, row_num: int | None = None, constraints: list[tuple[str, str, str]] | None = None + self, + catalog_name: str, + row_num: int | None = None, + constraints: list[app.Constraint] | None = None, ) -> table.Table: cache_path = self._obtain_cache_path(catalog_name, row_num, constraints) if not cache_path.exists(): @@ -146,9 +134,11 @@ def __init__( ): if len(constraints) % 3 != 0: raise ValueError("constraints must be provided in groups of three (column, sign, value)") - self.constraints: list[tuple[str, str, str]] = [] + self.constraints: list[app.Constraint] = [] for i in range(0, len(constraints), 3): - self.constraints.append((constraints[i], constraints[i + 1], constraints[i + 2])) + self.constraints.append( + app.Constraint(column=constraints[i], operator=constraints[i + 1], value=constraints[i + 2]) + ) self.catalog_name = catalog_name self.table_name = table_name self.batch_size = batch_size From b52d3e542f642e889af5f64517910db522e5fa48 Mon Sep 17 00:00:00 2001 From: kraysent Date: Sat, 15 Nov 2025 14:40:43 +0000 Subject: [PATCH 2/5] cleanup --- app/tap.py | 8 ++++---- plugins/vizier_v2.py | 29 +++++++++++++---------------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/app/tap.py b/app/tap.py index f79c0c5..95c655b 100644 --- a/app/tap.py +++ b/app/tap.py @@ -5,6 +5,8 @@ import app +TAP_ENDPOINT = "https://tapvizier.cds.unistra.fr/TAPVizieR/tap/sync" + @dataclass class Constraint: @@ -14,10 +16,8 @@ class Constraint: class TAPRepository: - TAP_ENDPOINT = "https://tapvizier.cds.unistra.fr/TAPVizieR/tap/sync" - - def __init__(self, tap_endpoint: str = TAP_ENDPOINT): - self.tap_endpoint = tap_endpoint + def __init__(self, endpoint: str = TAP_ENDPOINT): + self.tap_endpoint = endpoint def _quote_column(self, column: str) -> str: if any(char in column for char in "()[]."): diff --git a/plugins/vizier_v2.py b/plugins/vizier_v2.py index 22cbdf5..1dc3264 100644 --- a/plugins/vizier_v2.py +++ b/plugins/vizier_v2.py @@ -49,15 +49,18 @@ def dtype_to_datatype(dtype: str | np.dtype) -> models.DatatypeEnum: return models.DatatypeEnum.STRING -class CachedVizierClient: +class CachedTAPRepository: def __init__(self, cache_path: str = ".vizier_cache/"): self.cache_path = cache_path - self.tap_repository = app.TAPRepository() + self.repo = app.TAPRepository() - def _constraints_to_tuples(self, constraints: list[app.Constraint] | None) -> list[tuple[str, str, str]] | None: + def _get_filename(self, catalog_name: str, constraints: list[app.Constraint] | None) -> str: if not constraints: - return None - return [(c.column, c.operator, c.value) for c in constraints] + return f"{catalog_name}.vot" + + sorted_constraints = sorted([(c.column, c.operator, c.value) for c in constraints] or []) + constraint_str = "_".join(f"{col}_{sign}_{val}" for col, sign, val in sorted_constraints) + return f"{catalog_name}_constraints_{constraint_str}.vot" def _obtain_cache_path( self, @@ -69,9 +72,7 @@ def _obtain_cache_path( if row_num is not None: filename = f"{catalog_name}_rows_{row_num}.vot" if constraints: - sorted_constraints = sorted(self._constraints_to_tuples(constraints) or []) - constraint_str = "_".join(f"{col}_{sign}_{val}" for col, sign, val in sorted_constraints) - filename = f"{catalog_name}_constraints_{constraint_str}.vot" + filename = self._get_filename(catalog_name, constraints) filename = _sanitize_filename(filename) path = pathlib.Path(self.cache_path) / "catalogs" / filename @@ -88,10 +89,9 @@ def _write_catalog_cache( "downloading catalog from Vizier", catalog_name=catalog_name, row_num=row_num, - constraints=self._constraints_to_tuples(constraints), ) - tbl: table.Table = self.tap_repository.query(catalog_name, constraints=constraints) + tbl: table.Table = self.repo.query(catalog_name, constraints=constraints) if row_num is not None: tbl = table.Table(tbl[:row_num]) @@ -113,9 +113,6 @@ def get_table( return table.Table.read(cache_path, format="votable") - def get_catalog_metadata(self, catalog: str) -> dict: - return vizier.Vizier().get_catalog_metadata(catalog=catalog) - @final class VizierV2Plugin( @@ -142,7 +139,7 @@ def __init__( self.catalog_name = catalog_name self.table_name = table_name self.batch_size = batch_size - self.client = CachedVizierClient(cache_path=cache_path) + self.client = CachedTAPRepository(cache_path=cache_path) def prepare(self) -> None: pass @@ -151,11 +148,11 @@ def get_table_name(self) -> str: return _sanitize_filename(self.table_name) def get_bibcode(self) -> str: - resp = self.client.get_catalog_metadata(catalog=self.catalog_name) + resp = vizier.Vizier().get_catalog_metadata(catalog=self.catalog_name) return resp["origin_article"][0] def get_description(self) -> str: - resp = self.client.get_catalog_metadata(catalog=self.catalog_name) + resp = vizier.Vizier().get_catalog_metadata(catalog=self.catalog_name) return resp["title"][0] def get_schema(self) -> list[models.ColumnDescription]: From b1a958cb95d3a4114f9214776bdc1cc7f545888c Mon Sep 17 00:00:00 2001 From: kraysent Date: Sat, 15 Nov 2025 14:50:57 +0000 Subject: [PATCH 3/5] remove cache --- app/tap.py | 4 ++- plugins/vizier_v2.py | 71 ++------------------------------------------ 2 files changed, 6 insertions(+), 69 deletions(-) diff --git a/app/tap.py b/app/tap.py index 95c655b..389dc06 100644 --- a/app/tap.py +++ b/app/tap.py @@ -45,11 +45,13 @@ def query( table_name: str, constraints: list[Constraint] | None = None, order_by: str | None = None, + limit: int | None = None, ) -> table.Table: where_clause = self._build_where_clause(constraints) if constraints else "" order_by_clause = self._build_order_by_clause(order_by) + limit_clause = f" TOP {limit}" if limit else "" - query = f'SELECT *\nFROM "{table_name}"{where_clause}{order_by_clause}' + query = f'SELECT{limit_clause} *\nFROM "{table_name}"{where_clause}{order_by_clause}' app.logger.info("Running TAP query", query=query) data = registry.regtap.RegistryQuery(self.tap_endpoint, query) diff --git a/plugins/vizier_v2.py b/plugins/vizier_v2.py index 1dc3264..afb05b9 100644 --- a/plugins/vizier_v2.py +++ b/plugins/vizier_v2.py @@ -49,71 +49,6 @@ def dtype_to_datatype(dtype: str | np.dtype) -> models.DatatypeEnum: return models.DatatypeEnum.STRING -class CachedTAPRepository: - def __init__(self, cache_path: str = ".vizier_cache/"): - self.cache_path = cache_path - self.repo = app.TAPRepository() - - def _get_filename(self, catalog_name: str, constraints: list[app.Constraint] | None) -> str: - if not constraints: - return f"{catalog_name}.vot" - - sorted_constraints = sorted([(c.column, c.operator, c.value) for c in constraints] or []) - constraint_str = "_".join(f"{col}_{sign}_{val}" for col, sign, val in sorted_constraints) - return f"{catalog_name}_constraints_{constraint_str}.vot" - - def _obtain_cache_path( - self, - catalog_name: str, - row_num: int | None = None, - constraints: list[app.Constraint] | None = None, - ) -> pathlib.Path: - filename = f"{catalog_name}.vot" - if row_num is not None: - filename = f"{catalog_name}_rows_{row_num}.vot" - if constraints: - filename = self._get_filename(catalog_name, constraints) - - filename = _sanitize_filename(filename) - path = pathlib.Path(self.cache_path) / "catalogs" / filename - path.parent.mkdir(parents=True, exist_ok=True) - return path - - def _write_catalog_cache( - self, - catalog_name: str, - row_num: int | None = None, - constraints: list[app.Constraint] | None = None, - ) -> None: - app.logger.info( - "downloading catalog from Vizier", - catalog_name=catalog_name, - row_num=row_num, - ) - - tbl: table.Table = self.repo.query(catalog_name, constraints=constraints) - - if row_num is not None: - tbl = table.Table(tbl[:row_num]) - - cache_filename = self._obtain_cache_path(catalog_name, row_num, constraints) - tbl.write(str(cache_filename), format="votable") - app.logger.debug("wrote catalog cache", location=str(cache_filename)) - - def get_table( - self, - catalog_name: str, - row_num: int | None = None, - constraints: list[app.Constraint] | None = None, - ) -> table.Table: - cache_path = self._obtain_cache_path(catalog_name, row_num, constraints) - if not cache_path.exists(): - app.logger.debug("did not hit cache for the catalog, downloading") - self._write_catalog_cache(catalog_name, row_num, constraints) - - return table.Table.read(cache_path, format="votable") - - @final class VizierV2Plugin( app.UploaderPlugin, @@ -139,7 +74,7 @@ def __init__( self.catalog_name = catalog_name self.table_name = table_name self.batch_size = batch_size - self.client = CachedTAPRepository(cache_path=cache_path) + self.repo = app.TAPRepository() def prepare(self) -> None: pass @@ -156,7 +91,7 @@ def get_description(self) -> str: return resp["title"][0] def get_schema(self) -> list[models.ColumnDescription]: - t = self.client.get_table(self.table_name) + t = self.repo.query(self.table_name, limit=1) result = [] for _, col in t.columns.items(): @@ -174,7 +109,7 @@ def get_schema(self) -> list[models.ColumnDescription]: def get_data(self) -> Generator[tuple[pandas.DataFrame, float]]: constraints = self.constraints if self.constraints else None - t = self.client.get_table(self.table_name, constraints=constraints) + t = self.repo.query(self.table_name, constraints=constraints) total_rows = len(t) app.logger.info("uploading table", total_rows=total_rows) From e75d7f9d5e1d7b567de985c1879d089be2fa1ae4 Mon Sep 17 00:00:00 2001 From: kraysent Date: Sat, 15 Nov 2025 14:52:47 +0000 Subject: [PATCH 4/5] style --- makefile | 1 + plugins/vizier_v2.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/makefile b/makefile index cb8dc5e..c34fab3 100644 --- a/makefile +++ b/makefile @@ -8,6 +8,7 @@ check: fix: uvx ruff format . + uvx ruff check --fix test: uv run pytest tests diff --git a/plugins/vizier_v2.py b/plugins/vizier_v2.py index afb05b9..6ab4a2a 100644 --- a/plugins/vizier_v2.py +++ b/plugins/vizier_v2.py @@ -1,11 +1,9 @@ import itertools -import pathlib from collections.abc import Generator from typing import final import numpy as np import pandas -from astropy import table from astroquery import vizier import app From 9824cd632c84b7ae0509f4bfe14683aa2472682d Mon Sep 17 00:00:00 2001 From: kraysent Date: Sat, 15 Nov 2025 15:02:51 +0000 Subject: [PATCH 5/5] add proper batching --- plugins/vizier_v2.py | 50 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/plugins/vizier_v2.py b/plugins/vizier_v2.py index 6ab4a2a..2fcd972 100644 --- a/plugins/vizier_v2.py +++ b/plugins/vizier_v2.py @@ -1,4 +1,3 @@ -import itertools from collections.abc import Generator from typing import final @@ -58,9 +57,10 @@ def __init__( self, catalog_name: str, table_name: str, + index_column: str, *constraints: str, cache_path: str = ".vizier_cache/", - batch_size: int = 10, + batch_size: int = 1000, ): if len(constraints) % 3 != 0: raise ValueError("constraints must be provided in groups of three (column, sign, value)") @@ -71,6 +71,7 @@ def __init__( ) self.catalog_name = catalog_name self.table_name = table_name + self.index_column = index_column self.batch_size = batch_size self.repo = app.TAPRepository() @@ -106,22 +107,49 @@ def get_schema(self) -> list[models.ColumnDescription]: return result def get_data(self) -> Generator[tuple[pandas.DataFrame, float]]: - constraints = self.constraints if self.constraints else None - t = self.repo.query(self.table_name, constraints=constraints) + last_index_value = None + total_rows_processed = 0 + batch_number = 0 - total_rows = len(t) - app.logger.info("uploading table", total_rows=total_rows) + while True: + constraints = list(self.constraints) if self.constraints else [] - offset = 0 - for batch in itertools.batched(t, self.batch_size, strict=False): # pyright: ignore[reportArgumentType] - offset += len(batch) + if last_index_value is not None: + constraints.append(app.Constraint(column=self.index_column, operator=">", value=str(last_index_value))) + + quoted_index_column = ( + f'"{self.index_column}"' if any(char in self.index_column for char in "()[].") else self.index_column + ) + order_by = f"{quoted_index_column} ASC" + t = self.repo.query( + self.table_name, + constraints=constraints if constraints else None, + order_by=order_by, + limit=self.batch_size, + ) + + if len(t) == 0: + break rows = [] - for row in batch: + for row in t: row_dict = {k: v for k, v in dict(row).items() if v != "--"} rows.append(row_dict) + last_index_value = row[self.index_column] + + total_rows_processed += len(rows) + batch_number += 1 + + app.logger.info( + "uploading batch", + batch_number=batch_number, + rows_in_batch=len(rows), + total_rows_processed=total_rows_processed, + ) + + yield pandas.DataFrame(rows), 0.0 - yield pandas.DataFrame(rows), offset / total_rows + app.logger.info("finished uploading table", total_rows=total_rows_processed) def stop(self) -> None: pass