diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 00000000..b94b4507 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,73 @@ +name: "Test" + +on: + push: + paths-ignore: + - "docs/**" + pull_request: + paths-ignore: + - "docs/**" + schedule: + - cron: '40 1 * * 3' + + +jobs: + test: + name: test-python${{ matrix.python-version }}-sa${{ matrix.sqlalchemy-version }}-${{ matrix.db-engine }} + strategy: + matrix: + python-version: +# - "2.7" +# - "3.4" +# - "3.5" +# - "3.6" +# - "3.7" + - "3.8" +# - "3.9" +# - "3.10" +# - "pypy-3.7" + sqlalchemy-version: + - "<1.4" + - ">=1.4" + db-engine: + - sqlite + - postgres + - postgres-native + - mysql + runs-on: ubuntu-latest + services: + mysql: + image: mysql + ports: + - 3306:3306 + env: + MYSQL_DATABASE: sqlalchemy_continuum_test + MYSQL_ALLOW_EMPTY_PASSWORD: yes + options: >- + --health-cmd "mysqladmin ping" + --health-interval 5s + --health-timeout 2s + --health-retries 3 + postgres: + image: postgres + ports: + - 5432:5432 + env: + POSTGRES_PASSWORD: postgres + POSTGRES_DB: sqlalchemy_continuum_test + options: >- + --health-cmd pg_isready + --health-interval 5s + --health-timeout 2s + --health-retries 3 + steps: + - uses: actions/checkout@v1 + - name: Install sqlalchemy + run: pip3 install 'sqlalchemy${{ matrix.sqlalchemy-version }}' + - name: Build + run: pip3 install -e '.[test]' + - name: Run tests + run: pytest + env: + DB: ${{ matrix.db-engine }} + diff --git a/.gitignore b/.gitignore index a015a2d6..fe795a50 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,9 @@ nosetests.xml .mr.developer.cfg .project .pydevproject + +# mypy +.mypy_cache/ + +# Unit test / coverage reports +.cache diff --git a/.travis.yml b/.travis.yml index 6ee80dba..035bf151 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,9 @@ -addons: - postgresql: 9.3 +services: + - mysql + - postgresql + +dist: xenial +sudo: true env: - DB=mysql @@ -15,9 +19,10 @@ before_script: language: python python: - 2.7 - - 3.3 - 3.4 - 3.5 + - 3.6 + - 3.7 install: - pip install -e ".[test]" script: diff --git a/CHANGES.rst b/CHANGES.rst index c9db766a..8499b4f9 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,83 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Continuum release. +1.3.13 (2022-09-07) +^^^^^^^^^^^^^^^^^^^ + +- Fixes for Flask 2.2 and Flask-Login 0.6.2 (#288, thanks to AbdealiJK) +- Allow changed_entities to work without TransactionChanges plugin (#268, thanks to TomGoBravo) +- Fix Activity plugin for non-composite primary keys not named id (#210, thanks to dryobates) +- Allow sync_trigger to pass arguments through to create_trigger (#273, thanks to nanvel) +- Fix association tables on Oracle (#291, thanks to AbdealiJK) +- Fix some deprecation warnings in SA 1.4 (#269, #277, #279, #300, #302, thanks to TomGoBravo, edhaz, and indiVar0508) + +1.3.12 (2022-01-18) +^^^^^^^^^^^^^^^^^^^ + +- Support SA 1.4 + +1.3.11 (2020-05-24) +^^^^^^^^^^^^^^^^^^^ + +- Made ModelBuilder create column aliases in version models (#246, courtesy of killthekitten) + + +1.3.10 (2020-05-10) +^^^^^^^^^^^^^^^^^^^ + +- Added explicit "pseudo-backref" relationships for version/parent (#240, courtesy of lgedgar) +- Fixed m2m Bug when an unrelated change is made to a model (#242, courtesy of Andrew-Dickinson) + + +1.3.9 (2019-03-19) +^^^^^^^^^^^^^^^^^^ + +- Added SA 1.3 support +- Reverted trigger creation from 1.3.7 + + +1.3.8 (2019-02-27) +^^^^^^^^^^^^^^^^^^ + +- Fixed revert to ignore non-columns (#197, courtesy of mauler) + + +1.3.7 (2019-01-13) +^^^^^^^^^^^^^^^^^^ + +- Fix trigger creation during alembic migrations (#209, courtesy of lyndsysimon) + + +1.3.6 (2018-07-30) +^^^^^^^^^^^^^^^^^^ + +- Fixed ResourceClosedErrors from connections leaking when using an external transaction (#196, courtesy of vault) + + +1.3.5 (2018-06-03) +^^^^^^^^^^^^^^^^^^ + +- Track cloned connections (#167, courtesy of netcriptus) + + +1.3.4 (2018-03-07) +^^^^^^^^^^^^^^^^^^ + +- Exclude many-to-many properties from versioning if they are added in exclude parameter (#169, courtesy of fuhrysteve) + + +1.3.3 (2017-11-05) +^^^^^^^^^^^^^^^^^^ + +- Fixed changeset when updating object in same transaction as inserting it (#141, courtesy of oinopion) + + +1.3.2 (2017-10-12) +^^^^^^^^^^^^^^^^^^ + +- Fixed multiple schema handling (#132, courtesy of vault) + + 1.3.1 (2017-06-28) ^^^^^^^^^^^^^^^^^^ diff --git a/LICENSE b/LICENSE index d604ce84..cccc1d8f 100644 --- a/LICENSE +++ b/LICENSE @@ -12,8 +12,9 @@ modification, are permitted provided that the following conditions are met: this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. -* The names of the contributors may not be used to endorse or promote products - derived from this software without specific prior written permission. +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED diff --git a/README.rst b/README.rst index b3dec67c..367a16a7 100644 --- a/README.rst +++ b/README.rst @@ -41,7 +41,7 @@ In order to make your models versioned you need two things: from sqlalchemy_continuum import make_versioned - make_versioned() + make_versioned(user_cls=None) class Article(Base): @@ -53,27 +53,60 @@ In order to make your models versioned you need two things: content = sa.Column(sa.UnicodeText) - article = Article(name=u'Some article', content=u'Some content') + article = Article(name='Some article', content='Some content') session.add(article) session.commit() # article has now one version stored in database article.versions[0].name - # u'Some article' + # 'Some article' - article.name = u'Updated name' + article.name = 'Updated name' session.commit() article.versions[1].name - # u'Updated name' + # 'Updated name' # lets revert back to first version article.versions[0].revert() article.name - # u'Some article' + # 'Some article' + +For completeness, below is a working example. + +.. code-block:: python + + from sqlalchemy_continuum import make_versioned + from sqlalchemy import Column, Integer, Unicode, UnicodeText, create_engine + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.orm import create_session, configure_mappers + + make_versioned(user_cls=None) + + Base = declarative_base() + class Article(Base): + __versioned__ = {} + __tablename__ = 'article' + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(Unicode(255)) + content = Column(UnicodeText) + + configure_mappers() + engine = create_engine('sqlite://') + Base.metadata.create_all(engine) + session = create_session(bind=engine, autocommit=False) + article = Article(name=u'Some article', content=u'Some content') + session.add(article) + session.commit() + article.versions[0].name + article.name = u'Updated name' + session.commit() + article.versions[1].name + article.versions[0].revert() + article.name Resources --------- @@ -86,8 +119,8 @@ Resources .. image:: http://i.imgur.com/UFaRx.gif -.. |Build Status| image:: https://travis-ci.org/kvesteri/sqlalchemy-continuum.png?branch=master - :target: https://travis-ci.org/kvesteri/sqlalchemy-continuum +.. |Build Status| image:: https://github.com/kvesteri/sqlalchemy-continuum/workflows/Test/badge.svg + :target: https://github.com/kvesteri/sqlalchemy-continuum/actions?query=workflow%3ATest .. |Version Status| image:: https://img.shields.io/pypi/v/SQLAlchemy-Continuum.png :target: https://pypi.python.org/pypi/SQLAlchemy-Continuum/ .. |Downloads| image:: https://img.shields.io/pypi/dm/SQLAlchemy-Continuum.png diff --git a/benchmark.py b/benchmark.py index 9ee0250f..55c6e41c 100644 --- a/benchmark.py +++ b/benchmark.py @@ -6,7 +6,7 @@ import sqlalchemy as sa from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, close_all_sessions from sqlalchemy_continuum import ( make_versioned, versioning_manager, @@ -50,7 +50,7 @@ def test_versioning( make_versioned(options=options) - dns = 'postgres://postgres@localhost/sqlalchemy_continuum_test' + dns = 'postgresql://postgres:postgres@localhost/sqlalchemy_continuum_test' versioning_manager.plugins = plugins versioning_manager.transaction_cls = transaction_cls versioning_manager.user_cls = user_cls @@ -106,7 +106,7 @@ class Tag(Model): remove_versioning() versioning_manager.reset() - session.close_all() + close_all_sessions() session.expunge_all() Model.metadata.drop_all(connection) engine.dispose() diff --git a/docs/alembic.rst b/docs/alembic.rst index 562ebb34..d46a4700 100644 --- a/docs/alembic.rst +++ b/docs/alembic.rst @@ -1,6 +1,11 @@ Alembic migrations ================== -Each time you make changes to database structure you should also change the associated history tables. When you make changes to your models SQLAlchemy-Continuum automatically alters the history model definitions, hence you can use `alembic revision --autogenerate` just like before. You just need to make sure `make_versioned` function gets called before alembic gathers all your models. +Each time you make changes to database structure you should also change the associated history tables. When you make changes to your models SQLAlchemy-Continuum automatically alters the history model definitions, hence you can use `alembic revision --autogenerate` just like before. You just need to make sure `make_versioned` function gets called before alembic gathers all your models and `configure_mappers` is called afterwards. Pay close attention when dropping or moving data from parent tables and reflecting these changes to history tables. + +Troubleshooting +############### + +If alembic didn't detect any changes or generates reversed migration (tries to remove `*_version` tables from database instead of creating), make sure that `configure_mappers` was called by alembic command. diff --git a/docs/intro.rst b/docs/intro.rst index e6002b87..8b276994 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -5,7 +5,7 @@ Introduction Why? ^^^^ -SQLAlchemy already has versioning extension. This extension however is very limited. It does not support versioning entire transactions. +SQLAlchemy `already has a versioning extension `_. This extension however is very limited. It does not support versioning entire transactions. Hibernate for Java has Envers, which had nice features but lacks a nice API. Ruby on Rails has papertrail_, which has very nice API but lacks the efficiency and feature set of Envers. @@ -54,7 +54,7 @@ In order to make your models versioned you need two things: from sqlalchemy_continuum import make_versioned - make_versioned() + make_versioned(user_cls=None) class Article(Base): diff --git a/docs/native_versioning.rst b/docs/native_versioning.rst index 60e215fe..4bc63c96 100644 --- a/docs/native_versioning.rst +++ b/docs/native_versioning.rst @@ -30,3 +30,9 @@ When making schema migrations (for example adding new columns to version tables) sync_trigger(conn, 'article_version') + +If you don't use `PropertyModTrackerPlugin`, then you have to disable it: + +:: + + sync_trigger(conn, 'article_version', use_property_mod_tracking=False) diff --git a/docs/plugins.rst b/docs/plugins.rst index 5f88b6fe..09c1e261 100644 --- a/docs/plugins.rst +++ b/docs/plugins.rst @@ -7,7 +7,7 @@ Using plugins :: - from sqlalchemy.continuum.plugins import PropertyModTrackerPlugin + from sqlalchemy_continuum.plugins import PropertyModTrackerPlugin versioning_manager.plugins.append(PropertyModTrackerPlugin()) diff --git a/docs/version_objects.rst b/docs/version_objects.rst index ecbb71ce..54f0b774 100644 --- a/docs/version_objects.rst +++ b/docs/version_objects.rst @@ -102,7 +102,7 @@ you can easily check the changeset of given object in current transaction. article = Article(name=u'Some article') changeset(article) - # {'name': [u'Some article', None]} + # {'name': [None, u'Some article']} Version relationships diff --git a/setup.py b/setup.py index 439b821d..a0dfb651 100644 --- a/setup.py +++ b/setup.py @@ -28,15 +28,14 @@ def get_version(): 'pytest>=2.3.5', 'flexmock>=0.9.7', 'psycopg2>=2.4.6', - 'PyMySQL==0.6.1', + 'PyMySQL>=0.8.0', 'six>=1.4.0' ], - 'anyjson': ['anyjson>=0.3.3'], 'flask': ['Flask>=0.9'], 'flask-login': ['Flask-Login>=0.2.9'], - 'flask-sqlalchemy': ['Flask-SQLAlchemy>=1.0'], + 'flask-sqlalchemy': ['Flask-SQLAlchemy>=1.0,<3.0.0'], 'flexmock': ['flexmock>=0.9.7'], - 'i18n': ['SQLAlchemy-i18n>=0.8.4'], + 'i18n': ['SQLAlchemy-i18n>=0.8.4,!=1.1.0'], } @@ -77,9 +76,10 @@ def get_version(): 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Topic :: Internet :: WWW/HTTP :: Dynamic Content', 'Topic :: Software Development :: Libraries :: Python Modules' ] diff --git a/sqlalchemy_continuum/__init__.py b/sqlalchemy_continuum/__init__.py index bc32d298..61081473 100644 --- a/sqlalchemy_continuum/__init__.py +++ b/sqlalchemy_continuum/__init__.py @@ -18,7 +18,7 @@ ) -__version__ = '1.3.1' +__version__ = '1.3.13' versioning_manager = VersioningManager() @@ -70,6 +70,18 @@ def make_versioned( manager.track_association_operations ) + sa.event.listen( + sa.engine.Engine, + 'rollback', + manager.clear_connection + ) + + sa.event.listen( + sa.engine.Engine, + 'set_connection_execution_options', + manager.track_cloned_connections + ) + def remove_versioning( mapper=sa.orm.mapper, @@ -96,3 +108,15 @@ def remove_versioning( 'before_cursor_execute', manager.track_association_operations ) + + sa.event.remove( + sa.engine.Engine, + 'rollback', + manager.clear_connection + ) + + sa.event.remove( + sa.engine.Engine, + 'set_connection_execution_options', + manager.track_cloned_connections + ) diff --git a/sqlalchemy_continuum/builder.py b/sqlalchemy_continuum/builder.py index 47bbef35..53134029 100644 --- a/sqlalchemy_continuum/builder.py +++ b/sqlalchemy_continuum/builder.py @@ -1,5 +1,6 @@ from copy import copy from inspect import getmro +from functools import wraps import sqlalchemy as sa from sqlalchemy_utils.functions import get_declarative_base @@ -10,6 +11,18 @@ from .table_builder import TableBuilder +def prevent_reentry(handler): + in_handler = False + @wraps(handler) + def check_reentry(*args, **kwargs): + nonlocal in_handler + if in_handler: + return + in_handler = True + handler(*args, **kwargs) + in_handler = False + return check_reentry + class Builder(object): def build_triggers(self): """ @@ -141,17 +154,20 @@ def build_transaction_class(self): self.manager.create_transaction_model() self.manager.plugins.after_build_tx_class(self.manager) + @prevent_reentry def configure_versioned_classes(self): """ Configures all versioned classes that were collected during - instrumentation process. The configuration has 4 steps: + instrumentation process. The configuration has 6 steps: 1. Build tables for version models. 2. Build the actual version model declarative classes. 3. Build relationships between these models. 4. Empty pending_classes list so that consecutive mapper configuration does not create multiple version classes - 5. Assign all versioned attributes to use active history. + 5. Build aliases for columns. + 6. Assign all versioned attributes to use active history. + """ if not self.manager.options['versioning']: return @@ -168,11 +184,39 @@ def configure_versioned_classes(self): # Create copy of all pending versioned classes so that we can inspect # them later when creating relationships. - pending_copy = copy(self.manager.pending_classes) + pending_classes_copies = copy(self.manager.pending_classes) self.manager.pending_classes = [] - self.build_relationships(pending_copy) + self.build_relationships(pending_classes_copies) + self.enable_active_history(pending_classes_copies) + self.create_column_aliases(pending_classes_copies) - for cls in pending_copy: - # set the "active_history" flag + def enable_active_history(self, version_classes): + """ + Assign all versioned attributes to use active history. + """ + for cls in version_classes: for prop in sa.inspect(cls).iterate_properties: getattr(cls, prop.key).impl.active_history = True + + def create_column_aliases(self, version_classes): + """ + Create aliases for the columns from the original model. + + This, for example, imitates the behavior of @declared_attr columns. + """ + for cls in version_classes: + model_mapper = sa.inspect(cls) + version_class = self.manager.version_class_map.get(cls) + if not version_class: + continue + + version_class_mapper = sa.inspect(version_class) + + for key, column in model_mapper.columns.items(): + if key != column.key: + version_class_column = version_class.__table__.c.get(column.key) + + if version_class_column is None: + continue + + version_class_mapper.add_property(key, sa.orm.column_property(version_class_column)) diff --git a/sqlalchemy_continuum/dialects/postgresql.py b/sqlalchemy_continuum/dialects/postgresql.py index f24d9077..c69a64be 100644 --- a/sqlalchemy_continuum/dialects/postgresql.py +++ b/sqlalchemy_continuum/dialects/postgresql.py @@ -456,7 +456,7 @@ def create_versioning_trigger_listeners(manager, cls): ) -def sync_trigger(conn, table_name): +def sync_trigger(conn, table_name, **kwargs): """ Synchronizes versioning trigger for given table with given connection. @@ -468,6 +468,7 @@ def sync_trigger(conn, table_name): :param conn: SQLAlchemy connection object :param table_name: Name of the table to synchronize versioning trigger for + :params **kwargs: kwargs to pass to create_trigger .. versionadded: 1.1.0 """ @@ -489,7 +490,7 @@ def sync_trigger(conn, table_name): set(c.name for c in version_table.c if not c.name.endswith('_mod')) ) drop_trigger(conn, parent_table.name) - create_trigger(conn, table=parent_table, excluded_columns=excluded_columns) + create_trigger(conn, table=parent_table, excluded_columns=excluded_columns, **kwargs) def create_trigger( diff --git a/sqlalchemy_continuum/factory.py b/sqlalchemy_continuum/factory.py index 5e36dc81..9951f67a 100644 --- a/sqlalchemy_continuum/factory.py +++ b/sqlalchemy_continuum/factory.py @@ -6,7 +6,11 @@ def __call__(self, manager): Create model class but only if it doesn't already exist in declarative model registry. """ - registry = manager.declarative_base._decl_class_registry + Base = manager.declarative_base + try: + registry = Base.registry._class_registry + except AttributeError: # SQLAlchemy < 1.4 + registry = Base._decl_class_registry if self.model_name not in registry: return self.create_class(manager) return registry[self.model_name] diff --git a/sqlalchemy_continuum/fetcher.py b/sqlalchemy_continuum/fetcher.py index 1ac1a175..a1f2684d 100644 --- a/sqlalchemy_continuum/fetcher.py +++ b/sqlalchemy_continuum/fetcher.py @@ -59,7 +59,7 @@ def _transaction_id_subquery(self, obj, next_or_prev='next', alias=None): func = sa.func.max if alias is None: - alias = sa.orm.aliased(obj) + alias = sa.orm.aliased(obj.__class__) table = alias.__table__ if hasattr(alias, 'c'): attrs = alias.c @@ -90,11 +90,22 @@ def _transaction_id_subquery(self, obj, next_or_prev='next', alias=None): ) .correlate(table) ) - return query + try: + return query.scalar_subquery() + except AttributeError: # SQLAlchemy < 1.4 + return query.as_scalar() def _next_prev_query(self, obj, next_or_prev='next'): session = sa.orm.object_session(obj) + subquery = self._transaction_id_subquery( + obj, next_or_prev=next_or_prev + ) + try: + subquery = subquery.scalar_subquery() + except AttributeError: # SQLAlchemy < 1.4 + subquery = subquery.as_scalar() + return ( session.query(obj.__class__) .filter( @@ -102,11 +113,7 @@ def _next_prev_query(self, obj, next_or_prev='next'): getattr( obj.__class__, tx_column_name(obj) - ) - == - self._transaction_id_subquery( - obj, next_or_prev=next_or_prev - ), + ) == subquery, *parent_criteria(obj) ) ) @@ -117,7 +124,7 @@ def _index_query(self, obj): Returns the query needed for fetching the index of this record relative to version history. """ - alias = sa.orm.aliased(obj) + alias = sa.orm.aliased(obj.__class__) subquery = ( sa.select([sa.func.count('1')], from_obj=[alias.__table__]) diff --git a/sqlalchemy_continuum/manager.py b/sqlalchemy_continuum/manager.py index a452153b..7bb5f731 100644 --- a/sqlalchemy_continuum/manager.py +++ b/sqlalchemy_continuum/manager.py @@ -24,7 +24,15 @@ def wrapper(self, mapper, connection, target): try: uow = self.units_of_work[conn] except KeyError: - uow = self.units_of_work[conn.engine] + try: + uow = self.units_of_work[conn.engine] + except KeyError: + for connection in self.units_of_work.keys(): + if not connection.closed and connection.connection is conn.connection: + uow = self.unit_of_work(session) + break # The ConnectionFairy is the same, this connection is a clone + else: + raise return func(self, uow, target) return wrapper @@ -357,27 +365,71 @@ def clear(self, session): if session.transaction.nested: return conn = self.session_connection_map.pop(session, None) + if conn is None: + return + if conn in self.units_of_work: uow = self.units_of_work[conn] uow.reset(session) del self.units_of_work[conn] + for connection in dict(self.units_of_work).keys(): + if connection.closed or conn.connection is connection.connection: + uow = self.units_of_work[connection] + uow.reset(session) + del self.units_of_work[connection] + + def clear_connection(self, conn): + if conn in self.units_of_work: + uow = self.units_of_work[conn] + uow.reset() + del self.units_of_work[conn] + + + for session, connection in dict(self.session_connection_map).items(): + if connection is conn: + del self.session_connection_map[session] + + + for connection in dict(self.units_of_work).keys(): + if connection.closed or conn.connection is connection.connection: + uow = self.units_of_work[connection] + uow.reset() + del self.units_of_work[connection] + + def append_association_operation(self, conn, table_name, params, op): """ Append history association operation to pending_statements list. """ - params['operation_type'] = op stmt = ( self.metadata.tables[self.options['table_name'] % table_name] .insert() - .values(params) + .values({**params, 'operation_type': op}) ) try: uow = self.units_of_work[conn] except KeyError: - uow = self.units_of_work[conn.engine] + try: + uow = self.units_of_work[conn.engine] + except KeyError: + for connection in self.units_of_work.keys(): + if not connection.closed and connection.connection is conn.connection: + uow = self.unit_of_work(conn.session) + break # The ConnectionFairy is the same, this connection is a clone + else: + raise uow.pending_statements.append(stmt) + def track_cloned_connections(self, c, opt): + """ + Track cloned connections from association tables. + """ + if c not in self.units_of_work.keys(): + for connection, uow in dict(self.units_of_work).items(): + if not connection.closed and connection.connection is c.connection: # ConnectionFairy is the same - this is a clone + self.units_of_work[c] = uow + def track_association_operations( self, conn, cursor, statement, parameters, context, executemany ): @@ -400,7 +452,8 @@ def track_association_operations( if op is not None: table_name = statement.split(' ')[2] table_names = [ - table.name for table in self.association_tables + table.name if not table.schema else table.schema + '.' + table.name + for table in self.association_tables ] if table_name in table_names: if executemany: diff --git a/sqlalchemy_continuum/model_builder.py b/sqlalchemy_continuum/model_builder.py index 2be6e63b..b2f15114 100644 --- a/sqlalchemy_continuum/model_builder.py +++ b/sqlalchemy_continuum/model_builder.py @@ -107,6 +107,7 @@ class represents). """ conditions = [] foreign_keys = [] + model_keys = [] for key, column in sa.inspect(self.model).columns.items(): if column.primary_key: conditions.append( @@ -117,6 +118,9 @@ class represents). foreign_keys.append( getattr(self.version_class, key) ) + model_keys.append( + getattr(self.model, key) + ) # We need to check if versions relation was already set for parent # class. @@ -130,11 +134,18 @@ class represents). option(self.model, 'transaction_column_name') ), lazy='dynamic', - backref=sa.orm.backref( - 'version_parent' - ), viewonly=True ) + # We must explicitly declare this relationship, instead of + # specifying as a backref to the one above, since they are + # viewonly=True and SQLAlchemy will warn if using backref. + self.version_class.version_parent = sa.orm.relationship( + self.model, + primaryjoin=sa.and_(*conditions), + foreign_keys=model_keys, + viewonly=True, + uselist=False, + ) def build_transaction_relationship(self, tx_class): """ @@ -261,6 +272,7 @@ def mapper_args(cls): name = '%sVersion' % (self.model.__name__,) return type(name, self.base_classes(), args) + def __call__(self, table, tx_class): """ Build history model and relationships to parent model, transaction diff --git a/sqlalchemy_continuum/plugins/activity.py b/sqlalchemy_continuum/plugins/activity.py index 8905079a..10b85d3f 100644 --- a/sqlalchemy_continuum/plugins/activity.py +++ b/sqlalchemy_continuum/plugins/activity.py @@ -191,6 +191,7 @@ import sqlalchemy as sa from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.inspection import inspect from sqlalchemy_utils import JSONType, generic_relationship from .base import Plugin @@ -254,11 +255,13 @@ def _calculate_tx_id(self, obj): if object_version: return object_version.transaction_id - version_cls = version_class(obj.__class__) + model = obj.__class__ + version_cls = version_class(model) + primary_key = inspect(model).primary_key[0].name return session.query( sa.func.max(version_cls.transaction_id) ).filter( - version_cls.id == obj.id + getattr(version_cls, primary_key) == getattr(obj, primary_key) ).scalar() def calculate_object_tx_id(self): @@ -314,6 +317,8 @@ def target_version_type(cls): class ActivityPlugin(Plugin): + activity_cls = None + def after_build_models(self, manager): self.activity_cls = ActivityFactory()(manager) manager.activity_cls = self.activity_cls diff --git a/sqlalchemy_continuum/plugins/flask.py b/sqlalchemy_continuum/plugins/flask.py index c7b14254..d18a9b8b 100644 --- a/sqlalchemy_continuum/plugins/flask.py +++ b/sqlalchemy_continuum/plugins/flask.py @@ -36,7 +36,7 @@ def fetch_current_user_id(): if _app_ctx_stack.top is None or _request_ctx_stack.top is None: return try: - return current_user.id + return current_user.get_id() except AttributeError: return diff --git a/sqlalchemy_continuum/relationship_builder.py b/sqlalchemy_continuum/relationship_builder.py index f6b114e7..032e09b1 100644 --- a/sqlalchemy_continuum/relationship_builder.py +++ b/sqlalchemy_continuum/relationship_builder.py @@ -47,18 +47,21 @@ def one_to_many_subquery(self, obj): def many_to_one_subquery(self, obj): tx_column = option(obj, 'transaction_column_name') reflector = VersionExpressionReflector(obj, self.property) - - return getattr(self.remote_cls, tx_column) == ( - sa.select( - [sa.func.max(getattr(self.remote_cls, tx_column))] - ).where( - sa.and_( - getattr(self.remote_cls, tx_column) <= - getattr(obj, tx_column), - reflector(self.property.primaryjoin) - ) + subquery = sa.select( + [sa.func.max(getattr(self.remote_cls, tx_column))] + ).where( + sa.and_( + getattr(self.remote_cls, tx_column) <= + getattr(obj, tx_column), + reflector(self.property.primaryjoin) ) ) + try: + subquery = subquery.scalar_subquery() + except AttributeError: # SQLAlchemy < 1.4 + subquery = subquery.as_scalar() + + return getattr(self.remote_cls, tx_column) == subquery def query(self, obj): session = sa.orm.object_session(obj) @@ -249,6 +252,7 @@ def association_subquery(self, obj): FROM article_tag_version as article_tag_version2 WHERE article_tag_version2.tag_id = article_tag_version.tag_id AND article_tag_version2.tx_id <=5 + AND article_tag_version2.article_id = 3 GROUP BY article_tag_version2.tag_id HAVING MAX(article_tag_version2.tx_id) = @@ -260,6 +264,8 @@ def association_subquery(self, obj): """ tx_column = option(obj, 'transaction_column_name') + join_column = self.property.primaryjoin.right.name + object_join_column = self.property.primaryjoin.left.name reflector = VersionExpressionReflector(obj, self.property) association_table_alias = self.association_version_table.alias() @@ -276,6 +282,7 @@ def association_subquery(self, obj): sa.and_( association_table_alias.c[tx_column] <= getattr(obj, tx_column), + association_table_alias.c[join_column] == getattr(obj, object_join_column), *[association_col == self.association_version_table.c[association_col.name] for association_col @@ -316,7 +323,9 @@ def build_association_version_tables(self): column.table ) metadata = column.table.metadata - if metadata.schema: + if builder.parent_table.schema: + table_name = builder.parent_table.schema + '.' + builder.table_name + elif metadata.schema: table_name = metadata.schema + '.' + builder.table_name else: table_name = builder.table_name @@ -344,13 +353,16 @@ def __call__(self): except ClassNotVersioned: self.remote_cls = self.property.mapper.class_ - if self.property.secondary is not None and not self.property.viewonly: + if (self.property.secondary is not None and + not self.property.viewonly and + not self.manager.is_excluded_property( + self.model, self.property.key)): self.build_association_version_tables() # store remote cls to association table column pairs self.remote_to_association_column_pairs = [] for column_pair in self.property.local_remote_pairs: - if column_pair[0] in self.property.table.c.values(): + if column_pair[0] in self.property.target.c.values(): self.remote_to_association_column_pairs.append(column_pair) setattr( diff --git a/sqlalchemy_continuum/schema.py b/sqlalchemy_continuum/schema.py index 659df1b6..83728ef8 100644 --- a/sqlalchemy_continuum/schema.py +++ b/sqlalchemy_continuum/schema.py @@ -25,6 +25,10 @@ def get_end_tx_column_query( ] ) ) + try: + tx_criterion = tx_criterion.scalar_subquery() + except AttributeError: # SQLAlchemy < 1.4 + tx_criterion = tx_criterion.as_scalar() return sa.select( columns=[ getattr(v1.c, column) diff --git a/sqlalchemy_continuum/table_builder.py b/sqlalchemy_continuum/table_builder.py index 8ac0de8f..600d1e02 100644 --- a/sqlalchemy_continuum/table_builder.py +++ b/sqlalchemy_continuum/table_builder.py @@ -20,12 +20,9 @@ def reflect_column(self, column): :param column: SQLAlchemy Column object of parent table """ - # Make a copy of the column so that it does not point to wrong - # table. - column_copy = column.copy() - # Remove unique constraints + # Make a copy of the column so that it does not point to wrong table. + column_copy = column._copy() if hasattr(column, '_copy') else column.copy() column_copy.unique = False - # Remove onupdate triggers column_copy.onupdate = None if column_copy.autoincrement: column_copy.autoincrement = False @@ -150,5 +147,6 @@ def __call__(self, extends=None): extends.name if extends is not None else self.table_name, self.parent_table.metadata, *columns, + schema=self.parent_table.schema, extend_existing=extends is not None ) diff --git a/sqlalchemy_continuum/transaction.py b/sqlalchemy_continuum/transaction.py index ce3b3de1..68f0fd76 100644 --- a/sqlalchemy_continuum/transaction.py +++ b/sqlalchemy_continuum/transaction.py @@ -1,4 +1,5 @@ from datetime import datetime +from functools import partial try: from collections import OrderedDict @@ -22,6 +23,10 @@ def compile_big_integer(element, compiler, **kw): return 'INTEGER' +class NoChangesAttribute(Exception): + pass + + class TransactionBase(object): issued_at = sa.Column(sa.DateTime, default=datetime.utcnow) @@ -29,8 +34,13 @@ class TransactionBase(object): def entity_names(self): """ Return a list of entity names that changed during this transaction. + Raises a NoChangesAttribute exception if the 'changes' column does + not exist, most likely because TransactionChangesPlugin is not enabled. """ - return [changes.entity_name for changes in self.changes] + if hasattr(self, 'changes'): + return [changes.entity_name for changes in self.changes] + else: + raise NoChangesAttribute() @property def changed_entities(self): @@ -47,8 +57,11 @@ def changed_entities(self): session = sa.orm.object_session(self) for class_, version_class in tuples: - if class_.__name__ not in self.entity_names: - continue + try: + if class_.__name__ not in self.entity_names: + continue + except NoChangesAttribute: + pass tx_column = manager.option(class_, 'transaction_column_name') @@ -103,7 +116,7 @@ def create_triggers(cls): class TransactionFactory(ModelFactory): - model_name = 'Transaction' + model_name = 'VersionTransaction' def __init__(self, remote_addr=True): self.remote_addr = remote_addr @@ -112,16 +125,16 @@ def create_class(self, manager): """ Create Transaction class. """ - class Transaction( + class VersionTransaction( manager.declarative_base, TransactionBase ): - __tablename__ = 'transaction' + __tablename__ = 'version_transaction' __versioning_manager__ = manager id = sa.Column( sa.types.BigInteger, - sa.schema.Sequence('transaction_id_seq'), + sa.schema.Sequence('version_transaction_id_seq'), primary_key=True, autoincrement=True ) @@ -131,7 +144,11 @@ class Transaction( if manager.user_cls: user_cls = manager.user_cls - registry = manager.declarative_base._decl_class_registry + Base = manager.declarative_base + try: + registry = Base.registry._class_registry + except AttributeError: # SQLAlchemy < 1.4 + registry = Base._decl_class_registry if isinstance(user_cls, six.string_types): try: @@ -147,9 +164,7 @@ class Transaction( user_id = sa.Column( sa.inspect(user_cls).primary_key[0].type, - sa.ForeignKey( - '%s.%s' % (user_cls.__tablename__, sa.inspect(user_cls).primary_key[0].name) - ), + sa.ForeignKey(sa.inspect(user_cls).primary_key[0]), index=True ) @@ -162,7 +177,7 @@ def __repr__(self): for field in fields if hasattr(self, field) ) - return '' % ', '.join( + return '' % ', '.join( ( '%s=%r' % (field, value) if not isinstance(value, six.integer_types) @@ -175,5 +190,5 @@ def __repr__(self): ) if manager.options['native_versioning']: - create_triggers(Transaction) - return Transaction + create_triggers(VersionTransaction) + return VersionTransaction diff --git a/sqlalchemy_continuum/unit_of_work.py b/sqlalchemy_continuum/unit_of_work.py index 5f91b13d..22596b24 100644 --- a/sqlalchemy_continuum/unit_of_work.py +++ b/sqlalchemy_continuum/unit_of_work.py @@ -226,7 +226,7 @@ def version_validity_subquery(self, parent, version_obj, alias=None): return sa.select( [sa.text('max_1')], from_obj=[ - sa.sql.expression.alias(subquery, name='subquery') + sa.sql.expression.alias(subquery.subquery() if hasattr(subquery, 'subquery') else subquery, name='subquery') ] ) return subquery @@ -253,6 +253,11 @@ def update_version_validity(self, parent, version_obj): version_obj, alias=sa.orm.aliased(class_.__table__) ) + try: + subquery = subquery.scalar_subquery() + except AttributeError: # SQLAlchemy < 1.4 + subquery = subquery.as_scalar() + query = ( session.query(class_.__table__) .filter( diff --git a/sqlalchemy_continuum/utils.py b/sqlalchemy_continuum/utils.py index 372cc317..efca7b6f 100644 --- a/sqlalchemy_continuum/utils.py +++ b/sqlalchemy_continuum/utils.py @@ -133,7 +133,11 @@ def version_table(table): :param table: SQLAlchemy Table object """ - if table.metadata.schema: + if table.schema: + return table.metadata.tables[ + table.schema + '.' + table.name + '_version' + ] + elif table.metadata.schema: return table.metadata.tables[ table.metadata.schema + '.' + table.name + '_version' ] @@ -198,7 +202,11 @@ def versioned_column_properties(obj_or_class): cls = obj_or_class if isclass(obj_or_class) else obj_or_class.__class__ mapper = sa.inspect(cls) - for key in mapper.columns.keys(): + for key, column in mapper.columns.items(): + # Ignores non table columns + if not is_table_column(column): + continue + if not manager.is_excluded_property(obj_or_class, key): yield getattr(mapper.attrs, key) @@ -215,7 +223,7 @@ def versioned_relationships(obj, versioned_column_keys): yield prop -def vacuum(session, model): +def vacuum(session, model, yield_per=1000): """ When making structural changes to version tables (for example dropping columns) there are sometimes situations where some old version records @@ -236,6 +244,7 @@ def vacuum(session, model): :param session: SQLAlchemy session object :param model: SQLAlchemy declarative model class + :param yield_per: how many rows to process at a time """ version_cls = version_class(model) versions = defaultdict(list) @@ -243,15 +252,28 @@ def vacuum(session, model): query = ( session.query(version_cls) .order_by(option(version_cls, 'transaction_column_name')) - ) + ).yield_per(yield_per) + + primary_key_col = sa.inspection.inspect(model).primary_key[0].name for version in query: - if versions[version.id]: - prev_version = versions[version.id][-1] + version_id = getattr(version, primary_key_col) + if versions[version_id]: + prev_version = versions[version_id][-1] if naturally_equivalent(prev_version, version): session.delete(version) else: - versions[version.id].append(version) + versions[version_id].append(version) + + +def is_table_column(column): + """ + Return wheter of not give field is a column over the database table. + + :param column: SQLAclhemy model field. + :rtype: bool + """ + return isinstance(column, sa.Column) def is_internal_column(model, column_name): @@ -398,7 +420,10 @@ def changeset(obj): data = {} session = sa.orm.object_session(obj) if session and obj in session.deleted: - for column in sa.inspect(obj.__class__).columns.values(): + columns = [c for c in sa.inspect(obj.__class__).columns.values() + if is_table_column(c)] + + for column in columns: if not column.primary_key: value = getattr(obj, column.key) if value is not None: diff --git a/sqlalchemy_continuum/version.py b/sqlalchemy_continuum/version.py index 5c3c1ed2..d71e745d 100644 --- a/sqlalchemy_continuum/version.py +++ b/sqlalchemy_continuum/version.py @@ -1,4 +1,5 @@ import sqlalchemy as sa + from .reverter import Reverter from .utils import get_versioning_manager, is_internal_column, parent_class @@ -49,9 +50,6 @@ def changeset(self): and second list value as the new value. """ previous_version = self.previous - if not previous_version and self.operation_type != 0: - return {} - data = {} for key in sa.inspect(self.__class__).columns.keys(): diff --git a/tests/__init__.py b/tests/__init__.py index 310e9bb6..dc9dc878 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,4 @@ + from copy import copy import inspect import itertools as it @@ -6,7 +7,7 @@ import sqlalchemy as sa from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, column_property, close_all_sessions from sqlalchemy_continuum import ( ClassNotVersioned, version_class, @@ -41,9 +42,9 @@ def log_sql( def get_dns_from_driver(driver): if driver == 'postgres': - return 'postgres://postgres@localhost/sqlalchemy_continuum_test' + return 'postgresql://postgres:postgres@localhost/sqlalchemy_continuum_test' elif driver == 'mysql': - return 'mysql+pymysql://travis@localhost/sqlalchemy_continuum_test' + return 'mysql+pymysql://root@localhost/sqlalchemy_continuum_test' elif driver == 'sqlite': return 'sqlite:///:memory:' else: @@ -129,7 +130,7 @@ def teardown_method(self, method): QueryPool.queries = [] versioning_manager.reset() - self.session.close_all() + close_all_sessions() self.session.expunge_all() self.drop_tables() self.engine.dispose() @@ -148,6 +149,9 @@ class Article(self.Model): content = sa.Column(sa.UnicodeText) description = sa.Column(sa.UnicodeText) + # Dynamic column cotaining all text content data + fulltext_content = column_property(name + content + description) + class Tag(self.Model): __tablename__ = 'tag' __versioned__ = copy(self.options) diff --git a/tests/builders/test_table_builder.py b/tests/builders/test_table_builder.py index 16323d02..a2255c83 100644 --- a/tests/builders/test_table_builder.py +++ b/tests/builders/test_table_builder.py @@ -3,6 +3,7 @@ import sqlalchemy as sa from sqlalchemy_continuum import version_class from tests import TestCase +from pytest import mark class TestTableBuilder(TestCase): @@ -69,3 +70,31 @@ class Article(self.Model): def test_takes_out_onupdate_triggers(self): table = version_class(self.Article).__table__ assert table.c.last_update.onupdate is None + +@mark.skipif("os.environ.get('DB') == 'sqlite'") +class TestTableBuilderInOtherSchema(TestCase): + def create_models(self): + class Article(self.Model): + __tablename__ = 'article' + __versioned__ = copy(self.options) + __table_args__ = {'schema': 'other'} + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + last_update = sa.Column( + sa.DateTime, + default=datetime.utcnow, + onupdate=datetime.utcnow, + nullable=False + ) + self.Article = Article + + def create_tables(self): + self.connection.execute('DROP SCHEMA IF EXISTS other') + self.connection.execute('CREATE SCHEMA other') + TestCase.create_tables(self) + + def test_created_tables_retain_schema(self): + table = version_class(self.Article).__table__ + assert table.schema is not None + assert table.schema == self.Article.__table__.schema + diff --git a/tests/inheritance/test_single_table_inheritance.py b/tests/inheritance/test_single_table_inheritance.py index 9b723c15..73295bea 100644 --- a/tests/inheritance/test_single_table_inheritance.py +++ b/tests/inheritance/test_single_table_inheritance.py @@ -1,4 +1,5 @@ import sqlalchemy as sa +from sqlalchemy.ext.declarative import declared_attr from sqlalchemy_continuum import versioning_manager, version_class from tests import TestCase, create_test_cases @@ -18,6 +19,7 @@ class TextItem(self.Model): __mapper_args__ = { 'polymorphic_on': discriminator, + 'polymorphic_identity': u'base', 'with_polymorphic': '*' } @@ -25,6 +27,10 @@ class Article(TextItem): __mapper_args__ = {'polymorphic_identity': u'article'} name = sa.Column(sa.Unicode(255)) + @sa.ext.declarative.declared_attr + def status(cls): + return sa.Column("_status", sa.Unicode(255)) + class BlogPost(TextItem): __mapper_args__ = {'polymorphic_identity': u'blog_post'} title = sa.Column(sa.Unicode(255)) @@ -79,5 +85,8 @@ def test_transaction_changed_entities(self): assert transaction.entity_names == [u'Article'] assert transaction.changed_entities + def test_declared_attr_inheritance(self): + assert self.ArticleVersion.status + create_test_cases(SingleTableInheritanceTestCase) diff --git a/tests/plugins/test_activity.py b/tests/plugins/test_activity.py index 812eb542..4d0ab532 100644 --- a/tests/plugins/test_activity.py +++ b/tests/plugins/test_activity.py @@ -36,6 +36,34 @@ def create_activity(self, object=None, target=None): return activity +class TestActivityNotId(ActivityTestCase): + + def create_models(self): + TestCase.create_models(self) + + class NotIdModel(self.Model): + __tablename__ = 'not_id' + __versioned__ = { + 'base_classes': (self.Model, ) + } + + pk = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255), nullable=False) + self.NotIdModel = NotIdModel + + def test_create_activity_with_pk(self): + not_id_model = self.NotIdModel(name=u'Some model without id PK') + self.session.add(not_id_model) + self.session.commit() + self.create_activity(not_id_model) + self.session.commit() + activity = self.session.query(versioning_manager.activity_cls).first() + assert activity + assert activity.transaction_id + assert activity.object == not_id_model + assert activity.object_version == not_id_model.versions[-1] + + class TestActivity(ActivityTestCase): def test_creates_activity_class(self): assert versioning_manager.activity_cls.__name__ == 'Activity' diff --git a/tests/plugins/test_flask.py b/tests/plugins/test_flask.py index b81d6a65..0db414d0 100644 --- a/tests/plugins/test_flask.py +++ b/tests/plugins/test_flask.py @@ -1,11 +1,12 @@ import os from flask import Flask, url_for -from flask_login import LoginManager +from flask_login import LoginManager, UserMixin, login_user from flask_sqlalchemy import SQLAlchemy, _SessionSignalEvents from flexmock import flexmock import sqlalchemy as sa +from sqlalchemy.orm import close_all_sessions from sqlalchemy_continuum import ( make_versioned, remove_versioning, versioning_manager ) @@ -59,24 +60,10 @@ def teardown_method(self, method): self.client = None self.app = None - def login(self, user): - """ - Log in the user returned by :meth:`create_user`. - - :returns: the logged in user - """ - with self.client.session_transaction() as s: - s['user_id'] = user.id - return user - - def logout(self, user=None): - with self.client.session_transaction() as s: - s['user_id'] = None - def create_models(self): TestCase.create_models(self) - class User(self.Model): + class User(self.Model, UserMixin): __tablename__ = 'user' __versioned__ = { 'base_classes': (self.Model, ) @@ -114,7 +101,7 @@ def test_versioning_inside_request(self): user = self.User(name=u'Rambo') self.session.add(user) self.session.commit() - self.login(user) + login_user(user) self.client.get(url_for('.test_simple_flush')) article = self.session.query(self.Article).first() @@ -125,7 +112,7 @@ def test_raw_sql_and_flush(self): user = self.User(name=u'Rambo') self.session.add(user) self.session.commit() - self.login(user) + login_user(user) self.client.get(url_for('.test_raw_sql_and_flush')) assert ( self.session.query(versioning_manager.transaction_cls).count() == 2 @@ -248,7 +235,7 @@ def teardown_method(self, method): remove_versioning() self.db.session.remove() self.db.drop_all() - self.db.session.close_all() + close_all_sessions() self.db.engine.dispose() self.context.pop() self.context = None @@ -281,5 +268,3 @@ def test_create_transaction_with_scoped_session(self): uow = versioning_manager.unit_of_work(self.db.session) transaction = uow.create_transaction(self.db.session) assert transaction.id - - diff --git a/tests/relationships/test_association_table_relations.py b/tests/relationships/test_association_table_relations.py new file mode 100644 index 00000000..a447b61e --- /dev/null +++ b/tests/relationships/test_association_table_relations.py @@ -0,0 +1,65 @@ +import sqlalchemy as sa +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy.orm import relationship +from tests import TestCase, create_test_cases +from packaging import version as py_pkg_version + + +class AssociationTableRelationshipsTestCase(TestCase): + def create_models(self): + super(AssociationTableRelationshipsTestCase, self).create_models() + + class PublishedArticle(self.Model): + __tablename__ = 'published_article' + __table_args__ = ( + PrimaryKeyConstraint("article_id", "author_id"), + {'keep_existing': True} + ) + + article_id = sa.Column(sa.Integer, sa.ForeignKey('article.id')) + author_id = sa.Column(sa.Integer, sa.ForeignKey('author.id')) + relationship_kwargs = {} + if py_pkg_version.parse(sa.__version__) >= py_pkg_version.parse('1.4.0'): + relationship_kwargs.update({'overlaps': 'articles'}) + author = relationship('Author', **relationship_kwargs) + article = relationship('Article', **relationship_kwargs) + + self.PublishedArticle = PublishedArticle + + published_articles_table = sa.Table(PublishedArticle.__tablename__, + PublishedArticle.metadata, + extend_existing=True) + + class Author(self.Model): + __tablename__ = 'author' + __versioned__ = { + 'base_classes': (self.Model, ) + } + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + articles = relationship('Article', secondary=published_articles_table) + + self.Author = Author + + def test_version_relations(self): + article = self.Article() + name = u'Some article' + article.name = name + article.content = u'Some content' + self.session.add(article) + self.session.commit() + assert article.versions[0].name == name + + au = self.Author(name=u'Some author') + self.session.add(au) + self.session.commit() + + pa = self.PublishedArticle(article_id=article.id, author_id=au.id) + self.session.add(pa) + + self.session.commit() + + + +create_test_cases(AssociationTableRelationshipsTestCase) diff --git a/tests/relationships/test_custom_condition_relations.py b/tests/relationships/test_custom_condition_relations.py index b888e424..08c89955 100644 --- a/tests/relationships/test_custom_condition_relations.py +++ b/tests/relationships/test_custom_condition_relations.py @@ -1,6 +1,6 @@ import sqlalchemy as sa from tests import TestCase, create_test_cases - +from packaging import version as py_pkg_version class CustomConditionRelationsTestCase(TestCase): def create_models(self): @@ -26,12 +26,19 @@ class Tag(self.Model): article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id)) category = sa.Column(sa.Unicode(20)) + if py_pkg_version.parse(sa.__version__) < py_pkg_version.parse('1.4.0'): + primary_key_overlaps = {} + secondary_key_overlaps = {} + else: + primary_key_overlaps = {'overlaps': 'secondary_tags, Article'} + secondary_key_overlaps = {'overlaps': 'primary_tags, Article'} Article.primary_tags = sa.orm.relationship( Tag, primaryjoin=sa.and_( Tag.article_id == Article.id, Tag.category == u'primary' ), + **primary_key_overlaps ) Article.secondary_tags = sa.orm.relationship( @@ -40,6 +47,7 @@ class Tag(self.Model): Tag.article_id == Article.id, Tag.category == u'secondary' ), + **secondary_key_overlaps ) self.Article = Article diff --git a/tests/relationships/test_many_to_many_relations.py b/tests/relationships/test_many_to_many_relations.py index 256178ee..5398fe64 100644 --- a/tests/relationships/test_many_to_many_relations.py +++ b/tests/relationships/test_many_to_many_relations.py @@ -1,4 +1,5 @@ import pytest +from pytest import mark import sqlalchemy as sa from sqlalchemy_continuum import versioning_manager @@ -69,6 +70,33 @@ def test_single_insert(self): self.session.commit() assert len(article.versions[0].tags) == 1 + def test_unrelated_change(self): + tag1 = self.Tag(name=u'some tag') + tag2 = self.Tag(name=u'some tag2') + + self.session.add(tag1) + self.session.add(tag2) + self.session.commit() + + article1 = self.Article(name="Some article", ) + article1.name = u'Some article' + article1.tags.append(tag1) + + self.session.add(article1) + self.session.commit() + + article2 = self.Article() + article2.name = u'Some article2' + article2.tags.append(tag1) + + self.session.add(article2) + self.session.commit() + + article1.name = u'Some other name' + self.session.commit() + + assert len(article1.versions[1].tags) == 1 + def test_multi_insert(self): article = self.Article() article.name = u'Some article' @@ -339,4 +367,110 @@ def test_multiple_inserts_over_multiple_transactions(self): assert reference2.versions[0] in article.versions[2].references assert len(reference1.versions[2].cited_by) == 1 - assert article.versions[2] in reference1.versions[2].cited_by \ No newline at end of file + assert article.versions[2] in reference1.versions[2].cited_by + + +@mark.skipif("os.environ.get('DB') == 'sqlite'") +class TestManyToManySelfReferentialInOtherSchema(TestManyToManySelfReferential): + def create_models(self): + class Article(self.Model): + __tablename__ = 'article' + __versioned__ = {} + __table_args__ = {'schema': 'other'} + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + article_references = sa.Table( + 'article_references', + self.Model.metadata, + sa.Column( + 'referring_id', + sa.Integer, + sa.ForeignKey('other.article.id'), + primary_key=True, + ), + sa.Column( + 'referred_id', + sa.Integer, + sa.ForeignKey('other.article.id'), + primary_key=True + ), + schema='other' + ) + + Article.references = sa.orm.relationship( + Article, + secondary=article_references, + primaryjoin=Article.id == article_references.c.referring_id, + secondaryjoin=Article.id == article_references.c.referred_id, + backref='cited_by' + ) + + self.Article = Article + self.referenced_articles_table = article_references + + def create_tables(self): + self.connection.execute('DROP SCHEMA IF EXISTS other') + self.connection.execute('CREATE SCHEMA other') + TestManyToManySelfReferential.create_tables(self) + + +@mark.skipif("os.environ.get('DB') == 'sqlite'") +class ManyToManyRelationshipsInOtherSchemaTestCase(ManyToManyRelationshipsTestCase): + def create_models(self): + class Article(self.Model): + __tablename__ = 'article' + __versioned__ = { + 'base_classes': (self.Model, ) + } + __table_args__ = {'schema': 'other'} + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + article_tag = sa.Table( + 'article_tag', + self.Model.metadata, + sa.Column( + 'article_id', + sa.Integer, + sa.ForeignKey('other.article.id'), + primary_key=True, + ), + sa.Column( + 'tag_id', + sa.Integer, + sa.ForeignKey('other.tag.id'), + primary_key=True + ), + schema='other' + ) + + class Tag(self.Model): + __tablename__ = 'tag' + __versioned__ = { + 'base_classes': (self.Model, ) + } + __table_args__ = {'schema': 'other'} + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + Tag.articles = sa.orm.relationship( + Article, + secondary=article_tag, + backref='tags' + ) + + self.Article = Article + self.Tag = Tag + + + def create_tables(self): + self.connection.execute('DROP SCHEMA IF EXISTS other') + self.connection.execute('CREATE SCHEMA other') + ManyToManyRelationshipsTestCase.create_tables(self) + +create_test_cases(ManyToManyRelationshipsInOtherSchemaTestCase) + diff --git a/tests/test_changeset.py b/tests/test_changeset.py index e13b1712..f72ccb88 100644 --- a/tests/test_changeset.py +++ b/tests/test_changeset.py @@ -55,7 +55,11 @@ def test_changeset_for_history_that_does_not_have_first_insert(self): ''' % (self.transaction_column_name, tx_log.id) ) - assert self.session.query(self.ArticleVersion).first().changeset == {} + assert self.session.query(self.ArticleVersion).first().changeset == { + 'content': [None, 'some content'], + 'id': [None, 1], + 'name': [None, 'something'] + } class TestChangeSetWithValidityStrategy(ChangeSetTestCase): @@ -71,7 +75,7 @@ def create_models(self): class Article(self.Model): __tablename__ = 'article' __versioned__ = { - 'base_classes': (self.Model, ) + 'base_classes': (self.Model,) } id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) @@ -82,19 +86,22 @@ class Article(self.Model): class Tag(self.Model): __tablename__ = 'tag' __versioned__ = { - 'base_classes': (self.Model, ) + 'base_classes': (self.Model,) } id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id)) article = sa.orm.relationship(Article, backref='tags') - - Article.tag_count = sa.orm.column_property( - sa.select([sa.func.count(Tag.id)]) - .where(Tag.article_id == Article.id) - .correlate_except(Tag) - ) + + subquery = (sa.select([sa.func.count(Tag.id)]) + .where(Tag.article_id == Article.id) + .correlate_except(Tag)) + try: + subquery = subquery.scalar_subquery() + except AttributeError: # SQLAlchemy < 1.4 + subquery = subquery.as_scalar() + Article.tag_count = sa.orm.column_property(subquery) self.Article = Article self.Tag = Tag diff --git a/tests/test_column_inclusion_and_exclusion.py b/tests/test_column_inclusion_and_exclusion.py index e8530d2d..e916b383 100644 --- a/tests/test_column_inclusion_and_exclusion.py +++ b/tests/test_column_inclusion_and_exclusion.py @@ -53,3 +53,52 @@ class TextItem(self.Model): content = sa.Column('_content', sa.UnicodeText) self.TextItem = TextItem + + +class TestColumnExclusionWithRelationship(TestCase): + def create_models(self): + + class Word(self.Model): + __tablename__ = 'word' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + word = sa.Column(sa.Unicode(255)) + + class TextItemWord(self.Model): + __tablename__ = 'text_item_word' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + text_item_id = sa.Column(sa.Integer, sa.ForeignKey('text_item.id'), nullable=False) + word_id = sa.Column(sa.Integer, sa.ForeignKey('word.id'), nullable=False) + + class TextItem(self.Model): + __tablename__ = 'text_item' + __versioned__ = { + 'exclude': ['content'] + } + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + content = sa.orm.relationship(Word, secondary='text_item_word') + + self.TextItem = TextItem + self.Word = Word + + def test_excluded_columns_not_included_in_version_class(self): + cls = version_class(self.TextItem) + manager = cls._sa_class_manager + assert 'content' not in manager.keys() + + def test_versioning_with_column_exclusion(self): + item = self.TextItem(name=u'Some textitem', + content=[self.Word(word=u'bird')]) + self.session.add(item) + self.session.commit() + + assert item.versions[0].name == u'Some textitem' + + def test_does_not_create_record_if_only_excluded_column_updated(self): + item = self.TextItem(name=u'Some textitem') + self.session.add(item) + self.session.commit() + item.content.append(self.Word(word=u'Some content')) + self.session.commit() + assert item.versions.count() == 1 diff --git a/tests/test_mapper_args.py b/tests/test_mapper_args.py index 356f85f1..bb14ffb5 100644 --- a/tests/test_mapper_args.py +++ b/tests/test_mapper_args.py @@ -1,3 +1,6 @@ +from pytest import mark +from packaging import version + import sqlalchemy as sa from sqlalchemy_continuum import version_class from tests import TestCase @@ -29,6 +32,7 @@ def test_supports_column_prefix(self): assert self.TextItem._id +@mark.skipif("version.parse(sa.__version__) >= version.parse('1.4')") class TestOrderByWithStringArg(TestCase): def create_models(self): class TextItem(self.Model): @@ -55,6 +59,7 @@ def test_reflects_order_by(self): assert self.TextItemVersion.__mapper_args__['order_by'] == 'id' +@mark.skipif("version.parse(sa.__version__) >= version.parse('1.4')") class TestOrderByWithInstrumentedAttribute(TestCase): def create_models(self): class TextItem(self.Model): diff --git a/tests/test_sessions.py b/tests/test_sessions.py index f5780f89..6a1fbfb0 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -52,3 +52,19 @@ class TestUnitOfWork(TestCase): def test_with_session_arg(self): uow = versioning_manager.unit_of_work(self.session) assert isinstance(uow, UnitOfWork) + + +class TestExternalTransactionSession(TestCase): + + def test_session_with_external_transaction(self): + conn = self.engine.connect() + t = conn.begin() + session = Session(bind=conn) + + article = self.Article(name=u'My Session Article') + session.add(article) + session.flush() + + session.close() + t.rollback() + conn.close() diff --git a/tests/test_transaction.py b/tests/test_transaction.py index f647130d..2b3e7b56 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,6 +1,9 @@ import sqlalchemy as sa from sqlalchemy_continuum import versioning_manager from tests import TestCase +from pytest import mark +from sqlalchemy_continuum.plugins import TransactionMetaPlugin + class TestTransaction(TestCase): @@ -37,6 +40,19 @@ def test_repr(self): repr(transaction) ) + def test_changed_entities(self): + article_v0 = self.article.versions[0] + transaction = article_v0.transaction + assert transaction.changed_entities == { + self.ArticleVersion: [article_v0], + self.TagVersion: [self.article.tags[0].versions[0]], + } + + +# Check that the tests pass without TransactionChangesPlugin +class TestTransactionWithoutChangesPlugin(TestTransaction): + plugins = [TransactionMetaPlugin()] + class TestAssigningUserClass(TestCase): user_cls = 'User' @@ -56,3 +72,31 @@ class User(self.Model): def test_copies_primary_key_type_from_user_class(self): attr = versioning_manager.transaction_cls.user_id assert isinstance(attr.property.columns[0].type, sa.Unicode) + + +@mark.skipif("os.environ.get('DB') == 'sqlite'") +class TestAssigningUserClassInOtherSchema(TestCase): + user_cls = 'User' + + def create_models(self): + class User(self.Model): + __tablename__ = 'user' + __versioned__ = { + 'base_classes': (self.Model,) + } + __table_args__ = {'schema': 'other'} + + id = sa.Column(sa.Unicode(255), primary_key=True) + name = sa.Column(sa.Unicode(255), nullable=False) + + self.User = User + + def create_tables(self): + self.connection.execute('DROP SCHEMA IF EXISTS other') + self.connection.execute('CREATE SCHEMA other') + TestCase.create_tables(self) + + def test_can_build_transaction_model(self): + # If create_models didn't crash this should be good + pass + diff --git a/tox.ini b/tox.ini index 0f2033d4..2d9b2fd2 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py27, py33, py34, py35 +envlist = py27, py33, py34, py35, py36, py37 [testenv] commands = pip install -e ".[test]"