diff --git a/aws_advanced_python_wrapper/__init__.py b/aws_advanced_python_wrapper/__init__.py index d388f0a7..fbac6623 100644 --- a/aws_advanced_python_wrapper/__init__.py +++ b/aws_advanced_python_wrapper/__init__.py @@ -14,6 +14,7 @@ from logging import DEBUG, getLogger +from .cleanup import release_resources from .utils.utils import LogUtils from .wrapper import AwsWrapperConnection @@ -23,6 +24,17 @@ threadsafety = 2 paramstyle = "pyformat" +# Public API +__all__ = [ + 'connect', + 'AwsWrapperConnection', + 'release_resources', + 'set_logger', + 'apilevel', + 'threadsafety', + 'paramstyle' +] + def set_logger(name='aws_advanced_python_wrapper', level=DEBUG, format_string=None): LogUtils.setup_logger(getLogger(name), level, format_string) diff --git a/aws_advanced_python_wrapper/cleanup.py b/aws_advanced_python_wrapper/cleanup.py new file mode 100644 index 00000000..1327bc47 --- /dev/null +++ b/aws_advanced_python_wrapper/cleanup.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from aws_advanced_python_wrapper.host_monitoring_plugin import \ + MonitoringThreadContainer +from aws_advanced_python_wrapper.thread_pool_container import \ + ThreadPoolContainer + + +def release_resources() -> None: + """Release all global resources used by the wrapper.""" + MonitoringThreadContainer.clean_up() + ThreadPoolContainer.release_resources() diff --git a/aws_advanced_python_wrapper/custom_endpoint_plugin.py b/aws_advanced_python_wrapper/custom_endpoint_plugin.py index 03b672a6..8e490e02 100644 --- a/aws_advanced_python_wrapper/custom_endpoint_plugin.py +++ b/aws_advanced_python_wrapper/custom_endpoint_plugin.py @@ -169,7 +169,8 @@ def _run(self): len(endpoints), endpoint_hostnames) - sleep(self._refresh_rate_ns / 1_000_000_000) + if self._stop_event.wait(self._refresh_rate_ns / 1_000_000_000): + break continue endpoint_info = CustomEndpointInfo.from_db_cluster_endpoint(endpoints[0]) @@ -178,7 +179,8 @@ def _run(self): if cached_info is not None and cached_info == endpoint_info: elapsed_time = perf_counter_ns() - start_ns sleep_duration = max(0, self._refresh_rate_ns - elapsed_time) - sleep(sleep_duration / 1_000_000_000) + if self._stop_event.wait(sleep_duration / 1_000_000_000): + break continue logger.debug( @@ -196,7 +198,8 @@ def _run(self): elapsed_time = perf_counter_ns() - start_ns sleep_duration = max(0, self._refresh_rate_ns - elapsed_time) - sleep(sleep_duration / 1_000_000_000) + if self._stop_event.wait(sleep_duration / 1_000_000_000): + break continue except InterruptedError as e: raise e diff --git a/aws_advanced_python_wrapper/database_dialect.py b/aws_advanced_python_wrapper/database_dialect.py index 1dafd862..e6f4b973 100644 --- a/aws_advanced_python_wrapper/database_dialect.py +++ b/aws_advanced_python_wrapper/database_dialect.py @@ -28,7 +28,7 @@ from .exception_handling import ExceptionHandler from abc import ABC, abstractmethod -from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError +from concurrent.futures import TimeoutError from contextlib import closing from enum import Enum, auto @@ -37,6 +37,8 @@ from aws_advanced_python_wrapper.host_list_provider import ( ConnectionStringHostListProvider, RdsHostListProvider) from aws_advanced_python_wrapper.hostinfo import HostInfo +from aws_advanced_python_wrapper.thread_pool_container import \ + ThreadPoolContainer from aws_advanced_python_wrapper.utils.decorators import \ preserve_transaction_status_with_timeout from aws_advanced_python_wrapper.utils.log import Logger @@ -638,7 +640,7 @@ class DatabaseDialectManager(DatabaseDialectProvider): _ENDPOINT_CACHE_EXPIRATION_NS = 30 * 60_000_000_000 # 30 minutes _known_endpoint_dialects: CacheMap[str, DialectCode] = CacheMap() _custom_dialect: Optional[DatabaseDialect] = None - _executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="DatabaseDialectManagerExecutor") + _executor_name: ClassVar[str] = "DatabaseDialectManagerExecutor" _known_dialects_by_code: Dict[DialectCode, DatabaseDialect] = { DialectCode.MYSQL: MysqlDatabaseDialect(), DialectCode.RDS_MYSQL: RdsMysqlDialect(), @@ -776,7 +778,7 @@ def query_for_dialect(self, url: str, host_info: Optional[HostInfo], conn: Conne timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get(self._props) try: cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( - DatabaseDialectManager._executor, + ThreadPoolContainer.get_thread_pool(DatabaseDialectManager._executor_name), timeout_sec, driver_dialect, conn)(dialect_candidate.is_dialect) diff --git a/aws_advanced_python_wrapper/driver_dialect.py b/aws_advanced_python_wrapper/driver_dialect.py index c9df891a..9722cf94 100644 --- a/aws_advanced_python_wrapper/driver_dialect.py +++ b/aws_advanced_python_wrapper/driver_dialect.py @@ -21,11 +21,13 @@ from aws_advanced_python_wrapper.pep249 import Connection, Cursor from abc import ABC -from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError +from concurrent.futures import TimeoutError from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes from aws_advanced_python_wrapper.errors import (QueryTimeoutError, UnsupportedOperationError) +from aws_advanced_python_wrapper.thread_pool_container import \ + ThreadPoolContainer from aws_advanced_python_wrapper.utils.decorators import timeout from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, @@ -40,7 +42,7 @@ class DriverDialect(ABC): _QUERY = "SELECT 1" _ALL_METHODS = "*" - _executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="DriverDialectExecutor") + _executor_name: ClassVar[str] = "DriverDialectExecutor" _dialect_code: str = DriverDialectCodes.GENERIC _network_bound_methods: Set[str] = {_ALL_METHODS} _read_only: bool = False @@ -136,7 +138,7 @@ def execute( if exec_timeout > 0: try: - execute_with_timeout = timeout(DriverDialect._executor, exec_timeout)(exec_func) + execute_with_timeout = timeout(ThreadPoolContainer.get_thread_pool(DriverDialect._executor_name), exec_timeout)(exec_func) return execute_with_timeout() except TimeoutError as e: raise QueryTimeoutError(Messages.get_formatted("DriverDialect.ExecuteTimeout", method_name)) from e diff --git a/aws_advanced_python_wrapper/host_list_provider.py b/aws_advanced_python_wrapper/host_list_provider.py index 9d713817..81409443 100644 --- a/aws_advanced_python_wrapper/host_list_provider.py +++ b/aws_advanced_python_wrapper/host_list_provider.py @@ -16,7 +16,7 @@ import uuid from abc import ABC, abstractmethod -from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError +from concurrent.futures import TimeoutError from contextlib import closing from dataclasses import dataclass from datetime import datetime @@ -39,6 +39,8 @@ from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.pep249 import (Connection, Cursor, ProgrammingError) +from aws_advanced_python_wrapper.thread_pool_container import \ + ThreadPoolContainer from aws_advanced_python_wrapper.utils.cache_map import CacheMap from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages @@ -148,8 +150,6 @@ class RdsHostListProvider(DynamicHostListProvider, HostListProvider): # cluster IDs so that connections to the same clusters can share topology info. _cluster_ids_to_update: CacheMap[str, str] = CacheMap() - _executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="RdsHostListProviderExecutor") - def __init__(self, host_list_provider_service: HostListProviderService, props: Properties, topology_utils: TopologyUtils): self._host_list_provider_service: HostListProviderService = host_list_provider_service self._props: Properties = props @@ -425,6 +425,8 @@ class TopologyUtils(ABC): to various database engine deployments (e.g. Aurora, Multi-AZ, etc.). """ + _executor_name: ClassVar[str] = "TopologyUtils" + def __init__(self, dialect: db_dialect.TopologyAwareDatabaseDialect, props: Properties): self._dialect: db_dialect.TopologyAwareDatabaseDialect = dialect self._rds_utils = RdsUtils() @@ -487,7 +489,7 @@ def query_for_topology( an empty tuple will be returned. """ query_for_topology_func_with_timeout = preserve_transaction_status_with_timeout( - RdsHostListProvider._executor, self._max_timeout, driver_dialect, conn)(self._query_for_topology) + ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout, driver_dialect, conn)(self._query_for_topology) return query_for_topology_func_with_timeout(conn) @abstractmethod @@ -549,7 +551,7 @@ def create_host( def get_host_role(self, connection: Connection, driver_dialect: DriverDialect) -> HostRole: try: cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( - RdsHostListProvider._executor, self._max_timeout, driver_dialect, connection)(self._get_host_role) + ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout, driver_dialect, connection)(self._get_host_role) result = cursor_execute_func_with_timeout(connection) if result is not None: is_reader = result[0] @@ -572,7 +574,7 @@ def get_host_id(self, connection: Connection, driver_dialect: DriverDialect) -> """ cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( - RdsHostListProvider._executor, self._max_timeout, driver_dialect, connection)(self._get_host_id) + ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout, driver_dialect, connection)(self._get_host_id) result = cursor_execute_func_with_timeout(connection) if result: host_id: str = result[0] @@ -586,6 +588,9 @@ def _get_host_id(self, conn: Connection): class AuroraTopologyUtils(TopologyUtils): + + _executor_name: ClassVar[str] = "AuroraTopologyUtils" + def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]]: """ Query the database for topology information. @@ -636,6 +641,9 @@ def _process_query_results(self, cursor: Cursor) -> Tuple[HostInfo, ...]: class MultiAzTopologyUtils(TopologyUtils): + + _executor_name: ClassVar[str] = "MultiAzTopologyUtils" + def __init__( self, dialect: db_dialect.TopologyAwareDatabaseDialect, diff --git a/aws_advanced_python_wrapper/host_monitoring_plugin.py b/aws_advanced_python_wrapper/host_monitoring_plugin.py index 415a0a9e..a5e5752f 100644 --- a/aws_advanced_python_wrapper/host_monitoring_plugin.py +++ b/aws_advanced_python_wrapper/host_monitoring_plugin.py @@ -22,12 +22,11 @@ from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService -from concurrent.futures import (Executor, Future, ThreadPoolExecutor, - TimeoutError) +from concurrent.futures import Future, TimeoutError from dataclasses import dataclass from queue import Queue from threading import Event, Lock, RLock -from time import perf_counter_ns, sleep +from time import perf_counter_ns from typing import Any, Callable, ClassVar, Dict, FrozenSet, Optional, Set from _weakref import ReferenceType, ref @@ -36,6 +35,8 @@ from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.plugin import (CanReleaseResources, Plugin, PluginFactory) +from aws_advanced_python_wrapper.thread_pool_container import \ + ThreadPoolContainer from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages @@ -548,9 +549,8 @@ def _execute_conn_check(self, conn: Connection, timeout_sec: float): driver_dialect.execute("Cursor.execute", lambda: cursor.execute(query), query, exec_timeout=timeout_sec) cursor.fetchone() - # Used to help with testing def sleep(self, duration: int): - sleep(duration) + self._is_stopped.wait(duration) class MonitoringThreadContainer: @@ -565,7 +565,7 @@ class MonitoringThreadContainer: _monitor_map: ConcurrentDict[str, Monitor] = ConcurrentDict() _tasks_map: ConcurrentDict[Monitor, Future] = ConcurrentDict() - _executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="MonitoringThreadContainerExecutor") + _executor_name: ClassVar[str] = "MonitoringThreadContainerExecutor" # This logic ensures that this class is a Singleton def __new__(cls, *args, **kwargs): @@ -593,7 +593,9 @@ def _get_or_create_monitor(_) -> Monitor: if supplied_monitor is None: raise AwsWrapperError(Messages.get("MonitoringThreadContainer.SupplierMonitorNone")) self._tasks_map.compute_if_absent( - supplied_monitor, lambda _: MonitoringThreadContainer._executor.submit(supplied_monitor.run)) + supplied_monitor, + lambda _: ThreadPoolContainer.get_thread_pool(MonitoringThreadContainer._executor_name) + .submit(supplied_monitor.run)) return supplied_monitor if monitor is None: @@ -648,12 +650,9 @@ def _release_resources(self): for monitor, _ in self._tasks_map.items(): monitor.stop() + ThreadPoolContainer.release_pool(MonitoringThreadContainer._executor_name, wait=False) self._tasks_map.clear() - # Reset the executor. - self._executor.shutdown(wait=False) - MonitoringThreadContainer._executor = ThreadPoolExecutor(thread_name_prefix="MonitoringThreadContainerExecutor") - class MonitorService: def __init__(self, plugin_service: PluginService): diff --git a/aws_advanced_python_wrapper/mysql_driver_dialect.py b/aws_advanced_python_wrapper/mysql_driver_dialect.py index dd7055c5..d777cb4a 100644 --- a/aws_advanced_python_wrapper/mysql_driver_dialect.py +++ b/aws_advanced_python_wrapper/mysql_driver_dialect.py @@ -20,12 +20,14 @@ from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection -from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError +from concurrent.futures import TimeoutError from inspect import signature from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes from aws_advanced_python_wrapper.errors import UnsupportedOperationError +from aws_advanced_python_wrapper.thread_pool_container import \ + ThreadPoolContainer from aws_advanced_python_wrapper.utils.decorators import timeout from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, @@ -55,7 +57,7 @@ class MySQLDriverDialect(DriverDialect): AUTH_METHOD = "mysql_clear_password" IS_CLOSED_TIMEOUT_SEC = 3 - _executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="MySQLDriverDialectExecutor") + _executor_name: ClassVar[str] = "MySQLDriverDialectExecutor" _dialect_code: str = DriverDialectCodes.MYSQL_CONNECTOR_PYTHON _network_bound_methods: Set[str] = { @@ -94,7 +96,8 @@ def is_closed(self, conn: Connection) -> bool: if self.can_execute_query(conn): socket_timeout = WrapperProperties.SOCKET_TIMEOUT_SEC.get_float(self._props) timeout_sec = socket_timeout if socket_timeout > 0 else MySQLDriverDialect.IS_CLOSED_TIMEOUT_SEC - is_connected_with_timeout = timeout(MySQLDriverDialect._executor, timeout_sec)(conn.is_connected) # type: ignore + is_connected_with_timeout = timeout( + ThreadPoolContainer.get_thread_pool(MySQLDriverDialect._executor_name), timeout_sec)(conn.is_connected) # type: ignore try: return not is_connected_with_timeout() diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index fbf909e1..fa1879ab 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -41,7 +41,7 @@ from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from abc import abstractmethod -from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError +from concurrent.futures import TimeoutError from contextlib import closing from typing import (Any, Callable, Dict, FrozenSet, Optional, Protocol, Set, Tuple) @@ -85,6 +85,8 @@ from aws_advanced_python_wrapper.simple_read_write_splitting_plugin import \ SimpleReadWriteSplittingPluginFactory from aws_advanced_python_wrapper.stale_dns_plugin import StaleDnsPluginFactory +from aws_advanced_python_wrapper.thread_pool_container import \ + ThreadPoolContainer from aws_advanced_python_wrapper.utils.cache_map import CacheMap from aws_advanced_python_wrapper.utils.decorators import \ preserve_transaction_status_with_timeout @@ -314,7 +316,7 @@ class PluginServiceImpl(PluginService, HostListProviderService, CanReleaseResour _host_availability_expiring_cache: CacheMap[str, HostAvailability] = CacheMap() _status_cache: ClassVar[CacheMap[str, Any]] = CacheMap() - _executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="PluginServiceImplExecutor") + _executor_name: ClassVar[str] = "PluginServiceImplExecutor" def __init__( self, @@ -611,7 +613,7 @@ def fill_aliases(self, connection: Optional[Connection] = None, host_info: Optio try: timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get(self._props) cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( - PluginServiceImpl._executor, timeout_sec, driver_dialect, connection)(self._fill_aliases) + ThreadPoolContainer.get_thread_pool(PluginServiceImpl._executor_name), timeout_sec, driver_dialect, connection)(self._fill_aliases) cursor_execute_func_with_timeout(connection, host_info) except TimeoutError as e: raise QueryTimeoutError(Messages.get("PluginServiceImpl.FillAliasesTimeout")) from e diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index a35c0bfb..1c1c0fbd 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -274,6 +274,8 @@ HostResponseTimeMonitor.OpeningConnection=[HostResponseTimeMonitor] Opening a Re HostResponseTimeMonitor.ResponseTime=[HostResponseTimeMonitor] Response time for '{}': {} ms HostResponseTimeMonitor.Stopped=[HostResponseTimeMonitor] Stopped Response time thread for host '{}'. +ThreadPoolContainer.ErrorShuttingDownPool=[ThreadPoolContainer] Error shutting down pool '{}': '{}'. + OpenedConnectionTracker.OpenedConnectionsTracked=[OpenedConnectionTracker] Opened Connections Tracked: {} OpenedConnectionTracker.InvalidatingConnections=[OpenedConnectionTracker] Invalidating opened connections to host: {} OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet=[OpenedConnectionTracker] The driver is unable to track this opened connection because the instance endpoint is unknown. diff --git a/aws_advanced_python_wrapper/thread_pool_container.py b/aws_advanced_python_wrapper/thread_pool_container.py new file mode 100644 index 00000000..6c9cf905 --- /dev/null +++ b/aws_advanced_python_wrapper/thread_pool_container.py @@ -0,0 +1,119 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional + +from aws_advanced_python_wrapper.utils.log import Logger + +logger = Logger(__name__) + + +class ThreadPoolContainer: + """ + A container class for managing multiple named thread pools. + Provides static methods for getting, creating, and releasing thread pools. + """ + + _pools: Dict[str, ThreadPoolExecutor] = {} + _lock: threading.Lock = threading.Lock() + _default_max_workers: Optional[int] = None # Uses Python's default + + @classmethod + def get_thread_pool( + cls, + name: str, + max_workers: Optional[int] = None + ) -> ThreadPoolExecutor: + """ + Get an existing thread pool or create a new one if it doesn't exist. + + Args: + name: Unique identifier for the thread pool + max_workers: Max worker threads (only used when creating new pool) + If None, uses Python's default: min(32, os.cpu_count() + 4) + + Returns: + ThreadPoolExecutor instance + """ + with cls._lock: + if name not in cls._pools: + workers = max_workers or cls._default_max_workers + cls._pools[name] = ThreadPoolExecutor( + max_workers=workers, + thread_name_prefix=name + ) + return cls._pools[name] + + @classmethod + def release_resources(cls, wait=False) -> None: + """ + Shutdown all thread pools and release resources. + + Args: + wait: If True, wait for all pending tasks to complete + """ + with cls._lock: + for name, pool in cls._pools.items(): + try: + pool.shutdown(wait=wait) + except Exception as e: + logger.warning("ThreadPoolContainer.ErrorShuttingDownPool", name, e) + cls._pools.clear() + + @classmethod + def release_pool(cls, name: str, wait: bool = True) -> bool: + """ + Release a specific thread pool by name. + + Args: + name: The name of the thread pool to release + wait: If True, wait for pending tasks to complete + + Returns: + True if pool was found and released, False otherwise + """ + with cls._lock: + if name in cls._pools: + try: + cls._pools[name].shutdown(wait=wait) + del cls._pools[name] + return True + except Exception as e: + logger.warning("ThreadPoolContainer.ErrorShuttingDownPool", name, e) + return False + + @classmethod + def has_pool(cls, name: str) -> bool: + """Check if a pool with the given name exists.""" + with cls._lock: + return name in cls._pools + + @classmethod + def get_pool_names(cls) -> List[str]: + """Get a list of all active pool names.""" + with cls._lock: + return list(cls._pools.keys()) + + @classmethod + def get_pool_count(cls) -> int: + """Get the number of active pools.""" + with cls._lock: + return len(cls._pools) + + @classmethod + def set_default_max_workers(cls, max_workers: Optional[int]) -> None: + """Set the default max workers for new pools.""" + cls._default_max_workers = max_workers diff --git a/docs/examples/MySQLFailover.py b/docs/examples/MySQLFailover.py index b7e3b67f..d007567c 100644 --- a/docs/examples/MySQLFailover.py +++ b/docs/examples/MySQLFailover.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from aws_advanced_python_wrapper.pep249 import Connection -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources from aws_advanced_python_wrapper.errors import ( FailoverFailedError, FailoverSuccessError, TransactionResolutionUnknownError) @@ -61,26 +61,29 @@ def execute_queries_with_failover_handling(conn: Connection, sql: str, params: O if __name__ == "__main__": - with AwsWrapperConnection.connect( - mysql.connector.Connect, - host="database.cluster-xyz.us-east-1.rds.amazonaws.com", - database="mysql", - user="admin", - password="pwd", - plugins="failover", - wrapper_dialect="aurora-mysql", - autocommit=True - ) as awsconn: - configure_initial_session_states(awsconn) - execute_queries_with_failover_handling( - awsconn, "CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") - execute_queries_with_failover_handling( - awsconn, "INSERT INTO bank_test VALUES (%s, %s, %s)", (0, "Jane Doe", 200)) - execute_queries_with_failover_handling( - awsconn, "INSERT INTO bank_test VALUES (%s, %s, %s)", (1, "John Smith", 200)) - - cursor = execute_queries_with_failover_handling(awsconn, "SELECT * FROM bank_test") - for record in cursor: - print(record) - - execute_queries_with_failover_handling(awsconn, "DROP TABLE bank_test") + try: + with AwsWrapperConnection.connect( + mysql.connector.Connect, + host="database.cluster-xyz.us-east-1.rds.amazonaws.com", + database="mysql", + user="admin", + password="pwd", + plugins="failover", + wrapper_dialect="aurora-mysql", + autocommit=True + ) as awsconn: + configure_initial_session_states(awsconn) + execute_queries_with_failover_handling( + awsconn, "CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") + execute_queries_with_failover_handling( + awsconn, "INSERT INTO bank_test VALUES (%s, %s, %s)", (0, "Jane Doe", 200)) + execute_queries_with_failover_handling( + awsconn, "INSERT INTO bank_test VALUES (%s, %s, %s)", (1, "John Smith", 200)) + + cursor = execute_queries_with_failover_handling(awsconn, "SELECT * FROM bank_test") + for record in cursor: + print(record) + + execute_queries_with_failover_handling(awsconn, "DROP TABLE bank_test") + finally: + release_resources() diff --git a/docs/examples/MySQLFastestResponseStrategy.py b/docs/examples/MySQLFastestResponseStrategy.py index ad205892..ffb52561 100644 --- a/docs/examples/MySQLFastestResponseStrategy.py +++ b/docs/examples/MySQLFastestResponseStrategy.py @@ -14,7 +14,7 @@ import mysql.connector -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources from aws_advanced_python_wrapper.connection_provider import \ ConnectionProviderManager from aws_advanced_python_wrapper.sql_alchemy_connection_provider import \ @@ -24,43 +24,46 @@ provider = SqlAlchemyPooledConnectionProvider() ConnectionProviderManager.set_connection_provider(provider) - with AwsWrapperConnection.connect( - mysql.connector.Connect, - host="database.cluster-xyz.us-east-1.rds.amazonaws.com", - database="mysql", - user="user", - password="password", - plugins="read_write_splitting,fastest_response_strategy", - reader_host_selector_strategy="fastest_response", - autocommit=True - ) as conn: - # Set up - with conn.cursor() as setup_cursor: - setup_cursor.execute("CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") - setup_cursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (0, "Jane Doe", 200)) + try: + with AwsWrapperConnection.connect( + mysql.connector.Connect, + host="database.cluster-xyz.us-east-1.rds.amazonaws.com", + database="mysql", + user="user", + password="password", + plugins="read_write_splitting,fastest_response_strategy", + reader_host_selector_strategy="fastest_response", + autocommit=True + ) as conn: + # Set up + with conn.cursor() as setup_cursor: + setup_cursor.execute("CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") + setup_cursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (0, "Jane Doe", 200)) - conn.read_only = True - with conn.cursor() as cursor_1: - cursor_1.execute("SELECT * FROM bank_test") - results = cursor_1.fetchall() - for record in results: - print(record) + conn.read_only = True + with conn.cursor() as cursor_1: + cursor_1.execute("SELECT * FROM bank_test") + results = cursor_1.fetchall() + for record in results: + print(record) - # Switch to writer host - conn.read_only = False + # Switch to writer host + conn.read_only = False - # Use cached host when switching back to a reader - conn.read_only = True - with conn.cursor() as cursor_2: - cursor_2.execute("SELECT * FROM bank_test") - results = cursor_2.fetchall() - for record in results: - print(record) + # Use cached host when switching back to a reader + conn.read_only = True + with conn.cursor() as cursor_2: + cursor_2.execute("SELECT * FROM bank_test") + results = cursor_2.fetchall() + for record in results: + print(record) - # Tear down - conn.read_only = False - with conn.cursor() as teardown_cursor: - teardown_cursor.execute("DROP TABLE bank_test") - - # Closes all pools and removes all cached pool connections - ConnectionProviderManager.release_resources() + # Tear down + conn.read_only = False + with conn.cursor() as teardown_cursor: + teardown_cursor.execute("DROP TABLE bank_test") + finally: + # Clean up global resources created by wrapper + release_resources() + # Closes all pools and removes all cached pool connections + ConnectionProviderManager.release_resources() diff --git a/docs/examples/MySQLFederatedAuthentication.py b/docs/examples/MySQLFederatedAuthentication.py index 415b0f5e..65cf1d21 100644 --- a/docs/examples/MySQLFederatedAuthentication.py +++ b/docs/examples/MySQLFederatedAuthentication.py @@ -14,26 +14,30 @@ import mysql.connector -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources if __name__ == "__main__": - with AwsWrapperConnection.connect( - mysql.connector.Connect, - host="database.cluster-xyz.us-east-2.rds.amazonaws.com", - database="mysql", - plugins="federated_auth", - idp_name="adfs", - app_id="abcde1fgh3kLZTBz1S5d7", - idp_endpoint="ec2amaz-ab3cdef.example.com", - iam_role_arn="arn:aws:iam::123456789012:role/adfs_example_iam_role", - iam_idp_arn="arn:aws:iam::123456789012:saml-provider/adfs_example", - iam_region="us-east-2", - idp_username="some_federated_username@example.com", - idp_password="some_password", - db_user="john", - autocommit=True - ) as awsconn, awsconn.cursor() as awscursor: - awscursor.execute("SELECT 1") + try: + with AwsWrapperConnection.connect( + mysql.connector.Connect, + host="database.cluster-xyz.us-east-2.rds.amazonaws.com", + database="mysql", + plugins="federated_auth", + idp_name="adfs", + app_id="abcde1fgh3kLZTBz1S5d7", + idp_endpoint="ec2amaz-ab3cdef.example.com", + iam_role_arn="arn:aws:iam::123456789012:role/adfs_example_iam_role", + iam_idp_arn="arn:aws:iam::123456789012:saml-provider/adfs_example", + iam_region="us-east-2", + idp_username="some_federated_username@example.com", + idp_password="some_password", + db_user="john", + autocommit=True + ) as awsconn, awsconn.cursor() as awscursor: + awscursor.execute("SELECT 1") - res = awscursor.fetchone() - print(res) + res = awscursor.fetchone() + print(res) + finally: + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/MySQLIamAuthentication.py b/docs/examples/MySQLIamAuthentication.py index b440c658..32848d92 100644 --- a/docs/examples/MySQLIamAuthentication.py +++ b/docs/examples/MySQLIamAuthentication.py @@ -14,23 +14,27 @@ import mysql.connector -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources if __name__ == "__main__": - with AwsWrapperConnection.connect( - mysql.connector.Connect, - host="database.cluster-xyz.us-east-1.rds.amazonaws.com", - database="mysql", - user="admin", - plugins="iam", - wrapper_dialect="aurora-mysql", - autocommit=True - ) as awsconn, awsconn.cursor() as awscursor: - awscursor.execute("CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") - awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (0, "Jane Doe", 200)) - awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (1, "John Smith", 200)) - awscursor.execute("SELECT * FROM bank_test") + try: + with AwsWrapperConnection.connect( + mysql.connector.Connect, + host="database.cluster-xyz.us-east-1.rds.amazonaws.com", + database="mysql", + user="admin", + plugins="iam", + wrapper_dialect="aurora-mysql", + autocommit=True + ) as awsconn, awsconn.cursor() as awscursor: + awscursor.execute("CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") + awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (0, "Jane Doe", 200)) + awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (1, "John Smith", 200)) + awscursor.execute("SELECT * FROM bank_test") - for record in awscursor: - print(record) - awscursor.execute("DROP TABLE bank_test") + for record in awscursor: + print(record) + awscursor.execute("DROP TABLE bank_test") + finally: + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/MySQLInternalConnectionPoolPasswordWarning.py b/docs/examples/MySQLInternalConnectionPoolPasswordWarning.py index 311d1ec2..1f516d62 100644 --- a/docs/examples/MySQLInternalConnectionPoolPasswordWarning.py +++ b/docs/examples/MySQLInternalConnectionPoolPasswordWarning.py @@ -14,51 +14,56 @@ import mysql.connector -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources from aws_advanced_python_wrapper.connection_provider import \ ConnectionProviderManager from aws_advanced_python_wrapper.sql_alchemy_connection_provider import \ SqlAlchemyPooledConnectionProvider if __name__ == "__main__": - params = { - # In general, you should not use instance URLs to connect. However, we will use one here to simplify this - # example, because internal connection pools are only opened when connecting to an instance URL. Normally the - # internal connection pool would be opened when read_only is set instead of when you are initially connecting. - "host": "database-instance.xyz.us-east-1.rds.amazonaws.com", - "database": "mysql", - "user": "admin", - "plugins": "read_write_splitting,failover", - "autocommit": True - } - - correct_password = "correct_password" - incorrect_password = "incorrect_password" - - provider = SqlAlchemyPooledConnectionProvider() - ConnectionProviderManager.set_connection_provider(provider) + try: + params = { + # In general, you should not use instance URLs to connect. However, we will use one here to simplify this + # example, because internal connection pools are only opened when connecting to an instance URL. Normally the + # internal connection pool would be opened when read_only is set instead of when you are initially connecting. + "host": "database-instance.xyz.us-east-1.rds.amazonaws.com", + "database": "mysql", + "user": "admin", + "plugins": "read_write_splitting,failover", + "autocommit": True + } - # Create an internal connection pool with the correct password - conn = AwsWrapperConnection.connect(mysql.connector.Connect, **params, password=correct_password) - # Finished with connection. The connection is not actually closed here, instead it will be returned to the pool but - # will remain open. - conn.close() + correct_password = "correct_password" + incorrect_password = "incorrect_password" - # Even though we use an incorrect password, the original connection 'conn' will be returned by the pool, and we can - # still use it. - with AwsWrapperConnection.connect( - mysql.connector.Connect, **params, password=incorrect_password) as incorrect_password_conn: - incorrect_password_conn.cursor().execute("SELECT 1") + provider = SqlAlchemyPooledConnectionProvider() + ConnectionProviderManager.set_connection_provider(provider) - # Closes all pools and removes all cached pool connections - ConnectionProviderManager.release_resources() + # Create an internal connection pool with the correct password + conn = AwsWrapperConnection.connect(mysql.connector.Connect, **params, password=correct_password) + # Finished with connection. The connection is not actually closed here, instead it will be returned to the pool but + # will remain open. + conn.close() - try: - # Correctly throws an exception - creates a fresh connection pool which will check the password because there - # are no longer any cached pool connections. + # Even though we use an incorrect password, the original connection 'conn' will be returned by the pool, and we can + # still use it. with AwsWrapperConnection.connect( mysql.connector.Connect, **params, password=incorrect_password) as incorrect_password_conn: - # Will not reach - exception will be thrown - pass - except Exception: - print("Failed to connect - password was incorrect") + incorrect_password_conn.cursor().execute("SELECT 1") + + # Closes all pools and removes all cached pool connections + ConnectionProviderManager.release_resources() + + try: + # Correctly throws an exception - creates a fresh connection pool which will check the password because there + # are no longer any cached pool connections. + with AwsWrapperConnection.connect( + mysql.connector.Connect, **params, password=incorrect_password) as incorrect_password_conn: + # Will not reach - exception will be thrown + pass + except Exception: + print("Failed to connect - password was incorrect") + + finally: + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/MySQLOktaAuthentication.py b/docs/examples/MySQLOktaAuthentication.py index b4966438..74383013 100644 --- a/docs/examples/MySQLOktaAuthentication.py +++ b/docs/examples/MySQLOktaAuthentication.py @@ -14,25 +14,30 @@ import mysql.connector -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources if __name__ == "__main__": - with AwsWrapperConnection.connect( - mysql.connector.Connect, - host="database.cluster-xyz.us-east-2.rds.amazonaws.com", - database="mysql", - plugins="okta", - idp_endpoint="ec2amaz-ab3cdef.example.com", - app_id="abcde1fgh3kLZTBz1S5d7", - iam_role_arn="arn:aws:iam::123456789012:role/adfs_example_iam_role", - iam_idp_arn="arn:aws:iam::123456789012:saml-provider/adfs_example", - iam_region="us-east-2", - idp_username="some_federated_username@example.com", - idp_password="some_password", - db_user="john", - autocommit=True - ) as awsconn, awsconn.cursor() as awscursor: - awscursor.execute("SELECT @@aurora_server_id") + try: + with AwsWrapperConnection.connect( + mysql.connector.Connect, + host="database.cluster-xyz.us-east-2.rds.amazonaws.com", + database="mysql", + plugins="okta", + idp_endpoint="ec2amaz-ab3cdef.example.com", + app_id="abcde1fgh3kLZTBz1S5d7", + iam_role_arn="arn:aws:iam::123456789012:role/adfs_example_iam_role", + iam_idp_arn="arn:aws:iam::123456789012:saml-provider/adfs_example", + iam_region="us-east-2", + idp_username="some_federated_username@example.com", + idp_password="some_password", + db_user="john", + autocommit=True + ) as awsconn, awsconn.cursor() as awscursor: + awscursor.execute("SELECT @@aurora_server_id") - res = awscursor.fetchone() - print(res) + res = awscursor.fetchone() + print(res) + + finally: + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/MySQLReadWriteSplitting.py b/docs/examples/MySQLReadWriteSplitting.py index 22746538..a102d68d 100644 --- a/docs/examples/MySQLReadWriteSplitting.py +++ b/docs/examples/MySQLReadWriteSplitting.py @@ -22,7 +22,7 @@ import mysql.connector # type: ignore -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources from aws_advanced_python_wrapper.connection_provider import \ ConnectionProviderManager from aws_advanced_python_wrapper.errors import ( @@ -142,3 +142,6 @@ def execute_queries_with_failover_handling(conn: Connection, sql: str, params: O """ If connection pools were enabled, close them here """ ConnectionProviderManager.release_resources() + + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/MySQLSecretsManager.py b/docs/examples/MySQLSecretsManager.py index 19368a7b..bff460d4 100644 --- a/docs/examples/MySQLSecretsManager.py +++ b/docs/examples/MySQLSecretsManager.py @@ -16,17 +16,21 @@ import mysql.connector -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources if __name__ == "__main__": - with AwsWrapperConnection.connect( - mysql.connector.Connect, - host="database.cluster-xyz.us-east-1.rds.amazonaws.com", - database="mysql", - secrets_manager_secret_id="arn:aws:secretsmanager:::secret:Secre78tName-6RandomCharacters", - secrets_manager_region="us-east-2", - plugins="aws_secrets_manager" - ) as awsconn, awsconn.cursor() as cursor: - cursor.execute("SELECT @@aurora_server_id") - for record in cursor.fetchone(): - print(record) + try: + with AwsWrapperConnection.connect( + mysql.connector.Connect, + host="database.cluster-xyz.us-east-1.rds.amazonaws.com", + database="mysql", + secrets_manager_secret_id="arn:aws:secretsmanager:::secret:Secre78tName-6RandomCharacters", + secrets_manager_region="us-east-2", + plugins="aws_secrets_manager" + ) as awsconn, awsconn.cursor() as cursor: + cursor.execute("SELECT @@aurora_server_id") + for record in cursor.fetchone(): + print(record) + finally: + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/MySQLSimpleReadWriteSplitting.py b/docs/examples/MySQLSimpleReadWriteSplitting.py index bf91e3c0..17079664 100644 --- a/docs/examples/MySQLSimpleReadWriteSplitting.py +++ b/docs/examples/MySQLSimpleReadWriteSplitting.py @@ -21,7 +21,7 @@ import mysql.connector # type: ignore -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources from aws_advanced_python_wrapper.errors import ( FailoverFailedError, FailoverSuccessError, TransactionResolutionUnknownError) @@ -116,3 +116,6 @@ def execute_queries_with_failover_handling(conn: Connection, sql: str, params: O finally: with AwsWrapperConnection.connect(mysql.connector.Connect, **params) as conn: execute_queries_with_failover_handling(conn, "DROP TABLE bank_test") + + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/PGFailover.py b/docs/examples/PGFailover.py index e1ece5a5..c15964fb 100644 --- a/docs/examples/PGFailover.py +++ b/docs/examples/PGFailover.py @@ -18,13 +18,10 @@ import psycopg -from aws_advanced_python_wrapper.host_monitoring_plugin import \ - MonitoringThreadContainer - if TYPE_CHECKING: from aws_advanced_python_wrapper.pep249 import Connection -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources from aws_advanced_python_wrapper.errors import ( FailoverFailedError, FailoverSuccessError, TransactionResolutionUnknownError) @@ -69,18 +66,18 @@ def execute_queries_with_failover_handling(conn: Connection, sql: str, params: O "monitoring-socket_timeout": 10 } - with AwsWrapperConnection.connect( - psycopg.Connection.connect, - host="database.cluster-xyz.us-east-1.rds.amazonaws.com", - dbname="postgres", - user="john", - password="pwd", - plugins="failover,host_monitoring", - connect_timeout=30, - socket_timeout=30, - autocommit=True - ) as awsconn: - try: + try: + with AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="database.cluster-xyz.us-east-1.rds.amazonaws.com", + dbname="postgres", + user="john", + password="pwd", + plugins="failover,host_monitoring", + connect_timeout=30, + socket_timeout=30, + autocommit=True + ) as awsconn: configure_initial_session_states(awsconn) execute_queries_with_failover_handling( awsconn, "CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") @@ -95,6 +92,6 @@ def execute_queries_with_failover_handling(conn: Connection, sql: str, params: O print(record) execute_queries_with_failover_handling(awsconn, "DROP TABLE bank_test") - finally: - # Clean up any remaining resources created by the Host Monitoring Plugin. - MonitoringThreadContainer.clean_up() + finally: + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/PGFastestResponseStrategy.py b/docs/examples/PGFastestResponseStrategy.py index 0cb99d32..2916bdbc 100644 --- a/docs/examples/PGFastestResponseStrategy.py +++ b/docs/examples/PGFastestResponseStrategy.py @@ -14,7 +14,7 @@ import psycopg -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources from aws_advanced_python_wrapper.connection_provider import \ ConnectionProviderManager from aws_advanced_python_wrapper.sql_alchemy_connection_provider import \ @@ -24,43 +24,47 @@ provider = SqlAlchemyPooledConnectionProvider() ConnectionProviderManager.set_connection_provider(provider) - with AwsWrapperConnection.connect( - psycopg.Connection.connect, - host="database.cluster-xyz.us-east-1.rds.amazonaws.com", - dbname="postgres", - user="user", - password="password", - plugins="read_write_splitting,fastest_response_strategy", - reader_host_selector_strategy="fastest_response", - autocommit=True - ) as conn: - # Set up - with conn.cursor() as setup_cursor: - setup_cursor.execute("CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") - setup_cursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (0, "Jane Doe", 200)) + try: + with AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="database.cluster-xyz.us-east-1.rds.amazonaws.com", + dbname="postgres", + user="user", + password="password", + plugins="read_write_splitting,fastest_response_strategy", + reader_host_selector_strategy="fastest_response", + autocommit=True + ) as conn: + # Set up + with conn.cursor() as setup_cursor: + setup_cursor.execute("CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") + setup_cursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (0, "Jane Doe", 200)) - conn.read_only = True - with conn.cursor() as cursor_1: - cursor_1.execute("SELECT * FROM bank_test") - results = cursor_1.fetchall() - for record in results: - print(record) + conn.read_only = True + with conn.cursor() as cursor_1: + cursor_1.execute("SELECT * FROM bank_test") + results = cursor_1.fetchall() + for record in results: + print(record) - # Switch to writer host - conn.read_only = False + # Switch to writer host + conn.read_only = False - # Use cached host when switching back to a reader - conn.read_only = True - with conn.cursor() as cursor_2: - cursor_2.execute("SELECT * FROM bank_test") - results = cursor_2.fetchall() - for record in results: - print(record) + # Use cached host when switching back to a reader + conn.read_only = True + with conn.cursor() as cursor_2: + cursor_2.execute("SELECT * FROM bank_test") + results = cursor_2.fetchall() + for record in results: + print(record) - # Tear down - conn.read_only = False - with conn.cursor() as teardown_cursor: - teardown_cursor.execute("DROP TABLE bank_test") + # Tear down + conn.read_only = False + with conn.cursor() as teardown_cursor: + teardown_cursor.execute("DROP TABLE bank_test") + finally: + # Closes all pools and removes all cached pool connections + ConnectionProviderManager.release_resources() - # Closes all pools and removes all cached pool connections - ConnectionProviderManager.release_resources() + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/PGFederatedAuthentication.py b/docs/examples/PGFederatedAuthentication.py index c718e091..20f2a965 100644 --- a/docs/examples/PGFederatedAuthentication.py +++ b/docs/examples/PGFederatedAuthentication.py @@ -14,25 +14,29 @@ import psycopg -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources if __name__ == "__main__": - with AwsWrapperConnection.connect( - psycopg.Connection.connect, - host="database.cluster-xyz.us-east-2.rds.amazonaws.com", - dbname="postgres", - plugins="federated_auth", - idp_name="adfs", - idp_endpoint="ec2amaz-ab3cdef.example.com", - iam_role_arn="arn:aws:iam::123456789012:role/adfs_example_iam_role", - iam_idp_arn="arn:aws:iam::123456789012:saml-provider/adfs_example", - iam_region="us-east-2", - idp_username="some_federated_username@example.com", - idp_password="some_password", - db_user="john", - autocommit=True - ) as awsconn, awsconn.cursor() as awscursor: - awscursor.execute("SELECT 1") + try: + with AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="database.cluster-xyz.us-east-2.rds.amazonaws.com", + dbname="postgres", + plugins="federated_auth", + idp_name="adfs", + idp_endpoint="ec2amaz-ab3cdef.example.com", + iam_role_arn="arn:aws:iam::123456789012:role/adfs_example_iam_role", + iam_idp_arn="arn:aws:iam::123456789012:saml-provider/adfs_example", + iam_region="us-east-2", + idp_username="some_federated_username@example.com", + idp_password="some_password", + db_user="john", + autocommit=True + ) as awsconn, awsconn.cursor() as awscursor: + awscursor.execute("SELECT 1") - res = awscursor.fetchone() - print(res) + res = awscursor.fetchone() + print(res) + finally: + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/PGIamAuthentication.py b/docs/examples/PGIamAuthentication.py index cde49393..1b9344e3 100644 --- a/docs/examples/PGIamAuthentication.py +++ b/docs/examples/PGIamAuthentication.py @@ -14,24 +14,28 @@ import psycopg -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources if __name__ == "__main__": - with AwsWrapperConnection.connect( - psycopg.Connection.connect, - host="database.cluster-xyz.us-east-1.rds.amazonaws.com", - dbname="postgres", - user="john", - plugins="iam", - wrapper_dialect="aurora-pg", - autocommit=True - ) as awsconn, awsconn.cursor() as awscursor: - awscursor.execute("CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") - awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (0, "Jane Doe", 200)) - awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (1, "John Smith", 200)) - awscursor.execute("SELECT * FROM bank_test") + try: + with AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="database.cluster-xyz.us-east-1.rds.amazonaws.com", + dbname="postgres", + user="john", + plugins="iam", + wrapper_dialect="aurora-pg", + autocommit=True + ) as awsconn, awsconn.cursor() as awscursor: + awscursor.execute("CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") + awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (0, "Jane Doe", 200)) + awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (1, "John Smith", 200)) + awscursor.execute("SELECT * FROM bank_test") - res = awscursor.fetchall() - for record in res: - print(record) - awscursor.execute("DROP TABLE bank_test") + res = awscursor.fetchall() + for record in res: + print(record) + awscursor.execute("DROP TABLE bank_test") + finally: + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/PGInternalConnectionPoolPasswordWarning.py b/docs/examples/PGInternalConnectionPoolPasswordWarning.py index 4404ee31..026870d8 100644 --- a/docs/examples/PGInternalConnectionPoolPasswordWarning.py +++ b/docs/examples/PGInternalConnectionPoolPasswordWarning.py @@ -14,51 +14,55 @@ import psycopg -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources from aws_advanced_python_wrapper.connection_provider import \ ConnectionProviderManager from aws_advanced_python_wrapper.sql_alchemy_connection_provider import \ SqlAlchemyPooledConnectionProvider if __name__ == "__main__": - params = { - # In general, you should not use instance URLs to connect. However, we will use one here to simplify this - # example, because internal connection pools are only opened when connecting to an instance URL. Normally the - # internal connection pool would be opened when read_only is set instead of when you are initially connecting. - "host": "database-instance.xyz.us-east-1.rds.amazonaws.com", - "dbname": "postgres", - "user": "john", - "plugins": "read_write_splitting,failover,host_monitoring", - "autocommit": True - } - - correct_password = "correct_password" - incorrect_password = "incorrect_password" - - provider = SqlAlchemyPooledConnectionProvider() - ConnectionProviderManager.set_connection_provider(provider) + try: + params = { + # In general, you should not use instance URLs to connect. However, we will use one here to simplify this + # example, because internal connection pools are only opened when connecting to an instance URL. Normally the + # internal connection pool would be opened when read_only is set instead of when you are initially connecting. + "host": "database-instance.xyz.us-east-1.rds.amazonaws.com", + "dbname": "postgres", + "user": "john", + "plugins": "read_write_splitting,failover,host_monitoring", + "autocommit": True + } - # Create an internal connection pool with the correct password - conn = AwsWrapperConnection.connect(psycopg.Connection.connect, **params, password=correct_password) - # Finished with connection. The connection is not actually closed here, instead it will be returned to the pool but - # will remain open. - conn.close() + correct_password = "correct_password" + incorrect_password = "incorrect_password" - # Even though we use an incorrect password, the original connection 'conn' will be returned by the pool, and we can - # still use it. - with AwsWrapperConnection.connect( - psycopg.Connection.connect, **params, password=incorrect_password) as incorrect_password_conn: - incorrect_password_conn.cursor().execute("SELECT 1") + provider = SqlAlchemyPooledConnectionProvider() + ConnectionProviderManager.set_connection_provider(provider) - # Closes all pools and removes all cached pool connections - ConnectionProviderManager.release_resources() + # Create an internal connection pool with the correct password + conn = AwsWrapperConnection.connect(psycopg.Connection.connect, **params, password=correct_password) + # Finished with connection. The connection is not actually closed here, instead it will be returned to the pool but + # will remain open. + conn.close() - try: - # Correctly throws an exception - creates a fresh connection pool which will check the password because there - # are no longer any cached pool connections. + # Even though we use an incorrect password, the original connection 'conn' will be returned by the pool, and we can + # still use it. with AwsWrapperConnection.connect( psycopg.Connection.connect, **params, password=incorrect_password) as incorrect_password_conn: - # Will not reach - exception will be thrown - pass - except Exception: - print("Failed to connect - password was incorrect") + incorrect_password_conn.cursor().execute("SELECT 1") + + # Closes all pools and removes all cached pool connections + ConnectionProviderManager.release_resources() + + try: + # Correctly throws an exception - creates a fresh connection pool which will check the password because there + # are no longer any cached pool connections. + with AwsWrapperConnection.connect( + psycopg.Connection.connect, **params, password=incorrect_password) as incorrect_password_conn: + # Will not reach - exception will be thrown + pass + except Exception: + print("Failed to connect - password was incorrect") + finally: + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/PGLimitless.py b/docs/examples/PGLimitless.py index a7787b1f..e81bce1a 100644 --- a/docs/examples/PGLimitless.py +++ b/docs/examples/PGLimitless.py @@ -14,19 +14,23 @@ import psycopg -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources if __name__ == "__main__": - with AwsWrapperConnection.connect( - psycopg.Connection.connect, - host="limitless-cluster.limitless-xyz.us-east-1.rds.amazonaws.com", - dbname="postgres_limitless", - user="user", - password="password", - plugins="limitless", - autocommit=True - ) as awsconn, awsconn.cursor() as awscursor: - awscursor.execute("SELECT * FROM pg_catalog.aurora_db_instance_identifier()") + try: + with AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="limitless-cluster.limitless-xyz.us-east-1.rds.amazonaws.com", + dbname="postgres_limitless", + user="user", + password="password", + plugins="limitless", + autocommit=True + ) as awsconn, awsconn.cursor() as awscursor: + awscursor.execute("SELECT * FROM pg_catalog.aurora_db_instance_identifier()") - res = awscursor.fetchone() - print(res) + res = awscursor.fetchone() + print(res) + finally: + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/PGOktaAuthentication.py b/docs/examples/PGOktaAuthentication.py index f1f5eb48..8c365da2 100644 --- a/docs/examples/PGOktaAuthentication.py +++ b/docs/examples/PGOktaAuthentication.py @@ -14,25 +14,29 @@ import psycopg -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources if __name__ == "__main__": - with AwsWrapperConnection.connect( - psycopg.Connection.connect, - host="database.cluster-xyz.us-east-2.rds.amazonaws.com", - dbname="postgres", - plugins="okta", - idp_endpoint="ec2amaz-ab3cdef.example.com", - app_id="abcde1fgh3kLZTBz1S5d7", - iam_role_arn="arn:aws:iam::123456789012:role/adfs_example_iam_role", - iam_idp_arn="arn:aws:iam::123456789012:saml-provider/adfs_example", - iam_region="us-east-2", - idp_username="some_federated_username@example.com", - idp_password="some_password", - db_user="john", - autocommit=True - ) as awsconn, awsconn.cursor() as awscursor: - awscursor.execute("SELECT * FROM pg_catalog.aurora_db_instance_identifier()") + try: + with AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="database.cluster-xyz.us-east-2.rds.amazonaws.com", + dbname="postgres", + plugins="okta", + idp_endpoint="ec2amaz-ab3cdef.example.com", + app_id="abcde1fgh3kLZTBz1S5d7", + iam_role_arn="arn:aws:iam::123456789012:role/adfs_example_iam_role", + iam_idp_arn="arn:aws:iam::123456789012:saml-provider/adfs_example", + iam_region="us-east-2", + idp_username="some_federated_username@example.com", + idp_password="some_password", + db_user="john", + autocommit=True + ) as awsconn, awsconn.cursor() as awscursor: + awscursor.execute("SELECT * FROM pg_catalog.aurora_db_instance_identifier()") - res = awscursor.fetchone() - print(res) + res = awscursor.fetchone() + print(res) + finally: + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/PGOpenTelemetry.py b/docs/examples/PGOpenTelemetry.py index 02b2ed5d..d882fb7a 100644 --- a/docs/examples/PGOpenTelemetry.py +++ b/docs/examples/PGOpenTelemetry.py @@ -29,7 +29,7 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources SQL_DBLIST = "select datname from pg_database;" @@ -56,25 +56,29 @@ tracer = trace.get_tracer(__name__) with tracer.start_as_current_span("python_otlp_telemetry_app") as segment: - with AwsWrapperConnection.connect( - psycopg.Connection.connect, - host="db-identifier-postgres.XYZ.us-east-2.rds.amazonaws.com", - dbname="test_db", - user="user", - password="password", - plugins="failover,host_monitoring", - wrapper_dialect="aurora-pg", - autocommit=True, - enable_telemetry=True, - telemetry_submit_toplevel=False, - telemetry_traces_backend="OTLP", - telemetry_metrics_backend="OTLP", - telemetry_failover_additional_top_trace=True - ) as awsconn: - awscursor = awsconn.cursor() - awscursor.execute(SQL_DBLIST) - res = awscursor.fetchall() - for record in res: - print(record) + try: + with AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="db-identifier-postgres.XYZ.us-east-2.rds.amazonaws.com", + dbname="test_db", + user="user", + password="password", + plugins="failover,host_monitoring", + wrapper_dialect="aurora-pg", + autocommit=True, + enable_telemetry=True, + telemetry_submit_toplevel=False, + telemetry_traces_backend="OTLP", + telemetry_metrics_backend="OTLP", + telemetry_failover_additional_top_trace=True + ) as awsconn: + awscursor = awsconn.cursor() + awscursor.execute(SQL_DBLIST) + res = awscursor.fetchall() + for record in res: + print(record) + finally: + # Clean up global resources created by wrapper + release_resources() print("-- end of application") diff --git a/docs/examples/PGReadWriteSplitting.py b/docs/examples/PGReadWriteSplitting.py index 1e5b46d9..1e1b26a9 100644 --- a/docs/examples/PGReadWriteSplitting.py +++ b/docs/examples/PGReadWriteSplitting.py @@ -22,7 +22,7 @@ import psycopg # type: ignore -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources from aws_advanced_python_wrapper.connection_provider import \ ConnectionProviderManager from aws_advanced_python_wrapper.errors import ( @@ -143,3 +143,6 @@ def execute_queries_with_failover_handling(conn: Connection, sql: str, params: O """ If connection pools were enabled, close them here """ ConnectionProviderManager.release_resources() + + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/PGSecretsManager.py b/docs/examples/PGSecretsManager.py index 7fbe9774..f6126b38 100644 --- a/docs/examples/PGSecretsManager.py +++ b/docs/examples/PGSecretsManager.py @@ -16,17 +16,21 @@ import psycopg -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources if __name__ == "__main__": - with AwsWrapperConnection.connect( - psycopg.Connection.connect, - host="database.cluster-xyz.us-east-1.rds.amazonaws.com", - dbname="postgres", - secrets_manager_secret_id="arn:aws:secretsmanager:::secret:Secre78tName-6RandomCharacters", - secrets_manager_region="us-east-2", - plugins="aws_secrets_manager" - ) as awsconn, awsconn.cursor() as cursor: - cursor.execute("SELECT pg_catalog.aurora_db_instance_identifier()") - for record in cursor.fetchone(): - print(record) + try: + with AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="database.cluster-xyz.us-east-1.rds.amazonaws.com", + dbname="postgres", + secrets_manager_secret_id="arn:aws:secretsmanager:::secret:Secre78tName-6RandomCharacters", + secrets_manager_region="us-east-2", + plugins="aws_secrets_manager" + ) as awsconn, awsconn.cursor() as cursor: + cursor.execute("SELECT pg_catalog.aurora_db_instance_identifier()") + for record in cursor.fetchone(): + print(record) + finally: + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/PGSimpleReadWriteSplitting.py b/docs/examples/PGSimpleReadWriteSplitting.py index 28e65112..944cdee0 100644 --- a/docs/examples/PGSimpleReadWriteSplitting.py +++ b/docs/examples/PGSimpleReadWriteSplitting.py @@ -21,7 +21,7 @@ import psycopg # type: ignore -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources from aws_advanced_python_wrapper.errors import ( FailoverFailedError, FailoverSuccessError, TransactionResolutionUnknownError) @@ -117,3 +117,5 @@ def execute_queries_with_failover_handling(conn: Connection, sql: str, params: O finally: with AwsWrapperConnection.connect(psycopg.Connection.connect, **params) as conn: execute_queries_with_failover_handling(conn, "DROP TABLE bank_test") + # Clean up global resources created by wrapper + release_resources() diff --git a/docs/examples/PGXRayTelemetry.py b/docs/examples/PGXRayTelemetry.py index 2b1888d3..56f6365a 100644 --- a/docs/examples/PGXRayTelemetry.py +++ b/docs/examples/PGXRayTelemetry.py @@ -21,7 +21,7 @@ from aws_xray_sdk.core import xray_recorder from aws_xray_sdk.core.sampling.local.sampler import LocalSampler -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources SQL_DBLIST = "select datname from pg_database;" @@ -33,24 +33,28 @@ global_sdk_config.set_sdk_enabled(True) with xray_recorder.in_segment("python_xray_telemetry_app") as segment: - with AwsWrapperConnection.connect( - psycopg.Connection.connect, - host="db-identifier-postgres.XYZ.us-east-2.rds.amazonaws.com", - dbname="test_db", - user="user", - password="password", - plugins="failover,host_monitoring", - wrapper_dialect="aurora-pg", - autocommit=True, - enable_telemetry=True, - telemetry_submit_toplevel=False, - telemetry_traces_backend="XRAY", - telemetry_metrics_backend="NONE" - ) as awsconn: - awscursor = awsconn.cursor() - awscursor.execute(SQL_DBLIST) - res = awscursor.fetchall() - for record in res: - print(record) + try: + with AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="db-identifier-postgres.XYZ.us-east-2.rds.amazonaws.com", + dbname="test_db", + user="user", + password="password", + plugins="failover,host_monitoring", + wrapper_dialect="aurora-pg", + autocommit=True, + enable_telemetry=True, + telemetry_submit_toplevel=False, + telemetry_traces_backend="XRAY", + telemetry_metrics_backend="NONE" + ) as awsconn: + awscursor = awsconn.cursor() + awscursor.execute(SQL_DBLIST) + res = awscursor.fetchall() + for record in res: + print(record) + finally: + # Clean up global resources created by wrapper + release_resources() print("-- end of application") diff --git a/docs/using-the-python-driver/UsingThePythonDriver.md b/docs/using-the-python-driver/UsingThePythonDriver.md index 9bb19cd6..bc753267 100644 --- a/docs/using-the-python-driver/UsingThePythonDriver.md +++ b/docs/using-the-python-driver/UsingThePythonDriver.md @@ -44,6 +44,33 @@ These parameters are applicable to any instance of the AWS Advanced Python Drive | tcp_keepalive_interval | Number of seconds to wait before sending additional keepalive probes after the initial probe has been sent. | False | None | | tcp_keepalive_probes | Number of keepalive probes to send before concluding that the connection is invalid. | False | None | +## Resource Management + +The AWS Advanced Python Wrapper creates background threads and thread pools for various plugins during operations such as host monitoring and connection management. To ensure proper cleanup and prevent resource leaks, it's important to release these resources when your application shuts down. + +### Cleaning Up Resources + +Call the following methods before your application terminates: + +```python +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources + +try: + # Your application code here + conn = AwsWrapperConnection.connect(...) + # ... use connection +finally: + # Clean up all resources before application exit + release_resources() +``` + +> [!IMPORTANT] +> Always call `release_resources()` at application shutdown to ensure: +> - All monitoring threads are properly terminated +> - Thread pools are shut down gracefully +> - No resource leaks occur +> - The application exits cleanly without hanging + ## Plugins The AWS Advanced Python Driver uses plugins to execute database API calls. You can think of a plugin as an extensible code module that adds extra logic around any database API calls. The AWS Advanced Python Driver has a number of [built-in plugins](#list-of-available-plugins) available for use. diff --git a/docs/using-the-python-driver/using-plugins/UsingTheHostMonitoringPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheHostMonitoringPlugin.md index a1a19ad4..bb483326 100644 --- a/docs/using-the-python-driver/using-plugins/UsingTheHostMonitoringPlugin.md +++ b/docs/using-the-python-driver/using-plugins/UsingTheHostMonitoringPlugin.md @@ -21,8 +21,8 @@ This plugin only works with drivers that support aborting connections from a sep > [IMPORTANT]\ > The Host Monitoring Plugin creates monitoring threads in the background to monitor all connections established to each cluster instance. The monitoring threads can be cleaned up in two ways: > 1. If there are no connections to the cluster instance the thread is monitoring for over a period of time, the Host Monitoring Plugin will automatically terminate the thread. This period of time is adjustable via the `monitor_disposal_time_ms` parameter. -> 2. Client applications can manually call `MonitoringThreadContainer.clean_up()` to clean up any dangling resources. -> It is best practice to call `MonitoringThreadContainer.clean_up()` at the end of the application to ensure a graceful exit; otherwise, the application may wait until the `monitor_disposal_time_ms` has been passed before terminating. This is because the Python driver waits for all daemon threads to complete before exiting. +> 2. Client applications can manually call `aws_advanced_python_wrapper.release_resources()` to clean up any dangling resources. +> It is best practice to call `aws_advanced_python_wrapper.release_resources()` at the end of the application to ensure a graceful exit; otherwise, the application may wait until the `monitor_disposal_time_ms` has been passed before terminating. This is because the Python driver waits for all daemon threads to complete before exiting. > See [PGFailover](../../examples/PGFailover.py) for an example. ### Enhanced Failure Monitoring Parameters @@ -51,24 +51,27 @@ The Host Monitoring Connection Plugin may create new monitoring connections to c ```python import psycopg -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources props = { "monitoring-connect_timeout": 10, "monitoring-socket_timeout": 10 } - -conn = AwsWrapperConnection.connect( - psycopg.Connection.connect, - host="database.cluster-xyz.us-east-1.rds.amazonaws.com", - dbname="postgres", - user="john", - password="pwd", - plugins="host_monitoring", - # Configure the timeout values for all non-monitoring connections. - connect_timeout=30, socket_timeout=30, - # Configure different timeout values for the monitoring connections. - **props) + +try: + conn = AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="database.cluster-xyz.us-east-1.rds.amazonaws.com", + dbname="postgres", + user="john", + password="pwd", + plugins="host_monitoring", + # Configure the timeout values for all non-monitoring connections. + connect_timeout=30, socket_timeout=30, + # Configure different timeout values for the monitoring connections. + **props) +finally: + release_resources() ``` > [!IMPORTANT]\ diff --git a/tests/integration/container/conftest.py b/tests/integration/container/conftest.py index 2e23eeba..53c73c6f 100644 --- a/tests/integration/container/conftest.py +++ b/tests/integration/container/conftest.py @@ -31,6 +31,8 @@ from aws_advanced_python_wrapper.host_monitoring_plugin import \ MonitoringThreadContainer from aws_advanced_python_wrapper.plugin_service import PluginServiceImpl +from aws_advanced_python_wrapper.thread_pool_container import \ + ThreadPoolContainer from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils @@ -140,6 +142,7 @@ def pytest_runtest_setup(item): CustomEndpointPlugin._monitors.clear() CustomEndpointMonitor._custom_endpoint_info_cache.clear() MonitoringThreadContainer.clean_up() + ThreadPoolContainer.release_resources(wait=True) ConnectionProviderManager.reset_provider() DatabaseDialectManager.reset_custom_dialect() diff --git a/tests/integration/container/test_aurora_failover.py b/tests/integration/container/test_aurora_failover.py index 5ddf2a69..2634e456 100644 --- a/tests/integration/container/test_aurora_failover.py +++ b/tests/integration/container/test_aurora_failover.py @@ -22,8 +22,6 @@ from aws_advanced_python_wrapper.errors import ( FailoverSuccessError, TransactionResolutionUnknownError) -from aws_advanced_python_wrapper.host_monitoring_plugin import \ - MonitoringThreadContainer from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) from .utils.conditions import (disable_on_features, enable_on_deployments, @@ -34,7 +32,7 @@ if TYPE_CHECKING: from .utils.test_instance_info import TestInstanceInfo from .utils.test_driver import TestDriver - +from aws_advanced_python_wrapper import release_resources from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.wrapper import AwsWrapperConnection from .utils.driver_helper import DriverHelper @@ -59,7 +57,7 @@ def setup_method(self, request): self.logger.info(f"Starting test: {request.node.name}") yield self.logger.info(f"Ending test: {request.node.name}") - MonitoringThreadContainer.clean_up() + release_resources() gc.collect() @pytest.fixture(scope='class') diff --git a/tests/integration/container/test_basic_connectivity.py b/tests/integration/container/test_basic_connectivity.py index 4de78f57..ddce9741 100644 --- a/tests/integration/container/test_basic_connectivity.py +++ b/tests/integration/container/test_basic_connectivity.py @@ -16,8 +16,7 @@ from typing import TYPE_CHECKING -from aws_advanced_python_wrapper.host_monitoring_plugin import \ - MonitoringThreadContainer +from aws_advanced_python_wrapper import release_resources if TYPE_CHECKING: from .utils.test_instance_info import TestInstanceInfo @@ -150,4 +149,4 @@ def test_wrapper_connection_reader_cluster_with_efm_enabled(self, test_driver: T conn.close() - MonitoringThreadContainer.clean_up() + release_resources() diff --git a/tests/integration/container/test_read_write_splitting.py b/tests/integration/container/test_read_write_splitting.py index 781423e2..24a6e485 100644 --- a/tests/integration/container/test_read_write_splitting.py +++ b/tests/integration/container/test_read_write_splitting.py @@ -17,15 +17,13 @@ import pytest from sqlalchemy import PoolProxiedConnection -from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources from aws_advanced_python_wrapper.connection_provider import \ ConnectionProviderManager from aws_advanced_python_wrapper.errors import ( AwsWrapperError, FailoverFailedError, FailoverSuccessError, ReadWriteSplittingError, TransactionResolutionUnknownError) from aws_advanced_python_wrapper.host_list_provider import RdsHostListProvider -from aws_advanced_python_wrapper.host_monitoring_plugin import \ - MonitoringThreadContainer from aws_advanced_python_wrapper.sql_alchemy_connection_provider import \ SqlAlchemyPooledConnectionProvider from aws_advanced_python_wrapper.utils.log import Logger @@ -63,7 +61,7 @@ def setup_method(self, request): yield self.logger.info(f"Ending test: {request.node.name}") - MonitoringThreadContainer.clean_up() + release_resources() gc.collect() # Plugin configurations diff --git a/tests/unit/test_monitor.py b/tests/unit/test_monitor.py index 174eec87..9eca560e 100644 --- a/tests/unit/test_monitor.py +++ b/tests/unit/test_monitor.py @@ -18,6 +18,7 @@ import psycopg import pytest +from aws_advanced_python_wrapper import release_resources from aws_advanced_python_wrapper.host_monitoring_plugin import ( Monitor, MonitoringContext, MonitoringThreadContainer) from aws_advanced_python_wrapper.hostinfo import HostInfo @@ -83,7 +84,7 @@ def monitor(mock_plugin_service, host_info, props): def release_container(): yield while MonitoringThreadContainer._instance is not None: - MonitoringThreadContainer.clean_up() + release_resources() @pytest.fixture @@ -230,7 +231,7 @@ def test_run__no_contexts(mocker, monitor): assert container._monitor_map.get(host_alias) is None assert container._tasks_map.get(monitor) is None - MonitoringThreadContainer.clean_up() + release_resources() def test_check_connection_status__valid_then_invalid(mocker, monitor): diff --git a/tests/unit/test_monitor_service.py b/tests/unit/test_monitor_service.py index ed8e743b..6ce28f2a 100644 --- a/tests/unit/test_monitor_service.py +++ b/tests/unit/test_monitor_service.py @@ -16,6 +16,7 @@ import pytest from _weakref import ref +from aws_advanced_python_wrapper import release_resources from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_monitoring_plugin import ( MonitoringThreadContainer, MonitorService) @@ -46,15 +47,8 @@ def mock_monitor(mocker): @pytest.fixture -def mock_executor(mocker): - return mocker.MagicMock() - - -@pytest.fixture -def thread_container(mock_executor): - container = MonitoringThreadContainer() - MonitoringThreadContainer._executor = mock_executor - return container +def thread_container(): + return MonitoringThreadContainer() @pytest.fixture @@ -80,7 +74,7 @@ def setup_teardown(mocker, mock_thread_container, mock_plugin_service, mock_moni yield while MonitoringThreadContainer._instance is not None: - MonitoringThreadContainer.clean_up() + release_resources() def test_start_monitoring( @@ -99,16 +93,20 @@ def test_start_monitoring( assert aliases == monitor_service_mocked_container._cached_monitor_aliases -def test_start_monitoring__multiple_calls(monitor_service_with_container, mock_monitor, mock_executor, mock_conn): +def test_start_monitoring__multiple_calls(monitor_service_with_container, mock_monitor, mock_conn, mocker): aliases = frozenset({"instance-1"}) + # Mock ThreadPoolContainer.get_thread_pool + mock_thread_pool = mocker.MagicMock() + mocker.patch('aws_advanced_python_wrapper.host_monitoring_plugin.ThreadPoolContainer.get_thread_pool', return_value=mock_thread_pool) + num_calls = 5 for _ in range(num_calls): monitor_service_with_container.start_monitoring( mock_conn, aliases, HostInfo("instance-1"), Properties(), 5000, 1000, 3) assert num_calls == mock_monitor.start_monitoring.call_count - mock_executor.submit.assert_called_once_with(mock_monitor.run) + mock_thread_pool.submit.assert_called_once_with(mock_monitor.run) assert mock_monitor == monitor_service_with_container._cached_monitor() assert aliases == monitor_service_with_container._cached_monitor_aliases diff --git a/tests/unit/test_monitoring_thread_container.py b/tests/unit/test_monitoring_thread_container.py index cfd9e155..4e3dac25 100644 --- a/tests/unit/test_monitoring_thread_container.py +++ b/tests/unit/test_monitoring_thread_container.py @@ -14,21 +14,15 @@ import pytest +from aws_advanced_python_wrapper import release_resources from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_monitoring_plugin import \ MonitoringThreadContainer @pytest.fixture -def container(mock_executor): - container = MonitoringThreadContainer() - MonitoringThreadContainer._executor = mock_executor - return container - - -@pytest.fixture -def mock_executor(mocker): - return mocker.MagicMock() +def container(): + return MonitoringThreadContainer() @pytest.fixture @@ -68,16 +62,20 @@ def mock_monitor_supplier(mocker, mock_monitor1, mock_monitor2): def release_container(): yield while MonitoringThreadContainer._instance is not None: - MonitoringThreadContainer.clean_up() + release_resources() def test_get_or_create_monitor__monitor_created( - container, mock_monitor_supplier, mock_stopped_monitor, mock_monitor1, mock_executor, mock_future): + container, mock_monitor_supplier, mock_stopped_monitor, mock_monitor1, mock_future, mocker): + mock_thread_pool = mocker.MagicMock() + mock_thread_pool.submit.return_value = mock_future + mocker.patch('aws_advanced_python_wrapper.host_monitoring_plugin.ThreadPoolContainer.get_thread_pool', return_value=mock_thread_pool) + result = container.get_or_create_monitor(frozenset({"alias-1", "alias-2"}), mock_monitor_supplier) assert mock_monitor1 == result mock_monitor_supplier.assert_called_once() - mock_executor.submit.assert_called_once_with(mock_monitor1.run) + mock_thread_pool.submit.assert_called_once_with(mock_monitor1.run) assert mock_monitor1 == container._monitor_map.get("alias-1") assert mock_monitor1 == container._monitor_map.get("alias-2") @@ -175,3 +173,4 @@ def test_release_instance(mocker, container, mock_monitor1, mock_future): assert 0 == len(container._tasks_map) mock_future.cancel.assert_called_once() assert MonitoringThreadContainer._instance is None + release_resources() diff --git a/tests/unit/test_multithreaded_monitor_service.py b/tests/unit/test_multithreaded_monitor_service.py index db24d115..26cb9160 100644 --- a/tests/unit/test_multithreaded_monitor_service.py +++ b/tests/unit/test_multithreaded_monitor_service.py @@ -20,6 +20,7 @@ import psycopg import pytest +from aws_advanced_python_wrapper import release_resources from aws_advanced_python_wrapper.host_monitoring_plugin import ( MonitoringContext, MonitoringThreadContainer, MonitorService) from aws_advanced_python_wrapper.hostinfo import HostInfo @@ -110,7 +111,7 @@ def verify_concurrency(mock_monitor, mock_executor, mock_future, counter, concur assert concurrent_counter.get() > 0 concurrent_counter.set(0) - MonitoringThreadContainer.clean_up() + release_resources() def test_start_monitoring__connections_to_different_hosts( @@ -137,7 +138,7 @@ def test_start_monitoring__connections_to_different_hosts( expected_create_monitor_calls = [mocker.call(host_info, props, MonitoringThreadContainer())] * num_conns mock_create_monitor.assert_has_calls(expected_create_monitor_calls) finally: - release_resources(services) + release_service_resource(services) def test_start_monitoring__connections_to_same_host( @@ -164,7 +165,7 @@ def test_start_monitoring__connections_to_same_host( expected_create_monitor_calls = [mocker.call(host_info, props, MonitoringThreadContainer())] mock_create_monitor.assert_has_calls(expected_create_monitor_calls) finally: - release_resources(services) + release_service_resource(services) def test_stop_monitoring__connections_to_different_hosts( @@ -186,7 +187,7 @@ def test_stop_monitoring__connections_to_different_hosts( expected_stop_monitoring_calls = [mocker.call(context) for context in contexts] mock_monitor.stop_monitoring.assert_has_calls(expected_stop_monitoring_calls, True) finally: - release_resources(services) + release_service_resource(services) def test_stop_monitoring__connections_to_same_host( @@ -208,7 +209,7 @@ def test_stop_monitoring__connections_to_same_host( expected_stop_monitoring_calls = [mocker.call(context) for context in contexts] mock_monitor.stop_monitoring.assert_has_calls(expected_stop_monitoring_calls, True) finally: - release_resources(services) + release_service_resource(services) def generate_host_aliases(num_aliases: int, generate_unique_aliases: bool) -> List[FrozenSet[str]]: @@ -247,7 +248,7 @@ def _generate_contexts(num_contexts: int, generate_unique_contexts) -> List[Moni return _generate_contexts -def release_resources(services): +def release_service_resource(services): for service in services: service.release_resources() diff --git a/tests/unit/test_thread_pool_container.py b/tests/unit/test_thread_pool_container.py new file mode 100644 index 00000000..5f4d415c --- /dev/null +++ b/tests/unit/test_thread_pool_container.py @@ -0,0 +1,103 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent.futures import ThreadPoolExecutor + +import pytest + +from aws_advanced_python_wrapper.thread_pool_container import \ + ThreadPoolContainer + + +@pytest.fixture(autouse=True) +def cleanup_pools(): + """Clean up all pools after each test""" + yield + ThreadPoolContainer.release_resources() + + +def test_get_thread_pool_creates_new_pool(): + pool = ThreadPoolContainer.get_thread_pool("test_pool") + assert isinstance(pool, ThreadPoolExecutor) + assert ThreadPoolContainer.has_pool("test_pool") + + +def test_get_thread_pool_returns_existing_pool(): + pool1 = ThreadPoolContainer.get_thread_pool("test_pool") + pool2 = ThreadPoolContainer.get_thread_pool("test_pool") + assert pool1 is pool2 + + +def test_get_thread_pool_with_max_workers(): + pool = ThreadPoolContainer.get_thread_pool("test_pool", max_workers=5) + assert pool._max_workers == 5 + + +def test_has_pool(): + assert not ThreadPoolContainer.has_pool("nonexistent") + ThreadPoolContainer.get_thread_pool("test_pool") + assert ThreadPoolContainer.has_pool("test_pool") + + +def test_get_pool_names(): + assert ThreadPoolContainer.get_pool_names() == [] + ThreadPoolContainer.get_thread_pool("pool1") + ThreadPoolContainer.get_thread_pool("pool2") + names = ThreadPoolContainer.get_pool_names() + assert "pool1" in names + assert "pool2" in names + assert len(names) == 2 + + +def test_get_pool_count(): + assert ThreadPoolContainer.get_pool_count() == 0 + ThreadPoolContainer.get_thread_pool("pool1") + assert ThreadPoolContainer.get_pool_count() == 1 + ThreadPoolContainer.get_thread_pool("pool2") + assert ThreadPoolContainer.get_pool_count() == 2 + + +def test_release_pool(): + ThreadPoolContainer.get_thread_pool("test_pool") + assert ThreadPoolContainer.has_pool("test_pool") + + result = ThreadPoolContainer.release_pool("test_pool") + assert result is True + assert not ThreadPoolContainer.has_pool("test_pool") + + +def test_release_nonexistent_pool(): + result = ThreadPoolContainer.release_pool("nonexistent") + assert result is False + + +def test_release_resources(): + ThreadPoolContainer.get_thread_pool("pool1") + ThreadPoolContainer.get_thread_pool("pool2") + assert ThreadPoolContainer.get_pool_count() == 2 + + ThreadPoolContainer.release_resources() + assert ThreadPoolContainer.get_pool_count() == 0 + + +def test_set_default_max_workers(): + ThreadPoolContainer.set_default_max_workers(10) + pool = ThreadPoolContainer.get_thread_pool("test_pool") + assert pool._max_workers == 10 + + +def test_thread_name_prefix(): + pool = ThreadPoolContainer.get_thread_pool("custom_name") + # Check that the thread name prefix is set correctly + assert pool._thread_name_prefix == "custom_name"