Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ share/python-wheels/
*.egg
MANIFEST

# Jetbrains stuff
.idea

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
Expand Down
2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/python-libraries.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ format:
poetry run black .

test:
poetry run pytest -s -x
poetry run pytest -s -x
1,445 changes: 800 additions & 645 deletions auth-system/poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion auth-system/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ packages = [{ include = "macrostrat" }]

[tool.poetry.dependencies]
python = "^3.10"
"macrostrat.database" = "^3.3.1"
"macrostrat.database" = "^3.3.1||^4.0.0"
"macrostrat.utils" = "^1.2.0"
PyJWT = "^1.7.1 || ^2.0"
werkzeug = "^2.3.7 || ^3.0"
Expand Down
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from docker.client import DockerClient
from dotenv import load_dotenv
from pytest import fixture
from sqlalchemy import create_engine
from sqlalchemy.exc import OperationalError

from macrostrat.database.utils import create_engine
from macrostrat.dinosaur.upgrade_cluster.utils import database_cluster, get_unused_port

load_dotenv()
Expand Down
14 changes: 14 additions & 0 deletions database/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# Changelog

## [4.0.0] - Unreleased

- Upgrade to `psycopg` version 3 instead of Psycopg2
- `psycopg`'s new
[adaptation system](https://www.psycopg.org/psycopg3/docs/advanced/adapt.html)
means that parameter binding has changed substantially. Consequently, `AsIs`
and other parameter-binding extensions are no longer supported.
- `psycopg`'s new
[async mode](https://www.psycopg.org/psycopg3/docs/advanced/async.html#interrupting-async-operations)
allows us to remove the `set_wait_callback` approach to waiting for
long-running operations
- This is a major breaking change, but it should be mostly transparent, except
for a need to update custom parameter binding code.

## [3.5.3] - 2024-12-23

- Fix errors and add tests for `run_sql` changes
Expand Down
21 changes: 9 additions & 12 deletions database/macrostrat/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
from pathlib import Path
from typing import Optional, Union

from psycopg2.errors import InvalidSavepointSpecification
from psycopg2.sql import Identifier
from sqlalchemy import URL, Engine, MetaData, create_engine, inspect
from sqlalchemy.exc import IntegrityError, InternalError
from psycopg.errors import InvalidSavepointSpecification
from psycopg.sql import Identifier
from sqlalchemy import URL, Engine, MetaData, inspect
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session, scoped_session, sessionmaker
from sqlalchemy.sql.expression import Insert

from macrostrat.utils import get_logger

from .mapper import DatabaseMapper
from .postgresql import on_conflict, prefix_inserts # noqa
from .utils import ( # noqa
Expand All @@ -25,6 +24,7 @@
run_fixtures,
run_query,
run_sql,
create_engine,
)

metadata = MetaData()
Expand Down Expand Up @@ -60,12 +60,8 @@ def __init__(self, db_conn: Union[str, URL, Engine], *, echo_sql=False, **kwargs

self.instance_params = kwargs.pop("instance_params", {})

if isinstance(db_conn, Engine):
log.info(f"Set up database connection with engine {db_conn.url}")
self.engine = db_conn
else:
log.info(f"Setting up database connection with URL '{db_conn}'")
self.engine = create_engine(db_conn, echo=echo_sql, **kwargs)
self.engine = create_engine(db_conn, echo=echo_sql, **kwargs)

self.metadata = kwargs.get("metadata", metadata)

# Scoped session for database
Expand Down Expand Up @@ -334,6 +330,7 @@ def savepoint(self, name=None, rollback="on-error", connection=None):

if connection is None:
connection = self.session.connection()

params = {"name": Identifier(name)}
run_query(connection, "SAVEPOINT {name}", params)
should_rollback = rollback == "always"
Expand All @@ -356,7 +353,7 @@ def _clear_savepoint(connection, name, rollback=True):
run_query(connection, "ROLLBACK TO SAVEPOINT {name}", params)
else:
run_query(connection, "RELEASE SAVEPOINT {name}", params)
except InternalError as err:
except OperationalError as err:
if isinstance(err.orig, InvalidSavepointSpecification):
log.warning(
f"Savepoint {name} does not exist; we may have already rolled back."
Expand Down
9 changes: 5 additions & 4 deletions database/macrostrat/database/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from contextvars import ContextVar
from typing import TYPE_CHECKING

import psycopg2
from sqlalchemy.dialects import postgresql
from sqlalchemy.exc import CompileError
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert, text
from sqlalchemy.sql.dml import Insert
from sqlalchemy.sql.expression import text

if TYPE_CHECKING:
from ..database import Database
Expand All @@ -26,9 +26,10 @@ def on_conflict(action="restrict"):
_insert_mode.reset(token)


# @compiles(Insert, "postgresql")
@compiles(Insert, "postgresql")
def prefix_inserts(insert, compiler, **kw):
"""Conditionally adapt insert statements to use on-conflict resolution (a PostgreSQL feature)"""

if insert._post_values_clause is not None:
return compiler.visit_insert(insert, **kw)

Expand Down Expand Up @@ -63,7 +64,7 @@ def prefix_inserts(insert, compiler, **kw):
def table_exists(db: Database, table_name: str, schema: str = "public") -> bool:
"""Check if a table exists in a PostgreSQL database."""
sql = """SELECT EXISTS (
SELECT FROM information_schema.tables
SELECT FROM information_schema.tables
WHERE table_schema = :schema
AND table_name = :table_name
);"""
Expand Down
64 changes: 46 additions & 18 deletions database/macrostrat/database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@
from pathlib import Path
from re import search
from sys import stderr
from time import sleep
from typing import IO, Union
from warnings import warn

import psycopg2.errors
from click import echo, secho
from psycopg2.extensions import set_wait_callback
from psycopg2.extras import wait_select
from psycopg2.sql import SQL, Composable, Composed
from psycopg.errors import QueryCanceled
from psycopg.sql import SQL, Composable, Composed
from rich.console import Console
from sqlalchemy import MetaData, create_engine, text
from sqlalchemy import MetaData, text
from sqlalchemy import create_engine as base_create_engine
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import (
IntegrityError,
InternalError,
Expand All @@ -29,6 +28,7 @@
from sqlalchemy_utils import create_database as _create_database
from sqlalchemy_utils import database_exists, drop_database
from sqlparse import format, split
from time import sleep

from macrostrat.utils import cmd, get_logger

Expand Down Expand Up @@ -196,6 +196,8 @@ def _get_cursor(connectable):
while hasattr(conn, "driver_connection") or hasattr(conn, "connection"):
if hasattr(conn, "driver_connection"):
conn = conn.driver_connection
elif conn.connection == conn:
break
else:
conn = conn.connection
if callable(conn):
Expand Down Expand Up @@ -296,7 +298,7 @@ def _run_sql(connectable, sql, params=None, **kwargs):
if pre_bind_params is not None:
if not isinstance(query, SQL):
query = SQL(query)
# Pre-bind the parameters using PsycoPG2
# Pre-bind the parameters using psycopg
query = query.format(**pre_bind_params)

if isinstance(query, (SQL, Composed)):
Expand Down Expand Up @@ -328,9 +330,6 @@ def _run_sql(connectable, sql, params=None, **kwargs):
)
continue

# This only does something for postgresql, but it's harmless to run it for other engines
set_wait_callback(wait_select)

try:
trans = connectable.begin()
except InvalidRequestError:
Expand Down Expand Up @@ -360,7 +359,7 @@ def _run_sql(connectable, sql, params=None, **kwargs):

_print_error(sql_text, err, file=output_file)
finally:
set_wait_callback(None)
pass


def _should_raise_query_error(err):
Expand All @@ -379,7 +378,7 @@ def _should_raise_query_error(err):
# database backends.
# Ideally we could handle operational errors more gracefully
if (
isinstance(orig_err, psycopg2.errors.QueryCanceled)
isinstance(orig_err, QueryCanceled)
or getattr(orig_err, "pgcode", None) == "57014"
):
return True
Expand Down Expand Up @@ -602,6 +601,26 @@ def create_database(url, **kwargs):
_create_database(url, **kwargs)


def create_engine(db_conn, **kwargs):
if isinstance(db_conn, Engine):
log.info(f"Set up database connection with engine {db_conn.url}")
if db_conn.driver == "psycopg2":
log.warning(
"The psycopg2 driver is deprecated. Please use psycopg3 instead."
)
return db_conn
else:
log.info(f"Setting up database connection with URL '{db_conn}'")
url = db_conn
if isinstance(url, str):
url = make_url(url)
# Set the driver to psycopg if not already set
if url.drivername != "postgresql+psycopg":
url = url.set(drivername="postgresql+psycopg")

return base_create_engine(url, **kwargs)


def connection_args(engine):
"""Get PostgreSQL connection arguments for an engine"""
_psql_flags = {"-U": "username", "-h": "host", "-p": "port", "-P": "password"}
Expand All @@ -617,15 +636,24 @@ def connection_args(engine):
return flags, engine.url.database


def db_isready(engine_or_url):
args, _ = connection_args(engine_or_url)
c = cmd("pg_isready", args, capture_output=True)
return c.returncode == 0
def db_isready(engine_or_url, use_shell_command=False):
if use_shell_command:
args, _ = connection_args(engine_or_url)
c = cmd("pg_isready", args, capture_output=True)
return c.returncode == 0
# Use a more typical sqlalchemy connection approach
engine = create_engine(engine_or_url)
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
return True
except OperationalError:
return False


def wait_for_database(engine_or_url, quiet=False):
def wait_for_database(engine_or_url, *, quiet=False, use_shell_command=False):
msg = "Waiting for database..."
while not db_isready(engine_or_url):
while not db_isready(engine_or_url, use_shell_command=use_shell_command):
if not quiet:
echo(msg, err=True)
log.info(msg)
Expand Down
Loading