Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from aws_advanced_python_wrapper.errors import FailoverError
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
Expand Down Expand Up @@ -147,13 +148,12 @@ def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]):


class AuroraConnectionTrackerPlugin(Plugin):
_SUBSCRIBED_METHODS: Set[str] = {"*"}
_current_writer: Optional[HostInfo] = None
_need_update_current_writer: bool = False

@property
def subscribed_methods(self) -> Set[str]:
return self._SUBSCRIBED_METHODS
return self._subscribed_methods

def __init__(self,
plugin_service: PluginService,
Expand All @@ -164,6 +164,11 @@ def __init__(self,
self._props = props
self._rds_utils = rds_utils
self._tracker = tracker
self._subscribed_methods: Set[str] = {DbApiMethod.CONNECT.method_name,
DbApiMethod.CONNECTION_CLOSE.method_name,
DbApiMethod.CONNECT.method_name,
DbApiMethod.NOTIFY_HOST_LIST_CHANGED.method_name}
self._subscribed_methods.update(self._plugin_service.network_bound_methods)

def connect(
self,
Expand Down Expand Up @@ -210,5 +215,6 @@ def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]:


class AuroraConnectionTrackerPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
@staticmethod
def get_instance(plugin_service: PluginService, props: Properties) -> Plugin:
return AuroraConnectionTrackerPlugin(plugin_service, props)
Original file line number Diff line number Diff line change
Expand Up @@ -233,5 +233,6 @@ def _has_no_readers(self) -> bool:


class AuroraInitialConnectionStrategyPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
@staticmethod
def get_instance(plugin_service: PluginService, props: Properties) -> Plugin:
return AuroraInitialConnectionStrategyPlugin(plugin_service)
17 changes: 11 additions & 6 deletions aws_advanced_python_wrapper/aws_secrets_manager_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from aws_advanced_python_wrapper.plugin_service import PluginService

from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
Expand All @@ -45,7 +46,7 @@


class AwsSecretsManagerPlugin(Plugin):
_SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"}
_SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.CONNECT.method_name, DbApiMethod.FORCE_CONNECT.method_name}

_SECRETS_ARN_PATTERN = r"^arn:aws:secretsmanager:(?P<region>[^:\n]*):[^:\n]*:([^:/\n]*[:/])?(.*)$"
_ONE_YEAR_IN_SECONDS = 60 * 60 * 24 * 365
Expand Down Expand Up @@ -136,7 +137,8 @@ def _update_secret(self, token_expiration_ns: int, force_refetch: bool = False)
"""
telemetry_factory = self._plugin_service.get_telemetry_factory()
context = telemetry_factory.open_telemetry_context("fetch credentials", TelemetryTraceLevel.NESTED)
self._fetch_credentials_counter.inc()
if self._fetch_credentials_counter is not None:
self._fetch_credentials_counter.inc()

try:
fetched: bool = False
Expand Down Expand Up @@ -167,11 +169,13 @@ def _update_secret(self, token_expiration_ns: int, force_refetch: bool = False)

return fetched
except Exception as ex:
context.set_success(False)
context.set_exception(ex)
if context is not None:
context.set_success(False)
context.set_exception(ex)
raise ex
finally:
context.close_context()
if context is not None:
context.close_context()

def _fetch_latest_credentials(self):
"""
Expand Down Expand Up @@ -228,5 +232,6 @@ def _get_rds_region(self, secret_id: str, props: Properties) -> str:


class AwsSecretsManagerPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
@staticmethod
def get_instance(plugin_service: PluginService, props: Properties) -> Plugin:
return AwsSecretsManagerPlugin(plugin_service, props)
17 changes: 11 additions & 6 deletions aws_advanced_python_wrapper/blue_green_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from aws_advanced_python_wrapper.host_availability import HostAvailability
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.atomic import AtomicInt
from aws_advanced_python_wrapper.utils.concurrent import (ConcurrentDict,
Expand Down Expand Up @@ -471,7 +472,8 @@ def apply(
"SuspendConnectRouting.SwitchoverCompleteContinueWithConnect",
(time.time() - start_time_sec) * 1000))
finally:
telemetry_context.close_context()
if telemetry_context is not None:
telemetry_context.close_context()

# return None so that the next routing can attempt a connection
return None
Expand Down Expand Up @@ -540,7 +542,8 @@ def apply(
host_info.host,
(time.time() - start_time_sec) * 1000))
finally:
telemetry_context.close_context()
if telemetry_context is not None:
telemetry_context.close_context()

# return None so that the next routing can attempt a connection
return None
Expand Down Expand Up @@ -615,15 +618,16 @@ def apply(
method_name,
(time.time() - start_time_sec) * 1000))
finally:
telemetry_context.close_context()
if telemetry_context is not None:
telemetry_context.close_context()

# return empty so that the next routing can attempt a connection
return ValueContainer.empty()


