Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions app/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def discover_plugins(dir: str) -> dict[str, type[interface.UploaderPlugin]]:
spec.loader.exec_module(module)

if not hasattr(module, "plugin"):
log.logger.warn(
"python file has no declared plugin", filename=str(file_path)
)
log.logger.warn("python file has no declared plugin", filename=str(file_path))
continue

plugin_class = getattr(module, "plugin")
Expand Down
9 changes: 8 additions & 1 deletion app/upload.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Any

import click
Expand Down Expand Up @@ -37,6 +38,12 @@ def handle_call[T: Any](response: types.Response[T | models.HTTPValidationError]
return response.parsed


def sanitize_value(val: Any) -> Any:
if isinstance(val, float) and math.isnan(val):
return None
return val


def _upload(
plugin: interface.UploaderPlugin,
client: adminapi.AuthenticatedClient,
Expand Down Expand Up @@ -88,7 +95,7 @@ def _upload(
for _, row in data.iterrows():
item = models.AddDataRequestDataItem()
for col in data.columns:
item[col] = row[col]
item[col] = sanitize_value(row[col])
request_data.append(item)

_ = handle_call(
Expand Down
12 changes: 5 additions & 7 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def discover(plugin_dir: str) -> None:


table_name_descr = "Table name is a primary identifier of the table in HyperLEDA. It usually is a machine-readable string that will later be user to do any alterations to the table. Example: sdss_dr12."
table_description_descr = "Description of the table is a human-readable string that can later be used for searching of viewing the table."
table_description_descr = (
"Description of the table is a human-readable string that can later be used for searching of viewing the table."
)
bibcode_descr = "Bibcode is an identifier for the publication from the NASA ADS system https://ui.adsabs.harvard.edu/. It allows for easy search of the publication throughout a range of different sources."
pub_name_descr = "Name of the internal source. Can be a short description that represents where the data comes from."
pub_authors_descr = "Comma-separated list of authors of the internal source."
Expand Down Expand Up @@ -108,9 +110,7 @@ def upload(
try:
default_table_name = plugin.get_table_name()
except Exception:
app.logger.warning(
"failed to get default table name from plugin", plugin=plugin
)
app.logger.warning("failed to get default table name from plugin", plugin=plugin)

table_name = question(
"Enter table name",
Expand Down Expand Up @@ -207,9 +207,7 @@ def upload(
click.echo(parameter("Table type", table_type))

if not auto_proceed:
auto_proceed = question(
"Proceed? (y,n)", default="y", transformer=lambda s: s == "y"
)
auto_proceed = question("Proceed? (y,n)", default="y", transformer=lambda s: s == "y")

if auto_proceed:
app.upload(
Expand Down
4 changes: 1 addition & 3 deletions plugins/csv_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ def prepare(self) -> None:
with open(self.filename) as f:
self._total_chunks = sum(1 for _ in f) - 1

self._total_chunks = (
self._total_chunks + self._chunk_size - 1
) // self._chunk_size
self._total_chunks = (self._total_chunks + self._chunk_size - 1) // self._chunk_size
self._current_chunk = 0

def get_schema(self) -> list[models.ColumnDescription]:
Expand Down
4 changes: 1 addition & 3 deletions plugins/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def prepare(self) -> None:
raise ValueError(f"HDU {self.hdu_index} is not a binary table")

self._schema = self._table.columns
self._total_batches = (
len(self._table.data) + self._batch_size - 1
) // self._batch_size
self._total_batches = (len(self._table.data) + self._batch_size - 1) // self._batch_size
self._current_batch = 0

def get_schema(self) -> list[models.ColumnDescription]:
Expand Down
36 changes: 9 additions & 27 deletions plugins/vizier.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,34 +76,26 @@ def _write_schema_cache(self, catalog_name: str, table_name: str):

app.logger.debug("wrote cache", location=str(cache_filename))

def _obtain_cache_path(
self, type_path: str, catalog_name: str, table_name: str
) -> pathlib.Path:
def _obtain_cache_path(self, type_path: str, catalog_name: str, table_name: str) -> pathlib.Path:
filename = f"{_get_filename(catalog_name, table_name)}.vot"
path = pathlib.Path(self.cache_path) / type_path / filename

path.parent.mkdir(parents=True, exist_ok=True)
return path

def _get_schema_from_cache(
self, catalog_name: str, table_name: str
) -> tree.VOTableFile:
def _get_schema_from_cache(self, catalog_name: str, table_name: str) -> tree.VOTableFile:
cache_filename = self._obtain_cache_path("schemas", catalog_name, table_name)
return votable.parse(str(cache_filename))

def _get_table_from_cache(
self, catalog_name: str, table_name: str
) -> astropy.table.Table:
def _get_table_from_cache(self, catalog_name: str, table_name: str) -> astropy.table.Table:
cache_filename = self._obtain_cache_path("tables", catalog_name, table_name)
return astropy.table.Table.read(cache_filename, format="votable")

def prepare(self) -> None:
pass

def get_schema(self) -> list[models.ColumnDescription]:
if not self._obtain_cache_path(
"schemas", self.catalog_name, self.table_name
).exists():
if not self._obtain_cache_path("schemas", self.catalog_name, self.table_name).exists():
app.logger.debug("did not hit cache for the schema, downloading")
self._write_schema_cache(self.catalog_name, self.table_name)

Expand All @@ -122,9 +114,7 @@ def get_schema(self) -> list[models.ColumnDescription]:
]

def get_data(self) -> Generator[tuple[pandas.DataFrame, float]]:
if not self._obtain_cache_path(
"tables", self.catalog_name, self.table_name
).exists():
if not self._obtain_cache_path("tables", self.catalog_name, self.table_name).exists():
app.logger.debug("did not hit cache for the table, downloading")
self._write_table_cache(self.catalog_name, self.table_name)

Expand Down Expand Up @@ -154,9 +144,7 @@ def get_table_name(self) -> str:
def get_bibcode(self) -> str:
self.get_schema()
schema = self._get_schema_from_cache(self.catalog_name, self.table_name)
bibcode_info = next(
filter(lambda info: info.name == "cites", schema.resources[0].infos)
)
bibcode_info = next(filter(lambda info: info.name == "cites", schema.resources[0].infos))
return bibcode_info.value.split(":")[1]

def get_description(self) -> str:
Expand All @@ -173,9 +161,7 @@ def _get_filename(catalog_name: str, table_name: str) -> str:
return f"{_sanitize_filename(catalog_name)}_{_sanitize_filename(table_name)}"


def _get_columns(
client: vizier.VizierClass, catalog_name: str, table_name: str
) -> list[str]:
def _get_columns(client: vizier.VizierClass, catalog_name: str, table_name: str) -> list[str]:
catalogs = client.get_catalogs(catalog_name) # type: ignore

meta = None
Expand All @@ -190,9 +176,7 @@ def _get_columns(
return meta.colnames


def _download_table(
table_name: str, columns: list[str], max_rows: int | None = None
) -> str:
def _download_table(table_name: str, columns: list[str], max_rows: int | None = None) -> str:
out_max = "unlimited" if max_rows is None else max_rows

payload = [
Expand All @@ -216,9 +200,7 @@ def _download_table(
"Content-Type": "application/x-www-form-urlencoded",
}

response = requests.request(
http.HTTPMethod.POST, VIZIER_URL, data=data, headers=headers
)
response = requests.request(http.HTTPMethod.POST, VIZIER_URL, data=data, headers=headers)

return response.text

Expand Down
82 changes: 51 additions & 31 deletions plugins/vizier_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Generator
from typing import final

import numpy as np
import pandas
from astropy import table
from astroquery import utils, vizier
Expand All @@ -12,38 +13,38 @@


def _sanitize_filename(string: str) -> str:
return string.replace("/", "_")


def dtype_to_datatype(dtype) -> models.DatatypeEnum:
# Accept both dtypes and strings
return (
string.replace("/", "_")
.replace("&", "_and_")
.replace(">", "_gt_")
.replace("<", "_lt_")
.replace("=", "_eq_")
.replace(" ", "_")
.replace("!", "_not_")
)


def dtype_to_datatype(dtype: str | np.dtype) -> models.DatatypeEnum:
dtype_str = str(dtype).lower()
# Typical mappings
if any(
dtype_str.startswith(x)
for x in ("str", "unicode", "<u", "|s", "<U", "object", "bytes")
):

if any(dtype_str.startswith(x) for x in ("str", "unicode", "<u", "|s", "<U", "object", "bytes")):
return models.DatatypeEnum.STRING

if any(
dtype_str.startswith(x)
for x in (
"int",
"int8",
"int16",
"int32",
"int64",
"uint",
"uint8",
"uint16",
"uint32",
"uint64",
)
):
return models.DatatypeEnum.INTEGER
if any(
dtype_str.startswith(x)
for x in ("float", "float16", "float32", "float64", "double", "float128")
):

if any(dtype_str.startswith(x) for x in ("int32", "int64", "uint32", "uint64")):
return models.DatatypeEnum.LONG

if any(dtype_str.startswith(x) for x in ("float", "float16", "float32", "float64", "double", "float128")):
return models.DatatypeEnum.DOUBLE
return models.DatatypeEnum.STRING

Expand All @@ -55,41 +56,53 @@ def __init__(self, cache_path: str = ".vizier_cache/"):
self._client.ROW_LIMIT = -1

def _obtain_cache_path(
self, catalog_name: str, row_num: int | None = None
self, catalog_name: str, row_num: int | None = None, constraints: dict[str, str] | None = None
) -> pathlib.Path:
filename = f"{_sanitize_filename(catalog_name)}.vot"
filename = f"{catalog_name}.vot"
if row_num is not None:
filename = f"{_sanitize_filename(catalog_name)}_rows_{row_num}.vot"
filename = f"{catalog_name}_rows_{row_num}.vot"
if constraints:
sorted_constraints = sorted(constraints.items())
constraint_str = "_".join(f"{k}_{v}" for k, v in sorted_constraints)
filename = f"{catalog_name}_constraints_{constraint_str}.vot"

filename = _sanitize_filename(filename)
path = pathlib.Path(self.cache_path) / "catalogs" / filename
path.parent.mkdir(parents=True, exist_ok=True)
return path

def _write_catalog_cache(
self, catalog_name: str, row_num: int | None = None
self, catalog_name: str, row_num: int | None = None, constraints: dict[str, str] | None = None
) -> None:
app.logger.info(
"downloading catalog from Vizier",
catalog_name=catalog_name,
row_num=row_num,
constraints=constraints,
)
client = self._client
if row_num is not None:
client = vizier.Vizier()
client.ROW_LIMIT = row_num
catalogs: utils.TableList = client.get_catalogs(catalog_name) # pyright: ignore[reportAttributeAccessIssue]
query_kwargs = {"catalog": catalog_name}
if constraints:
query_kwargs.update(constraints)
catalogs: utils.TableList = client.query_constraints(**query_kwargs) # pyright: ignore[reportAttributeAccessIssue]

if not catalogs:
raise ValueError("catalog not found")

cache_filename = self._obtain_cache_path(catalog_name, row_num)
cache_filename = self._obtain_cache_path(catalog_name, row_num, constraints)
catalogs[0].write(str(cache_filename), format="votable")
app.logger.debug("wrote catalog cache", location=str(cache_filename))

def get_table(self, catalog_name: str, row_num: int | None = None) -> table.Table:
cache_path = self._obtain_cache_path(catalog_name, row_num)
def get_table(
self, catalog_name: str, row_num: int | None = None, constraints: dict[str, str] | None = None
) -> table.Table:
cache_path = self._obtain_cache_path(catalog_name, row_num, constraints)
if not cache_path.exists():
app.logger.debug("did not hit cache for the catalog, downloading")
self._write_catalog_cache(catalog_name, row_num)
self._write_catalog_cache(catalog_name, row_num, constraints)

return table.Table.read(cache_path, format="votable")

Expand All @@ -108,9 +121,15 @@ def __init__(
self,
catalog_name: str,
table_name: str,
*constraints: str,
cache_path: str = ".vizier_cache/",
batch_size: int = 500,
batch_size: int = 10,
):
if len(constraints) % 2 != 0:
raise ValueError("constraints must be provided in pairs (column, constraint_value)")
self.constraints: dict[str, str] = {}
for i in range(0, len(constraints), 2):
self.constraints[constraints[i]] = constraints[i + 1]
self.catalog_name = catalog_name
self.table_name = table_name
self.batch_size = batch_size
Expand Down Expand Up @@ -152,7 +171,8 @@ def get_schema(self) -> list[models.ColumnDescription]:
return result

def get_data(self) -> Generator[tuple[pandas.DataFrame, float]]:
t = self.client.get_table(self.table_name)
constraints = self.constraints if self.constraints else None
t = self.client.get_table(self.table_name, constraints=constraints)

total_rows = len(t)
app.logger.info("uploading table", total_rows=total_rows)
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies = [
"astropy>=7.1.0",
"openapi-python-client>=0.27.1",
"pandas>=2.3.3",
"numpy>=2.3.4",
]

[tool.pytest.ini_options]
Expand All @@ -25,3 +26,6 @@ dev = [
"pandas-stubs>=2.2.3.250308",
"pytest>=8.4.2",
]

[tool.ruff]
line-length = 120
Loading
Loading