Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions app/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from app.discover import discover_plugins
from app.interface import (
UploaderPlugin,
DefaultTableNamer,
BibcodeProvider,
DefaultTableNamer,
DescriptionProvider,
UploaderPlugin,
)
from app.discover import discover_plugins
from app.upload import upload
from app.log import logger
from app.tap import Constraint, TAPRepository
from app.upload import upload

__all__ = [
"UploaderPlugin",
Expand All @@ -16,4 +17,6 @@
"discover_plugins",
"upload",
"logger",
"TAPRepository",
"Constraint",
]
59 changes: 59 additions & 0 deletions app/tap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from dataclasses import dataclass

from astropy import table
from pyvo import registry

import app

TAP_ENDPOINT = "https://tapvizier.cds.unistra.fr/TAPVizieR/tap/sync"


@dataclass
class Constraint:
column: str
operator: str
value: str


class TAPRepository:
def __init__(self, endpoint: str = TAP_ENDPOINT):
self.tap_endpoint = endpoint

def _quote_column(self, column: str) -> str:
if any(char in column for char in "()[]."):
return f'"{column}"'
return column

def _build_where_clause(self, constraints: list[Constraint]) -> str:
if not constraints:
return ""

conditions = []
for constraint in constraints:
quoted_column = self._quote_column(constraint.column)
conditions.append(f"{quoted_column} {constraint.operator} {constraint.value}")

return " WHERE " + " AND ".join(conditions)

def _build_order_by_clause(self, order_by: str | None) -> str:
if not order_by:
return ""
return f" ORDER BY {order_by}"

def query(
self,
table_name: str,
constraints: list[Constraint] | None = None,
order_by: str | None = None,
limit: int | None = None,
) -> table.Table:
where_clause = self._build_where_clause(constraints) if constraints else ""
order_by_clause = self._build_order_by_clause(order_by)
limit_clause = f" TOP {limit}" if limit else ""

query = f'SELECT{limit_clause} *\nFROM "{table_name}"{where_clause}{order_by_clause}'

app.logger.info("Running TAP query", query=query)
data = registry.regtap.RegistryQuery(self.tap_endpoint, query)
result = data.execute()
return result.to_table()
1 change: 1 addition & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ check:

fix:
uvx ruff format .
uvx ruff check --fix

test:
uv run pytest tests
Expand Down
146 changes: 47 additions & 99 deletions plugins/vizier_v2.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import itertools
import pathlib
from collections.abc import Generator
from typing import final

import numpy as np
import pandas
from astropy import table
from astroquery import vizier
from pyvo import registry

import app
from app.gen.client.adminapi import models, types
Expand All @@ -25,21 +21,6 @@ def _sanitize_filename(string: str) -> str:
)


def _build_where_clause(constraints: list[tuple[str, str, str]]) -> str:
if not constraints:
return ""

conditions = []
for column, sign, value in constraints:
if any(char in column for char in "()[]."):
quoted_column = f'"{column}"'
else:
quoted_column = column
conditions.append(f"{quoted_column} {sign} {value}")

return " WHERE " + " AND ".join(conditions)


def dtype_to_datatype(dtype: str | np.dtype) -> models.DatatypeEnum:
dtype_str = str(dtype).lower()

Expand All @@ -65,70 +46,6 @@ def dtype_to_datatype(dtype: str | np.dtype) -> models.DatatypeEnum:
return models.DatatypeEnum.STRING


class CachedVizierClient:
TAP_ENDPOINT = "https://tapvizier.cds.unistra.fr/TAPVizieR/tap/sync"

def __init__(self, cache_path: str = ".vizier_cache/"):
self.cache_path = cache_path

def _obtain_cache_path(
self, catalog_name: str, row_num: int | None = None, constraints: list[tuple[str, str, str]] | None = None
) -> pathlib.Path:
filename = f"{catalog_name}.vot"
if row_num is not None:
filename = f"{catalog_name}_rows_{row_num}.vot"
if constraints:
sorted_constraints = sorted(constraints)
constraint_str = "_".join(f"{col}_{sign}_{val}" for col, sign, val in sorted_constraints)
filename = f"{catalog_name}_constraints_{constraint_str}.vot"

filename = _sanitize_filename(filename)
path = pathlib.Path(self.cache_path) / "catalogs" / filename
path.parent.mkdir(parents=True, exist_ok=True)
return path

def _write_catalog_cache(
self, catalog_name: str, row_num: int | None = None, constraints: list[tuple[str, str, str]] | None = None
) -> None:
app.logger.info(
"downloading catalog from Vizier",
catalog_name=catalog_name,
row_num=row_num,
constraints=constraints,
)

where_clause = _build_where_clause(constraints) if constraints else ""