class BlueGreenPlugin(Plugin):
_SUBSCRIBED_METHODS: Set[str] = {"connect"}
_CLOSE_METHODS: ClassVar[Set[str]] = {"Connection.close", "Cursor.close"}
_SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.CONNECT.method_name}
_CLOSE_METHODS: ClassVar[Set[str]] = {DbApiMethod.CONNECTION_CLOSE.method_name, DbApiMethod.CURSOR_CLOSE.method_name}
_status_providers: ClassVar[ConcurrentDict[str, BlueGreenStatusProvider]] = ConcurrentDict()

def __init__(self, plugin_service: PluginService, props: Properties):
Expand Down Expand Up @@ -779,7 +783,8 @@ def get_hold_time_ns(self) -> int:


class BlueGreenPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
@staticmethod
def get_instance(plugin_service: PluginService, props: Properties) -> Plugin:
return BlueGreenPlugin(plugin_service, props)


Expand Down
6 changes: 4 additions & 2 deletions aws_advanced_python_wrapper/connect_time_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from time import perf_counter_ns
from typing import Callable, Set

from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.messages import Messages

Expand All @@ -42,7 +43,7 @@ def reset_connect_time():

@property
def subscribed_methods(self) -> Set[str]:
return {"connect", "force_connect"}
return {DbApiMethod.CONNECT.method_name, DbApiMethod.FORCE_CONNECT.method_name}

