diff --git a/google/cloud/spanner_dbapi/batch_dml_executor.py b/google/cloud/spanner_dbapi/batch_dml_executor.py index e4483b4525..65176a3837 100644 --- a/google/cloud/spanner_dbapi/batch_dml_executor.py +++ b/google/cloud/spanner_dbapi/batch_dml_executor.py @@ -2,7 +2,6 @@ from enum import Enum from typing import TYPE_CHECKING, List -from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.parsed_statement import ( ParsedStatement, StatementType, @@ -11,6 +10,9 @@ from google.rpc.code_pb2 import ABORTED, OK from google.api_core.exceptions import Aborted +from google.cloud.spanner_dbapi.transaction_helper import ( + _get_batch_statements_result_checksum, +) from google.cloud.spanner_dbapi.utils import StreamedManyResultSets if TYPE_CHECKING: @@ -69,6 +71,7 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]): from google.cloud.spanner_dbapi import OperationalError connection = cursor.connection + transaction_helper = connection._transaction_helper many_result_set = StreamedManyResultSets() statements_tuple = [] for statement in statements: @@ -78,28 +81,23 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]): many_result_set.add_iter(res) cursor._row_count = sum([max(val, 0) for val in res]) else: - retried = False while True: try: transaction = connection.transaction_checkout() status, res = transaction.batch_update(statements_tuple) - many_result_set.add_iter(res) - res_checksum = ResultsChecksum() - res_checksum.consume_result(res) - res_checksum.consume_result(status.code) - if not retried: - connection._statements.append((statements, res_checksum)) - cursor._row_count = sum([max(val, 0) for val in res]) - if status.code == ABORTED: connection._transaction = None raise Aborted(status.message) elif status.code != OK: raise OperationalError(status.message) + + checksum = _get_batch_statements_result_checksum(res, status.code) + many_result_set.add_iter(res) + transaction_helper._batch_statements_list.append((statements, checksum)) + cursor._row_count = sum([max(val, 0) for val in res]) return many_result_set except Aborted: - connection.retry_transaction() - retried = True + transaction_helper.retry_transaction() def _do_batch_update(transaction, statements): diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index ec8951493c..e65d55b00b 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -13,7 +13,6 @@ # limitations under the License. """DB-API Connection for the Google Cloud Spanner.""" -import time import warnings from google.api_core.exceptions import Aborted @@ -21,13 +20,11 @@ from google.cloud import spanner_v1 as spanner from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement +from google.cloud.spanner_dbapi.transaction_helper import TransactionHelper from google.cloud.spanner_v1 import RequestOptions -from google.cloud.spanner_v1.session import _get_retry_delay from google.cloud.spanner_v1.snapshot import Snapshot from deprecated import deprecated -from google.cloud.spanner_dbapi.checksum import _compare_checksums -from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Cursor from google.cloud.spanner_dbapi.exceptions import ( InterfaceError, @@ -37,13 +34,10 @@ from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT from google.cloud.spanner_dbapi.version import PY_VERSION -from google.rpc.code_pb2 import ABORTED - CLIENT_TRANSACTION_NOT_STARTED_WARNING = ( "This method is non-operational as a transaction has not been started." ) -MAX_INTERNAL_RETRIES = 50 def check_not_closed(function): @@ -99,9 +93,6 @@ def __init__(self, instance, database=None, read_only=False): self._transaction = None self._session = None self._snapshot = None - # SQL statements, which were executed - # within the current transaction - self._statements = [] self.is_closed = False self._autocommit = False @@ -118,6 +109,7 @@ def __init__(self, instance, database=None, read_only=False): self._spanner_transaction_started = False self._batch_mode = BatchMode.NONE self._batch_dml_executor: BatchDmlExecutor = None + self._transaction_helper = TransactionHelper(self) @property def autocommit(self): @@ -299,76 +291,6 @@ def _release_session(self): self.database._pool.put(self._session) self._session = None - def retry_transaction(self): - """Retry the aborted transaction. - - All the statements executed in the original transaction - will be re-executed in new one. Results checksums of the - original statements and the retried ones will be compared. - - :raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted` - If results checksum of the retried statement is - not equal to the checksum of the original one. - """ - attempt = 0 - while True: - self._spanner_transaction_started = False - attempt += 1 - if attempt > MAX_INTERNAL_RETRIES: - raise - - try: - self._rerun_previous_statements() - break - except Aborted as exc: - delay = _get_retry_delay(exc.errors[0], attempt) - if delay: - time.sleep(delay) - - def _rerun_previous_statements(self): - """ - Helper to run all the remembered statements - from the last transaction. - """ - for statement in self._statements: - if isinstance(statement, list): - statements, checksum = statement - - transaction = self.transaction_checkout() - statements_tuple = [] - for single_statement in statements: - statements_tuple.append(single_statement.get_tuple()) - status, res = transaction.batch_update(statements_tuple) - - if status.code == ABORTED: - raise Aborted(status.details) - - retried_checksum = ResultsChecksum() - retried_checksum.consume_result(res) - retried_checksum.consume_result(status.code) - - _compare_checksums(checksum, retried_checksum) - else: - res_iter, retried_checksum = self.run_statement(statement, retried=True) - # executing all the completed statements - if statement != self._statements[-1]: - for res in res_iter: - retried_checksum.consume_result(res) - - _compare_checksums(statement.checksum, retried_checksum) - # executing the failed statement - else: - # streaming up to the failed result or - # to the end of the streaming iterator - while len(retried_checksum) < len(statement.checksum): - try: - res = next(iter(res_iter)) - retried_checksum.consume_result(res) - except StopIteration: - break - - _compare_checksums(statement.checksum, retried_checksum) - def transaction_checkout(self): """Get a Cloud Spanner transaction. @@ -461,11 +383,12 @@ def commit(self): if self._spanner_transaction_started and not self._read_only: self._transaction.commit() except Aborted: - self.retry_transaction() + self._transaction_helper.retry_transaction() self.commit() finally: self._release_session() - self._statements = [] + self._transaction_helper._single_statements = [] + self._transaction_helper._batch_statements_list = [] self._transaction_begin_marked = False self._spanner_transaction_started = False @@ -485,7 +408,8 @@ def rollback(self): self._transaction.rollback() finally: self._release_session() - self._statements = [] + self._transaction_helper._single_statements = [] + self._transaction_helper._batch_statements_list = [] self._transaction_begin_marked = False self._spanner_transaction_started = False @@ -504,7 +428,7 @@ def run_prior_DDL_statements(self): return self.database.update_ddl(ddl_statements).result() - def run_statement(self, statement: Statement, retried=False): + def run_statement(self, statement: Statement): """Run single SQL statement in begun transaction. This method is never used in autocommit mode. In @@ -524,17 +448,11 @@ def run_statement(self, statement: Statement, retried=False): checksum of this statement results. """ transaction = self.transaction_checkout() - if not retried: - self._statements.append(statement) - - return ( - transaction.execute_sql( - statement.sql, - statement.params, - param_types=statement.param_types, - request_options=self.request_options, - ), - ResultsChecksum() if retried else statement.checksum, + return transaction.execute_sql( + statement.sql, + statement.params, + param_types=statement.param_types, + request_options=self.request_options, ) @check_not_closed diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 47d028d475..1674a5cb59 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -13,7 +13,7 @@ # limitations under the License. """Database cursor for Google Cloud Spanner DB API.""" - +import itertools from collections import namedtuple import sqlparse @@ -47,6 +47,9 @@ Statement, ParsedStatement, ) +from google.cloud.spanner_dbapi.transaction_helper import ( + _get_single_statement_result_checksum, +) from google.cloud.spanner_dbapi.utils import PeekIterator from google.cloud.spanner_dbapi.utils import StreamedManyResultSets @@ -90,9 +93,8 @@ def __init__(self, connection): self._row_count = _UNSET_COUNT self.lastrowid = None self.connection = connection + self.transaction_helper = self.connection._transaction_helper self._is_closed = False - # the currently running SQL statement results checksum - self._checksum = None # the number of rows to fetch at a time with fetchmany() self.arraysize = 1 @@ -275,26 +277,22 @@ def _execute_in_rw_transaction(self, parsed_statement: ParsedStatement): # For every other operation, we've got to ensure that # any prior DDL statements were run. self.connection.run_prior_DDL_statements() + statement = parsed_statement.statement if self.connection._client_transaction_started: - ( - self._result_set, - self._checksum, - ) = self.connection.run_statement(parsed_statement.statement) - while True: try: - self._itr = PeekIterator(self._result_set) - break + self._result_set = self.connection.run_statement(statement) + itr, self._itr = itertools.tee(PeekIterator(self._result_set), 2) + statement.checksum = _get_single_statement_result_checksum(itr) + self.transaction_helper._single_statements.append(statement) + return except Aborted: - self.connection.retry_transaction() - except Exception as ex: - self.connection._statements.remove(parsed_statement.statement) - raise ex + self.transaction_helper.retry_transaction() else: self.connection.database.run_in_transaction( self._do_execute_update_in_autocommit, - parsed_statement.statement.sql, - parsed_statement.statement.params or None, + statement.sql, + statement.params or None, ) @check_not_closed @@ -357,17 +355,12 @@ def fetchone(self): sequence, or None when no more data is available.""" try: res = next(self) - if ( - self.connection._client_transaction_started - and not self.connection.read_only - ): - self._checksum.consume_result(res) return res except StopIteration: return except Aborted: if not self.connection.read_only: - self.connection.retry_transaction() + self.transaction_helper.retry_transaction() return self.fetchone() @check_not_closed @@ -378,15 +371,10 @@ def fetchall(self): res = [] try: for row in self: - if ( - self.connection._client_transaction_started - and not self.connection.read_only - ): - self._checksum.consume_result(row) res.append(row) except Aborted: if not self.connection.read_only: - self.connection.retry_transaction() + self.transaction_helper.retry_transaction() return self.fetchall() return res @@ -410,17 +398,12 @@ def fetchmany(self, size=None): for _ in range(size): try: res = next(self) - if ( - self.connection._client_transaction_started - and not self.connection.read_only - ): - self._checksum.consume_result(res) items.append(res) except StopIteration: break except Aborted: if not self.connection.read_only: - self.connection.retry_transaction() + self.transaction_helper.retry_transaction() return self.fetchmany(size) return items diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 76ac951e0c..29da84ee73 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -24,7 +24,6 @@ from . import client_side_statement_parser from deprecated import deprecated -from .checksum import ResultsChecksum from .exceptions import Error from .parsed_statement import ParsedStatement, StatementType, Statement from .types import DateStr, TimestampStr @@ -230,7 +229,6 @@ def classify_statement(query, args=None): query, args, get_param_types(args or None), - ResultsChecksum(), ) if RE_DDL.match(query): return ParsedStatement(StatementType.DDL, statement) diff --git a/google/cloud/spanner_dbapi/transaction_helper.py b/google/cloud/spanner_dbapi/transaction_helper.py new file mode 100644 index 0000000000..f0f10a4a49 --- /dev/null +++ b/google/cloud/spanner_dbapi/transaction_helper.py @@ -0,0 +1,105 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List + +import time +from google.api_core.exceptions import Aborted + +from google.cloud.spanner_dbapi.utils import PeekIterator + +if TYPE_CHECKING: + from google.cloud.spanner_dbapi import Connection +from google.cloud.spanner_dbapi.checksum import ResultsChecksum, _compare_checksums +from google.cloud.spanner_dbapi.parsed_statement import Statement +from google.cloud.spanner_v1.session import _get_retry_delay + +from google.rpc.code_pb2 import ABORTED + +MAX_INTERNAL_RETRIES = 50 + + +class TransactionHelper: + def __init__(self, connection: "Connection"): + self._connection = connection + # Non-Batch statements, which were executed within the current + # transaction + self._single_statements: List[Statement] = [] + # Batch statements, which were executed within the current transaction + self._batch_statements_list: List[(List[Statement], ResultsChecksum)] = [] + + def retry_transaction(self): + """Retry the aborted transaction. + + All the statements executed in the original transaction + will be re-executed in new one. Results checksums of the + original statements and the retried ones will be compared. + + :raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted` + If results checksum of the retried statement is + not equal to the checksum of the original one. + """ + attempt = 0 + while True: + self._connection._spanner_transaction_started = False + attempt += 1 + if attempt > MAX_INTERNAL_RETRIES: + raise + + try: + self._rerun_previous_statements() + self._rerun_previous_batch_statements() + break + except Aborted as exc: + delay = _get_retry_delay(exc.errors[0], attempt) + if delay: + time.sleep(delay) + + def _rerun_previous_batch_statements(self): + """ + Helper to run all the remembered statements from the last transaction. + """ + for batch_statements, original_checksum in self._batch_statements_list: + transaction = self._connection.transaction_checkout() + statements_tuple = [] + for single_statement in batch_statements: + statements_tuple.append(single_statement.get_tuple()) + status, res = transaction.batch_update(statements_tuple) + if status.code == ABORTED: + raise Aborted(status.details) + + retried_checksum = _get_batch_statements_result_checksum(res, status.code) + _compare_checksums(original_checksum, retried_checksum) + + def _rerun_previous_statements(self): + for single_statement in self._single_statements: + res_iter = self._connection.run_statement(single_statement) + retried_checksum = _get_single_statement_result_checksum( + PeekIterator(res_iter) + ) + _compare_checksums(single_statement.checksum, retried_checksum) + + +def _get_batch_statements_result_checksum(res, status_code): + retried_checksum = ResultsChecksum() + retried_checksum.consume_result(res) + retried_checksum.consume_result(status_code) + return retried_checksum + + +def _get_single_statement_result_checksum(res_iter): + retried_checksum = ResultsChecksum() + for res in res_iter: + retried_checksum.consume_result(res) + return retried_checksum diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index cf1a01e6dd..89f7bfc7d3 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -13,8 +13,6 @@ # limitations under the License. import datetime -import hashlib -import pickle import pytest import time @@ -89,25 +87,11 @@ def _execute_common_statements(self, cursor): VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') """ ) - cursor.execute( - """ - UPDATE contacts - SET first_name = 'updated-first-name' - WHERE first_name = 'first-name' - """ - ) - cursor.execute( - """ - UPDATE contacts - SET email = 'test.email_updated@domen.ru' - WHERE email = 'test.email@domen.ru' - """ - ) return ( 1, - "updated-first-name", + "first-name", "last-name", - "test.email_updated@domen.ru", + "test.email@domen.ru", ) @pytest.mark.parametrize("client_side", [True, False]) @@ -126,7 +110,6 @@ def test_commit(self, client_side): assert got_rows == [updated_row] - @pytest.mark.skip(reason="b/315807641") def test_commit_exception(self): """Test that if exception during commit method is caught, then subsequent operations on same Cursor and Connection object works @@ -148,7 +131,6 @@ def test_commit_exception(self): assert got_rows == [updated_row] - @pytest.mark.skip(reason="b/315807641") def test_rollback_exception(self): """Test that if exception during rollback method is caught, then subsequent operations on same Cursor and Connection object works @@ -170,7 +152,6 @@ def test_rollback_exception(self): assert got_rows == [updated_row] - @pytest.mark.skip(reason="b/315807641") def test_cursor_execute_exception(self): """Test that if exception in Cursor's execute method is caught when Connection is not in autocommit mode, then subsequent operations on @@ -633,32 +614,6 @@ def test_rollback_on_connection_closing(self, shared_instance, dbapi_database): cursor.close() conn.close() - def test_results_checksum(self): - """Test that results checksum is calculated properly.""" - - self._cursor.execute( - """ - INSERT INTO contacts (contact_id, first_name, last_name, email) - VALUES - (1, 'first-name', 'last-name', 'test.email@domen.ru'), - (2, 'first-name2', 'last-name2', 'test.email2@domen.ru') - """ - ) - assert len(self._conn._statements) == 1 - self._conn.commit() - - self._cursor.execute("SELECT * FROM contacts") - got_rows = self._cursor.fetchall() - - assert len(self._conn._statements) == 1 - self._conn.commit() - - checksum = hashlib.sha256() - checksum.update(pickle.dumps(got_rows[0])) - checksum.update(pickle.dumps(got_rows[1])) - - assert self._cursor._checksum.checksum.digest() == checksum.digest() - def test_execute_many(self): row_data = [ (1, "first-name", "last-name", "test.email@example.com"), diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 022d50c522..4559006623 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -470,83 +470,6 @@ def test_begin(self): self.assertEqual(self._under_test._transaction_begin_marked, True) - def test_run_statement_wo_retried(self): - """Check that Connection remembers executed statements.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - sql = """SELECT 23 FROM table WHERE id = @a1""" - params = {"a1": "value"} - param_types = {"a1": str} - - connection = self._make_connection() - connection.transaction_checkout = mock.Mock() - statement = Statement(sql, params, param_types, ResultsChecksum()) - connection.run_statement(statement) - - self.assertEqual(connection._statements[0].sql, sql) - self.assertEqual(connection._statements[0].params, params) - self.assertEqual(connection._statements[0].param_types, param_types) - self.assertIsInstance(connection._statements[0].checksum, ResultsChecksum) - - def test_run_statement_w_retried(self): - """Check that Connection doesn't remember re-executed statements.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - sql = """SELECT 23 FROM table WHERE id = @a1""" - params = {"a1": "value"} - param_types = {"a1": str} - - connection = self._make_connection() - connection.transaction_checkout = mock.Mock() - statement = Statement(sql, params, param_types, ResultsChecksum()) - connection.run_statement(statement, retried=True) - - self.assertEqual(len(connection._statements), 0) - - def test_run_statement_w_heterogenous_insert_statements(self): - """Check that Connection executed heterogenous insert statements.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - from google.rpc.status_pb2 import Status - from google.rpc.code_pb2 import OK - - sql = "INSERT INTO T (f1, f2) VALUES (1, 2)" - params = None - param_types = None - - connection = self._make_connection() - transaction = mock.MagicMock() - connection.transaction_checkout = mock.Mock(return_value=transaction) - transaction.batch_update = mock.Mock(return_value=(Status(code=OK), 1)) - statement = Statement(sql, params, param_types, ResultsChecksum()) - - connection.run_statement(statement, retried=True) - - self.assertEqual(len(connection._statements), 0) - - def test_run_statement_w_homogeneous_insert_statements(self): - """Check that Connection executed homogeneous insert statements.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - from google.rpc.status_pb2 import Status - from google.rpc.code_pb2 import OK - - sql = "INSERT INTO T (f1, f2) VALUES (%s, %s), (%s, %s)" - params = ["a", "b", "c", "d"] - param_types = {"f1": str, "f2": str} - - connection = self._make_connection() - transaction = mock.MagicMock() - connection.transaction_checkout = mock.Mock(return_value=transaction) - transaction.batch_update = mock.Mock(return_value=(Status(code=OK), 1)) - statement = Statement(sql, params, param_types, ResultsChecksum()) - - connection.run_statement(statement, retried=True) - - self.assertEqual(len(connection._statements), 0) - @mock.patch("google.cloud.spanner_v1.transaction.Transaction") def test_commit_clears_statements(self, mock_transaction): """ @@ -556,13 +479,13 @@ def test_commit_clears_statements(self, mock_transaction): connection = self._make_connection() connection._spanner_transaction_started = True connection._transaction = mock.Mock() - connection._statements = [{}, {}] + connection._transaction_helper._single_statements = [{}, {}] - self.assertEqual(len(connection._statements), 2) + self.assertEqual(len(connection._transaction_helper._single_statements), 2) connection.commit() - self.assertEqual(len(connection._statements), 0) + self.assertEqual(len(connection._transaction_helper._single_statements), 0) @mock.patch("google.cloud.spanner_v1.transaction.Transaction") def test_rollback_clears_statements(self, mock_transaction): @@ -573,244 +496,30 @@ def test_rollback_clears_statements(self, mock_transaction): connection = self._make_connection() connection._spanner_transaction_started = True connection._transaction = mock_transaction - connection._statements = [{}, {}] + connection._transaction_helper._single_statements = [{}, {}] - self.assertEqual(len(connection._statements), 2) + self.assertEqual(len(connection._transaction_helper._single_statements), 2) connection.rollback() - self.assertEqual(len(connection._statements), 0) - - def test_retry_transaction_w_checksum_match(self): - """Check retrying an aborted transaction.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = ["field1", "field2"] - connection = self._make_connection() - checksum = ResultsChecksum() - checksum.consume_result(row) - - retried_checkum = ResultsChecksum() - run_mock = connection.run_statement = mock.Mock() - run_mock.return_value = ([row], retried_checkum) - - statement = Statement("SELECT 1", [], {}, checksum) - connection._statements.append(statement) - - with mock.patch( - "google.cloud.spanner_dbapi.connection._compare_checksums" - ) as compare_mock: - connection.retry_transaction() - - compare_mock.assert_called_with(checksum, retried_checkum) - run_mock.assert_called_with(statement, retried=True) - - def test_retry_transaction_w_checksum_mismatch(self): - """ - Check retrying an aborted transaction - with results checksums mismatch. - """ - from google.cloud.spanner_dbapi.exceptions import RetryAborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = ["field1", "field2"] - retried_row = ["field3", "field4"] - connection = self._make_connection() - - checksum = ResultsChecksum() - checksum.consume_result(row) - retried_checkum = ResultsChecksum() - run_mock = connection.run_statement = mock.Mock() - run_mock.return_value = ([retried_row], retried_checkum) - - statement = Statement("SELECT 1", [], {}, checksum) - connection._statements.append(statement) - - with self.assertRaises(RetryAborted): - connection.retry_transaction() + self.assertEqual(len(connection._transaction_helper._single_statements), 0) @mock.patch("google.cloud.spanner_v1.Client") def test_commit_retry_aborted_statements(self, mock_client): """Check that retried transaction executing the same statements.""" from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = ["field1", "field2"] connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) mock_transaction = mock.Mock() connection._spanner_transaction_started = True connection._transaction = mock_transaction mock_transaction.commit.side_effect = [Aborted("Aborted"), None] - run_mock = connection.run_statement = mock.Mock() - run_mock.return_value = ([row], ResultsChecksum()) + run_mock = connection._transaction_helper = mock.Mock() connection.commit() - run_mock.assert_called_with(statement, retried=True) - - @mock.patch("google.cloud.spanner_v1.Client") - def test_retry_aborted_retry(self, mock_client): - """ - Check that in case of a retried transaction failed, - the connection will retry it once again. - """ - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = ["field1", "field2"] - - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) - metadata_mock = mock.Mock() - metadata_mock.trailing_metadata.return_value = {} - run_mock = connection.run_statement = mock.Mock() - run_mock.side_effect = [ - Aborted("Aborted", errors=[metadata_mock]), - ([row], ResultsChecksum()), - ] - - connection.retry_transaction() - - run_mock.assert_has_calls( - ( - mock.call(statement, retried=True), - mock.call(statement, retried=True), - ) - ) - - def test_retry_transaction_raise_max_internal_retries(self): - """Check retrying raise an error of max internal retries.""" - from google.cloud.spanner_dbapi import connection as conn - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - conn.MAX_INTERNAL_RETRIES = 0 - row = ["field1", "field2"] - connection = self._make_connection() - - checksum = ResultsChecksum() - checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, checksum) - connection._statements.append(statement) - - with self.assertRaises(Exception): - connection.retry_transaction() - - conn.MAX_INTERNAL_RETRIES = 50 - - @mock.patch("google.cloud.spanner_v1.Client") - def test_retry_aborted_retry_without_delay(self, mock_client): - """ - Check that in case of a retried transaction failed, - the connection will retry it once again. - """ - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = ["field1", "field2"] - - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) - metadata_mock = mock.Mock() - metadata_mock.trailing_metadata.return_value = {} - run_mock = connection.run_statement = mock.Mock() - run_mock.side_effect = [ - Aborted("Aborted", errors=[metadata_mock]), - ([row], ResultsChecksum()), - ] - connection._get_retry_delay = mock.Mock(return_value=False) - - connection.retry_transaction() - - run_mock.assert_has_calls( - ( - mock.call(statement, retried=True), - mock.call(statement, retried=True), - ) - ) - - def test_retry_transaction_w_multiple_statement(self): - """Check retrying an aborted transaction.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = ["field1", "field2"] - connection = self._make_connection() - - checksum = ResultsChecksum() - checksum.consume_result(row) - retried_checkum = ResultsChecksum() - - statement = Statement("SELECT 1", [], {}, checksum) - statement1 = Statement("SELECT 2", [], {}, checksum) - connection._statements.append(statement) - connection._statements.append(statement1) - run_mock = connection.run_statement = mock.Mock() - run_mock.return_value = ([row], retried_checkum) - - with mock.patch( - "google.cloud.spanner_dbapi.connection._compare_checksums" - ) as compare_mock: - connection.retry_transaction() - - compare_mock.assert_called_with(checksum, retried_checkum) - - run_mock.assert_called_with(statement1, retried=True) - - def test_retry_transaction_w_empty_response(self): - """Check retrying an aborted transaction.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = [] - connection = self._make_connection() - - checksum = ResultsChecksum() - checksum.count = 1 - retried_checkum = ResultsChecksum() - - statement = Statement("SELECT 1", [], {}, checksum) - connection._statements.append(statement) - run_mock = connection.run_statement = mock.Mock() - run_mock.return_value = ([row], retried_checkum) - - with mock.patch( - "google.cloud.spanner_dbapi.connection._compare_checksums" - ) as compare_mock: - connection.retry_transaction() - - compare_mock.assert_called_with(checksum, retried_checkum) - - run_mock.assert_called_with(statement, retried=True) + assert run_mock.retry_transaction.called def test_validate_ok(self): connection = self._make_connection() diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 3328b0e17f..5ac3981d35 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -17,6 +17,7 @@ import sys import unittest +from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.parsed_statement import ( ParsedStatement, StatementType, @@ -44,7 +45,7 @@ def _make_connection(self, *args, **kwargs): def _transaction_mock(self, mock_response=[]): from google.rpc.code_pb2 import OK - transaction = mock.Mock(committed=False, rolled_back=False) + transaction = mock.Mock() transaction.batch_update = mock.Mock( return_value=[mock.Mock(code=OK), mock_response] ) @@ -175,8 +176,6 @@ def test_execute_database_error(self): cursor.execute(sql="SELECT 1") def test_execute_autocommit_off(self): - from google.cloud.spanner_dbapi.utils import PeekIterator - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) cursor.connection._autocommit = False @@ -184,30 +183,24 @@ def test_execute_autocommit_off(self): cursor.execute("sql") self.assertIsInstance(cursor._result_set, mock.MagicMock) - self.assertIsInstance(cursor._itr, PeekIterator) def test_execute_insert_statement_autocommit_off(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.utils import PeekIterator - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) cursor.connection._autocommit = False cursor.connection.transaction_checkout = mock.MagicMock(autospec=True) - cursor._checksum = ResultsChecksum() sql = "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_statement", - return_value=ParsedStatement(StatementType.UPDATE, sql), + return_value=ParsedStatement(StatementType.UPDATE, Statement(sql)), ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=(mock.MagicMock(), ResultsChecksum()), + return_value=(mock.MagicMock()), ): cursor.execute(sql) self.assertIsInstance(cursor._result_set, mock.MagicMock) - self.assertIsInstance(cursor._itr, PeekIterator) def test_execute_statement(self): connection = self._make_connection(self.INSTANCE, mock.MagicMock()) @@ -547,7 +540,7 @@ def test_executemany_insert_batch_failed(self): connection.autocommit = True cursor = connection.cursor() - transaction = mock.Mock(committed=False, rolled_back=False) + transaction = mock.Mock() transaction.batch_update = mock.Mock( return_value=(mock.Mock(code=UNKNOWN, message=err_details), []) ) @@ -565,7 +558,6 @@ def test_executemany_insert_batch_failed(self): def test_executemany_insert_batch_aborted(self): from google.cloud.spanner_dbapi import connect - from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_v1.param_types import INT64 from google.rpc.code_pb2 import ABORTED @@ -574,7 +566,7 @@ def test_executemany_insert_batch_aborted(self): connection = connect("test-instance", "test-database") - transaction1 = mock.Mock(committed=False, rolled_back=False) + transaction1 = mock.Mock() transaction1.batch_update = mock.Mock( side_effect=[(mock.Mock(code=ABORTED, message=err_details), [])] ) @@ -584,7 +576,7 @@ def test_executemany_insert_batch_aborted(self): connection.transaction_checkout = mock.Mock( side_effect=[transaction1, transaction2] ) - connection.retry_transaction = mock.Mock() + connection._transaction_helper.retry_transaction = mock.Mock() cursor = connection.cursor() cursor.executemany(sql, [(1, 2, 3, 4), (5, 6, 7, 8)]) @@ -617,10 +609,10 @@ def test_executemany_insert_batch_aborted(self): ), ] ) - connection.retry_transaction.assert_called_once() + connection._transaction_helper.retry_transaction.assert_called_once() self.assertEqual( - connection._statements[0][0], + connection._transaction_helper._batch_statements_list[0][0], [ Statement( """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", @@ -634,7 +626,9 @@ def test_executemany_insert_batch_aborted(self): ), ], ) - self.assertIsInstance(connection._statements[0][1], ResultsChecksum) + self.assertIsInstance( + connection._transaction_helper._batch_statements_list[0][1], ResultsChecksum + ) @mock.patch("google.cloud.spanner_v1.Client") def test_executemany_database_error(self, mock_client): @@ -650,8 +644,6 @@ def test_executemany_database_error(self, mock_client): sys.version_info[0] < 3, "Python 2 has an outdated iterator definition" ) def test_fetchone(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) cursor._checksum = ResultsChecksum() @@ -665,8 +657,6 @@ def test_fetchone(self): sys.version_info[0] < 3, "Python 2 has an outdated iterator definition" ) def test_fetchone_w_autocommit(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) connection.autocommit = True cursor = self._make_one(connection) @@ -678,8 +668,6 @@ def test_fetchone_w_autocommit(self): self.assertIsNone(cursor.fetchone()) def test_fetchmany(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) cursor._checksum = ResultsChecksum() @@ -692,8 +680,6 @@ def test_fetchmany(self): self.assertEqual(result, lst[1:]) def test_fetchmany_w_autocommit(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) connection.autocommit = True cursor = self._make_one(connection) @@ -707,8 +693,6 @@ def test_fetchmany_w_autocommit(self): self.assertEqual(result, lst[1:]) def test_fetchall(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) cursor._checksum = ResultsChecksum() @@ -717,8 +701,6 @@ def test_fetchall(self): self.assertEqual(cursor.fetchall(), lst) def test_fetchall_w_autocommit(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) connection.autocommit = True cursor = self._make_one(connection) @@ -896,7 +878,10 @@ def test_get_table_column_schema(self): self.assertEqual(result, expected) @mock.patch("google.cloud.spanner_v1.Client") - def test_peek_iterator_aborted(self, mock_client): + @mock.patch( + "google.cloud.spanner_dbapi.cursor._get_single_statement_result_checksum" + ) + def test_peek_iterator_aborted(self, mock_result_checksum, mock_client): """ Checking that an Aborted exception is retried in case it happened while streaming the first element with a PeekIterator. @@ -905,41 +890,38 @@ def test_peek_iterator_aborted(self, mock_client): from google.cloud.spanner_dbapi.connection import connect connection = connect("test-instance", "test-database") - cursor = connection.cursor() with mock.patch( "google.cloud.spanner_dbapi.utils.PeekIterator.__init__", side_effect=(Aborted("Aborted"), None), ): with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + "google.cloud.spanner_dbapi.transaction_helper.TransactionHelper.retry_transaction" ) as retry_mock: with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=((1, 2, 3), None), + return_value=(1, 2, 3), ): cursor.execute("SELECT * FROM table_name") - retry_mock.assert_called_with() + retry_mock.assert_called_with() @mock.patch("google.cloud.spanner_v1.Client") def test_fetchone_retry_aborted(self, mock_client): """Check that aborted fetch re-executing transaction.""" from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect connection = connect("test-instance", "test-database") cursor = connection.cursor() - cursor._checksum = ResultsChecksum() with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__next__", side_effect=(Aborted("Aborted"), None), ): with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + "google.cloud.spanner_dbapi.transaction_helper.TransactionHelper.retry_transaction" ) as retry_mock: cursor.fetchone() @@ -953,15 +935,15 @@ def test_fetchone_retry_aborted_statements(self, mock_client): from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement - row = ["field1", "field2"] + row = ("field1", "field2") connection = connect("test-instance", "test-database") cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) + checksum = ResultsChecksum() + checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) + statement = Statement("SELECT 1", [], {}, checksum) + connection._transaction_helper._single_statements.append(statement) with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__next__", @@ -969,11 +951,11 @@ def test_fetchone_retry_aborted_statements(self, mock_client): ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row], ResultsChecksum()), + return_value=([row]), ) as run_mock: cursor.fetchone() - run_mock.assert_called_with(statement, retried=True) + run_mock.assert_called_with(statement) @mock.patch("google.cloud.spanner_v1.Client") def test_fetchone_retry_aborted_statements_checksums_mismatch(self, mock_client): @@ -984,17 +966,17 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self, mock_client) from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement - row = ["field1", "field2"] - row2 = ["updated_field1", "field2"] + row = ("field1", "field2") + row2 = ("updated_field1", "field2") connection = connect("test-instance", "test-database") cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) + checksum = ResultsChecksum() + checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) + statement = Statement("SELECT 1", [], {}, checksum) + connection._transaction_helper._single_statements.append(statement) with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__next__", @@ -1002,31 +984,28 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self, mock_client) ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row2], ResultsChecksum()), + return_value=([row2]), ) as run_mock: with self.assertRaises(RetryAborted): cursor.fetchone() - run_mock.assert_called_with(statement, retried=True) + run_mock.assert_called_with(statement) @mock.patch("google.cloud.spanner_v1.Client") def test_fetchall_retry_aborted(self, mock_client): """Check that aborted fetch re-executing transaction.""" from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect connection = connect("test-instance", "test-database") cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__iter__", side_effect=(Aborted("Aborted"), iter([])), ): with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + "google.cloud.spanner_dbapi.transaction_helper.TransactionHelper.retry_transaction" ) as retry_mock: cursor.fetchall() @@ -1040,15 +1019,15 @@ def test_fetchall_retry_aborted_statements(self, mock_client): from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement - row = ["field1", "field2"] + row = ("field1", "field2") connection = connect("test-instance", "test-database") cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) + checksum = ResultsChecksum() + checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) + statement = Statement("SELECT 1", [], {}, checksum) + connection._transaction_helper._single_statements.append(statement) with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__iter__", @@ -1056,11 +1035,11 @@ def test_fetchall_retry_aborted_statements(self, mock_client): ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row], ResultsChecksum()), + return_value=[row], ) as run_mock: cursor.fetchall() - run_mock.assert_called_with(statement, retried=True) + run_mock.assert_called_with(statement) @mock.patch("google.cloud.spanner_v1.Client") def test_fetchall_retry_aborted_statements_checksums_mismatch(self, mock_client): @@ -1071,17 +1050,17 @@ def test_fetchall_retry_aborted_statements_checksums_mismatch(self, mock_client) from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement - row = ["field1", "field2"] - row2 = ["updated_field1", "field2"] + row = ("field1", "field2") + row2 = ("updated_field1", "field2") connection = connect("test-instance", "test-database") cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) + checksum = ResultsChecksum() + checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) + statement = Statement("SELECT 1", [], {}, checksum) + connection._transaction_helper._single_statements.append(statement) with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__iter__", @@ -1089,31 +1068,29 @@ def test_fetchall_retry_aborted_statements_checksums_mismatch(self, mock_client) ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row2], ResultsChecksum()), + return_value=[row2], ) as run_mock: with self.assertRaises(RetryAborted): cursor.fetchall() - run_mock.assert_called_with(statement, retried=True) + run_mock.assert_called_with(statement) @mock.patch("google.cloud.spanner_v1.Client") def test_fetchmany_retry_aborted(self, mock_client): """Check that aborted fetch re-executing transaction.""" from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect connection = connect("test-instance", "test-database") cursor = connection.cursor() - cursor._checksum = ResultsChecksum() with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__next__", side_effect=(Aborted("Aborted"), None), ): with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + "google.cloud.spanner_dbapi.transaction_helper.TransactionHelper.retry_transaction" ) as retry_mock: cursor.fetchmany() @@ -1127,15 +1104,15 @@ def test_fetchmany_retry_aborted_statements(self, mock_client): from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement - row = ["field1", "field2"] + row = ("field1", "field2") connection = connect("test-instance", "test-database") cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) + checksum = ResultsChecksum() + checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) + statement = Statement("SELECT 1", [], {}, checksum) + connection._transaction_helper._single_statements.append(statement) with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__next__", @@ -1143,11 +1120,11 @@ def test_fetchmany_retry_aborted_statements(self, mock_client): ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row], ResultsChecksum()), + return_value=([row]), ) as run_mock: cursor.fetchmany(len(row)) - run_mock.assert_called_with(statement, retried=True) + run_mock.assert_called_with(statement) @mock.patch("google.cloud.spanner_v1.Client") def test_fetchmany_retry_aborted_statements_checksums_mismatch(self, mock_client): @@ -1158,17 +1135,17 @@ def test_fetchmany_retry_aborted_statements_checksums_mismatch(self, mock_client from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement - row = ["field1", "field2"] - row2 = ["updated_field1", "field2"] + row = ("field1", "field2") + row2 = ("updated_field1", "field2") connection = connect("test-instance", "test-database") cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) + checksum = ResultsChecksum() + checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) + statement = Statement("SELECT 1", [], {}, checksum) + connection._transaction_helper._single_statements.append(statement) with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__next__", @@ -1176,12 +1153,12 @@ def test_fetchmany_retry_aborted_statements_checksums_mismatch(self, mock_client ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row2], ResultsChecksum()), + return_value=([row2]), ) as run_mock: with self.assertRaises(RetryAborted): cursor.fetchmany(len(row)) - run_mock.assert_called_with(statement, retried=True) + run_mock.assert_called_with(statement) @mock.patch("google.cloud.spanner_v1.Client") def test_ddls_with_semicolon(self, mock_client): diff --git a/tests/unit/spanner_dbapi/test_transaction_helper.py b/tests/unit/spanner_dbapi/test_transaction_helper.py new file mode 100644 index 0000000000..d246134ab8 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_transaction_helper.py @@ -0,0 +1,340 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from unittest import mock +from unittest.mock import call + +from google.cloud.spanner_dbapi.exceptions import RetryAborted +from google.cloud.spanner_dbapi.checksum import ResultsChecksum +from google.cloud.spanner_dbapi.parsed_statement import Statement +from google.api_core.exceptions import Aborted +from google.cloud.spanner_dbapi import transaction_helper +from google.rpc.status_pb2 import Status +from google.rpc.code_pb2 import OK + +from google.cloud.spanner_dbapi.transaction_helper import TransactionHelper + + +class TestTransactionHelper(unittest.TestCase): + @mock.patch("google.cloud.spanner_dbapi.connection.Connection") + def setUp(self, mock_connection): + self._under_test = TransactionHelper(mock_connection) + self._mock_connection = mock_connection + + def test_retry_transaction_checksum_mismatch(self): + """ + Check retrying an aborted transaction with different result results in + checksums mismatch and exception thrown. + """ + + row = ("field1", "field2") + checksum = ResultsChecksum() + checksum.consume_result(row) + statement = Statement("SELECT 1", [], {}, checksum) + self._under_test._single_statements.append(statement) + + retried_row = ("field3", "field4") + run_mock = self._under_test._connection.run_statement = mock.Mock() + run_mock.return_value = [retried_row] + + with self.assertRaises(RetryAborted): + self._under_test.retry_transaction() + + def test_retry_aborted_retry(self): + """ + Check that in case of a retried transaction aborted, + it will be retried once again. + """ + + row = ("field1", "field2") + checksum = ResultsChecksum() + checksum.consume_result(row) + statement = Statement("SELECT 1", [], {}, checksum) + self._under_test._single_statements.append(statement) + + metadata_mock = mock.Mock() + metadata_mock.trailing_metadata.return_value = {} + run_mock = self._under_test._connection.run_statement = mock.Mock() + run_mock.side_effect = [ + Aborted("Aborted", errors=[metadata_mock]), + [row], + ] + + self._under_test.retry_transaction() + + run_mock.assert_has_calls( + ( + mock.call(statement), + mock.call(statement), + ) + ) + + def test_retry_transaction_raise_max_internal_retries(self): + """Check retrying raise an error of max internal retries.""" + + transaction_helper.MAX_INTERNAL_RETRIES = 0 + row = ("field1", "field2") + checksum = ResultsChecksum() + checksum.consume_result(row) + statement = Statement("SELECT 1", [], {}, checksum) + self._under_test._single_statements.append(statement) + + with self.assertRaises(Exception): + self._under_test.retry_transaction() + + transaction_helper.MAX_INTERNAL_RETRIES = 50 + + def test_retry_aborted_retry_without_delay(self): + """ + Check that in case of a retried transaction failed, + the connection will retry it once again. + """ + + row = ("field1", "field2") + checksum = ResultsChecksum() + checksum.consume_result(row) + statement = Statement("SELECT 1", [], {}, checksum) + self._under_test._single_statements.append(statement) + + metadata_mock = mock.Mock() + metadata_mock.trailing_metadata.return_value = {} + run_mock = self._under_test._connection.run_statement = mock.Mock() + run_mock.side_effect = [ + Aborted("Aborted", errors=[metadata_mock]), + [row], + ] + self._under_test._get_retry_delay = mock.Mock(return_value=False) + + self._under_test.retry_transaction() + + run_mock.assert_has_calls( + ( + mock.call(statement), + mock.call(statement), + ) + ) + + def test_retry_transaction_w_multiple_statement(self): + """Check retrying an aborted transaction having multiple statements.""" + + row = ("field1", "field2") + checksum = ResultsChecksum() + checksum.consume_result(row) + statement = Statement("SELECT 1", [], {}, checksum) + statement1 = Statement("SELECT 2", [], {}, checksum) + self._under_test._single_statements.append(statement) + self._under_test._single_statements.append(statement1) + run_mock = self._under_test._connection.run_statement = mock.Mock() + run_mock.return_value = [row] + retried_checksum = checksum + + with mock.patch( + "google.cloud.spanner_dbapi.transaction_helper._compare_checksums" + ) as compare_mock: + self._under_test.retry_transaction() + + compare_mock.assert_called_with(checksum, retried_checksum) + run_mock.assert_has_calls([call(statement), call(statement1)]) + + def test_retry_transaction_w_empty_response(self): + """Check retrying an aborted transaction with empty response.""" + + row = () + checksum = ResultsChecksum() + statement = Statement("SELECT 1", [], {}, checksum) + self._under_test._single_statements.append(statement) + run_mock = self._under_test._connection.run_statement = mock.Mock() + run_mock.return_value = [row] + retried_checksum = ResultsChecksum() + retried_checksum.consume_result(row) + + with mock.patch( + "google.cloud.spanner_dbapi.transaction_helper._compare_checksums" + ) as compare_mock: + self._under_test.retry_transaction() + + compare_mock.assert_called_with(checksum, retried_checksum) + run_mock.assert_called_with(statement) + + def test_retry_transaction_batch_statements_checksum_match(self): + """ + Check retrying an aborted transaction with same result, results in + checksums match. + """ + + res = (1, 1) + checksum = ResultsChecksum() + checksum.consume_result(res) + checksum.consume_result(OK) + statement1 = Statement("INSERT INTO T (f1, f2) VALUES (1, 2)", None, None) + statement2 = Statement("INSERT INTO T (f1, f2) VALUES (3, 4)", None, None) + self._under_test._batch_statements_list.append( + ([statement1, statement2], checksum) + ) + + mock_transaction = mock.MagicMock() + self._under_test._connection.transaction_checkout = mock.Mock( + return_value=mock_transaction + ) + mock_transaction.batch_update = mock.Mock(return_value=(Status(code=OK), res)) + + self._under_test.retry_transaction() + + mock_transaction.batch_update.assert_called_with( + [ + (("INSERT INTO T (f1, f2) VALUES (1, 2)", None, None)), + ("INSERT INTO T (f1, f2) VALUES (3, 4)", None, None), + ] + ) + + def test_retry_transaction_w_multiple_batch_statements(self): + """Check retrying an aborted transaction having multiple statements.""" + + res = (1, 1) + checksum = ResultsChecksum() + checksum.consume_result(res) + checksum.consume_result(OK) + statement1 = Statement("INSERT INTO T (f1, f2) VALUES (1, 2)", None, None) + statement2 = Statement("INSERT INTO T (f1, f2) VALUES (3, 4)", None, None) + self._under_test._batch_statements_list.append(([statement1], checksum)) + self._under_test._batch_statements_list.append(([statement2], checksum)) + + mock_transaction = mock.MagicMock() + self._under_test._connection.transaction_checkout = mock.Mock( + return_value=mock_transaction + ) + mock_transaction.batch_update = mock.Mock(return_value=(Status(code=OK), res)) + + self._under_test.retry_transaction() + + mock_transaction.batch_update.assert_has_calls( + [ + call( + [ + (("INSERT INTO T (f1, f2) VALUES (1, 2)", None, None)), + ] + ), + call( + [ + ("INSERT INTO T (f1, f2) VALUES (3, 4)", None, None), + ] + ), + ] + ) + + def test_retry_transaction_batch_statements_checksum_mismatch(self): + """ + Check retrying an aborted transaction with different result, results in + checksums mismatch and exception thrown. + """ + + res = (1, 1) + checksum = ResultsChecksum() + checksum.consume_result(res) + checksum.consume_result(OK) + statement1 = Statement("INSERT INTO T (f1, f2) VALUES (1, 2)", None, None) + statement2 = Statement("INSERT INTO T (f1, f2) VALUES (3, 4)", None, None) + self._under_test._batch_statements_list.append( + ([statement1, statement2], checksum) + ) + + retried_res = (2, 3) + mock_transaction = mock.MagicMock() + self._under_test._connection.transaction_checkout = mock.Mock( + return_value=mock_transaction + ) + mock_transaction.batch_update = mock.Mock( + return_value=(Status(code=OK), retried_res) + ) + + with self.assertRaises(RetryAborted): + self._under_test.retry_transaction() + + def test_batch_statements_retry_aborted_retry(self): + """ + Check that in case of a retried transaction aborted, + it will be retried once again. + """ + + res = 1 + checksum = ResultsChecksum() + checksum.consume_result(res) + checksum.consume_result(OK) + statement1 = Statement("INSERT INTO T (f1, f2) VALUES (1, 2)", None, None) + self._under_test._batch_statements_list.append(([statement1], checksum)) + + metadata_mock = mock.Mock() + metadata_mock.trailing_metadata.return_value = {} + mock_transaction = mock.MagicMock() + self._under_test._connection.transaction_checkout = mock.Mock( + return_value=mock_transaction + ) + mock_transaction.batch_update.side_effect = [ + Aborted("Aborted", errors=[metadata_mock]), + (Status(code=OK), res), + ] + + self._under_test.retry_transaction() + + statement_tuple = ("INSERT INTO T (f1, f2) VALUES (1, 2)", None, None) + mock_transaction.batch_update.assert_has_calls( + ( + mock.call([statement_tuple]), + mock.call([statement_tuple]), + ) + ) + + def test_transaction_w_both_batch_and_non_batch_statements(self): + """ + Check transaction having both batch and non batch type of statements and + having same result in retry succeeds. + """ + + row = ("field1", "field2") + checksum = ResultsChecksum() + checksum.consume_result(row) + single_statement = Statement("SELECT 1", [], {}, checksum) + self._under_test._single_statements.append(single_statement) + + res = (1, 1) + checksum = ResultsChecksum() + checksum.consume_result(res) + checksum.consume_result(OK) + batch_statement_1 = Statement( + "INSERT INTO T (f1, f2) VALUES (1, 2)", None, None + ) + batch_statement_2 = Statement( + "INSERT INTO T (f1, f2) VALUES (3, 4)", None, None + ) + self._under_test._batch_statements_list.append( + ([batch_statement_1, batch_statement_2], checksum) + ) + + run_mock = self._under_test._connection.run_statement = mock.Mock() + run_mock.return_value = [row] + mock_transaction = mock.MagicMock() + self._under_test._connection.transaction_checkout = mock.Mock( + return_value=mock_transaction + ) + mock_transaction.batch_update = mock.Mock(return_value=(Status(code=OK), res)) + + self._under_test.retry_transaction() + + run_mock.assert_called_with(single_statement) + mock_transaction.batch_update.assert_called_with( + [ + (("INSERT INTO T (f1, f2) VALUES (1, 2)", None, None)), + ("INSERT INTO T (f1, f2) VALUES (3, 4)", None, None), + ] + )