Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions aws_advanced_python_wrapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from logging import DEBUG, getLogger

from .cleanup import release_resources
from .utils.utils import LogUtils
from .wrapper import AwsWrapperConnection

Expand All @@ -23,6 +24,17 @@
threadsafety = 2
paramstyle = "pyformat"

# Public API
__all__ = [
'connect',
'AwsWrapperConnection',
'release_resources',
'set_logger',
'apilevel',
'threadsafety',
'paramstyle'
]


def set_logger(name='aws_advanced_python_wrapper', level=DEBUG, format_string=None):
LogUtils.setup_logger(getLogger(name), level, format_string)
24 changes: 24 additions & 0 deletions aws_advanced_python_wrapper/cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from aws_advanced_python_wrapper.host_monitoring_plugin import \
MonitoringThreadContainer
from aws_advanced_python_wrapper.thread_pool_container import \
ThreadPoolContainer


def release_resources() -> None:
"""Release all global resources used by the wrapper."""
MonitoringThreadContainer.clean_up()
ThreadPoolContainer.release_resources()
9 changes: 6 additions & 3 deletions aws_advanced_python_wrapper/custom_endpoint_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def _run(self):
len(endpoints),
endpoint_hostnames)

sleep(self._refresh_rate_ns / 1_000_000_000)
if self._stop_event.wait(self._refresh_rate_ns / 1_000_000_000):
break
continue

endpoint_info = CustomEndpointInfo.from_db_cluster_endpoint(endpoints[0])
Expand All @@ -178,7 +179,8 @@ def _run(self):
if cached_info is not None and cached_info == endpoint_info:
elapsed_time = perf_counter_ns() - start_ns
sleep_duration = max(0, self._refresh_rate_ns - elapsed_time)
sleep(sleep_duration / 1_000_000_000)
if self._stop_event.wait(sleep_duration / 1_000_000_000):
break
continue

logger.debug(
Expand All @@ -196,7 +198,8 @@ def _run(self):

elapsed_time = perf_counter_ns() - start_ns
sleep_duration = max(0, self._refresh_rate_ns - elapsed_time)
sleep(sleep_duration / 1_000_000_000)
if self._stop_event.wait(sleep_duration / 1_000_000_000):
break
continue
except InterruptedError as e:
raise e
Expand Down
8 changes: 5 additions & 3 deletions aws_advanced_python_wrapper/database_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .exception_handling import ExceptionHandler

from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
from concurrent.futures import TimeoutError
from contextlib import closing
from enum import Enum, auto

Expand All @@ -37,6 +37,8 @@
from aws_advanced_python_wrapper.host_list_provider import (
ConnectionStringHostListProvider, RdsHostListProvider)
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.thread_pool_container import \
ThreadPoolContainer
from aws_advanced_python_wrapper.utils.decorators import \
preserve_transaction_status_with_timeout
from aws_advanced_python_wrapper.utils.log import Logger
Expand Down Expand Up @@ -638,7 +640,7 @@ class DatabaseDialectManager(DatabaseDialectProvider):
_ENDPOINT_CACHE_EXPIRATION_NS = 30 * 60_000_000_000 # 30 minutes
_known_endpoint_dialects: CacheMap[str, DialectCode] = CacheMap()
_custom_dialect: Optional[DatabaseDialect] = None
_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="DatabaseDialectManagerExecutor")
_executor_name: ClassVar[str] = "DatabaseDialectManagerExecutor"
_known_dialects_by_code: Dict[DialectCode, DatabaseDialect] = {
DialectCode.MYSQL: MysqlDatabaseDialect(),
DialectCode.RDS_MYSQL: RdsMysqlDialect(),
Expand Down Expand Up @@ -776,7 +778,7 @@ def query_for_dialect(self, url: str, host_info: Optional[HostInfo], conn: Conne
timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get(self._props)
try:
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
DatabaseDialectManager._executor,
ThreadPoolContainer.get_thread_pool(DatabaseDialectManager._executor_name),
timeout_sec,
driver_dialect,
conn)(dialect_candidate.is_dialect)
Expand Down
8 changes: 5 additions & 3 deletions aws_advanced_python_wrapper/driver_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
from aws_advanced_python_wrapper.pep249 import Connection, Cursor

from abc import ABC
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
from concurrent.futures import TimeoutError

from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes
from aws_advanced_python_wrapper.errors import (QueryTimeoutError,
UnsupportedOperationError)
from aws_advanced_python_wrapper.thread_pool_container import \
ThreadPoolContainer
from aws_advanced_python_wrapper.utils.decorators import timeout
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
Expand All @@ -40,7 +42,7 @@ class DriverDialect(ABC):
_QUERY = "SELECT 1"
_ALL_METHODS = "*"

_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="DriverDialectExecutor")
_executor_name: ClassVar[str] = "DriverDialectExecutor"
_dialect_code: str = DriverDialectCodes.GENERIC
_network_bound_methods: Set[str] = {_ALL_METHODS}
_read_only: bool = False
Expand Down Expand Up @@ -136,7 +138,7 @@ def execute(

if exec_timeout > 0:
try:
execute_with_timeout = timeout(DriverDialect._executor, exec_timeout)(exec_func)
execute_with_timeout = timeout(ThreadPoolContainer.get_thread_pool(DriverDialect._executor_name), exec_timeout)(exec_func)
return execute_with_timeout()
except TimeoutError as e:
raise QueryTimeoutError(Messages.get_formatted("DriverDialect.ExecuteTimeout", method_name)) from e
Expand Down
20 changes: 14 additions & 6 deletions aws_advanced_python_wrapper/host_list_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import uuid
from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
from concurrent.futures import TimeoutError
from contextlib import closing
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -39,6 +39,8 @@
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.pep249 import (Connection, Cursor,
ProgrammingError)
from aws_advanced_python_wrapper.thread_pool_container import \
ThreadPoolContainer
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
Expand Down Expand Up @@ -148,8 +150,6 @@ class RdsHostListProvider(DynamicHostListProvider, HostListProvider):
# cluster IDs so that connections to the same clusters can share topology info.
_cluster_ids_to_update: CacheMap[str, str] = CacheMap()

_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="RdsHostListProviderExecutor")

