From a4a8f4b7882b3a00230047142a3b2b0feb826c89 Mon Sep 17 00:00:00 2001 From: Daven Quinn Date: Sat, 21 Dec 2024 18:55:06 -0600 Subject: [PATCH] Added a function to filter statements --- database/CHANGELOG.md | 7 +++++ database/macrostrat/database/__init__.py | 2 +- database/macrostrat/database/utils.py | 25 +++++++++++++--- database/pyproject.toml | 2 +- database/tests/test_database.py | 36 +++++++++++++++++++++++- 5 files changed, 65 insertions(+), 7 deletions(-) diff --git a/database/CHANGELOG.md b/database/CHANGELOG.md index 1231197..9ed4f72 100644 --- a/database/CHANGELOG.md +++ b/database/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [3.5.1] - 2024-12-21 + +- Add a `statement_filter` parameter to the `run_sql` function to allow for + filtering of statements in a SQL file. +- Improved the consistency of the `Database.run_sql` function with the `run_sql` + utility function. + ## [3.5.0] - 2024-11-25 - Add database transfer utilities for asynchronous `pg_load` and `pg_dump` diff --git a/database/macrostrat/database/__init__.py b/database/macrostrat/database/__init__.py index 877fb3d..9dc07ac 100644 --- a/database/macrostrat/database/__init__.py +++ b/database/macrostrat/database/__init__.py @@ -128,7 +128,7 @@ def run_sql(self, fn, params=None, **kwargs): Returns: Iterator of results from the query. """ params = self._setup_params(params, kwargs) - return iter(run_sql(self.session, fn, params, **kwargs)) + return run_sql(self.session, fn, params, **kwargs) def run_query(self, sql, params=None, **kwargs): """Run a single query on the database object, returning the result. diff --git a/database/macrostrat/database/utils.py b/database/macrostrat/database/utils.py index 9f745ac..4821c22 100644 --- a/database/macrostrat/database/utils.py +++ b/database/macrostrat/database/utils.py @@ -5,11 +5,11 @@ from typing import IO, Union from warnings import warn +import psycopg2.errors from click import echo, secho from psycopg2.extensions import set_wait_callback from psycopg2.extras import wait_select from psycopg2.sql import SQL, Composable, Composed -import psycopg2.errors from rich.console import Console from sqlalchemy import MetaData, create_engine, text from sqlalchemy.engine import Connection, Engine @@ -18,7 +18,7 @@ InternalError, InvalidRequestError, ProgrammingError, - OperationalError + OperationalError, ) from sqlalchemy.orm import sessionmaker from sqlalchemy.schema import Table @@ -232,6 +232,9 @@ def infer_has_server_binds(sql): return "%s" in sql or search(r"%\(\w+\)s", sql) +_default_statement_filter = lambda sql_text, params: True + + def _run_sql(connectable, sql, params=None, **kwargs): """ Internal function for running a query on a SQLAlchemy connectable, @@ -247,6 +250,7 @@ def _run_sql(connectable, sql, params=None, **kwargs): raise_errors = kwargs.pop("raise_errors", False) has_server_binds = kwargs.pop("has_server_binds", None) ensure_single_query = kwargs.pop("ensure_single_query", False) + statement_filter = kwargs.pop("statement_filter", _default_statement_filter) if stop_on_error: raise_errors = True @@ -288,6 +292,11 @@ def _run_sql(connectable, sql, params=None, **kwargs): if has_server_binds is None: has_server_binds = infer_has_server_binds(sql_text) + should_run = statement_filter(sql_text, params) + if not should_run: + pretty_print(sql_text, dim=True, strikethrough=True) + continue + # This only does something for postgresql, but it's harmless to run it for other engines set_wait_callback(wait_select) @@ -325,7 +334,9 @@ def _run_sql(connectable, sql, params=None, **kwargs): def _should_raise_query_error(err): """Determine if an error should be raised for a query or not.""" - if not isinstance(err, (ProgrammingError, IntegrityError, InternalError, OperationalError)): + if not isinstance( + err, (ProgrammingError, IntegrityError, InternalError, OperationalError) + ): return True orig_err = getattr(err, "orig", None) @@ -336,7 +347,10 @@ def _should_raise_query_error(err): # We might want to change this behavior in the future, or support more graceful handling of errors from other # database backends. # Ideally we could handle operational errors more gracefully - if isinstance(orig_err, psycopg2.errors.QueryCanceled) or getattr(orig_err, "pgcode", None) == "57014": + if ( + isinstance(orig_err, psycopg2.errors.QueryCanceled) + or getattr(orig_err, "pgcode", None) == "57014" + ): return True return False @@ -444,6 +458,9 @@ def run_sql(*args, **kwargs): returning a list after completion. ensure_single_query : bool If True, raise an error if multiple queries are passed when only one is expected. + statement_filter : Callable + A function that takes a SQL statement and parameters and returns True if the statement + should be run, and False if it should be skipped. """ res = _run_sql(*args, **kwargs) if kwargs.pop("yield_results", False): diff --git a/database/pyproject.toml b/database/pyproject.toml index a0565e0..fd68f06 100644 --- a/database/pyproject.toml +++ b/database/pyproject.toml @@ -3,7 +3,7 @@ authors = ["Daven Quinn "] description = "A SQLAlchemy-based database toolkit." name = "macrostrat.database" packages = [{ include = "macrostrat" }] -version = "3.5.0" +version = "3.5.1" [tool.poetry.dependencies] GeoAlchemy2 = "^0.15.2" diff --git a/database/tests/test_database.py b/database/tests/test_database.py index f30d564..37d88c0 100644 --- a/database/tests/test_database.py +++ b/database/tests/test_database.py @@ -11,7 +11,7 @@ from psycopg2.errors import SyntaxError from psycopg2.extensions import AsIs from psycopg2.sql import SQL, Identifier, Literal, Placeholder -from pytest import fixture, mark, raises, warns +from pytest import fixture, raises, warns from sqlalchemy.exc import ProgrammingError from sqlalchemy.sql import text @@ -106,6 +106,40 @@ def test_sql_text_inference_6(): assert infer_is_sql_text(insert_sample_query) +def test_sql_statement_filtering(db): + sql = """ + INSERT INTO sample (name) VALUES (:name); + + DELETE FROM sample WHERE name = :name; + """ + + assert infer_is_sql_text(sql) + + with db.transaction(rollback="always"): + # Make sure there are no samples + assert _get_sample_count(db) == 0 + + # Run the SQL, filtering out the DELETE statement + + def filter_func(statement, params): + return not statement.startswith("DELETE") + + res = db.run_sql( + sql, + params=dict(name="Test"), + raise_errors=True, + statement_filter=filter_func, + yield_results=False, + ) + + assert len(res) == 1 + assert _get_sample_count(db) == 1 + + +def _get_sample_count(db): + return db.run_query("SELECT count(*) FROM sample").scalar() + + def test_sql_interpolation_psycopg(db): db.run_sql(insert_sample_query, params=dict(name="Test"), raise_errors=True) db.session.commit()