diff --git a/app/data/repositories/common.py b/app/data/repositories/common.py index 63946f84..b82117c2 100644 --- a/app/data/repositories/common.py +++ b/app/data/repositories/common.py @@ -38,3 +38,9 @@ def get_source_by_id(self, source_id: int) -> model.Bibliography: row = self._storage.query_one(template.GET_SOURCE_BY_ID, params=[source_id]) return model.Bibliography(**row) + + def register_pgcs(self, pgcs: list[int]): + self._storage.exec( + f"INSERT INTO common.pgc (id) VALUES {','.join(['(%s)'] * len(pgcs))} ON CONFLICT (id) DO NOTHING", + params=pgcs, + ) diff --git a/app/data/repositories/layer0/objects.py b/app/data/repositories/layer0/objects.py index 7556107d..ae561770 100644 --- a/app/data/repositories/layer0/objects.py +++ b/app/data/repositories/layer0/objects.py @@ -173,21 +173,36 @@ def add_crossmatch_result(self, data: dict[str, model.CIResult]) -> None: self._storage.exec(query, params=params) def upsert_pgc(self, pgcs: dict[str, int | None]) -> None: - values = [] - params = [] + pgcs_to_insert: dict[str, int] = {} + + new_objects = [object_id for object_id, pgc in pgcs.items() if pgc is None] + + if new_objects: + result = self._storage.query( + f"""INSERT INTO common.pgc + VALUES {",".join(["(DEFAULT)"] * len(new_objects))} + RETURNING id""", + ) + + ids = [row["id"] for row in result] + + for object_id, pgc_id in zip(new_objects, ids, strict=False): + pgcs_to_insert[object_id] = pgc_id for object_id, pgc in pgcs.items(): - params.append(object_id) - if pgc is None: - values.append("(%s, DEFAULT)") - else: + if pgc is not None: + pgcs_to_insert[object_id] = pgc + + if pgcs_to_insert: + update_query = "UPDATE rawdata.objects SET pgc = v.pgc FROM (VALUES " + params = [] + values = [] + + for object_id, pgc_id in pgcs_to_insert.items(): values.append("(%s, %s)") - params.append(pgc) - - self._storage.exec( - f""" - INSERT INTO rawdata.pgc (object_id, id) VALUES {",".join(values)} - ON CONFLICT (object_id) DO UPDATE SET id = EXCLUDED.id - """, - params=params, - ) + params.extend([object_id, pgc_id]) + + update_query += ",".join(values) + update_query += ") AS v(object_id, pgc) WHERE rawdata.objects.id = v.object_id" + + self._storage.exec(update_query, params=params) diff --git a/app/data/repositories/layer1.py b/app/data/repositories/layer1.py index cfeb82f4..115ff304 100644 --- a/app/data/repositories/layer1.py +++ b/app/data/repositories/layer1.py @@ -70,16 +70,16 @@ def get_new_observations( query = f"""SELECT * FROM {object_cls.layer1_table()} AS l1 - JOIN rawdata.pgc AS pgc ON l1.object_id = pgc.object_id - WHERE id IN ( - SELECT DISTINCT id + JOIN rawdata.objects AS o ON l1.object_id = o.id + WHERE o.pgc IN ( + SELECT DISTINCT o.pgc FROM {object_cls.layer1_table()} AS l1 - JOIN rawdata.pgc AS pgc ON l1.object_id = pgc.object_id - WHERE modification_time > %s AND pgc.id > %s - ORDER BY id + JOIN rawdata.objects AS o ON l1.object_id = o.id + WHERE o.modification_time > %s AND o.pgc > %s + ORDER BY o.pgc LIMIT %s ) - ORDER BY pgc.id ASC""" + ORDER BY o.pgc ASC""" rows = self._storage.query(query, params=[dt, offset, limit]) @@ -87,7 +87,7 @@ def get_new_observations( for row in rows: object_id = row.pop("object_id") - pgc = int(row.pop("id")) + pgc = int(row.pop("pgc")) catalog_object = object_cls.from_layer1(row) key = (pgc, object_id) diff --git a/postgres/migrations/V014__common_pgc_connections.sql b/postgres/migrations/V014__common_pgc_connections.sql new file mode 100644 index 00000000..2449e0c6 --- /dev/null +++ b/postgres/migrations/V014__common_pgc_connections.sql @@ -0,0 +1,23 @@ +ALTER TABLE layer2.cz +ADD CONSTRAINT cz_pgc_fkey FOREIGN KEY (pgc) REFERENCES common.pgc(id) ON DELETE RESTRICT ON UPDATE CASCADE; +ALTER TABLE layer2.designation +ADD CONSTRAINT cz_pgc_fkey FOREIGN KEY (pgc) REFERENCES common.pgc(id) ON DELETE RESTRICT ON UPDATE CASCADE; +ALTER TABLE layer2.icrs +ADD CONSTRAINT cz_pgc_fkey FOREIGN KEY (pgc) REFERENCES common.pgc(id) ON DELETE RESTRICT ON UPDATE CASCADE; + +CREATE OR REPLACE FUNCTION rawdata_set_modification_time() +RETURNS TRIGGER AS $$ +BEGIN + IF NEW.pgc IS DISTINCT FROM OLD.pgc THEN + NEW.modification_time := now(); + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER set_modification_time_on_pgc_update +BEFORE UPDATE OF pgc ON rawdata.objects +FOR EACH ROW +EXECUTE FUNCTION rawdata_set_modification_time(); + +DROP TABLE rawdata.pgc; diff --git a/tests/integration/layer2_import_test.py b/tests/integration/layer2_import_test.py index 64dc0044..42b948e8 100644 --- a/tests/integration/layer2_import_test.py +++ b/tests/integration/layer2_import_test.py @@ -42,6 +42,7 @@ def test_import_two_catalogs(self): ["123", "124"], ) + self.common_repo.register_pgcs([1234, 1245]) self.layer0_repo.upsert_pgc({"123": 1234, "124": 1245}) self.layer1_repo.save_data( [ diff --git a/tests/integration/layer2_repository_test.py b/tests/integration/layer2_repository_test.py index 92e13d27..bae56135 100644 --- a/tests/integration/layer2_repository_test.py +++ b/tests/integration/layer2_repository_test.py @@ -12,6 +12,7 @@ class Layer2RepositoryTest(unittest.TestCase): def setUpClass(cls) -> None: cls.pg_storage = lib.TestPostgresStorage.get() + cls.common_repo = repositories.CommonRepository(cls.pg_storage.get_storage(), structlog.get_logger()) cls.layer2_repo = repositories.Layer2Repository(cls.pg_storage.get_storage(), structlog.get_logger()) def tearDown(self): @@ -23,6 +24,7 @@ def test_one_object(self): model.Layer2CatalogObject(2, model.DesignationCatalogObject(design="test2")), ] + self.common_repo.register_pgcs([1, 2]) self.layer2_repo.save_data(objects) actual = self.layer2_repo.query( @@ -42,6 +44,7 @@ def test_several_objects(self): model.Layer2CatalogObject(2, model.ICRSCatalogObject(ra=11, dec=11, e_ra=0.1, e_dec=0.1)), ] + self.common_repo.register_pgcs([1, 2]) self.layer2_repo.save_data(objects) actual = self.layer2_repo.query( @@ -65,6 +68,7 @@ def test_several_catalogs(self): model.Layer2CatalogObject(2, model.DesignationCatalogObject(design="test2")), ] + self.common_repo.register_pgcs([1, 2]) self.layer2_repo.save_data(objects) actual = self.layer2_repo.query( @@ -94,6 +98,7 @@ def test_several_filters(self): model.Layer2CatalogObject(1, model.DesignationCatalogObject(design="test")), ] + self.common_repo.register_pgcs([1, 2]) self.layer2_repo.save_data(objects) actual = self.layer2_repo.query( @@ -134,6 +139,7 @@ def test_pagination(self): model.Layer2CatalogObject(5, model.ICRSCatalogObject(ra=14, dec=14, e_ra=0.1, e_dec=0.1)), ] + self.common_repo.register_pgcs([1, 2, 3, 4, 5]) self.layer2_repo.save_data(objects) actual = self.layer2_repo.query( @@ -155,6 +161,7 @@ def test_batch_query(self): model.Layer2CatalogObject(5, model.ICRSCatalogObject(ra=14, dec=14, e_ra=0.1, e_dec=0.1)), ] + self.common_repo.register_pgcs([1, 2, 3, 4, 5]) self.layer2_repo.save_data(objects) actual = self.layer2_repo.query_batch(