def connect(
self,
Expand All @@ -65,5 +66,6 @@ def connect(


class ConnectTimePluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
@staticmethod
def get_instance(plugin_service: PluginService, props: Properties) -> Plugin:
return ConnectTimePlugin()
15 changes: 10 additions & 5 deletions aws_advanced_python_wrapper/custom_endpoint_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from boto3 import Session

from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.properties import WrapperProperties
Expand Down Expand Up @@ -194,7 +195,9 @@ def _run(self):
self._custom_endpoint_host_info.host,
endpoint_info,
CustomEndpointMonitor._CUSTOM_ENDPOINT_INFO_EXPIRATION_NS)
self._info_changed_counter.inc()

if self._info_changed_counter is not None:
self._info_changed_counter.inc()

elapsed_time = perf_counter_ns() - start_ns
sleep_duration = max(0, self._refresh_rate_ns - elapsed_time)
Expand Down Expand Up @@ -228,7 +231,7 @@ class CustomEndpointPlugin(Plugin):
A plugin that analyzes custom endpoints for custom endpoint information and custom endpoint changes, such as adding
or removing an instance in the custom endpoint.
"""
_SUBSCRIBED_METHODS: ClassVar[Set[str]] = {"connect"}
_SUBSCRIBED_METHODS: ClassVar[Set[str]] = {DbApiMethod.CONNECT.method_name}
_CACHE_CLEANUP_RATE_NS: ClassVar[int] = 6 * 10 ^ 10 # 1 minute
_monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, CustomEndpointMonitor]] = \
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_RATE_NS,
Expand All @@ -250,7 +253,7 @@ def __init__(self, plugin_service: PluginService, props: Properties):
self._custom_endpoint_host_info: Optional[HostInfo] = None
self._custom_endpoint_id: Optional[str] = None
telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory()
self._wait_for_info_counter: TelemetryCounter = telemetry_factory.create_counter("customEndpoint.waitForInfo.counter")
self._wait_for_info_counter: TelemetryCounter | None = telemetry_factory.create_counter("customEndpoint.waitForInfo.counter")

CustomEndpointPlugin._SUBSCRIBED_METHODS.update(self._plugin_service.network_bound_methods)

Expand Down Expand Up @@ -312,7 +315,8 @@ def _wait_for_info(self, monitor: CustomEndpointMonitor):
if has_info:
return

self._wait_for_info_counter.inc()
if self._wait_for_info_counter is not None:
self._wait_for_info_counter.inc()
host_info = cast('HostInfo', self._custom_endpoint_host_info)
hostname = host_info.host
logger.debug("CustomEndpointPlugin.WaitingForCustomEndpointInfo", hostname, self._wait_for_info_timeout_ms)
Expand Down Expand Up @@ -343,5 +347,6 @@ def execute(self, target: type, method_name: str, execute_func: Callable, *args:


class CustomEndpointPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
@staticmethod
def get_instance(plugin_service: PluginService, props: Properties) -> Plugin:
return CustomEndpointPlugin(plugin_service, props)
12 changes: 7 additions & 5 deletions aws_advanced_python_wrapper/default_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.host_availability import HostAvailability
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.plugin import Plugin
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.telemetry.telemetry import \
TelemetryTraceLevel


class DefaultPlugin(Plugin):
_SUBSCRIBED_METHODS: Set[str] = {"*"}
_CLOSE_METHOD = "Connection.close"
_SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.ALL.method_name}

def __init__(self, plugin_service: PluginService, connection_provider_manager: ConnectionProviderManager):
self._plugin_service: PluginService = plugin_service
Expand Down Expand Up @@ -74,7 +74,8 @@ def _connect(
database_dialect = self._plugin_service.database_dialect
conn = conn_provider.connect(target_func, driver_dialect, database_dialect, host_info, props)
finally:
context.close_context()
if context is not None:
context.close_context()

self._plugin_service.set_availability(host_info.all_aliases, HostAvailability.AVAILABLE)
self._plugin_service.update_driver_dialect(conn_provider)
Expand Down Expand Up @@ -106,9 +107,10 @@ def execute(self, target: object, method_name: str, execute_func: Callable, *arg
try:
result = self._plugin_service.driver_dialect.execute(method_name, execute_func, *args, **kwargs)
finally:
context.close_context()
if context is not None:
context.close_context()

if method_name != DefaultPlugin._CLOSE_METHOD and self._plugin_service.current_connection is not None:
if method_name != DbApiMethod.CONNECTION_CLOSE.method_name and self._plugin_service.current_connection is not None:
self._plugin_service.update_in_transaction()

return result
Expand Down
10 changes: 5 additions & 5 deletions aws_advanced_python_wrapper/developer_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from aws_advanced_python_wrapper.utils.properties import Properties
from aws_advanced_python_wrapper.hostinfo import HostInfo

from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
Expand Down Expand Up @@ -76,8 +77,7 @@ def set_method_callback(method_callback: Optional[ExceptionSimulatorMethodCallba


class DeveloperPlugin(Plugin):
_ALL_METHODS: str = "*"
_SUBSCRIBED_METHODS: Set[str] = {_ALL_METHODS}
_SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.ALL.method_name}

@property
def subscribed_methods(self) -> Set[str]:
Expand All @@ -90,7 +90,7 @@ def execute(self, target: type, method_name: str, execute_func: Callable, *args:
def raise_method_exception_if_set(
self, target: type, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> None:
if ExceptionSimulatorManager.next_method_exception is not None:
if DeveloperPlugin._ALL_METHODS == ExceptionSimulatorManager.next_method_name or \
if DbApiMethod.ALL.method_name == ExceptionSimulatorManager.next_method_name or \
method_name == ExceptionSimulatorManager.next_method_name:
self.raise_exception_on_method(ExceptionSimulatorManager.next_method_exception, method_name)
elif ExceptionSimulatorManager.method_callback is not None:
Expand Down Expand Up @@ -158,6 +158,6 @@ def raise_exception_on_connect(self, error: Optional[Exception]) -> None:


class DeveloperPluginFactory(PluginFactory):

def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
@staticmethod
def get_instance(plugin_service: PluginService, props: Properties) -> Plugin:
return DeveloperPlugin()
8 changes: 4 additions & 4 deletions aws_advanced_python_wrapper/driver_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes
from aws_advanced_python_wrapper.errors import (QueryTimeoutError,
UnsupportedOperationError)
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.thread_pool_container import \
ThreadPoolContainer
from aws_advanced_python_wrapper.utils.decorators import timeout
Expand All @@ -40,11 +41,10 @@ class DriverDialect(ABC):
Driver dialects help the driver-agnostic AWS Python Driver interface with the driver-specific functionality of the underlying Python Driver.
"""
_QUERY = "SELECT 1"
_ALL_METHODS = "*"

_executor_name: ClassVar[str] = "DriverDialectExecutor"
_dialect_code: str = DriverDialectCodes.GENERIC
_network_bound_methods: Set[str] = {_ALL_METHODS}
_network_bound_methods: Set[str] = {DbApiMethod.ALL.method_name}
_read_only: bool = False
_autocommit: bool = False
_driver_name: str = "Generic"
Expand Down Expand Up @@ -130,7 +130,7 @@ def execute(
*args: Any,
exec_timeout: Optional[float] = None,
**kwargs: Any) -> Cursor:
if DriverDialect._ALL_METHODS not in self.network_bound_methods and method_name not in self.network_bound_methods:
if DbApiMethod.ALL.method_name not in self.network_bound_methods and method_name not in self.network_bound_methods:
return exec_func()

if exec_timeout is None:
Expand Down Expand Up @@ -161,7 +161,7 @@ def ping(self, conn: Connection) -> bool:
try:
with conn.cursor() as cursor:
query = DriverDialect._QUERY
self.execute("Cursor.execute", lambda: cursor.execute(query), query, exec_timeout=10)
self.execute(DbApiMethod.CURSOR_EXECUTE.method_name, lambda: cursor.execute(query), query, exec_timeout=10)
cursor.fetchone()
return True
except Exception:
Expand Down
3 changes: 2 additions & 1 deletion aws_advanced_python_wrapper/execute_time_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,6 @@ def execute(self, target: type, method_name: str, execute_func: Callable, *args:


class ExecuteTimePluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
@staticmethod
def get_instance(plugin_service: PluginService, props: Properties) -> Plugin:
return ExecuteTimePlugin()
Loading