def __init__(self, host_list_provider_service: HostListProviderService, props: Properties, topology_utils: TopologyUtils):
self._host_list_provider_service: HostListProviderService = host_list_provider_service
self._props: Properties = props
Expand Down Expand Up @@ -425,6 +425,8 @@ class TopologyUtils(ABC):
to various database engine deployments (e.g. Aurora, Multi-AZ, etc.).
"""

_executor_name: ClassVar[str] = "TopologyUtils"

def __init__(self, dialect: db_dialect.TopologyAwareDatabaseDialect, props: Properties):
self._dialect: db_dialect.TopologyAwareDatabaseDialect = dialect
self._rds_utils = RdsUtils()
Expand Down Expand Up @@ -487,7 +489,7 @@ def query_for_topology(
an empty tuple will be returned.
"""
query_for_topology_func_with_timeout = preserve_transaction_status_with_timeout(
RdsHostListProvider._executor, self._max_timeout, driver_dialect, conn)(self._query_for_topology)
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout, driver_dialect, conn)(self._query_for_topology)
return query_for_topology_func_with_timeout(conn)

@abstractmethod
Expand Down Expand Up @@ -549,7 +551,7 @@ def create_host(
def get_host_role(self, connection: Connection, driver_dialect: DriverDialect) -> HostRole:
try:
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
RdsHostListProvider._executor, self._max_timeout, driver_dialect, connection)(self._get_host_role)
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout, driver_dialect, connection)(self._get_host_role)
result = cursor_execute_func_with_timeout(connection)
if result is not None:
is_reader = result[0]
Expand All @@ -572,7 +574,7 @@ def get_host_id(self, connection: Connection, driver_dialect: DriverDialect) ->
"""

cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
RdsHostListProvider._executor, self._max_timeout, driver_dialect, connection)(self._get_host_id)
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout, driver_dialect, connection)(self._get_host_id)
result = cursor_execute_func_with_timeout(connection)
if result:
host_id: str = result[0]
Expand All @@ -586,6 +588,9 @@ def _get_host_id(self, conn: Connection):


class AuroraTopologyUtils(TopologyUtils):

_executor_name: ClassVar[str] = "AuroraTopologyUtils"

def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]]:
"""
Query the database for topology information.
Expand Down Expand Up @@ -636,6 +641,9 @@ def _process_query_results(self, cursor: Cursor) -> Tuple[HostInfo, ...]:


class MultiAzTopologyUtils(TopologyUtils):

_executor_name: ClassVar[str] = "MultiAzTopologyUtils"

def __init__(
self,
dialect: db_dialect.TopologyAwareDatabaseDialect,
Expand Down
21 changes: 10 additions & 11 deletions aws_advanced_python_wrapper/host_monitoring_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.plugin_service import PluginService

from concurrent.futures import (Executor, Future, ThreadPoolExecutor,
TimeoutError)
from concurrent.futures import Future, TimeoutError
from dataclasses import dataclass
from queue import Queue
from threading import Event, Lock, RLock
from time import perf_counter_ns, sleep
from time import perf_counter_ns
from typing import Any, Callable, ClassVar, Dict, FrozenSet, Optional, Set

from _weakref import ReferenceType, ref
Expand All @@ -36,6 +35,8 @@
from aws_advanced_python_wrapper.host_availability import HostAvailability
from aws_advanced_python_wrapper.plugin import (CanReleaseResources, Plugin,
PluginFactory)
from aws_advanced_python_wrapper.thread_pool_container import \
ThreadPoolContainer
from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
Expand Down Expand Up @@ -548,9 +549,8 @@ def _execute_conn_check(self, conn: Connection, timeout_sec: float):
driver_dialect.execute("Cursor.execute", lambda: cursor.execute(query), query, exec_timeout=timeout_sec)
cursor.fetchone()

# Used to help with testing
def sleep(self, duration: int):
sleep(duration)
self._is_stopped.wait(duration)


class MonitoringThreadContainer:
Expand All @@ -565,7 +565,7 @@ class MonitoringThreadContainer:

_monitor_map: ConcurrentDict[str, Monitor] = ConcurrentDict()
_tasks_map: ConcurrentDict[Monitor, Future] = ConcurrentDict()
_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="MonitoringThreadContainerExecutor")
_executor_name: ClassVar[str] = "MonitoringThreadContainerExecutor"

# This logic ensures that this class is a Singleton
def __new__(cls, *args, **kwargs):
Expand Down Expand Up @@ -593,7 +593,9 @@ def _get_or_create_monitor(_) -> Monitor:
if supplied_monitor is None:
raise AwsWrapperError(Messages.get("MonitoringThreadContainer.SupplierMonitorNone"))
self._tasks_map.compute_if_absent(
supplied_monitor, lambda _: MonitoringThreadContainer._executor.submit(supplied_monitor.run))
supplied_monitor,
lambda _: ThreadPoolContainer.get_thread_pool(MonitoringThreadContainer._executor_name)
.submit(supplied_monitor.run))
return supplied_monitor

if monitor is None:
Expand Down Expand Up @@ -648,12 +650,9 @@ def _release_resources(self):
for monitor, _ in self._tasks_map.items():
monitor.stop()

ThreadPoolContainer.release_pool(MonitoringThreadContainer._executor_name, wait=False)
self._tasks_map.clear()

# Reset the executor.
self._executor.shutdown(wait=False)
MonitoringThreadContainer._executor = ThreadPoolExecutor(thread_name_prefix="MonitoringThreadContainerExecutor")


class MonitorService:
def __init__(self, plugin_service: PluginService):
Expand Down
9 changes: 6 additions & 3 deletions aws_advanced_python_wrapper/mysql_driver_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.pep249 import Connection

from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
from concurrent.futures import TimeoutError
from inspect import signature

