Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions database/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## [3.5.1] - 2024-12-21

- Add a `statement_filter` parameter to the `run_sql` function to allow for
filtering of statements in a SQL file.
- Improved the consistency of the `Database.run_sql` function with the `run_sql`
utility function.

## [3.5.0] - 2024-11-25

- Add database transfer utilities for asynchronous `pg_load` and `pg_dump`
Expand Down
2 changes: 1 addition & 1 deletion database/macrostrat/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def run_sql(self, fn, params=None, **kwargs):
Returns: Iterator of results from the query.
"""
params = self._setup_params(params, kwargs)
return iter(run_sql(self.session, fn, params, **kwargs))
return run_sql(self.session, fn, params, **kwargs)

def run_query(self, sql, params=None, **kwargs):
"""Run a single query on the database object, returning the result.
Expand Down
25 changes: 21 additions & 4 deletions database/macrostrat/database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from typing import IO, Union
from warnings import warn

import psycopg2.errors
from click import echo, secho
from psycopg2.extensions import set_wait_callback
from psycopg2.extras import wait_select
from psycopg2.sql import SQL, Composable, Composed
import psycopg2.errors
from rich.console import Console
from sqlalchemy import MetaData, create_engine, text
from sqlalchemy.engine import Connection, Engine
Expand All @@ -18,7 +18,7 @@
InternalError,
InvalidRequestError,
ProgrammingError,
OperationalError
OperationalError,
)
from sqlalchemy.orm import sessionmaker
from sqlalchemy.schema import Table
Expand Down Expand Up @@ -232,6 +232,9 @@ def infer_has_server_binds(sql):
return "%s" in sql or search(r"%\(\w+\)s", sql)


_default_statement_filter = lambda sql_text, params: True


def _run_sql(connectable, sql, params=None, **kwargs):
"""
Internal function for running a query on a SQLAlchemy connectable,
Expand All @@ -247,6 +250,7 @@ def _run_sql(connectable, sql, params=None, **kwargs):
raise_errors = kwargs.pop("raise_errors", False)
has_server_binds = kwargs.pop("has_server_binds", None)
ensure_single_query = kwargs.pop("ensure_single_query", False)
statement_filter = kwargs.pop("statement_filter", _default_statement_filter)

if stop_on_error:
raise_errors = True
Expand Down Expand Up @@ -288,6 +292,11 @@ def _run_sql(connectable, sql, params=None, **kwargs):
if has_server_binds is None:
has_server_binds = infer_has_server_binds(sql_text)

should_run = statement_filter(sql_text, params)
if not should_run:
pretty_print(sql_text, dim=True, strikethrough=True)
continue

# This only does something for postgresql, but it's harmless to run it for other engines
set_wait_callback(wait_select)

Expand Down Expand Up @@ -325,7 +334,9 @@ def _run_sql(connectable, sql, params=None, **kwargs):

def _should_raise_query_error(err):
"""Determine if an error should be raised for a query or not."""
if not isinstance(err, (ProgrammingError, IntegrityError, InternalError, OperationalError)):
if not isinstance(
err, (ProgrammingError, IntegrityError, InternalError, OperationalError)
):
return True

orig_err = getattr(err, "orig", None)
Expand All @@ -336,7 +347,10 @@ def _should_raise_query_error(err):
# We might want to change this behavior in the future, or support more graceful handling of errors from other
# database backends.
# Ideally we could handle operational errors more gracefully
if isinstance(orig_err, psycopg2.errors.QueryCanceled) or getattr(orig_err, "pgcode", None) == "57014":
if (
isinstance(orig_err, psycopg2.errors.QueryCanceled)
or getattr(orig_err, "pgcode", None) == "57014"
):
return True

return False
Expand Down Expand Up @@ -444,6 +458,9 @@ def run_sql(*args, **kwargs):
returning a list after completion.
ensure_single_query : bool
If True, raise an error if multiple queries are passed when only one is expected.
statement_filter : Callable
A function that takes a SQL statement and parameters and returns True if the statement
should be run, and False if it should be skipped.
"""
res = _run_sql(*args, **kwargs)
if kwargs.pop("yield_results", False):
Expand Down
2 changes: 1 addition & 1 deletion database/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ authors = ["Daven Quinn <dev@davenquinn.com>"]
description = "A SQLAlchemy-based database toolkit."
name = "macrostrat.database"
packages = [{ include = "macrostrat" }]
version = "3.5.0"
version = "3.5.1"

[tool.poetry.dependencies]
GeoAlchemy2 = "^0.15.2"
Expand Down
36 changes: 35 additions & 1 deletion database/tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from psycopg2.errors import SyntaxError
from psycopg2.extensions import AsIs
from psycopg2.sql import SQL, Identifier, Literal, Placeholder
from pytest import fixture, mark, raises, warns
from pytest import fixture, raises, warns
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.sql import text

Expand Down Expand Up @@ -106,6 +106,40 @@ def test_sql_text_inference_6():
assert infer_is_sql_text(insert_sample_query)


def test_sql_statement_filtering(db):
sql = """
INSERT INTO sample (name) VALUES (:name);

DELETE FROM sample WHERE name = :name;
"""

assert infer_is_sql_text(sql)

with db.transaction(rollback="always"):
# Make sure there are no samples
assert _get_sample_count(db) == 0

# Run the SQL, filtering out the DELETE statement

def filter_func(statement, params):
return not statement.startswith("DELETE")

res = db.run_sql(
sql,
params=dict(name="Test"),
raise_errors=True,
statement_filter=filter_func,
yield_results=False,
)

assert len(res) == 1
assert _get_sample_count(db) == 1


def _get_sample_count(db):
return db.run_query("SELECT count(*) FROM sample").scalar()


def test_sql_interpolation_psycopg(db):
db.run_sql(insert_sample_query, params=dict(name="Test"), raise_errors=True)
db.session.commit()
Expand Down
Loading