if row_num is not None:
select_clause = f"SELECT TOP {row_num} *"
else:
select_clause = "SELECT *"

query = f'{select_clause}\nFROM "{catalog_name}"{where_clause}'

app.logger.info("Running query", query=query)
data = registry.regtap.RegistryQuery(self.TAP_ENDPOINT, query)
result = data.execute()
tbl = result.to_table()

cache_filename = self._obtain_cache_path(catalog_name, row_num, constraints)
tbl.write(str(cache_filename), format="votable")
app.logger.debug("wrote catalog cache", location=str(cache_filename))

def get_table(
self, catalog_name: str, row_num: int | None = None, constraints: list[tuple[str, str, str]] | None = None
) -> table.Table:
cache_path = self._obtain_cache_path(catalog_name, row_num, constraints)
if not cache_path.exists():
app.logger.debug("did not hit cache for the catalog, downloading")
self._write_catalog_cache(catalog_name, row_num, constraints)

return table.Table.read(cache_path, format="votable")

def get_catalog_metadata(self, catalog: str) -> dict:
return vizier.Vizier().get_catalog_metadata(catalog=catalog)


@final
class VizierV2Plugin(
app.UploaderPlugin,
Expand All @@ -140,19 +57,23 @@ def __init__(
self,
catalog_name: str,
table_name: str,
index_column: str,
*constraints: str,
cache_path: str = ".vizier_cache/",
batch_size: int = 10,
batch_size: int = 1000,
):
if len(constraints) % 3 != 0:
raise ValueError("constraints must be provided in groups of three (column, sign, value)")
self.constraints: list[tuple[str, str, str]] = []
self.constraints: list[app.Constraint] = []
for i in range(0, len(constraints), 3):
self.constraints.append((constraints[i], constraints[i + 1], constraints[i + 2]))
self.constraints.append(
app.Constraint(column=constraints[i], operator=constraints[i + 1], value=constraints[i + 2])
)
self.catalog_name = catalog_name
self.table_name = table_name
self.index_column = index_column
self.batch_size = batch_size
self.client = CachedVizierClient(cache_path=cache_path)
self.repo = app.TAPRepository()

def prepare(self) -> None:
pass
Expand All @@ -161,15 +82,15 @@ def get_table_name(self) -> str:
return _sanitize_filename(self.table_name)

def get_bibcode(self) -> str:
resp = self.client.get_catalog_metadata(catalog=self.catalog_name)
resp = vizier.Vizier().get_catalog_metadata(catalog=self.catalog_name)
return resp["origin_article"][0]

def get_description(self) -> str:
resp = self.client.get_catalog_metadata(catalog=self.catalog_name)
resp = vizier.Vizier().get_catalog_metadata(catalog=self.catalog_name)
return resp["title"][0]

def get_schema(self) -> list[models.ColumnDescription]:
t = self.client.get_table(self.table_name)
t = self.repo.query(self.table_name, limit=1)
result = []

for _, col in t.columns.items():
Expand All @@ -186,22 +107,49 @@ def get_schema(self) -> list[models.ColumnDescription]:
return result

def get_data(self) -> Generator[tuple[pandas.DataFrame, float]]:
constraints = self.constraints if self.constraints else None
t = self.client.get_table(self.table_name, constraints=constraints)
last_index_value = None
total_rows_processed = 0
batch_number = 0

while True:
constraints = list(self.constraints) if self.constraints else []

if last_index_value is not None:
constraints.append(app.Constraint(column=self.index_column, operator=">", value=str(last_index_value)))

total_rows = len(t)
app.logger.info("uploading table", total_rows=total_rows)
quoted_index_column = (
f'"{self.index_column}"' if any(char in self.index_column for char in "()[].") else self.index_column
)
order_by = f"{quoted_index_column} ASC"
t = self.repo.query(
self.table_name,
constraints=constraints if constraints else None,
order_by=order_by,
limit=self.batch_size,
)

offset = 0
for batch in itertools.batched(t, self.batch_size, strict=False): # pyright: ignore[reportArgumentType]
offset += len(batch)
if len(t) == 0:
break

rows = []
for row in batch:
for row in t:
row_dict = {k: v for k, v in dict(row).items() if v != "--"}
rows.append(row_dict)
last_index_value = row[self.index_column]

total_rows_processed += len(rows)
batch_number += 1

app.logger.info(
"uploading batch",
batch_number=batch_number,
rows_in_batch=len(rows),
total_rows_processed=total_rows_processed,
)

yield pandas.DataFrame(rows), 0.0

yield pandas.DataFrame(rows), offset / total_rows
app.logger.info("finished uploading table", total_rows=total_rows_processed)

def stop(self) -> None:
pass
Expand Down
Loading