Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions app/data/model/layer0.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ class Layer0Object:
object_id: str
data: list[interface.CatalogObject]

def get[T](self, t: type[T]) -> T | None:
for obj in self.data:
if isinstance(obj, t):
return obj

return None


@dataclass
class CIResultObjectNew:
Expand Down
9 changes: 8 additions & 1 deletion app/data/repositories/layer0/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,16 @@ def upsert_objects(self, table_id: int, objects: list[model.Layer0Object]) -> No
def get_table_statistics(self, table_id: int) -> model.TableStatistics:
return self.objects_repo.get_table_statistics(table_id)

def get_objects(self, table_id: int, limit: int, offset: int) -> list[model.Layer0Object]:
def get_objects_by_id(self, table_id: int, limit: int, offset: int) -> list[model.Layer0Object]:
return self.objects_repo.get_objects(table_id, limit, offset)

def get_objects(self, table_name: str, limit: int, offset: int) -> list[model.Layer0Object]:
meta = self.fetch_metadata_by_name(table_name)
if meta.table_id is None:
raise RuntimeError(f"{table_name} has no table_id")

return self.objects_repo.get_objects(meta.table_id, limit, offset)

def get_processed_objects(
self, table_id: int, limit: int, offset: str | None = None
) -> list[model.Layer0ProcessedObject]:
Expand Down
20 changes: 9 additions & 11 deletions app/data/repositories/layer2/filters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import abc
from typing import Any

import astropy
import astropy.units
import astropy.units.quantity

from astropy import units as u

class Filter(abc.ABC):
@abc.abstractmethod
Expand Down Expand Up @@ -80,18 +77,19 @@ def get_params(self):


class ICRSCoordinatesInRadiusFilter(Filter):
def __init__(self, radius: float | astropy.units.quantity.Quantity):
if isinstance(radius, astropy.units.quantity.Quantity):
radius = radius.to(astropy.units.deg).value
def __init__(self, radius: float | u.quantity.Quantity):
if isinstance(radius, u.quantity.Quantity):
radius = radius.to(u.deg).value

self._radius = radius

def get_query(self):
return """
ST_Distance(
ST_MakePoint((sp.params->>'dec')::float, (params->>'ra')::float-180),
ST_MakePoint(layer2.icrs.dec, layer2.icrs.ra-180)
) < %s
ST_DWithin(
ST_MakePoint((sp.params->>'dec')::float, (sp.params->>'ra')::float-180),
ST_MakePoint(layer2.icrs.dec, layer2.icrs.ra-180),
%s
)
"""

def get_params(self):
Expand Down
10 changes: 8 additions & 2 deletions app/tasks/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@

@final
class ProcessTask(interface.Task):
def __init__(self, table_id: int, batch_size: int = 500, workers: int = 8) -> None:
def __init__(
self, table_id: int, batch_size: int = 500, workers: int = 8, crossmatch_enabled: bool = False
) -> None:
self.table_id = table_id
self.batch_size = batch_size
self.workers = workers
self.log = structlog.get_logger()
self.crossmatch_enabled = crossmatch_enabled

@classmethod
def name(cls) -> str:
Expand All @@ -34,12 +37,15 @@ def run(self):
self.log.info("Starting marking of objects", **ctx)
processing.mark_objects(self.layer0_repo, self.table_id, self.batch_size)

if not self.crossmatch_enabled:
return

self.log.info("Erasing previous crossmatching results", **ctx)
self.layer0_repo.erase_crossmatch_results(self.table_id)

