Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 44 additions & 29 deletions plugins/vizier_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import numpy as np
import pandas
from astropy import table
from astroquery import utils, vizier
from astroquery import vizier
from pyvo import registry

import app
from app.gen.client.adminapi import models, types
Expand All @@ -24,6 +25,21 @@ def _sanitize_filename(string: str) -> str:
)


def _build_where_clause(constraints: list[tuple[str, str, str]]) -> str:
if not constraints:
return ""

conditions = []
for column, sign, value in constraints:
if any(char in column for char in "()[]."):
quoted_column = f'"{column}"'
else:
quoted_column = column
conditions.append(f"{quoted_column} {sign} {value}")

return " WHERE " + " AND ".join(conditions)


def dtype_to_datatype(dtype: str | np.dtype) -> models.DatatypeEnum:
dtype_str = str(dtype).lower()

Expand All @@ -50,20 +66,20 @@ def dtype_to_datatype(dtype: str | np.dtype) -> models.DatatypeEnum:


class CachedVizierClient:
TAP_ENDPOINT = "https://tapvizier.cds.unistra.fr/TAPVizieR/tap/sync"

def __init__(self, cache_path: str = ".vizier_cache/"):
self.cache_path = cache_path
self._client = vizier.Vizier()
self._client.ROW_LIMIT = -1

def _obtain_cache_path(
self, catalog_name: str, row_num: int | None = None, constraints: dict[str, str] | None = None
self, catalog_name: str, row_num: int | None = None, constraints: list[tuple[str, str, str]] | None = None
) -> pathlib.Path:
filename = f"{catalog_name}.vot"
if row_num is not None:
filename = f"{catalog_name}_rows_{row_num}.vot"
if constraints:
sorted_constraints = sorted(constraints.items())
constraint_str = "_".join(f"{k}_{v}" for k, v in sorted_constraints)
sorted_constraints = sorted(constraints)
constraint_str = "_".join(f"{col}_{sign}_{val}" for col, sign, val in sorted_constraints)
filename = f"{catalog_name}_constraints_{constraint_str}.vot"

filename = _sanitize_filename(filename)
Expand All @@ -72,32 +88,35 @@ def _obtain_cache_path(
return path

def _write_catalog_cache(
self, catalog_name: str, row_num: int | None = None, constraints: dict[str, str] | None = None
self, catalog_name: str, row_num: int | None = None, constraints: list[tuple[str, str, str]] | None = None
) -> None:
app.logger.info(
"downloading catalog from Vizier",
catalog_name=catalog_name,
row_num=row_num,
constraints=constraints,
)
client = self._client

where_clause = _build_where_clause(constraints) if constraints else ""

if row_num is not None:
client = vizier.Vizier()
client.ROW_LIMIT = row_num
query_kwargs = {"catalog": catalog_name}
if constraints:
query_kwargs.update(constraints)
catalogs: utils.TableList = client.query_constraints(**query_kwargs) # pyright: ignore[reportAttributeAccessIssue]
select_clause = f"SELECT TOP {row_num} *"
else:
select_clause = "SELECT *"

if not catalogs:
raise ValueError("catalog not found")
query = f'{select_clause}\nFROM "{catalog_name}"{where_clause}'

app.logger.info("Running query", query=query)
data = registry.regtap.RegistryQuery(self.TAP_ENDPOINT, query)
result = data.execute()
tbl = result.to_table()

cache_filename = self._obtain_cache_path(catalog_name, row_num, constraints)
catalogs[0].write(str(cache_filename), format="votable")
tbl.write(str(cache_filename), format="votable")
app.logger.debug("wrote catalog cache", location=str(cache_filename))

def get_table(
self, catalog_name: str, row_num: int | None = None, constraints: dict[str, str] | None = None
self, catalog_name: str, row_num: int | None = None, constraints: list[tuple[str, str, str]] | None = None
) -> table.Table:
cache_path = self._obtain_cache_path(catalog_name, row_num, constraints)
if not cache_path.exists():
Expand All @@ -107,7 +126,7 @@ def get_table(
return table.Table.read(cache_path, format="votable")

def get_catalog_metadata(self, catalog: str) -> dict:
return self._client.get_catalog_metadata(catalog=catalog)
return vizier.Vizier().get_catalog_metadata(catalog=catalog)


@final
Expand All @@ -125,11 +144,11 @@ def __init__(
cache_path: str = ".vizier_cache/",
batch_size: int = 10,
):
if len(constraints) % 2 != 0:
raise ValueError("constraints must be provided in pairs (column, constraint_value)")
self.constraints: dict[str, str] = {}
for i in range(0, len(constraints), 2):
self.constraints[constraints[i]] = constraints[i + 1]
if len(constraints) % 3 != 0:
raise ValueError("constraints must be provided in groups of three (column, sign, value)")
self.constraints: list[tuple[str, str, str]] = []
for i in range(0, len(constraints), 3):
self.constraints.append((constraints[i], constraints[i + 1], constraints[i + 2]))
self.catalog_name = catalog_name
self.table_name = table_name
self.batch_size = batch_size
Expand All @@ -139,11 +158,7 @@ def prepare(self) -> None:
pass

def get_table_name(self) -> str:
t = self.client.get_table(self.table_name, row_num=1)
if not hasattr(t, "meta") or t.meta is None:
raise RuntimeError("unable to get table name")

return str(t.meta["ID"])
return _sanitize_filename(self.table_name)

def get_bibcode(self) -> str:
resp = self.client.get_catalog_metadata(catalog=self.catalog_name)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"openapi-python-client>=0.27.1",
"pandas>=2.3.3",
"numpy>=2.3.4",
"pyvo>=1.8",
]

[tool.pytest.ini_options]
Expand Down
Loading