from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes
from aws_advanced_python_wrapper.errors import UnsupportedOperationError
from aws_advanced_python_wrapper.thread_pool_container import \
ThreadPoolContainer
from aws_advanced_python_wrapper.utils.decorators import timeout
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
Expand Down Expand Up @@ -55,7 +57,7 @@ class MySQLDriverDialect(DriverDialect):
AUTH_METHOD = "mysql_clear_password"
IS_CLOSED_TIMEOUT_SEC = 3

_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="MySQLDriverDialectExecutor")
_executor_name: ClassVar[str] = "MySQLDriverDialectExecutor"

_dialect_code: str = DriverDialectCodes.MYSQL_CONNECTOR_PYTHON
_network_bound_methods: Set[str] = {
Expand Down Expand Up @@ -94,7 +96,8 @@ def is_closed(self, conn: Connection) -> bool:
if self.can_execute_query(conn):
socket_timeout = WrapperProperties.SOCKET_TIMEOUT_SEC.get_float(self._props)
timeout_sec = socket_timeout if socket_timeout > 0 else MySQLDriverDialect.IS_CLOSED_TIMEOUT_SEC
is_connected_with_timeout = timeout(MySQLDriverDialect._executor, timeout_sec)(conn.is_connected) # type: ignore
is_connected_with_timeout = timeout(
ThreadPoolContainer.get_thread_pool(MySQLDriverDialect._executor_name), timeout_sec)(conn.is_connected) # type: ignore

try:
return not is_connected_with_timeout()
Expand Down
8 changes: 5 additions & 3 deletions aws_advanced_python_wrapper/plugin_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory

from abc import abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
from concurrent.futures import TimeoutError
from contextlib import closing
from typing import (Any, Callable, Dict, FrozenSet, Optional, Protocol, Set,
Tuple)
Expand Down Expand Up @@ -85,6 +85,8 @@
from aws_advanced_python_wrapper.simple_read_write_splitting_plugin import \
SimpleReadWriteSplittingPluginFactory
from aws_advanced_python_wrapper.stale_dns_plugin import StaleDnsPluginFactory
from aws_advanced_python_wrapper.thread_pool_container import \
ThreadPoolContainer
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
from aws_advanced_python_wrapper.utils.decorators import \
preserve_transaction_status_with_timeout
Expand Down Expand Up @@ -314,7 +316,7 @@ class PluginServiceImpl(PluginService, HostListProviderService, CanReleaseResour
_host_availability_expiring_cache: CacheMap[str, HostAvailability] = CacheMap()
_status_cache: ClassVar[CacheMap[str, Any]] = CacheMap()

_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="PluginServiceImplExecutor")
_executor_name: ClassVar[str] = "PluginServiceImplExecutor"

def __init__(
self,
Expand Down Expand Up @@ -611,7 +613,7 @@ def fill_aliases(self, connection: Optional[Connection] = None, host_info: Optio
try:
timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get(self._props)
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
PluginServiceImpl._executor, timeout_sec, driver_dialect, connection)(self._fill_aliases)
ThreadPoolContainer.get_thread_pool(PluginServiceImpl._executor_name), timeout_sec, driver_dialect, connection)(self._fill_aliases)
cursor_execute_func_with_timeout(connection, host_info)
except TimeoutError as e:
raise QueryTimeoutError(Messages.get("PluginServiceImpl.FillAliasesTimeout")) from e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ HostResponseTimeMonitor.OpeningConnection=[HostResponseTimeMonitor] Opening a Re
HostResponseTimeMonitor.ResponseTime=[HostResponseTimeMonitor] Response time for '{}': {} ms
HostResponseTimeMonitor.Stopped=[HostResponseTimeMonitor] Stopped Response time thread for host '{}'.

ThreadPoolContainer.ErrorShuttingDownPool=[ThreadPoolContainer] Error shutting down pool '{}': '{}'.

OpenedConnectionTracker.OpenedConnectionsTracked=[OpenedConnectionTracker] Opened Connections Tracked: {}
OpenedConnectionTracker.InvalidatingConnections=[OpenedConnectionTracker] Invalidating opened connections to host: {}
OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet=[OpenedConnectionTracker] The driver is unable to track this opened connection because the instance endpoint is unknown.
Expand Down
Loading
Loading