self.log.info("Starting cross-identification", **ctx)
for offset, data in containers.read_batches(
self.layer0_repo.get_objects,
self.layer0_repo.get_objects_by_id,
lambda data: len(data) == 0,
0,
lambda _, offset: offset + self.batch_size,
Expand Down
2 changes: 2 additions & 0 deletions experiments/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
data
results
162 changes: 162 additions & 0 deletions experiments/bayes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import sys
from pathlib import Path

import numpy as np
from astropy import coordinates
from astropy import units as u

sys.path.insert(0, str(Path(__file__).parent / ".."))

import math

import pandas
import structlog

from app.data import model
from app.data.repositories.layer2 import filters, params
from app.data.repositories.layer2.repository import Layer2Repository
from experiments.entities import CrossIdentificationResult

logger = structlog.get_logger()


def calculate_bayes_factor(ra1: float, dec1: float, sigma1: float, ra2: float, dec2: float, sigma2: float) -> float:
"""
Calculate the Bayes factor for two observations.

Args:
ra1, dec1: Coordinates of first observation (degrees)
sigma1: Astrometric error of first observation (degrees)
ra2, dec2: Coordinates of second observation (degrees)
sigma2: Astrometric error of second observation (degrees)

Returns:
Bayes factor B(H,K|D)
"""
c1 = coordinates.SkyCoord(ra1 * u.deg, dec1 * u.deg, frame="icrs")
c2 = coordinates.SkyCoord(ra2 * u.deg, dec2 * u.deg, frame="icrs")
sep = c1.separation(c2).to(u.rad).value

sigma1_rad = math.radians(sigma1)
sigma2_rad = math.radians(sigma2)

return 2 / (sigma1_rad**2 + sigma2_rad**2) * math.exp(-(sep**2) / (2.0 * (sigma1_rad**2 + sigma2_rad**2)))


def bayes_to_posterior(bayes_factor: float, prior: float) -> float:
if bayes_factor <= 0:
return 0

return (1.0 + (1.0 - prior) / (bayes_factor * prior)) ** -1


def posterior_to_bayes(posterior: float, prior: float) -> float:
assert posterior >= 0 and posterior <= 1

return (posterior * (1.0 - prior)) / (prior * (1.0 - posterior))


def cross_identify_objects_bayesian(
parameters: pandas.DataFrame,
layer2_repo: Layer2Repository,
lower_posterior_probability: float,
upper_posterior_probability: float,
cutoff_radius_degrees: float,
prior_probability: float,
) -> dict[str, CrossIdentificationResult]:
"""
Perform cross-identification using Bayesian approach.

Args:
positions: DataFrame, containing "ra", "dec" and "e_pos" columns.
layer2_repo: Layer2Repository instance for database queries
lower_posterior_probability: Threshold for "definitely different"
upper_posterior_probability: Threshold for "definitely same"
cutoff_radius_degrees: Search radius in degrees
prior_probability: Prior probability that two random objects are the same

Returns:
Dictionary mapping object IDs to cross-identification results
"""
results: dict[str, CrossIdentificationResult] = {}

ra_column = parameters["ra"]
dec_column = parameters["dec"]
error_column = parameters["e_pos"]

bf_lower = posterior_to_bayes(lower_posterior_probability, prior_probability)
bf_upper = posterior_to_bayes(upper_posterior_probability, prior_probability)

logger.info("Bayes factors", b_lower=bf_lower, b_upper=bf_upper)

batch_size = 50
total_objects = len(ra_column)

for batch_start in range(0, total_objects, batch_size):
batch_end = min(batch_start + batch_size, total_objects)

logger.info("Processing batch", batch_n=batch_start // batch_size + 1, start=batch_start + 1, end=batch_end)

search_types: dict[str, filters.Filter] = {"icrs": filters.ICRSCoordinatesInRadiusFilter(cutoff_radius_degrees)}

search_params = {}
for i in range(batch_start, batch_end):
object_id = f"obj_{i}"
search_params[object_id] = params.ICRSSearchParams(ra=ra_column[i], dec=dec_column[i])

batch_results = layer2_repo.query_batch(
catalogs=[model.RawCatalog.ICRS],
search_types=search_types,
search_params=search_params,
limit=10000,
offset=0,
)

for i in range(batch_start, batch_end):
object_id = f"obj_{i}"
candidates = batch_results.get(object_id, [])

if len(candidates) == 0:
results[object_id] = CrossIdentificationResult(status="new")
continue

candidate_pgcs = []

for candidate in candidates:
icrs_data: model.ICRSCatalogObject | None = None
for catalog_obj in candidate.data:
if isinstance(catalog_obj, model.ICRSCatalogObject):
icrs_data = catalog_obj
break

if icrs_data is None:
continue

bf = calculate_bayes_factor(
ra_column[i],
dec_column[i],
error_column[i],
icrs_data.ra,
icrs_data.dec,
np.sqrt(icrs_data.e_ra**2 + icrs_data.e_dec**2),
)

candidate_pgcs.append((candidate.pgc, bf))

high_probability_matches = [
(pgc, bayes_to_posterior(bf, prior_probability)) for pgc, bf in candidate_pgcs if bf > bf_upper
]
low_probability_matches = [
(pgc, bayes_to_posterior(bf, prior_probability)) for pgc, bf in candidate_pgcs if bf < bf_lower
]
all_matches = [(pgc, bayes_to_posterior(bf, prior_probability)) for pgc, bf in candidate_pgcs]

if len(high_probability_matches) == 1 and len(low_probability_matches) == len(candidate_pgcs) - 1:
pgc, posterior = high_probability_matches[0]
results[object_id] = CrossIdentificationResult(status="existing", pgc_numbers={pgc: posterior})
elif len(low_probability_matches) == len(all_matches):
results[object_id] = CrossIdentificationResult(status="new")
else:
results[object_id] = CrossIdentificationResult(status="collision", pgc_numbers=dict(all_matches))

return results
23 changes: 23 additions & 0 deletions experiments/bayes_analyze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent / ".."))

import pandas
from matplotlib import pyplot as plt

plt.style.use("petroff10")

parameter = "prior"

df = pandas.read_csv("experiments/data/results.csv")

plt.plot(df[parameter], df["collision_objects_ratio"], label="collision")
# plt.plot(df[parameter], df["existing_objects_ratio"], label="existing")
# plt.plot(df[parameter], df["new_objects_ratio"], label="new")
plt.legend()
plt.xlabel(parameter)
plt.ylabel("Object ratio")
plt.grid(True)
plt.xscale("log")
plt.show()
Loading
Loading